mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
Compare commits
109 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2de98d32be | |||
| 83127467c5 | |||
| e07c94f49d | |||
| b87146a284 | |||
| 186d9f7f44 | |||
| 3a8ffc2104 | |||
| e54570162e | |||
| 34bb97a40e | |||
| f5c1a16f8a | |||
| b29d7d2166 | |||
| 3ea0db69ea | |||
| 4304a5e899 | |||
| 4019c1e4f7 | |||
| 30ad7c1d0b | |||
| e33564c569 | |||
| 5ff28445fd | |||
| 13d177e5d0 | |||
| 3ffc995f27 | |||
| b2bd016135 | |||
| 812dedaea2 | |||
| f65b6737f2 | |||
| 5d45aa196b | |||
| debb39f56c | |||
| 7ce6f4fd9e | |||
| c2f2bdb3d3 | |||
| 201d14804e | |||
| 7e54710d4a | |||
| 88870be4d2 | |||
| 46bf809715 | |||
| e19e9642a2 | |||
| 32675b8b35 | |||
| aecce001ee | |||
| 32d73171fd | |||
| 265fd2ec0c | |||
| efebf2eba6 | |||
| f7b655ae33 | |||
| 35982b41ad | |||
| 788e3b71fd | |||
| 3496bc2684 | |||
| 997c7d15ff | |||
| 83246e47d5 | |||
| 50e7b78c33 | |||
| b937af3056 | |||
| a5e995c750 | |||
| e95e08a699 | |||
| bcaf92f62a | |||
| ead4afbfe6 | |||
| 685aaf207f | |||
| 76ff6c9639 | |||
| 1cf24ee5de | |||
| c9637090fa | |||
| 0ff0ff42ab | |||
| a4fb32ff2b | |||
| 7d2f078111 | |||
| b0b66941ab | |||
| cbb7387a72 | |||
| 19430b0ecb | |||
| 8e3cfeede5 | |||
| 4fa5775974 | |||
| 4e7d823ee4 | |||
| 7a16c76adc | |||
| 70a21ee73a | |||
| 28d2de8f39 | |||
| 7f192ae850 | |||
| 9f6746ded9 | |||
| 7514d3a0ff | |||
| c83281a52b | |||
| 4515bb92c2 | |||
| e326b84204 | |||
| 1b93049b8e | |||
| 4912449dda | |||
| b70cce4f34 | |||
| 4c566836b2 | |||
| bb3261883a | |||
| 512d0f16ce | |||
| 8159431ce4 | |||
| 9f9f265fb3 | |||
| 9d38349091 | |||
| fec8bac800 | |||
| e76f5f3d45 | |||
| 1ad493c5c7 | |||
| ea6ddc8792 | |||
| 6d4e8bcec5 | |||
| e2ed345280 | |||
| e542eb797e | |||
| e631fc1b17 | |||
| 290c5a4774 | |||
| 287d60c31e | |||
| 3d45d98895 | |||
| db4be4f9a2 | |||
| 80093e69ed | |||
| ef519ba517 | |||
| d79eb1f0fa | |||
| ac8ee6525d | |||
| e35e8382d6 | |||
| fbb3408a25 | |||
| 44fed9a647 | |||
| e7f11487b9 | |||
| 054c417603 | |||
| 94d62a6ef0 | |||
| 91e6dfd2c8 | |||
| b6a0c4b44c | |||
| 8eb0fa855a | |||
| 3bf696c546 | |||
| 3e461a0539 | |||
| a2ece01ecf | |||
| 623c9fb5ad | |||
| 139506f336 | |||
| 6d424554ad |
@@ -0,0 +1,79 @@
|
||||
name: Bug Report
|
||||
description: Report a bug or issue with Kit
|
||||
title: "fix: "
|
||||
labels: ["bug"]
|
||||
body:
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: Bug Description
|
||||
description: What happened? What did you expect to happen?
|
||||
placeholder: |
|
||||
The BorderColor field in ToolRenderConfig is documented but never applied
|
||||
during tool rendering. I expected the tool block to render with my custom
|
||||
color, but it uses the default styling instead.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
attributes:
|
||||
label: Steps to Reproduce
|
||||
description: Provide clear steps to reproduce the issue
|
||||
placeholder: |
|
||||
1. Create an extension with `api.RegisterToolRenderer(ext.ToolRenderConfig{...})`
|
||||
2. Set `BorderColor: "#89b4fa"` in the config
|
||||
3. Run a tool that uses this renderer
|
||||
4. Observe the border color is not applied
|
||||
render: markdown
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: code
|
||||
attributes:
|
||||
label: Relevant Code / Configuration
|
||||
description: Paste any code, configuration, or error messages
|
||||
placeholder: |
|
||||
```go
|
||||
api.RegisterToolRenderer(ext.ToolRenderConfig{
|
||||
ToolName: "bash",
|
||||
DisplayName: "Shell",
|
||||
BorderColor: "#a6e3a1", // This is ignored!
|
||||
Background: "#1e1e2e", // This is ignored!
|
||||
})
|
||||
```
|
||||
render: go
|
||||
|
||||
- type: input
|
||||
id: component
|
||||
attributes:
|
||||
label: Affected Component
|
||||
description: Which part of Kit is affected?
|
||||
placeholder: e.g., extensions, ui, tool rendering, session management
|
||||
|
||||
- type: input
|
||||
id: version
|
||||
attributes:
|
||||
label: Kit Version
|
||||
description: What version of Kit are you running?
|
||||
placeholder: e.g., v0.1.0, commit hash, or "main"
|
||||
|
||||
- type: textarea
|
||||
id: context
|
||||
attributes:
|
||||
label: Additional Context
|
||||
description: Any other context, proposed fixes, or related issues
|
||||
placeholder: |
|
||||
The issue appears to be in `internal/ui/messages.go:RenderToolMessage()`
|
||||
which ignores the BorderColor and Background fields from ToolRendererData.
|
||||
|
||||
- type: checkboxes
|
||||
id: terms
|
||||
attributes:
|
||||
label: Checklist
|
||||
options:
|
||||
- label: I've searched existing issues and this hasn't been reported yet
|
||||
required: true
|
||||
- label: I've tested with the latest version of Kit
|
||||
required: false
|
||||
@@ -0,0 +1,11 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: Kit Documentation
|
||||
url: https://github.com/mark3labs/kit/tree/main/www/pages
|
||||
about: Check the documentation before filing an issue
|
||||
- name: Extension Examples
|
||||
url: https://github.com/mark3labs/kit/tree/main/examples/extensions
|
||||
about: See working extension examples for reference
|
||||
- name: Discussions
|
||||
url: https://github.com/mark3labs/kit/discussions
|
||||
about: For questions, ideas, or general discussion
|
||||
@@ -0,0 +1,40 @@
|
||||
name: Documentation Issue
|
||||
description: Report missing, incorrect, or unclear documentation
|
||||
title: "docs: "
|
||||
labels: ["documentation"]
|
||||
body:
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: Documentation Issue
|
||||
description: What's wrong or missing in the documentation?
|
||||
placeholder: |
|
||||
The ToolRenderConfig documentation mentions BorderColor and Background fields,
|
||||
but the code doesn't actually use them. The docs should either be updated
|
||||
to reflect reality, or the bug should be fixed.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: location
|
||||
attributes:
|
||||
label: Documentation Location
|
||||
description: Where is the affected documentation?
|
||||
placeholder: e.g., README.md, examples/extensions/tool-renderer-demo.go, pkg/kit docs
|
||||
|
||||
- type: textarea
|
||||
id: suggestion
|
||||
attributes:
|
||||
label: Suggested Improvement
|
||||
description: How should the documentation be improved?
|
||||
placeholder: |
|
||||
Add a note that BorderColor and Background are not yet implemented,
|
||||
or fix the bug and document the correct behavior.
|
||||
|
||||
- type: checkboxes
|
||||
id: terms
|
||||
attributes:
|
||||
label: Checklist
|
||||
options:
|
||||
- label: I've checked that this documentation issue still exists in the latest version
|
||||
required: true
|
||||
@@ -0,0 +1,64 @@
|
||||
name: Feature Request
|
||||
description: Suggest a new feature or enhancement for Kit
|
||||
title: "feat: "
|
||||
labels: ["enhancement"]
|
||||
body:
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: Feature Description
|
||||
description: What would you like to see added or changed?
|
||||
placeholder: |
|
||||
I'd like to be able to customize the border color of tool result blocks
|
||||
dynamically based on the tool type or result status.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: motivation
|
||||
attributes:
|
||||
label: Motivation / Use Case
|
||||
description: Why is this feature needed? What problem does it solve?
|
||||
placeholder: |
|
||||
When running multiple tools in sequence, it's hard to visually distinguish
|
||||
between file reads (blue), shell commands (green), and errors (red)
|
||||
without custom border colors.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: proposed
|
||||
attributes:
|
||||
label: Proposed Implementation
|
||||
description: How do you think this should work? (optional)
|
||||
placeholder: |
|
||||
Extend `ToolRenderConfig` to accept a function that receives the tool
|
||||
result and returns a color based on the content:
|
||||
|
||||
```go
|
||||
BorderColorFunc: func(result string, isError bool) string {
|
||||
if isError {
|
||||
return "#f38ba8"
|
||||
}
|
||||
return "#89b4fa"
|
||||
}
|
||||
```
|
||||
render: go
|
||||
|
||||
- type: checkboxes
|
||||
id: alternatives
|
||||
attributes:
|
||||
label: Alternatives Considered
|
||||
options:
|
||||
- label: I've considered workarounds or alternative approaches
|
||||
required: false
|
||||
|
||||
- type: checkboxes
|
||||
id: terms
|
||||
attributes:
|
||||
label: Checklist
|
||||
options:
|
||||
- label: I've searched existing issues and this hasn't been requested yet
|
||||
required: true
|
||||
- label: This feature aligns with Kit's design philosophy (TUI-first, extension-based)
|
||||
required: false
|
||||
@@ -3,6 +3,7 @@
|
||||
.env
|
||||
.kit/*
|
||||
!.kit/extensions/
|
||||
!.kit/prompts/
|
||||
aidocs/
|
||||
*.log
|
||||
/kit
|
||||
|
||||
@@ -28,11 +28,15 @@ type lintResult struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// Package-level state: set of .go files edited during the current agent turn.
|
||||
var editedFiles map[string]bool
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
ctx.Print("go-edit-lint extension loaded - will run gopls and golangci-lint on Go file edits")
|
||||
ctx.Print("go-edit-lint extension loaded - will run gopls and golangci-lint after agent turns that edit Go files")
|
||||
})
|
||||
|
||||
// Track edited .go files — don't lint yet.
|
||||
api.OnToolResult(func(e ext.ToolResultEvent, ctx ext.Context) *ext.ToolResultResult {
|
||||
if e.IsError || !isEditOrWrite(e.ToolName) {
|
||||
return nil
|
||||
@@ -43,30 +47,72 @@ func Init(api ext.API) {
|
||||
return nil
|
||||
}
|
||||
|
||||
report := runGoDiagnostics(ctx.CWD, absPath)
|
||||
|
||||
// Check if there are issues and add explicit prompt for the LLM to react
|
||||
goplsIssues, lintIssues := countIssues(report)
|
||||
hasIssues := goplsIssues > 0 || lintIssues > 0
|
||||
|
||||
var enhanced string
|
||||
if hasIssues {
|
||||
enhanced = e.Content + "\n\n" + report + "\n\n⚠️ DIAGNOSTICS FOUND: Please review the issues above and fix them before proceeding."
|
||||
} else {
|
||||
enhanced = e.Content + "\n\n" + report
|
||||
if editedFiles == nil {
|
||||
editedFiles = make(map[string]bool)
|
||||
}
|
||||
editedFiles[absPath] = true
|
||||
return nil
|
||||
})
|
||||
|
||||
// After the agent turn ends, lint all collected files.
|
||||
api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) {
|
||||
if len(editedFiles) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Show TUI message block for diagnostics visibility (only if there are issues)
|
||||
// Snapshot and reset immediately so the next turn starts clean.
|
||||
files := editedFiles
|
||||
editedFiles = nil
|
||||
|
||||
// Skip lint on errored turns.
|
||||
if e.StopReason == "error" {
|
||||
return
|
||||
}
|
||||
|
||||
// Collect unique directories and file list for gopls.
|
||||
var allGoplsOutput []string
|
||||
for absPath := range files {
|
||||
res := runGopls(ctx.CWD, absPath)
|
||||
formatted := formatToolResult(res, "")
|
||||
if formatted != "" {
|
||||
allGoplsOutput = append(allGoplsOutput, fmt.Sprintf("# %s\n%s", filepath.Base(absPath), formatted))
|
||||
}
|
||||
}
|
||||
|
||||
lintRes := runGolangCILint(ctx.CWD, "./...")
|
||||
|
||||
goplsSection := "No diagnostics."
|
||||
if len(allGoplsOutput) > 0 {
|
||||
goplsSection = strings.Join(allGoplsOutput, "\n\n")
|
||||
}
|
||||
lintSection := formatToolResult(lintRes, "No lint issues.")
|
||||
|
||||
// Build file list for the report header.
|
||||
var fileNames []string
|
||||
for absPath := range files {
|
||||
fileNames = append(fileNames, filepath.Base(absPath))
|
||||
}
|
||||
|
||||
report := fmt.Sprintf(
|
||||
"<go_diagnostics files=%q>\n[gopls]\n%s\n\n[golangci-lint]\n%s\n</go_diagnostics>",
|
||||
strings.Join(fileNames, ", "),
|
||||
goplsSection,
|
||||
lintSection,
|
||||
)
|
||||
|
||||
goplsIssues, lintIssues := countIssues(report)
|
||||
hasIssues := goplsIssues > 0 || lintIssues > 0
|
||||
|
||||
if hasIssues {
|
||||
// Show TUI block so the user sees it too.
|
||||
var msgLines []string
|
||||
msgLines = append(msgLines, fmt.Sprintf("File: %s", filepath.Base(absPath)))
|
||||
msgLines = append(msgLines, fmt.Sprintf("Files: %s", strings.Join(fileNames, ", ")))
|
||||
if goplsIssues > 0 {
|
||||
msgLines = append(msgLines, fmt.Sprintf("gopls: %d issue(s)", goplsIssues))
|
||||
}
|
||||
if lintIssues > 0 {
|
||||
msgLines = append(msgLines, fmt.Sprintf("golangci-lint: %d issue(s)", lintIssues))
|
||||
}
|
||||
msgLines = append(msgLines, "", "⚠️ Please fix these issues before proceeding.")
|
||||
|
||||
borderColor := "#f9e2af" // yellow
|
||||
if goplsIssues > 0 && lintIssues > 0 {
|
||||
@@ -78,9 +124,16 @@ func Init(api ext.API) {
|
||||
BorderColor: borderColor,
|
||||
Subtitle: "go-edit-lint",
|
||||
})
|
||||
}
|
||||
|
||||
return &ext.ToolResultResult{Content: &enhanced}
|
||||
// Inject a follow-up message so the agent fixes the issues.
|
||||
ctx.SendMessage(report + "\n\n⚠️ DIAGNOSTICS FOUND: Please review and fix the issues above.")
|
||||
} else {
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: fmt.Sprintf("Files: %s\n✓ All clean", strings.Join(fileNames, ", ")),
|
||||
BorderColor: "#a6e3a1",
|
||||
Subtitle: "go-edit-lint",
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -106,18 +159,6 @@ func resolveGoFilePath(inputJSON, cwd string) (string, bool) {
|
||||
return absPath, true
|
||||
}
|
||||
|
||||
func runGoDiagnostics(cwd, absPath string) string {
|
||||
gopls := runGopls(cwd, absPath)
|
||||
lint := runGolangCILint(cwd, "./...")
|
||||
|
||||
return fmt.Sprintf(
|
||||
"<go_diagnostics file=%q>\n[gopls]\n%s\n\n[golangci-lint]\n%s\n</go_diagnostics>",
|
||||
filepath.Base(absPath),
|
||||
formatToolResult(gopls, "No diagnostics."),
|
||||
formatToolResult(lint, "No lint issues."),
|
||||
)
|
||||
}
|
||||
|
||||
func runGopls(cwd, absPath string) lintResult {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), diagnosticsTimeout)
|
||||
defer cancel()
|
||||
@@ -178,7 +219,9 @@ func formatToolResult(res lintResult, emptyFallback string) string {
|
||||
out := strings.TrimSpace(res.Output)
|
||||
if out == "" {
|
||||
if res.Err == nil {
|
||||
lines = append(lines, emptyFallback)
|
||||
if emptyFallback != "" {
|
||||
lines = append(lines, emptyFallback)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
lines = append(lines, out)
|
||||
@@ -197,17 +240,15 @@ func truncate(s string, max int) string {
|
||||
}
|
||||
|
||||
func countIssues(report string) (goplsCount, lintCount int) {
|
||||
// Extract gopls section
|
||||
goplsStart := strings.Index(report, "[gopls]")
|
||||
lintStart := strings.Index(report, "[golangci-lint]")
|
||||
endTag := strings.Index(report, "</go_diagnostics>")
|
||||
|
||||
if goplsStart != -1 && lintStart != -1 {
|
||||
goplsSection := report[goplsStart:lintStart]
|
||||
// Count non-empty lines excluding the header and "No diagnostics." message
|
||||
for _, line := range strings.Split(goplsSection, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" && line != "[gopls]" && line != "No diagnostics." {
|
||||
if line != "" && line != "[gopls]" && line != "No diagnostics." && !strings.HasPrefix(line, "#") {
|
||||
goplsCount++
|
||||
}
|
||||
}
|
||||
@@ -215,7 +256,6 @@ func countIssues(report string) (goplsCount, lintCount int) {
|
||||
|
||||
if lintStart != -1 && endTag != -1 {
|
||||
lintSection := report[lintStart:endTag]
|
||||
// Count non-empty lines excluding the header and "No lint issues." message
|
||||
for _, line := range strings.Split(lintSection, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" && line != "[golangci-lint]" && line != "No lint issues." {
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
---
|
||||
description: Run ACP smoke test against opencode/kimi-k2.5 to verify JSON-RPC stdio works
|
||||
---
|
||||
|
||||
Run the ACP smoke test to verify the Kit ACP server works correctly over JSON-RPC stdio with streaming responses.
|
||||
|
||||
## Steps
|
||||
|
||||
1. Build the kit binary:
|
||||
```bash
|
||||
go build -o output/kit ./cmd/kit
|
||||
```
|
||||
|
||||
2. Run the smoke test Python script against opencode/kimi-k2.5:
|
||||
```bash
|
||||
python3 scripts/acp_smoke_test.py
|
||||
```
|
||||
|
||||
3. Verify the output shows:
|
||||
- `session/new` returns a valid `sessionId`
|
||||
- `session/prompt` streams `agent_thought_chunk` notifications (reasoning)
|
||||
- `session/prompt` streams `agent_message_chunk` notifications (response)
|
||||
- Final result has `stopReason: "end_turn"`
|
||||
- `✓ SMOKE TEST PASSED` at the end
|
||||
|
||||
4. If the test fails, check:
|
||||
- `output/kit` binary exists and is executable
|
||||
- `OPENCODE_API_KEY` or `OPENCODE_ZEN_API_KEY` environment variable is set
|
||||
- `scripts/acp_smoke_test.py` exists
|
||||
- The model `opencode/kimi-k2.5` is available (`kit models opencode | grep kimi-k2.5`)
|
||||
|
||||
5. For testing with a different model, edit the script or set the `MODEL` variable:
|
||||
```bash
|
||||
MODEL=anthropic/claude-sonnet-4-5 python3 scripts/acp_smoke_test.py
|
||||
```
|
||||
|
||||
The smoke test exercises the full ACP protocol: session lifecycle, streaming notifications, and tool-free prompt completion.
|
||||
@@ -0,0 +1,30 @@
|
||||
---
|
||||
description: Stage, commit, and push changes with an auto-generated conventional commit message
|
||||
---
|
||||
|
||||
Review the current git status and diff, then stage all changes, write a concise conventional commit message, commit, and push to the current branch.
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Check status**: `git status` — understand what has changed
|
||||
2. **Review the diff**: `git diff` (and `git diff --cached` if anything is already staged) — read the actual changes
|
||||
3. **Stage everything**: `git add -A`
|
||||
4. **Craft the commit message** following Conventional Commits:
|
||||
- Format: `<type>(<scope>): <short summary>`
|
||||
- Types: `feat`, `fix`, `refactor`, `chore`, `docs`, `test`, `perf`, `build`
|
||||
- Scope: optional, the subsystem affected (e.g. `ui`, `cmd`, `config`)
|
||||
- Summary: imperative mood, lowercase, no trailing period, ≤72 chars
|
||||
- Body: add a blank line then bullet points for non-trivial changes
|
||||
- Do **not** include "Generated by" or similar noise
|
||||
5. **Commit**: `git commit -m "<message>"`
|
||||
6. **Push**: `git push`
|
||||
|
||||
## Guidelines
|
||||
|
||||
- Read the actual diff — do not guess from filenames alone
|
||||
- Prefer one well-scoped commit; do not split unless the changes are clearly unrelated
|
||||
- Keep the subject line under 72 characters
|
||||
- Use the body to explain *what* and *why*, not *how*
|
||||
- If there is nothing to commit, say so and stop
|
||||
|
||||
$@
|
||||
@@ -0,0 +1,86 @@
|
||||
---
|
||||
description: Create a feature request using the GitHub template
|
||||
---
|
||||
|
||||
Create a feature request for the Kit repository. The user wants to request: $@
|
||||
|
||||
## Feature Request Template
|
||||
|
||||
This prompt uses the `feature_request` GitHub template which requires:
|
||||
|
||||
| Field | Required | Purpose |
|
||||
|-------|----------|---------|
|
||||
| **Feature Description** | Yes | What should be added or changed |
|
||||
| **Motivation / Use Case** | Yes | Why is this needed? What problem does it solve? |
|
||||
| **Proposed Implementation** | No | How do you think this should work? |
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Understand the request** from `$@`
|
||||
- What capability is missing?
|
||||
- What would the ideal behavior look like?
|
||||
|
||||
2. **Ask clarifying questions** if needed:
|
||||
- "What problem does this solve for you?"
|
||||
- "How would you expect this to work?"
|
||||
- "Are there similar features in other tools you use?"
|
||||
|
||||
3. **Craft the title** using conventional format:
|
||||
- `feat: <short description>`
|
||||
- Lowercase, imperative mood, ≤72 chars
|
||||
- Good examples:
|
||||
- `feat: add keyboard shortcut for clearing input`
|
||||
- `feat: support custom themes per extension`
|
||||
- `feat: add fuzzy matching to model selector`
|
||||
- Bad examples:
|
||||
- `Feature request: can we have...` (too vague)
|
||||
- `It would be nice if...` (not imperative)
|
||||
|
||||
4. **Build the body** with the template fields:
|
||||
|
||||
**Feature Description:**
|
||||
- Clear statement of what to add/change
|
||||
- Be specific about the behavior
|
||||
- Include UI/UX details if relevant
|
||||
|
||||
**Motivation / Use Case:**
|
||||
- What problem does this solve?
|
||||
- Current workaround (if any) and why it's insufficient
|
||||
- Who benefits from this feature?
|
||||
|
||||
**Proposed Implementation** (optional but helpful):
|
||||
- High-level approach
|
||||
- API changes if applicable
|
||||
- Example usage code
|
||||
|
||||
5. **Create the issue**:
|
||||
```bash
|
||||
gh issue create --template feature_request --title "feat: ..." --body "..."
|
||||
```
|
||||
|
||||
6. **Confirm success**:
|
||||
- Show the issue URL and number
|
||||
- Mention it was created with the feature_request template
|
||||
|
||||
## Guidelines
|
||||
|
||||
- Focus on the *problem* first, then the solution
|
||||
- Include concrete examples of how the feature would be used
|
||||
- Consider edge cases and mention them
|
||||
- If proposing API changes, show before/after code
|
||||
- Check if similar features exist in related tools (mention them for reference)
|
||||
- Align with Kit's philosophy: TUI-first, extension-based, keyboard-driven
|
||||
|
||||
## Example
|
||||
|
||||
User: `/feature-request I want to be able to customize tool border colors dynamically`
|
||||
|
||||
You:
|
||||
1. Title: `feat: dynamic border colors for tool results based on status`
|
||||
2. Body:
|
||||
- **Feature Description**: Allow `ToolRenderConfig` to accept a function that determines border color based on tool result content or status, enabling dynamic visual feedback.
|
||||
- **Motivation**: When running multiple tools, it's hard to distinguish file reads (blue), shell commands (green), and errors (red) without custom colors per result.
|
||||
- **Proposed Implementation**: Add `BorderColorFunc` callback that receives `(result string, isError bool)` and returns a color string.
|
||||
|
||||
3. Execute: `gh issue create --template feature_request --title "feat: ..." --body "..."`
|
||||
4. Confirm: Created issue #43 using feature_request template
|
||||
@@ -0,0 +1,100 @@
|
||||
---
|
||||
description: File a GitHub issue using the appropriate template
|
||||
---
|
||||
|
||||
File a GitHub issue for the Kit repository. The user wants to create an issue about: $@
|
||||
|
||||
## Issue Templates Available
|
||||
|
||||
This repository has structured issue templates. You MUST use the appropriate template:
|
||||
|
||||
| Type | Template | Use For |
|
||||
|------|----------|---------|
|
||||
| `bug` | `bug_report` | Something is broken, not working as expected |
|
||||
| `feat` | `feature_request` | New feature, enhancement, improvement |
|
||||
| `docs` | `documentation` | Missing, incorrect, or unclear documentation |
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Determine the issue type** from `$@`:
|
||||
- Bug → use `--template bug_report`
|
||||
- Feature → use `--template feature_request`
|
||||
- Documentation → use `--template documentation`
|
||||
|
||||
2. **Ask clarifying questions** if critical info is missing:
|
||||
- For bugs: "What were you doing when this happened?" (reproduction steps)
|
||||
- For features: "What problem does this solve?" (motivation)
|
||||
- For docs: "Where did you look for this information?" (location)
|
||||
|
||||
3. **Craft the title** using conventional format:
|
||||
- `<type>: <short description>`
|
||||
- Lowercase, imperative mood, ≤72 chars
|
||||
- Examples:
|
||||
- `fix: ToolRenderConfig BorderColor ignored during rendering`
|
||||
- `feat: add keyboard shortcut for clearing input`
|
||||
- `docs: clarify extension widget lifecycle`
|
||||
|
||||
4. **File the issue** using the template:
|
||||
```bash
|
||||
# For bugs
|
||||
gh issue create --template bug_report --title "fix: ..." --body "..."
|
||||
|
||||
# For features
|
||||
gh issue create --template feature_request --title "feat: ..." --body "..."
|
||||
|
||||
# For documentation
|
||||
gh issue create --template documentation --title "docs: ..." --body "..."
|
||||
```
|
||||
|
||||
The template will guide the user through the required fields. You need to provide:
|
||||
- **Bug reports**: Description, reproduction steps, expected vs actual behavior
|
||||
- **Feature requests**: Description, motivation/use case, optional proposed implementation
|
||||
- **Documentation**: Description, location of docs, suggested improvement
|
||||
|
||||
5. **Confirm success** by showing:
|
||||
- The issue URL
|
||||
- The issue number
|
||||
- Which template was used
|
||||
|
||||
## Template Field Guide
|
||||
|
||||
### Bug Report (`bug_report`)
|
||||
Required fields in the body:
|
||||
- **Bug Description** - what happened vs expected
|
||||
- **Steps to Reproduce** - numbered list to recreate the bug
|
||||
- **Relevant Code** - code snippets, configuration, error messages
|
||||
- **Component** - which part of Kit (ui, extensions, session, etc.)
|
||||
- **Version** - Kit version or commit hash
|
||||
|
||||
### Feature Request (`feature_request`)
|
||||
Required fields in the body:
|
||||
- **Feature Description** - what to add/change
|
||||
- **Motivation / Use Case** - why this is needed
|
||||
- **Proposed Implementation** - how it could work (optional)
|
||||
|
||||
### Documentation (`documentation`)
|
||||
Required fields in the body:
|
||||
- **Documentation Issue** - what's wrong or missing
|
||||
- **Documentation Location** - file or URL where docs exist
|
||||
- **Suggested Improvement** - how to fix the docs
|
||||
|
||||
## Guidelines
|
||||
|
||||
- ALWAYS use `--template <name>` instead of bare `gh issue create`
|
||||
- Include file paths and line numbers when you know them
|
||||
- Use triple backticks for code blocks
|
||||
- Keep the body factual - avoid speculation unless in "Proposed Fix" section
|
||||
- If you're unsure about technical details, say so in the issue
|
||||
- For UI bugs, describe what you see vs what you expect
|
||||
- For API bugs, include the relevant struct/function names
|
||||
|
||||
## Example Usage
|
||||
|
||||
User: `/file-issue The ToolRenderConfig BorderColor field is documented but never used in rendering`
|
||||
|
||||
You:
|
||||
1. Determine this is a **bug** (documented field doesn't work)
|
||||
2. Use `--template bug_report`
|
||||
3. Gather: reproduction steps (register renderer with BorderColor), expected (custom color), actual (default color)
|
||||
4. Create issue with title `fix: ToolRenderConfig BorderColor and Background fields are ignored`
|
||||
5. Confirm: Created issue #42 using bug_report template
|
||||
@@ -0,0 +1,47 @@
|
||||
---
|
||||
description: Scaffold a new prompt template in .kit/prompts/
|
||||
---
|
||||
|
||||
Create a new kit prompt template. The user wants a prompt that does: $@
|
||||
|
||||
## What a prompt template is
|
||||
|
||||
A prompt template is a `.md` file in `.kit/prompts/` (project-local) or `~/.kit/prompts/` (global).
|
||||
It becomes a `/slug` slash command in the kit input box — typed as `/filename` with optional arguments.
|
||||
|
||||
## File format
|
||||
|
||||
```
|
||||
---
|
||||
description: One-line description shown in autocomplete
|
||||
---
|
||||
|
||||
Body text of the prompt. Use $@ for all user-supplied arguments,
|
||||
$1 $2 etc. for positional arguments.
|
||||
```
|
||||
|
||||
- **Filename** → slug: `commit-push.md` becomes `/commit-push`
|
||||
- **Frontmatter**: only `description` is recognised; keep it under ~80 chars
|
||||
- **Body**: plain markdown; the full text is submitted as the user's message when the template fires
|
||||
- **Arguments**: `$@` expands to everything the user typed after the slash command name;
|
||||
`$1`, `$2` for individual positional args; omit entirely if no arguments are needed
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Understand the workflow** the user described in `$@` — ask a clarifying question if the intent is ambiguous
|
||||
2. **Choose a filename**: short, lowercase, hyphen-separated, descriptive (e.g. `code-review.md`)
|
||||
3. **Write the description**: one sentence, imperative, fits in autocomplete
|
||||
4. **Draft the body**:
|
||||
- Open with a single sentence stating the goal
|
||||
- Use `## Steps` for multi-step workflows; use plain prose for simple prompts
|
||||
- Be specific: name commands, flags, and file paths where relevant
|
||||
- End with `$@` on its own line if the user might want to pass context or a hint; omit if the prompt is self-contained
|
||||
5. **Write the file** to `.kit/prompts/<slug>.md`
|
||||
6. **Confirm** by showing the final file content and the slash command that activates it
|
||||
|
||||
## Guidelines
|
||||
|
||||
- Keep prompts action-oriented — they should tell kit *what to do*, not just *what to think about*
|
||||
- Prefer concrete steps over vague instructions
|
||||
- A prompt that does one thing well beats one that tries to cover every edge case
|
||||
- If the workflow already exists as a prompt, suggest extending it instead of duplicating
|
||||
@@ -0,0 +1,70 @@
|
||||
---
|
||||
description: Semantic version tagging workflow - analyzes commits and tags releases
|
||||
---
|
||||
|
||||
# Release Tagging Workflow
|
||||
|
||||
Tag a new version of this Go project following semantic versioning.
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Fetch remote tags**: `git fetch --tags origin`
|
||||
|
||||
2. **Find latest version**: `git tag -l | sort -V | tail -5` to see recent tags
|
||||
|
||||
3. **Analyze changes since last tag**:
|
||||
- `git log <latest-tag>..HEAD --oneline` - list commits
|
||||
- `git diff <latest-tag>..HEAD --stat` - see file stats
|
||||
- `git diff <latest-tag>..HEAD --name-only` - see changed files
|
||||
|
||||
4. **Determine version bump** (Semantic Versioning):
|
||||
- **MAJOR (X.0.0)**: Breaking API changes, incompatible modifications
|
||||
- **MINOR (0.X.0)**: New features, backward-compatible additions
|
||||
- **PATCH (0.0.X)**: Bug fixes, backward-compatible fixes
|
||||
|
||||
Look for indicators:
|
||||
- `feat:` or `feature:` commits → MINOR
|
||||
- `fix:` or `bugfix:` commits → PATCH
|
||||
- `breaking:` or `BREAKING CHANGE:` → MAJOR
|
||||
- Breaking API changes in `pkg/` or public interfaces → MAJOR
|
||||
- New commands, flags, or features → MINOR
|
||||
- Documentation-only changes → PATCH (or skip)
|
||||
|
||||
5. **Calculate new version**: Increment appropriate segment, reset lower segments to 0
|
||||
|
||||
6. **Draft tag message**:
|
||||
- Summarize key changes from commits
|
||||
- Group by type (Features, Fixes, Breaking Changes)
|
||||
- Keep concise but informative
|
||||
|
||||
7. **Create annotated tag**: `git tag -a vX.Y.Z -m "vX.Y.Z - <summary>\n\n<detailed list>"`
|
||||
|
||||
8. **Push tag**: `git push origin vX.Y.Z`
|
||||
|
||||
## Guidelines
|
||||
|
||||
- Always fetch remote tags first to avoid conflicts
|
||||
- Use annotated tags (`-a`) with descriptive messages
|
||||
- Follow semver strictly - when in doubt, prefer conservative bump (patch over minor)
|
||||
- For Go projects, changes to `pkg/` or exported APIs warrant careful version consideration
|
||||
- If no changes since last tag, suggest skipping the release
|
||||
- Include commit summaries in the tag message body
|
||||
|
||||
## Example Tag Message Format
|
||||
|
||||
```
|
||||
v0.30.1 - Bug fixes for model handling and UI improvements
|
||||
|
||||
Fixes:
|
||||
- Properly handle think tags from Qwen/DeepSeek models
|
||||
- Handle custom provider model persistence and bare model names
|
||||
|
||||
Improvements:
|
||||
- UI style refactoring and cleanup
|
||||
```
|
||||
|
||||
Wait for the user to confirm the version and message before executing tag commands.
|
||||
|
||||
---
|
||||
|
||||
$@
|
||||
@@ -100,3 +100,21 @@ Positional args are the prompt. `@file` args attach file content. Key flags: `--
|
||||
- Never guess or manually search the filesystem for external projects
|
||||
- Example: `btca ask -r https://github.com/user/repo -q "How does X work?"`
|
||||
- See `.agents/skills/btca-cli/SKILL.md` for full btca usage
|
||||
|
||||
## BTCA Configured Resources
|
||||
The following external repositories are configured in `btca.config.jsonc` for research:
|
||||
|
||||
- bubbletea
|
||||
- lipgloss
|
||||
- bubbles
|
||||
- glamour
|
||||
- fantasy
|
||||
- catwalk
|
||||
- crush
|
||||
- pi
|
||||
- iteratr
|
||||
- yaegi
|
||||
- acp-go-sdk
|
||||
- opencode
|
||||
- herald
|
||||
- herald-md
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
# Autoscroll Fix - Final Summary
|
||||
|
||||
## Root Cause
|
||||
|
||||
The autoscroll was failing for streaming assistant messages due to a bug in how `GotoBottom()` calculated item heights.
|
||||
|
||||
### The Problem
|
||||
|
||||
1. **Reasoning blocks** (`StreamingMessageItem` with `role="reasoning"`) are **never cached** because they have live duration counters that update every render
|
||||
2. The `Height()` method returns `0` when `cachedRender == ""`
|
||||
3. `GotoBottom()` was calling:
|
||||
```go
|
||||
itemHeight := item.Height() // Returns 0 for reasoning
|
||||
if itemHeight == 0 {
|
||||
item.Render(s.width) // Renders but doesn't cache (reasoning)
|
||||
itemHeight = item.Height() // Still returns 0!
|
||||
}
|
||||
```
|
||||
4. This caused incorrect scroll position calculations, especially during reasoning → assistant transitions
|
||||
|
||||
## The Solution
|
||||
|
||||
Changed `GotoBottom()` and `AtBottom()` to calculate height **directly from the rendered string** instead of relying on the cached height:
|
||||
|
||||
```go
|
||||
// OLD: item.Height() which checks cached render
|
||||
itemHeight := item.Height()
|
||||
if itemHeight == 0 {
|
||||
item.Render(s.width)
|
||||
itemHeight = item.Height() // Still might be 0!
|
||||
}
|
||||
|
||||
// NEW: Calculate from rendered string directly
|
||||
rendered := item.Render(s.width)
|
||||
itemHeight := strings.Count(rendered, "\n") + 1
|
||||
```
|
||||
|
||||
This works for **all** items regardless of whether they cache their render or not.
|
||||
|
||||
## Files Changed
|
||||
|
||||
### `internal/ui/scrolllist.go`
|
||||
- **`GotoBottom()`**: Calculate height from rendered string (2 loops)
|
||||
- **`AtBottom()`**: Calculate height from rendered string (1 loop)
|
||||
|
||||
### `internal/ui/model.go`
|
||||
- **`appendStreamingChunk()`**: For existing messages, call `GotoBottom()` directly (iteratr pattern)
|
||||
- **`refreshContent()`**: Simplified to only call `SetItems()` (removed redundant `GotoBottom()`)
|
||||
- **Bash streaming handler**: Removed redundant `GotoBottom()` after `refreshContent()`
|
||||
|
||||
## Testing Results
|
||||
|
||||
✅ **Test prompt**: "explore this repo"
|
||||
|
||||
**Before fix**:
|
||||
- Autoscroll stopped after reasoning block completed
|
||||
- Viewport stuck showing end of reasoning ("Thought for 203ms")
|
||||
- Assistant response streamed off-screen below
|
||||
|
||||
**After fix**:
|
||||
- Autoscroll works throughout reasoning block
|
||||
- Autoscroll continues during reasoning → assistant transition
|
||||
- Viewport stays at bottom showing latest assistant content
|
||||
- Final position shows end of response (build commands section)
|
||||
|
||||
## Behavior Verified
|
||||
|
||||
1. ✅ Streaming text auto-scrolls to bottom
|
||||
2. ✅ Works across reasoning → assistant transition
|
||||
3. ✅ Manual scroll up (PgUp) disables autoscroll
|
||||
4. ✅ Scroll to bottom (Alt+End) re-enables autoscroll
|
||||
5. ✅ Accurate positioning with no offset errors
|
||||
|
||||
## Performance Note
|
||||
|
||||
The fix calls `Render()` on all items during `GotoBottom()` calculations. This is acceptable because:
|
||||
- `Render()` is already optimized with caching for non-reasoning items
|
||||
- `GotoBottom()` is only called during content updates (not every frame)
|
||||
- Reasoning blocks need to render anyway for live duration updates
|
||||
- This matches iteratr's approach of ensuring items are rendered before height calculations
|
||||
@@ -477,7 +477,7 @@ During an interactive session, use these slash commands:
|
||||
| `/import <path>` | Import and switch to a session from a JSONL file |
|
||||
| `/share` | Upload session to GitHub Gist and get a shareable viewer URL |
|
||||
| `/tree` | Navigate the session tree |
|
||||
| `/fork` | Branch from an earlier message |
|
||||
| `/fork` | Fork to new session from an earlier message |
|
||||
| `/new` | Start a fresh session |
|
||||
|
||||
## Go SDK
|
||||
@@ -531,7 +531,12 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
NoSession: true, // Ephemeral mode
|
||||
|
||||
// Tool options
|
||||
ExtraTools: []kit.Tool{...}, // Additional tools alongside defaults
|
||||
Tools: []kit.Tool{...}, // Replace default tool set entirely
|
||||
ExtraTools: []kit.Tool{...}, // Add tools alongside defaults
|
||||
DisableCoreTools: true, // Use no core tools (0 tools, for chat-only)
|
||||
|
||||
// Configuration
|
||||
SkipConfig: true, // Skip .kit.yml files (viper defaults + env vars still apply)
|
||||
|
||||
// Compaction
|
||||
AutoCompact: true, // Auto-compact near context limit
|
||||
@@ -540,6 +545,28 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
})
|
||||
```
|
||||
|
||||
### Custom Tools
|
||||
|
||||
Create custom tools with automatic schema generation — no external dependencies needed:
|
||||
|
||||
```go
|
||||
type SearchInput struct {
|
||||
Query string `json:"query" description:"Search query"`
|
||||
}
|
||||
|
||||
searchTool := kit.NewTool("search", "Search the codebase",
|
||||
func(ctx context.Context, input SearchInput) (kit.ToolOutput, error) {
|
||||
return kit.TextResult("Found: ..."), nil
|
||||
},
|
||||
)
|
||||
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
ExtraTools: []kit.Tool{searchTool}, // adds alongside built-in tools
|
||||
})
|
||||
```
|
||||
|
||||
Use `kit.NewParallelTool` for tools safe to run concurrently. See the [SDK docs](/sdk/overview) for full details on struct tags, `ToolOutput` fields, and `ToolCallIDFromContext`.
|
||||
|
||||
### With Callbacks
|
||||
|
||||
```go
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
acp "github.com/coder/acp-go-sdk"
|
||||
|
||||
"github.com/mark3labs/kit/internal/acpserver"
|
||||
@@ -54,6 +55,8 @@ func runACP(cmd *cobra.Command, _ []string) error {
|
||||
conn.SetLogger(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: slog.LevelDebug,
|
||||
})))
|
||||
// Also set charmbracelet/log level for acpserver package logging
|
||||
log.SetLevel(log.DebugLevel)
|
||||
}
|
||||
|
||||
// Wait for either the client to disconnect or a signal.
|
||||
|
||||
+314
-100
@@ -7,11 +7,10 @@ import (
|
||||
"image/color"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/fantasy"
|
||||
"charm.land/lipgloss/v2"
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
@@ -19,6 +18,8 @@ import (
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
"github.com/mark3labs/kit/internal/prompts"
|
||||
"github.com/mark3labs/kit/internal/ui"
|
||||
"github.com/mark3labs/kit/internal/ui/commands"
|
||||
"github.com/mark3labs/kit/internal/watcher"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
@@ -38,7 +39,6 @@ var (
|
||||
noExitFlag bool
|
||||
maxSteps int
|
||||
streamFlag bool // Enable streaming output
|
||||
compactMode bool // Enable compact output mode
|
||||
autoCompactFlag bool // Enable auto-compaction near context limit
|
||||
|
||||
// Session management
|
||||
@@ -50,12 +50,14 @@ var (
|
||||
noSessionFlag bool // --no-session: ephemeral mode, no persistence
|
||||
|
||||
// Model generation parameters
|
||||
maxTokens int
|
||||
temperature float32
|
||||
topP float32
|
||||
topK int32
|
||||
stopSequences []string
|
||||
thinkingLevel string
|
||||
maxTokens int
|
||||
temperature float32
|
||||
topP float32
|
||||
topK int32
|
||||
frequencyPenalty float32
|
||||
presencePenalty float32
|
||||
stopSequences []string
|
||||
thinkingLevel string
|
||||
|
||||
// Ollama-specific parameters
|
||||
numGPU int32
|
||||
@@ -156,6 +158,9 @@ func InitConfig() {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
// Rebuild the model registry now that viper has the config loaded,
|
||||
// so customModels defined in the config file are picked up.
|
||||
models.ReloadGlobalRegistry()
|
||||
}
|
||||
|
||||
// LoadConfigWithEnvSubstitution loads a config file with environment variable
|
||||
@@ -219,29 +224,10 @@ func configToUiTheme(cfg config.Theme) ui.Theme {
|
||||
}
|
||||
}
|
||||
|
||||
// kitBanner returns the KIT ASCII art title with KITT scanner lights,
|
||||
// rendered with a KITT red gradient.
|
||||
// kitBanner returns the KIT ASCII art title with KITT scanner lights.
|
||||
// Delegates to ui.KitBanner() which owns the logo rendering.
|
||||
func kitBanner() string {
|
||||
kittDark := lipgloss.Color("#8B0000")
|
||||
kittBright := lipgloss.Color("#FF2200")
|
||||
lines := []string{
|
||||
" ██╗ ██╗ ██╗ ████████╗",
|
||||
" ██║ ██╔╝ ██║ ╚══██╔══╝",
|
||||
" █████╔╝ ██║ ██║",
|
||||
" ██╔═██╗ ██║ ██║",
|
||||
" ██║ ██╗ ██║ ██║",
|
||||
" ╚═╝ ╚═╝ ╚═╝ ╚═╝",
|
||||
" ░░░░░░▒▒▒▒▒▓▓▓▓███████████████▓▓▓▓▒▒▒▒▒░░░░░░",
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
for i, line := range lines {
|
||||
if i > 0 {
|
||||
result.WriteString("\n")
|
||||
}
|
||||
result.WriteString(ui.ApplyGradient(line, kittDark, kittBright))
|
||||
}
|
||||
return result.String()
|
||||
return ui.KitBanner()
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -280,8 +266,6 @@ func init() {
|
||||
IntVar(&maxSteps, "max-steps", 0, "maximum number of agent steps (0 for unlimited)")
|
||||
rootCmd.PersistentFlags().
|
||||
BoolVar(&streamFlag, "stream", true, "enable streaming output for faster response display")
|
||||
rootCmd.PersistentFlags().
|
||||
BoolVar(&compactMode, "compact", false, "enable compact output mode without fancy styling")
|
||||
rootCmd.PersistentFlags().
|
||||
BoolVar(&autoCompactFlag, "auto-compact", false, "auto-compact conversation when near context limit")
|
||||
rootCmd.PersistentFlags().
|
||||
@@ -311,6 +295,8 @@ func init() {
|
||||
flags.Float32Var(&temperature, "temperature", 0.7, "controls randomness in responses (0.0-1.0)")
|
||||
flags.Float32Var(&topP, "top-p", 0.95, "controls diversity via nucleus sampling (0.0-1.0)")
|
||||
flags.Int32Var(&topK, "top-k", 40, "controls diversity by limiting top K tokens to sample from")
|
||||
flags.Float32Var(&frequencyPenalty, "frequency-penalty", 0.0, "penalizes tokens based on frequency of appearance (0.0-2.0)")
|
||||
flags.Float32Var(&presencePenalty, "presence-penalty", 0.0, "penalizes tokens based on whether they have appeared (0.0-2.0)")
|
||||
flags.StringSliceVar(&stopSequences, "stop-sequences", nil, "custom stop sequences (comma-separated)")
|
||||
flags.StringVar(&thinkingLevel, "thinking-level", "off", "extended thinking level: off, minimal, low, medium, high")
|
||||
|
||||
@@ -325,7 +311,6 @@ func init() {
|
||||
_ = viper.BindPFlag("debug", rootCmd.PersistentFlags().Lookup("debug"))
|
||||
_ = viper.BindPFlag("max-steps", rootCmd.PersistentFlags().Lookup("max-steps"))
|
||||
_ = viper.BindPFlag("stream", rootCmd.PersistentFlags().Lookup("stream"))
|
||||
_ = viper.BindPFlag("compact", rootCmd.PersistentFlags().Lookup("compact"))
|
||||
_ = viper.BindPFlag("auto-compact", rootCmd.PersistentFlags().Lookup("auto-compact"))
|
||||
|
||||
_ = viper.BindPFlag("provider-url", rootCmd.PersistentFlags().Lookup("provider-url"))
|
||||
@@ -334,6 +319,8 @@ func init() {
|
||||
_ = viper.BindPFlag("temperature", rootCmd.PersistentFlags().Lookup("temperature"))
|
||||
_ = viper.BindPFlag("top-p", rootCmd.PersistentFlags().Lookup("top-p"))
|
||||
_ = viper.BindPFlag("top-k", rootCmd.PersistentFlags().Lookup("top-k"))
|
||||
_ = viper.BindPFlag("frequency-penalty", rootCmd.PersistentFlags().Lookup("frequency-penalty"))
|
||||
_ = viper.BindPFlag("presence-penalty", rootCmd.PersistentFlags().Lookup("presence-penalty"))
|
||||
_ = viper.BindPFlag("stop-sequences", rootCmd.PersistentFlags().Lookup("stop-sequences"))
|
||||
_ = viper.BindPFlag("thinking-level", rootCmd.PersistentFlags().Lookup("thinking-level"))
|
||||
_ = viper.BindPFlag("num-gpu-layers", rootCmd.PersistentFlags().Lookup("num-gpu-layers"))
|
||||
@@ -411,21 +398,21 @@ func runKit(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// extensionCommandsForUI converts extension-registered CommandDefs into the
|
||||
// ui.ExtensionCommand type used by the interactive TUI. Command names are
|
||||
// commands.ExtensionCommand type used by the interactive TUI. Command names are
|
||||
// normalised to start with "/" so they integrate with the slash-command
|
||||
// autocomplete and dispatch pipeline.
|
||||
func extensionCommandsForUI(k *kit.Kit) []ui.ExtensionCommand {
|
||||
func extensionCommandsForUI(k *kit.Kit) []commands.ExtensionCommand {
|
||||
defs := k.Extensions().Commands()
|
||||
if len(defs) == 0 {
|
||||
return nil
|
||||
}
|
||||
cmds := make([]ui.ExtensionCommand, 0, len(defs))
|
||||
cmds := make([]commands.ExtensionCommand, 0, len(defs))
|
||||
for _, d := range defs {
|
||||
name := d.Name
|
||||
if len(name) > 0 && name[0] != '/' {
|
||||
name = "/" + name
|
||||
}
|
||||
ec := ui.ExtensionCommand{
|
||||
ec := commands.ExtensionCommand{
|
||||
Name: name,
|
||||
Description: d.Description,
|
||||
Execute: func(args string) (string, error) {
|
||||
@@ -728,7 +715,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
var spinnerFunc kit.SpinnerFunc
|
||||
if !quietFlag {
|
||||
spinnerFunc = func(fn func() error) error {
|
||||
tempCli, tempErr := ui.NewCLI(viper.GetBool("debug"), viper.GetBool("compact"))
|
||||
tempCli, tempErr := ui.NewCLI(viper.GetBool("debug"))
|
||||
if tempErr == nil {
|
||||
return tempCli.ShowSpinner(fn)
|
||||
}
|
||||
@@ -738,13 +725,33 @@ func runNormalMode(ctx context.Context) error {
|
||||
|
||||
// Build Kit options from CLI flags and create the SDK instance.
|
||||
// kit.New() handles: config → skills → agent → session → extension bridge.
|
||||
authHandler, authErr := kit.NewCLIMCPAuthHandler()
|
||||
if authErr != nil {
|
||||
// Non-fatal: OAuth just won't be available for remote MCP servers.
|
||||
fmt.Fprintf(os.Stderr, "Warning: Failed to create OAuth handler: %v\n", authErr)
|
||||
}
|
||||
|
||||
// appInstancePtr is used to break the circular dependency between
|
||||
// kit.New (which needs the OnMCPServerLoaded callback) and app.New
|
||||
// (which is needed by the callback to send events to the TUI).
|
||||
var appInstancePtr *app.App
|
||||
|
||||
kitOpts := &kit.Options{
|
||||
Quiet: quietFlag,
|
||||
Debug: debugMode,
|
||||
NoSession: noSessionFlag,
|
||||
Continue: continueFlag,
|
||||
SessionPath: sessionPath,
|
||||
AutoCompact: autoCompactFlag,
|
||||
Quiet: quietFlag,
|
||||
Debug: debugMode,
|
||||
NoSession: noSessionFlag,
|
||||
Continue: continueFlag,
|
||||
SessionPath: sessionPath,
|
||||
AutoCompact: autoCompactFlag,
|
||||
MCPAuthHandler: authHandler,
|
||||
// This callback is called when each MCP server finishes loading.
|
||||
// We use a closure that captures appInstancePtr which is set after
|
||||
// app.New() is called below.
|
||||
OnMCPServerLoaded: func(serverName string, toolCount int, err error) {
|
||||
if appInstancePtr != nil {
|
||||
appInstancePtr.NotifyMCPServerLoaded(serverName, toolCount, err)
|
||||
}
|
||||
},
|
||||
CLI: &kit.CLIOptions{
|
||||
MCPConfig: mcpConfig,
|
||||
ShowSpinner: true,
|
||||
@@ -792,7 +799,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
|
||||
// Load existing messages from resumed/continued sessions.
|
||||
treeSession := kitInstance.GetTreeSession()
|
||||
var messages []fantasy.Message
|
||||
var messages []kit.LLMMessage
|
||||
if treeSession != nil {
|
||||
messages = treeSession.GetLLMMessages()
|
||||
}
|
||||
@@ -815,8 +822,16 @@ func runNormalMode(ctx context.Context) error {
|
||||
}
|
||||
|
||||
appInstance := app.New(appOpts, messages)
|
||||
appInstancePtr = appInstance // Wire up the MCP server loaded callback.
|
||||
defer appInstance.Close()
|
||||
|
||||
// Wire OAuth handler to route messages through the TUI once it's running.
|
||||
if authHandler != nil {
|
||||
authHandler.NotifyFunc = func(serverName, message string) {
|
||||
appInstance.PrintFromExtension("info", message)
|
||||
}
|
||||
}
|
||||
|
||||
// Buffer for extension messages during startup (printed after startup banner).
|
||||
var startupExtensionMessages []string
|
||||
|
||||
@@ -840,7 +855,37 @@ func runNormalMode(ctx context.Context) error {
|
||||
PrintBlock: appInstance.PrintBlockFromExtension,
|
||||
SendMessage: func(text string) { appInstance.Run(text) },
|
||||
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
Abort: func() { appInstance.Abort() },
|
||||
IsIdle: func() bool { return !appInstance.IsBusy() },
|
||||
Compact: func(cfg extensions.CompactConfig) error {
|
||||
return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
|
||||
},
|
||||
SendMultimodalMessage: func(text string, files []extensions.FilePart) {
|
||||
parts := make([]kit.LLMFilePart, len(files))
|
||||
for i, f := range files {
|
||||
parts[i] = kit.LLMFilePart{
|
||||
Filename: f.Filename,
|
||||
Data: f.Data,
|
||||
MediaType: f.MediaType,
|
||||
}
|
||||
}
|
||||
appInstance.RunWithFiles(text, parts)
|
||||
},
|
||||
GetSessionUsage: func() extensions.SessionUsage {
|
||||
if usageTracker == nil {
|
||||
return extensions.SessionUsage{}
|
||||
}
|
||||
stats := usageTracker.GetSessionStats()
|
||||
return extensions.SessionUsage{
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheReadTokens: stats.TotalCacheReadTokens,
|
||||
TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
RequestCount: stats.RequestCount,
|
||||
}
|
||||
},
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
SetWidget: func(config extensions.WidgetConfig) {
|
||||
kitInstance.Extensions().SetWidget(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
@@ -1261,7 +1306,37 @@ func runNormalMode(ctx context.Context) error {
|
||||
PrintBlock: appInstance.PrintBlockFromExtension,
|
||||
SendMessage: func(text string) { appInstance.Run(text) },
|
||||
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
Abort: func() { appInstance.Abort() },
|
||||
IsIdle: func() bool { return !appInstance.IsBusy() },
|
||||
Compact: func(cfg extensions.CompactConfig) error {
|
||||
return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
|
||||
},
|
||||
SendMultimodalMessage: func(text string, files []extensions.FilePart) {
|
||||
parts := make([]kit.LLMFilePart, len(files))
|
||||
for i, f := range files {
|
||||
parts[i] = kit.LLMFilePart{
|
||||
Filename: f.Filename,
|
||||
Data: f.Data,
|
||||
MediaType: f.MediaType,
|
||||
}
|
||||
}
|
||||
appInstance.RunWithFiles(text, parts)
|
||||
},
|
||||
GetSessionUsage: func() extensions.SessionUsage {
|
||||
if usageTracker == nil {
|
||||
return extensions.SessionUsage{}
|
||||
}
|
||||
stats := usageTracker.GetSessionStats()
|
||||
return extensions.SessionUsage{
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheReadTokens: stats.TotalCacheReadTokens,
|
||||
TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
RequestCount: stats.RequestCount,
|
||||
}
|
||||
},
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
SetWidget: func(config extensions.WidgetConfig) {
|
||||
kitInstance.Extensions().SetWidget(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
@@ -1561,6 +1636,49 @@ func runNormalMode(ctx context.Context) error {
|
||||
})
|
||||
}
|
||||
|
||||
// Build prompt template and skill item provider callbacks for hot-reload.
|
||||
// These are called by the TUI when ContentReloadEvent fires.
|
||||
getPromptTemplates := func() []*prompts.PromptTemplate {
|
||||
if noPromptTemplates {
|
||||
return nil
|
||||
}
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
cwd, _ := os.Getwd()
|
||||
tpls, _, err := prompts.LoadAll(prompts.LoadOptions{
|
||||
Cwd: cwd,
|
||||
HomeDir: homeDir,
|
||||
ExtraPaths: promptTemplatePaths,
|
||||
ConfigPaths: viper.GetStringSlice("prompts"),
|
||||
IncludeDefaults: true,
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Warning: failed to reload prompt templates: %v", err)
|
||||
}
|
||||
return tpls
|
||||
}
|
||||
|
||||
getSkillItems := func() []ui.SkillItem {
|
||||
// Re-discover skills from disk.
|
||||
if err := kitInstance.ReloadSkills(); err != nil {
|
||||
log.Printf("Warning: failed to reload skills: %v", err)
|
||||
return nil
|
||||
}
|
||||
cwd, _ := os.Getwd()
|
||||
var items []ui.SkillItem
|
||||
for _, s := range kitInstance.GetSkills() {
|
||||
source := "user"
|
||||
if strings.HasPrefix(s.Path, cwd) {
|
||||
source = "project"
|
||||
}
|
||||
items = append(items, ui.SkillItem{
|
||||
Name: s.Name,
|
||||
Path: s.Path,
|
||||
Source: source,
|
||||
})
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
// Build extension UI providers once (shared between both modes).
|
||||
getWidgets := widgetProviderForUI(kitInstance)
|
||||
getHeader := headerProviderForUI(kitInstance)
|
||||
@@ -1572,10 +1690,29 @@ func runNormalMode(ctx context.Context) error {
|
||||
emitBeforeFork := beforeForkProviderForUI(kitInstance)
|
||||
emitBeforeSessionSwitch := beforeSessionSwitchProviderForUI(kitInstance)
|
||||
getGlobalShortcuts := globalShortcutsProviderForUI(kitInstance)
|
||||
getExtensionCommands := func() []ui.ExtensionCommand {
|
||||
getExtensionCommands := func() []commands.ExtensionCommand {
|
||||
return extensionCommandsForUI(kitInstance)
|
||||
}
|
||||
|
||||
// Build dynamic tool name and MCP tool count providers. These are called
|
||||
// by the TUI when MCPToolsReadyEvent fires to refresh the /tools list
|
||||
// and startup info bar after background MCP tool loading completes.
|
||||
getToolNames := func() []string {
|
||||
return kitInstance.GetToolNames()
|
||||
}
|
||||
getMCPToolCount := func() int {
|
||||
return kitInstance.GetMCPToolCount()
|
||||
}
|
||||
|
||||
// Start a goroutine that waits for background MCP tool loading to
|
||||
// complete and notifies the TUI so it can refresh tool names and counts.
|
||||
if len(mcpConfig.MCPServers) > 0 {
|
||||
go func() {
|
||||
_ = kitInstance.WaitForMCPTools()
|
||||
appInstance.NotifyMCPToolsReady()
|
||||
}()
|
||||
}
|
||||
|
||||
// Build model switching callbacks for the /model command.
|
||||
setModelForUI := func(modelString string) error {
|
||||
err := kitInstance.SetModel(context.Background(), modelString)
|
||||
@@ -1629,9 +1766,81 @@ func runNormalMode(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build extension reload callback for the /reload-ext command.
|
||||
reloadExtensionsForUI := func() error {
|
||||
err := kitInstance.Extensions().Reload()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start file watcher for automatic extension hot-reload.
|
||||
extraPaths := viper.GetStringSlice("extension")
|
||||
watchDirs := extensions.WatchedDirs(extraPaths)
|
||||
if len(watchDirs) > 0 {
|
||||
extWatcher, watchErr := extensions.NewWatcher(watchDirs, func() {
|
||||
if err := reloadExtensionsForUI(); err != nil {
|
||||
log.Printf("auto-reload extensions failed: %v", err)
|
||||
}
|
||||
})
|
||||
if watchErr != nil {
|
||||
log.Printf("extension file watcher not started: %v", watchErr)
|
||||
} else {
|
||||
go extWatcher.Start(ctx)
|
||||
defer func() { _ = extWatcher.Close() }()
|
||||
}
|
||||
}
|
||||
|
||||
// Start file watchers for automatic prompt and skill hot-reload.
|
||||
{
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
cwd, _ := os.Getwd()
|
||||
|
||||
// Collect prompt template directories.
|
||||
promptDirs := watcher.CollectDirs(
|
||||
[]string{
|
||||
filepath.Join(homeDir, ".kit", "prompts"),
|
||||
filepath.Join(cwd, ".kit", "prompts"),
|
||||
},
|
||||
append(promptTemplatePaths, viper.GetStringSlice("prompts")...),
|
||||
)
|
||||
|
||||
// Collect skill directories.
|
||||
skillDirs := watcher.CollectDirs(
|
||||
[]string{
|
||||
filepath.Join(homeDir, ".config", "kit", "skills"),
|
||||
filepath.Join(cwd, ".agents", "skills"),
|
||||
filepath.Join(cwd, ".kit", "skills"),
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
// Combine all content directories and start a single watcher.
|
||||
allContentDirs := append(promptDirs, skillDirs...)
|
||||
if len(allContentDirs) > 0 {
|
||||
contentWatcher, watchErr := watcher.New(watcher.Options{
|
||||
Dirs: allContentDirs,
|
||||
Extensions: []string{".md", ".txt"},
|
||||
Label: "prompts/skills",
|
||||
OnReload: func() {
|
||||
log.Printf("auto-reloading prompts and skills")
|
||||
appInstance.NotifyContentReload()
|
||||
},
|
||||
})
|
||||
if watchErr != nil {
|
||||
log.Printf("content file watcher not started: %v", watchErr)
|
||||
} else {
|
||||
go contentWatcher.Start(ctx)
|
||||
defer func() { _ = contentWatcher.Close() }()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if running in non-interactive mode
|
||||
if positionalPrompt != "" {
|
||||
return runNonInteractiveModeApp(ctx, appInstance, cli, positionalPrompt, quietFlag, jsonFlag, noExitFlag, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI)
|
||||
return runNonInteractiveModeApp(ctx, appInstance, cli, positionalPrompt, quietFlag, jsonFlag, noExitFlag, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getPromptTemplates, getSkillItems, getToolNames, getMCPToolCount, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI)
|
||||
}
|
||||
|
||||
// Quiet mode is not allowed in interactive mode
|
||||
@@ -1639,7 +1848,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
return fmt.Errorf("--quiet requires a prompt")
|
||||
}
|
||||
|
||||
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, startupExtensionMessages)
|
||||
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getPromptTemplates, getSkillItems, getToolNames, getMCPToolCount, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI, startupExtensionMessages)
|
||||
}
|
||||
|
||||
// runNonInteractiveModeApp executes a single prompt via the app layer and exits,
|
||||
@@ -1652,7 +1861,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
//
|
||||
// When --no-exit is set, after the prompt completes the interactive BubbleTea
|
||||
// TUI is started so the user can continue the conversation.
|
||||
func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui.CLI, prompt string, quiet, jsonOutput, noExit bool, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []ui.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []ui.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error) error {
|
||||
func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui.CLI, prompt string, quiet, jsonOutput, noExit bool, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getPromptTemplates func() []*prompts.PromptTemplate, getSkillItems func() []ui.SkillItem, getToolNames func() []string, getMCPToolCount func() int, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error) error {
|
||||
// Expand @file references in the prompt before sending to the agent.
|
||||
if cwd, err := os.Getwd(); err == nil {
|
||||
prompt = ui.ProcessFileAttachments(prompt, cwd)
|
||||
@@ -1695,7 +1904,7 @@ func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui
|
||||
|
||||
// If --no-exit was requested, hand off to the interactive TUI.
|
||||
if noExit {
|
||||
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, providerName, loadingMessage, serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModel, emitModelChange, isReasoningModel, thinkingLevel, setThinkingLevel, switchSession, nil)
|
||||
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, providerName, loadingMessage, serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getPromptTemplates, getSkillItems, getToolNames, getMCPToolCount, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModel, emitModelChange, isReasoningModel, thinkingLevel, setThinkingLevel, switchSession, reloadExtensions, nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1793,7 +2002,19 @@ func writeJSONError(err error) {
|
||||
// 4. Calls program.Run() which blocks until the user quits (Ctrl+C or /quit).
|
||||
//
|
||||
// SetupCLI is not used for interactive mode; the TUI (AppModel) handles its own rendering.
|
||||
func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []ui.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []ui.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, startupExtensionMessages []string) error {
|
||||
func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getPromptTemplates func() []*prompts.PromptTemplate, getSkillItems func() []ui.SkillItem, getToolNames func() []string, getMCPToolCount func() int, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error, startupExtensionMessages []string) error {
|
||||
// Redirect all log output (stdlib and charm) to a file so that log
|
||||
// messages don't write to stderr and corrupt the TUI. Bubble Tea
|
||||
// captures stdout for rendering; any stray stderr output from
|
||||
// background goroutines (watchers, extension handlers, SDK internals)
|
||||
// will visually corrupt the terminal.
|
||||
logDir := filepath.Join(os.TempDir(), "kit")
|
||||
_ = os.MkdirAll(logDir, 0o700)
|
||||
logFile, logErr := tea.LogToFile(filepath.Join(logDir, "kit.log"), "kit")
|
||||
if logErr == nil {
|
||||
defer func() { _ = logFile.Close() }()
|
||||
}
|
||||
|
||||
// Determine terminal size; fall back gracefully.
|
||||
termWidth, termHeight, err := term.GetSize(int(os.Stdout.Fd()))
|
||||
if err != nil || termWidth == 0 {
|
||||
@@ -1804,54 +2025,47 @@ func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelN
|
||||
cwd, _ := os.Getwd()
|
||||
|
||||
appModel := ui.NewAppModel(appInstance, ui.AppModelOptions{
|
||||
CompactMode: viper.GetBool("compact"),
|
||||
ModelName: modelName,
|
||||
ProviderName: providerName,
|
||||
LoadingMessage: loadingMessage,
|
||||
Cwd: cwd,
|
||||
Width: termWidth,
|
||||
Height: termHeight,
|
||||
ServerNames: serverNames,
|
||||
ToolNames: toolNames,
|
||||
MCPToolCount: mcpToolCount,
|
||||
ExtensionToolCount: extensionToolCount,
|
||||
UsageTracker: usageTracker,
|
||||
ExtensionCommands: extCommands,
|
||||
PromptTemplates: promptTemplates,
|
||||
ContextPaths: contextPaths,
|
||||
SkillItems: skillItems,
|
||||
GetWidgets: getWidgets,
|
||||
GetHeader: getHeader,
|
||||
GetFooter: getFooter,
|
||||
GetToolRenderer: getToolRenderer,
|
||||
GetEditorInterceptor: getEditorInterceptor,
|
||||
GetUIVisibility: getUIVisibility,
|
||||
GetStatusBarEntries: getStatusBarEntries,
|
||||
EmitBeforeFork: emitBeforeFork,
|
||||
EmitBeforeSessionSwitch: emitBeforeSessionSwitch,
|
||||
GetGlobalShortcuts: getGlobalShortcuts,
|
||||
GetExtensionCommands: getExtensionCommands,
|
||||
SetModel: setModel,
|
||||
EmitModelChange: emitModelChange,
|
||||
ThinkingLevel: thinkingLevel,
|
||||
IsReasoningModel: isReasoningModel,
|
||||
SetThinkingLevel: setThinkingLevel,
|
||||
SwitchSession: switchSession,
|
||||
ShowSessionPicker: resumeFlag,
|
||||
ModelName: modelName,
|
||||
ProviderName: providerName,
|
||||
LoadingMessage: loadingMessage,
|
||||
Cwd: cwd,
|
||||
Width: termWidth,
|
||||
Height: termHeight,
|
||||
ServerNames: serverNames,
|
||||
ToolNames: toolNames,
|
||||
GetToolNames: getToolNames,
|
||||
GetMCPToolCount: getMCPToolCount,
|
||||
MCPToolCount: mcpToolCount,
|
||||
ExtensionToolCount: extensionToolCount,
|
||||
UsageTracker: usageTracker,
|
||||
ExtensionCommands: extCommands,
|
||||
PromptTemplates: promptTemplates,
|
||||
GetPromptTemplates: getPromptTemplates,
|
||||
ContextPaths: contextPaths,
|
||||
SkillItems: skillItems,
|
||||
GetSkillItems: getSkillItems,
|
||||
StartupExtensionMessages: startupExtensionMessages,
|
||||
GetWidgets: getWidgets,
|
||||
GetHeader: getHeader,
|
||||
GetFooter: getFooter,
|
||||
GetToolRenderer: getToolRenderer,
|
||||
GetEditorInterceptor: getEditorInterceptor,
|
||||
GetUIVisibility: getUIVisibility,
|
||||
GetStatusBarEntries: getStatusBarEntries,
|
||||
EmitBeforeFork: emitBeforeFork,
|
||||
EmitBeforeSessionSwitch: emitBeforeSessionSwitch,
|
||||
GetGlobalShortcuts: getGlobalShortcuts,
|
||||
GetExtensionCommands: getExtensionCommands,
|
||||
SetModel: setModel,
|
||||
EmitModelChange: emitModelChange,
|
||||
ThinkingLevel: thinkingLevel,
|
||||
IsReasoningModel: isReasoningModel,
|
||||
SetThinkingLevel: setThinkingLevel,
|
||||
SwitchSession: switchSession,
|
||||
ReloadExtensions: reloadExtensions,
|
||||
ShowSessionPicker: resumeFlag,
|
||||
})
|
||||
|
||||
// Print startup info to stdout before Bubble Tea takes over the screen.
|
||||
appModel.PrintStartupInfo()
|
||||
|
||||
// Print any extension messages that were captured during startup.
|
||||
if len(startupExtensionMessages) > 0 {
|
||||
fmt.Println()
|
||||
for _, msg := range startupExtensionMessages {
|
||||
fmt.Println(msg)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
program := tea.NewProgram(appModel)
|
||||
|
||||
// Register the program with the app layer so agent events are sent to the TUI.
|
||||
|
||||
@@ -41,7 +41,6 @@ func BuildAppOptions(mcpConfig *config.Config, modelName string, serverNames, to
|
||||
StreamingEnabled: viper.GetBool("stream"),
|
||||
Quiet: quietFlag,
|
||||
Debug: viper.GetBool("debug"),
|
||||
CompactMode: viper.GetBool("compact"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,7 +130,6 @@ func SetupCLIForNonInteractive(k *kit.Kit) (*ui.CLI, error) {
|
||||
Agent: agentAdapter,
|
||||
ModelString: viper.GetString("model"),
|
||||
Debug: viper.GetBool("debug"),
|
||||
Compact: viper.GetBool("compact"),
|
||||
Quiet: quietFlag,
|
||||
ShowDebug: false,
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
|
||||
@@ -7,10 +7,12 @@
|
||||
// development: edit your extension source, then type /reload to pick up
|
||||
// changes immediately.
|
||||
//
|
||||
// Event handlers, slash commands, tool renderers, message renderers, and
|
||||
// keyboard shortcuts update immediately. Extension-defined tools are NOT
|
||||
// updated (they are baked into the agent at creation time and require a
|
||||
// restart).
|
||||
// Note: Extensions in autoloaded directories (~/.config/kit/extensions/
|
||||
// and .kit/extensions/) are automatically reloaded on save. The /reload
|
||||
// command is useful for extensions loaded via -e from other locations.
|
||||
//
|
||||
// Event handlers, slash commands, tool definitions, tool renderers,
|
||||
// message renderers, and keyboard shortcuts all update immediately.
|
||||
//
|
||||
// Commands:
|
||||
// /reload — hot-reload all extensions from disk
|
||||
|
||||
@@ -168,6 +168,10 @@ var (
|
||||
// Test
|
||||
pendingTest *PendingTest
|
||||
|
||||
// Typing indicator
|
||||
typingTicker *time.Ticker
|
||||
typingStop chan struct{}
|
||||
|
||||
// Latest context for background goroutines
|
||||
latestCtx ext.Context
|
||||
latestCtxSet bool
|
||||
@@ -203,8 +207,23 @@ func configDir() string {
|
||||
return filepath.Join(home, ".config", "kit")
|
||||
}
|
||||
|
||||
func globalConfigDir() string {
|
||||
home, _ := os.UserHomeDir()
|
||||
return filepath.Join(home, ".config", "kit")
|
||||
}
|
||||
|
||||
func configPath() string {
|
||||
return filepath.Join(configDir(), "kit-telegram.json")
|
||||
// Prefer project-local config, fall back to global config.
|
||||
local := filepath.Join(configDir(), "kit-telegram.json")
|
||||
if _, err := os.Stat(local); err == nil {
|
||||
return local
|
||||
}
|
||||
global := filepath.Join(globalConfigDir(), "kit-telegram.json")
|
||||
if _, err := os.Stat(global); err == nil {
|
||||
return global
|
||||
}
|
||||
// Neither exists — return local path (will be created on connect).
|
||||
return local
|
||||
}
|
||||
|
||||
func failureLogDir() string {
|
||||
@@ -387,6 +406,14 @@ func tgEditMessageText(token string, chatID int64, messageID int, text string) (
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
func tgSendChatAction(token string, chatID int64, action string) error {
|
||||
_, err := telegramRequest(token, "sendChatAction", map[string]any{
|
||||
"chat_id": chatID,
|
||||
"action": action,
|
||||
}, 15)
|
||||
return err
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────
|
||||
// Error classification
|
||||
// ──────────────────────────────────────────────
|
||||
@@ -637,6 +664,48 @@ func clearHealthTimer() {
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────
|
||||
// Typing indicator
|
||||
// ──────────────────────────────────────────────
|
||||
|
||||
func startTypingLoop() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if typingTicker != nil {
|
||||
return
|
||||
}
|
||||
cfg := config
|
||||
if cfg == nil || !cfg.Enabled {
|
||||
return
|
||||
}
|
||||
token := cfg.BotToken
|
||||
chatID := cfg.ChatID
|
||||
typingTicker = time.NewTicker(4 * time.Second)
|
||||
typingStop = make(chan struct{})
|
||||
// Send immediately, then every 4 seconds.
|
||||
go func() {
|
||||
tgSendChatAction(token, chatID, "typing")
|
||||
for {
|
||||
select {
|
||||
case <-typingTicker.C:
|
||||
tgSendChatAction(token, chatID, "typing")
|
||||
case <-typingStop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func stopTypingLoop() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if typingTicker != nil {
|
||||
typingTicker.Stop()
|
||||
close(typingStop)
|
||||
typingTicker = nil
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────
|
||||
// Polling lifecycle
|
||||
// ──────────────────────────────────────────────
|
||||
@@ -2105,6 +2174,7 @@ func Init(api ext.API) {
|
||||
mu.Unlock()
|
||||
|
||||
sendShutdownDisconnectedMessage()
|
||||
stopTypingLoop()
|
||||
stopPolling()
|
||||
clearHealthTimer()
|
||||
clearFooter()
|
||||
@@ -2128,6 +2198,7 @@ func Init(api ext.API) {
|
||||
mu.Unlock()
|
||||
|
||||
report("run.start", fmt.Sprintf("runId=%d", run.ID))
|
||||
startTypingLoop()
|
||||
ensureProgressMessage()
|
||||
updateProgressMessage()
|
||||
})
|
||||
@@ -2140,6 +2211,8 @@ func Init(api ext.API) {
|
||||
run := activeRun
|
||||
mu.Unlock()
|
||||
|
||||
stopTypingLoop()
|
||||
|
||||
if run != nil {
|
||||
// Capture final response from event
|
||||
if e.Response != "" {
|
||||
|
||||
@@ -9,13 +9,19 @@ require (
|
||||
charm.land/huh/v2 v2.0.3
|
||||
charm.land/lipgloss/v2 v2.0.2
|
||||
github.com/alecthomas/chroma/v2 v2.23.1
|
||||
github.com/atotto/clipboard v0.1.4
|
||||
github.com/aymanbagabas/go-udiff v0.4.1
|
||||
github.com/charmbracelet/fang v1.0.0
|
||||
github.com/charmbracelet/log v1.0.0
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260330092749-0f94982c930b
|
||||
github.com/clipperhouse/displaywidth v0.11.0
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0
|
||||
github.com/coder/acp-go-sdk v0.6.3
|
||||
github.com/indaco/herald v0.10.0
|
||||
github.com/indaco/herald-md v0.1.0
|
||||
github.com/mark3labs/mcp-go v0.46.0
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/indaco/herald v0.13.0
|
||||
github.com/indaco/herald-md v0.3.0
|
||||
github.com/mark3labs/mcp-go v0.47.1
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/traefik/yaegi v0.16.1
|
||||
@@ -25,16 +31,15 @@ require (
|
||||
|
||||
require (
|
||||
cloud.google.com/go v0.123.0 // indirect
|
||||
cloud.google.com/go/auth v0.19.0 // indirect
|
||||
cloud.google.com/go/auth v0.20.0 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.9.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect
|
||||
github.com/atotto/clipboard v0.1.4 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.5 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.13 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.13 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.14 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.14 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 // indirect
|
||||
@@ -42,47 +47,41 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.14 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.18 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 // indirect
|
||||
github.com/aws/smithy-go v1.24.2 // indirect
|
||||
github.com/aws/smithy-go v1.24.3 // indirect
|
||||
github.com/catppuccin/go v0.3.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab // indirect
|
||||
github.com/charmbracelet/colorprofile v0.4.3 // indirect
|
||||
github.com/charmbracelet/harmonica v0.2.0 // indirect
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 // indirect
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266 // indirect
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260330092749-0f94982c930b // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260330094520-2dce04b6f8a4 // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260406091427-a791e22d5143 // indirect
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260330094520-2dce04b6f8a4 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260406091427-a791e22d5143 // indirect
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/json v0.2.0 // indirect
|
||||
github.com/charmbracelet/x/termios v0.1.1 // indirect
|
||||
github.com/charmbracelet/x/windows v0.2.2 // indirect
|
||||
github.com/clipperhouse/displaywidth v0.11.0 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0 // indirect
|
||||
github.com/dlclark/regexp2 v1.11.5 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 // indirect
|
||||
github.com/go-logfmt/logfmt v0.6.1 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
|
||||
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/jsonschema-go v0.4.2 // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.20.0 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.21.0 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.3.0 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.3.1 // indirect
|
||||
github.com/kaptinlin/jsonpointer v0.4.17 // indirect
|
||||
github.com/kaptinlin/jsonschema v0.7.7 // indirect
|
||||
github.com/kaptinlin/messageformat-go v0.4.19 // indirect
|
||||
@@ -104,21 +103,21 @@ require (
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
github.com/yuin/goldmark v1.8.2 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 // indirect
|
||||
go.opentelemetry.io/otel v1.42.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.42.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.42.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 // indirect
|
||||
go.opentelemetry.io/otel v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.43.0 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/crypto v0.49.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 // indirect
|
||||
golang.org/x/net v0.52.0 // indirect
|
||||
golang.org/x/oauth2 v0.36.0 // indirect
|
||||
golang.org/x/time v0.15.0 // indirect
|
||||
google.golang.org/api v0.273.0 // indirect
|
||||
google.golang.org/genai v1.52.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7 // indirect
|
||||
google.golang.org/grpc v1.79.3 // indirect
|
||||
google.golang.org/api v0.275.0 // indirect
|
||||
google.golang.org/genai v1.52.1 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260406210006-6f92a3bedf2d // indirect
|
||||
google.golang.org/grpc v1.80.0 // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
)
|
||||
@@ -129,13 +128,13 @@ require (
|
||||
github.com/charmbracelet/x/term v0.2.2 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.21 // indirect
|
||||
github.com/mattn/go-isatty v0.0.21 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.23 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/sys v0.43.0 // indirect
|
||||
golang.org/x/text v0.35.0
|
||||
)
|
||||
|
||||
@@ -10,20 +10,20 @@ charm.land/lipgloss/v2 v2.0.2 h1:xFolbF8JdpNkM2cEPTfXEcW1p6NRzOWTSamRfYEw8cs=
|
||||
charm.land/lipgloss/v2 v2.0.2/go.mod h1:KjPle2Qd3YmvP1KL5OMHiHysGcNwq6u83MUjYkFvEkM=
|
||||
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
|
||||
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
|
||||
cloud.google.com/go/auth v0.19.0 h1:DGYwtbcsGsT1ywuxsIoWi1u/vlks0moIblQHgSDgQkQ=
|
||||
cloud.google.com/go/auth v0.19.0/go.mod h1:2Aph7BT2KnaSFOM0JDPyiYgNh6PL9vGMiP8CUIXZ+IY=
|
||||
cloud.google.com/go/auth v0.20.0 h1:kXTssoVb4azsVDoUiF8KvxAqrsQcQtB53DcSgta74CA=
|
||||
cloud.google.com/go/auth v0.20.0/go.mod h1:942/yi/itH1SsmpyrbnTMDgGfdy2BUqIKyd0cyYLc5Q=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
|
||||
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
|
||||
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 h1:fou+2+WFTib47nS+nz/ozhEBnvU96bKHy6LjRsY4E28=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0/go.mod h1:t76Ruy8AHvUAC8GfMWJMa0ElSbuIcO03NLpynfbgsPA=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 h1:B+blDbyVIG3WaikNxPnhPiJ1MThR03b3vKGtER95TP4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1/go.mod h1:JdM5psgjfBf5fo2uWOZhflPWyDBZ/O/CNAH9CtsuZE4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 h1:fhqpLE3UEXi9lPaBRpQ6XuRW0nU7hgg4zlmZZa+a9q4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0/go.mod h1:7dCRMLwisfRH3dBupKeNCioWYUZ4SS09Z14H+7i8ZoY=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk=
|
||||
github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ=
|
||||
github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
|
||||
github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0=
|
||||
@@ -38,10 +38,10 @@ github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.5/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 h1:eBMB84YGghSocM7PsjmmPffTa+1FBUeNvGvFou6V/4o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.13 h1:5KgbxMaS2coSWRrx9TX/QtWbqzgQkOdEa3sZPhBhCSg=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.13/go.mod h1:8zz7wedqtCbw5e9Mi2doEwDyEgHcEE9YOJp6a8jdSMY=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.13 h1:mA59E3fokBvyEGHKFdnpNNrvaR351cqiHgRg+JzOSRI=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.13/go.mod h1:yoTXOQKea18nrM69wGF9jBdG4WocSZA1h38A+t/MAsk=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.14 h1:opVIRo/ZbbI8OIqSOKmpFaY7IwfFUOCCXBsUpJOwDdI=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.14/go.mod h1:U4/V0uKxh0Tl5sxmCBZ3AecYny4UNlVmObYjKuuaiOo=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.14 h1:n+UcGWAIZHkXzYt87uMFBv/l8THYELoX6gVcUvgl6fI=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.14/go.mod h1:cJKuyWB59Mqi0jM3nFYQRmnHVQIcgoxjEMAbLkpr62w=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 h1:NUS3K4BTDArQqNu2ih7yeDLaS3bmHD0YndtA6UP884g=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21/go.mod h1:YWNWJQNjKigKY1RHVJCuupeWDrrHjRqHm0N9rdrWzYI=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 h1:Rgg6wvjjtX8bNHcvi9OnXWwcE0a2vGpbwmtICOsvcf4=
|
||||
@@ -56,14 +56,14 @@ github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 h1:c31//R3x
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21/go.mod h1:r6+pf23ouCB718FUxaqzZdbpYFyDtehyZcmP5KL9FkA=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 h1:QKZH0S178gCmFEgst8hN0mCX1KxLgHBKKY/CLqwP8lg=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9/go.mod h1:7yuQJoT+OoH8aqIxw9vwF+8KpvLZ8AWmvmUWHsGQZvI=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.14 h1:GcLE9ba5ehAQma6wlopUesYg/hbcOhFNWTjELkiWkh4=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.14/go.mod h1:WSvS1NLr7JaPunCXqpJnWk1Bjo7IxzZXrZi1QQCkuqM=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.18 h1:mP49nTpfKtpXLt5SLn8Uv8z6W+03jYVoOSAl/c02nog=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.18/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 h1:lFd1+ZSEYJZYvv9d6kXzhkZu07si3f+GQ1AaYwa2LUM=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15/go.mod h1:WSvS1NLr7JaPunCXqpJnWk1Bjo7IxzZXrZi1QQCkuqM=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 h1:dzztQ1YmfPrxdrOiuZRMF6fuOwWlWpD2StNLTceKpys=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 h1:p8ogvvLugcR/zLBXTXrTkj0RYBUdErbMnAFFp12Lm/U=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10/go.mod h1:60dv0eZJfeVXfbT1tFJinbHrDfSJ2GZl4Q//OSSNAVw=
|
||||
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
|
||||
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/aws/smithy-go v1.24.3 h1:XgOAaUgx+HhVBoP4v8n6HCQoTRDhoMghKqw4LNHsDNg=
|
||||
github.com/aws/smithy-go v1.24.3/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/aymanbagabas/go-udiff v0.4.1 h1:OEIrQ8maEeDBXQDoGCbbTTXYJMYRCRO1fnodZ12Gv5o=
|
||||
@@ -96,14 +96,14 @@ github.com/charmbracelet/x/conpty v0.1.1 h1:s1bUxjoi7EpqiXysVtC+a8RrvPPNcNvAjfi4
|
||||
github.com/charmbracelet/x/conpty v0.1.1/go.mod h1:OmtR77VODEFbiTzGE9G1XiRJAga6011PIm4u5fTNZpk=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 h1:JSt3B+U9iqk37QUU2Rvb6DSBYRLtWqFqfxf8l5hOZUA=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86/go.mod h1:2P0UgXMEa6TsToMSuFqKFQR+fZTO9CNGUNokkPatT/0=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260330094520-2dce04b6f8a4 h1:pIj18ZCZO4WOVj7jwjLoUb1lC7rS/I8oC3fZWXugNaY=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260330094520-2dce04b6f8a4/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260406091427-a791e22d5143 h1:zmBor0ftFNqVFp9U59ZoEDRUCIYSGOGSIfGGkNZRufs=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260406091427-a791e22d5143/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f h1:pk6gmGpCE7F3FcjaOEKYriCvpmIN4+6OS/RD0vm4uIA=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f/go.mod h1:IfZAMTHB6XkZSeXUqriemErjAWCCzT0LwjKFYCZyw0I=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 h1:55/qLwjIh0gL0Vni+QAWk7T/qRVP6sBf+2agPBgnOFE=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0/go.mod h1:5UHwmG+is5THxMyCJHNPCn2/ecI07aKNrW+LcResjJ8=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260330094520-2dce04b6f8a4 h1:VSd4zShIAf/4FgEDFJpapEcAPrc7h3dyyN7V9JlJpQw=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260330094520-2dce04b6f8a4/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260406091427-a791e22d5143 h1:aEppolah2k9c0LzKX2fk5ryuyQ0Lq8kCOjkvMw1b8o4=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260406091427-a791e22d5143/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 h1:i69S2XI7uG1u4NLGeJPSYU++Nmjvpo9nwd6aoEm7gkA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0/go.mod h1:/ehtMPNh9K4odGFkqYJKpIYyePhdp1hLBRvyY4bWkH8=
|
||||
github.com/charmbracelet/x/json v0.2.0 h1:DqB+ZGx2h+Z+1s98HOuOyli+i97wsFQIxP2ZQANTPrQ=
|
||||
@@ -173,20 +173,20 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
|
||||
github.com/googleapis/gax-go/v2 v2.20.0 h1:NIKVuLhDlIV74muWlsMM4CcQZqN6JJ20Qcxd9YMuYcs=
|
||||
github.com/googleapis/gax-go/v2 v2.20.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
|
||||
github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI=
|
||||
github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/indaco/herald v0.10.0 h1:XzahEKX6cr50qZQrUdA3QrQBHg8uGm5jETD0UDi21BI=
|
||||
github.com/indaco/herald v0.10.0/go.mod h1:T5g1+XLYvpjouhzAGHnAHDCKizhESkoV6+QPZ3DhgWA=
|
||||
github.com/indaco/herald-md v0.1.0 h1:zmYudYo+uamzKTBcIffJVJYrqk9xDNnVrTh+de2zciw=
|
||||
github.com/indaco/herald-md v0.1.0/go.mod h1:Z1HxPCbSn+/+TFzOM/UbsmKeEk/28NNI6JOTileKXto=
|
||||
github.com/kaptinlin/go-i18n v0.3.0 h1:wP76dvYg04bvwTb+8NB+CmdZ2kL7lSSCQ9B/kFv7QHo=
|
||||
github.com/kaptinlin/go-i18n v0.3.0/go.mod h1:pVcu9qsW5pOIOoZFJXesRYmLos1vMQrby70JPAoWmJU=
|
||||
github.com/indaco/herald v0.13.0 h1:+xVG9Fx5NpuWhwku/9IlRL6I009NnX4VUGKvlZHTRxU=
|
||||
github.com/indaco/herald v0.13.0/go.mod h1:T5g1+XLYvpjouhzAGHnAHDCKizhESkoV6+QPZ3DhgWA=
|
||||
github.com/indaco/herald-md v0.3.0 h1:hN1cKyrexPPM9PeHBsKuaWvIizSi/iYvM9yzRgtdb8M=
|
||||
github.com/indaco/herald-md v0.3.0/go.mod h1:RUHVaDSG45ymJjKyxpDwBocLXrZo93FB4OeYMsw9B9s=
|
||||
github.com/kaptinlin/go-i18n v0.3.1 h1:plXi3XQE1aYamFi8TU0K6actODmw2+5FSobmhTkfQ/0=
|
||||
github.com/kaptinlin/go-i18n v0.3.1/go.mod h1:ZRoAHj7elWYamfbv7wev7Ajch6LOzjtBaq8nWe8HIVk=
|
||||
github.com/kaptinlin/jsonpointer v0.4.17 h1:mY9k8ciWncxbsECyaxKnR0MdmxamNdp2tLQkAKVrtSk=
|
||||
github.com/kaptinlin/jsonpointer v0.4.17/go.mod h1:SsfsjqnHG5zuKo1DTBzk1VknaHlL4osHw+X9kZKukpU=
|
||||
github.com/kaptinlin/jsonschema v0.7.7 h1:41BlQJ9dskH0oE5DSzBUrl/w4JQYIr6N6L0B5GNyDoM=
|
||||
@@ -201,12 +201,12 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0 h1:UtrWVfLdarDgc44HcS7pYloGHJUjHV/4FwW4TvVgFr4=
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mark3labs/mcp-go v0.46.0 h1:8KRibF4wcKejbLsHxCA/QBVUr5fQ9nwz/n8lGqmaALo=
|
||||
github.com/mark3labs/mcp-go v0.46.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.21 h1:jJKAZiQH+2mIinzCJIaIG9Be1+0NR+5sz/lYEEjdM8w=
|
||||
github.com/mattn/go-runewidth v0.0.21/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/mark3labs/mcp-go v0.47.1 h1:A9sJJ20mscl/ssLYHjodfaoBmq6uuhMG7pAPNYaQymQ=
|
||||
github.com/mark3labs/mcp-go v0.47.1/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
|
||||
github.com/mattn/go-isatty v0.0.21 h1:xYae+lCNBP7QuW4PUnNG61ffM4hVIfm+zUzDuSzYLGs=
|
||||
github.com/mattn/go-isatty v0.0.21/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
|
||||
github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw=
|
||||
github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
@@ -272,20 +272,20 @@ github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE=
|
||||
github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 h1:yI1/OhfEPy7J9eoa6Sj051C7n5dvpj0QX8g4sRchg04=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 h1:OyrsyzuttWTSur2qN/Lm0m2a8yqyIjUVBZcxFPuXq2o=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0/go.mod h1:C2NGBr+kAB4bk3xtMXfZ94gqFDtg/GkI7e9zqGh5Beg=
|
||||
go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho=
|
||||
go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc=
|
||||
go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4=
|
||||
go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI=
|
||||
go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo=
|
||||
go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc=
|
||||
go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY=
|
||||
go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0 h1:0Qx7VGBacMm9ZENQ7TnNObTYI4ShC+lHI16seduaxZo=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0/go.mod h1:Sje3i3MjSPKTSPvVWCaL8ugBzJwik3u4smCjUeuupqg=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 h1:CqXxU8VOmDefoh0+ztfGaymYbhdB/tT3zs79QaZTNGY=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0/go.mod h1:BuhAPThV8PBHBvg8ZzZ/Ok3idOdhWIodywz2xEcRbJo=
|
||||
go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I=
|
||||
go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0=
|
||||
go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM=
|
||||
go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
|
||||
@@ -298,29 +298,28 @@ golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
|
||||
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/api v0.273.0 h1:r/Bcv36Xa/te1ugaN1kdJ5LoA5Wj/cL+a4gj6FiPBjQ=
|
||||
google.golang.org/api v0.273.0/go.mod h1:JbAt7mF+XVmWu6xNP8/+CTiGH30ofmCmk9nM8d8fHew=
|
||||
google.golang.org/genai v1.52.0 h1:ekVIxWHtLUNbt+v0WWi4j3JT4yrHDEbysMcHQcaCQoI=
|
||||
google.golang.org/genai v1.52.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
|
||||
google.golang.org/api v0.275.0 h1:vfY5d9vFVJeWEZT65QDd9hbndr7FyZ2+6mIzGAh71NI=
|
||||
google.golang.org/api v0.275.0/go.mod h1:Fnag/EWUPIcJXuIkP1pjoTgS5vdxlk3eeemL7Do6bvw=
|
||||
google.golang.org/genai v1.52.1 h1:dYoljKtLDXMiBdVaClSJ/ZPwZ7j1N0lGjMhwOKOQUlk=
|
||||
google.golang.org/genai v1.52.1/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 h1:XzmzkmB14QhVhgnawEVsOn6OFsnpyxNPRY9QV01dNB0=
|
||||
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:L43LFes82YgSonw6iTXTxXUX1OlULt4AQtkik4ULL/I=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7 h1:41r6JMbpzBMen0R/4TZeeAmGXSJC7DftGINUodzTkPI=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:EIQZ5bFCfRQDV4MhRle7+OgjNtZ6P1PiZBgAKuxXu/Y=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7 h1:ndE4FoJqsIceKP2oYSnUZqhTdYufCYYkqwtFzfrhI7w=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE=
|
||||
google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260406210006-6f92a3bedf2d h1:wT2n40TBqFY6wiwazVK9/iTWbsQrgk5ZfCSVFLO9LQA=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260406210006-6f92a3bedf2d/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
|
||||
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
||||
+226
-16
@@ -7,8 +7,11 @@ package acpserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
@@ -20,7 +23,6 @@ import (
|
||||
// Version is injected at build time; fallback to "dev".
|
||||
var Version = "dev"
|
||||
|
||||
// Agent implements the acp.Agent interface, delegating to Kit for LLM
|
||||
// execution, tool calls, and session management.
|
||||
type Agent struct {
|
||||
conn *acp.AgentSideConnection
|
||||
@@ -111,13 +113,20 @@ func (a *Agent) Prompt(ctx context.Context, params acp.PromptRequest) (acp.Promp
|
||||
)
|
||||
}
|
||||
|
||||
// Extract text from prompt content blocks.
|
||||
promptText := extractPromptText(params.Prompt)
|
||||
if promptText == "" {
|
||||
// Extract text and file attachments from prompt content blocks.
|
||||
promptText, files := extractPromptContent(params.Prompt)
|
||||
if promptText == "" && len(files) == 0 {
|
||||
return acp.PromptResponse{}, acp.NewInvalidParams("empty prompt")
|
||||
}
|
||||
|
||||
log.Debug("acp: prompt", "session", sessionID, "prompt_len", len(promptText))
|
||||
// If we have files but no text prompt, add a default prompt
|
||||
// This is required because the underlying LLM library needs a non-empty prompt
|
||||
// when there are no previous messages in the conversation.
|
||||
if promptText == "" && len(files) > 0 {
|
||||
promptText = "Please analyze the attached file."
|
||||
}
|
||||
|
||||
log.Debug("acp: prompt", "session", sessionID, "prompt_len", len(promptText), "files", len(files))
|
||||
|
||||
// Create a cancellable context for this prompt turn.
|
||||
promptCtx, cancel := context.WithCancel(ctx)
|
||||
@@ -129,7 +138,13 @@ func (a *Agent) Prompt(ctx context.Context, params acp.PromptRequest) (acp.Promp
|
||||
defer unsub()
|
||||
|
||||
// Run the prompt through Kit's full turn lifecycle.
|
||||
_, err := sess.kit.PromptResult(promptCtx, promptText)
|
||||
// Use PromptResultWithFiles when file attachments are present.
|
||||
var err error
|
||||
if len(files) > 0 {
|
||||
_, err = sess.kit.PromptResultWithFiles(promptCtx, promptText, files)
|
||||
} else {
|
||||
_, err = sess.kit.PromptResult(promptCtx, promptText)
|
||||
}
|
||||
if err != nil {
|
||||
if promptCtx.Err() != nil {
|
||||
return acp.PromptResponse{
|
||||
@@ -162,6 +177,24 @@ func (a *Agent) SetSessionMode(_ context.Context, _ acp.SetSessionModeRequest) (
|
||||
return acp.SetSessionModeResponse{}, nil
|
||||
}
|
||||
|
||||
// SetSessionModel changes the active model for a session.
|
||||
func (a *Agent) SetSessionModel(ctx context.Context, params acp.SetSessionModelRequest) (acp.SetSessionModelResponse, error) {
|
||||
sessionID := string(params.SessionId)
|
||||
sess, ok := a.registry.get(sessionID)
|
||||
if !ok {
|
||||
return acp.SetSessionModelResponse{}, acp.NewInvalidParams(fmt.Sprintf("session not found: %s", sessionID))
|
||||
}
|
||||
|
||||
modelID := string(params.ModelId)
|
||||
log.Debug("acp: set_session_model", "session", sessionID, "model", modelID)
|
||||
|
||||
if err := sess.kit.SetModel(ctx, modelID); err != nil {
|
||||
return acp.SetSessionModelResponse{}, fmt.Errorf("set model: %w", err)
|
||||
}
|
||||
|
||||
return acp.SetSessionModelResponse{}, nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Event streaming: Kit events → ACP SessionUpdate notifications
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -231,19 +264,196 @@ func (a *Agent) subscribeEvents(ctx context.Context, k *kit.Kit, sessionID acp.S
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// extractPromptText extracts the concatenated text content from ACP content
|
||||
// blocks. Non-text blocks are ignored for now.
|
||||
func extractPromptText(blocks []acp.ContentBlock) string {
|
||||
var text string
|
||||
for _, block := range blocks {
|
||||
if block.Text != nil {
|
||||
if text != "" {
|
||||
text += "\n"
|
||||
// extractPromptContent extracts text and file attachments from ACP content blocks.
|
||||
// It converts supported content blocks (image, audio, resource) to Kit's LLMFilePart.
|
||||
func extractPromptContent(blocks []acp.ContentBlock) (string, []kit.LLMFilePart) {
|
||||
var textParts []string
|
||||
var files []kit.LLMFilePart
|
||||
|
||||
log.Debug("acp: extracting content", "blocks", len(blocks))
|
||||
|
||||
for i, block := range blocks {
|
||||
switch {
|
||||
// Text content
|
||||
case block.Text != nil:
|
||||
log.Debug("acp: content block", "index", i, "type", "text", "len", len(block.Text.Text))
|
||||
textParts = append(textParts, block.Text.Text)
|
||||
|
||||
// Image data (base64)
|
||||
case block.Image != nil:
|
||||
mimeType := block.Image.MimeType
|
||||
if mimeType == "" {
|
||||
mimeType = "image/png" // Default fallback
|
||||
}
|
||||
text += block.Text.Text
|
||||
log.Debug("acp: content block", "index", i, "type", "image", "mime", mimeType, "data_len", len(block.Image.Data))
|
||||
if data, err := base64.StdEncoding.DecodeString(block.Image.Data); err == nil {
|
||||
files = append(files, kit.LLMFilePart{
|
||||
Filename: "image.png",
|
||||
Data: data,
|
||||
MediaType: mimeType,
|
||||
})
|
||||
} else {
|
||||
log.Debug("acp: failed to decode image", "error", err)
|
||||
}
|
||||
|
||||
// Audio data (base64)
|
||||
case block.Audio != nil:
|
||||
mimeType := block.Audio.MimeType
|
||||
if mimeType == "" {
|
||||
mimeType = "audio/wav" // Default fallback
|
||||
}
|
||||
log.Debug("acp: content block", "index", i, "type", "audio", "mime", mimeType)
|
||||
if data, err := base64.StdEncoding.DecodeString(block.Audio.Data); err == nil {
|
||||
files = append(files, kit.LLMFilePart{
|
||||
Filename: "audio.wav",
|
||||
Data: data,
|
||||
MediaType: mimeType,
|
||||
})
|
||||
} else {
|
||||
log.Debug("acp: failed to decode audio", "error", err)
|
||||
}
|
||||
|
||||
// Embedded resource (text or binary file content)
|
||||
case block.Resource != nil:
|
||||
log.Debug("acp: content block", "index", i, "type", "resource")
|
||||
res := block.Resource.Resource
|
||||
// Text resource - append as text content with file reference
|
||||
if res.TextResourceContents != nil {
|
||||
uri := res.TextResourceContents.Uri
|
||||
content := res.TextResourceContents.Text
|
||||
mimeType := "text/plain"
|
||||
if res.TextResourceContents.MimeType != nil {
|
||||
mimeType = *res.TextResourceContents.MimeType
|
||||
}
|
||||
log.Debug("acp: text resource", "uri", uri, "mime", mimeType, "len", len(content))
|
||||
// Text files are included as formatted text, NOT as FilePart
|
||||
// FilePart is for binary files (images, audio, PDFs) only
|
||||
textParts = append(textParts, fmt.Sprintf("[File: %s]\n```\n%s\n```", uri, content))
|
||||
}
|
||||
// Binary resource (base64 blob) - these become FilePart
|
||||
if res.BlobResourceContents != nil {
|
||||
uri := res.BlobResourceContents.Uri
|
||||
mimeType := "application/octet-stream"
|
||||
if res.BlobResourceContents.MimeType != nil {
|
||||
mimeType = *res.BlobResourceContents.MimeType
|
||||
}
|
||||
log.Debug("acp: binary resource", "uri", uri, "mime", mimeType, "blob_len", len(res.BlobResourceContents.Blob))
|
||||
if data, err := base64.StdEncoding.DecodeString(res.BlobResourceContents.Blob); err == nil {
|
||||
files = append(files, kit.LLMFilePart{
|
||||
Filename: extractFilenameFromURI(uri),
|
||||
Data: data,
|
||||
MediaType: mimeType,
|
||||
})
|
||||
} else {
|
||||
log.Debug("acp: failed to decode binary resource", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Resource link (file reference without embedded content)
|
||||
case block.ResourceLink != nil:
|
||||
uri := block.ResourceLink.Uri
|
||||
name := block.ResourceLink.Name
|
||||
log.Debug("acp: content block", "index", i, "type", "resource_link", "uri", uri, "name", name)
|
||||
// For resource links, we'll try to read the file from disk
|
||||
// This requires the file URI to be accessible (file:// scheme)
|
||||
if content, err := readResourceFromURI(uri); err == nil {
|
||||
// Detect if it's a text file or binary file
|
||||
mimeType := "text/plain"
|
||||
if block.ResourceLink.MimeType != nil {
|
||||
mimeType = *block.ResourceLink.MimeType
|
||||
}
|
||||
log.Debug("acp: resource link loaded", "uri", uri, "mime", mimeType, "size", len(content))
|
||||
|
||||
// Only create FilePart for binary files (images, audio, PDFs, etc.)
|
||||
// Text files are included as formatted text in the message
|
||||
if isTextMimeType(mimeType) || looksLikeText(content) {
|
||||
textParts = append(textParts, fmt.Sprintf("[File: %s]\n```\n%s\n```", uri, string(content)))
|
||||
} else {
|
||||
// Binary file - create FilePart for models that support it
|
||||
files = append(files, kit.LLMFilePart{
|
||||
Filename: extractFilenameFromURI(uri),
|
||||
Data: content,
|
||||
MediaType: mimeType,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// If we can't read it, include as a text reference
|
||||
log.Debug("acp: resource link failed to load", "uri", uri, "error", err)
|
||||
textParts = append(textParts, fmt.Sprintf("[Referenced file: %s]", uri))
|
||||
}
|
||||
|
||||
default:
|
||||
log.Debug("acp: content block", "index", i, "type", "unknown/unhandled")
|
||||
}
|
||||
}
|
||||
return text
|
||||
|
||||
// Debug log the extracted content
|
||||
for i, f := range files {
|
||||
log.Debug("acp: extracted file", "index", i, "filename", f.Filename, "mime", f.MediaType, "size", len(f.Data))
|
||||
}
|
||||
|
||||
return strings.Join(textParts, "\n"), files
|
||||
}
|
||||
|
||||
// isTextMimeType returns true if the MIME type indicates text content.
|
||||
func isTextMimeType(mimeType string) bool {
|
||||
return strings.HasPrefix(mimeType, "text/") ||
|
||||
mimeType == "application/json" ||
|
||||
mimeType == "application/xml" ||
|
||||
mimeType == "application/javascript" ||
|
||||
mimeType == "application/typescript" ||
|
||||
mimeType == "application/x-sh" ||
|
||||
mimeType == "application/x-python" ||
|
||||
mimeType == "application/x-yaml" ||
|
||||
mimeType == "application/x-toml"
|
||||
}
|
||||
|
||||
// looksLikeText checks if the content appears to be text (not binary).
|
||||
// It samples the first 512 bytes and checks for null bytes or high
|
||||
// concentration of non-printable characters.
|
||||
func looksLikeText(data []byte) bool {
|
||||
if len(data) == 0 {
|
||||
return true
|
||||
}
|
||||
// Check first 512 bytes (or less if file is smaller)
|
||||
sampleSize := min(len(data), 512)
|
||||
sample := data[:sampleSize]
|
||||
|
||||
// Count non-printable characters
|
||||
nonPrintable := 0
|
||||
for _, b := range sample {
|
||||
// Null byte indicates binary
|
||||
if b == 0 {
|
||||
return false
|
||||
}
|
||||
// Count control characters (except common whitespace)
|
||||
if b < 32 && b != '\n' && b != '\r' && b != '\t' {
|
||||
nonPrintable++
|
||||
}
|
||||
}
|
||||
|
||||
// If more than 30% non-printable, consider it binary
|
||||
return float64(nonPrintable)/float64(sampleSize) < 0.3
|
||||
}
|
||||
|
||||
// extractFilenameFromURI extracts a filename from a file URI or path.
|
||||
func extractFilenameFromURI(uri string) string {
|
||||
// Handle file:// URIs
|
||||
uri = strings.TrimPrefix(uri, "file://")
|
||||
// Extract basename
|
||||
if idx := strings.LastIndex(uri, "/"); idx >= 0 {
|
||||
return uri[idx+1:]
|
||||
}
|
||||
return uri
|
||||
}
|
||||
|
||||
// readResourceFromURI attempts to read file content from a file:// URI.
|
||||
func readResourceFromURI(uri string) ([]byte, error) {
|
||||
if !strings.HasPrefix(uri, "file://") {
|
||||
return nil, fmt.Errorf("unsupported URI scheme: %s", uri)
|
||||
}
|
||||
path := uri[7:] // Remove file:// prefix
|
||||
return os.ReadFile(path)
|
||||
}
|
||||
|
||||
// parseToolArgs attempts to parse a JSON tool args string into a map for
|
||||
|
||||
+351
-113
@@ -25,11 +25,26 @@ type AgentConfig struct {
|
||||
StreamingEnabled bool
|
||||
DebugLogger tools.DebugLogger
|
||||
|
||||
// AuthHandler handles OAuth authorization for remote MCP servers.
|
||||
// When set, remote transports are configured with OAuth support.
|
||||
// If nil, remote MCP servers that require OAuth will fail to connect.
|
||||
AuthHandler tools.MCPAuthHandler
|
||||
|
||||
// TokenStoreFactory, if non-nil, creates a custom token store for each
|
||||
// remote MCP server's OAuth tokens. When nil, the default file-based
|
||||
// token store is used.
|
||||
TokenStoreFactory tools.TokenStoreFactory
|
||||
|
||||
// CoreTools overrides the default core tool set. If empty, core.AllTools()
|
||||
// is used. This allows SDK users to provide a custom tool set (e.g.
|
||||
// CodingTools or tools with a custom WorkDir).
|
||||
CoreTools []fantasy.AgentTool
|
||||
|
||||
// DisableCoreTools, when true, prevents loading any core tools.
|
||||
// If both DisableCoreTools is true and CoreTools is empty, the agent
|
||||
// will have no tools (useful for simple chat completions).
|
||||
DisableCoreTools bool
|
||||
|
||||
// ToolWrapper is an optional function that wraps the combined tool list
|
||||
// before it is passed to the LLM agent. Used by the extensions system
|
||||
// to intercept tool calls/results.
|
||||
@@ -38,6 +53,11 @@ type AgentConfig struct {
|
||||
// ExtraTools are additional tools to include alongside core and MCP tools.
|
||||
// Used by extensions to register custom tools.
|
||||
ExtraTools []fantasy.AgentTool
|
||||
|
||||
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
|
||||
// loading (successfully or with error). The callback receives the server
|
||||
// name, tool count, and any error. Called from the background goroutine.
|
||||
OnMCPServerLoaded func(serverName string, toolCount int, err error)
|
||||
}
|
||||
|
||||
// ToolCallHandler is a function type for handling tool calls as they happen.
|
||||
@@ -63,6 +83,10 @@ type ToolCallContentHandler func(content string)
|
||||
// ReasoningDeltaHandler is a function type for handling streaming reasoning/thinking deltas.
|
||||
type ReasoningDeltaHandler func(delta string)
|
||||
|
||||
// ReasoningCompleteHandler is a function type for handling reasoning/thinking completion.
|
||||
// Called when the last reasoning token has been processed, before text streaming starts.
|
||||
type ReasoningCompleteHandler func()
|
||||
|
||||
// ToolOutputHandler is a function type for handling streaming tool output chunks.
|
||||
// Used by tools like bash to stream output as it arrives rather than waiting
|
||||
// for the command to complete. The isStderr flag indicates if the chunk
|
||||
@@ -70,6 +94,14 @@ type ReasoningDeltaHandler func(delta string)
|
||||
// Note: This is an alias for core.ToolOutputCallback to avoid import cycles.
|
||||
type ToolOutputHandler = core.ToolOutputCallback
|
||||
|
||||
// StepMessagesHandler is a function type for persisting messages after each
|
||||
// complete step in a multi-step agent turn. The handler receives the messages
|
||||
// produced by the step (typically an assistant message with tool calls followed
|
||||
// by a tool-role message with results, or a final assistant message with text).
|
||||
// This enables incremental session persistence so that progress is saved as
|
||||
// it happens rather than only at the end of the turn.
|
||||
type StepMessagesHandler func(stepMessages []fantasy.Message)
|
||||
|
||||
// StepUsageHandler is a function type for handling token usage after each
|
||||
// complete step in a multi-step agent turn. This enables real-time cost
|
||||
// tracking during long-running tool-calling conversations.
|
||||
@@ -79,6 +111,10 @@ type StepUsageHandler func(inputTokens, outputTokens, cacheReadTokens, cacheCrea
|
||||
// Core tools (bash, read, write, edit, grep, find, ls) are registered as direct
|
||||
// AgentTool implementations — no MCP layer, no serialization overhead.
|
||||
// Additional tools from external MCP servers can be loaded alongside core tools.
|
||||
//
|
||||
// When MCP servers are configured, tool loading happens in the background so the
|
||||
// agent (and UI) can start immediately. The first LLM call automatically waits
|
||||
// for MCP tools to finish loading before proceeding.
|
||||
type Agent struct {
|
||||
toolManager *tools.MCPToolManager
|
||||
fantasyAgent fantasy.Agent
|
||||
@@ -92,6 +128,18 @@ type Agent struct {
|
||||
coreTools []fantasy.AgentTool
|
||||
extraTools []fantasy.AgentTool
|
||||
toolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool // stored for SetModel rebuild
|
||||
|
||||
// providerOptions and modelConfig are stored for rebuilding the fantasy
|
||||
// agent when MCP tools arrive asynchronously or on SetModel.
|
||||
providerOptions fantasy.ProviderOptions
|
||||
skipMaxOutputTokens bool
|
||||
modelConfig *models.ProviderConfig
|
||||
|
||||
// mcpReady is closed when background MCP tool loading completes (success
|
||||
// or failure). nil when no MCP servers are configured.
|
||||
mcpReady chan struct{}
|
||||
// mcpErr holds any error from background MCP loading.
|
||||
mcpErr error
|
||||
}
|
||||
|
||||
// GenerateWithLoopResult contains the result and conversation history from an agent interaction.
|
||||
@@ -106,11 +154,19 @@ type GenerateWithLoopResult struct {
|
||||
TotalUsage fantasy.Usage
|
||||
// StopReason is the LLM provider's finish reason for the final response.
|
||||
StopReason string
|
||||
// PersistedMessageCount is the number of new messages (beyond the original
|
||||
// input) that were already persisted incrementally via OnStepMessages during
|
||||
// generation. The caller should skip these when doing post-generation
|
||||
// persistence to avoid duplicates.
|
||||
PersistedMessageCount int
|
||||
}
|
||||
|
||||
// NewAgent creates a new Agent with core tools and optional MCP tool integration.
|
||||
// Core tools (bash, read, write, edit, grep, find, ls) are always registered.
|
||||
// External MCP tools are loaded from the config if any MCP servers are configured.
|
||||
// If MCP servers are configured, their tools are loaded in the background —
|
||||
// the agent returns immediately and is usable with core tools only. The first
|
||||
// LLM call (GenerateWithLoop) automatically waits for MCP tools to finish
|
||||
// loading and rebuilds the agent with the full tool set.
|
||||
func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
// Create the LLM provider
|
||||
providerResult, err := models.CreateProvider(ctx, agentConfig.ModelConfig)
|
||||
@@ -120,34 +176,22 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
|
||||
// Register core tools (direct AgentTool implementations, no MCP overhead).
|
||||
// Use caller-provided tools if set, otherwise default to all core tools.
|
||||
coreTools := agentConfig.CoreTools
|
||||
if len(coreTools) == 0 {
|
||||
// DisableCoreTools allows explicitly having zero tools (for chat-only mode).
|
||||
var coreTools []fantasy.AgentTool
|
||||
if agentConfig.DisableCoreTools && len(agentConfig.CoreTools) == 0 {
|
||||
// Explicitly zero tools - chat-only mode
|
||||
coreTools = nil
|
||||
} else if len(agentConfig.CoreTools) > 0 {
|
||||
// Custom tools provided - use them
|
||||
coreTools = agentConfig.CoreTools
|
||||
} else {
|
||||
// Default: load all core tools
|
||||
coreTools = core.AllTools()
|
||||
}
|
||||
|
||||
// Build the combined tool list: core tools + any external MCP tools
|
||||
// Build the initial tool list: core tools + extension tools (no MCP yet).
|
||||
allTools := make([]fantasy.AgentTool, len(coreTools))
|
||||
copy(allTools, coreTools)
|
||||
|
||||
// Load external MCP tools if configured
|
||||
var toolManager *tools.MCPToolManager
|
||||
if agentConfig.MCPConfig != nil && len(agentConfig.MCPConfig.MCPServers) > 0 {
|
||||
toolManager = tools.NewMCPToolManager()
|
||||
toolManager.SetModel(providerResult.Model)
|
||||
|
||||
if agentConfig.DebugLogger != nil {
|
||||
toolManager.SetDebugLogger(agentConfig.DebugLogger)
|
||||
}
|
||||
|
||||
if err := toolManager.LoadTools(ctx, agentConfig.MCPConfig); err != nil {
|
||||
// MCP tool loading failures are non-fatal; core tools still work
|
||||
fmt.Printf("Warning: Failed to load MCP tools: %v\n", err)
|
||||
} else {
|
||||
mcpTools := toolManager.GetTools()
|
||||
allTools = append(allTools, mcpTools...)
|
||||
}
|
||||
}
|
||||
|
||||
// Append any extra tools provided by extensions.
|
||||
if len(agentConfig.ExtraTools) > 0 {
|
||||
allTools = append(allTools, agentConfig.ExtraTools...)
|
||||
@@ -159,6 +203,147 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
}
|
||||
|
||||
// Build agent options
|
||||
agentOpts := buildAgentOptions(agentConfig, providerResult, allTools)
|
||||
|
||||
// Create the agent
|
||||
fantasyAgent := fantasy.NewAgent(providerResult.Model, agentOpts...)
|
||||
|
||||
// Determine provider type from model string
|
||||
providerType := "default"
|
||||
if agentConfig.ModelConfig != nil && agentConfig.ModelConfig.ModelString != "" {
|
||||
if p, _, err := models.ParseModelString(agentConfig.ModelConfig.ModelString); err == nil {
|
||||
providerType = p
|
||||
}
|
||||
}
|
||||
|
||||
a := &Agent{
|
||||
fantasyAgent: fantasyAgent,
|
||||
model: providerResult.Model,
|
||||
providerCloser: providerResult.Closer,
|
||||
maxSteps: agentConfig.MaxSteps,
|
||||
systemPrompt: agentConfig.SystemPrompt,
|
||||
loadingMessage: providerResult.Message,
|
||||
providerType: providerType,
|
||||
streamingEnabled: agentConfig.StreamingEnabled,
|
||||
coreTools: coreTools,
|
||||
extraTools: agentConfig.ExtraTools,
|
||||
toolWrapper: agentConfig.ToolWrapper,
|
||||
providerOptions: providerResult.ProviderOptions,
|
||||
skipMaxOutputTokens: providerResult.SkipMaxOutputTokens,
|
||||
modelConfig: agentConfig.ModelConfig,
|
||||
}
|
||||
|
||||
// Start MCP tool loading in the background if servers are configured.
|
||||
// The mcpReady channel is closed when loading completes (success or failure).
|
||||
if agentConfig.MCPConfig != nil && len(agentConfig.MCPConfig.MCPServers) > 0 {
|
||||
toolManager := tools.NewMCPToolManager()
|
||||
toolManager.SetModel(providerResult.Model)
|
||||
if agentConfig.AuthHandler != nil {
|
||||
toolManager.SetAuthHandler(agentConfig.AuthHandler)
|
||||
}
|
||||
if agentConfig.TokenStoreFactory != nil {
|
||||
toolManager.SetTokenStoreFactory(agentConfig.TokenStoreFactory)
|
||||
}
|
||||
if agentConfig.DebugLogger != nil {
|
||||
toolManager.SetDebugLogger(agentConfig.DebugLogger)
|
||||
}
|
||||
// Set per-server loaded callback if provided.
|
||||
if agentConfig.OnMCPServerLoaded != nil {
|
||||
toolManager.SetOnServerLoaded(agentConfig.OnMCPServerLoaded)
|
||||
}
|
||||
a.toolManager = toolManager
|
||||
a.mcpReady = make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(a.mcpReady)
|
||||
if err := toolManager.LoadTools(ctx, agentConfig.MCPConfig); err != nil {
|
||||
a.mcpErr = err
|
||||
fmt.Printf("Warning: Failed to load MCP tools: %v\n", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// WaitForMCPTools blocks until background MCP tool loading completes.
|
||||
// Returns nil if no MCP servers are configured or if loading succeeded.
|
||||
// Returns the loading error if all servers failed. Safe to call multiple times.
|
||||
func (a *Agent) WaitForMCPTools() error {
|
||||
if a.mcpReady == nil {
|
||||
return nil
|
||||
}
|
||||
<-a.mcpReady
|
||||
return a.mcpErr
|
||||
}
|
||||
|
||||
// MCPToolsReady returns true if MCP tool loading has completed (or was never
|
||||
// started). This is a non-blocking check useful for UI status display.
|
||||
func (a *Agent) MCPToolsReady() bool {
|
||||
if a.mcpReady == nil {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-a.mcpReady:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// ensureMCPTools waits for MCP tools to load and rebuilds the fantasy agent
|
||||
// with the full tool set. Called lazily before the first LLM call.
|
||||
// This is idempotent — subsequent calls after the first rebuild are no-ops.
|
||||
func (a *Agent) ensureMCPTools() {
|
||||
if a.mcpReady == nil {
|
||||
return
|
||||
}
|
||||
<-a.mcpReady
|
||||
|
||||
// If there are MCP tools, rebuild the fantasy agent to include them.
|
||||
if a.toolManager != nil && len(a.toolManager.GetTools()) > 0 {
|
||||
a.rebuildFantasyAgent()
|
||||
}
|
||||
|
||||
// Nil out the channel so future calls are instant no-ops and we
|
||||
// don't rebuild again.
|
||||
a.mcpReady = nil
|
||||
}
|
||||
|
||||
// rebuildFantasyAgent reconstructs the fantasy agent with the current full
|
||||
// tool set (core + MCP + extension tools). Used after MCP tools arrive
|
||||
// asynchronously and by SetModel.
|
||||
func (a *Agent) rebuildFantasyAgent() {
|
||||
allTools := make([]fantasy.AgentTool, len(a.coreTools))
|
||||
copy(allTools, a.coreTools)
|
||||
if a.toolManager != nil {
|
||||
allTools = append(allTools, a.toolManager.GetTools()...)
|
||||
}
|
||||
if len(a.extraTools) > 0 {
|
||||
allTools = append(allTools, a.extraTools...)
|
||||
}
|
||||
if a.toolWrapper != nil {
|
||||
allTools = a.toolWrapper(allTools)
|
||||
}
|
||||
|
||||
providerResult := &models.ProviderResult{
|
||||
Model: a.model,
|
||||
ProviderOptions: a.providerOptions,
|
||||
SkipMaxOutputTokens: a.skipMaxOutputTokens,
|
||||
}
|
||||
agentOpts := buildAgentOptions(&AgentConfig{
|
||||
ModelConfig: a.modelConfig,
|
||||
SystemPrompt: a.systemPrompt,
|
||||
MaxSteps: a.maxSteps,
|
||||
}, providerResult, allTools)
|
||||
|
||||
a.fantasyAgent = fantasy.NewAgent(a.model, agentOpts...)
|
||||
}
|
||||
|
||||
// buildAgentOptions constructs the fantasy.AgentOption slice from config,
|
||||
// provider result, and the combined tool list. Shared by NewAgent,
|
||||
// rebuildFantasyAgent, and SetModel.
|
||||
func buildAgentOptions(agentConfig *AgentConfig, providerResult *models.ProviderResult, allTools []fantasy.AgentTool) []fantasy.AgentOption {
|
||||
var agentOpts []fantasy.AgentOption
|
||||
|
||||
if agentConfig.SystemPrompt != "" {
|
||||
@@ -196,33 +381,15 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
if agentConfig.ModelConfig.TopK != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithTopK(int64(*agentConfig.ModelConfig.TopK)))
|
||||
}
|
||||
}
|
||||
|
||||
// Create the agent
|
||||
fantasyAgent := fantasy.NewAgent(providerResult.Model, agentOpts...)
|
||||
|
||||
// Determine provider type from model string
|
||||
providerType := "default"
|
||||
if agentConfig.ModelConfig != nil && agentConfig.ModelConfig.ModelString != "" {
|
||||
if p, _, err := models.ParseModelString(agentConfig.ModelConfig.ModelString); err == nil {
|
||||
providerType = p
|
||||
if agentConfig.ModelConfig.FrequencyPenalty != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithFrequencyPenalty(float64(*agentConfig.ModelConfig.FrequencyPenalty)))
|
||||
}
|
||||
if agentConfig.ModelConfig.PresencePenalty != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithPresencePenalty(float64(*agentConfig.ModelConfig.PresencePenalty)))
|
||||
}
|
||||
}
|
||||
|
||||
return &Agent{
|
||||
toolManager: toolManager,
|
||||
fantasyAgent: fantasyAgent,
|
||||
model: providerResult.Model,
|
||||
providerCloser: providerResult.Closer,
|
||||
maxSteps: agentConfig.MaxSteps,
|
||||
systemPrompt: agentConfig.SystemPrompt,
|
||||
loadingMessage: providerResult.Message,
|
||||
providerType: providerType,
|
||||
streamingEnabled: agentConfig.StreamingEnabled,
|
||||
coreTools: coreTools,
|
||||
extraTools: agentConfig.ExtraTools,
|
||||
toolWrapper: agentConfig.ToolWrapper,
|
||||
}, nil
|
||||
return agentOpts
|
||||
}
|
||||
|
||||
// GenerateWithLoop processes messages with a custom loop that displays tool calls in real-time.
|
||||
@@ -231,7 +398,7 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []fantasy.Message
|
||||
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult,
|
||||
onResponse, onToolCallContent, nil, nil, nil, nil)
|
||||
onResponse, onToolCallContent, nil, nil, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
// GenerateWithLoopAndStreaming processes messages using the agent with streaming and callbacks.
|
||||
@@ -242,10 +409,17 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
|
||||
onStreamingResponse StreamingResponseHandler,
|
||||
onReasoningDelta ReasoningDeltaHandler,
|
||||
onReasoningComplete ReasoningCompleteHandler,
|
||||
onToolOutput ToolOutputHandler,
|
||||
onStepMessages StepMessagesHandler,
|
||||
onStepUsage StepUsageHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
|
||||
// Wait for background MCP tool loading to complete and rebuild the
|
||||
// fantasy agent with the full tool set. This is a no-op when no MCP
|
||||
// servers are configured or tools have already been integrated.
|
||||
a.ensureMCPTools()
|
||||
|
||||
// Inject tool output handler into context for use by core tools (e.g., bash).
|
||||
if onToolOutput != nil {
|
||||
ctx = core.ContextWithToolOutputCallback(ctx, onToolOutput)
|
||||
@@ -277,6 +451,10 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
// when it returns an error, but the OnStepFinish callback fires
|
||||
// for every step that completed before the error occurred.
|
||||
var completedStepMessages []fantasy.Message
|
||||
// persistedCount tracks how many new messages (beyond the original
|
||||
// input) were persisted incrementally via onStepMessages, so the
|
||||
// caller can skip them during post-generation persistence.
|
||||
var persistedCount int
|
||||
|
||||
// Use the streaming agent
|
||||
streamCall := fantasy.AgentStreamCall{
|
||||
@@ -295,6 +473,17 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
return nil
|
||||
},
|
||||
|
||||
// Reasoning/thinking complete callback
|
||||
OnReasoningEnd: func(id string, _ fantasy.ReasoningContent) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if onReasoningComplete != nil {
|
||||
onReasoningComplete()
|
||||
}
|
||||
return nil
|
||||
},
|
||||
|
||||
// Text streaming callback
|
||||
OnTextDelta: func(id, text string) error {
|
||||
if ctx.Err() != nil {
|
||||
@@ -351,6 +540,13 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
// persisted even if a later step is cancelled.
|
||||
completedStepMessages = append(completedStepMessages, step.Messages...)
|
||||
|
||||
// Persist step messages incrementally so progress is saved
|
||||
// as it happens rather than only at the end of the turn.
|
||||
if onStepMessages != nil && len(step.Messages) > 0 {
|
||||
onStepMessages(step.Messages)
|
||||
persistedCount += len(step.Messages)
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
@@ -381,7 +577,7 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
opts fantasy.PrepareStepFunctionOptions,
|
||||
) (context.Context, fantasy.PrepareStepResult, error) {
|
||||
// Drain all pending steer messages (non-blocking).
|
||||
var steered []string
|
||||
var steered []SteerMessage
|
||||
for {
|
||||
select {
|
||||
case msg := <-steerCh:
|
||||
@@ -398,9 +594,9 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
if len(steered) > 0 {
|
||||
// Inject each steer message as a user message so the
|
||||
// LLM sees the redirection on the next step.
|
||||
for _, text := range steered {
|
||||
for _, sm := range steered {
|
||||
result.Messages = append(result.Messages,
|
||||
fantasy.NewUserMessage(text))
|
||||
fantasy.NewUserMessage(sm.Text, sm.Files...))
|
||||
}
|
||||
// Notify that steer messages were consumed.
|
||||
if onConsumed != nil {
|
||||
@@ -429,19 +625,25 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
partialMessages = append(partialMessages, messages...)
|
||||
partialMessages = append(partialMessages, completedStepMessages...)
|
||||
return &GenerateWithLoopResult{
|
||||
ConversationMessages: partialMessages,
|
||||
ConversationMessages: partialMessages,
|
||||
PersistedMessageCount: persistedCount,
|
||||
}, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fire the response callback for callers that use it (e.g. non-streaming
|
||||
// callers that still want the final response notification).
|
||||
if onResponse != nil && result.Response.Content.Text() != "" {
|
||||
// Fire the response callback so callers (e.g. the TUI) can reset
|
||||
// streaming state. This must fire even when the response text is
|
||||
// empty (e.g. reasoning-only responses) so the UI properly resets
|
||||
// the stream component and avoids duplicate content on the next
|
||||
// flush.
|
||||
if onResponse != nil {
|
||||
onResponse(result.Response.Content.Text())
|
||||
}
|
||||
|
||||
return convertAgentResult(result, messages), nil
|
||||
r := convertAgentResult(result, messages)
|
||||
r.PersistedMessageCount = persistedCount
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// Non-streaming path with no callbacks — use the simpler Generate call.
|
||||
@@ -454,8 +656,9 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// For non-streaming, fire the response callback with the final text
|
||||
if onResponse != nil && result.Response.Content.Text() != "" {
|
||||
// For non-streaming, fire the response callback so callers can reset
|
||||
// streaming state (see streaming path comment above).
|
||||
if onResponse != nil {
|
||||
onResponse(result.Response.Content.Text())
|
||||
}
|
||||
|
||||
@@ -623,6 +826,67 @@ func (a *Agent) GetExtensionToolCount() int {
|
||||
return len(a.extraTools)
|
||||
}
|
||||
|
||||
// SetExtraTools replaces the agent's extra tools (e.g. extension-registered
|
||||
// tools) and rebuilds the internal agent with the updated tool list. The
|
||||
// model, system prompt, and all other configuration are preserved.
|
||||
func (a *Agent) SetExtraTools(extraTools []fantasy.AgentTool) {
|
||||
a.extraTools = extraTools
|
||||
a.rebuildFantasyAgent()
|
||||
}
|
||||
|
||||
// AddMCPServer connects to a new MCP server at runtime and makes its tools
|
||||
// available to the agent. Returns the number of tools loaded.
|
||||
// If the agent has no tool manager (no MCP servers were configured at init),
|
||||
// one is created automatically.
|
||||
func (a *Agent) AddMCPServer(ctx context.Context, name string, cfg config.MCPServerConfig) (int, error) {
|
||||
// Ensure MCP tools from initial load are settled first.
|
||||
a.ensureMCPTools()
|
||||
|
||||
if a.toolManager == nil {
|
||||
a.toolManager = tools.NewMCPToolManager()
|
||||
a.toolManager.SetModel(a.model)
|
||||
a.toolManager.SetOnToolsChanged(func() {
|
||||
a.rebuildFantasyAgent()
|
||||
})
|
||||
}
|
||||
|
||||
count, err := a.toolManager.AddServer(ctx, name, cfg)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// AddServer's onToolsChanged callback triggers rebuildFantasyAgent,
|
||||
// but only if it was wired. Ensure rebuild happens regardless.
|
||||
a.rebuildFantasyAgent()
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// RemoveMCPServer disconnects an MCP server and removes its tools from the agent.
|
||||
func (a *Agent) RemoveMCPServer(name string) error {
|
||||
if a.toolManager == nil {
|
||||
return fmt.Errorf("no MCP servers loaded")
|
||||
}
|
||||
|
||||
// Ensure MCP tools from initial load are settled first.
|
||||
a.ensureMCPTools()
|
||||
|
||||
err := a.toolManager.RemoveServer(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveServer's onToolsChanged callback triggers rebuildFantasyAgent,
|
||||
// but ensure rebuild happens regardless.
|
||||
a.rebuildFantasyAgent()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMCPToolManager returns the underlying MCP tool manager.
|
||||
// Returns nil if no MCP servers have been configured.
|
||||
func (a *Agent) GetMCPToolManager() *tools.MCPToolManager {
|
||||
return a.toolManager
|
||||
}
|
||||
|
||||
// GetLoadingMessage returns the loading message from provider creation.
|
||||
func (a *Agent) GetLoadingMessage() string {
|
||||
return a.loadingMessage
|
||||
@@ -636,64 +900,20 @@ func (a *Agent) GetLoadedServerNames() []string {
|
||||
return a.toolManager.GetLoadedServerNames()
|
||||
}
|
||||
|
||||
// SetModel swaps the agent's LLM provider to a new model. The existing tools,
|
||||
// system prompt, and configuration are preserved. The old provider is closed
|
||||
// if it has a closer. Returns the previous model string for notification.
|
||||
// SetModel swaps the agent's LLM provider to a new model. The existing tools
|
||||
// and configuration are preserved. When the new model's ProviderConfig carries
|
||||
// a system prompt (from per-model settings), it replaces the agent's stored
|
||||
// prompt so the rebuilt fantasy agent uses it. The old provider is closed if
|
||||
// it has a closer.
|
||||
func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) error {
|
||||
// Ensure MCP tools are loaded before rebuilding (SetModel may be called
|
||||
// before the first LLM call).
|
||||
a.ensureMCPTools()
|
||||
|
||||
providerResult, err := models.CreateProvider(ctx, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create model provider: %v", err)
|
||||
}
|
||||
|
||||
// Rebuild tool list (same as NewAgent).
|
||||
allTools := make([]fantasy.AgentTool, len(a.coreTools))
|
||||
copy(allTools, a.coreTools)
|
||||
if a.toolManager != nil {
|
||||
allTools = append(allTools, a.toolManager.GetTools()...)
|
||||
}
|
||||
if len(a.extraTools) > 0 {
|
||||
allTools = append(allTools, a.extraTools...)
|
||||
}
|
||||
if a.toolWrapper != nil {
|
||||
allTools = a.toolWrapper(allTools)
|
||||
}
|
||||
|
||||
// Rebuild agent options.
|
||||
var agentOpts []fantasy.AgentOption
|
||||
if a.systemPrompt != "" {
|
||||
agentOpts = append(agentOpts, fantasy.WithSystemPrompt(a.systemPrompt))
|
||||
}
|
||||
if len(allTools) > 0 {
|
||||
agentOpts = append(agentOpts, fantasy.WithTools(allTools...))
|
||||
}
|
||||
if a.maxSteps > 0 {
|
||||
agentOpts = append(agentOpts, fantasy.WithStopConditions(
|
||||
fantasy.StepCountIs(a.maxSteps),
|
||||
))
|
||||
}
|
||||
|
||||
// Pass provider-specific options (e.g. OpenAI Responses API reasoning settings).
|
||||
if providerResult.ProviderOptions != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithProviderOptions(providerResult.ProviderOptions))
|
||||
}
|
||||
|
||||
// Pass generation parameters when available.
|
||||
// Skip max_output_tokens for providers that don't support it (e.g., Codex OAuth)
|
||||
if config.MaxTokens > 0 && !providerResult.SkipMaxOutputTokens {
|
||||
agentOpts = append(agentOpts, fantasy.WithMaxOutputTokens(int64(config.MaxTokens)))
|
||||
}
|
||||
if config.Temperature != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithTemperature(float64(*config.Temperature)))
|
||||
}
|
||||
if config.TopP != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithTopP(float64(*config.TopP)))
|
||||
}
|
||||
if config.TopK != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithTopK(int64(*config.TopK)))
|
||||
}
|
||||
|
||||
newFantasyAgent := fantasy.NewAgent(providerResult.Model, agentOpts...)
|
||||
|
||||
// Close old provider.
|
||||
if a.providerCloser != nil {
|
||||
_ = a.providerCloser.Close()
|
||||
@@ -705,9 +925,18 @@ func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) err
|
||||
}
|
||||
|
||||
// Swap fields.
|
||||
a.fantasyAgent = newFantasyAgent
|
||||
a.model = providerResult.Model
|
||||
a.providerCloser = providerResult.Closer
|
||||
a.providerOptions = providerResult.ProviderOptions
|
||||
a.skipMaxOutputTokens = providerResult.SkipMaxOutputTokens
|
||||
a.modelConfig = config
|
||||
|
||||
// Update system prompt when the config carries one (from per-model
|
||||
// settings or the global config). This allows model-specific system
|
||||
// prompts to take effect on model switch.
|
||||
if config.SystemPrompt != "" {
|
||||
a.systemPrompt = config.SystemPrompt
|
||||
}
|
||||
|
||||
// Update provider type.
|
||||
if config.ModelString != "" {
|
||||
@@ -716,6 +945,9 @@ func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) err
|
||||
}
|
||||
}
|
||||
|
||||
// Rebuild the fantasy agent with the new model and current tool set.
|
||||
a.rebuildFantasyAgent()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -725,7 +957,13 @@ func (a *Agent) GetModel() fantasy.LanguageModel {
|
||||
}
|
||||
|
||||
// Close closes the agent and cleans up resources.
|
||||
// If MCP tools are still loading in the background, Close waits for them
|
||||
// to finish before closing connections to avoid resource leaks.
|
||||
func (a *Agent) Close() error {
|
||||
// Wait for background MCP loading to finish before closing connections.
|
||||
if a.mcpReady != nil {
|
||||
<-a.mcpReady
|
||||
}
|
||||
var toolErr error
|
||||
if a.toolManager != nil {
|
||||
toolErr = a.toolManager.Close()
|
||||
|
||||
@@ -0,0 +1,242 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
)
|
||||
|
||||
// mockModel is a minimal LanguageModel that satisfies the interface
|
||||
// without making real API calls. Used to test tool management wiring.
|
||||
type mockModel struct{}
|
||||
|
||||
func (m *mockModel) Generate(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
return &fantasy.Response{}, nil
|
||||
}
|
||||
func (m *mockModel) Stream(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockModel) GenerateObject(_ context.Context, _ fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
return &fantasy.ObjectResponse{}, nil
|
||||
}
|
||||
func (m *mockModel) StreamObject(_ context.Context, _ fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockModel) Provider() string { return "mock" }
|
||||
func (m *mockModel) Model() string { return "mock-model" }
|
||||
|
||||
// testdataDir returns the absolute path to the tools testdata directory.
|
||||
func testdataDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
_, file, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
t.Fatal("cannot determine test file path")
|
||||
}
|
||||
return filepath.Join(filepath.Dir(file), "..", "tools", "testdata")
|
||||
}
|
||||
|
||||
// echoServerConfig returns an MCPServerConfig for the test echo MCP server.
|
||||
func echoServerConfig(t *testing.T) config.MCPServerConfig {
|
||||
t.Helper()
|
||||
script := filepath.Join(testdataDir(t), "echo_server.py")
|
||||
if _, err := os.Stat(script); err != nil {
|
||||
t.Skipf("echo_server.py not found: %v", err)
|
||||
}
|
||||
return config.MCPServerConfig{
|
||||
Command: []string{"python3", script},
|
||||
}
|
||||
}
|
||||
|
||||
// newTestAgent creates a minimal Agent with a mock model and no core tools,
|
||||
// suitable for testing MCP server management without an API key.
|
||||
func newTestAgent() *Agent {
|
||||
model := &mockModel{}
|
||||
a := &Agent{
|
||||
model: model,
|
||||
coreTools: nil,
|
||||
extraTools: nil,
|
||||
maxSteps: 10,
|
||||
systemPrompt: "test",
|
||||
fantasyAgent: fantasy.NewAgent(model),
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func TestAgent_AddMCPServer(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
a := newTestAgent()
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Initially no MCP tools.
|
||||
if a.GetMCPToolCount() != 0 {
|
||||
t.Fatalf("Expected 0 MCP tools initially, got %d", a.GetMCPToolCount())
|
||||
}
|
||||
|
||||
// Add a server.
|
||||
count, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMCPServer failed: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("Expected 2 tools, got %d", count)
|
||||
}
|
||||
|
||||
// Verify tools are in the agent's tool list.
|
||||
if a.GetMCPToolCount() != 2 {
|
||||
t.Errorf("Expected 2 MCP tools, got %d", a.GetMCPToolCount())
|
||||
}
|
||||
|
||||
allTools := a.GetTools()
|
||||
toolNames := make(map[string]bool)
|
||||
for _, tool := range allTools {
|
||||
toolNames[tool.Info().Name] = true
|
||||
}
|
||||
if !toolNames["echo__echo"] {
|
||||
t.Error("Expected tool 'echo__echo' in agent tools")
|
||||
}
|
||||
if !toolNames["echo__greet"] {
|
||||
t.Error("Expected tool 'echo__greet' in agent tools")
|
||||
}
|
||||
|
||||
// Verify loaded server names.
|
||||
names := a.GetLoadedServerNames()
|
||||
found := false
|
||||
for _, n := range names {
|
||||
if n == "echo" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected 'echo' in loaded server names: %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_RemoveMCPServer(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
a := newTestAgent()
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add then remove.
|
||||
_, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMCPServer failed: %v", err)
|
||||
}
|
||||
|
||||
err = a.RemoveMCPServer("echo")
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveMCPServer failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify tools removed.
|
||||
if a.GetMCPToolCount() != 0 {
|
||||
t.Errorf("Expected 0 MCP tools after removal, got %d", a.GetMCPToolCount())
|
||||
}
|
||||
|
||||
// Verify agent's tool list has no MCP tools.
|
||||
for _, tool := range a.GetTools() {
|
||||
if strings.Contains(tool.Info().Name, "echo__") {
|
||||
t.Errorf("Found leftover tool after removal: %s", tool.Info().Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_RemoveMCPServer_NoToolManager(t *testing.T) {
|
||||
a := newTestAgent()
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
err := a.RemoveMCPServer("nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no tool manager exists")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no MCP servers loaded") {
|
||||
t.Errorf("Expected 'no MCP servers loaded' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_AddMCPServer_CreatesToolManager(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
a := newTestAgent()
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
// Initially no tool manager.
|
||||
if a.GetMCPToolManager() != nil {
|
||||
t.Fatal("Expected nil tool manager initially")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
_, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMCPServer failed: %v", err)
|
||||
}
|
||||
|
||||
// Tool manager should now exist.
|
||||
if a.GetMCPToolManager() == nil {
|
||||
t.Fatal("Expected tool manager to be created by AddMCPServer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_AddRemoveAdd_MCP(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
a := newTestAgent()
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add → Remove → Add cycle.
|
||||
_, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("First add failed: %v", err)
|
||||
}
|
||||
|
||||
err = a.RemoveMCPServer("echo")
|
||||
if err != nil {
|
||||
t.Fatalf("Remove failed: %v", err)
|
||||
}
|
||||
|
||||
count, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Re-add failed: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("Expected 2 tools on re-add, got %d", count)
|
||||
}
|
||||
if a.GetMCPToolCount() != 2 {
|
||||
t.Errorf("Expected 2 MCP tools after re-add, got %d", a.GetMCPToolCount())
|
||||
}
|
||||
}
|
||||
@@ -36,13 +36,26 @@ type AgentCreationOptions struct {
|
||||
SpinnerFunc SpinnerFunc // Function to show spinner (provided by caller)
|
||||
// DebugLogger is an optional logger for debugging MCP communications
|
||||
DebugLogger tools.DebugLogger // Optional debug logger
|
||||
// AuthHandler handles OAuth authorization for remote MCP servers
|
||||
AuthHandler tools.MCPAuthHandler
|
||||
// TokenStoreFactory, if non-nil, creates a custom token store for each
|
||||
// remote MCP server's OAuth tokens. When nil, the default file-based
|
||||
// token store is used.
|
||||
TokenStoreFactory tools.TokenStoreFactory
|
||||
// CoreTools overrides the default core tool set. If empty, core.AllTools()
|
||||
// is used.
|
||||
CoreTools []fantasy.AgentTool
|
||||
// DisableCoreTools, when true, prevents loading any core tools.
|
||||
// If both DisableCoreTools is true and CoreTools is empty, the agent
|
||||
// will have no tools (useful for simple chat completions).
|
||||
DisableCoreTools bool
|
||||
// ToolWrapper wraps the combined tool list before agent creation.
|
||||
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
|
||||
// ExtraTools are additional tools to include (e.g. from extensions).
|
||||
ExtraTools []fantasy.AgentTool
|
||||
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
|
||||
// loading (successfully or with error). Called from the background goroutine.
|
||||
OnMCPServerLoaded func(serverName string, toolCount int, err error)
|
||||
}
|
||||
|
||||
// CreateAgent creates an agent with optional spinner for Ollama models.
|
||||
@@ -50,15 +63,19 @@ type AgentCreationOptions struct {
|
||||
// Returns the created agent or an error if creation fails.
|
||||
func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error) {
|
||||
agentConfig := &AgentConfig{
|
||||
ModelConfig: opts.ModelConfig,
|
||||
MCPConfig: opts.MCPConfig,
|
||||
SystemPrompt: opts.SystemPrompt,
|
||||
MaxSteps: opts.MaxSteps,
|
||||
StreamingEnabled: opts.StreamingEnabled,
|
||||
DebugLogger: opts.DebugLogger,
|
||||
CoreTools: opts.CoreTools,
|
||||
ToolWrapper: opts.ToolWrapper,
|
||||
ExtraTools: opts.ExtraTools,
|
||||
ModelConfig: opts.ModelConfig,
|
||||
MCPConfig: opts.MCPConfig,
|
||||
SystemPrompt: opts.SystemPrompt,
|
||||
MaxSteps: opts.MaxSteps,
|
||||
StreamingEnabled: opts.StreamingEnabled,
|
||||
DebugLogger: opts.DebugLogger,
|
||||
AuthHandler: opts.AuthHandler,
|
||||
TokenStoreFactory: opts.TokenStoreFactory,
|
||||
CoreTools: opts.CoreTools,
|
||||
DisableCoreTools: opts.DisableCoreTools,
|
||||
ToolWrapper: opts.ToolWrapper,
|
||||
ExtraTools: opts.ExtraTools,
|
||||
OnMCPServerLoaded: opts.OnMCPServerLoaded,
|
||||
}
|
||||
|
||||
var agent *Agent
|
||||
|
||||
+15
-4
@@ -1,6 +1,17 @@
|
||||
package agent
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// SteerMessage carries a steering prompt and optional file attachments
|
||||
// (e.g. clipboard images) through the steer channel.
|
||||
type SteerMessage struct {
|
||||
Text string
|
||||
Files []fantasy.FilePart
|
||||
}
|
||||
|
||||
// steerChKey is the context key for the steer channel.
|
||||
type steerChKey struct{}
|
||||
@@ -11,7 +22,7 @@ type steerConsumedKey struct{}
|
||||
// ContextWithSteerCh returns a new context with the steer channel attached.
|
||||
// The agent's PrepareStep function checks this channel between steps and
|
||||
// injects any pending steer messages as user messages before the next LLM call.
|
||||
func ContextWithSteerCh(ctx context.Context, ch <-chan string) context.Context {
|
||||
func ContextWithSteerCh(ctx context.Context, ch <-chan SteerMessage) context.Context {
|
||||
return context.WithValue(ctx, steerChKey{}, ch)
|
||||
}
|
||||
|
||||
@@ -23,8 +34,8 @@ func ContextWithSteerConsumed(ctx context.Context, fn func(count int)) context.C
|
||||
}
|
||||
|
||||
// steerChFromContext extracts the steer channel from the context, or nil.
|
||||
func steerChFromContext(ctx context.Context) <-chan string {
|
||||
ch, _ := ctx.Value(steerChKey{}).(<-chan string)
|
||||
func steerChFromContext(ctx context.Context) <-chan SteerMessage {
|
||||
ch, _ := ctx.Value(steerChKey{}).(<-chan SteerMessage)
|
||||
return ch
|
||||
}
|
||||
|
||||
|
||||
+215
-35
@@ -20,7 +20,7 @@ import (
|
||||
// queueItem holds a prompt and optional image attachments for the execution queue.
|
||||
type queueItem struct {
|
||||
Prompt string
|
||||
Files []fantasy.FilePart
|
||||
Files []kit.LLMFilePart
|
||||
}
|
||||
|
||||
// App is the application-layer orchestrator. It owns the agentic loop,
|
||||
@@ -82,7 +82,7 @@ type App struct {
|
||||
|
||||
// New creates a new App with the provided options and pre-loaded messages.
|
||||
// initialMessages may be nil or empty for a fresh session.
|
||||
func New(opts Options, initialMessages []fantasy.Message) *App {
|
||||
func New(opts Options, initialMessages []kit.LLMMessage) *App {
|
||||
rootCtx, rootCancel := context.WithCancel(context.Background())
|
||||
return &App{
|
||||
opts: opts,
|
||||
@@ -126,9 +126,8 @@ func (a *App) Run(prompt string) int {
|
||||
// If the app is idle the prompt executes immediately; otherwise it is queued.
|
||||
// Returns the current queue depth (0 = started immediately, >0 = queued).
|
||||
//
|
||||
// Satisfies ui.AppController (via RunWithImages which converts ImageAttachment
|
||||
// to fantasy.FilePart).
|
||||
func (a *App) RunWithFiles(prompt string, files []fantasy.FilePart) int {
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) RunWithFiles(prompt string, files []kit.LLMFilePart) int {
|
||||
a.mu.Lock()
|
||||
|
||||
if a.closed {
|
||||
@@ -163,6 +162,24 @@ func (a *App) CancelCurrentStep() {
|
||||
cancel()
|
||||
}
|
||||
|
||||
// IsBusy returns true when the agent is currently processing a turn.
|
||||
func (a *App) IsBusy() bool {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
return a.busy
|
||||
}
|
||||
|
||||
// Abort cancels the current agent step (if running) and clears the queue.
|
||||
// Unlike InterruptAndSend, no new message is injected — the agent simply
|
||||
// stops. Safe to call when idle (no-op).
|
||||
func (a *App) Abort() {
|
||||
a.mu.Lock()
|
||||
a.queue = a.queue[:0]
|
||||
cancel := a.cancelStep
|
||||
a.mu.Unlock()
|
||||
cancel()
|
||||
}
|
||||
|
||||
// QueueLength returns the number of prompts currently waiting in the queue.
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
@@ -188,6 +205,15 @@ func (a *App) QueueLength() int {
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) Steer(prompt string) int {
|
||||
return a.SteerWithFiles(prompt, nil)
|
||||
}
|
||||
|
||||
// SteerWithFiles injects a steering message with optional file attachments
|
||||
// (e.g. pasted images) into the currently running agent turn. Behaves like
|
||||
// Steer but includes file parts alongside the text.
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) SteerWithFiles(prompt string, files []kit.LLMFilePart) int {
|
||||
a.mu.Lock()
|
||||
|
||||
if a.closed {
|
||||
@@ -196,8 +222,8 @@ func (a *App) Steer(prompt string) int {
|
||||
}
|
||||
|
||||
if !a.busy {
|
||||
// Not busy — start immediately, same as Run().
|
||||
item := queueItem{Prompt: prompt}
|
||||
// Not busy — start immediately, same as RunWithFiles().
|
||||
item := queueItem{Prompt: prompt, Files: files}
|
||||
a.busy = true
|
||||
a.wg.Add(1)
|
||||
a.mu.Unlock()
|
||||
@@ -212,7 +238,7 @@ func (a *App) Steer(prompt string) int {
|
||||
// execution, before next LLM call). If PrepareStep doesn't fire
|
||||
// (text-only response), drainQueue will pick it up after the turn.
|
||||
if a.opts.Kit != nil {
|
||||
a.opts.Kit.InjectSteer(prompt)
|
||||
a.opts.Kit.InjectSteerWithFiles(prompt, files)
|
||||
}
|
||||
return 1
|
||||
}
|
||||
@@ -314,12 +340,12 @@ func (a *App) SwitchTreeSession(ts *session.TreeManager) {
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) AddContextMessage(text string) {
|
||||
msg := fantasy.NewUserMessage(text)
|
||||
a.store.Add(msg)
|
||||
kitMsg := fantasy.NewUserMessage(text)
|
||||
a.store.Add(kitMsg)
|
||||
|
||||
// Persist to tree session if active.
|
||||
if ts := a.opts.TreeSession; ts != nil {
|
||||
_, _ = ts.AppendLLMMessage(msg)
|
||||
_, _ = ts.AppendLLMMessage(fantasy.NewUserMessage(text))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -357,6 +383,15 @@ func (a *App) CompactConversation(customInstructions string) error {
|
||||
a.mu.Unlock()
|
||||
}()
|
||||
|
||||
// Subscribe to SDK events for streaming compaction summary to the TUI.
|
||||
sendFn := func(msg tea.Msg) {
|
||||
if a.program != nil {
|
||||
a.program.Send(msg)
|
||||
}
|
||||
}
|
||||
unsub := a.subscribeSDKEvents(sendFn, nil)
|
||||
defer unsub()
|
||||
|
||||
result, err := a.opts.Kit.Compact(a.rootCtx, nil, customInstructions)
|
||||
if err != nil {
|
||||
a.sendEvent(CompactErrorEvent{Err: err})
|
||||
@@ -382,6 +417,78 @@ func (a *App) CompactConversation(customInstructions string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CompactAsync is like CompactConversation but calls onComplete/onError
|
||||
// callbacks instead of sending TUI events. Used by the extension API's
|
||||
// ctx.Compact() which needs callback-based notification.
|
||||
func (a *App) CompactAsync(customInstructions string, onComplete func(), onError func(string)) error {
|
||||
a.mu.Lock()
|
||||
if a.closed {
|
||||
a.mu.Unlock()
|
||||
return fmt.Errorf("app is closed")
|
||||
}
|
||||
if a.busy {
|
||||
a.mu.Unlock()
|
||||
return fmt.Errorf("cannot compact while the agent is working")
|
||||
}
|
||||
if a.opts.Kit == nil {
|
||||
a.mu.Unlock()
|
||||
return fmt.Errorf("SDK instance not available")
|
||||
}
|
||||
a.busy = true
|
||||
a.wg.Add(1)
|
||||
a.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
defer a.wg.Done()
|
||||
defer func() {
|
||||
a.mu.Lock()
|
||||
a.busy = false
|
||||
a.mu.Unlock()
|
||||
}()
|
||||
|
||||
// Subscribe to SDK events for streaming compaction summary to the TUI.
|
||||
sendFn := func(msg tea.Msg) {
|
||||
if a.program != nil {
|
||||
a.program.Send(msg)
|
||||
}
|
||||
}
|
||||
unsub := a.subscribeSDKEvents(sendFn, nil)
|
||||
defer unsub()
|
||||
|
||||
result, err := a.opts.Kit.Compact(a.rootCtx, nil, customInstructions)
|
||||
if err != nil {
|
||||
a.sendEvent(CompactErrorEvent{Err: err})
|
||||
if onError != nil {
|
||||
onError(err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
if result == nil {
|
||||
a.sendEvent(CompactErrorEvent{Err: fmt.Errorf("nothing to compact")})
|
||||
if onError != nil {
|
||||
onError("nothing to compact")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Sync in-memory store with the compacted session.
|
||||
if a.opts.TreeSession != nil {
|
||||
a.store.Replace(a.opts.TreeSession.GetLLMMessages())
|
||||
}
|
||||
|
||||
a.sendEvent(CompactCompleteEvent{
|
||||
Summary: result.Summary,
|
||||
OriginalTokens: result.OriginalTokens,
|
||||
CompactedTokens: result.CompactedTokens,
|
||||
MessagesRemoved: result.MessagesRemoved,
|
||||
})
|
||||
if onComplete != nil {
|
||||
onComplete()
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Non-interactive execution
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -506,11 +613,10 @@ func (a *App) drainQueue(first queueItem) {
|
||||
a.mu.Lock()
|
||||
items = append(items, a.queue...)
|
||||
a.queue = a.queue[:0] // Clear the queue
|
||||
queueLen := len(a.queue)
|
||||
a.mu.Unlock()
|
||||
|
||||
// Send queue updated event (queue is now empty)
|
||||
a.sendEvent(QueueUpdatedEvent{Length: queueLen})
|
||||
// Notify UI: all queued messages have been consumed into this batch.
|
||||
a.sendEvent(QueueUpdatedEvent{Length: 0})
|
||||
|
||||
// Process all collected items as a single batch
|
||||
a.runQueueBatch(items)
|
||||
@@ -523,8 +629,8 @@ func (a *App) drainQueue(first queueItem) {
|
||||
if leftover := a.opts.Kit.DrainSteer(); len(leftover) > 0 {
|
||||
a.mu.Lock()
|
||||
steerItems := make([]queueItem, len(leftover))
|
||||
for i, text := range leftover {
|
||||
steerItems[i] = queueItem{Prompt: text}
|
||||
for i, sm := range leftover {
|
||||
steerItems[i] = queueItem{Prompt: sm.Text, Files: sm.Files}
|
||||
}
|
||||
a.queue = append(steerItems, a.queue...)
|
||||
a.mu.Unlock()
|
||||
@@ -543,6 +649,11 @@ func (a *App) drainQueue(first queueItem) {
|
||||
}
|
||||
a.mu.Unlock()
|
||||
|
||||
if hasMore {
|
||||
// Notify UI: these newly queued messages have been consumed into the next batch.
|
||||
a.sendEvent(QueueUpdatedEvent{Length: 0})
|
||||
}
|
||||
|
||||
if !hasMore {
|
||||
// No more items, we're done
|
||||
break
|
||||
@@ -609,7 +720,7 @@ func (a *App) runQueueBatch(items []queueItem) {
|
||||
// executeStep runs a single agentic step by delegating to the SDK's
|
||||
// PromptResult() (or PromptResultWithFiles for multimodal), which handles
|
||||
// session persistence, hooks, extension events, and the generation loop.
|
||||
func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.Msg), files []fantasy.FilePart) (*kit.TurnResult, error) {
|
||||
func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.Msg), files []kit.LLMFilePart) (*kit.TurnResult, error) {
|
||||
// Test hook: bypass SDK entirely.
|
||||
if a.opts.PromptFunc != nil {
|
||||
return a.opts.PromptFunc(ctx, prompt)
|
||||
@@ -776,6 +887,8 @@ func (a *App) subscribeSDKEvents(sendFn func(tea.Msg), stepUsageSeen *atomic.Boo
|
||||
sendFn(StreamChunkEvent{Content: ev.Chunk})
|
||||
case kit.ReasoningDeltaEvent:
|
||||
sendFn(ReasoningChunkEvent{Delta: ev.Delta})
|
||||
case kit.ReasoningCompleteEvent:
|
||||
sendFn(ReasoningCompleteEvent{})
|
||||
case kit.ToolOutputEvent:
|
||||
sendFn(ToolOutputEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
@@ -817,7 +930,8 @@ func (a *App) QuitFromExtension() {
|
||||
// controls styling: "" for plain text, "info" for a system message block,
|
||||
// "error" for an error block. In interactive mode it sends an
|
||||
// ExtensionPrintEvent through the program so the TUI can render it with the
|
||||
// appropriate renderer. In non-interactive mode it falls back to stdout.
|
||||
// appropriate renderer. In non-interactive mode it falls back to stderr with
|
||||
// a level prefix so errors are distinguishable from plain output.
|
||||
func (a *App) PrintFromExtension(level, text string) {
|
||||
a.mu.Lock()
|
||||
prog := a.program
|
||||
@@ -826,8 +940,16 @@ func (a *App) PrintFromExtension(level, text string) {
|
||||
prog.Send(ExtensionPrintEvent{Text: text, Level: level})
|
||||
return
|
||||
}
|
||||
// Non-interactive fallback: write directly to stdout.
|
||||
fmt.Println(text)
|
||||
// Non-interactive fallback: write to stderr with a level prefix so that
|
||||
// errors and info messages are distinguishable from plain output.
|
||||
switch level {
|
||||
case "error":
|
||||
fmt.Fprintf(os.Stderr, "[ERROR] %s\n", text)
|
||||
case "info":
|
||||
fmt.Fprintf(os.Stderr, "[INFO] %s\n", text)
|
||||
default:
|
||||
fmt.Println(text)
|
||||
}
|
||||
}
|
||||
|
||||
// SetEditorTextFromExtension sends an EditorTextSetEvent to the TUI to
|
||||
@@ -884,6 +1006,47 @@ func (a *App) NotifyWidgetUpdate() {
|
||||
}
|
||||
}
|
||||
|
||||
// NotifyContentReload sends a ContentReloadEvent to the TUI so it refreshes
|
||||
// prompt templates and skills from their provider callbacks. Called by file
|
||||
// watchers when .md/.txt files change in prompt or skill directories.
|
||||
// In non-interactive mode this is a no-op.
|
||||
func (a *App) NotifyContentReload() {
|
||||
a.mu.Lock()
|
||||
prog := a.program
|
||||
a.mu.Unlock()
|
||||
if prog != nil {
|
||||
prog.Send(ContentReloadEvent{})
|
||||
}
|
||||
}
|
||||
|
||||
// NotifyMCPToolsReady sends an MCPToolsReadyEvent to the TUI so it refreshes
|
||||
// tool names and MCP tool count from provider callbacks. Called when background
|
||||
// MCP tool loading completes. In non-interactive mode this is a no-op.
|
||||
func (a *App) NotifyMCPToolsReady() {
|
||||
a.mu.Lock()
|
||||
prog := a.program
|
||||
a.mu.Unlock()
|
||||
if prog != nil {
|
||||
prog.Send(MCPToolsReadyEvent{})
|
||||
}
|
||||
}
|
||||
|
||||
// NotifyMCPServerLoaded sends an MCPServerLoadedEvent to the TUI so it can
|
||||
// display a system message when a single MCP server finishes loading. Called
|
||||
// per server as background MCP tool loading progresses.
|
||||
func (a *App) NotifyMCPServerLoaded(serverName string, toolCount int, err error) {
|
||||
a.mu.Lock()
|
||||
prog := a.program
|
||||
a.mu.Unlock()
|
||||
if prog != nil {
|
||||
prog.Send(MCPServerLoadedEvent{
|
||||
ServerName: serverName,
|
||||
ToolCount: toolCount,
|
||||
Error: err,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// SendEvent sends a tea.Msg to the registered program. Safe to call from
|
||||
// any goroutine. No-op when no program is registered.
|
||||
//
|
||||
@@ -968,11 +1131,12 @@ func (a *App) PrintBlockFromExtension(opts extensions.PrintBlockOpts) {
|
||||
})
|
||||
return
|
||||
}
|
||||
// Non-interactive fallback.
|
||||
// Non-interactive fallback: render a simple framed block to stderr so
|
||||
// it is visually distinct from plain stdout output.
|
||||
if opts.Subtitle != "" {
|
||||
fmt.Printf("%s\n — %s\n", opts.Text, opts.Subtitle)
|
||||
fmt.Fprintf(os.Stderr, "--- %s ---\n%s\n", opts.Subtitle, opts.Text)
|
||||
} else {
|
||||
fmt.Println(opts.Text)
|
||||
fmt.Fprintf(os.Stderr, "---\n%s\n---\n", opts.Text)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1001,9 +1165,10 @@ func (a *App) recordStepUsage(ev kit.StepUsageEvent, stepUsageSeen *atomic.Bool)
|
||||
int(ev.CacheWriteTokens),
|
||||
)
|
||||
// NOTE: We do NOT call SetContextTokens here. Context fill is set once
|
||||
// at turn completion via updateUsageFromTurnResult using FinalUsage.InputTokens,
|
||||
// which reflects the full accumulated context. Per-step context tokens would
|
||||
// cause the display to jump around during multi-step tool calls.
|
||||
// at turn completion via updateUsageFromTurnResult, which sums all token
|
||||
// categories (Input + CacheRead + CacheCreate + Output) from FinalUsage.
|
||||
// Per-step context tokens would cause the display to jump around during
|
||||
// multi-step tool calls.
|
||||
}
|
||||
|
||||
// updateUsageFromTurnResult records token usage from an SDK TurnResult into the
|
||||
@@ -1067,15 +1232,30 @@ func (a *App) updateUsageFromTurnResult(result *kit.TurnResult, userPrompt strin
|
||||
}
|
||||
|
||||
// --- Context window fill (drives the % bar) ---
|
||||
// Use FinalUsage.InputTokens as the context window fill. The API's InputTokens
|
||||
// already includes the full conversation history (system prompt + all previous
|
||||
// messages + current user message). Adding OutputTokens would double-count since
|
||||
// the output becomes part of the input for the next turn.
|
||||
if result.FinalUsage != nil && result.FinalUsage.InputTokens > 0 {
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: calling SetContextTokens=%d (FinalUsage.InputTokens)",
|
||||
result.FinalUsage.InputTokens)
|
||||
// Calculate context fill from the LAST API call's usage. The context
|
||||
// window is filled by everything sent to and received from the model:
|
||||
//
|
||||
// InputTokens — non-cached input (may be small with prompt caching)
|
||||
// CacheReadTokens — input tokens served from cache
|
||||
// CacheCreationTokens — input tokens written to cache this call
|
||||
// OutputTokens — assistant output (becomes input next turn)
|
||||
//
|
||||
// With Anthropic prompt caching, InputTokens can drop to near-zero while
|
||||
// CacheReadTokens holds the bulk of the context. We must sum all four to
|
||||
// get the true context window utilization.
|
||||
//
|
||||
// We use FinalUsage (last step only), NOT TotalUsage, because TotalUsage
|
||||
// sums across all tool-calling steps — and each step re-sends the full
|
||||
// conversation, so TotalUsage massively overstates the actual window fill.
|
||||
if result.FinalUsage != nil {
|
||||
u := result.FinalUsage
|
||||
contextFill := int(u.InputTokens) + int(u.CacheReadTokens) + int(u.CacheCreationTokens) + int(u.OutputTokens)
|
||||
if contextFill > 0 {
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: SetContextTokens=%d (Input=%d + CacheRead=%d + CacheCreate=%d + Output=%d)",
|
||||
contextFill, u.InputTokens, u.CacheReadTokens, u.CacheCreationTokens, u.OutputTokens)
|
||||
}
|
||||
a.opts.UsageTracker.SetContextTokens(contextFill)
|
||||
}
|
||||
a.opts.UsageTracker.SetContextTokens(int(result.FinalUsage.InputTokens))
|
||||
}
|
||||
}
|
||||
|
||||
+25
-21
@@ -7,8 +7,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
@@ -574,13 +572,13 @@ func TestUpdateUsageFromTurnResult_skipsTotalsWhenStepUsageSeen(t *testing.T) {
|
||||
|
||||
app.updateUsageFromTurnResult(&kit.TurnResult{
|
||||
Response: "ok",
|
||||
TotalUsage: &fantasy.Usage{
|
||||
TotalUsage: &kit.LLMUsage{
|
||||
InputTokens: 999,
|
||||
OutputTokens: 111,
|
||||
CacheReadTokens: 7,
|
||||
CacheCreationTokens: 3,
|
||||
},
|
||||
FinalUsage: &fantasy.Usage{InputTokens: 456},
|
||||
FinalUsage: &kit.LLMUsage{InputTokens: 456},
|
||||
}, "prompt", true)
|
||||
|
||||
usage.mu.Lock()
|
||||
@@ -608,13 +606,13 @@ func TestUpdateUsageFromTurnResult_recordsWhenInputTokensZero(t *testing.T) {
|
||||
// Simulate OpenAI-compatible behavior: all prompt tokens cached, InputTokens=0
|
||||
app.updateUsageFromTurnResult(&kit.TurnResult{
|
||||
Response: "ok",
|
||||
TotalUsage: &fantasy.Usage{
|
||||
TotalUsage: &kit.LLMUsage{
|
||||
InputTokens: 0, // All cached - subtracted from prompt
|
||||
OutputTokens: 150, // Actual generated tokens
|
||||
CacheReadTokens: 500, // Cache hit
|
||||
CacheCreationTokens: 0,
|
||||
},
|
||||
FinalUsage: &fantasy.Usage{InputTokens: 0, OutputTokens: 150},
|
||||
FinalUsage: &kit.LLMUsage{InputTokens: 0, OutputTokens: 150},
|
||||
}, "prompt", false)
|
||||
|
||||
usage.mu.Lock()
|
||||
@@ -632,33 +630,39 @@ func TestUpdateUsageFromTurnResult_recordsWhenInputTokensZero(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateUsageFromTurnResult_contextTokensUsesInputOnly verifies that context
|
||||
// window fill uses InputTokens only (not input+output). The API's InputTokens
|
||||
// already includes the full conversation history; adding output would double-count.
|
||||
func TestUpdateUsageFromTurnResult_contextTokensUsesInputOnly(t *testing.T) {
|
||||
// TestUpdateUsageFromTurnResult_contextTokensUsesAllCategories verifies that
|
||||
// context window fill uses all token categories from the final API call:
|
||||
// InputTokens + CacheReadTokens + CacheCreationTokens + OutputTokens.
|
||||
// With Anthropic prompt caching, InputTokens can be near-zero while
|
||||
// CacheReadTokens holds the bulk of the context.
|
||||
func TestUpdateUsageFromTurnResult_contextTokensUsesAllCategories(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
defer app.Close()
|
||||
|
||||
app.updateUsageFromTurnResult(&kit.TurnResult{
|
||||
Response: "ok",
|
||||
TotalUsage: &fantasy.Usage{
|
||||
InputTokens: 1000,
|
||||
OutputTokens: 200,
|
||||
TotalUsage: &kit.LLMUsage{
|
||||
InputTokens: 3,
|
||||
OutputTokens: 5,
|
||||
CacheReadTokens: 0,
|
||||
CacheCreationTokens: 4317,
|
||||
},
|
||||
FinalUsage: &fantasy.Usage{
|
||||
InputTokens: 1000, // Full context including history
|
||||
OutputTokens: 200,
|
||||
FinalUsage: &kit.LLMUsage{
|
||||
InputTokens: 3, // Non-cached input (small with caching)
|
||||
OutputTokens: 5, // Assistant output
|
||||
CacheReadTokens: 0, // No cache reads on first call
|
||||
CacheCreationTokens: 4317, // System prompt + tools written to cache
|
||||
},
|
||||
}, "prompt", false)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
|
||||
// Context tokens should be InputTokens only (1000), not input+output (1200)
|
||||
// because InputTokens already includes the full conversation history
|
||||
if usage.contextCalls != 1 || usage.lastContextTokens != 1000 {
|
||||
t.Fatalf("expected context tokens=1000 (InputTokens only), got calls=%d tokens=%d",
|
||||
usage.contextCalls, usage.lastContextTokens)
|
||||
// Context tokens should be Input + CacheRead + CacheCreate + Output = 4325
|
||||
expected := 3 + 0 + 4317 + 5
|
||||
if usage.contextCalls != 1 || usage.lastContextTokens != expected {
|
||||
t.Fatalf("expected context tokens=%d (all categories), got calls=%d tokens=%d",
|
||||
expected, usage.contextCalls, usage.lastContextTokens)
|
||||
}
|
||||
}
|
||||
|
||||
+27
-3
@@ -1,6 +1,6 @@
|
||||
package app
|
||||
|
||||
import "charm.land/fantasy"
|
||||
import kit "github.com/mark3labs/kit/pkg/kit"
|
||||
|
||||
// StreamChunkEvent is sent by the app layer when a streaming text delta arrives
|
||||
// from the LLM. Each chunk contains an incremental portion of the response.
|
||||
@@ -16,6 +16,11 @@ type ReasoningChunkEvent struct {
|
||||
Delta string
|
||||
}
|
||||
|
||||
// ReasoningCompleteEvent is sent when reasoning/thinking is finished, after
|
||||
// the last reasoning token has been processed. The TUI uses this to freeze
|
||||
// the reasoning duration counter.
|
||||
type ReasoningCompleteEvent struct{}
|
||||
|
||||
// ToolCallStartedEvent is sent when a tool call has been parsed and is about to execute.
|
||||
// It carries the tool name and its arguments for display purposes.
|
||||
type ToolCallStartedEvent struct {
|
||||
@@ -118,8 +123,8 @@ type SpinnerEvent struct {
|
||||
// MessageCreatedEvent is sent when a new message is added to the message store.
|
||||
// This allows the TUI to stay in sync with the conversation history.
|
||||
type MessageCreatedEvent struct {
|
||||
// Message is the fantasy message that was added to the store.
|
||||
Message fantasy.Message
|
||||
// Message is the message that was added to the store.
|
||||
Message kit.LLMMessage
|
||||
}
|
||||
|
||||
// CompactCompleteEvent is sent when a /compact operation finishes successfully.
|
||||
@@ -162,6 +167,25 @@ type ModelChangedEvent struct {
|
||||
// from its WidgetProvider on the next render cycle.
|
||||
type WidgetUpdateEvent struct{}
|
||||
|
||||
// ContentReloadEvent is sent when prompt templates or skills are reloaded
|
||||
// from disk (e.g. by a file watcher detecting changes). The TUI refreshes
|
||||
// its autocomplete entries and internal state from the provider callbacks.
|
||||
type ContentReloadEvent struct{}
|
||||
|
||||
// MCPToolsReadyEvent is sent when background MCP tool loading completes.
|
||||
// The TUI refreshes its tool names and MCP tool count from provider callbacks
|
||||
// so that /tools and the startup info bar reflect the loaded MCP tools.
|
||||
type MCPToolsReadyEvent struct{}
|
||||
|
||||
// MCPServerLoadedEvent is sent when a single MCP server finishes loading
|
||||
// (successfully or with error). The TUI displays a system message so users
|
||||
// see real-time progress as each server initializes.
|
||||
type MCPServerLoadedEvent struct {
|
||||
ServerName string
|
||||
ToolCount int
|
||||
Error error // nil on success
|
||||
}
|
||||
|
||||
// EditorTextSetEvent is sent when an extension calls ctx.SetEditorText to
|
||||
// pre-fill the input editor with text. The TUI handles this by setting the
|
||||
// textarea content and moving the cursor to the end.
|
||||
|
||||
@@ -3,14 +3,14 @@ package app
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"charm.land/fantasy"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// MessageStore is a thread-safe in-memory store for the conversation history.
|
||||
// On-disk persistence is handled by the TreeManager at the app/SDK layer.
|
||||
type MessageStore struct {
|
||||
mu sync.RWMutex
|
||||
messages []fantasy.Message
|
||||
messages []kit.LLMMessage
|
||||
}
|
||||
|
||||
// NewMessageStore creates an empty MessageStore.
|
||||
@@ -20,14 +20,14 @@ func NewMessageStore() *MessageStore {
|
||||
|
||||
// NewMessageStoreWithMessages creates a MessageStore pre-populated with the
|
||||
// given messages. This is used when loading an existing session at startup.
|
||||
func NewMessageStoreWithMessages(msgs []fantasy.Message) *MessageStore {
|
||||
cp := make([]fantasy.Message, len(msgs))
|
||||
func NewMessageStoreWithMessages(msgs []kit.LLMMessage) *MessageStore {
|
||||
cp := make([]kit.LLMMessage, len(msgs))
|
||||
copy(cp, msgs)
|
||||
return &MessageStore{messages: cp}
|
||||
}
|
||||
|
||||
// Add appends a single message to the store.
|
||||
func (s *MessageStore) Add(msg fantasy.Message) {
|
||||
func (s *MessageStore) Add(msg kit.LLMMessage) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.messages = append(s.messages, msg)
|
||||
@@ -36,22 +36,22 @@ func (s *MessageStore) Add(msg fantasy.Message) {
|
||||
// Replace replaces the entire message history with the given slice. This is
|
||||
// used after an agent step returns the full updated conversation (including
|
||||
// tool calls and results).
|
||||
func (s *MessageStore) Replace(msgs []fantasy.Message) {
|
||||
func (s *MessageStore) Replace(msgs []kit.LLMMessage) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
cp := make([]fantasy.Message, len(msgs))
|
||||
cp := make([]kit.LLMMessage, len(msgs))
|
||||
copy(cp, msgs)
|
||||
s.messages = cp
|
||||
}
|
||||
|
||||
// GetAll returns a snapshot copy of the current message slice.
|
||||
// The returned slice is safe to modify without affecting the store.
|
||||
func (s *MessageStore) GetAll() []fantasy.Message {
|
||||
func (s *MessageStore) GetAll() []kit.LLMMessage {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
cp := make([]fantasy.Message, len(s.messages))
|
||||
cp := make([]kit.LLMMessage, len(s.messages))
|
||||
copy(cp, s.messages)
|
||||
return cp
|
||||
}
|
||||
|
||||
@@ -4,16 +4,29 @@ import (
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// makeTextMsg builds a minimal fantasy.Message with a single TextPart.
|
||||
func makeTextMsg(role, text string) fantasy.Message {
|
||||
return fantasy.Message{
|
||||
Role: fantasy.MessageRole(role),
|
||||
// makeTextMsg builds a minimal kit.LLMMessage using fantasy.NewUserMessage
|
||||
// or constructing with the given role.
|
||||
func makeTextMsg(role, text string) kit.LLMMessage {
|
||||
return kit.LLMMessage{
|
||||
Role: kit.LLMMessageRole(role),
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: text}},
|
||||
}
|
||||
}
|
||||
|
||||
// textOf extracts the plain text from an LLMMessage for assertions.
|
||||
func textOf(msg kit.LLMMessage) string {
|
||||
for _, part := range msg.Content {
|
||||
if tp, ok := part.(fantasy.TextPart); ok {
|
||||
return tp.Text
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// NewMessageStore / NewMessageStoreWithMessages
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -29,7 +42,7 @@ func TestNewMessageStore_empty(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewMessageStoreWithMessages_preloaded(t *testing.T) {
|
||||
msgs := []fantasy.Message{
|
||||
msgs := []kit.LLMMessage{
|
||||
makeTextMsg("user", "hello"),
|
||||
makeTextMsg("assistant", "hi"),
|
||||
}
|
||||
@@ -42,7 +55,7 @@ func TestNewMessageStoreWithMessages_preloaded(t *testing.T) {
|
||||
// NewMessageStoreWithMessages must deep-copy the slice so that external
|
||||
// modifications don't affect the store.
|
||||
func TestNewMessageStoreWithMessages_isolatesInput(t *testing.T) {
|
||||
msgs := []fantasy.Message{makeTextMsg("user", "hello")}
|
||||
msgs := []kit.LLMMessage{makeTextMsg("user", "hello")}
|
||||
s := NewMessageStoreWithMessages(msgs)
|
||||
|
||||
// Mutate the source slice.
|
||||
@@ -52,9 +65,8 @@ func TestNewMessageStoreWithMessages_isolatesInput(t *testing.T) {
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(got))
|
||||
}
|
||||
tp, ok := got[0].Content[0].(fantasy.TextPart)
|
||||
if !ok || tp.Text != "hello" {
|
||||
t.Fatalf("store was mutated by external slice change; got %q", tp.Text)
|
||||
if textOf(got[0]) != "hello" {
|
||||
t.Fatalf("store was mutated by external slice change; got %q", textOf(got[0]))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,9 +92,8 @@ func TestAdd_preservesOrder(t *testing.T) {
|
||||
}
|
||||
got := s.GetAll()
|
||||
for i, expected := range texts {
|
||||
tp, ok := got[i].Content[0].(fantasy.TextPart)
|
||||
if !ok || tp.Text != expected {
|
||||
t.Fatalf("message[%d]: expected %q, got %q", i, expected, tp.Text)
|
||||
if textOf(got[i]) != expected {
|
||||
t.Fatalf("message[%d]: expected %q, got %q", i, expected, textOf(got[i]))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -95,7 +106,7 @@ func TestReplace_swapsHistory(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s.Add(makeTextMsg("user", "old"))
|
||||
|
||||
replacement := []fantasy.Message{
|
||||
replacement := []kit.LLMMessage{
|
||||
makeTextMsg("user", "new1"),
|
||||
makeTextMsg("assistant", "new2"),
|
||||
}
|
||||
@@ -105,25 +116,22 @@ func TestReplace_swapsHistory(t *testing.T) {
|
||||
t.Fatalf("expected 2 messages after replace, got %d", s.Len())
|
||||
}
|
||||
got := s.GetAll()
|
||||
tp0, _ := got[0].Content[0].(fantasy.TextPart)
|
||||
tp1, _ := got[1].Content[0].(fantasy.TextPart)
|
||||
if tp0.Text != "new1" || tp1.Text != "new2" {
|
||||
t.Fatalf("unexpected messages after replace: %q %q", tp0.Text, tp1.Text)
|
||||
if textOf(got[0]) != "new1" || textOf(got[1]) != "new2" {
|
||||
t.Fatalf("unexpected messages after replace: %q %q", textOf(got[0]), textOf(got[1]))
|
||||
}
|
||||
}
|
||||
|
||||
// Replace must deep-copy the incoming slice.
|
||||
func TestReplace_isolatesInput(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
replacement := []fantasy.Message{makeTextMsg("user", "original")}
|
||||
replacement := []kit.LLMMessage{makeTextMsg("user", "original")}
|
||||
s.Replace(replacement)
|
||||
|
||||
replacement[0] = makeTextMsg("user", "mutated")
|
||||
|
||||
got := s.GetAll()
|
||||
tp, _ := got[0].Content[0].(fantasy.TextPart)
|
||||
if tp.Text != "original" {
|
||||
t.Fatalf("store was mutated by external slice change after Replace; got %q", tp.Text)
|
||||
if textOf(got[0]) != "original" {
|
||||
t.Fatalf("store was mutated by external slice change after Replace; got %q", textOf(got[0]))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,9 +148,8 @@ func TestGetAll_returnsCopy(t *testing.T) {
|
||||
got[0] = makeTextMsg("user", "mutated")
|
||||
|
||||
internal := s.GetAll()
|
||||
tp, _ := internal[0].Content[0].(fantasy.TextPart)
|
||||
if tp.Text != "hello" {
|
||||
t.Fatalf("GetAll returned non-copy; store was mutated to %q", tp.Text)
|
||||
if textOf(internal[0]) != "hello" {
|
||||
t.Fatalf("GetAll returned non-copy; store was mutated to %q", textOf(internal[0]))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,9 +186,8 @@ func TestClear_allowsSubsequentAdds(t *testing.T) {
|
||||
t.Fatalf("expected 1 message after Clear+Add, got %d", s.Len())
|
||||
}
|
||||
got := s.GetAll()
|
||||
tp, _ := got[0].Content[0].(fantasy.TextPart)
|
||||
if tp.Text != "after" {
|
||||
t.Fatalf("expected %q, got %q", "after", tp.Text)
|
||||
if textOf(got[0]) != "after" {
|
||||
t.Fatalf("expected %q, got %q", "after", textOf(got[0]))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -21,8 +21,10 @@ type UsageUpdater interface {
|
||||
// the provider does not return exact counts.
|
||||
EstimateAndUpdateUsage(inputText, outputText string)
|
||||
// SetContextTokens records the approximate current context window fill
|
||||
// level. This should be the final API call's input+output tokens (from
|
||||
// FinalResponse.Usage), NOT the aggregate TotalUsage.
|
||||
// level. This should be the sum of ALL token categories from the last
|
||||
// API call: InputTokens + CacheReadTokens + CacheCreationTokens +
|
||||
// OutputTokens. With Anthropic prompt caching, InputTokens can be
|
||||
// near-zero while CacheReadTokens holds the bulk of the context.
|
||||
SetContextTokens(tokens int)
|
||||
}
|
||||
|
||||
@@ -67,10 +69,6 @@ type Options struct {
|
||||
// Debug enables verbose debug logging.
|
||||
Debug bool
|
||||
|
||||
// CompactMode selects the compact renderer instead of the block renderer for
|
||||
// message formatting.
|
||||
CompactMode bool
|
||||
|
||||
// UsageTracker is an optional callback for recording token usage after each
|
||||
// agent step. When non-nil, the app layer calls UpdateUsage (or
|
||||
// EstimateAndUpdateUsage as a fallback) using the usage data returned by the
|
||||
|
||||
@@ -428,6 +428,10 @@ type PreviousCompaction struct {
|
||||
ModifiedFiles []string
|
||||
}
|
||||
|
||||
// StreamCallback is called for each chunk of text during streaming compaction.
|
||||
// Return a non-nil error to cancel the stream.
|
||||
type StreamCallback func(delta string) error
|
||||
|
||||
// Compact summarises older messages using the LLM, returning the compaction
|
||||
// result and a new message slice (summary message + preserved recent
|
||||
// messages).
|
||||
@@ -442,6 +446,8 @@ type PreviousCompaction struct {
|
||||
//
|
||||
// prev carries file tracking from a previous compaction for cumulative
|
||||
// tracking. Pass nil if there is no prior compaction.
|
||||
// onChunk is an optional callback for streaming summary text. Pass nil for
|
||||
// non-streaming compaction.
|
||||
func Compact(
|
||||
ctx context.Context,
|
||||
model fantasy.LanguageModel,
|
||||
@@ -449,6 +455,7 @@ func Compact(
|
||||
opts CompactionOptions,
|
||||
customInstructions string,
|
||||
prev *PreviousCompaction,
|
||||
onChunk StreamCallback,
|
||||
) (*CompactionResult, []fantasy.Message, error) {
|
||||
opts.defaults()
|
||||
|
||||
@@ -487,9 +494,9 @@ func Compact(
|
||||
var err error
|
||||
|
||||
if IsSplitTurn(messages, cutPoint) {
|
||||
summaryText, err = compactSplitTurn(ctx, model, oldMessages, messages, cutPoint, opts, customInstructions)
|
||||
summaryText, err = compactSplitTurn(ctx, model, oldMessages, messages, cutPoint, opts, customInstructions, onChunk)
|
||||
} else {
|
||||
summaryText, err = compactNormal(ctx, model, oldMessages, opts, customInstructions)
|
||||
summaryText, err = compactNormal(ctx, model, oldMessages, opts, customInstructions, onChunk)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@@ -527,15 +534,17 @@ func Compact(
|
||||
}
|
||||
|
||||
// compactNormal generates a summary for a clean turn-boundary cut.
|
||||
// If onChunk is provided, text deltas are streamed to it.
|
||||
func compactNormal(
|
||||
ctx context.Context,
|
||||
model fantasy.LanguageModel,
|
||||
oldMessages []fantasy.Message,
|
||||
opts CompactionOptions,
|
||||
customInstructions string,
|
||||
onChunk StreamCallback,
|
||||
) (string, error) {
|
||||
conversationText := serializeMessages(oldMessages)
|
||||
return generateSummary(ctx, model, conversationText, opts, customInstructions)
|
||||
return generateSummary(ctx, model, conversationText, opts, customInstructions, onChunk)
|
||||
}
|
||||
|
||||
// compactSplitTurn handles the case where the cut point lands mid-turn.
|
||||
@@ -546,6 +555,7 @@ func compactNormal(
|
||||
//
|
||||
// The merged result preserves context from both the older history and the
|
||||
// beginning of the current long turn.
|
||||
// If onChunk is provided, both summaries and the separator are streamed.
|
||||
func compactSplitTurn(
|
||||
ctx context.Context,
|
||||
model fantasy.LanguageModel,
|
||||
@@ -554,6 +564,7 @@ func compactSplitTurn(
|
||||
cutPoint int,
|
||||
opts CompactionOptions,
|
||||
customInstructions string,
|
||||
onChunk StreamCallback,
|
||||
) (string, error) {
|
||||
// Find where the split turn starts.
|
||||
turnStart := findTurnStart(allMessages, cutPoint)
|
||||
@@ -573,12 +584,19 @@ func compactSplitTurn(
|
||||
// Generate history summary if there are complete turns before the split.
|
||||
if len(historyMessages) >= 2 {
|
||||
historySummary, err = generateSummary(ctx, model,
|
||||
serializeMessages(historyMessages), opts, "")
|
||||
serializeMessages(historyMessages), opts, "", onChunk)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("split turn history summary failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Stream the separator between history and turn prefix summaries.
|
||||
if onChunk != nil && historySummary != "" {
|
||||
if err := onChunk("\n\n---\n\n## Current Turn (in progress)\n\n"); err != nil {
|
||||
return "", fmt.Errorf("streaming separator failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Generate turn prefix summary.
|
||||
turnPrefixText := serializeMessages(turnPrefixMessages)
|
||||
turnPrefixPrompt := "The messages above are the BEGINNING of a long turn that was split. " +
|
||||
@@ -588,16 +606,10 @@ func compactSplitTurn(
|
||||
turnPrefixPrompt += "\n\nAdditional instructions: " + customInstructions
|
||||
}
|
||||
|
||||
summaryAgent := fantasy.NewAgent(model,
|
||||
fantasy.WithSystemPrompt(defaultSystemPrompt),
|
||||
)
|
||||
result, err := summaryAgent.Generate(ctx, fantasy.AgentCall{
|
||||
Prompt: turnPrefixText + "\n\n" + turnPrefixPrompt,
|
||||
})
|
||||
turnPrefixSummary, err := generateSummary(ctx, model, turnPrefixText, opts, turnPrefixPrompt, onChunk)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("split turn prefix summary failed: %w", err)
|
||||
}
|
||||
turnPrefixSummary := result.Response.Content.Text()
|
||||
|
||||
// Merge the two summaries.
|
||||
if historySummary != "" && turnPrefixSummary != "" {
|
||||
@@ -610,12 +622,14 @@ func compactSplitTurn(
|
||||
}
|
||||
|
||||
// generateSummary calls the LLM to produce a structured summary.
|
||||
// If onChunk is provided, the summary is streamed using Agent.Stream().
|
||||
func generateSummary(
|
||||
ctx context.Context,
|
||||
model fantasy.LanguageModel,
|
||||
conversationText string,
|
||||
opts CompactionOptions,
|
||||
customInstructions string,
|
||||
onChunk StreamCallback,
|
||||
) (string, error) {
|
||||
userPrompt := opts.SummaryPrompt
|
||||
if userPrompt == "" {
|
||||
@@ -628,8 +642,31 @@ func generateSummary(
|
||||
summaryAgent := fantasy.NewAgent(model,
|
||||
fantasy.WithSystemPrompt(defaultSystemPrompt),
|
||||
)
|
||||
|
||||
prompt := conversationText + "\n\n" + userPrompt
|
||||
|
||||
// Use streaming if onChunk is provided.
|
||||
if onChunk != nil {
|
||||
var fullText strings.Builder
|
||||
_, err := summaryAgent.Stream(ctx, fantasy.AgentStreamCall{
|
||||
Prompt: prompt,
|
||||
OnTextDelta: func(_, delta string) error {
|
||||
if delta != "" {
|
||||
fullText.WriteString(delta)
|
||||
return onChunk(delta)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("compaction summarisation (streaming) failed: %w", err)
|
||||
}
|
||||
return fullText.String(), nil
|
||||
}
|
||||
|
||||
// Non-streaming path.
|
||||
result, err := summaryAgent.Generate(ctx, fantasy.AgentCall{
|
||||
Prompt: conversationText + "\n\n" + userPrompt,
|
||||
Prompt: prompt,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("compaction summarisation failed: %w", err)
|
||||
|
||||
@@ -243,7 +243,7 @@ func TestCompact_TooFewMessages(t *testing.T) {
|
||||
makeTextMessageN(fantasy.MessageRoleUser, 400),
|
||||
}
|
||||
|
||||
result, newMsgs, err := Compact(context.TODO(), nil, msgs, CompactionOptions{}, "", nil)
|
||||
result, newMsgs, err := Compact(context.TODO(), nil, msgs, CompactionOptions{}, "", nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -262,7 +262,7 @@ func TestCompact_WithinBudget(t *testing.T) {
|
||||
makeTextMessageN(fantasy.MessageRoleAssistant, 400),
|
||||
}
|
||||
|
||||
result, newMsgs, err := Compact(context.TODO(), nil, msgs, CompactionOptions{}, "", nil)
|
||||
result, newMsgs, err := Compact(context.TODO(), nil, msgs, CompactionOptions{}, "", nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
@@ -157,11 +157,28 @@ type Theme struct {
|
||||
Markdown MarkdownThemeConfig `json:"markdown,omitzero" yaml:"markdown,omitempty"`
|
||||
}
|
||||
|
||||
// GenerationParams defines generation parameter defaults that can be attached
|
||||
// to individual models. These act as model-level defaults — CLI flags and
|
||||
// global config values take precedence when explicitly set.
|
||||
type GenerationParams struct {
|
||||
MaxTokens *int `json:"maxTokens,omitempty" yaml:"maxTokens,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
TopP *float32 `json:"topP,omitempty" yaml:"topP,omitempty"`
|
||||
TopK *int32 `json:"topK,omitempty" yaml:"topK,omitempty"`
|
||||
FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty" yaml:"frequencyPenalty,omitempty"`
|
||||
PresencePenalty *float32 `json:"presencePenalty,omitempty" yaml:"presencePenalty,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty" yaml:"stopSequences,omitempty"`
|
||||
ThinkingLevel string `json:"thinkingLevel,omitempty" yaml:"thinkingLevel,omitempty"`
|
||||
SystemPrompt string `json:"systemPrompt,omitempty" yaml:"systemPrompt,omitempty"`
|
||||
}
|
||||
|
||||
// CustomModelConfig defines a custom model that can be used with custom/custom
|
||||
// or other custom/ prefixed models. These models are loaded from the config file
|
||||
// and merged into the custom provider in the model registry.
|
||||
type CustomModelConfig struct {
|
||||
Name string `json:"name" yaml:"name"`
|
||||
BaseURL string `json:"baseUrl,omitempty" yaml:"baseUrl,omitempty"`
|
||||
APIKey string `json:"apiKey,omitempty" yaml:"apiKey,omitempty"`
|
||||
Family string `json:"family,omitempty" yaml:"family,omitempty"`
|
||||
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
|
||||
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
|
||||
@@ -169,6 +186,11 @@ type CustomModelConfig struct {
|
||||
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
|
||||
Cost CostConfig `json:"cost" yaml:"cost"`
|
||||
Limit LimitConfig `json:"limit" yaml:"limit"`
|
||||
|
||||
// Generation parameter defaults for this model.
|
||||
// These are applied when the user hasn't explicitly set the corresponding
|
||||
// CLI flag or global config value.
|
||||
Params GenerationParams `json:"params,omitzero" yaml:"params,omitempty"`
|
||||
}
|
||||
|
||||
// CostConfig defines the pricing for a custom model.
|
||||
@@ -191,18 +213,19 @@ type Config struct {
|
||||
Model string `json:"model,omitempty" yaml:"model,omitempty"`
|
||||
MaxSteps int `json:"max-steps,omitempty" yaml:"max-steps,omitempty"`
|
||||
Debug bool `json:"debug,omitempty" yaml:"debug,omitempty"`
|
||||
Compact bool `json:"compact,omitempty" yaml:"compact,omitempty"`
|
||||
SystemPrompt string `json:"system-prompt,omitempty" yaml:"system-prompt,omitempty"`
|
||||
ProviderAPIKey string `json:"provider-api-key,omitempty" yaml:"provider-api-key,omitempty"`
|
||||
ProviderURL string `json:"provider-url,omitempty" yaml:"provider-url,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty" yaml:"stream,omitempty"`
|
||||
Theme any `json:"theme" yaml:"theme"`
|
||||
// Model generation parameters
|
||||
MaxTokens int `json:"max-tokens,omitempty" yaml:"max-tokens,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
TopP *float32 `json:"top-p,omitempty" yaml:"top-p,omitempty"`
|
||||
TopK *int32 `json:"top-k,omitempty" yaml:"top-k,omitempty"`
|
||||
StopSequences []string `json:"stop-sequences,omitempty" yaml:"stop-sequences,omitempty"`
|
||||
MaxTokens int `json:"max-tokens,omitempty" yaml:"max-tokens,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
TopP *float32 `json:"top-p,omitempty" yaml:"top-p,omitempty"`
|
||||
TopK *int32 `json:"top-k,omitempty" yaml:"top-k,omitempty"`
|
||||
FrequencyPenalty *float32 `json:"frequency-penalty,omitempty" yaml:"frequency-penalty,omitempty"`
|
||||
PresencePenalty *float32 `json:"presence-penalty,omitempty" yaml:"presence-penalty,omitempty"`
|
||||
StopSequences []string `json:"stop-sequences,omitempty" yaml:"stop-sequences,omitempty"`
|
||||
|
||||
// Thinking / extended reasoning
|
||||
ThinkingLevel string `json:"thinking-level,omitempty" yaml:"thinking-level,omitempty"`
|
||||
@@ -216,6 +239,12 @@ type Config struct {
|
||||
|
||||
// Custom model definitions (under custom/ provider)
|
||||
CustomModels map[string]CustomModelConfig `json:"customModels,omitempty" yaml:"customModels,omitempty"`
|
||||
|
||||
// Per-model generation parameter overrides. Keys are "provider/model" strings
|
||||
// (e.g. "anthropic/claude-sonnet-4-5-20250929", "openai/gpt-4o"). These
|
||||
// settings act as model-level defaults — CLI flags and global config values
|
||||
// take precedence when explicitly set.
|
||||
ModelSettings map[string]GenerationParams `json:"modelSettings,omitempty" yaml:"modelSettings,omitempty"`
|
||||
}
|
||||
|
||||
// GetTransportType returns the transport type for the server config, mapping
|
||||
@@ -364,16 +393,55 @@ mcpServers:
|
||||
# debug: false # Enable debug logging
|
||||
# system-prompt: "/path/to/system-prompt.txt" # System prompt text file
|
||||
|
||||
# Model generation parameters (all optional)
|
||||
# Model generation parameters (all optional, apply globally to all models)
|
||||
# max-tokens: 4096 # Maximum tokens in response
|
||||
# temperature: 0.7 # Randomness (0.0-1.0)
|
||||
# top-p: 0.95 # Nucleus sampling (0.0-1.0)
|
||||
# top-k: 40 # Top K sampling
|
||||
# frequency-penalty: 0.0 # Penalize frequent tokens (0.0-2.0)
|
||||
# presence-penalty: 0.0 # Penalize present tokens (0.0-2.0)
|
||||
# stop-sequences: ["Human:", "Assistant:"] # Custom stop sequences
|
||||
|
||||
# Per-model generation parameter overrides (apply to specific models)
|
||||
# These act as model-level defaults — CLI flags and global settings above take precedence.
|
||||
# Keys are "provider/model" strings matching the model you use.
|
||||
# modelSettings:
|
||||
# anthropic/claude-sonnet-4-5-20250929:
|
||||
# temperature: 0.3
|
||||
# maxTokens: 8192
|
||||
# openai/gpt-4o:
|
||||
# temperature: 0.7
|
||||
# topP: 0.95
|
||||
# topK: 40
|
||||
# frequencyPenalty: 0.1
|
||||
# presencePenalty: 0.1
|
||||
# anthropic/claude-opus-4-6:
|
||||
# thinkingLevel: "high"
|
||||
# maxTokens: 16384
|
||||
# systemPrompt: "You are a deep reasoning assistant." # or a file path
|
||||
|
||||
# API Configuration (can also use environment variables)
|
||||
# provider-api-key: "your-api-key" # API key for OpenAI, Anthropic, or Google
|
||||
# provider-url: "https://api.openai.com/v1" # Base URL for OpenAI, Anthropic, or Ollama
|
||||
|
||||
# Custom model definitions (under custom/ provider)
|
||||
# customModels:
|
||||
# my-local-llama:
|
||||
# name: "Local Llama 3"
|
||||
# baseUrl: "http://localhost:8080/v1"
|
||||
# family: "llama"
|
||||
# temperature: true
|
||||
# cost:
|
||||
# input: 0.0
|
||||
# output: 0.0
|
||||
# limit:
|
||||
# context: 131072
|
||||
# output: 8192
|
||||
# params: # Generation parameter defaults for this model
|
||||
# temperature: 0.8
|
||||
# topP: 0.95
|
||||
# topK: 40
|
||||
# systemPrompt: "You are a helpful local assistant."
|
||||
`
|
||||
|
||||
_, err = file.WriteString(content)
|
||||
|
||||
+1
-20
@@ -67,7 +67,7 @@ func executeRead(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return readDirectory(absPath)
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("'%s' is a directory, not a file. Use the ls tool to list directory contents.", args.Path)), nil
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(absPath)
|
||||
@@ -116,25 +116,6 @@ func executeRead(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
return fantasy.NewTextResponse(tr.Content), nil
|
||||
}
|
||||
|
||||
func readDirectory(absPath string) (fantasy.ToolResponse, error) {
|
||||
entries, err := os.ReadDir(absPath)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to read directory: %v", err)), nil
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if entry.IsDir() {
|
||||
name += "/"
|
||||
}
|
||||
result.WriteString(name + "\n")
|
||||
}
|
||||
|
||||
tr := truncateHead(result.String(), 500, defaultMaxBytes)
|
||||
return fantasy.NewTextResponse(tr.Content), nil
|
||||
}
|
||||
|
||||
// resolvePathWithWorkDir resolves a path to an absolute path relative to the
|
||||
// given workDir. If workDir is empty, os.Getwd() is used.
|
||||
func resolvePathWithWorkDir(path, workDir string) (string, error) {
|
||||
|
||||
+26
-33
@@ -130,13 +130,22 @@ func executeSubagent(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolRe
|
||||
), fmt.Errorf("no subagent spawner in context")
|
||||
}
|
||||
|
||||
// Detach from the parent's deadline so the subagent gets its own
|
||||
// independent timeout (applied downstream in Kit.Subagent). The parent
|
||||
// context may carry a tight deadline from the LLM generation loop or
|
||||
// other tool timeouts that would prematurely kill the subagent.
|
||||
// We preserve context values (spawner, etc.) and propagate parent
|
||||
// cancellation (e.g. user hits Ctrl-C) without inheriting the deadline.
|
||||
spawnCtx := detachedWithCancel(ctx)
|
||||
// Build a clean context for the subagent that inherits values (e.g. the
|
||||
// spawner callback) but is completely detached from the parent's
|
||||
// deadline AND cancellation. The subagent gets its own independent
|
||||
// timeout (applied downstream in Kit.Subagent).
|
||||
//
|
||||
// Why full detachment instead of propagating parent cancellation?
|
||||
// The parent context may already be done (deadline exceeded or
|
||||
// cancelled) by the time this tool handler executes — for example when
|
||||
// the generation loop context carries a deadline, when the user
|
||||
// double-ESC cancels mid-turn, or when parallel tool execution
|
||||
// encounters a race between stream completion and tool dispatch. Using
|
||||
// context.WithoutCancel (Go 1.21+) ensures the subagent always starts
|
||||
// cleanly with a fresh timeout, following the pattern used by crush for
|
||||
// shutdown-resilient child work. The subagent's own timeout
|
||||
// (defaultSubagentTimeout / user-specified) provides the safety net.
|
||||
spawnCtx := context.WithoutCancel(valuesContext{parent: ctx})
|
||||
|
||||
// Spawn in-process subagent.
|
||||
result, err := spawner(spawnCtx, call.ID, args.Task, args.Model, args.SystemPrompt, timeout)
|
||||
@@ -173,37 +182,21 @@ func executeSubagent(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolRe
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Context detachment
|
||||
// Context helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// detachedContext wraps a parent context, preserving its values but removing
|
||||
// its deadline and cancellation. This allows the subagent to have its own
|
||||
// independent timeout while still accessing context-stored values (e.g. the
|
||||
// subagent spawner function).
|
||||
type detachedContext struct {
|
||||
// valuesContext preserves a parent context's values (e.g. the subagent
|
||||
// spawner callback) while stripping its deadline and cancellation. Combined
|
||||
// with context.WithoutCancel() this gives the subagent a completely clean
|
||||
// context that only inherits value-based dependencies.
|
||||
type valuesContext struct {
|
||||
parent context.Context
|
||||
}
|
||||
|
||||
func (d detachedContext) Deadline() (time.Time, bool) { return time.Time{}, false }
|
||||
func (d detachedContext) Done() <-chan struct{} { return nil }
|
||||
func (d detachedContext) Err() error { return nil }
|
||||
func (d detachedContext) Value(key any) any { return d.parent.Value(key) }
|
||||
|
||||
// detachedWithCancel creates a new context that inherits values from the
|
||||
// parent but has no deadline. Cancellation of the parent is propagated: when
|
||||
// the parent is cancelled the returned context is also cancelled, but the
|
||||
// parent's deadline does not apply to the child.
|
||||
func detachedWithCancel(parent context.Context) context.Context {
|
||||
child, cancel := context.WithCancel(detachedContext{parent: parent})
|
||||
go func() {
|
||||
select {
|
||||
case <-parent.Done():
|
||||
cancel()
|
||||
case <-child.Done():
|
||||
}
|
||||
}()
|
||||
return child
|
||||
}
|
||||
func (v valuesContext) Deadline() (time.Time, bool) { return time.Time{}, false }
|
||||
func (v valuesContext) Done() <-chan struct{} { return nil }
|
||||
func (v valuesContext) Err() error { return nil }
|
||||
func (v valuesContext) Value(key any) any { return v.parent.Value(key) }
|
||||
|
||||
// truncateResponse limits the response length to avoid overwhelming context windows.
|
||||
func truncateResponse(s string, maxLen int) string {
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestValuesContext_StripsDeadlineAndCancellation(t *testing.T) {
|
||||
// Parent with a tight deadline.
|
||||
parent, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||
defer cancel()
|
||||
time.Sleep(5 * time.Millisecond) // Let deadline expire.
|
||||
|
||||
if parent.Err() == nil {
|
||||
t.Fatal("expected parent to be expired")
|
||||
}
|
||||
|
||||
vc := valuesContext{parent: parent}
|
||||
|
||||
if _, ok := vc.Deadline(); ok {
|
||||
t.Error("valuesContext should report no deadline")
|
||||
}
|
||||
if vc.Done() != nil {
|
||||
t.Error("valuesContext.Done() should return nil")
|
||||
}
|
||||
if vc.Err() != nil {
|
||||
t.Errorf("valuesContext.Err() should be nil, got %v", vc.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func TestValuesContext_PreservesValues(t *testing.T) {
|
||||
type testKey struct{}
|
||||
parent := context.WithValue(context.Background(), testKey{}, "hello")
|
||||
|
||||
vc := valuesContext{parent: parent}
|
||||
|
||||
got, ok := vc.Value(testKey{}).(string)
|
||||
if !ok || got != "hello" {
|
||||
t.Errorf("expected value 'hello', got %q (ok=%v)", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnContext_SurvivesCancelledParent(t *testing.T) {
|
||||
// Simulate the exact scenario from the bug: the parent generation
|
||||
// context is already cancelled when the subagent tool handler runs.
|
||||
parent, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancelled before detach.
|
||||
|
||||
// This is what executeSubagent now does:
|
||||
spawnCtx := context.WithoutCancel(valuesContext{parent: parent})
|
||||
|
||||
// The spawn context must be alive.
|
||||
if spawnCtx.Err() != nil {
|
||||
t.Fatalf("spawnCtx should be alive, got err: %v", spawnCtx.Err())
|
||||
}
|
||||
|
||||
// Adding a timeout should produce a working context.
|
||||
tCtx, tCancel := context.WithTimeout(spawnCtx, 5*time.Second)
|
||||
defer tCancel()
|
||||
|
||||
if tCtx.Err() != nil {
|
||||
t.Fatalf("timeout context should be alive, got err: %v", tCtx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnContext_SurvivesDeadlineExceededParent(t *testing.T) {
|
||||
// Simulate: parent had a deadline that already expired.
|
||||
parent, pCancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||
defer pCancel()
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
if parent.Err() != context.DeadlineExceeded {
|
||||
t.Fatalf("expected parent deadline exceeded, got: %v", parent.Err())
|
||||
}
|
||||
|
||||
spawnCtx := context.WithoutCancel(valuesContext{parent: parent})
|
||||
|
||||
if spawnCtx.Err() != nil {
|
||||
t.Fatalf("spawnCtx should be alive after deadline-exceeded parent, got: %v", spawnCtx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnContext_PreservesSpawnerValue(t *testing.T) {
|
||||
// Verify the subagent spawner callback survives context detachment.
|
||||
called := false
|
||||
spawner := SubagentSpawnFunc(func(ctx context.Context, toolCallID, prompt, model, systemPrompt string, timeout time.Duration) (*SubagentSpawnResult, error) {
|
||||
called = true
|
||||
return &SubagentSpawnResult{Response: "ok"}, nil
|
||||
})
|
||||
|
||||
parent := WithSubagentSpawner(context.Background(), spawner)
|
||||
// Cancel the parent.
|
||||
parentCtx, cancel := context.WithCancel(parent)
|
||||
cancel()
|
||||
|
||||
spawnCtx := context.WithoutCancel(valuesContext{parent: parentCtx})
|
||||
|
||||
// Should be able to retrieve the spawner from the detached context.
|
||||
recovered := getSubagentSpawner(spawnCtx)
|
||||
if recovered == nil {
|
||||
t.Fatal("spawner should be recoverable from detached context")
|
||||
}
|
||||
|
||||
result, err := recovered(spawnCtx, "tc1", "test task", "", "", time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("spawner call failed: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Error("spawner was not called")
|
||||
}
|
||||
if result.Response != "ok" {
|
||||
t.Errorf("expected 'ok', got %q", result.Response)
|
||||
}
|
||||
}
|
||||
@@ -77,6 +77,64 @@ type Context struct {
|
||||
// ctx.CancelAndSend("Stop what you're doing and focus on the tests")
|
||||
CancelAndSend func(string)
|
||||
|
||||
// Abort cancels the current agent turn (if running) and clears the
|
||||
// message queue. Unlike CancelAndSend, no new message is injected —
|
||||
// the agent simply stops. Safe to call when idle (no-op).
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ctx.Abort() // stop whatever the agent is doing
|
||||
Abort func()
|
||||
|
||||
// IsIdle returns true when the agent is not processing a turn.
|
||||
// Extensions can use this to decide whether to dispatch immediately
|
||||
// or queue work for later.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// if ctx.IsIdle() {
|
||||
// ctx.SendMessage("start new task")
|
||||
// }
|
||||
IsIdle func() bool
|
||||
|
||||
// Compact triggers context compaction, summarising older messages to
|
||||
// free context window space. Returns an error if compaction cannot
|
||||
// start (e.g. agent is busy or app is closed). The actual compaction
|
||||
// runs asynchronously; use OnComplete/OnError callbacks in
|
||||
// CompactConfig to observe the result.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// err := ctx.Compact(ext.CompactConfig{
|
||||
// OnComplete: func() { ctx.PrintInfo("Compaction done") },
|
||||
// OnError: func(errMsg string) { ctx.PrintError("Compact failed: " + errMsg) },
|
||||
// })
|
||||
Compact func(CompactConfig) error
|
||||
|
||||
// SendMultimodalMessage injects a message with file attachments (images,
|
||||
// documents) into the conversation and triggers a new agent turn. Files
|
||||
// are described by FilePart structs containing the raw bytes, filename,
|
||||
// and MIME type. If the agent is busy the message is queued.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// data, _ := os.ReadFile("photo.jpg")
|
||||
// ctx.SendMultimodalMessage("Describe this image", []ext.FilePart{
|
||||
// {Filename: "photo.jpg", Data: data, MediaType: "image/jpeg"},
|
||||
// })
|
||||
SendMultimodalMessage func(text string, files []FilePart)
|
||||
|
||||
// GetSessionUsage returns aggregated token usage and cost statistics
|
||||
// for the current session. This includes total input/output tokens,
|
||||
// cache read/write tokens, total cost, and request count.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// usage := ctx.GetSessionUsage()
|
||||
// fmt.Sprintf("Tokens: ↑%d ↓%d Cost: $%.3f",
|
||||
// usage.TotalInputTokens, usage.TotalOutputTokens, usage.TotalCost)
|
||||
GetSessionUsage func() SessionUsage
|
||||
|
||||
// SetWidget places or updates a persistent widget in the TUI. Widgets
|
||||
// remain visible across agent turns until explicitly removed. The
|
||||
// widget is identified by WidgetConfig.ID; calling SetWidget with the
|
||||
@@ -937,6 +995,48 @@ type StatusBarEntry struct {
|
||||
Priority int
|
||||
}
|
||||
|
||||
// CompactConfig configures a programmatic context compaction request.
|
||||
type CompactConfig struct {
|
||||
// CustomInstructions is optional text appended to the summary prompt
|
||||
// (e.g. "Focus on the API design decisions"). Empty uses the default.
|
||||
CustomInstructions string
|
||||
// OnComplete is called when compaction finishes successfully.
|
||||
// May be nil if the caller doesn't need notification.
|
||||
OnComplete func()
|
||||
// OnError is called when compaction fails. The argument is the error message.
|
||||
// May be nil if the caller doesn't need notification.
|
||||
OnError func(errMsg string)
|
||||
}
|
||||
|
||||
// FilePart describes a file attachment for multimodal messages. Extensions
|
||||
// use this with SendMultimodalMessage to attach images or documents.
|
||||
type FilePart struct {
|
||||
// Filename is the name of the file (e.g. "photo.jpg").
|
||||
Filename string
|
||||
// Data is the raw file content.
|
||||
Data []byte
|
||||
// MediaType is the MIME type (e.g. "image/jpeg", "application/pdf").
|
||||
MediaType string
|
||||
}
|
||||
|
||||
// SessionUsage contains aggregated token usage and cost statistics for
|
||||
// the current session. Extensions use this with GetSessionUsage() to
|
||||
// report usage information.
|
||||
type SessionUsage struct {
|
||||
// TotalInputTokens is the sum of input tokens across all requests.
|
||||
TotalInputTokens int
|
||||
// TotalOutputTokens is the sum of output tokens across all requests.
|
||||
TotalOutputTokens int
|
||||
// TotalCacheReadTokens is the sum of cache read tokens.
|
||||
TotalCacheReadTokens int
|
||||
// TotalCacheWriteTokens is the sum of cache write tokens.
|
||||
TotalCacheWriteTokens int
|
||||
// TotalCost is the total cost in USD across all requests.
|
||||
TotalCost float64
|
||||
// RequestCount is the number of LLM requests made in this session.
|
||||
RequestCount int
|
||||
}
|
||||
|
||||
// PrintBlockOpts configures a custom styled block for PrintBlock.
|
||||
type PrintBlockOpts struct {
|
||||
// Text is the main content to display.
|
||||
|
||||
@@ -154,6 +154,11 @@ func NewInstaller(projectDir string) *Installer {
|
||||
|
||||
// Install clones a git repository to the appropriate scope.
|
||||
func (i *Installer) Install(source *GitSource, scope InstallScope) error {
|
||||
return i.install(source, scope, nil)
|
||||
}
|
||||
|
||||
// install is the internal implementation that supports optional include paths.
|
||||
func (i *Installer) install(source *GitSource, scope InstallScope, includePaths []string) error {
|
||||
targetDir := i.getInstallPath(source, scope)
|
||||
|
||||
// Check if already installed
|
||||
@@ -199,6 +204,7 @@ func (i *Installer) Install(source *GitSource, scope InstallScope) error {
|
||||
Pinned: source.Pinned,
|
||||
Scope: scope,
|
||||
Installed: time.Now(),
|
||||
Include: includePaths,
|
||||
}
|
||||
if err := i.addToManifest(entry, scope); err != nil {
|
||||
// Don't fail the install, just log the error
|
||||
@@ -268,7 +274,22 @@ func (i *Installer) Update(source *GitSource, scope InstallScope) error {
|
||||
cleanCmd.Dir = targetDir
|
||||
_ = cleanCmd.Run() // Ignore errors - clean is best effort
|
||||
|
||||
// Update manifest timestamp
|
||||
// Update manifest timestamp, preserving existing fields like Include
|
||||
existing, _ := i.loadManifest(scope)
|
||||
var include []string
|
||||
var installed time.Time
|
||||
if existing != nil {
|
||||
for _, p := range existing.Packages {
|
||||
if p.Host+"/"+p.Path == source.Identity() {
|
||||
include = p.Include
|
||||
installed = p.Installed
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if installed.IsZero() {
|
||||
installed = time.Now()
|
||||
}
|
||||
entry := ManifestEntry{
|
||||
Source: source.String(),
|
||||
Repo: source.Repo,
|
||||
@@ -277,8 +298,9 @@ func (i *Installer) Update(source *GitSource, scope InstallScope) error {
|
||||
Ref: "",
|
||||
Pinned: false,
|
||||
Scope: scope,
|
||||
Installed: time.Now(),
|
||||
Installed: installed,
|
||||
Updated: time.Now(),
|
||||
Include: include,
|
||||
}
|
||||
_ = i.addToManifest(entry, scope) // Best effort - don't fail update if manifest fails
|
||||
|
||||
@@ -503,30 +525,7 @@ func (i *Installer) PreviewExtensions(source *GitSource) ([]ExtensionPreview, st
|
||||
// InstallWithInclude clones a repo and installs only the specified extensions.
|
||||
// includePaths are relative paths like "./git/main.go" - if empty, installs all.
|
||||
func (i *Installer) InstallWithInclude(source *GitSource, scope InstallScope, includePaths []string) error {
|
||||
// First, do a regular install
|
||||
if err := i.Install(source, scope); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If specific includes were requested, update the manifest
|
||||
if len(includePaths) > 0 {
|
||||
entry := ManifestEntry{
|
||||
Source: source.String(),
|
||||
Repo: source.Repo,
|
||||
Host: source.Host,
|
||||
Path: source.Path,
|
||||
Ref: source.Ref,
|
||||
Pinned: source.Pinned,
|
||||
Scope: scope,
|
||||
Include: includePaths,
|
||||
}
|
||||
|
||||
if err := addEntryToManifest(entry, scope); err != nil {
|
||||
return fmt.Errorf("updating manifest with includes: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return i.install(source, scope, includePaths)
|
||||
}
|
||||
|
||||
// CleanupTempDir removes a temporary directory used for preview.
|
||||
|
||||
@@ -34,15 +34,10 @@ func LoadExtensions(extraPaths []string) ([]LoadedExtension, error) {
|
||||
for _, p := range paths {
|
||||
ext, err := loadSingleExtension(p)
|
||||
if err != nil {
|
||||
log.Warn("skipping extension", "path", p, "err", err)
|
||||
continue
|
||||
}
|
||||
loaded = append(loaded, *ext)
|
||||
log.Debug("loaded extension", "path", p,
|
||||
"handlers", countHandlers(ext),
|
||||
"tools", len(ext.Tools),
|
||||
"commands", len(ext.Commands),
|
||||
"tool_renderers", len(ext.ToolRenderers))
|
||||
log.Debug("loaded extension", "path", p, "handlers", countHandlers(ext), "tools", len(ext.Tools), "commands", len(ext.Commands), "tool_renderers", len(ext.ToolRenderers))
|
||||
}
|
||||
return loaded, nil
|
||||
}
|
||||
@@ -133,7 +128,7 @@ func findExtensionsInDir(dir string) []string {
|
||||
|
||||
for _, entry := range entries {
|
||||
full := filepath.Join(dir, entry.Name())
|
||||
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".go") {
|
||||
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".go") && !strings.HasSuffix(entry.Name(), "_test.go") {
|
||||
results = append(results, full)
|
||||
} else if entry.IsDir() {
|
||||
main := filepath.Join(full, "main.go")
|
||||
@@ -190,9 +185,13 @@ func findExtensionsInRepo(repoPath string) []string {
|
||||
isExtDir := base == "extensions" || base == "ext" ||
|
||||
strings.HasSuffix(base, "-extensions") || strings.HasSuffix(base, "-ext")
|
||||
|
||||
isExamplesSubdir := relPath == "examples" || strings.HasPrefix(relPath, "examples/")
|
||||
// Allow walking into examples/ so we can reach examples/extensions/ etc,
|
||||
// but don't treat examples/ itself or non-extension subdirs as extension locations.
|
||||
if relPath == "examples" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !isExtDir && !isExamplesSubdir {
|
||||
if !isExtDir {
|
||||
mainPath := filepath.Join(path, "main.go")
|
||||
if _, err := os.Stat(mainPath); err == nil {
|
||||
if relPath == base { // Top-level directory
|
||||
@@ -202,13 +201,6 @@ func findExtensionsInRepo(repoPath string) []string {
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
if isExamplesSubdir || isExtDir {
|
||||
if !multiFileDirs[relPath] {
|
||||
multiFileDirs[relPath] = true
|
||||
results = append(results, mainPath)
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
@@ -227,7 +219,7 @@ func findExtensionsInRepo(repoPath string) []string {
|
||||
}
|
||||
|
||||
// It's a file
|
||||
if !strings.HasSuffix(info.Name(), ".go") {
|
||||
if !strings.HasSuffix(info.Name(), ".go") || strings.HasSuffix(info.Name(), "_test.go") {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -253,10 +253,13 @@ func ScanForExtensions(dir string) ([]ExtensionPreview, error) {
|
||||
isExtDir := base == "extensions" || base == "ext" ||
|
||||
strings.HasSuffix(base, "-extensions") || strings.HasSuffix(base, "-ext")
|
||||
|
||||
// Or check if it's a subdirectory of examples/ that might contain extensions
|
||||
isExamplesSubdir := relPath == "examples" || strings.HasPrefix(relPath, "examples/")
|
||||
// Allow walking into examples/ so we can reach examples/extensions/ etc,
|
||||
// but don't treat examples/ itself or non-extension subdirs as extension locations.
|
||||
if relPath == "examples" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !isExtDir && !isExamplesSubdir {
|
||||
if !isExtDir {
|
||||
// Check for main.go before skipping
|
||||
mainPath := filepath.Join(path, "main.go")
|
||||
if _, err := os.Stat(mainPath); err == nil {
|
||||
@@ -272,18 +275,6 @@ func ScanForExtensions(dir string) ([]ExtensionPreview, error) {
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
// Inside a valid extensions directory
|
||||
if isExamplesSubdir || isExtDir {
|
||||
if !multiFileDirs[relPath] {
|
||||
multiFileDirs[relPath] = true
|
||||
previews = append(previews, ExtensionPreview{
|
||||
Path: "./" + relPath + "/main.go",
|
||||
Name: deriveExtensionName(relPath+"/main.go", true),
|
||||
IsMain: true,
|
||||
})
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
}
|
||||
|
||||
// Not an extension location
|
||||
@@ -309,7 +300,7 @@ func ScanForExtensions(dir string) ([]ExtensionPreview, error) {
|
||||
}
|
||||
|
||||
// It's a file - check if it's a valid extension
|
||||
if !strings.HasSuffix(info.Name(), ".go") {
|
||||
if !strings.HasSuffix(info.Name(), ".go") || strings.HasSuffix(info.Name(), "_test.go") {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -2,12 +2,12 @@ package extensions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
@@ -86,6 +86,21 @@ func normalizeContext(ctx Context) Context {
|
||||
if ctx.CancelAndSend == nil {
|
||||
ctx.CancelAndSend = func(string) {}
|
||||
}
|
||||
if ctx.Abort == nil {
|
||||
ctx.Abort = func() {}
|
||||
}
|
||||
if ctx.IsIdle == nil {
|
||||
ctx.IsIdle = func() bool { return true }
|
||||
}
|
||||
if ctx.Compact == nil {
|
||||
ctx.Compact = func(CompactConfig) error { return fmt.Errorf("compact not available") }
|
||||
}
|
||||
if ctx.SendMultimodalMessage == nil {
|
||||
ctx.SendMultimodalMessage = func(string, []FilePart) {}
|
||||
}
|
||||
if ctx.GetSessionUsage == nil {
|
||||
ctx.GetSessionUsage = func() SessionUsage { return SessionUsage{} }
|
||||
}
|
||||
if ctx.SetWidget == nil {
|
||||
ctx.SetWidget = func(WidgetConfig) {}
|
||||
}
|
||||
@@ -355,10 +370,7 @@ func (r *Runner) Emit(event Event) (Result, error) {
|
||||
for _, handler := range handlers {
|
||||
result, err := safeCall(handler, event, ctx)
|
||||
if err != nil {
|
||||
log.Warn("extension handler error",
|
||||
"path", ext.Path,
|
||||
"event", event.Type(),
|
||||
"err", err)
|
||||
log.Printf("WARN extension handler error: path=%s event=%s err=%v", ext.Path, event.Type(), err)
|
||||
continue
|
||||
}
|
||||
if result == nil {
|
||||
@@ -692,9 +704,7 @@ func (r *Runner) EmitCustomEvent(name, data string) {
|
||||
safeInvoke := func(h func(string)) {
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
log.Warn("custom event handler panicked",
|
||||
"event", name,
|
||||
"err", fmt.Sprintf("%v", rec))
|
||||
log.Printf("WARN custom event handler panicked: event=%s err=%v", name, rec)
|
||||
}
|
||||
}()
|
||||
h(data)
|
||||
|
||||
@@ -31,6 +31,7 @@ func Symbols() interp.Exports {
|
||||
// Session types
|
||||
"SessionMessage": reflect.ValueOf((*SessionMessage)(nil)),
|
||||
"ExtensionEntry": reflect.ValueOf((*ExtensionEntry)(nil)),
|
||||
"SessionUsage": reflect.ValueOf((*SessionUsage)(nil)),
|
||||
|
||||
// Option types
|
||||
"OptionDef": reflect.ValueOf((*OptionDef)(nil)),
|
||||
@@ -44,6 +45,8 @@ func Symbols() interp.Exports {
|
||||
// LLM completion types
|
||||
"CompleteRequest": reflect.ValueOf((*CompleteRequest)(nil)),
|
||||
"CompleteResponse": reflect.ValueOf((*CompleteResponse)(nil)),
|
||||
"CompactConfig": reflect.ValueOf((*CompactConfig)(nil)),
|
||||
"FilePart": reflect.ValueOf((*FilePart)(nil)),
|
||||
|
||||
// Status bar types
|
||||
"StatusBarEntry": reflect.ValueOf((*StatusBarEntry)(nil)),
|
||||
|
||||
@@ -0,0 +1,192 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
// Watcher monitors extension directories for file changes and triggers
|
||||
// a reload callback when .go files are created, modified, or removed.
|
||||
// It uses fsnotify for kernel-level file notifications (inotify on Linux,
|
||||
// kqueue on macOS) with debouncing to coalesce rapid editor writes.
|
||||
type Watcher struct {
|
||||
watcher *fsnotify.Watcher
|
||||
onReload func()
|
||||
debounce time.Duration
|
||||
cancel context.CancelFunc
|
||||
done chan struct{}
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewWatcher creates a file watcher that monitors the given directories
|
||||
// for .go file changes. When a change is detected (after debouncing),
|
||||
// onReload is called. The watcher must be started with Start() and
|
||||
// stopped with Close().
|
||||
func NewWatcher(dirs []string, onReload func()) (*Watcher, error) {
|
||||
fsw, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating file watcher: %w", err)
|
||||
}
|
||||
|
||||
for _, dir := range dirs {
|
||||
// Watch the directory itself.
|
||||
if err := fsw.Add(dir); err != nil {
|
||||
log.Printf("DEBUG watcher: skipping directory: dir=%s err=%v", dir, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Also watch immediate subdirectories (for */main.go pattern).
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
subdir := filepath.Join(dir, entry.Name())
|
||||
if err := fsw.Add(subdir); err != nil {
|
||||
log.Printf("DEBUG watcher: skipping subdirectory: dir=%s err=%v", subdir, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &Watcher{
|
||||
watcher: fsw,
|
||||
onReload: onReload,
|
||||
debounce: 300 * time.Millisecond,
|
||||
done: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start begins watching for file changes. It blocks until the context
|
||||
// is cancelled or Close() is called. Typically called in a goroutine.
|
||||
func (w *Watcher) Start(ctx context.Context) {
|
||||
w.mu.Lock()
|
||||
ctx, w.cancel = context.WithCancel(ctx)
|
||||
w.mu.Unlock()
|
||||
|
||||
defer close(w.done)
|
||||
|
||||
var timer *time.Timer
|
||||
var timerC <-chan time.Time
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
return
|
||||
|
||||
case event, ok := <-w.watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Only care about .go files.
|
||||
if !strings.HasSuffix(event.Name, ".go") {
|
||||
continue
|
||||
}
|
||||
|
||||
// React to write, create, remove, rename events.
|
||||
if event.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Remove|fsnotify.Rename) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("DEBUG watcher: file changed: file=%s op=%s", event.Name, event.Op)
|
||||
|
||||
// Debounce: reset timer on each event.
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
timer = time.NewTimer(w.debounce)
|
||||
timerC = timer.C
|
||||
|
||||
case <-timerC:
|
||||
timerC = nil
|
||||
timer = nil
|
||||
log.Printf("DEBUG watcher: reloading extensions")
|
||||
w.onReload()
|
||||
|
||||
case err, ok := <-w.watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Printf("WARN watcher: error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the watcher and releases resources.
|
||||
func (w *Watcher) Close() error {
|
||||
w.mu.Lock()
|
||||
cancel := w.cancel
|
||||
w.mu.Unlock()
|
||||
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
|
||||
// Wait for the event loop to finish.
|
||||
<-w.done
|
||||
return w.watcher.Close()
|
||||
}
|
||||
|
||||
// WatchedDirs returns the directories to watch for extension changes.
|
||||
// This includes the global extensions directory and the project-local
|
||||
// .kit/extensions/ directory (if they exist). Explicit -e paths that
|
||||
// point to directories are also included; explicit file paths cause
|
||||
// their parent directory to be watched instead.
|
||||
func WatchedDirs(extraPaths []string) []string {
|
||||
var dirs []string
|
||||
seen := make(map[string]bool)
|
||||
|
||||
add := func(dir string) {
|
||||
abs, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if seen[abs] {
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the directory exists.
|
||||
info, err := os.Stat(abs)
|
||||
if err != nil || !info.IsDir() {
|
||||
return
|
||||
}
|
||||
|
||||
seen[abs] = true
|
||||
dirs = append(dirs, abs)
|
||||
}
|
||||
|
||||
// Global extensions dir.
|
||||
add(globalExtensionsDir())
|
||||
|
||||
// Project-local extensions dir.
|
||||
add(filepath.Join(".kit", "extensions"))
|
||||
|
||||
// Explicit paths that are directories.
|
||||
for _, p := range extraPaths {
|
||||
info, err := os.Stat(p)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if info.IsDir() {
|
||||
add(p)
|
||||
} else {
|
||||
// For explicit files, watch the parent directory.
|
||||
add(filepath.Dir(p))
|
||||
}
|
||||
}
|
||||
|
||||
return dirs
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestWatcher_ReloadsOnGoFileChange(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Write an initial extension file.
|
||||
extFile := filepath.Join(dir, "test.go")
|
||||
if err := os.WriteFile(extFile, []byte("package main\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var reloadCount atomic.Int32
|
||||
|
||||
w, err := NewWatcher([]string{dir}, func() {
|
||||
reloadCount.Add(1)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
go w.Start(t.Context())
|
||||
|
||||
// Modify the file.
|
||||
time.Sleep(50 * time.Millisecond) // let watcher settle
|
||||
if err := os.WriteFile(extFile, []byte("package main\n// changed\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Wait for debounce (300ms) + margin.
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
|
||||
if got := reloadCount.Load(); got != 1 {
|
||||
t.Errorf("expected 1 reload, got %d", got)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatcher_IgnoresNonGoFiles(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
var reloadCount atomic.Int32
|
||||
|
||||
w, err := NewWatcher([]string{dir}, func() {
|
||||
reloadCount.Add(1)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
go w.Start(t.Context())
|
||||
|
||||
// Write a non-.go file.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
txtFile := filepath.Join(dir, "notes.txt")
|
||||
if err := os.WriteFile(txtFile, []byte("hello"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Wait past the debounce window.
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
|
||||
if got := reloadCount.Load(); got != 0 {
|
||||
t.Errorf("expected 0 reloads for .txt file, got %d", got)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatcher_Debounces(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
extFile := filepath.Join(dir, "ext.go")
|
||||
if err := os.WriteFile(extFile, []byte("package main\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var reloadCount atomic.Int32
|
||||
|
||||
w, err := NewWatcher([]string{dir}, func() {
|
||||
reloadCount.Add(1)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
go w.Start(t.Context())
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Rapid-fire writes (simulating editor save: write temp, rename, etc.).
|
||||
for range 5 {
|
||||
if err := os.WriteFile(extFile, []byte("package main\n// changed\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Wait for debounce to fire.
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
|
||||
if got := reloadCount.Load(); got != 1 {
|
||||
t.Errorf("expected 1 debounced reload, got %d", got)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatchedDirs_Deduplicates(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
dirs := WatchedDirs([]string{dir, dir})
|
||||
|
||||
count := 0
|
||||
for _, d := range dirs {
|
||||
abs, _ := filepath.Abs(dir)
|
||||
if d == abs {
|
||||
count++
|
||||
}
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("expected directory to appear once, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatchedDirs_FileParent(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
file := filepath.Join(dir, "ext.go")
|
||||
if err := os.WriteFile(file, []byte("package main\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dirs := WatchedDirs([]string{file})
|
||||
|
||||
abs, _ := filepath.Abs(dir)
|
||||
found := false
|
||||
for _, d := range dirs {
|
||||
if d == abs {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("expected parent dir %s in watched dirs %v", abs, dirs)
|
||||
}
|
||||
}
|
||||
+103
-23
@@ -33,6 +33,10 @@ type AgentSetupOptions struct {
|
||||
// CoreTools overrides the default core tool set. If empty, core.AllTools()
|
||||
// is used. Allows SDK users to pass custom tools (e.g. with WithWorkDir).
|
||||
CoreTools []fantasy.AgentTool
|
||||
// DisableCoreTools, when true, prevents loading any core tools.
|
||||
// If both DisableCoreTools is true and CoreTools is empty, the agent
|
||||
// will have no tools (useful for simple chat completions).
|
||||
DisableCoreTools bool
|
||||
// ExtraTools are additional tools added alongside core, MCP, and extension
|
||||
// tools. They do not replace the defaults — they extend them.
|
||||
ExtraTools []fantasy.AgentTool
|
||||
@@ -40,6 +44,34 @@ type AgentSetupOptions struct {
|
||||
// wrapping. Used by the SDK hook system. Both wrappers compose:
|
||||
// extension wrapper runs first (inner), then this wrapper (outer).
|
||||
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
|
||||
|
||||
// ProviderConfig, when non-nil, is used directly instead of calling
|
||||
// BuildProviderConfig(). Callers that already hold viperInitMu can
|
||||
// pre-build this and release the lock before calling SetupAgent, so the
|
||||
// slow agent/MCP initialisation runs concurrently with other New() calls.
|
||||
ProviderConfig *models.ProviderConfig
|
||||
// Debug enables debug logging. When zero-value, viper is consulted.
|
||||
// Only meaningful when ProviderConfig is also set.
|
||||
Debug bool
|
||||
// NoExtensions skips extension loading. When false, viper is consulted.
|
||||
// Only meaningful when ProviderConfig is also set.
|
||||
NoExtensions bool
|
||||
// MaxSteps overrides the agent step limit. 0 means use viper value.
|
||||
// Only meaningful when ProviderConfig is also set.
|
||||
MaxSteps int
|
||||
// StreamingEnabled controls streaming. Only meaningful when ProviderConfig
|
||||
// is also set.
|
||||
StreamingEnabled bool
|
||||
// AuthHandler handles OAuth authorization for remote MCP servers.
|
||||
// When set, remote transports are configured with OAuth support.
|
||||
AuthHandler tools.MCPAuthHandler
|
||||
// TokenStoreFactory, if non-nil, creates a custom token store for each
|
||||
// remote MCP server's OAuth tokens. When nil, the default file-based
|
||||
// token store is used.
|
||||
TokenStoreFactory tools.TokenStoreFactory
|
||||
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
|
||||
// loading (successfully or with error). Called from the background goroutine.
|
||||
OnMCPServerLoaded func(serverName string, toolCount int, err error)
|
||||
}
|
||||
|
||||
// AgentSetupResult bundles the created agent and any debug logger so the caller
|
||||
@@ -54,15 +86,17 @@ type AgentSetupResult struct {
|
||||
|
||||
// BuildProviderConfig creates a *models.ProviderConfig from the current viper
|
||||
// state. All entry points (root, script, SDK) converge through this function.
|
||||
//
|
||||
// Generation parameter pointers (Temperature, TopP, etc.) are only set when
|
||||
// the user has explicitly configured them via CLI flag, environment variable,
|
||||
// or global config file. This allows per-model defaults from modelSettings
|
||||
// and customModels to fill in unset parameters downstream.
|
||||
func BuildProviderConfig() (*models.ProviderConfig, string, error) {
|
||||
systemPrompt, err := config.LoadSystemPrompt(viper.GetString("system-prompt"))
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to load system prompt: %w", err)
|
||||
}
|
||||
|
||||
temperature := float32(viper.GetFloat64("temperature"))
|
||||
topP := float32(viper.GetFloat64("top-p"))
|
||||
topK := int32(viper.GetInt("top-k"))
|
||||
numGPU := int32(viper.GetInt("num-gpu-layers"))
|
||||
mainGPU := int32(viper.GetInt("main-gpu"))
|
||||
|
||||
@@ -72,9 +106,6 @@ func BuildProviderConfig() (*models.ProviderConfig, string, error) {
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
ProviderURL: viper.GetString("provider-url"),
|
||||
MaxTokens: viper.GetInt("max-tokens"),
|
||||
Temperature: &temperature,
|
||||
TopP: &topP,
|
||||
TopK: &topK,
|
||||
StopSequences: viper.GetStringSlice("stop-sequences"),
|
||||
NumGPU: &numGPU,
|
||||
MainGPU: &mainGPU,
|
||||
@@ -82,21 +113,66 @@ func BuildProviderConfig() (*models.ProviderConfig, string, error) {
|
||||
ThinkingLevel: models.ParseThinkingLevel(viper.GetString("thinking-level")),
|
||||
}
|
||||
|
||||
// Only set generation parameter pointers when the user has explicitly
|
||||
// provided a value. This leaves nil pointers for unset params, allowing
|
||||
// per-model defaults (modelSettings / customModels params) to apply.
|
||||
if viper.IsSet("temperature") {
|
||||
v := float32(viper.GetFloat64("temperature"))
|
||||
cfg.Temperature = &v
|
||||
}
|
||||
if viper.IsSet("top-p") {
|
||||
v := float32(viper.GetFloat64("top-p"))
|
||||
cfg.TopP = &v
|
||||
}
|
||||
if viper.IsSet("top-k") {
|
||||
v := int32(viper.GetInt("top-k"))
|
||||
cfg.TopK = &v
|
||||
}
|
||||
if viper.IsSet("frequency-penalty") {
|
||||
v := float32(viper.GetFloat64("frequency-penalty"))
|
||||
cfg.FrequencyPenalty = &v
|
||||
}
|
||||
if viper.IsSet("presence-penalty") {
|
||||
v := float32(viper.GetFloat64("presence-penalty"))
|
||||
cfg.PresencePenalty = &v
|
||||
}
|
||||
|
||||
return cfg, systemPrompt, nil
|
||||
}
|
||||
|
||||
// SetupAgent creates an agent from the current viper state + the provided
|
||||
// options. It wraps BuildProviderConfig and agent.CreateAgent.
|
||||
func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult, error) {
|
||||
modelConfig, systemPrompt, err := BuildProviderConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
var modelConfig *models.ProviderConfig
|
||||
var systemPrompt string
|
||||
|
||||
if opts.ProviderConfig != nil {
|
||||
// Pre-built config supplied by caller (e.g. Kit.New after releasing
|
||||
// viperInitMu). Use it directly — no viper reads needed here.
|
||||
modelConfig = opts.ProviderConfig
|
||||
systemPrompt = modelConfig.SystemPrompt
|
||||
} else {
|
||||
var err error
|
||||
modelConfig, systemPrompt, err = BuildProviderConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve debug / no-extensions / max-steps / streaming: prefer explicit
|
||||
// fields (set when ProviderConfig was pre-built) over viper fallback.
|
||||
debugEnabled := opts.Debug || viper.GetBool("debug")
|
||||
noExtensions := opts.NoExtensions || viper.GetBool("no-extensions")
|
||||
maxSteps := opts.MaxSteps
|
||||
if maxSteps == 0 {
|
||||
maxSteps = viper.GetInt("max-steps")
|
||||
}
|
||||
streamingEnabled := opts.StreamingEnabled || viper.GetBool("stream")
|
||||
|
||||
// Create the appropriate debug logger.
|
||||
var debugLogger tools.DebugLogger
|
||||
var bufferedLogger *tools.BufferedDebugLogger
|
||||
if viper.GetBool("debug") {
|
||||
if debugEnabled {
|
||||
if opts.UseBufferedLogger {
|
||||
bufferedLogger = tools.NewBufferedDebugLogger(true)
|
||||
debugLogger = bufferedLogger
|
||||
@@ -108,7 +184,7 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
// Load extensions unless --no-extensions is set.
|
||||
var extRunner *extensions.Runner
|
||||
var extCreationOpts extensionCreationOpts
|
||||
if !viper.GetBool("no-extensions") {
|
||||
if !noExtensions {
|
||||
var extErr error
|
||||
extRunner, extCreationOpts, extErr = loadExtensions()
|
||||
if extErr != nil {
|
||||
@@ -137,18 +213,22 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
}
|
||||
|
||||
a, err := agent.CreateAgent(ctx, &agent.AgentCreationOptions{
|
||||
ModelConfig: modelConfig,
|
||||
MCPConfig: opts.MCPConfig,
|
||||
SystemPrompt: systemPrompt,
|
||||
MaxSteps: viper.GetInt("max-steps"),
|
||||
StreamingEnabled: viper.GetBool("stream"),
|
||||
ShowSpinner: opts.ShowSpinner,
|
||||
Quiet: opts.Quiet,
|
||||
SpinnerFunc: opts.SpinnerFunc,
|
||||
DebugLogger: debugLogger,
|
||||
CoreTools: opts.CoreTools,
|
||||
ToolWrapper: toolWrapper,
|
||||
ExtraTools: extraTools,
|
||||
ModelConfig: modelConfig,
|
||||
MCPConfig: opts.MCPConfig,
|
||||
SystemPrompt: systemPrompt,
|
||||
MaxSteps: maxSteps,
|
||||
StreamingEnabled: streamingEnabled,
|
||||
ShowSpinner: opts.ShowSpinner,
|
||||
Quiet: opts.Quiet,
|
||||
SpinnerFunc: opts.SpinnerFunc,
|
||||
DebugLogger: debugLogger,
|
||||
AuthHandler: opts.AuthHandler,
|
||||
TokenStoreFactory: opts.TokenStoreFactory,
|
||||
CoreTools: opts.CoreTools,
|
||||
DisableCoreTools: opts.DisableCoreTools,
|
||||
ToolWrapper: toolWrapper,
|
||||
ExtraTools: extraTools,
|
||||
OnMCPServerLoaded: opts.OnMCPServerLoaded,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create agent: %w", err)
|
||||
|
||||
+236
-9
@@ -2,6 +2,8 @@ package models
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
@@ -31,12 +33,14 @@ func loadCustomModelsFromConfig() map[string]ModelInfo {
|
||||
|
||||
// modelConfigToModelInfo converts a CustomModelConfig to a ModelInfo.
|
||||
func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
|
||||
return ModelInfo{
|
||||
info := ModelInfo{
|
||||
ID: modelID,
|
||||
Name: cfg.Name,
|
||||
Attachment: cfg.Attachment,
|
||||
Reasoning: cfg.Reasoning,
|
||||
Temperature: cfg.Temperature,
|
||||
BaseURL: cfg.BaseURL,
|
||||
APIKey: cfg.APIKey,
|
||||
Cost: Cost{
|
||||
Input: cfg.Cost.Input,
|
||||
Output: cfg.Cost.Output,
|
||||
@@ -46,19 +50,242 @@ func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
|
||||
Output: cfg.Limit.Output,
|
||||
},
|
||||
}
|
||||
|
||||
// Convert custom model generation params if any are set.
|
||||
if p := convertGenerationParams(cfg.Params); p != nil {
|
||||
info.Params = p
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// LoadModelSettingsFromConfig loads per-model generation parameter overrides
|
||||
// from the config file. Keys are "provider/model" strings. Returns nil if
|
||||
// no model settings are configured.
|
||||
func LoadModelSettingsFromConfig() map[string]*GenerationParams {
|
||||
if !viper.IsSet("modelSettings") {
|
||||
return nil
|
||||
}
|
||||
|
||||
var settings map[string]GenerationParamsConfig
|
||||
if err := viper.UnmarshalKey("modelSettings", &settings); err != nil {
|
||||
log.Printf("Warning: Failed to parse modelSettings: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make(map[string]*GenerationParams, len(settings))
|
||||
for modelKey, cfg := range settings {
|
||||
if p := convertGenerationParams(cfg); p != nil {
|
||||
result[modelKey] = p
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// convertGenerationParams converts a GenerationParamsConfig to a GenerationParams.
|
||||
// Returns nil if no parameters are set.
|
||||
func convertGenerationParams(cfg GenerationParamsConfig) *GenerationParams {
|
||||
p := &GenerationParams{}
|
||||
any := false
|
||||
|
||||
if cfg.MaxTokens != nil {
|
||||
p.MaxTokens = cfg.MaxTokens
|
||||
any = true
|
||||
}
|
||||
if cfg.Temperature != nil {
|
||||
p.Temperature = cfg.Temperature
|
||||
any = true
|
||||
}
|
||||
if cfg.TopP != nil {
|
||||
p.TopP = cfg.TopP
|
||||
any = true
|
||||
}
|
||||
if cfg.TopK != nil {
|
||||
p.TopK = cfg.TopK
|
||||
any = true
|
||||
}
|
||||
if cfg.FrequencyPenalty != nil {
|
||||
p.FrequencyPenalty = cfg.FrequencyPenalty
|
||||
any = true
|
||||
}
|
||||
if cfg.PresencePenalty != nil {
|
||||
p.PresencePenalty = cfg.PresencePenalty
|
||||
any = true
|
||||
}
|
||||
if len(cfg.StopSequences) > 0 {
|
||||
p.StopSequences = cfg.StopSequences
|
||||
any = true
|
||||
}
|
||||
if cfg.ThinkingLevel != "" {
|
||||
p.ThinkingLevel = ParseThinkingLevel(cfg.ThinkingLevel)
|
||||
any = true
|
||||
}
|
||||
if cfg.SystemPrompt != "" {
|
||||
p.SystemPrompt = cfg.SystemPrompt
|
||||
any = true
|
||||
}
|
||||
|
||||
if !any {
|
||||
return nil
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// ApplyModelSettings merges per-model generation parameter defaults from the
|
||||
// registry into a ProviderConfig. Model-level params are only applied for
|
||||
// fields where the user has not explicitly set a value (i.e., the
|
||||
// corresponding viper key is not set via CLI flag or global config).
|
||||
//
|
||||
// The lookup order is:
|
||||
// 1. modelSettings["provider/model"] from config (highest model-level priority)
|
||||
// 2. ModelInfo.Params from custom model definitions
|
||||
//
|
||||
// Both are overridden by explicit CLI flags / global config values.
|
||||
func ApplyModelSettings(config *ProviderConfig, modelInfo *ModelInfo) {
|
||||
provider, modelName, err := ParseModelString(config.ModelString)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Collect model-level params: modelSettings override > custom model params.
|
||||
// modelSettings takes priority because it's the more specific/intentional config.
|
||||
var params *GenerationParams
|
||||
|
||||
// First check modelSettings from config.
|
||||
if settings := LoadModelSettingsFromConfig(); settings != nil {
|
||||
modelKey := provider + "/" + modelName
|
||||
if p, ok := settings[modelKey]; ok {
|
||||
params = p
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to ModelInfo.Params (from custom model definitions).
|
||||
if params == nil && modelInfo != nil && modelInfo.Params != nil {
|
||||
params = modelInfo.Params
|
||||
}
|
||||
|
||||
if params == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Apply each parameter only when the user hasn't explicitly set it.
|
||||
// We check viper.IsSet() which returns true only when the key was
|
||||
// set via CLI flag, environment variable, or config file global section.
|
||||
|
||||
if params.MaxTokens != nil && !isExplicitlySet("max-tokens") {
|
||||
config.MaxTokens = *params.MaxTokens
|
||||
}
|
||||
if params.Temperature != nil && !isExplicitlySet("temperature") {
|
||||
config.Temperature = params.Temperature
|
||||
}
|
||||
if params.TopP != nil && !isExplicitlySet("top-p") {
|
||||
config.TopP = params.TopP
|
||||
}
|
||||
if params.TopK != nil && !isExplicitlySet("top-k") {
|
||||
config.TopK = params.TopK
|
||||
}
|
||||
if params.FrequencyPenalty != nil && !isExplicitlySet("frequency-penalty") {
|
||||
config.FrequencyPenalty = params.FrequencyPenalty
|
||||
}
|
||||
if params.PresencePenalty != nil && !isExplicitlySet("presence-penalty") {
|
||||
config.PresencePenalty = params.PresencePenalty
|
||||
}
|
||||
if len(params.StopSequences) > 0 && !isExplicitlySet("stop-sequences") {
|
||||
config.StopSequences = params.StopSequences
|
||||
}
|
||||
if params.ThinkingLevel != "" && !isExplicitlySet("thinking-level") {
|
||||
config.ThinkingLevel = params.ThinkingLevel
|
||||
}
|
||||
if params.SystemPrompt != "" && config.SystemPrompt == "" {
|
||||
// Resolve file paths: if the value points to an existing file, read it.
|
||||
// We check config.SystemPrompt == "" rather than isExplicitlySet because
|
||||
// viper.BindPFlag causes IsSet to return true even for unset flags.
|
||||
config.SystemPrompt = LoadSystemPromptValue(params.SystemPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
// LoadSystemPromptValue resolves a system prompt value that may be either
|
||||
// inline text or a file path. If the value is a path to an existing file,
|
||||
// its contents are read and returned. Otherwise the string is returned as-is.
|
||||
// This mirrors config.LoadSystemPrompt but lives in the models package to
|
||||
// avoid circular dependencies.
|
||||
func LoadSystemPromptValue(input string) string {
|
||||
if input == "" {
|
||||
return ""
|
||||
}
|
||||
if info, err := os.Stat(input); err == nil && !info.IsDir() {
|
||||
content, err := os.ReadFile(input)
|
||||
if err != nil {
|
||||
log.Printf("Warning: failed to read system prompt file %q: %v", input, err)
|
||||
return input
|
||||
}
|
||||
return strings.TrimSpace(string(content))
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
// isExplicitlySet returns true when the user has explicitly set a config key
|
||||
// via CLI flag, environment variable, or the global section of the config file.
|
||||
// Model-level defaults should not override explicitly set values.
|
||||
func isExplicitlySet(key string) bool {
|
||||
// viper.IsSet returns true if the key has been set in any of the
|
||||
// data stores (flag, env, config file, default). We need to check
|
||||
// whether the value was set at the global config level (not just
|
||||
// as a default). For generation params, the global config keys use
|
||||
// hyphenated names (e.g. "max-tokens", "top-p").
|
||||
//
|
||||
// Since viper merges all sources, IsSet returns true even for config
|
||||
// file values. This means global config file values (e.g.
|
||||
// temperature: 0.7 at the top level) will correctly take precedence
|
||||
// over model-level defaults, which is the desired behavior.
|
||||
return viper.IsSet(key)
|
||||
}
|
||||
|
||||
// GenerationParams holds per-model generation parameter defaults.
|
||||
// These are stored on ModelInfo and applied during provider creation.
|
||||
// Nil pointer fields mean "no model-level default" — the global config
|
||||
// or CLI flag value (if any) will be used instead.
|
||||
type GenerationParams struct {
|
||||
MaxTokens *int
|
||||
Temperature *float32
|
||||
TopP *float32
|
||||
TopK *int32
|
||||
FrequencyPenalty *float32
|
||||
PresencePenalty *float32
|
||||
StopSequences []string
|
||||
ThinkingLevel ThinkingLevel
|
||||
SystemPrompt string // Per-model system prompt (inline text or file path)
|
||||
}
|
||||
|
||||
// CustomModelConfig defines a custom model configuration loaded from the config file.
|
||||
// This is a duplicate here to avoid circular dependencies with internal/config.
|
||||
type CustomModelConfig struct {
|
||||
Name string `json:"name" yaml:"name"`
|
||||
Family string `json:"family,omitempty" yaml:"family,omitempty"`
|
||||
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
|
||||
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
|
||||
Temperature bool `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
|
||||
Cost CostConfig `json:"cost" yaml:"cost"`
|
||||
Limit LimitConfig `json:"limit" yaml:"limit"`
|
||||
Name string `json:"name" yaml:"name"`
|
||||
BaseURL string `json:"baseUrl,omitempty" yaml:"baseUrl,omitempty"`
|
||||
APIKey string `json:"apiKey,omitempty" yaml:"apiKey,omitempty"`
|
||||
Family string `json:"family,omitempty" yaml:"family,omitempty"`
|
||||
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
|
||||
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
|
||||
Temperature bool `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
|
||||
Cost CostConfig `json:"cost" yaml:"cost"`
|
||||
Limit LimitConfig `json:"limit" yaml:"limit"`
|
||||
Params GenerationParamsConfig `json:"params,omitzero" yaml:"params,omitempty"`
|
||||
}
|
||||
|
||||
// GenerationParamsConfig is the JSON/YAML-serializable form of generation
|
||||
// parameter defaults. Used in both customModels[].params and modelSettings[].
|
||||
type GenerationParamsConfig struct {
|
||||
MaxTokens *int `json:"maxTokens,omitempty" yaml:"maxTokens,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
TopP *float32 `json:"topP,omitempty" yaml:"topP,omitempty"`
|
||||
TopK *int32 `json:"topK,omitempty" yaml:"topK,omitempty"`
|
||||
FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty" yaml:"frequencyPenalty,omitempty"`
|
||||
PresencePenalty *float32 `json:"presencePenalty,omitempty" yaml:"presencePenalty,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty" yaml:"stopSequences,omitempty"`
|
||||
ThinkingLevel string `json:"thinkingLevel,omitempty" yaml:"thinkingLevel,omitempty"`
|
||||
SystemPrompt string `json:"systemPrompt,omitempty" yaml:"systemPrompt,omitempty"`
|
||||
}
|
||||
|
||||
// CostConfig defines the pricing for a custom model.
|
||||
|
||||
@@ -0,0 +1,422 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
func TestConvertGenerationParams(t *testing.T) {
|
||||
t.Run("empty config returns nil", func(t *testing.T) {
|
||||
cfg := GenerationParamsConfig{}
|
||||
p := convertGenerationParams(cfg)
|
||||
if p != nil {
|
||||
t.Errorf("expected nil, got %+v", p)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("temperature only", func(t *testing.T) {
|
||||
temp := float32(0.7)
|
||||
cfg := GenerationParamsConfig{Temperature: &temp}
|
||||
p := convertGenerationParams(cfg)
|
||||
if p == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if p.Temperature == nil || *p.Temperature != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %v", p.Temperature)
|
||||
}
|
||||
if p.TopP != nil {
|
||||
t.Errorf("expected nil TopP, got %v", p.TopP)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("all params set", func(t *testing.T) {
|
||||
maxTokens := 8192
|
||||
temp := float32(0.5)
|
||||
topP := float32(0.9)
|
||||
topK := int32(50)
|
||||
freqPenalty := float32(0.1)
|
||||
presPenalty := float32(0.2)
|
||||
cfg := GenerationParamsConfig{
|
||||
MaxTokens: &maxTokens,
|
||||
Temperature: &temp,
|
||||
TopP: &topP,
|
||||
TopK: &topK,
|
||||
FrequencyPenalty: &freqPenalty,
|
||||
PresencePenalty: &presPenalty,
|
||||
StopSequences: []string{"STOP"},
|
||||
ThinkingLevel: "high",
|
||||
}
|
||||
p := convertGenerationParams(cfg)
|
||||
if p == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if p.MaxTokens == nil || *p.MaxTokens != 8192 {
|
||||
t.Errorf("expected maxTokens 8192, got %v", p.MaxTokens)
|
||||
}
|
||||
if p.Temperature == nil || *p.Temperature != 0.5 {
|
||||
t.Errorf("expected temperature 0.5, got %v", p.Temperature)
|
||||
}
|
||||
if p.TopP == nil || *p.TopP != 0.9 {
|
||||
t.Errorf("expected topP 0.9, got %v", p.TopP)
|
||||
}
|
||||
if p.TopK == nil || *p.TopK != 50 {
|
||||
t.Errorf("expected topK 50, got %v", p.TopK)
|
||||
}
|
||||
if p.FrequencyPenalty == nil || *p.FrequencyPenalty != 0.1 {
|
||||
t.Errorf("expected frequencyPenalty 0.1, got %v", p.FrequencyPenalty)
|
||||
}
|
||||
if p.PresencePenalty == nil || *p.PresencePenalty != 0.2 {
|
||||
t.Errorf("expected presencePenalty 0.2, got %v", p.PresencePenalty)
|
||||
}
|
||||
if len(p.StopSequences) != 1 || p.StopSequences[0] != "STOP" {
|
||||
t.Errorf("expected stop sequences [STOP], got %v", p.StopSequences)
|
||||
}
|
||||
if p.ThinkingLevel != ThinkingHigh {
|
||||
t.Errorf("expected thinking level high, got %v", p.ThinkingLevel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("thinking level parsing", func(t *testing.T) {
|
||||
cfg := GenerationParamsConfig{ThinkingLevel: "medium"}
|
||||
p := convertGenerationParams(cfg)
|
||||
if p == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if p.ThinkingLevel != ThinkingMedium {
|
||||
t.Errorf("expected thinking level medium, got %v", p.ThinkingLevel)
|
||||
}
|
||||
})
|
||||
t.Run("system prompt only", func(t *testing.T) {
|
||||
cfg := GenerationParamsConfig{SystemPrompt: "You are helpful."}
|
||||
p := convertGenerationParams(cfg)
|
||||
if p == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if p.SystemPrompt != "You are helpful." {
|
||||
t.Errorf("expected system prompt, got %q", p.SystemPrompt)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelConfigToModelInfoWithParams(t *testing.T) {
|
||||
temp := float32(0.8)
|
||||
topP := float32(0.95)
|
||||
cfg := CustomModelConfig{
|
||||
Name: "Test Model",
|
||||
BaseURL: "http://localhost:8080/v1",
|
||||
Temperature: true,
|
||||
Params: GenerationParamsConfig{
|
||||
Temperature: &temp,
|
||||
TopP: &topP,
|
||||
},
|
||||
}
|
||||
|
||||
info := modelConfigToModelInfo("test-model", cfg)
|
||||
|
||||
if info.Params == nil {
|
||||
t.Fatal("expected non-nil Params")
|
||||
}
|
||||
if info.Params.Temperature == nil || *info.Params.Temperature != 0.8 {
|
||||
t.Errorf("expected temperature 0.8, got %v", info.Params.Temperature)
|
||||
}
|
||||
if info.Params.TopP == nil || *info.Params.TopP != 0.95 {
|
||||
t.Errorf("expected topP 0.95, got %v", info.Params.TopP)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelConfigToModelInfoWithoutParams(t *testing.T) {
|
||||
cfg := CustomModelConfig{
|
||||
Name: "Test Model",
|
||||
BaseURL: "http://localhost:8080/v1",
|
||||
}
|
||||
|
||||
info := modelConfigToModelInfo("test-model", cfg)
|
||||
|
||||
if info.Params != nil {
|
||||
t.Errorf("expected nil Params, got %+v", info.Params)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyModelSettings(t *testing.T) {
|
||||
// Save and restore viper state.
|
||||
originalViper := viper.AllSettings()
|
||||
defer func() {
|
||||
viper.Reset()
|
||||
for k, v := range originalViper {
|
||||
viper.Set(k, v)
|
||||
}
|
||||
}()
|
||||
|
||||
t.Run("applies model params when not explicitly set", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
temp := float32(0.8)
|
||||
topK := int32(50)
|
||||
maxTokens := 4096
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
Temperature: &temp,
|
||||
TopK: &topK,
|
||||
MaxTokens: &maxTokens,
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.Temperature == nil || *config.Temperature != 0.8 {
|
||||
t.Errorf("expected temperature 0.8, got %v", config.Temperature)
|
||||
}
|
||||
if config.TopK == nil || *config.TopK != 50 {
|
||||
t.Errorf("expected topK 50, got %v", config.TopK)
|
||||
}
|
||||
if config.MaxTokens != 4096 {
|
||||
t.Errorf("expected maxTokens 4096, got %d", config.MaxTokens)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("explicit viper values take precedence", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
viper.Set("temperature", 0.3)
|
||||
|
||||
temp := float32(0.8)
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
Temperature: &temp,
|
||||
},
|
||||
}
|
||||
|
||||
explicitTemp := float32(0.3)
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
Temperature: &explicitTemp,
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
// Temperature should NOT be overridden because it's explicitly set in viper
|
||||
if config.Temperature == nil || *config.Temperature != 0.3 {
|
||||
t.Errorf("expected temperature 0.3 (explicit), got %v", config.Temperature)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil model info is safe", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
// Should not panic
|
||||
ApplyModelSettings(config, nil)
|
||||
|
||||
if config.Temperature != nil {
|
||||
t.Errorf("expected nil temperature, got %v", config.Temperature)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model info without params is safe", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
modelInfo := &ModelInfo{ID: "test-model"}
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.Temperature != nil {
|
||||
t.Errorf("expected nil temperature, got %v", config.Temperature)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modelSettings from viper takes priority over ModelInfo.Params", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
// Set up modelSettings in viper (simulating config file)
|
||||
viper.Set("modelSettings", map[string]any{
|
||||
"custom/test-model": map[string]any{
|
||||
"temperature": 0.5,
|
||||
"topK": 30,
|
||||
},
|
||||
})
|
||||
|
||||
// ModelInfo has different params
|
||||
temp := float32(0.8)
|
||||
topK := int32(50)
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
Temperature: &temp,
|
||||
TopK: &topK,
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
// modelSettings should win over ModelInfo.Params
|
||||
if config.Temperature == nil || *config.Temperature != 0.5 {
|
||||
t.Errorf("expected temperature 0.5 (from modelSettings), got %v", config.Temperature)
|
||||
}
|
||||
if config.TopK == nil || *config.TopK != 30 {
|
||||
t.Errorf("expected topK 30 (from modelSettings), got %v", config.TopK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("stop sequences applied from model params", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
StopSequences: []string{"STOP", "END"},
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if len(config.StopSequences) != 2 || config.StopSequences[0] != "STOP" {
|
||||
t.Errorf("expected stop sequences [STOP END], got %v", config.StopSequences)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("thinking level applied from model params", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
ThinkingLevel: ThinkingHigh,
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.ThinkingLevel != ThinkingHigh {
|
||||
t.Errorf("expected thinking level high, got %v", config.ThinkingLevel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("system prompt applied from model params", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
SystemPrompt: "You are a coding assistant.",
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.SystemPrompt != "You are a coding assistant." {
|
||||
t.Errorf("expected system prompt to be set, got %q", config.SystemPrompt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("explicit system prompt takes precedence", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
SystemPrompt: "Model-specific prompt",
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
SystemPrompt: "Global prompt",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
// Global system prompt should NOT be overridden because config
|
||||
// already has a non-empty SystemPrompt.
|
||||
if config.SystemPrompt != "Global prompt" {
|
||||
t.Errorf("expected global prompt preserved, got %q", config.SystemPrompt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("system prompt from file path", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
// Create a temp file with a system prompt
|
||||
tmpFile, err := os.CreateTemp("", "kit-test-prompt-*.txt")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() { _ = os.Remove(tmpFile.Name()) }()
|
||||
if _, err := tmpFile.WriteString(" Prompt from file "); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_ = tmpFile.Close()
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
SystemPrompt: tmpFile.Name(),
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.SystemPrompt != "Prompt from file" {
|
||||
t.Errorf("expected trimmed file content, got %q", config.SystemPrompt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modelSettings system prompt overrides custom model params", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
viper.Set("modelSettings", map[string]any{
|
||||
"custom/test-model": map[string]any{
|
||||
"systemPrompt": "From modelSettings",
|
||||
},
|
||||
})
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
SystemPrompt: "From custom model",
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.SystemPrompt != "From modelSettings" {
|
||||
t.Errorf("expected modelSettings prompt, got %q", config.SystemPrompt)
|
||||
}
|
||||
})
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
+48
-148
@@ -10,7 +10,6 @@ import (
|
||||
"maps"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -144,20 +143,22 @@ func ParseThinkingLevel(s string) ThinkingLevel {
|
||||
|
||||
// ProviderConfig holds configuration for creating LLM providers.
|
||||
type ProviderConfig struct {
|
||||
ModelString string
|
||||
SystemPrompt string
|
||||
ProviderAPIKey string
|
||||
ProviderURL string
|
||||
MaxTokens int
|
||||
Temperature *float32
|
||||
TopP *float32
|
||||
TopK *int32
|
||||
StopSequences []string
|
||||
NumGPU *int32
|
||||
MainGPU *int32
|
||||
TLSSkipVerify bool
|
||||
ThinkingLevel ThinkingLevel
|
||||
DisableCaching bool // Opt-out: set to true to disable automatic prompt caching
|
||||
ModelString string
|
||||
SystemPrompt string
|
||||
ProviderAPIKey string
|
||||
ProviderURL string
|
||||
MaxTokens int
|
||||
Temperature *float32
|
||||
TopP *float32
|
||||
TopK *int32
|
||||
FrequencyPenalty *float32
|
||||
PresencePenalty *float32
|
||||
StopSequences []string
|
||||
NumGPU *int32
|
||||
MainGPU *int32
|
||||
TLSSkipVerify bool
|
||||
ThinkingLevel ThinkingLevel
|
||||
DisableCaching bool // Opt-out: set to true to disable automatic prompt caching
|
||||
}
|
||||
|
||||
// ProviderResult contains the result of provider creation.
|
||||
@@ -240,6 +241,11 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
|
||||
validateModelConfig(config, modelInfo)
|
||||
}
|
||||
|
||||
// Apply per-model generation parameter defaults. Model-level params are
|
||||
// only applied for fields where the user hasn't explicitly set a value
|
||||
// via CLI flag or global config.
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
// Create the base provider
|
||||
var result *ProviderResult
|
||||
var createErr error
|
||||
@@ -525,13 +531,13 @@ func buildOpenAIProviderOptions(config *ProviderConfig, modelName string) fantas
|
||||
func thinkingLevelToReasoningEffort(level ThinkingLevel) *openai.ReasoningEffort {
|
||||
switch level {
|
||||
case ThinkingMinimal:
|
||||
return openai.ReasoningEffortOption(openai.ReasoningEffortMinimal)
|
||||
return new(openai.ReasoningEffortMinimal)
|
||||
case ThinkingLow:
|
||||
return openai.ReasoningEffortOption(openai.ReasoningEffortLow)
|
||||
return new(openai.ReasoningEffortLow)
|
||||
case ThinkingMedium:
|
||||
return openai.ReasoningEffortOption(openai.ReasoningEffortMedium)
|
||||
return new(openai.ReasoningEffortMedium)
|
||||
case ThinkingHigh:
|
||||
return openai.ReasoningEffortOption(openai.ReasoningEffortHigh)
|
||||
return new(openai.ReasoningEffortHigh)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
@@ -1000,139 +1006,29 @@ func createVercelProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
return &ProviderResult{Model: model}, nil
|
||||
}
|
||||
|
||||
// thinkTagRegex matches <think>...</think> tags for extracting reasoning content
|
||||
// from models that wrap thinking in XML-like tags (e.g., Qwen, DeepSeek).
|
||||
var thinkTagRegex = regexp.MustCompile(`(?s)<think>(.*?)</think>`)
|
||||
|
||||
// customExtraContentFunc extracts reasoning from <think> tags in the content field.
|
||||
// This handles models like Qwen and DeepSeek that return reasoning wrapped in XML tags
|
||||
// rather than using a separate reasoning_content field.
|
||||
func customExtraContentFunc(choice openaisdk.ChatCompletionChoice) []fantasy.Content {
|
||||
var content []fantasy.Content
|
||||
if choice.Message.Content == "" {
|
||||
return content
|
||||
}
|
||||
|
||||
// Check for <think> tags in the content
|
||||
matches := thinkTagRegex.FindStringSubmatch(choice.Message.Content)
|
||||
if len(matches) > 1 {
|
||||
// Found reasoning content in <think> tags
|
||||
reasoning := strings.TrimSpace(matches[1])
|
||||
if reasoning != "" {
|
||||
content = append(content, fantasy.ReasoningContent{
|
||||
Text: reasoning,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return content
|
||||
}
|
||||
|
||||
// customStreamExtraFunc handles streaming responses with <think> tags.
|
||||
// It extracts reasoning content and emits proper reasoning events.
|
||||
func customStreamExtraFunc(
|
||||
chunk openaisdk.ChatCompletionChunk,
|
||||
yield func(fantasy.StreamPart) bool,
|
||||
ctx map[string]any,
|
||||
) (map[string]any, bool) {
|
||||
if len(chunk.Choices) == 0 {
|
||||
return ctx, true
|
||||
}
|
||||
|
||||
const reasoningStartedKey = "reasoning_started"
|
||||
const reasoningBufferKey = "reasoning_buffer"
|
||||
const inThinkTagKey = "in_think_tag"
|
||||
|
||||
reasoningStarted, _ := ctx[reasoningStartedKey].(bool)
|
||||
inThinkTag, _ := ctx[inThinkTagKey].(bool)
|
||||
reasoningBuffer, _ := ctx[reasoningBufferKey].(string)
|
||||
|
||||
for i, choice := range chunk.Choices {
|
||||
content := choice.Delta.Content
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for <think> tag start
|
||||
if strings.Contains(content, "<think>") {
|
||||
inThinkTag = true
|
||||
ctx[inThinkTagKey] = true
|
||||
|
||||
// Emit reasoning start event
|
||||
if !reasoningStarted {
|
||||
reasoningStarted = true
|
||||
ctx[reasoningStartedKey] = true
|
||||
if !yield(fantasy.StreamPart{
|
||||
Type: fantasy.StreamPartTypeReasoningStart,
|
||||
ID: fmt.Sprintf("%d", i),
|
||||
}) {
|
||||
return ctx, false
|
||||
}
|
||||
}
|
||||
|
||||
// Extract content after <think>
|
||||
parts := strings.SplitN(content, "<think>", 2)
|
||||
if len(parts) > 1 && parts[1] != "" {
|
||||
reasoningBuffer += parts[1]
|
||||
ctx[reasoningBufferKey] = reasoningBuffer
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for </think> tag end
|
||||
if strings.Contains(content, "</think>") {
|
||||
inThinkTag = false
|
||||
ctx[inThinkTagKey] = false
|
||||
|
||||
// Extract content before </think>
|
||||
parts := strings.SplitN(content, "</think>", 2)
|
||||
if len(parts) > 0 {
|
||||
reasoningBuffer += parts[0]
|
||||
}
|
||||
|
||||
// Emit the accumulated reasoning
|
||||
if reasoningBuffer != "" {
|
||||
if !yield(fantasy.StreamPart{
|
||||
Type: fantasy.StreamPartTypeReasoningDelta,
|
||||
ID: fmt.Sprintf("%d", i),
|
||||
Delta: reasoningBuffer,
|
||||
}) {
|
||||
return ctx, false
|
||||
}
|
||||
ctx[reasoningBufferKey] = ""
|
||||
}
|
||||
|
||||
// Emit reasoning end
|
||||
if !yield(fantasy.StreamPart{
|
||||
Type: fantasy.StreamPartTypeReasoningEnd,
|
||||
ID: fmt.Sprintf("%d", i),
|
||||
}) {
|
||||
return ctx, false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Accumulate reasoning content while in think tag
|
||||
if inThinkTag {
|
||||
reasoningBuffer += content
|
||||
ctx[reasoningBufferKey] = reasoningBuffer
|
||||
}
|
||||
}
|
||||
|
||||
return ctx, true
|
||||
}
|
||||
|
||||
// customToPromptFunc converts prompts to OpenAI format using the default conversion.
|
||||
func customToPromptFunc(prompt fantasy.Prompt, systemPrompt, user string) ([]openaisdk.ChatCompletionMessageParamUnion, []fantasy.CallWarning) {
|
||||
return openai.DefaultToPrompt(prompt, systemPrompt, user)
|
||||
}
|
||||
|
||||
func createCustomProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
||||
if config.ProviderURL == "" {
|
||||
return nil, fmt.Errorf("custom provider requires --provider-url")
|
||||
// Resolve base URL: per-model override > global provider-url flag/config
|
||||
registry := GetGlobalRegistry()
|
||||
modelInfo := registry.LookupModel("custom", modelName)
|
||||
|
||||
baseURL := config.ProviderURL
|
||||
if modelInfo != nil && modelInfo.BaseURL != "" {
|
||||
baseURL = modelInfo.BaseURL
|
||||
}
|
||||
|
||||
if baseURL == "" {
|
||||
return nil, fmt.Errorf("custom provider requires --provider-url or a baseUrl in the model config")
|
||||
}
|
||||
|
||||
apiKey := config.ProviderAPIKey
|
||||
if modelInfo != nil && modelInfo.APIKey != "" {
|
||||
apiKey = modelInfo.APIKey
|
||||
}
|
||||
if apiKey == "" {
|
||||
apiKey = os.Getenv("CUSTOM_API_KEY")
|
||||
}
|
||||
@@ -1141,15 +1037,13 @@ func createCustomProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
apiKey = "custom"
|
||||
}
|
||||
|
||||
// Use the openai provider directly with custom hooks to handle <think> tags
|
||||
// from models like Qwen and DeepSeek that wrap reasoning in XML tags.
|
||||
// <think> tag extraction is handled transparently at the agent layer,
|
||||
// so no provider-level hooks are needed here.
|
||||
var opts []openai.Option
|
||||
opts = append(opts, openai.WithBaseURL(config.ProviderURL))
|
||||
opts = append(opts, openai.WithBaseURL(baseURL))
|
||||
opts = append(opts, openai.WithAPIKey(apiKey))
|
||||
opts = append(opts, openai.WithName("custom"))
|
||||
opts = append(opts, openai.WithLanguageModelOptions(
|
||||
openai.WithLanguageModelExtraContentFunc(customExtraContentFunc),
|
||||
openai.WithLanguageModelStreamExtraFunc(customStreamExtraFunc),
|
||||
openai.WithLanguageModelToPromptFunc(customToPromptFunc),
|
||||
))
|
||||
|
||||
@@ -1277,6 +1171,12 @@ func buildOllamaOptions(config *ProviderConfig) map[string]any {
|
||||
if config.TopK != nil {
|
||||
options["top_k"] = int(*config.TopK)
|
||||
}
|
||||
if config.FrequencyPenalty != nil {
|
||||
options["frequency_penalty"] = *config.FrequencyPenalty
|
||||
}
|
||||
if config.PresencePenalty != nil {
|
||||
options["presence_penalty"] = *config.PresencePenalty
|
||||
}
|
||||
if len(config.StopSequences) > 0 {
|
||||
options["stop"] = config.StopSequences
|
||||
}
|
||||
|
||||
@@ -24,6 +24,13 @@ type ModelInfo struct {
|
||||
Cost Cost
|
||||
Limit Limit
|
||||
ProviderNPM string // Model-specific provider npm override (e.g. "@ai-sdk/anthropic")
|
||||
BaseURL string // Per-model base URL override (custom models only)
|
||||
APIKey string // Per-model API key override (custom models only)
|
||||
|
||||
// Params holds per-model generation parameter defaults. These are applied
|
||||
// when the user hasn't explicitly set the corresponding CLI flag or global
|
||||
// config value. Nil pointer fields mean "no model-level default".
|
||||
Params *GenerationParams
|
||||
}
|
||||
|
||||
// SupportsCaching returns true if this model family supports prompt caching.
|
||||
@@ -234,6 +241,18 @@ func (r *ModelsRegistry) LookupModel(provider, modelID string) *ModelInfo {
|
||||
return &modelInfo
|
||||
}
|
||||
|
||||
// LookupModelForSettings is a convenience function that parses a
|
||||
// "provider/model" string and looks up the ModelInfo in the global registry.
|
||||
// Returns nil when the model string is invalid or the model is unknown.
|
||||
// Used by Kit.SetModel to pre-apply per-model settings before CreateProvider.
|
||||
func LookupModelForSettings(modelString string) *ModelInfo {
|
||||
provider, modelName, err := ParseModelString(modelString)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return GetGlobalRegistry().LookupModel(provider, modelName)
|
||||
}
|
||||
|
||||
// getRequiredEnvVars returns the required environment variables for a provider.
|
||||
func (r *ModelsRegistry) getRequiredEnvVars(provider string) ([]string, error) {
|
||||
providerInfo, exists := r.providers[provider]
|
||||
@@ -367,8 +386,8 @@ func (r *ModelsRegistry) GetFantasyProviders() []string {
|
||||
|
||||
// isProviderLLMSupported checks if a provider can be used with the LLM layer.
|
||||
func isProviderLLMSupported(providerID string, info *ProviderInfo) bool {
|
||||
// Ollama is always supported (via openaicompat pointed at localhost)
|
||||
if providerID == "ollama" {
|
||||
// Ollama and custom are always supported (model names are user-defined).
|
||||
if providerID == "ollama" || providerID == "custom" {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -400,6 +419,52 @@ func (r *ModelsRegistry) GetProviderInfo(provider string) *ProviderInfo {
|
||||
return &info
|
||||
}
|
||||
|
||||
// ValidateModelString checks whether a model string is well-formed and refers
|
||||
// to a known provider. It returns a user-friendly error with suggestions when
|
||||
// the model or provider is unrecognised. Passing validation does not guarantee
|
||||
// that API authentication will succeed — it only catches obvious mistakes
|
||||
// (typos, missing provider prefix, non-existent provider names) early so that
|
||||
// callers such as subagent spawning can return fast feedback.
|
||||
//
|
||||
// Unknown models under a known provider are allowed (the provider API is the
|
||||
// authority), but a completely unknown provider is rejected.
|
||||
func (r *ModelsRegistry) ValidateModelString(modelString string) error {
|
||||
provider, modelName, err := ParseModelString(modelString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ollama and custom are always valid — model names are user-defined.
|
||||
if provider == "ollama" || provider == "custom" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the provider exists in the registry.
|
||||
providerInfo := r.GetProviderInfo(provider)
|
||||
if providerInfo == nil {
|
||||
known := r.GetSupportedProviders()
|
||||
return fmt.Errorf(
|
||||
"unknown provider %q in model string %q. Known providers: %s",
|
||||
provider, modelString, strings.Join(known, ", "),
|
||||
)
|
||||
}
|
||||
|
||||
// Provider exists — check if the model is known. An unknown model is
|
||||
// only a warning (the provider API decides), but we surface suggestions
|
||||
// so the caller can self-correct.
|
||||
if r.LookupModel(provider, modelName) == nil {
|
||||
if suggestions := r.SuggestModels(provider, modelName); len(suggestions) > 0 {
|
||||
return fmt.Errorf(
|
||||
"model %q not found for provider %s. Did you mean one of: %s",
|
||||
modelName, provider, strings.Join(suggestions, ", "),
|
||||
)
|
||||
}
|
||||
// No suggestions — let it through; the provider API is the authority.
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Global registry instance
|
||||
var globalRegistry = NewModelsRegistry()
|
||||
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidateModelString(t *testing.T) {
|
||||
registry := GetGlobalRegistry()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
wantErr bool
|
||||
errSubstr string // expected substring in error message (empty = don't check)
|
||||
}{
|
||||
{
|
||||
name: "valid anthropic model",
|
||||
model: "anthropic/claude-sonnet-4-6",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing provider prefix",
|
||||
model: "claude-sonnet-4-6",
|
||||
wantErr: true,
|
||||
errSubstr: "invalid model format",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
model: "",
|
||||
wantErr: true,
|
||||
errSubstr: "invalid model format",
|
||||
},
|
||||
{
|
||||
name: "unknown provider",
|
||||
model: "fakeprovider/some-model",
|
||||
wantErr: true,
|
||||
errSubstr: "unknown provider",
|
||||
},
|
||||
{
|
||||
name: "ollama always valid",
|
||||
model: "ollama/llama3",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "custom always valid",
|
||||
model: "custom/my-fine-tune",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty provider",
|
||||
model: "/claude-sonnet-4-6",
|
||||
wantErr: true,
|
||||
errSubstr: "invalid model format",
|
||||
},
|
||||
{
|
||||
name: "empty model name",
|
||||
model: "anthropic/",
|
||||
wantErr: true,
|
||||
errSubstr: "invalid model format",
|
||||
},
|
||||
{
|
||||
name: "unknown model under known provider (no suggestions)",
|
||||
model: "anthropic/totally-unknown-xyz-999",
|
||||
wantErr: false, // no suggestions → passes through
|
||||
},
|
||||
{
|
||||
name: "typo model under known provider with suggestions",
|
||||
model: "anthropic/claude-sonet", // misspelled "sonnet"
|
||||
wantErr: true,
|
||||
errSubstr: "Did you mean",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := registry.ValidateModelString(tt.model)
|
||||
if tt.wantErr && err == nil {
|
||||
t.Errorf("ValidateModelString(%q) = nil, want error", tt.model)
|
||||
}
|
||||
if !tt.wantErr && err != nil {
|
||||
t.Errorf("ValidateModelString(%q) = %v, want nil", tt.model, err)
|
||||
}
|
||||
if tt.errSubstr != "" && err != nil {
|
||||
if !strings.Contains(err.Error(), tt.errSubstr) {
|
||||
t.Errorf("ValidateModelString(%q) error = %q, want substring %q",
|
||||
tt.model, err.Error(), tt.errSubstr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,11 +2,10 @@ package prompts
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
)
|
||||
|
||||
// LoadOptions configures how templates are discovered and loaded.
|
||||
@@ -74,10 +73,7 @@ func LoadAll(opts LoadOptions) ([]*PromptTemplate, []Diagnostic, error) {
|
||||
DroppedPath: tpl.FilePath,
|
||||
Reason: fmt.Sprintf("template from %s overridden by %s", source, existing.Source),
|
||||
})
|
||||
log.Debug("template collision",
|
||||
"name", tpl.Name,
|
||||
"dropped", tpl.FilePath,
|
||||
"kept", existing.FilePath)
|
||||
log.Printf("DEBUG template collision: name=%s dropped=%s kept=%s", tpl.Name, tpl.FilePath, existing.FilePath)
|
||||
} else {
|
||||
tpl.Source = source
|
||||
seen[tpl.Name] = tpl
|
||||
|
||||
@@ -0,0 +1,317 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/mark3labs/kit/internal/message"
|
||||
)
|
||||
|
||||
// TestCompactionCreatesNewLeaf verifies that after compaction, the compaction
|
||||
// entry has no parent (creating a new root), and BuildContext returns only
|
||||
// the summary and kept messages, not the old compacted messages.
|
||||
func TestCompactionCreatesNewLeaf(t *testing.T) {
|
||||
tm := InMemoryTreeSession("/test")
|
||||
|
||||
// Add some messages: M1, M2 (old, will be compacted), M3, M4 (kept)
|
||||
msg1 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Message 1 - old"}}}
|
||||
msg2 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 2 - old"}}}
|
||||
msg3 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Message 3 - kept"}}}
|
||||
msg4 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 4 - kept"}}}
|
||||
|
||||
_, _ = tm.AppendMessage(msg1)
|
||||
_, _ = tm.AppendMessage(msg2)
|
||||
id3, _ := tm.AppendMessage(msg3)
|
||||
id4, _ := tm.AppendMessage(msg4)
|
||||
|
||||
// Verify initial state - all messages should be in context
|
||||
messages, _, _ := tm.BuildContext()
|
||||
if len(messages) != 4 {
|
||||
t.Fatalf("expected 4 messages before compaction, got %d", len(messages))
|
||||
}
|
||||
|
||||
// Verify entry IDs
|
||||
entryIDs := tm.GetContextEntryIDs()
|
||||
if len(entryIDs) != 4 {
|
||||
t.Fatalf("expected 4 entry IDs before compaction, got %d", len(entryIDs))
|
||||
}
|
||||
|
||||
// Now add a compaction entry, simulating that M3 is the first kept entry
|
||||
summary := "Summary of old messages"
|
||||
compactionID, err := tm.AppendCompaction(summary, id3, 1000, 500, 2, []string{}, []string{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to append compaction: %v", err)
|
||||
}
|
||||
|
||||
// Verify the compaction entry has no parent (empty ParentID)
|
||||
compactionEntry := tm.GetEntry(compactionID).(*CompactionEntry)
|
||||
if compactionEntry.ParentID != "" {
|
||||
t.Errorf("compaction entry should have no parent, got %q", compactionEntry.ParentID)
|
||||
}
|
||||
|
||||
// Verify the leaf is now the compaction entry
|
||||
if tm.GetLeafID() != compactionID {
|
||||
t.Errorf("leaf should be compaction entry %q, got %q", compactionID, tm.GetLeafID())
|
||||
}
|
||||
|
||||
// Now BuildContext should return: [summary] + [M3, M4]
|
||||
messages, _, _ = tm.BuildContext()
|
||||
if len(messages) != 3 {
|
||||
t.Fatalf("expected 3 messages after compaction (summary + 2 kept), got %d", len(messages))
|
||||
}
|
||||
|
||||
// First message should be the summary
|
||||
if messages[0].Role != fantasy.MessageRoleSystem {
|
||||
t.Errorf("first message should be system summary, got %s", messages[0].Role)
|
||||
}
|
||||
summaryText := messages[0].Content[0].(fantasy.TextPart).Text
|
||||
if summaryText != "[Conversation summary — earlier messages were compacted]\n\n"+summary {
|
||||
t.Errorf("unexpected summary text: %s", summaryText)
|
||||
}
|
||||
|
||||
// Second message should be M3 (kept)
|
||||
if messages[1].Role != fantasy.MessageRoleUser {
|
||||
t.Errorf("second message should be user (M3), got %s", messages[1].Role)
|
||||
}
|
||||
m3Text := messages[1].Content[0].(fantasy.TextPart).Text
|
||||
if m3Text != "Message 3 - kept" {
|
||||
t.Errorf("unexpected M3 text: %s", m3Text)
|
||||
}
|
||||
|
||||
// Third message should be M4 (kept)
|
||||
if messages[2].Role != fantasy.MessageRoleAssistant {
|
||||
t.Errorf("third message should be assistant (M4), got %s", messages[2].Role)
|
||||
}
|
||||
m4Text := messages[2].Content[0].(fantasy.TextPart).Text
|
||||
if m4Text != "Message 4 - kept" {
|
||||
t.Errorf("unexpected M4 text: %s", m4Text)
|
||||
}
|
||||
|
||||
// Verify GetContextEntryIDs returns correct IDs
|
||||
entryIDs = tm.GetContextEntryIDs()
|
||||
if len(entryIDs) != 3 {
|
||||
t.Fatalf("expected 3 entry IDs after compaction (empty for summary + 2 kept), got %d: %v", len(entryIDs), entryIDs)
|
||||
}
|
||||
|
||||
// First entry ID should be empty (summary has no entry)
|
||||
if entryIDs[0] != "" {
|
||||
t.Errorf("first entry ID should be empty (summary), got %q", entryIDs[0])
|
||||
}
|
||||
|
||||
// Second and third should be id3 and id4 (the kept messages)
|
||||
if entryIDs[1] != id3 {
|
||||
t.Errorf("second entry ID should be %q (M3), got %q", id3, entryIDs[1])
|
||||
}
|
||||
if entryIDs[2] != id4 {
|
||||
t.Errorf("third entry ID should be %q (M4), got %q", id4, entryIDs[2])
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompactionWithNewMessagesAfterCompaction verifies that messages appended
|
||||
// after compaction are correctly included in the context.
|
||||
func TestCompactionWithNewMessagesAfterCompaction(t *testing.T) {
|
||||
tm := InMemoryTreeSession("/test")
|
||||
|
||||
// Add initial messages
|
||||
msg1 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Message 1"}}}
|
||||
msg2 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 2"}}}
|
||||
msg3 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Message 3 - kept"}}}
|
||||
|
||||
_, _ = tm.AppendMessage(msg1)
|
||||
_, _ = tm.AppendMessage(msg2)
|
||||
id3, _ := tm.AppendMessage(msg3)
|
||||
|
||||
// Compact, keeping only M3
|
||||
_, _ = tm.AppendCompaction("Summary", id3, 1000, 500, 2, []string{}, []string{})
|
||||
|
||||
// Add a new message after compaction
|
||||
msg4 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 4 - after compaction"}}}
|
||||
_, _ = tm.AppendMessage(msg4)
|
||||
|
||||
// BuildContext should return: [summary] + [M4 (new after compaction)] + [M3 (kept)]
|
||||
messages, _, _ := tm.BuildContext()
|
||||
if len(messages) != 3 {
|
||||
t.Fatalf("expected 3 messages (summary + M4 + M3), got %d: %+v", len(messages), messages)
|
||||
}
|
||||
|
||||
// Verify order: summary, M4 (new), M3 (kept)
|
||||
if messages[0].Role != fantasy.MessageRoleSystem {
|
||||
t.Errorf("first message should be summary, got %s", messages[0].Role)
|
||||
}
|
||||
if messages[1].Role != fantasy.MessageRoleAssistant {
|
||||
t.Errorf("second message should be assistant (M4), got %s", messages[1].Role)
|
||||
}
|
||||
m4Text := messages[1].Content[0].(fantasy.TextPart).Text
|
||||
if m4Text != "Message 4 - after compaction" {
|
||||
t.Errorf("unexpected M4 text: %s", m4Text)
|
||||
}
|
||||
if messages[2].Role != fantasy.MessageRoleUser {
|
||||
t.Errorf("third message should be user (M3), got %s", messages[2].Role)
|
||||
}
|
||||
|
||||
// Verify that M1 is NOT in the context
|
||||
for i, msg := range messages {
|
||||
if msg.Role == fantasy.MessageRoleUser {
|
||||
text := msg.Content[0].(fantasy.TextPart).Text
|
||||
if text == "Message 1" {
|
||||
t.Errorf("Message 1 (compacted) should not be in context at index %d", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompactionWithNoKeptMessages verifies compaction when all messages are compacted.
|
||||
func TestCompactionWithNoKeptMessages(t *testing.T) {
|
||||
tm := InMemoryTreeSession("/test")
|
||||
|
||||
// Add messages that will all be compacted
|
||||
msg1 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Message 1"}}}
|
||||
msg2 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 2"}}}
|
||||
|
||||
if _, err := tm.AppendMessage(msg1); err != nil {
|
||||
t.Fatalf("failed to append message: %v", err)
|
||||
}
|
||||
if _, err := tm.AppendMessage(msg2); err != nil {
|
||||
t.Fatalf("failed to append message: %v", err)
|
||||
}
|
||||
|
||||
// Compact with no kept messages (empty firstKeptEntryID)
|
||||
summary := "All messages summarized"
|
||||
compactionID, _ := tm.AppendCompaction(summary, "", 1000, 100, 2, []string{}, []string{})
|
||||
|
||||
// Verify the compaction entry has no parent
|
||||
compactionEntry := tm.GetEntry(compactionID).(*CompactionEntry)
|
||||
if compactionEntry.ParentID != "" {
|
||||
t.Errorf("compaction entry should have no parent, got %q", compactionEntry.ParentID)
|
||||
}
|
||||
|
||||
// BuildContext should return only the summary
|
||||
messages, _, _ := tm.BuildContext()
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected 1 message (summary only), got %d: %+v", len(messages), messages)
|
||||
}
|
||||
if messages[0].Role != fantasy.MessageRoleSystem {
|
||||
t.Errorf("message should be system summary, got %s", messages[0].Role)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultipleCompactions verifies that multiple compactions work correctly.
|
||||
func TestMultipleCompactions(t *testing.T) {
|
||||
tm := InMemoryTreeSession("/test")
|
||||
|
||||
// First batch of messages
|
||||
msg1 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Batch 1 - User"}}}
|
||||
msg2 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Batch 1 - Assistant"}}}
|
||||
id1, _ := tm.AppendMessage(msg1)
|
||||
id2, _ := tm.AppendMessage(msg2)
|
||||
|
||||
// First compaction
|
||||
_, _ = tm.AppendCompaction("Summary 1", id1, 1000, 500, 1, []string{}, []string{})
|
||||
|
||||
// Second batch
|
||||
msg3 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Batch 2 - User"}}}
|
||||
msg4 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Batch 2 - Assistant"}}}
|
||||
id3, _ := tm.AppendMessage(msg3)
|
||||
id4, _ := tm.AppendMessage(msg4)
|
||||
|
||||
// Second compaction (compacting the first compaction + batch 2)
|
||||
// Note: id3 is the first kept entry, so id3 and id4 should be preserved
|
||||
compactionID2, _ := tm.AppendCompaction("Summary 2", id3, 1000, 500, 3, []string{}, []string{})
|
||||
|
||||
// Verify second compaction has no parent
|
||||
compactionEntry2 := tm.GetEntry(compactionID2).(*CompactionEntry)
|
||||
if compactionEntry2.ParentID != "" {
|
||||
t.Errorf("second compaction entry should have no parent, got %q", compactionEntry2.ParentID)
|
||||
}
|
||||
|
||||
// Add final message
|
||||
msg5 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Final message"}}}
|
||||
id5, _ := tm.AppendMessage(msg5)
|
||||
|
||||
// BuildContext should include:
|
||||
// - Summary 2 (from second compaction)
|
||||
// - msg5 (final message)
|
||||
// - msg3, msg4 (kept from second compaction)
|
||||
// But NOT Summary 1 or msg1, msg2 (they're before the first kept entry of compaction 2)
|
||||
messages, _, _ := tm.BuildContext()
|
||||
|
||||
// Should have: Summary 2 + msg5 + msg3 + msg4 = 4 messages
|
||||
if len(messages) != 4 {
|
||||
t.Fatalf("expected 4 messages (Summary 2 + msg5 + msg3 + msg4), got %d: %+v", len(messages), messages)
|
||||
}
|
||||
|
||||
// First should be Summary 2
|
||||
if messages[0].Role != fantasy.MessageRoleSystem {
|
||||
t.Errorf("first message should be system (Summary 2), got %s", messages[0].Role)
|
||||
}
|
||||
summaryText := messages[0].Content[0].(fantasy.TextPart).Text
|
||||
if summaryText != "[Conversation summary — earlier messages were compacted]\n\nSummary 2" {
|
||||
t.Errorf("unexpected summary: %s", summaryText)
|
||||
}
|
||||
|
||||
// Verify msg5 is included
|
||||
foundFinal := false
|
||||
for _, msg := range messages {
|
||||
if msg.Role == fantasy.MessageRoleUser {
|
||||
text := msg.Content[0].(fantasy.TextPart).Text
|
||||
if text == "Final message" {
|
||||
foundFinal = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundFinal {
|
||||
t.Error("Final message (msg5) should be in context")
|
||||
}
|
||||
|
||||
// Verify msg1, msg2 are NOT included (compacted by first compaction, then second)
|
||||
for _, msg := range messages {
|
||||
if msg.Role == fantasy.MessageRoleUser || msg.Role == fantasy.MessageRoleAssistant {
|
||||
text := msg.Content[0].(fantasy.TextPart).Text
|
||||
if text == "Batch 1 - User" || text == "Batch 1 - Assistant" {
|
||||
t.Errorf("Batch 1 messages should not be in context, found: %s", text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify entry IDs
|
||||
entryIDs := tm.GetContextEntryIDs()
|
||||
if len(entryIDs) != 4 {
|
||||
t.Fatalf("expected 4 entry IDs, got %d: %v", len(entryIDs), entryIDs)
|
||||
}
|
||||
|
||||
// First should be empty (summary)
|
||||
if entryIDs[0] != "" {
|
||||
t.Errorf("first entry ID should be empty (summary), got %q", entryIDs[0])
|
||||
}
|
||||
|
||||
// Check that id5 is in the list
|
||||
if !slices.Contains(entryIDs, id5) {
|
||||
t.Errorf("id5 (final message) should be in entry IDs, got %v", entryIDs)
|
||||
}
|
||||
|
||||
// Verify id3 and id4 ARE in the list (they were kept)
|
||||
foundID3, foundID4 := false, false
|
||||
for _, id := range entryIDs {
|
||||
if id == id3 {
|
||||
foundID3 = true
|
||||
}
|
||||
if id == id4 {
|
||||
foundID4 = true
|
||||
}
|
||||
}
|
||||
if !foundID3 {
|
||||
t.Errorf("id3 (kept message) should be in entry IDs, got %v", entryIDs)
|
||||
}
|
||||
if !foundID4 {
|
||||
t.Errorf("id4 (kept message) should be in entry IDs, got %v", entryIDs)
|
||||
}
|
||||
|
||||
// Verify id1 and id2 are NOT in the list (they were compacted away)
|
||||
for _, id := range entryIDs {
|
||||
if id == id1 || id == id2 {
|
||||
t.Errorf("id1 or id2 (compacted) should not be in entry IDs, found %q in %v", id, entryIDs)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -24,6 +24,7 @@ const (
|
||||
EntryTypeSessionInfo EntryType = "session_info"
|
||||
EntryTypeExtensionData EntryType = "extension_data"
|
||||
EntryTypeCompaction EntryType = "compaction"
|
||||
EntryTypeSystemPrompt EntryType = "system_prompt"
|
||||
)
|
||||
|
||||
// CurrentVersion is the session format version for JSONL tree sessions.
|
||||
@@ -117,6 +118,19 @@ type CompactionEntry struct {
|
||||
ModifiedFiles []string `json:"modified_files,omitempty"`
|
||||
}
|
||||
|
||||
// SystemPromptEntry records the system prompt and model used for the session.
|
||||
// This is primarily for sharing/debugging to see what instructions were
|
||||
// active during the conversation. It does NOT participate in the tree
|
||||
// structure (no ParentID) and is not used when building LLM context.
|
||||
type SystemPromptEntry struct {
|
||||
Type EntryType `json:"type"` // always "system_prompt"
|
||||
ID string `json:"id"` // unique entry ID
|
||||
Timestamp time.Time `json:"timestamp"` // when captured
|
||||
Content string `json:"content"` // the system prompt text
|
||||
Model string `json:"model"` // the model used (e.g., "claude-sonnet-4-5")
|
||||
Provider string `json:"provider"` // the provider used (e.g., "anthropic")
|
||||
}
|
||||
|
||||
// GenerateEntryID creates a unique entry identifier (16 hex chars).
|
||||
func GenerateEntryID() string {
|
||||
bytes := make([]byte, 8)
|
||||
@@ -217,6 +231,18 @@ func NewCompactionEntry(parentID, summary, firstKeptEntryID string, tokensBefore
|
||||
}
|
||||
}
|
||||
|
||||
// NewSystemPromptEntry creates a SystemPromptEntry.
|
||||
func NewSystemPromptEntry(content, model, provider string) *SystemPromptEntry {
|
||||
return &SystemPromptEntry{
|
||||
Type: EntryTypeSystemPrompt,
|
||||
ID: GenerateEntryID(),
|
||||
Timestamp: time.Now(),
|
||||
Content: content,
|
||||
Model: model,
|
||||
Provider: provider,
|
||||
}
|
||||
}
|
||||
|
||||
// --- JSONL marshaling helpers ---
|
||||
|
||||
// MarshalEntry serializes any entry to a JSON line (no trailing newline).
|
||||
@@ -295,6 +321,13 @@ func UnmarshalEntry(data []byte) (any, error) {
|
||||
}
|
||||
return &e, nil
|
||||
|
||||
case EntryTypeSystemPrompt:
|
||||
var e SystemPromptEntry
|
||||
if err := json.Unmarshal(data, &e); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal system_prompt entry: %w", err)
|
||||
}
|
||||
return &e, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown entry type: %q", env.Type)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSystemPromptEntry(t *testing.T) {
|
||||
// Test creation
|
||||
content := "You are a helpful coding assistant."
|
||||
model := "claude-sonnet-4-5"
|
||||
provider := "anthropic"
|
||||
entry := NewSystemPromptEntry(content, model, provider)
|
||||
|
||||
if entry.Type != EntryTypeSystemPrompt {
|
||||
t.Errorf("Expected type %q, got %q", EntryTypeSystemPrompt, entry.Type)
|
||||
}
|
||||
|
||||
if entry.Content != content {
|
||||
t.Errorf("Expected content %q, got %q", content, entry.Content)
|
||||
}
|
||||
|
||||
if entry.Model != model {
|
||||
t.Errorf("Expected model %q, got %q", model, entry.Model)
|
||||
}
|
||||
|
||||
if entry.Provider != provider {
|
||||
t.Errorf("Expected provider %q, got %q", provider, entry.Provider)
|
||||
}
|
||||
|
||||
if entry.ID == "" {
|
||||
t.Error("Expected non-empty ID")
|
||||
}
|
||||
|
||||
// Test marshaling
|
||||
data, err := MarshalEntry(entry)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
// Test unmarshaling
|
||||
unmarshaled, err := UnmarshalEntry(data)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
sysPrompt, ok := unmarshaled.(*SystemPromptEntry)
|
||||
if !ok {
|
||||
t.Fatalf("Expected *SystemPromptEntry, got %T", unmarshaled)
|
||||
}
|
||||
|
||||
if sysPrompt.Type != EntryTypeSystemPrompt {
|
||||
t.Errorf("Unmarshaled: expected type %q, got %q", EntryTypeSystemPrompt, sysPrompt.Type)
|
||||
}
|
||||
|
||||
if sysPrompt.Content != content {
|
||||
t.Errorf("Unmarshaled: expected content %q, got %q", content, sysPrompt.Content)
|
||||
}
|
||||
|
||||
if sysPrompt.Model != model {
|
||||
t.Errorf("Unmarshaled: expected model %q, got %q", model, sysPrompt.Model)
|
||||
}
|
||||
|
||||
if sysPrompt.Provider != provider {
|
||||
t.Errorf("Unmarshaled: expected provider %q, got %q", provider, sysPrompt.Provider)
|
||||
}
|
||||
|
||||
if sysPrompt.ID != entry.ID {
|
||||
t.Errorf("Unmarshaled: expected ID %q, got %q", entry.ID, sysPrompt.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemPromptEntryJSONStructure(t *testing.T) {
|
||||
content := "Test system prompt content"
|
||||
model := "gpt-4o"
|
||||
provider := "openai"
|
||||
entry := NewSystemPromptEntry(content, model, provider)
|
||||
|
||||
data, err := MarshalEntry(entry)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
// Verify JSON structure
|
||||
var raw map[string]any
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
t.Fatalf("Failed to unmarshal to raw map: %v", err)
|
||||
}
|
||||
|
||||
if raw["type"] != "system_prompt" {
|
||||
t.Errorf("Expected type 'system_prompt', got %v", raw["type"])
|
||||
}
|
||||
|
||||
if raw["content"] != content {
|
||||
t.Errorf("Expected content %q, got %v", content, raw["content"])
|
||||
}
|
||||
|
||||
if raw["model"] != model {
|
||||
t.Errorf("Expected model %q, got %v", model, raw["model"])
|
||||
}
|
||||
|
||||
if raw["provider"] != provider {
|
||||
t.Errorf("Expected provider %q, got %v", provider, raw["provider"])
|
||||
}
|
||||
|
||||
if raw["id"] == "" || raw["id"] == nil {
|
||||
t.Error("Expected non-empty id field")
|
||||
}
|
||||
|
||||
if raw["timestamp"] == "" || raw["timestamp"] == nil {
|
||||
t.Error("Expected non-empty timestamp field")
|
||||
}
|
||||
}
|
||||
@@ -114,6 +114,187 @@ func CreateTreeSession(cwd string) (*TreeManager, error) {
|
||||
return tm, nil
|
||||
}
|
||||
|
||||
// ForkToNewSession creates a new session file containing the history up to and
|
||||
// including the target entry ID. This matches Pi's /fork behavior: it creates
|
||||
// a completely new session file with a parent_session reference, copying all
|
||||
// entries from the root to the target point.
|
||||
func (tm *TreeManager) ForkToNewSession(cwd string, targetID string) (*TreeManager, error) {
|
||||
tm.mu.RLock()
|
||||
defer tm.mu.RUnlock()
|
||||
|
||||
// Get the branch from root to target (root-to-leaf order).
|
||||
branch := tm.getBranchLocked(targetID)
|
||||
if len(branch) == 0 {
|
||||
return nil, fmt.Errorf("target entry %q not found", targetID)
|
||||
}
|
||||
|
||||
// Create a new session file.
|
||||
newTm, err := CreateTreeSession(cwd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set the parent session reference in the header.
|
||||
newTm.header.ParentSession = tm.filePath
|
||||
newTm.header.ParentSessionID = tm.header.ID
|
||||
|
||||
// Rewrite the header with the parent reference.
|
||||
// We need to close and recreate the file to rewrite the header.
|
||||
if err := newTm.file.Close(); err != nil {
|
||||
return nil, fmt.Errorf("failed to close new session file: %w", err)
|
||||
}
|
||||
|
||||
// Recreate the file and write the updated header.
|
||||
f, err := os.Create(newTm.filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to recreate session file: %w", err)
|
||||
}
|
||||
newTm.file = f
|
||||
|
||||
if err := newTm.writeEntry(&newTm.header); err != nil {
|
||||
_ = f.Close()
|
||||
return nil, fmt.Errorf("failed to write session header: %w", err)
|
||||
}
|
||||
|
||||
// Copy entries from the branch to the new session.
|
||||
// We need to remap IDs since the new session is independent.
|
||||
idMap := make(map[string]string) // old ID -> new ID
|
||||
var prevNewID string
|
||||
|
||||
for _, entry := range branch {
|
||||
oldID := tm.EntryID(entry)
|
||||
newID := GenerateEntryID()
|
||||
idMap[oldID] = newID
|
||||
|
||||
// Create a copy of the entry with the new ID and remapped parent.
|
||||
var newEntry any
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
newEntry = &MessageEntry{
|
||||
Entry: Entry{
|
||||
Type: EntryTypeMessage,
|
||||
ID: newID,
|
||||
ParentID: prevNewID, // Chain sequentially in new session
|
||||
Timestamp: e.Timestamp,
|
||||
},
|
||||
Role: e.Role,
|
||||
Parts: e.Parts,
|
||||
Model: e.Model,
|
||||
Provider: e.Provider,
|
||||
}
|
||||
// Copy label if present.
|
||||
if label, ok := tm.labels[oldID]; ok {
|
||||
newTm.labels[newID] = label
|
||||
}
|
||||
|
||||
case *ModelChangeEntry:
|
||||
newEntry = &ModelChangeEntry{
|
||||
Entry: Entry{
|
||||
Type: EntryTypeModelChange,
|
||||
ID: newID,
|
||||
ParentID: prevNewID,
|
||||
Timestamp: e.Timestamp,
|
||||
},
|
||||
Provider: e.Provider,
|
||||
ModelID: e.ModelID,
|
||||
}
|
||||
|
||||
case *LabelEntry:
|
||||
// Remap the target ID if it's in our copied branch.
|
||||
newTargetID := e.TargetID
|
||||
if mapped, ok := idMap[e.TargetID]; ok {
|
||||
newTargetID = mapped
|
||||
}
|
||||
newEntry = &LabelEntry{
|
||||
Entry: Entry{
|
||||
Type: EntryTypeLabel,
|
||||
ID: newID,
|
||||
ParentID: prevNewID,
|
||||
Timestamp: e.Timestamp,
|
||||
},
|
||||
TargetID: newTargetID,
|
||||
Label: e.Label,
|
||||
}
|
||||
|
||||
case *SessionInfoEntry:
|
||||
newEntry = &SessionInfoEntry{
|
||||
Entry: Entry{
|
||||
Type: EntryTypeSessionInfo,
|
||||
ID: newID,
|
||||
ParentID: prevNewID,
|
||||
Timestamp: e.Timestamp,
|
||||
},
|
||||
Name: e.Name,
|
||||
}
|
||||
newTm.sessionName = e.Name
|
||||
|
||||
case *ExtensionDataEntry:
|
||||
newEntry = &ExtensionDataEntry{
|
||||
Entry: Entry{
|
||||
Type: EntryTypeExtensionData,
|
||||
ID: newID,
|
||||
ParentID: prevNewID,
|
||||
Timestamp: e.Timestamp,
|
||||
},
|
||||
ExtType: e.ExtType,
|
||||
Data: e.Data,
|
||||
}
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
// Remap the from ID if it's in our copied branch.
|
||||
newFromID := e.FromID
|
||||
if mapped, ok := idMap[e.FromID]; ok {
|
||||
newFromID = mapped
|
||||
}
|
||||
newEntry = &BranchSummaryEntry{
|
||||
Entry: Entry{
|
||||
Type: EntryTypeBranchSummary,
|
||||
ID: newID,
|
||||
ParentID: prevNewID,
|
||||
Timestamp: e.Timestamp,
|
||||
},
|
||||
FromID: newFromID,
|
||||
Summary: e.Summary,
|
||||
}
|
||||
|
||||
case *CompactionEntry:
|
||||
// Remap the first kept entry ID if it's in our copied branch.
|
||||
newFirstKeptID := e.FirstKeptEntryID
|
||||
if mapped, ok := idMap[e.FirstKeptEntryID]; ok {
|
||||
newFirstKeptID = mapped
|
||||
}
|
||||
newEntry = &CompactionEntry{
|
||||
Entry: Entry{
|
||||
Type: EntryTypeCompaction,
|
||||
ID: newID,
|
||||
ParentID: prevNewID,
|
||||
Timestamp: e.Timestamp,
|
||||
},
|
||||
Summary: e.Summary,
|
||||
FirstKeptEntryID: newFirstKeptID,
|
||||
TokensBefore: e.TokensBefore,
|
||||
TokensAfter: e.TokensAfter,
|
||||
MessagesRemoved: e.MessagesRemoved,
|
||||
ReadFiles: e.ReadFiles,
|
||||
ModifiedFiles: e.ModifiedFiles,
|
||||
}
|
||||
}
|
||||
|
||||
if newEntry != nil {
|
||||
if err := newTm.appendAndPersist(newEntry); err != nil {
|
||||
_ = f.Close()
|
||||
return nil, fmt.Errorf("failed to copy entry: %w", err)
|
||||
}
|
||||
prevNewID = newID
|
||||
}
|
||||
}
|
||||
|
||||
// Set the leaf to the last entry in the new session.
|
||||
newTm.leafID = prevNewID
|
||||
|
||||
return newTm, nil
|
||||
}
|
||||
|
||||
// OpenTreeSession opens an existing JSONL session file.
|
||||
func OpenTreeSession(path string) (*TreeManager, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
@@ -328,11 +509,19 @@ func (tm *TreeManager) AppendExtensionData(extType, data string) (string, error)
|
||||
// AppendCompaction adds a compaction entry to the tree. The entry records
|
||||
// the summary and the ID of the first entry that should be preserved in the
|
||||
// LLM context. Messages before that entry are replaced by the summary.
|
||||
//
|
||||
// The compaction entry becomes a new "root" for the post-compaction branch
|
||||
// with no parent (empty ParentID). This breaks the parent chain so that old
|
||||
// compacted messages are no longer traversed when building context. The kept
|
||||
// messages are explicitly collected via FirstKeptEntryID in BuildContext.
|
||||
func (tm *TreeManager) AppendCompaction(summary, firstKeptEntryID string, tokensBefore, tokensAfter, messagesRemoved int, readFiles, modifiedFiles []string) (string, error) {
|
||||
tm.mu.Lock()
|
||||
defer tm.mu.Unlock()
|
||||
|
||||
entry := NewCompactionEntry(tm.leafID, summary, firstKeptEntryID, tokensBefore, tokensAfter, messagesRemoved, readFiles, modifiedFiles)
|
||||
// The compaction entry has no parent, making it a new "root" for the
|
||||
// post-compaction branch. This ensures old compacted messages are not
|
||||
// traversed when walking from the current leaf.
|
||||
entry := NewCompactionEntry("", summary, firstKeptEntryID, tokensBefore, tokensAfter, messagesRemoved, readFiles, modifiedFiles)
|
||||
if err := tm.appendAndPersist(entry); err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -502,14 +691,18 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
// Find the last compaction entry on this branch — it determines
|
||||
// which older messages are replaced by the summary.
|
||||
var lastCompaction *CompactionEntry
|
||||
var compactionIndex = -1
|
||||
for i := len(branch) - 1; i >= 0; i-- {
|
||||
if c, ok := branch[i].(*CompactionEntry); ok {
|
||||
lastCompaction = c
|
||||
compactionIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If there is a compaction, inject the summary first.
|
||||
// If there is a compaction, inject the summary first and collect
|
||||
// the kept messages starting from FirstKeptEntryID (since the
|
||||
// compaction entry's parent chain doesn't include them).
|
||||
if lastCompaction != nil {
|
||||
messages = append(messages, fantasy.Message{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
@@ -519,21 +712,104 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Determine whether to skip entries (everything before firstKeptEntryID).
|
||||
skipping := lastCompaction != nil
|
||||
for _, entry := range branch {
|
||||
// Once we reach the first kept entry, stop skipping.
|
||||
if skipping {
|
||||
entryID := tm.EntryID(entry)
|
||||
if entryID == lastCompaction.FirstKeptEntryID {
|
||||
skipping = false
|
||||
} else {
|
||||
// Collect entries from the compaction entry itself (at compactionIndex)
|
||||
// and any entries before it in the branch (newer messages).
|
||||
for i := compactionIndex; i < len(branch); i++ {
|
||||
entry := branch[i]
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue // skip malformed entries
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
messages = append(messages, msgs...)
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
// Convert branch summary to a user message for context.
|
||||
if e.Summary != "" {
|
||||
messages = append(messages, fantasy.Message{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{
|
||||
Text: fmt.Sprintf("[Branch context: %s]", e.Summary),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
case *ModelChangeEntry:
|
||||
provider = e.Provider
|
||||
modelID = e.ModelID
|
||||
|
||||
case *CompactionEntry:
|
||||
// Already handled above (summary injected).
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Now collect the kept messages starting from FirstKeptEntryID.
|
||||
// These are not in the current branch because the compaction entry
|
||||
// is parented to the first kept entry's parent, not the first kept entry.
|
||||
// We iterate through entries in order (not using getBranchLocked) to avoid
|
||||
// walking back to old compacted messages.
|
||||
// We stop when we reach the compaction entry to avoid double-counting
|
||||
// messages that were added after the compaction.
|
||||
if lastCompaction.FirstKeptEntryID != "" {
|
||||
found := false
|
||||
for _, entry := range tm.entries {
|
||||
entryID := tm.EntryID(entry)
|
||||
|
||||
// Skip entries until we reach the first kept entry.
|
||||
if !found {
|
||||
if entryID == lastCompaction.FirstKeptEntryID {
|
||||
found = true
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Stop when we reach the compaction entry itself.
|
||||
// Messages after the compaction are collected from the branch walk above.
|
||||
if entryID == lastCompaction.ID {
|
||||
break
|
||||
}
|
||||
|
||||
// Process this kept entry.
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
messages = append(messages, msgs...)
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
if e.Summary != "" {
|
||||
messages = append(messages, fantasy.Message{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{
|
||||
Text: fmt.Sprintf("[Branch context: %s]", e.Summary),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
case *ModelChangeEntry:
|
||||
provider = e.Provider
|
||||
modelID = e.ModelID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return messages, provider, modelID
|
||||
}
|
||||
|
||||
// No compaction - process the entire branch normally.
|
||||
for _, entry := range branch {
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
@@ -559,10 +835,6 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
case *ModelChangeEntry:
|
||||
provider = e.Provider
|
||||
modelID = e.ModelID
|
||||
|
||||
case *CompactionEntry:
|
||||
// Already handled above (the last one on the branch).
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
@@ -672,31 +944,92 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
|
||||
|
||||
// Find the last compaction entry for skip logic.
|
||||
var lastCompaction *CompactionEntry
|
||||
var compactionIndex = -1
|
||||
for i := len(branch) - 1; i >= 0; i-- {
|
||||
if c, ok := branch[i].(*CompactionEntry); ok {
|
||||
lastCompaction = c
|
||||
compactionIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
var ids []string
|
||||
|
||||
// If there's a compaction summary injected, it has no entry ID.
|
||||
// If there's a compaction, we need to collect IDs from:
|
||||
// 1. Entries after the compaction entry in the branch (newer messages)
|
||||
// 2. Entries from FirstKeptEntryID onwards (kept messages)
|
||||
if lastCompaction != nil {
|
||||
ids = append(ids, "") // placeholder for the summary system message
|
||||
}
|
||||
// Placeholder for the summary system message (no entry ID).
|
||||
ids = append(ids, "")
|
||||
|
||||
skipping := lastCompaction != nil
|
||||
for _, entry := range branch {
|
||||
if skipping {
|
||||
entryID := tm.EntryID(entry)
|
||||
if entryID == lastCompaction.FirstKeptEntryID {
|
||||
skipping = false
|
||||
} else {
|
||||
continue
|
||||
// Collect IDs from entries after the compaction entry (newer messages).
|
||||
for i := compactionIndex + 1; i < len(branch); i++ {
|
||||
entry := branch[i]
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
for range msgs {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
if e.Summary != "" {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Collect IDs from the kept messages starting at FirstKeptEntryID.
|
||||
// We iterate through entries in order (not using getBranchLocked) to avoid
|
||||
// walking back to old compacted messages.
|
||||
// We stop when we reach the compaction entry to avoid double-counting.
|
||||
if lastCompaction.FirstKeptEntryID != "" {
|
||||
found := false
|
||||
for _, entry := range tm.entries {
|
||||
entryID := tm.EntryID(entry)
|
||||
|
||||
// Skip entries until we reach the first kept entry.
|
||||
if !found {
|
||||
if entryID == lastCompaction.FirstKeptEntryID {
|
||||
found = true
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Stop when we reach the compaction entry itself.
|
||||
if entryID == lastCompaction.ID {
|
||||
break
|
||||
}
|
||||
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
for range msgs {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
if e.Summary != "" {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
// No compaction - collect IDs from the entire branch.
|
||||
for _, entry := range branch {
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
@@ -712,9 +1045,6 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
|
||||
if e.Summary != "" {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
|
||||
case *CompactionEntry:
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -60,14 +60,16 @@ type MCPConnection struct {
|
||||
// creation, health monitoring, and cleanup. The pool runs background health checks
|
||||
// to proactively identify and remove unhealthy connections.
|
||||
type MCPConnectionPool struct {
|
||||
connections map[string]*MCPConnection
|
||||
config *ConnectionPoolConfig
|
||||
mu sync.RWMutex
|
||||
model fantasy.LanguageModel
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
debug bool
|
||||
debugLogger DebugLogger
|
||||
connections map[string]*MCPConnection
|
||||
config *ConnectionPoolConfig
|
||||
mu sync.RWMutex
|
||||
model fantasy.LanguageModel
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
debug bool
|
||||
debugLogger DebugLogger
|
||||
oauthFlow *OAuthFlowRunner
|
||||
tokenStoreFactory TokenStoreFactory // custom factory for per-server token stores (nil = default FileTokenStore)
|
||||
}
|
||||
|
||||
// NewMCPConnectionPool creates a new MCP connection pool with the specified configuration.
|
||||
@@ -75,19 +77,24 @@ type MCPConnectionPool struct {
|
||||
// goroutine for periodic health checks that runs until Close is called.
|
||||
// The model parameter is used for MCP servers that require sampling support.
|
||||
// Thread-safe for concurrent use immediately after creation.
|
||||
func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageModel, debug bool) *MCPConnectionPool {
|
||||
func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageModel, debug bool, authHandler MCPAuthHandler, tokenStoreFactory TokenStoreFactory) *MCPConnectionPool {
|
||||
if config == nil {
|
||||
config = DefaultConnectionPoolConfig()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pool := &MCPConnectionPool{
|
||||
connections: make(map[string]*MCPConnection),
|
||||
config: config,
|
||||
model: model,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
debug: debug,
|
||||
connections: make(map[string]*MCPConnection),
|
||||
config: config,
|
||||
model: model,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
debug: debug,
|
||||
tokenStoreFactory: tokenStoreFactory,
|
||||
}
|
||||
|
||||
if authHandler != nil {
|
||||
pool.oauthFlow = NewOAuthFlowRunner(authHandler)
|
||||
}
|
||||
|
||||
go pool.startHealthCheck()
|
||||
@@ -103,6 +110,15 @@ func (p *MCPConnectionPool) SetDebugLogger(logger DebugLogger) {
|
||||
p.debugLogger = logger
|
||||
}
|
||||
|
||||
// SetOAuthFlow sets the OAuth flow runner for the connection pool.
|
||||
// When set, the pool can trigger OAuth re-authorization when a tool call fails
|
||||
// with an OAuth error (e.g. expired token). Thread-safe and can be called at any time.
|
||||
func (p *MCPConnectionPool) SetOAuthFlow(flow *OAuthFlowRunner) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.oauthFlow = flow
|
||||
}
|
||||
|
||||
// GetConnection retrieves or creates a connection for the specified MCP server.
|
||||
// If a healthy, non-idle connection exists in the pool, it will be reused.
|
||||
// Otherwise, a new connection is created and added to the pool.
|
||||
@@ -230,18 +246,43 @@ func (p *MCPConnectionPool) performHealthCheck(ctx context.Context, conn *MCPCon
|
||||
|
||||
// createConnection creates a new connection
|
||||
func (p *MCPConnectionPool) createConnection(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (*MCPConnection, error) {
|
||||
client, err := p.createMCPClient(ctx, serverName, serverConfig)
|
||||
mcpClient, err := p.createMCPClient(ctx, serverName, serverConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// SSE transport can return OAuth error during Start()
|
||||
if p.oauthFlow != nil && IsOAuthError(err) {
|
||||
if flowErr := p.oauthFlow.RunAuthFlow(ctx, serverName, err); flowErr != nil {
|
||||
return nil, fmt.Errorf("OAuth authorization failed: %w", flowErr)
|
||||
}
|
||||
// Retry after successful auth
|
||||
mcpClient, err = p.createMCPClient(ctx, serverName, serverConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := p.initializeClient(ctx, client); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, err
|
||||
if err := p.initializeClient(ctx, mcpClient); err != nil {
|
||||
// Streamable HTTP transport returns OAuth error during Initialize()
|
||||
if p.oauthFlow != nil && IsOAuthError(err) {
|
||||
if flowErr := p.oauthFlow.RunAuthFlow(ctx, serverName, err); flowErr != nil {
|
||||
_ = mcpClient.Close()
|
||||
return nil, fmt.Errorf("OAuth authorization failed: %w", flowErr)
|
||||
}
|
||||
// Retry initialization after successful auth
|
||||
if err := p.initializeClient(ctx, mcpClient); err != nil {
|
||||
_ = mcpClient.Close()
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
_ = mcpClient.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
conn := &MCPConnection{
|
||||
client: client,
|
||||
client: mcpClient,
|
||||
serverName: serverName,
|
||||
serverConfig: serverConfig,
|
||||
lastUsed: time.Now(),
|
||||
@@ -323,13 +364,29 @@ func (p *MCPConnectionPool) createSSEClient(ctx context.Context, serverConfig co
|
||||
}
|
||||
}
|
||||
|
||||
// Enable OAuth for remote transports when an auth handler is configured.
|
||||
// The OAuthConfig uses PKCE and the handler's redirect URI. Client ID and
|
||||
// scopes are discovered automatically via dynamic client registration and
|
||||
// server metadata (RFC 9728).
|
||||
if p.oauthFlow != nil {
|
||||
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
|
||||
if tsErr != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
|
||||
}
|
||||
options = append(options, transport.WithOAuth(transport.OAuthConfig{
|
||||
RedirectURI: p.oauthFlow.handler.RedirectURI(),
|
||||
PKCEEnabled: true,
|
||||
TokenStore: tokenStore,
|
||||
}))
|
||||
}
|
||||
|
||||
sseClient, err := client.NewSSEMCPClient(serverConfig.URL, options...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := sseClient.Start(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to start SSE client: %v", err)
|
||||
return nil, fmt.Errorf("failed to start SSE client: %w", err)
|
||||
}
|
||||
|
||||
return sseClient, nil
|
||||
@@ -354,18 +411,44 @@ func (p *MCPConnectionPool) createStreamableClient(ctx context.Context, serverCo
|
||||
}
|
||||
}
|
||||
|
||||
// Enable OAuth for remote transports when an auth handler is configured.
|
||||
// The OAuthConfig uses PKCE and the handler's redirect URI. Client ID and
|
||||
// scopes are discovered automatically via dynamic client registration and
|
||||
// server metadata (RFC 9728).
|
||||
if p.oauthFlow != nil {
|
||||
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
|
||||
if tsErr != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
|
||||
}
|
||||
options = append(options, transport.WithHTTPOAuth(transport.OAuthConfig{
|
||||
RedirectURI: p.oauthFlow.handler.RedirectURI(),
|
||||
PKCEEnabled: true,
|
||||
TokenStore: tokenStore,
|
||||
}))
|
||||
}
|
||||
|
||||
streamableClient, err := client.NewStreamableHttpClient(serverConfig.URL, options...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := streamableClient.Start(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to start streamable HTTP client: %v", err)
|
||||
return nil, fmt.Errorf("failed to start streamable HTTP client: %w", err)
|
||||
}
|
||||
|
||||
return streamableClient, nil
|
||||
}
|
||||
|
||||
// createTokenStore creates a token store for the given server URL.
|
||||
// If a custom TokenStoreFactory is configured, it is used; otherwise the
|
||||
// default file-backed token store is created.
|
||||
func (p *MCPConnectionPool) createTokenStore(serverURL string) (transport.TokenStore, error) {
|
||||
if p.tokenStoreFactory != nil {
|
||||
return p.tokenStoreFactory(serverURL)
|
||||
}
|
||||
return NewFileTokenStore(serverURL)
|
||||
}
|
||||
|
||||
// initializeClient initializes the client
|
||||
func (p *MCPConnectionPool) initializeClient(ctx context.Context, client client.MCPClient) error {
|
||||
initCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
@@ -381,7 +464,7 @@ func (p *MCPConnectionPool) initializeClient(ctx context.Context, client client.
|
||||
|
||||
_, err := client.Initialize(initCtx, initRequest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialization timeout or failed: %v", err)
|
||||
return fmt.Errorf("initialization timeout or failed: %w", err)
|
||||
}
|
||||
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
@@ -512,6 +595,27 @@ func (p *MCPConnectionPool) GetClients() map[string]client.MCPClient {
|
||||
return clients
|
||||
}
|
||||
|
||||
// RemoveConnection closes and removes a single connection from the pool.
|
||||
// Returns an error if the connection does not exist or if closing fails.
|
||||
// Thread-safe for concurrent use.
|
||||
func (p *MCPConnectionPool) RemoveConnection(serverName string) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
conn, exists := p.connections[serverName]
|
||||
if !exists {
|
||||
return fmt.Errorf("connection %q not found in pool", serverName)
|
||||
}
|
||||
|
||||
err := conn.client.Close()
|
||||
delete(p.connections, serverName)
|
||||
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
p.debugLogger.LogDebug(fmt.Sprintf("[POOL] Removed connection %s", serverName))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Close gracefully shuts down the connection pool, closing all client connections
|
||||
// and stopping the background health check goroutine. It attempts to close all
|
||||
// connections even if some fail, logging any errors encountered.
|
||||
@@ -539,6 +643,9 @@ func (p *MCPConnectionPool) Close() error {
|
||||
|
||||
// isConnectionError checks if the error is connection-related
|
||||
func isConnectionError(err error) bool {
|
||||
if IsOAuthError(err) {
|
||||
return false // OAuth errors are recoverable, not connection failures
|
||||
}
|
||||
errStr := err.Error()
|
||||
return strings.Contains(errStr, "Connection not found") ||
|
||||
strings.Contains(errStr, "transport error") ||
|
||||
|
||||
@@ -59,9 +59,30 @@ func (t *mcpFantasyTool) Run(ctx context.Context, call fantasy.ToolCall) (fantas
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
// Mark connection as unhealthy for automatic recovery
|
||||
t.mapping.manager.connectionPool.HandleConnectionError(t.mapping.serverName, err)
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("failed to call mcp tool: %w", err)
|
||||
// Handle OAuth re-authorization: token may have expired mid-session.
|
||||
if t.mapping.manager.connectionPool.oauthFlow != nil && IsOAuthError(err) {
|
||||
if flowErr := t.mapping.manager.connectionPool.oauthFlow.RunAuthFlow(ctx, t.mapping.serverName, err); flowErr != nil {
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", t.mapping.originalName, flowErr)
|
||||
}
|
||||
// Retry the tool call after successful re-auth.
|
||||
result, err = conn.client.CallTool(ctx, mcp.CallToolRequest{
|
||||
Request: mcp.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: mcp.CallToolParams{
|
||||
Name: t.mapping.originalName,
|
||||
Arguments: arguments,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.mapping.manager.connectionPool.HandleConnectionError(t.mapping.serverName, err)
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("failed to call mcp tool after re-auth: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Mark connection as unhealthy for automatic recovery
|
||||
t.mapping.manager.connectionPool.HandleConnectionError(t.mapping.serverName, err)
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("failed to call mcp tool: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Marshal the MCP result to JSON string
|
||||
|
||||
+230
-28
@@ -4,8 +4,10 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
@@ -18,13 +20,25 @@ import (
|
||||
// pooling, health checks, tool name prefixing to avoid conflicts, and sampling support for LLM interactions.
|
||||
// Thread-safe for concurrent tool invocations.
|
||||
type MCPToolManager struct {
|
||||
connectionPool *MCPConnectionPool
|
||||
tools []fantasy.AgentTool
|
||||
toolMap map[string]*toolMapping // maps prefixed tool names to their server and original name
|
||||
model fantasy.LanguageModel // LLM model for sampling
|
||||
config *config.Config
|
||||
debug bool
|
||||
debugLogger DebugLogger
|
||||
connectionPool *MCPConnectionPool
|
||||
tools []fantasy.AgentTool
|
||||
toolMap map[string]*toolMapping // maps prefixed tool names to their server and original name
|
||||
mu sync.Mutex // protects tools and toolMap during parallel loading
|
||||
model fantasy.LanguageModel // LLM model for sampling
|
||||
authHandler MCPAuthHandler // OAuth handler for remote servers (nil = no OAuth)
|
||||
tokenStoreFactory TokenStoreFactory // factory for creating per-server token stores (nil = default FileTokenStore)
|
||||
config *config.Config
|
||||
debug bool
|
||||
debugLogger DebugLogger
|
||||
|
||||
// onServerLoaded, if non-nil, is called when each server finishes loading.
|
||||
// Called with server name, tool count, and error (nil on success).
|
||||
onServerLoaded func(serverName string, toolCount int, err error)
|
||||
|
||||
// onToolsChanged, if non-nil, is called after AddServer or RemoveServer
|
||||
// mutates the tool list. The agent layer uses this to trigger a
|
||||
// rebuildFantasyAgent so the LLM sees the updated tools.
|
||||
onToolsChanged func()
|
||||
}
|
||||
|
||||
// toolMapping stores the mapping between prefixed tool names and their original details
|
||||
@@ -53,6 +67,22 @@ func (m *MCPToolManager) SetModel(model fantasy.LanguageModel) {
|
||||
m.model = model
|
||||
}
|
||||
|
||||
// SetAuthHandler sets the OAuth handler for remote MCP server authentication.
|
||||
// When set, remote transports (streamable HTTP, SSE) are configured with OAuth
|
||||
// support, enabling automatic authorization flows when servers require authentication.
|
||||
// This method should be called before LoadTools.
|
||||
func (m *MCPToolManager) SetAuthHandler(handler MCPAuthHandler) {
|
||||
m.authHandler = handler
|
||||
}
|
||||
|
||||
// SetTokenStoreFactory sets a custom factory for creating per-server OAuth token
|
||||
// stores. When set, the factory is called for each remote MCP server instead of
|
||||
// using the default file-based token store. This method should be called before
|
||||
// LoadTools.
|
||||
func (m *MCPToolManager) SetTokenStoreFactory(factory TokenStoreFactory) {
|
||||
m.tokenStoreFactory = factory
|
||||
}
|
||||
|
||||
// SetDebugLogger sets the debug logger for the tool manager.
|
||||
// The logger will be used to output detailed debugging information about MCP connections,
|
||||
// tool loading, and execution. If a connection pool exists, it will also be configured
|
||||
@@ -64,48 +94,207 @@ func (m *MCPToolManager) SetDebugLogger(logger DebugLogger) {
|
||||
}
|
||||
}
|
||||
|
||||
// SetOnServerLoaded sets the callback that's invoked when each MCP server finishes
|
||||
// loading. The callback receives the server name, tool count, and any error.
|
||||
// Call this before LoadTools to receive per-server notifications.
|
||||
func (m *MCPToolManager) SetOnServerLoaded(cb func(serverName string, toolCount int, err error)) {
|
||||
m.onServerLoaded = cb
|
||||
}
|
||||
|
||||
// SetOnToolsChanged sets the callback that's invoked after AddServer or
|
||||
// RemoveServer mutates the tool list. The agent layer uses this to trigger
|
||||
// a rebuild of the fantasy agent so the LLM sees the updated tool set.
|
||||
func (m *MCPToolManager) SetOnToolsChanged(cb func()) {
|
||||
m.onToolsChanged = cb
|
||||
}
|
||||
|
||||
// AddServer connects to a new MCP server at runtime and loads its tools.
|
||||
// The server's tools are immediately available to the agent after this call.
|
||||
// Returns the number of tools loaded from the server.
|
||||
//
|
||||
// If the connection pool has not been initialised yet (i.e. LoadTools was never
|
||||
// called), AddServer creates one automatically using the manager's current
|
||||
// configuration.
|
||||
//
|
||||
// Returns an error if a server with the same name is already loaded, or if
|
||||
// the connection or tool loading fails.
|
||||
func (m *MCPToolManager) AddServer(ctx context.Context, name string, cfg config.MCPServerConfig) (int, error) {
|
||||
m.mu.Lock()
|
||||
// Check for duplicate.
|
||||
if _, exists := m.toolMap[name+"__"]; exists {
|
||||
m.mu.Unlock()
|
||||
return 0, fmt.Errorf("MCP server %q is already loaded", name)
|
||||
}
|
||||
// More thorough duplicate check: scan toolMap for any key with the server prefix.
|
||||
prefix := name + "__"
|
||||
for k := range m.toolMap {
|
||||
if len(k) >= len(prefix) && k[:len(prefix)] == prefix {
|
||||
m.mu.Unlock()
|
||||
return 0, fmt.Errorf("MCP server %q is already loaded", name)
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
// Lazily create the connection pool if LoadTools was never called.
|
||||
m.ensureConnectionPool()
|
||||
|
||||
count, err := m.loadServerTools(ctx, name, cfg)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to add MCP server %q: %w", name, err)
|
||||
}
|
||||
|
||||
// Notify listeners.
|
||||
if m.onServerLoaded != nil {
|
||||
m.onServerLoaded(name, count, nil)
|
||||
}
|
||||
if m.onToolsChanged != nil {
|
||||
m.onToolsChanged()
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// RemoveServer disconnects an MCP server and removes all its tools.
|
||||
// After this call the agent will no longer see or be able to call tools from
|
||||
// the named server. Returns an error if the server is not loaded.
|
||||
func (m *MCPToolManager) RemoveServer(name string) error {
|
||||
prefix := name + "__"
|
||||
|
||||
m.mu.Lock()
|
||||
|
||||
// Check the server actually has tools loaded.
|
||||
found := false
|
||||
for k := range m.toolMap {
|
||||
if len(k) >= len(prefix) && k[:len(prefix)] == prefix {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
m.mu.Unlock()
|
||||
return fmt.Errorf("MCP server %q is not loaded", name)
|
||||
}
|
||||
|
||||
// Remove tools belonging to this server.
|
||||
newTools := make([]fantasy.AgentTool, 0, len(m.tools))
|
||||
for _, t := range m.tools {
|
||||
if len(t.Info().Name) < len(prefix) || t.Info().Name[:len(prefix)] != prefix {
|
||||
newTools = append(newTools, t)
|
||||
}
|
||||
}
|
||||
m.tools = newTools
|
||||
|
||||
// Remove tool mappings.
|
||||
for k := range m.toolMap {
|
||||
if len(k) >= len(prefix) && k[:len(prefix)] == prefix {
|
||||
delete(m.toolMap, k)
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
// Close the connection in the pool (best-effort).
|
||||
if m.connectionPool != nil {
|
||||
_ = m.connectionPool.RemoveConnection(name)
|
||||
}
|
||||
|
||||
if m.onToolsChanged != nil {
|
||||
m.onToolsChanged()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureConnectionPool lazily creates a connection pool if one does not exist.
|
||||
// This allows AddServer to work even if LoadTools was never called.
|
||||
func (m *MCPToolManager) ensureConnectionPool() {
|
||||
if m.connectionPool != nil {
|
||||
return
|
||||
}
|
||||
debug := false
|
||||
if m.config != nil {
|
||||
debug = m.config.Debug
|
||||
}
|
||||
if m.debugLogger == nil {
|
||||
m.debugLogger = NewSimpleDebugLogger(debug)
|
||||
}
|
||||
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, debug, m.authHandler, m.tokenStoreFactory)
|
||||
m.connectionPool.SetDebugLogger(m.debugLogger)
|
||||
}
|
||||
|
||||
// LoadTools loads tools from all configured MCP servers based on the provided configuration.
|
||||
// It initializes the connection pool, connects to each configured server, and loads their tools.
|
||||
// Tools from different servers are prefixed with the server name to avoid naming conflicts.
|
||||
// Returns an error only if all configured servers fail to load; partial failures are logged as warnings.
|
||||
// This method is thread-safe and idempotent.
|
||||
func (m *MCPToolManager) LoadTools(ctx context.Context, config *config.Config) error {
|
||||
func (m *MCPToolManager) LoadTools(ctx context.Context, cfg *config.Config) error {
|
||||
// Initialize connection pool
|
||||
m.config = config
|
||||
m.debug = config.Debug
|
||||
m.config = cfg
|
||||
m.debug = cfg.Debug
|
||||
if m.debugLogger == nil {
|
||||
m.debugLogger = NewSimpleDebugLogger(config.Debug)
|
||||
m.debugLogger = NewSimpleDebugLogger(cfg.Debug)
|
||||
}
|
||||
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, config.Debug)
|
||||
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, cfg.Debug, m.authHandler, m.tokenStoreFactory)
|
||||
m.connectionPool.SetDebugLogger(m.debugLogger)
|
||||
|
||||
var loadErrors []string
|
||||
// Load all servers in parallel. Each server connection (subprocess
|
||||
// spawn, MCP initialize handshake, ListTools) is independent and
|
||||
// typically dominated by process startup latency. Running them
|
||||
// concurrently reduces total wall-clock time from O(n * avg) to
|
||||
// O(max).
|
||||
type serverResult struct {
|
||||
name string
|
||||
err error
|
||||
}
|
||||
|
||||
for serverName, serverConfig := range config.MCPServers {
|
||||
if err := m.loadServerTools(ctx, serverName, serverConfig); err != nil {
|
||||
loadErrors = append(loadErrors, fmt.Sprintf("server %s: %v", serverName, err))
|
||||
fmt.Printf("Warning: Failed to load MCP server '%s': %v\n", serverName, err)
|
||||
continue
|
||||
results := make(chan serverResult, len(cfg.MCPServers))
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for serverName, serverConfig := range cfg.MCPServers {
|
||||
wg.Add(1)
|
||||
go func(name string, sc config.MCPServerConfig) {
|
||||
defer wg.Done()
|
||||
count, err := m.loadServerTools(ctx, name, sc)
|
||||
results <- serverResult{name: name, err: err}
|
||||
// Notify callback if set (for real-time UI updates).
|
||||
if m.onServerLoaded != nil {
|
||||
m.onServerLoaded(name, count, err)
|
||||
}
|
||||
}(serverName, serverConfig)
|
||||
}
|
||||
|
||||
// Close results channel once all goroutines finish.
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(results)
|
||||
}()
|
||||
|
||||
var loadErrors []string
|
||||
for r := range results {
|
||||
if r.err != nil {
|
||||
loadErrors = append(loadErrors, fmt.Sprintf("server %s: %v", r.name, r.err))
|
||||
fmt.Printf("Warning: Failed to load MCP server '%s': %v\n", r.name, r.err)
|
||||
}
|
||||
}
|
||||
|
||||
// If all servers failed to load, return an error
|
||||
if len(loadErrors) == len(config.MCPServers) && len(config.MCPServers) > 0 {
|
||||
if len(loadErrors) == len(cfg.MCPServers) && len(cfg.MCPServers) > 0 {
|
||||
return fmt.Errorf("all MCP servers failed to load: %s", strings.Join(loadErrors, "; "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadServerTools loads tools from a single MCP server
|
||||
func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) error {
|
||||
// loadServerTools loads tools from a single MCP server.
|
||||
// Thread-safe: may be called concurrently for different servers.
|
||||
// Returns the number of tools loaded from this server, or -1 on error.
|
||||
func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (int, error) {
|
||||
// Add debug logging
|
||||
m.debugLogConnectionInfo(serverName, serverConfig)
|
||||
|
||||
// Get connection from pool
|
||||
conn, err := m.connectionPool.GetConnection(ctx, serverName, serverConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get connection from pool: %v", err)
|
||||
return -1, fmt.Errorf("failed to get connection from pool: %v", err)
|
||||
}
|
||||
|
||||
// Get tools from this server
|
||||
@@ -113,7 +302,7 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
|
||||
if err != nil {
|
||||
// Handle connection error
|
||||
m.connectionPool.HandleConnectionError(serverName, err)
|
||||
return fmt.Errorf("failed to list tools: %v", err)
|
||||
return -1, fmt.Errorf("failed to list tools: %v", err)
|
||||
}
|
||||
|
||||
// Create name set for allowed tools
|
||||
@@ -125,6 +314,10 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
|
||||
}
|
||||
}
|
||||
|
||||
// Build tools locally before acquiring the lock.
|
||||
var localTools []fantasy.AgentTool
|
||||
localMap := make(map[string]*toolMapping)
|
||||
|
||||
// Convert MCP tools to fantasy AgentTools with prefixed names
|
||||
for _, mcpTool := range listResults.Tools {
|
||||
// Filter tools based on allowedTools/excludedTools
|
||||
@@ -142,7 +335,7 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
|
||||
// Convert MCP InputSchema to map[string]any for fantasy ToolInfo
|
||||
marshaledSchema, err := json.Marshal(mcpTool.InputSchema)
|
||||
if err != nil {
|
||||
return fmt.Errorf("conv mcp tool input schema fail(marshal): %w, tool name: %s", err, mcpTool.Name)
|
||||
return -1, fmt.Errorf("conv mcp tool input schema fail(marshal): %w, tool name: %s", err, mcpTool.Name)
|
||||
}
|
||||
|
||||
// Fix for JSON Schema draft-07 vs draft-04 compatibility
|
||||
@@ -151,7 +344,7 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
|
||||
// Parse into map[string]any for fantasy's parameters format
|
||||
var schemaMap map[string]any
|
||||
if err := json.Unmarshal(marshaledSchema, &schemaMap); err != nil {
|
||||
return fmt.Errorf("conv mcp tool input schema fail(unmarshal): %w, tool name: %s", err, mcpTool.Name)
|
||||
return -1, fmt.Errorf("conv mcp tool input schema fail(unmarshal): %w, tool name: %s", err, mcpTool.Name)
|
||||
}
|
||||
|
||||
// Extract properties and required from the schema
|
||||
@@ -184,7 +377,7 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
|
||||
serverConfig: serverConfig,
|
||||
manager: m,
|
||||
}
|
||||
m.toolMap[prefixedName] = mapping
|
||||
localMap[prefixedName] = mapping
|
||||
|
||||
// Create fantasy AgentTool
|
||||
fantasyTool := &mcpFantasyTool{
|
||||
@@ -197,10 +390,16 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
|
||||
mapping: mapping,
|
||||
}
|
||||
|
||||
m.tools = append(m.tools, fantasyTool)
|
||||
localTools = append(localTools, fantasyTool)
|
||||
}
|
||||
|
||||
return nil
|
||||
// Merge into the manager under the lock.
|
||||
m.mu.Lock()
|
||||
maps.Copy(m.toolMap, localMap)
|
||||
m.tools = append(m.tools, localTools...)
|
||||
m.mu.Unlock()
|
||||
|
||||
return len(localTools), nil
|
||||
}
|
||||
|
||||
// GetTools returns all loaded tools as fantasy AgentTools from all configured MCP servers.
|
||||
@@ -225,6 +424,9 @@ func (m *MCPToolManager) GetLoadedServerNames() []string {
|
||||
// proper cleanup of stdio processes, network connections, and other resources.
|
||||
// It is safe to call Close multiple times.
|
||||
func (m *MCPToolManager) Close() error {
|
||||
if m.connectionPool == nil {
|
||||
return nil
|
||||
}
|
||||
return m.connectionPool.Close()
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,323 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
)
|
||||
|
||||
// testdataDir returns the absolute path to the testdata directory.
|
||||
func testdataDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
_, file, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
t.Fatal("cannot determine test file path")
|
||||
}
|
||||
return filepath.Join(filepath.Dir(file), "testdata")
|
||||
}
|
||||
|
||||
// echoServerConfig returns an MCPServerConfig for the test echo MCP server.
|
||||
func echoServerConfig(t *testing.T) config.MCPServerConfig {
|
||||
t.Helper()
|
||||
script := filepath.Join(testdataDir(t), "echo_server.py")
|
||||
if _, err := os.Stat(script); err != nil {
|
||||
t.Skipf("echo_server.py not found: %v", err)
|
||||
}
|
||||
return config.MCPServerConfig{
|
||||
Command: []string{"python3", script},
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_AddServer_Integration tests adding a real MCP server
|
||||
// at runtime and verifying tools are loaded.
|
||||
func TestMCPToolManager_AddServer_Integration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
manager := NewMCPToolManager()
|
||||
defer func() { _ = manager.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Track callbacks.
|
||||
var mu sync.Mutex
|
||||
var loadedServer string
|
||||
var loadedCount int
|
||||
toolsChangedCount := 0
|
||||
|
||||
manager.SetOnServerLoaded(func(name string, count int, err error) {
|
||||
mu.Lock()
|
||||
loadedServer = name
|
||||
loadedCount = count
|
||||
mu.Unlock()
|
||||
})
|
||||
manager.SetOnToolsChanged(func() {
|
||||
mu.Lock()
|
||||
toolsChangedCount++
|
||||
mu.Unlock()
|
||||
})
|
||||
|
||||
// Add the server.
|
||||
count, err := manager.AddServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddServer failed: %v", err)
|
||||
}
|
||||
|
||||
if count != 2 {
|
||||
t.Errorf("Expected 2 tools from echo server, got %d", count)
|
||||
}
|
||||
|
||||
// Verify callbacks fired.
|
||||
mu.Lock()
|
||||
if loadedServer != "echo" {
|
||||
t.Errorf("Expected onServerLoaded for 'echo', got %q", loadedServer)
|
||||
}
|
||||
if loadedCount != 2 {
|
||||
t.Errorf("Expected onServerLoaded count=2, got %d", loadedCount)
|
||||
}
|
||||
if toolsChangedCount != 1 {
|
||||
t.Errorf("Expected onToolsChanged called once, got %d", toolsChangedCount)
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
// Verify tools are accessible.
|
||||
tools := manager.GetTools()
|
||||
if len(tools) != 2 {
|
||||
t.Fatalf("Expected 2 tools, got %d", len(tools))
|
||||
}
|
||||
|
||||
// Verify tool names are prefixed.
|
||||
toolNames := make(map[string]bool)
|
||||
for _, tool := range tools {
|
||||
toolNames[tool.Info().Name] = true
|
||||
}
|
||||
if !toolNames["echo__echo"] {
|
||||
t.Error("Expected tool 'echo__echo'")
|
||||
}
|
||||
if !toolNames["echo__greet"] {
|
||||
t.Error("Expected tool 'echo__greet'")
|
||||
}
|
||||
|
||||
// Verify server appears in loaded names.
|
||||
names := manager.GetLoadedServerNames()
|
||||
if !slices.Contains(names, "echo") {
|
||||
t.Errorf("Expected 'echo' in loaded server names, got: %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_RemoveServer_Integration tests removing a real MCP server
|
||||
// and verifying tools are cleaned up.
|
||||
func TestMCPToolManager_RemoveServer_Integration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
manager := NewMCPToolManager()
|
||||
defer func() { _ = manager.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add the server first.
|
||||
count, err := manager.AddServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddServer failed: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Fatalf("Expected 2 tools, got %d", count)
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
toolsChangedCount := 0
|
||||
manager.SetOnToolsChanged(func() {
|
||||
mu.Lock()
|
||||
toolsChangedCount++
|
||||
mu.Unlock()
|
||||
})
|
||||
|
||||
// Remove the server.
|
||||
err = manager.RemoveServer("echo")
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveServer failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify tools are gone.
|
||||
tools := manager.GetTools()
|
||||
if len(tools) != 0 {
|
||||
t.Errorf("Expected 0 tools after removal, got %d", len(tools))
|
||||
}
|
||||
|
||||
// Verify callback fired.
|
||||
mu.Lock()
|
||||
if toolsChangedCount != 1 {
|
||||
t.Errorf("Expected onToolsChanged called once, got %d", toolsChangedCount)
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
// Verify server is gone from loaded names.
|
||||
names := manager.GetLoadedServerNames()
|
||||
for _, n := range names {
|
||||
if n == "echo" {
|
||||
t.Error("Server 'echo' should not appear in loaded names after removal")
|
||||
}
|
||||
}
|
||||
|
||||
// Removing again should error.
|
||||
err = manager.RemoveServer("echo")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error removing already-removed server")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not loaded") {
|
||||
t.Errorf("Expected 'not loaded' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_AddRemoveMultiple_Integration tests adding and removing
|
||||
// multiple servers, verifying tool isolation.
|
||||
func TestMCPToolManager_AddRemoveMultiple_Integration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
manager := NewMCPToolManager()
|
||||
defer func() { _ = manager.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add two servers with the same binary but different names.
|
||||
count1, err := manager.AddServer(ctx, "server-a", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddServer server-a failed: %v", err)
|
||||
}
|
||||
count2, err := manager.AddServer(ctx, "server-b", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddServer server-b failed: %v", err)
|
||||
}
|
||||
|
||||
totalTools := count1 + count2
|
||||
if totalTools != 4 {
|
||||
t.Fatalf("Expected 4 total tools (2+2), got %d", totalTools)
|
||||
}
|
||||
|
||||
tools := manager.GetTools()
|
||||
if len(tools) != 4 {
|
||||
t.Fatalf("Expected 4 tools, got %d", len(tools))
|
||||
}
|
||||
|
||||
// Remove server-a, verify server-b tools remain.
|
||||
err = manager.RemoveServer("server-a")
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveServer server-a failed: %v", err)
|
||||
}
|
||||
|
||||
tools = manager.GetTools()
|
||||
if len(tools) != 2 {
|
||||
t.Fatalf("Expected 2 tools after removing server-a, got %d", len(tools))
|
||||
}
|
||||
|
||||
// Remaining tools should all be from server-b.
|
||||
for _, tool := range tools {
|
||||
if !strings.HasPrefix(tool.Info().Name, "server-b__") {
|
||||
t.Errorf("Expected tool from server-b, got: %s", tool.Info().Name)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove server-b.
|
||||
err = manager.RemoveServer("server-b")
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveServer server-b failed: %v", err)
|
||||
}
|
||||
|
||||
tools = manager.GetTools()
|
||||
if len(tools) != 0 {
|
||||
t.Errorf("Expected 0 tools after removing all servers, got %d", len(tools))
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_AddServer_DuplicateDetection_Integration tests that
|
||||
// adding a server with the same name as an already loaded server errors.
|
||||
func TestMCPToolManager_AddServer_DuplicateDetection_Integration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
manager := NewMCPToolManager()
|
||||
defer func() { _ = manager.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add the server.
|
||||
_, err := manager.AddServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("First AddServer failed: %v", err)
|
||||
}
|
||||
|
||||
// Try to add again with the same name.
|
||||
_, err = manager.AddServer(ctx, "echo", cfg)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error adding duplicate server")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "already loaded") {
|
||||
t.Errorf("Expected 'already loaded' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_AddAfterRemove_Integration tests that a server can be
|
||||
// re-added after being removed.
|
||||
func TestMCPToolManager_AddAfterRemove_Integration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
manager := NewMCPToolManager()
|
||||
defer func() { _ = manager.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add, remove, re-add.
|
||||
_, err := manager.AddServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("First AddServer failed: %v", err)
|
||||
}
|
||||
|
||||
err = manager.RemoveServer("echo")
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveServer failed: %v", err)
|
||||
}
|
||||
|
||||
count, err := manager.AddServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Re-AddServer failed: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("Expected 2 tools on re-add, got %d", count)
|
||||
}
|
||||
|
||||
tools := manager.GetTools()
|
||||
if len(tools) != 2 {
|
||||
t.Errorf("Expected 2 tools after re-add, got %d", len(tools))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
)
|
||||
|
||||
// TestMCPToolManager_AddServer_DuplicateName verifies that adding a server
|
||||
// with a name that already exists returns an error.
|
||||
func TestMCPToolManager_AddServer_DuplicateName(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
|
||||
cfg := config.MCPServerConfig{
|
||||
Command: []string{"non-existent-command"},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// First add will fail (bad command), but let's test the duplicate detection
|
||||
// by simulating a loaded server via LoadTools first.
|
||||
loadCfg := &config.Config{
|
||||
MCPServers: map[string]config.MCPServerConfig{
|
||||
"test-server": cfg,
|
||||
},
|
||||
}
|
||||
// This will fail to load but creates the connection pool.
|
||||
_ = manager.LoadTools(ctx, loadCfg)
|
||||
|
||||
// Now try to add the same server name — the tools didn't load (bad command),
|
||||
// so AddServer should not find a duplicate and should fail with connection error.
|
||||
_, err := manager.AddServer(ctx, "test-server", cfg)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when adding server with bad command, got nil")
|
||||
}
|
||||
// It should be a connection error, not a duplicate error.
|
||||
if strings.Contains(err.Error(), "already loaded") {
|
||||
t.Fatalf("Should not report duplicate since server failed to load initially: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_RemoveServer_NotLoaded verifies that removing a server
|
||||
// that doesn't exist returns an appropriate error.
|
||||
func TestMCPToolManager_RemoveServer_NotLoaded(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
|
||||
err := manager.RemoveServer("nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when removing non-existent server, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not loaded") {
|
||||
t.Errorf("Expected 'not loaded' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_AddServer_CreatesConnectionPool verifies that AddServer
|
||||
// lazily creates a connection pool when LoadTools was never called.
|
||||
func TestMCPToolManager_AddServer_CreatesConnectionPool(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
|
||||
// Connection pool should be nil initially.
|
||||
if manager.connectionPool != nil {
|
||||
t.Fatal("Expected nil connection pool before any operation")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// AddServer with a bad command — should fail, but the pool should be created.
|
||||
_, err := manager.AddServer(ctx, "lazy-server", config.MCPServerConfig{
|
||||
Command: []string{"non-existent-command"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for bad command")
|
||||
}
|
||||
|
||||
// Connection pool should have been created.
|
||||
if manager.connectionPool == nil {
|
||||
t.Fatal("Expected connection pool to be created lazily by AddServer")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_OnToolsChanged_Callback verifies that the onToolsChanged
|
||||
// callback fires on RemoveServer (we can't easily test AddServer with a real
|
||||
// MCP server, but we can test the callback wiring).
|
||||
func TestMCPToolManager_OnToolsChanged_Callback(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
|
||||
var mu sync.Mutex
|
||||
callCount := 0
|
||||
manager.SetOnToolsChanged(func() {
|
||||
mu.Lock()
|
||||
callCount++
|
||||
mu.Unlock()
|
||||
})
|
||||
|
||||
// RemoveServer on non-existent should NOT fire callback.
|
||||
_ = manager.RemoveServer("nonexistent")
|
||||
|
||||
mu.Lock()
|
||||
if callCount != 0 {
|
||||
t.Errorf("Expected 0 callback calls for failed remove, got %d", callCount)
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
// TestMCPToolManager_Close_NilPool verifies Close is safe when the connection
|
||||
// pool was never initialized.
|
||||
func TestMCPToolManager_Close_NilPool(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
err := manager.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil error from Close with nil pool, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPConnectionPool_RemoveConnection_NotFound verifies that removing a
|
||||
// non-existent connection returns an error.
|
||||
func TestMCPConnectionPool_RemoveConnection_NotFound(t *testing.T) {
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), nil, false, nil, nil)
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
err := pool.RemoveConnection("nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for non-existent connection")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not found") {
|
||||
t.Errorf("Expected 'not found' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_EnsureConnectionPool_Idempotent verifies that
|
||||
// ensureConnectionPool doesn't recreate an existing pool.
|
||||
func TestMCPToolManager_EnsureConnectionPool_Idempotent(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
|
||||
// First call creates the pool.
|
||||
manager.ensureConnectionPool()
|
||||
pool1 := manager.connectionPool
|
||||
if pool1 == nil {
|
||||
t.Fatal("Expected pool to be created")
|
||||
}
|
||||
|
||||
// Second call should be a no-op.
|
||||
manager.ensureConnectionPool()
|
||||
pool2 := manager.connectionPool
|
||||
if pool1 != pool2 {
|
||||
t.Fatal("Expected ensureConnectionPool to be idempotent")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/client/transport"
|
||||
)
|
||||
|
||||
// MCPAuthHandler is the internal interface for handling MCP OAuth flows.
|
||||
// The SDK-level kit.MCPAuthHandler is adapted to this interface in cmd/root.go
|
||||
// or pkg/kit/kit.go, keeping the tools package decoupled from the SDK.
|
||||
type MCPAuthHandler interface {
|
||||
// RedirectURI returns the OAuth redirect URI for transport setup.
|
||||
RedirectURI() string
|
||||
// HandleAuth is called when a server requires OAuth authorization.
|
||||
// It receives the server name and the authorization URL the user must visit.
|
||||
// It returns the full callback URL (containing code and state query params)
|
||||
// after the user completes authorization.
|
||||
HandleAuth(ctx context.Context, serverName string, authURL string) (callbackURL string, err error)
|
||||
}
|
||||
|
||||
// TokenStoreFactory creates a transport.TokenStore for a given MCP server URL.
|
||||
// When provided to the connection pool, it is called once per remote MCP server
|
||||
// instead of using the default file-based token store. Implementations can
|
||||
// return any transport.TokenStore — in-memory, database-backed, encrypted, etc.
|
||||
type TokenStoreFactory func(serverURL string) (transport.TokenStore, error)
|
||||
|
||||
// OAuthFlowRunner handles the OAuth authorization flow when an MCP server
|
||||
// returns an OAuthAuthorizationRequiredError. It coordinates dynamic client
|
||||
// registration, PKCE generation, user authorization (via MCPAuthHandler),
|
||||
// and token exchange.
|
||||
type OAuthFlowRunner struct {
|
||||
handler MCPAuthHandler
|
||||
}
|
||||
|
||||
// NewOAuthFlowRunner creates a new OAuthFlowRunner with the given auth handler.
|
||||
func NewOAuthFlowRunner(handler MCPAuthHandler) *OAuthFlowRunner {
|
||||
return &OAuthFlowRunner{handler: handler}
|
||||
}
|
||||
|
||||
// RunAuthFlow executes the OAuth authorization flow for the given server.
|
||||
// It extracts the OAuthHandler from the error, performs dynamic client registration
|
||||
// if needed, generates PKCE parameters, delegates to the MCPAuthHandler for user
|
||||
// interaction, and exchanges the authorization code for a token.
|
||||
func (r *OAuthFlowRunner) RunAuthFlow(ctx context.Context, serverName string, authErr error) error {
|
||||
// Extract the OAuthHandler from the authorization-required error.
|
||||
oauthHandler := client.GetOAuthHandler(authErr)
|
||||
if oauthHandler == nil {
|
||||
return fmt.Errorf("oauth flow: failed to extract OAuth handler from error: %w", authErr)
|
||||
}
|
||||
|
||||
// Perform dynamic client registration if no client ID is configured yet.
|
||||
if oauthHandler.GetClientID() == "" {
|
||||
if err := oauthHandler.RegisterClient(ctx, "kit"); err != nil {
|
||||
return fmt.Errorf("oauth flow: dynamic client registration failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Generate PKCE code verifier and challenge.
|
||||
codeVerifier, err := client.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
return fmt.Errorf("oauth flow: failed to generate code verifier: %w", err)
|
||||
}
|
||||
codeChallenge := client.GenerateCodeChallenge(codeVerifier)
|
||||
|
||||
// Generate a random state parameter for CSRF protection.
|
||||
state, err := client.GenerateState()
|
||||
if err != nil {
|
||||
return fmt.Errorf("oauth flow: failed to generate state: %w", err)
|
||||
}
|
||||
|
||||
// Build the authorization URL the user needs to visit.
|
||||
authURL, err := oauthHandler.GetAuthorizationURL(ctx, state, codeChallenge)
|
||||
if err != nil {
|
||||
return fmt.Errorf("oauth flow: failed to get authorization URL: %w", err)
|
||||
}
|
||||
|
||||
// Delegate to the MCPAuthHandler for user-facing authorization (e.g. open
|
||||
// browser, wait for redirect). It returns the full callback URL containing
|
||||
// the authorization code and state.
|
||||
callbackURL, err := r.handler.HandleAuth(ctx, serverName, authURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("oauth flow: user authorization failed: %w", err)
|
||||
}
|
||||
|
||||
// Parse the callback URL to extract the authorization code and state.
|
||||
parsed, err := url.Parse(callbackURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("oauth flow: failed to parse callback URL: %w", err)
|
||||
}
|
||||
|
||||
code := parsed.Query().Get("code")
|
||||
returnedState := parsed.Query().Get("state")
|
||||
|
||||
if code == "" {
|
||||
return fmt.Errorf("oauth flow: callback URL missing 'code' parameter")
|
||||
}
|
||||
if returnedState == "" {
|
||||
return fmt.Errorf("oauth flow: callback URL missing 'state' parameter")
|
||||
}
|
||||
|
||||
// Exchange the authorization code for an access token.
|
||||
if err := oauthHandler.ProcessAuthorizationResponse(ctx, code, returnedState, codeVerifier); err != nil {
|
||||
return fmt.Errorf("oauth flow: token exchange failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsOAuthError returns true if the error is an OAuthAuthorizationRequiredError.
|
||||
func IsOAuthError(err error) bool {
|
||||
return client.IsOAuthAuthorizationRequiredError(err)
|
||||
}
|
||||
+111
@@ -0,0 +1,111 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Minimal MCP server over stdio for testing. Exposes one tool: echo."""
|
||||
import json
|
||||
import sys
|
||||
|
||||
|
||||
def read_message():
|
||||
"""Read a JSON-RPC message from stdin."""
|
||||
line = sys.stdin.readline()
|
||||
if not line:
|
||||
return None
|
||||
return json.loads(line.strip())
|
||||
|
||||
|
||||
def write_message(msg):
|
||||
"""Write a JSON-RPC message to stdout."""
|
||||
sys.stdout.write(json.dumps(msg) + "\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def handle(msg):
|
||||
method = msg.get("method", "")
|
||||
mid = msg.get("id")
|
||||
|
||||
if method == "initialize":
|
||||
write_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": mid,
|
||||
"result": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"serverInfo": {"name": "test-echo", "version": "1.0.0"},
|
||||
},
|
||||
})
|
||||
elif method == "notifications/initialized":
|
||||
pass # no response needed
|
||||
elif method == "tools/list":
|
||||
write_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": mid,
|
||||
"result": {
|
||||
"tools": [
|
||||
{
|
||||
"name": "echo",
|
||||
"description": "Echoes the input text back.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string", "description": "Text to echo"}
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "greet",
|
||||
"description": "Returns a greeting.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Name to greet"}
|
||||
},
|
||||
"required": ["name"],
|
||||
},
|
||||
},
|
||||
]
|
||||
},
|
||||
})
|
||||
elif method == "tools/call":
|
||||
tool_name = msg["params"]["name"]
|
||||
args = msg["params"].get("arguments", {})
|
||||
if tool_name == "echo":
|
||||
text = args.get("text", "")
|
||||
write_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": mid,
|
||||
"result": {
|
||||
"content": [{"type": "text", "text": text}]
|
||||
},
|
||||
})
|
||||
elif tool_name == "greet":
|
||||
name = args.get("name", "World")
|
||||
write_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": mid,
|
||||
"result": {
|
||||
"content": [{"type": "text", "text": f"Hello, {name}!"}]
|
||||
},
|
||||
})
|
||||
else:
|
||||
write_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": mid,
|
||||
"error": {"code": -32601, "message": f"Unknown tool: {tool_name}"},
|
||||
})
|
||||
elif method == "ping":
|
||||
write_message({"jsonrpc": "2.0", "id": mid, "result": {}})
|
||||
else:
|
||||
if mid is not None:
|
||||
write_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": mid,
|
||||
"error": {"code": -32601, "message": f"Unknown method: {method}"},
|
||||
})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
while True:
|
||||
msg = read_message()
|
||||
if msg is None:
|
||||
break
|
||||
handle(msg)
|
||||
@@ -0,0 +1,155 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/mark3labs/mcp-go/client/transport"
|
||||
)
|
||||
|
||||
// Compile-time check that FileTokenStore implements transport.TokenStore.
|
||||
var _ transport.TokenStore = (*FileTokenStore)(nil)
|
||||
|
||||
// FileTokenStore is a file-backed implementation of transport.TokenStore that
|
||||
// persists OAuth tokens as JSON on disk. Tokens are stored in a shared JSON file
|
||||
// keyed by server URL, allowing multiple MCP servers to maintain independent tokens.
|
||||
//
|
||||
// The token file is located at $XDG_CONFIG_HOME/.kit/mcp_tokens.json, falling back
|
||||
// to ~/.config/.kit/mcp_tokens.json when XDG_CONFIG_HOME is not set.
|
||||
//
|
||||
// FileTokenStore is safe for concurrent use.
|
||||
type FileTokenStore struct {
|
||||
serverKey string
|
||||
filePath string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewFileTokenStore creates a new FileTokenStore for the given server URL.
|
||||
// The serverKey is used as the map key in the shared token file, and should
|
||||
// typically be the MCP server's base URL.
|
||||
//
|
||||
// Returns an error if the token file path cannot be resolved.
|
||||
func NewFileTokenStore(serverKey string) (*FileTokenStore, error) {
|
||||
filePath, err := resolveTokenFilePath()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolving token file path: %w", err)
|
||||
}
|
||||
|
||||
return &FileTokenStore{
|
||||
serverKey: serverKey,
|
||||
filePath: filePath,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetToken returns the stored token for this store's server key.
|
||||
// Returns transport.ErrNoToken if no token exists for the server key or if
|
||||
// the token file does not yet exist.
|
||||
// Returns context.Canceled or context.DeadlineExceeded if the context is done.
|
||||
func (s *FileTokenStore) GetToken(ctx context.Context) (*transport.Token, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
tokens, err := readTokenFile(s.filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, transport.ErrNoToken
|
||||
}
|
||||
return nil, fmt.Errorf("reading token file: %w", err)
|
||||
}
|
||||
|
||||
token, ok := tokens[s.serverKey]
|
||||
if !ok {
|
||||
return nil, transport.ErrNoToken
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// SaveToken persists the given token for this store's server key.
|
||||
// If the token file or its parent directories do not exist, they are created.
|
||||
// Existing tokens for other server keys are preserved.
|
||||
// Returns context.Canceled or context.DeadlineExceeded if the context is done.
|
||||
func (s *FileTokenStore) SaveToken(ctx context.Context, token *transport.Token) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
tokens, err := readTokenFile(s.filePath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("reading token file: %w", err)
|
||||
}
|
||||
if tokens == nil {
|
||||
tokens = make(map[string]*transport.Token)
|
||||
}
|
||||
|
||||
tokens[s.serverKey] = token
|
||||
|
||||
if err := writeTokenFile(s.filePath, tokens); err != nil {
|
||||
return fmt.Errorf("writing token file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveTokenFilePath determines the path to the token file using
|
||||
// XDG_CONFIG_HOME if set, otherwise falling back to ~/.config/.kit/.
|
||||
func resolveTokenFilePath() (string, error) {
|
||||
configDir := os.Getenv("XDG_CONFIG_HOME")
|
||||
if configDir == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("determining user home directory: %w", err)
|
||||
}
|
||||
configDir = filepath.Join(home, ".config")
|
||||
}
|
||||
|
||||
return filepath.Join(configDir, ".kit", "mcp_tokens.json"), nil
|
||||
}
|
||||
|
||||
// readTokenFile reads and unmarshals the token file into a server-keyed map.
|
||||
// Returns os.ErrNotExist (via os.IsNotExist) if the file does not exist.
|
||||
func readTokenFile(path string) (map[string]*transport.Token, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tokens map[string]*transport.Token
|
||||
if err := json.Unmarshal(data, &tokens); err != nil {
|
||||
return nil, fmt.Errorf("unmarshaling token file: %w", err)
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// writeTokenFile marshals the token map and writes it to disk, creating
|
||||
// parent directories as needed. The file is written with 0600 permissions
|
||||
// to protect sensitive token data.
|
||||
func writeTokenFile(path string, tokens map[string]*transport.Token) error {
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return fmt.Errorf("creating token directory %s: %w", dir, err)
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(tokens, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling tokens: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, data, 0600); err != nil {
|
||||
return fmt.Errorf("writing token file %s: %w", path, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"image/color"
|
||||
|
||||
"charm.land/lipgloss/v2"
|
||||
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
)
|
||||
|
||||
// blockRenderer handles rendering of content blocks with configurable options
|
||||
@@ -175,7 +177,7 @@ func renderContentBlock(content string, containerWidth int, options ...rendering
|
||||
borderChars = 1
|
||||
}
|
||||
|
||||
theme := GetTheme()
|
||||
theme := style.GetTheme()
|
||||
|
||||
// Resolve foreground color: caller override or theme default.
|
||||
fgColor := theme.Text
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
"github.com/mark3labs/kit/internal/ui/core"
|
||||
)
|
||||
|
||||
// ==========================================================================
|
||||
@@ -59,7 +60,7 @@ func TestInputComponent_SubmitEmitsSubmitMsg(t *testing.T) {
|
||||
t.Fatal("expected a cmd from pressing enter on non-empty input")
|
||||
}
|
||||
|
||||
sm, ok := msg.(submitMsg)
|
||||
sm, ok := msg.(core.SubmitMsg)
|
||||
if !ok {
|
||||
t.Fatalf("expected submitMsg, got %T", msg)
|
||||
}
|
||||
@@ -83,7 +84,7 @@ func TestInputComponent_CtrlD_SubmitEmitsSubmitMsg(t *testing.T) {
|
||||
if msg == nil {
|
||||
t.Fatal("expected a cmd from ctrl+d on non-empty input")
|
||||
}
|
||||
sm, ok := msg.(submitMsg)
|
||||
sm, ok := msg.(core.SubmitMsg)
|
||||
if !ok {
|
||||
t.Fatalf("expected submitMsg from ctrl+d, got %T", msg)
|
||||
}
|
||||
@@ -175,7 +176,7 @@ func TestInputComponent_ClearForwardsAsSubmitMsg(t *testing.T) {
|
||||
t.Fatalf("%s: expected submitMsg cmd, got nil", alias)
|
||||
}
|
||||
msg := runCmd(cmd)
|
||||
sm, ok := msg.(submitMsg)
|
||||
sm, ok := msg.(core.SubmitMsg)
|
||||
if !ok {
|
||||
t.Fatalf("%s: expected submitMsg, got %T", alias, msg)
|
||||
}
|
||||
@@ -230,7 +231,7 @@ func TestInputComponent_ClearQueue_ForwardsAsSubmitMsg(t *testing.T) {
|
||||
t.Fatalf("%s: expected submitMsg cmd, got nil", alias)
|
||||
}
|
||||
msg := runCmd(cmd)
|
||||
sm, ok := msg.(submitMsg)
|
||||
sm, ok := msg.(core.SubmitMsg)
|
||||
if !ok {
|
||||
t.Fatalf("%s: expected submitMsg, got %T", alias, msg)
|
||||
}
|
||||
@@ -258,7 +259,7 @@ func TestInputComponent_UnknownSlashCommand_ForwardsAsSubmit(t *testing.T) {
|
||||
if msg == nil {
|
||||
t.Fatal("expected submitMsg for unknown slash command")
|
||||
}
|
||||
sm, ok := msg.(submitMsg)
|
||||
sm, ok := msg.(core.SubmitMsg)
|
||||
if !ok {
|
||||
t.Fatalf("expected submitMsg for unknown slash command, got %T", msg)
|
||||
}
|
||||
@@ -275,10 +276,9 @@ func TestInputComponent_UnknownSlashCommand_ForwardsAsSubmit(t *testing.T) {
|
||||
// Helpers
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// newTestStream creates a StreamComponent with a fixed width and model name,
|
||||
// in non-compact mode.
|
||||
// newTestStream creates a StreamComponent with a fixed width and model name.
|
||||
func newTestStream() *StreamComponent {
|
||||
return NewStreamComponent(false, 80, "test-model")
|
||||
return NewStreamComponent(80, "test-model")
|
||||
}
|
||||
|
||||
// sendStreamMsg calls component.Update and returns the updated component.
|
||||
@@ -699,3 +699,38 @@ func TestStreamComponent_StaleFlushTick_Discarded(t *testing.T) {
|
||||
t.Fatalf("expected streamContent='new' after current flush, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamComponent_ConsumeOverflow_NoHeight verifies that when height is
|
||||
// unconstrained (0), ConsumeOverflow always returns "".
|
||||
func TestStreamComponent_ConsumeOverflow_NoOp(t *testing.T) {
|
||||
c := newTestStream()
|
||||
// Commit some content directly.
|
||||
c.streamContent.WriteString("line1\nline2\nline3")
|
||||
c.phase = streamPhaseActive
|
||||
|
||||
// ConsumeOverflow is a no-op in alt screen mode — always returns "".
|
||||
if got := c.ConsumeOverflow(); got != "" {
|
||||
t.Fatalf("expected empty from no-op ConsumeOverflow, got %q", got)
|
||||
}
|
||||
|
||||
// Also returns "" with a height set.
|
||||
c.height = 2
|
||||
if got := c.ConsumeOverflow(); got != "" {
|
||||
t.Fatalf("expected empty from no-op ConsumeOverflow with height, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamComponent_GetRenderedContent_ReturnsAll verifies that
|
||||
// GetRenderedContent returns all accumulated content.
|
||||
func TestStreamComponent_GetRenderedContent_ReturnsAll(t *testing.T) {
|
||||
c := newTestStream()
|
||||
c.renderer = nil
|
||||
c.phase = streamPhaseActive
|
||||
|
||||
c.streamContent.WriteString("a\nb\nc\nd\ne")
|
||||
|
||||
got := c.GetRenderedContent()
|
||||
if got != "a\nb\nc\nd\ne" {
|
||||
t.Fatalf("expected full content, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
+11
-44
@@ -5,39 +5,33 @@ import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"charm.land/lipgloss/v2"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
)
|
||||
|
||||
// CLI manages the command-line interface for KIT, providing message rendering,
|
||||
// user input handling, and display management. It supports both standard and compact
|
||||
// display modes, handles streaming responses, tracks token usage, and manages the
|
||||
// overall conversation flow between the user and AI assistants.
|
||||
// user input handling, and display management. It handles streaming responses,
|
||||
// tracks token usage, and manages the overall conversation flow between the
|
||||
// user and AI assistants.
|
||||
type CLI struct {
|
||||
renderer Renderer
|
||||
usageTracker *UsageTracker
|
||||
width int
|
||||
compactMode bool
|
||||
debug bool
|
||||
modelName string
|
||||
}
|
||||
|
||||
// NewCLI creates and initializes a new CLI instance with the specified display modes.
|
||||
// The debug parameter enables debug message rendering, while compact enables a more
|
||||
// condensed display format. Returns an initialized CLI ready for interaction or an
|
||||
// NewCLI creates and initializes a new CLI instance. The debug parameter enables
|
||||
// debug message rendering. Returns an initialized CLI ready for interaction or an
|
||||
// error if initialization fails.
|
||||
func NewCLI(debug bool, compact bool) (*CLI, error) {
|
||||
func NewCLI(debug bool) (*CLI, error) {
|
||||
cli := &CLI{
|
||||
compactMode: compact,
|
||||
debug: debug,
|
||||
debug: debug,
|
||||
}
|
||||
cli.updateSize()
|
||||
if compact {
|
||||
cli.renderer = NewCompactRenderer(cli.width, debug)
|
||||
} else {
|
||||
cli.renderer = newMessageRenderer(cli.width, debug)
|
||||
}
|
||||
cli.renderer = newMessageRenderer(cli.width, debug)
|
||||
|
||||
return cli, nil
|
||||
}
|
||||
@@ -132,7 +126,7 @@ func (c *CLI) DisplayInfo(message string) {
|
||||
// DisplayExtensionBlock renders a custom styled block with the given border
|
||||
// color and optional subtitle. Used by extensions via ctx.PrintBlock.
|
||||
func (c *CLI) DisplayExtensionBlock(text, borderColor, subtitle string) {
|
||||
theme := GetTheme()
|
||||
theme := style.GetTheme()
|
||||
|
||||
borderClr := theme.Info
|
||||
if borderColor != "" {
|
||||
@@ -178,33 +172,6 @@ func (c *CLI) DisplayDebugConfig(config map[string]any) {
|
||||
fmt.Println(c.renderer.RenderDebugConfigMessage(config, time.Now()).Content)
|
||||
}
|
||||
|
||||
// UpdateUsageFromResponse records token usage using metadata from the fantasy
|
||||
// response. Only actual API-reported tokens are used for cost tracking.
|
||||
// If the provider doesn't report token counts, no usage is recorded.
|
||||
func (c *CLI) UpdateUsageFromResponse(response *fantasy.Response, inputText string) {
|
||||
if c.usageTracker == nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage := response.Usage
|
||||
inputTokens := int(usage.InputTokens)
|
||||
outputTokens := int(usage.OutputTokens)
|
||||
|
||||
// Only use actual API-reported tokens for cost tracking.
|
||||
// We intentionally do NOT estimate tokens - estimation is inaccurate
|
||||
// and should never be used for cost calculations.
|
||||
if inputTokens > 0 {
|
||||
cacheReadTokens := int(usage.CacheReadTokens)
|
||||
cacheWriteTokens := int(usage.CacheCreationTokens)
|
||||
c.usageTracker.UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens)
|
||||
// Per-response usage is a single API call, so it represents the
|
||||
// actual context window fill level.
|
||||
c.usageTracker.SetContextTokens(inputTokens + outputTokens)
|
||||
}
|
||||
// If inputTokens is 0, the provider didn't report usage - we skip recording
|
||||
// rather than estimating, to ensure cost accuracy.
|
||||
}
|
||||
|
||||
// DisplayUsageAfterResponse renders and displays token usage information immediately
|
||||
// following an AI response. This provides real-time feedback about the cost and
|
||||
// token consumption of each interaction.
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
package clipboard
|
||||
|
||||
import (
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"github.com/atotto/clipboard"
|
||||
)
|
||||
|
||||
// CopyToClipboard writes text to both the system clipboard and via OSC 52.
|
||||
// Returns a tea.Cmd that can be used in Bubble Tea's Update flow.
|
||||
func CopyToClipboard(text string) tea.Cmd {
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return tea.Sequence(
|
||||
// Method 1: OSC 52 escape sequence (works in modern terminals)
|
||||
tea.SetClipboard(text),
|
||||
|
||||
// Method 2: Native system clipboard (atotto/clipboard)
|
||||
func() tea.Msg {
|
||||
// Best effort - ignore errors
|
||||
_ = clipboard.WriteAll(text)
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package ui
|
||||
package commands
|
||||
|
||||
import (
|
||||
"slices"
|
||||
@@ -7,6 +7,10 @@ import (
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
)
|
||||
|
||||
// ListThemesFunc is set by the ui package to provide theme name completion.
|
||||
// This breaks the circular dependency between commands and ui packages.
|
||||
var ListThemesFunc func() []string
|
||||
|
||||
// SlashCommand represents a user-invokable slash command with its metadata.
|
||||
// Commands can have multiple aliases and are organized by category for better
|
||||
// discoverability and help display.
|
||||
@@ -99,7 +103,10 @@ var SlashCommands = []SlashCommand{
|
||||
Description: "Switch color theme (e.g. /theme catppuccin)",
|
||||
Category: "System",
|
||||
Complete: func(prefix string) []string {
|
||||
names := ListThemes()
|
||||
if ListThemesFunc == nil {
|
||||
return nil
|
||||
}
|
||||
names := ListThemesFunc()
|
||||
if prefix == "" {
|
||||
return names
|
||||
}
|
||||
@@ -112,6 +119,12 @@ var SlashCommands = []SlashCommand{
|
||||
return matches
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "/reload-ext",
|
||||
Description: "Hot-reload all extensions from disk",
|
||||
Category: "System",
|
||||
Aliases: []string{"/re"},
|
||||
},
|
||||
{
|
||||
Name: "/quit",
|
||||
Description: "Exit the application",
|
||||
@@ -127,7 +140,7 @@ var SlashCommands = []SlashCommand{
|
||||
},
|
||||
{
|
||||
Name: "/fork",
|
||||
Description: "Branch from an earlier message",
|
||||
Description: "Fork to new session from an earlier message",
|
||||
Category: "Navigation",
|
||||
},
|
||||
{
|
||||
@@ -1,444 +0,0 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/lipgloss/v2"
|
||||
)
|
||||
|
||||
// CompactRenderer handles rendering messages in a space-efficient compact format,
|
||||
// optimized for terminals with limited vertical space. It displays messages with
|
||||
// minimal decorations while maintaining readability and essential information.
|
||||
type CompactRenderer struct {
|
||||
width int
|
||||
debug bool
|
||||
|
||||
// getToolRenderer returns extension-provided rendering overrides for a
|
||||
// specific tool. May be nil if no extensions are loaded. Used in
|
||||
// RenderToolMessage to check for custom header/body formatting before
|
||||
// falling back to builtin renderers.
|
||||
getToolRenderer func(toolName string) *ToolRendererData
|
||||
}
|
||||
|
||||
// NewCompactRenderer creates and initializes a new CompactRenderer with the specified
|
||||
// terminal width and debug mode setting. The width parameter determines line wrapping,
|
||||
// while debug enables additional diagnostic output in rendered messages.
|
||||
func NewCompactRenderer(width int, debug bool) *CompactRenderer {
|
||||
return &CompactRenderer{
|
||||
width: width,
|
||||
debug: debug,
|
||||
}
|
||||
}
|
||||
|
||||
// SetWidth updates the terminal width for the renderer, affecting how content
|
||||
// is wrapped and formatted in subsequent render operations.
|
||||
func (r *CompactRenderer) SetWidth(width int) {
|
||||
r.width = width
|
||||
}
|
||||
|
||||
// RenderUserMessage renders a user's input message in compact format with a
|
||||
// distinctive symbol (>) and label. The content is formatted to preserve structure
|
||||
// while minimizing vertical space usage. Returns a UIMessage with formatted content
|
||||
// and metadata.
|
||||
func (r *CompactRenderer) RenderUserMessage(content string, timestamp time.Time) UIMessage {
|
||||
theme := GetTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Info).Render(">")
|
||||
label := lipgloss.NewStyle().Foreground(theme.Info).Bold(true).Render("User")
|
||||
|
||||
// Only run markdown rendering when the message contains code spans or
|
||||
// fenced code blocks. Plain text is rendered directly so that newlines
|
||||
// are preserved without the extra paragraph spacing glamour adds.
|
||||
var compactContent string
|
||||
if strings.Contains(content, "`") {
|
||||
mdContent := strings.ReplaceAll(content, "\n", "\n\n")
|
||||
compactContent = r.formatUserAssistantContent(mdContent)
|
||||
compactContent = removeBlankLines(compactContent)
|
||||
} else {
|
||||
compactContent = content
|
||||
}
|
||||
|
||||
// Handle multi-line content
|
||||
lines := strings.Split(compactContent, "\n")
|
||||
var formattedLines []string
|
||||
|
||||
for i, line := range lines {
|
||||
if i == 0 {
|
||||
// First line includes symbol and label
|
||||
formattedLines = append(formattedLines, fmt.Sprintf("%s %s %s", symbol, label, line))
|
||||
} else {
|
||||
// Subsequent lines without indentation for compact mode
|
||||
formattedLines = append(formattedLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
return UIMessage{
|
||||
Type: UserMessage,
|
||||
Content: strings.Join(formattedLines, "\n"),
|
||||
Height: len(formattedLines),
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderAssistantMessage renders an AI assistant's response in compact format with
|
||||
// a distinctive symbol (<) and the model name as label. Empty content is ignored
|
||||
// and returns an empty message. Returns a UIMessage with formatted content and metadata.
|
||||
func (r *CompactRenderer) RenderAssistantMessage(content string, timestamp time.Time, modelName string) UIMessage {
|
||||
// Ignore empty responses - don't render anything
|
||||
compactContent := r.formatUserAssistantContent(content)
|
||||
if compactContent == "" {
|
||||
return UIMessage{
|
||||
Type: AssistantMessage,
|
||||
Content: "",
|
||||
Height: 0,
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
theme := GetTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Primary).Render("<")
|
||||
|
||||
// Use the full model name, fallback to "Assistant" if empty
|
||||
if modelName == "" {
|
||||
modelName = "Assistant"
|
||||
}
|
||||
label := lipgloss.NewStyle().Foreground(theme.Primary).Bold(true).Render(modelName)
|
||||
|
||||
// Handle multi-line content
|
||||
lines := strings.Split(compactContent, "\n")
|
||||
var formattedLines []string
|
||||
|
||||
for i, line := range lines {
|
||||
if i == 0 {
|
||||
// First line includes symbol and label
|
||||
formattedLines = append(formattedLines, fmt.Sprintf("%s %s %s", symbol, label, line))
|
||||
} else {
|
||||
// Subsequent lines without indentation for compact mode
|
||||
formattedLines = append(formattedLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
return UIMessage{
|
||||
Type: AssistantMessage,
|
||||
Content: strings.Join(formattedLines, "\n"),
|
||||
Height: len(formattedLines),
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderToolMessage renders a unified tool block in compact format, combining
|
||||
// the tool invocation header (icon + display name + params) with the execution
|
||||
// result body. Status is indicated by icon: checkmark for success, cross for error.
|
||||
func (r *CompactRenderer) RenderToolMessage(toolName, toolArgs, toolResult string, isError bool) UIMessage {
|
||||
theme := GetTheme()
|
||||
|
||||
// Resolve extension renderer once for all overrides.
|
||||
var extRd *ToolRendererData
|
||||
if r.getToolRenderer != nil {
|
||||
extRd = r.getToolRenderer(toolName)
|
||||
}
|
||||
|
||||
// Status icon
|
||||
var icon string
|
||||
iconColor := theme.Success
|
||||
if isError {
|
||||
icon = "×"
|
||||
iconColor = theme.Error
|
||||
} else {
|
||||
icon = "✓"
|
||||
}
|
||||
|
||||
iconStr := lipgloss.NewStyle().Foreground(iconColor).Bold(true).Render(icon)
|
||||
|
||||
// Extension can override display name.
|
||||
displayName := toolDisplayName(toolName)
|
||||
if extRd != nil && extRd.DisplayName != "" {
|
||||
displayName = extRd.DisplayName
|
||||
}
|
||||
nameStr := lipgloss.NewStyle().Foreground(theme.Info).Bold(true).Render(displayName)
|
||||
|
||||
// Format params — check extension renderer first.
|
||||
paramBudget := max(r.width-10-len(displayName), 20)
|
||||
var params string
|
||||
if extRd != nil && extRd.RenderHeader != nil {
|
||||
params = extRd.RenderHeader(toolArgs, paramBudget)
|
||||
}
|
||||
if params == "" {
|
||||
params = formatToolParams(toolArgs, paramBudget)
|
||||
}
|
||||
|
||||
// Build header line
|
||||
header := iconStr + " " + nameStr
|
||||
if params != "" {
|
||||
header += " " + lipgloss.NewStyle().Foreground(theme.Muted).Render(params)
|
||||
}
|
||||
|
||||
// Format body: check extension renderer first, then compact builtin, then default.
|
||||
var body string
|
||||
if extRd != nil && extRd.RenderBody != nil {
|
||||
body = extRd.RenderBody(toolResult, isError, r.width-4)
|
||||
// Apply markdown rendering if requested and body is non-empty.
|
||||
if body != "" && extRd.BodyMarkdown {
|
||||
body = strings.TrimSuffix(toMarkdown(body, r.width-4), "\n")
|
||||
}
|
||||
}
|
||||
if body == "" {
|
||||
if isError {
|
||||
body = lipgloss.NewStyle().Foreground(theme.Error).Render(r.formatToolResult(toolResult))
|
||||
} else {
|
||||
// Use compact summary renderers instead of full tool body renderers.
|
||||
body = renderToolBodyCompact(toolName, toolArgs, toolResult, r.width-4)
|
||||
if body == "" {
|
||||
formatted := r.formatToolResult(toolResult)
|
||||
if formatted == "" {
|
||||
body = lipgloss.NewStyle().Foreground(theme.Muted).Italic(true).Render("(no output)")
|
||||
} else {
|
||||
body = lipgloss.NewStyle().Foreground(theme.Muted).Render(formatted)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Combine header + indented body
|
||||
var lines []string
|
||||
lines = append(lines, header)
|
||||
if body != "" {
|
||||
for line := range strings.SplitSeq(body, "\n") {
|
||||
lines = append(lines, " "+line)
|
||||
}
|
||||
}
|
||||
|
||||
return UIMessage{
|
||||
Type: ToolMessage,
|
||||
Content: strings.Join(lines, "\n"),
|
||||
Height: len(lines),
|
||||
}
|
||||
}
|
||||
|
||||
// RenderSystemMessage renders a system notification or informational message in
|
||||
// compact format with a distinctive symbol (*) and "System" label. Content is
|
||||
// formatted to fit on a single line for minimal space usage.
|
||||
func (r *CompactRenderer) RenderSystemMessage(content string, timestamp time.Time) UIMessage {
|
||||
theme := GetTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Muted).Render("◇")
|
||||
label := lipgloss.NewStyle().Foreground(theme.Muted).Bold(true).Render("System")
|
||||
|
||||
compactContent := r.formatCompactContent(content)
|
||||
|
||||
line := fmt.Sprintf("%s %-8s %s", symbol, label, compactContent)
|
||||
|
||||
return UIMessage{
|
||||
Type: SystemMessage,
|
||||
Content: line,
|
||||
Height: 1,
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderErrorMessage renders an error notification in compact format with a
|
||||
// distinctive error symbol (!) and styling to ensure visibility. The error
|
||||
// content is displayed in a single line with appropriate color highlighting.
|
||||
func (r *CompactRenderer) RenderErrorMessage(errorMsg string, timestamp time.Time) UIMessage {
|
||||
theme := GetTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Error).Render("!")
|
||||
label := lipgloss.NewStyle().Foreground(theme.Error).Bold(true).Render("Error")
|
||||
|
||||
compactContent := lipgloss.NewStyle().Foreground(theme.Error).Render(r.formatCompactContent(errorMsg))
|
||||
|
||||
line := fmt.Sprintf("%s %-8s %s", symbol, label, compactContent)
|
||||
|
||||
return UIMessage{
|
||||
Type: ErrorMessage,
|
||||
Content: line,
|
||||
Height: 1,
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderDebugMessage renders diagnostic information in compact format when debug
|
||||
// mode is enabled. Messages are truncated if they exceed the available width to
|
||||
// maintain single-line display.
|
||||
func (r *CompactRenderer) RenderDebugMessage(message string, timestamp time.Time) UIMessage {
|
||||
theme := GetTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Tool).Render("*")
|
||||
label := lipgloss.NewStyle().Foreground(theme.Tool).Bold(true).Render("Debug")
|
||||
|
||||
// Truncate message if too long
|
||||
content := message
|
||||
if len(content) > r.width-20 {
|
||||
content = content[:r.width-23] + "..."
|
||||
}
|
||||
|
||||
line := fmt.Sprintf("%s %-8s %s", symbol, label, content)
|
||||
|
||||
return UIMessage{
|
||||
Type: SystemMessage,
|
||||
Content: line,
|
||||
Height: 1,
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderDebugConfigMessage renders configuration settings in compact format for
|
||||
// debugging purposes. Config entries are displayed as key=value pairs separated
|
||||
// by commas, truncated if necessary to fit on a single line.
|
||||
func (r *CompactRenderer) RenderDebugConfigMessage(config map[string]any, timestamp time.Time) UIMessage {
|
||||
theme := GetTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Tool).Render("*")
|
||||
label := lipgloss.NewStyle().Foreground(theme.Tool).Bold(true).Render("Debug")
|
||||
|
||||
// Format config as compact key=value pairs
|
||||
var configPairs []string
|
||||
for key, value := range config {
|
||||
if value != nil {
|
||||
configPairs = append(configPairs, fmt.Sprintf("%s=%v", key, value))
|
||||
}
|
||||
}
|
||||
|
||||
content := strings.Join(configPairs, ", ")
|
||||
if len(content) > r.width-20 {
|
||||
content = content[:r.width-23] + "..."
|
||||
}
|
||||
|
||||
line := fmt.Sprintf("%s %-8s %s", symbol, label, content)
|
||||
|
||||
return UIMessage{
|
||||
Type: SystemMessage,
|
||||
Content: line,
|
||||
Height: 1,
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// formatCompactContent formats content for compact single-line display
|
||||
func (r *CompactRenderer) formatCompactContent(content string) string {
|
||||
if content == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Remove markdown formatting for compact display
|
||||
content = strings.ReplaceAll(content, "\n", " ")
|
||||
content = strings.ReplaceAll(content, "\t", " ")
|
||||
|
||||
// Collapse multiple spaces
|
||||
for strings.Contains(content, " ") {
|
||||
content = strings.ReplaceAll(content, " ", " ")
|
||||
}
|
||||
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
// Truncate if too long (unless in debug mode)
|
||||
maxLen := max(
|
||||
// Reserve space for symbol and label more conservatively
|
||||
r.width-28,
|
||||
// Minimum width for readability
|
||||
40)
|
||||
if !r.debug && len(content) > maxLen {
|
||||
content = content[:maxLen-3] + "..."
|
||||
}
|
||||
|
||||
return content
|
||||
}
|
||||
|
||||
// formatUserAssistantContent formats user and assistant content using glamour markdown rendering
|
||||
func (r *CompactRenderer) formatUserAssistantContent(content string) string {
|
||||
if content == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Calculate available width more conservatively
|
||||
// Account for: symbol (1) + spaces (2) + label (up to 20 chars) + space (1) + margin (4)
|
||||
availableWidth := max(r.width-28,
|
||||
// Minimum width for readability
|
||||
40)
|
||||
|
||||
// Use glamour to render markdown content with proper width
|
||||
rendered := toMarkdown(content, availableWidth)
|
||||
return strings.TrimSuffix(rendered, "\n")
|
||||
}
|
||||
|
||||
// wrapText wraps text to the specified width, preserving existing line breaks
|
||||
func (r *CompactRenderer) wrapText(text string, width int) string {
|
||||
if width <= 0 {
|
||||
return text
|
||||
}
|
||||
|
||||
lines := strings.Split(text, "\n")
|
||||
var wrappedLines []string
|
||||
|
||||
for _, line := range lines {
|
||||
if len(line) <= width {
|
||||
wrappedLines = append(wrappedLines, line)
|
||||
continue
|
||||
}
|
||||
|
||||
// Wrap long lines
|
||||
words := strings.Fields(line)
|
||||
if len(words) == 0 {
|
||||
wrappedLines = append(wrappedLines, line)
|
||||
continue
|
||||
}
|
||||
|
||||
currentLine := ""
|
||||
for _, word := range words {
|
||||
// If adding this word would exceed the width, start a new line
|
||||
if len(currentLine)+len(word)+1 > width && currentLine != "" {
|
||||
wrappedLines = append(wrappedLines, currentLine)
|
||||
currentLine = word
|
||||
} else {
|
||||
if currentLine == "" {
|
||||
currentLine = word
|
||||
} else {
|
||||
currentLine += " " + word
|
||||
}
|
||||
}
|
||||
}
|
||||
if currentLine != "" {
|
||||
wrappedLines = append(wrappedLines, currentLine)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(wrappedLines, "\n")
|
||||
}
|
||||
|
||||
// formatToolResult formats tool results preserving formatting but limiting to 5 lines
|
||||
func (r *CompactRenderer) formatToolResult(result string) string {
|
||||
if result == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check if this is bash output with stdout/stderr tags
|
||||
if strings.Contains(result, "<stdout>") || strings.Contains(result, "<stderr>") {
|
||||
result = r.formatBashOutput(result)
|
||||
}
|
||||
|
||||
// Calculate available width more conservatively
|
||||
availableWidth := max(r.width-28,
|
||||
// Minimum width for readability
|
||||
40)
|
||||
|
||||
// First wrap the text to prevent long lines (tool results are usually plain text, not markdown)
|
||||
wrappedResult := r.wrapText(result, availableWidth)
|
||||
|
||||
// Then limit to 5 lines
|
||||
lines := strings.Split(wrappedResult, "\n")
|
||||
if len(lines) > 5 {
|
||||
lines = lines[:5]
|
||||
// Add truncation indicator
|
||||
if len(lines) == 5 && lines[4] != "" {
|
||||
lines[4] = lines[4] + "..."
|
||||
} else {
|
||||
lines = append(lines, "...")
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// formatBashOutput formats bash command output by removing stdout/stderr tags
|
||||
// and styling appropriately. Delegates tag parsing to the shared parseBashOutput
|
||||
// helper.
|
||||
func (r *CompactRenderer) formatBashOutput(result string) string {
|
||||
return parseBashOutput(result, GetTheme())
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package ui
|
||||
package core
|
||||
|
||||
// ImageAttachment holds a clipboard image that will be sent alongside the
|
||||
// user's text prompt to the LLM. The data is raw image bytes; MediaType is
|
||||
@@ -10,9 +10,9 @@ type ImageAttachment struct {
|
||||
MediaType string
|
||||
}
|
||||
|
||||
// submitMsg is sent by the InputComponent when the user submits a text prompt.
|
||||
// SubmitMsg is sent by the InputComponent when the user submits a text prompt.
|
||||
// The parent model receives this and calls app.Run(Text) to start agent processing.
|
||||
type submitMsg struct {
|
||||
type SubmitMsg struct {
|
||||
// Text is the user's input text to send to the agent.
|
||||
Text string
|
||||
// Images holds clipboard image attachments to send alongside the text.
|
||||
@@ -20,10 +20,10 @@ type submitMsg struct {
|
||||
Images []ImageAttachment
|
||||
}
|
||||
|
||||
// cancelTimerExpiredMsg is sent by the tea.Tick command that starts when the user
|
||||
// CancelTimerExpiredMsg is sent by the tea.Tick command that starts when the user
|
||||
// presses ESC once during stateWorking. If this message arrives before the user
|
||||
// presses ESC a second time, the canceling state is reset to false.
|
||||
type cancelTimerExpiredMsg struct{}
|
||||
type CancelTimerExpiredMsg struct{}
|
||||
|
||||
// --- Tree session events ---
|
||||
|
||||
@@ -42,14 +42,14 @@ type TreeNodeSelectedMsg struct {
|
||||
// TreeCancelledMsg is sent when the user cancels the tree selector (ESC).
|
||||
type TreeCancelledMsg struct{}
|
||||
|
||||
// shellCommandMsg is sent by the InputComponent when the user submits a
|
||||
// ShellCommandMsg is sent by the InputComponent when the user submits a
|
||||
// ! or !! prefixed command. The parent model intercepts this to execute
|
||||
// the shell command directly instead of forwarding to the LLM.
|
||||
//
|
||||
// Matching pi's behavior:
|
||||
// - !cmd → run shell command, output INCLUDED in LLM context
|
||||
// - !!cmd → run shell command, output EXCLUDED from LLM context
|
||||
type shellCommandMsg struct {
|
||||
type ShellCommandMsg struct {
|
||||
// Command is the shell command to execute (prefix stripped).
|
||||
Command string
|
||||
// ExcludeFromContext is true for !! (output excluded from LLM context),
|
||||
@@ -57,9 +57,9 @@ type shellCommandMsg struct {
|
||||
ExcludeFromContext bool
|
||||
}
|
||||
|
||||
// shellCommandResultMsg carries the result of a shell command execution
|
||||
// ShellCommandResultMsg carries the result of a shell command execution
|
||||
// back to the parent model for display.
|
||||
type shellCommandResultMsg struct {
|
||||
type ShellCommandResultMsg struct {
|
||||
// Command is the original shell command that was executed.
|
||||
Command string
|
||||
// Output is the combined stdout/stderr output.
|
||||
@@ -68,6 +68,6 @@ type shellCommandResultMsg struct {
|
||||
ExitCode int
|
||||
// Err is non-nil if the command failed to start or timed out.
|
||||
Err error
|
||||
// ExcludeFromContext mirrors the flag from shellCommandMsg.
|
||||
// ExcludeFromContext mirrors the flag from ShellCommandMsg.
|
||||
ExcludeFromContext bool
|
||||
}
|
||||
@@ -139,7 +139,9 @@ func (h *CLIEventHandler) Handle(msg tea.Msg) {
|
||||
case "block":
|
||||
h.cli.DisplayExtensionBlock(e.Text, e.BorderColor, e.Subtitle)
|
||||
default:
|
||||
fmt.Println(e.Text)
|
||||
// Route unstyled extension prints through the system message
|
||||
// renderer so they get consistent formatting and timestamps.
|
||||
h.cli.DisplayInfo(e.Text)
|
||||
}
|
||||
|
||||
case app.StepCompleteEvent:
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
package ui
|
||||
|
||||
// This file re-exports types from subpackages for backward compatibility.
|
||||
// External importers can continue using ui.XXX without needing to import
|
||||
// from subpackages directly.
|
||||
|
||||
import (
|
||||
"github.com/mark3labs/kit/internal/ui/commands"
|
||||
"github.com/mark3labs/kit/internal/ui/core"
|
||||
"github.com/mark3labs/kit/internal/ui/fileutil"
|
||||
"github.com/mark3labs/kit/internal/ui/prefs"
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
)
|
||||
|
||||
// Re-export from core package
|
||||
type (
|
||||
ImageAttachment = core.ImageAttachment
|
||||
SubmitMsg = core.SubmitMsg
|
||||
CancelTimerExpiredMsg = core.CancelTimerExpiredMsg
|
||||
TreeNodeSelectedMsg = core.TreeNodeSelectedMsg
|
||||
TreeCancelledMsg = core.TreeCancelledMsg
|
||||
ShellCommandMsg = core.ShellCommandMsg
|
||||
ShellCommandResultMsg = core.ShellCommandResultMsg
|
||||
)
|
||||
|
||||
// Re-export from commands package
|
||||
type (
|
||||
SlashCommand = commands.SlashCommand
|
||||
ExtensionCommand = commands.ExtensionCommand
|
||||
)
|
||||
|
||||
// Re-export functions from fileutil package
|
||||
var ProcessFileAttachments = fileutil.ProcessFileAttachments
|
||||
|
||||
// Re-export from prefs package
|
||||
var (
|
||||
LoadThemePreference = prefs.LoadThemePreference
|
||||
SaveThemePreference = prefs.SaveThemePreference
|
||||
LoadModelPreference = prefs.LoadModelPreference
|
||||
SaveModelPreference = prefs.SaveModelPreference
|
||||
LoadThinkingLevelPreference = prefs.LoadThinkingLevelPreference
|
||||
SaveThinkingLevelPreference = prefs.SaveThinkingLevelPreference
|
||||
)
|
||||
|
||||
// Re-export from style package
|
||||
type (
|
||||
Theme = style.Theme
|
||||
MarkdownThemeColors = style.MarkdownThemeColors
|
||||
)
|
||||
|
||||
var (
|
||||
GetTheme = style.GetTheme
|
||||
SetTheme = style.SetTheme
|
||||
DefaultTheme = style.DefaultTheme
|
||||
ApplyTheme = style.ApplyTheme
|
||||
ApplyThemeWithoutSave = style.ApplyThemeWithoutSave
|
||||
ListThemes = style.ListThemes
|
||||
RegisterThemeFromConfig = style.RegisterThemeFromConfig
|
||||
KitBanner = style.KitBanner
|
||||
AdaptiveColor = style.AdaptiveColor
|
||||
IsDarkBackground = style.IsDarkBackground
|
||||
)
|
||||
@@ -25,7 +25,6 @@ type CLISetupOptions struct {
|
||||
Agent AgentInterface
|
||||
ModelString string
|
||||
Debug bool
|
||||
Compact bool
|
||||
Quiet bool
|
||||
ShowDebug bool // Whether to show debug config
|
||||
ProviderAPIKey string // For OAuth detection
|
||||
@@ -76,7 +75,7 @@ func SetupCLI(opts *CLISetupOptions) (*CLI, error) {
|
||||
return nil, nil // No CLI in quiet mode
|
||||
}
|
||||
|
||||
cli, err := NewCLI(opts.Debug, opts.Compact)
|
||||
cli, err := NewCLI(opts.Debug)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create CLI: %v", err)
|
||||
}
|
||||
@@ -110,9 +109,7 @@ func SetupCLI(opts *CLISetupOptions) (*CLI, error) {
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("")
|
||||
|
||||
// Display model info
|
||||
// Display model info (the system message block provides its own spacing).
|
||||
if provider != "unknown" && model != "unknown" {
|
||||
cli.DisplayInfo(fmt.Sprintf("Model loaded: %s (%s)", provider, model))
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package ui
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -5,32 +5,34 @@ import (
|
||||
"time"
|
||||
|
||||
"charm.land/lipgloss/v2"
|
||||
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
)
|
||||
|
||||
// Renderer is the interface satisfied by both MessageRenderer and
|
||||
// CompactRenderer. It allows model.go and cli.go to call rendering methods
|
||||
// without branching on compact mode.
|
||||
// Renderer is the interface satisfied by MessageRenderer. It allows model.go
|
||||
// and cli.go to call rendering methods uniformly.
|
||||
type Renderer interface {
|
||||
RenderUserMessage(content string, timestamp time.Time) UIMessage
|
||||
RenderAssistantMessage(content string, timestamp time.Time, modelName string) UIMessage
|
||||
RenderReasoningBlock(content string, timestamp time.Time) UIMessage
|
||||
RenderToolMessage(toolName, toolArgs, toolResult string, isError bool) UIMessage
|
||||
RenderSystemMessage(content string, timestamp time.Time) UIMessage
|
||||
RenderErrorMessage(errorMsg string, timestamp time.Time) UIMessage
|
||||
RenderDebugMessage(message string, timestamp time.Time) UIMessage
|
||||
RenderDebugConfigMessage(config map[string]any, timestamp time.Time) UIMessage
|
||||
SetWidth(width int)
|
||||
UpdateTheme()
|
||||
}
|
||||
|
||||
// Compile-time checks that both renderers satisfy the Renderer interface.
|
||||
// Compile-time check that MessageRenderer satisfies the Renderer interface.
|
||||
var _ Renderer = (*MessageRenderer)(nil)
|
||||
var _ Renderer = (*CompactRenderer)(nil)
|
||||
|
||||
// parseBashOutput parses <stdout>/<stderr> tagged output from bash tool
|
||||
// results, styling stderr with the theme's error color. Returns the
|
||||
// combined, styled output string with tags stripped.
|
||||
//
|
||||
// Shared by both MessageRenderer and CompactRenderer.
|
||||
func parseBashOutput(result string, theme Theme) string {
|
||||
// Shared by MessageRenderer.
|
||||
func parseBashOutput(result string, theme style.Theme) string {
|
||||
var formattedResult strings.Builder
|
||||
remaining := result
|
||||
|
||||
|
||||
@@ -2,20 +2,22 @@ package ui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/mark3labs/kit/internal/ui/commands"
|
||||
)
|
||||
|
||||
// FuzzyMatch represents the result of a fuzzy string matching operation,
|
||||
// containing the matched command and its relevance score. Higher scores
|
||||
// indicate better matches.
|
||||
type FuzzyMatch struct {
|
||||
Command *SlashCommand
|
||||
Command *commands.SlashCommand
|
||||
Score int
|
||||
}
|
||||
|
||||
// FuzzyMatchCommands performs fuzzy string matching on the provided slash commands
|
||||
// based on the query string. Returns a slice of matches sorted by relevance score
|
||||
// in descending order. An empty query returns all commands with zero scores.
|
||||
func FuzzyMatchCommands(query string, commands []SlashCommand) []FuzzyMatch {
|
||||
func FuzzyMatchCommands(query string, commands []commands.SlashCommand) []FuzzyMatch {
|
||||
if query == "" || query == "/" {
|
||||
// Return all commands when query is empty or just "/"
|
||||
matches := make([]FuzzyMatch, len(commands))
|
||||
@@ -57,7 +59,7 @@ func FuzzyMatchCommands(query string, commands []SlashCommand) []FuzzyMatch {
|
||||
}
|
||||
|
||||
// fuzzyScore calculates the fuzzy match score for a command
|
||||
func fuzzyScore(query string, cmd *SlashCommand) int {
|
||||
func fuzzyScore(query string, cmd *commands.SlashCommand) int {
|
||||
// Check exact match first
|
||||
cmdName := strings.ToLower(strings.TrimPrefix(cmd.Name, "/"))
|
||||
if cmdName == query {
|
||||
|
||||
+123
-67
@@ -10,6 +10,9 @@ import (
|
||||
"charm.land/lipgloss/v2"
|
||||
|
||||
"github.com/mark3labs/kit/internal/clipboard"
|
||||
"github.com/mark3labs/kit/internal/ui/commands"
|
||||
"github.com/mark3labs/kit/internal/ui/core"
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
)
|
||||
|
||||
// InputComponent is the interactive text input field for the parent AppModel.
|
||||
@@ -29,7 +32,7 @@ import (
|
||||
// app.Run().
|
||||
type InputComponent struct {
|
||||
textarea textarea.Model
|
||||
commands []SlashCommand
|
||||
commands []commands.SlashCommand
|
||||
showPopup bool
|
||||
filtered []FuzzyMatch
|
||||
selected int
|
||||
@@ -42,17 +45,17 @@ type InputComponent struct {
|
||||
// Argument completion state. When the user types "/cmd " followed by
|
||||
// a partial argument and the command has a Complete function, the popup
|
||||
// switches to argument-completion mode showing suggestions from Complete.
|
||||
argMode bool // true when showing arg completions
|
||||
argCommand string // command prefix for arg mode (e.g. "/bookmark")
|
||||
argSynthCmds []SlashCommand // backing storage for synthetic arg entries
|
||||
argMode bool // true when showing arg completions
|
||||
argCommand string // command prefix for arg mode (e.g. "/bookmark")
|
||||
argSynthCmds []commands.SlashCommand // backing storage for synthetic arg entries
|
||||
|
||||
// File completion state. When the user types @ followed by a partial
|
||||
// file path, the popup shows file/directory suggestions from the cwd.
|
||||
fileMode bool // true when showing @file completions
|
||||
filePrefix string // current text after @ being matched
|
||||
fileAtStartIdx int // byte offset of @ in the textarea value
|
||||
fileSuggestions []FileSuggestion // backing storage for file entries
|
||||
fileSynthCmds []SlashCommand // synthetic SlashCommands wrapping file entries
|
||||
fileMode bool // true when showing @file completions
|
||||
filePrefix string // current text after @ being matched
|
||||
fileAtStartIdx int // byte offset of @ in the textarea value
|
||||
fileSuggestions []FileSuggestion // backing storage for file entries
|
||||
fileSynthCmds []commands.SlashCommand // synthetic commands.SlashCommands wrapping file entries
|
||||
|
||||
// cwd is the working directory used for @file path resolution and
|
||||
// autocomplete suggestions. Set by the parent via SetCwd.
|
||||
@@ -66,12 +69,12 @@ type InputComponent struct {
|
||||
hideHint bool
|
||||
|
||||
// agentBusy indicates the agent is currently working. When true, the
|
||||
// hint text shows steering shortcut (Ctrl+S) instead of submit.
|
||||
// hint text shows steering shortcut (Ctrl+X s) instead of submit.
|
||||
agentBusy bool
|
||||
|
||||
// pendingImages holds clipboard images attached to the next submission.
|
||||
// Images are added via Ctrl+V and cleared on submit or Ctrl+U.
|
||||
pendingImages []ImageAttachment
|
||||
pendingImages []core.ImageAttachment
|
||||
|
||||
// history stores previously submitted prompts (most recent last).
|
||||
// Limited to maxHistory entries; duplicates of the previous entry are
|
||||
@@ -94,7 +97,7 @@ const maxHistory = 100
|
||||
|
||||
// clipboardImageMsg is the result of an async clipboard image read.
|
||||
type clipboardImageMsg struct {
|
||||
image *ImageAttachment
|
||||
image *core.ImageAttachment
|
||||
err error
|
||||
}
|
||||
|
||||
@@ -106,7 +109,7 @@ func NewInputComponent(width int, title string, appCtrl AppController) *InputCom
|
||||
ta.Placeholder = "Type your message..."
|
||||
ta.ShowLineNumbers = false
|
||||
ta.Prompt = ""
|
||||
ta.CharLimit = 5000
|
||||
ta.CharLimit = 0
|
||||
ta.SetWidth(width - 8) // Account for container padding, border and internal padding
|
||||
ta.SetHeight(3) // Default to 3 lines like huh
|
||||
ta.Focus()
|
||||
@@ -119,7 +122,7 @@ func NewInputComponent(width int, title string, appCtrl AppController) *InputCom
|
||||
)
|
||||
|
||||
// Style the textarea using theme colors.
|
||||
theme := GetTheme()
|
||||
theme := style.GetTheme()
|
||||
styles := ta.Styles()
|
||||
styles.Focused.Base = lipgloss.NewStyle()
|
||||
styles.Focused.Placeholder = lipgloss.NewStyle().Foreground(theme.VeryMuted)
|
||||
@@ -130,7 +133,7 @@ func NewInputComponent(width int, title string, appCtrl AppController) *InputCom
|
||||
|
||||
return &InputComponent{
|
||||
textarea: ta,
|
||||
commands: SlashCommands,
|
||||
commands: commands.SlashCommands,
|
||||
width: width,
|
||||
popupHeight: 7,
|
||||
title: title,
|
||||
@@ -329,7 +332,7 @@ func (s *InputComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
s.filePrefix = prefix
|
||||
s.fileAtStartIdx = atIdx
|
||||
s.fileSuggestions = suggestions
|
||||
s.fileSynthCmds = make([]SlashCommand, len(suggestions))
|
||||
s.fileSynthCmds = make([]commands.SlashCommand, len(suggestions))
|
||||
s.filtered = make([]FuzzyMatch, len(suggestions))
|
||||
for i, fs := range suggestions {
|
||||
name := fs.RelPath
|
||||
@@ -337,7 +340,7 @@ func (s *InputComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if fs.IsDir {
|
||||
desc = "directory"
|
||||
}
|
||||
s.fileSynthCmds[i] = SlashCommand{Name: name, Description: desc}
|
||||
s.fileSynthCmds[i] = commands.SlashCommand{Name: name, Description: desc}
|
||||
s.filtered[i] = FuzzyMatch{Command: &s.fileSynthCmds[i], Score: fs.Score}
|
||||
}
|
||||
s.selected = 0
|
||||
@@ -396,14 +399,14 @@ func (s *InputComponent) handleSubmit(value string) tea.Cmd {
|
||||
cmd := strings.TrimSpace(trimmed[2:])
|
||||
if cmd != "" {
|
||||
return func() tea.Msg {
|
||||
return shellCommandMsg{Command: cmd, ExcludeFromContext: true}
|
||||
return core.ShellCommandMsg{Command: cmd, ExcludeFromContext: true}
|
||||
}
|
||||
}
|
||||
} else if strings.HasPrefix(trimmed, "!") {
|
||||
cmd := strings.TrimSpace(trimmed[1:])
|
||||
if cmd != "" {
|
||||
return func() tea.Msg {
|
||||
return shellCommandMsg{Command: cmd, ExcludeFromContext: false}
|
||||
return core.ShellCommandMsg{Command: cmd, ExcludeFromContext: false}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -411,9 +414,9 @@ func (s *InputComponent) handleSubmit(value string) tea.Cmd {
|
||||
// Resolve via canonical command lookup so aliases are handled uniformly.
|
||||
// Only /quit is handled locally — all other slash commands (including
|
||||
// /clear and /clear-queue) are forwarded to the parent model via
|
||||
// submitMsg so the parent can update its own state (scrollback, queue
|
||||
// submitMsg so the parent can update its own state (ScrollList, queue
|
||||
// counts, etc.) in one place.
|
||||
if sc := GetCommandByName(trimmed); sc != nil {
|
||||
if sc := commands.GetCommandByName(trimmed); sc != nil {
|
||||
switch sc.Name {
|
||||
case "/quit":
|
||||
return tea.Quit
|
||||
@@ -426,7 +429,7 @@ func (s *InputComponent) handleSubmit(value string) tea.Cmd {
|
||||
images := s.pendingImages
|
||||
s.pendingImages = nil
|
||||
return func() tea.Msg {
|
||||
return submitMsg{Text: trimmed, Images: images}
|
||||
return core.SubmitMsg{Text: trimmed, Images: images}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -463,7 +466,7 @@ func (s *InputComponent) resetHistoryBrowsing() {
|
||||
func (s *InputComponent) View() tea.View {
|
||||
containerStyle := lipgloss.NewStyle()
|
||||
|
||||
theme := GetTheme()
|
||||
theme := style.GetTheme()
|
||||
|
||||
// PaddingLeft(3) aligns with message content: border(1) + paddingLeft(2).
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
@@ -486,10 +489,8 @@ func (s *InputComponent) View() tea.View {
|
||||
view.WriteString("\n")
|
||||
view.WriteString(inputBoxStyle.Render(s.textarea.View()))
|
||||
|
||||
if s.showPopup && len(s.filtered) > 0 {
|
||||
view.WriteString("\n")
|
||||
view.WriteString(s.renderPopup())
|
||||
}
|
||||
// Popup is now rendered as a centered overlay in AppModel.View()
|
||||
// instead of inline here to prevent bottom overflow
|
||||
|
||||
// Show image attachment indicator when images are pending.
|
||||
if len(s.pendingImages) > 0 {
|
||||
@@ -513,12 +514,12 @@ func (s *InputComponent) View() tea.View {
|
||||
availableHintWidth := s.width - 3
|
||||
if s.agentBusy {
|
||||
// When the agent is working, show steering shortcut.
|
||||
if availableHintWidth >= 55 {
|
||||
hint = "enter queue • ctrl+s steer • esc esc cancel"
|
||||
} else if availableHintWidth >= 35 {
|
||||
hint = "↵ queue • ^S steer • esc×2 cancel"
|
||||
if availableHintWidth >= 60 {
|
||||
hint = "enter queue • ctrl+x s steer • esc esc cancel"
|
||||
} else if availableHintWidth >= 40 {
|
||||
hint = "↵ queue • ^X s steer • esc×2 cancel"
|
||||
} else {
|
||||
hint = "^S steer"
|
||||
hint = "^X s steer"
|
||||
}
|
||||
} else if availableHintWidth >= 67 {
|
||||
hint = "enter submit • ctrl+j / shift+enter new line • ctrl+v paste image"
|
||||
@@ -537,19 +538,62 @@ func (s *InputComponent) View() tea.View {
|
||||
}
|
||||
|
||||
// renderPopup renders the autocomplete popup for slash command suggestions.
|
||||
func (s *InputComponent) renderPopup() string {
|
||||
theme := GetTheme()
|
||||
// When rendered inline (not centered), returns the styled popup content.
|
||||
// RenderPopupCentered renders the popup as a centered overlay.
|
||||
func (s *InputComponent) RenderPopupCentered(termWidth, termHeight int) string {
|
||||
if !s.showPopup || len(s.filtered) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
popupContent := s.renderPopupWithOptions(true)
|
||||
|
||||
// Center popup using lipgloss.Place
|
||||
positioned := lipgloss.Place(
|
||||
termWidth,
|
||||
termHeight,
|
||||
lipgloss.Center,
|
||||
lipgloss.Center,
|
||||
popupContent,
|
||||
)
|
||||
|
||||
return positioned
|
||||
}
|
||||
|
||||
// renderPopupWithOptions renders the popup content with optional center styling.
|
||||
func (s *InputComponent) renderPopupWithOptions(centered bool) string {
|
||||
theme := style.GetTheme()
|
||||
popupWidth := max(s.width-4, 20)
|
||||
|
||||
// Use the theme background for the popup - the full-width item backgrounds
|
||||
// and primary-colored selection will provide sufficient contrast
|
||||
popupBg := theme.Background
|
||||
|
||||
popupStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(theme.MutedBorder).
|
||||
BorderForeground(theme.Primary).
|
||||
Background(popupBg).
|
||||
Padding(1, 2).
|
||||
Width(popupWidth).
|
||||
MarginLeft(0)
|
||||
MarginLeft(0).
|
||||
MarginBottom(1) // Visual depth/shadow effect
|
||||
|
||||
// Inner content width: popup minus border (2) and horizontal padding (4).
|
||||
innerWidth := max(popupWidth-6, 10)
|
||||
|
||||
// Item background styles for high contrast
|
||||
normalItemBg := lipgloss.NewStyle().
|
||||
Background(popupBg).
|
||||
Foreground(theme.Text).
|
||||
Width(innerWidth).
|
||||
Padding(0, 1)
|
||||
|
||||
selectedItemBg := lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(theme.Background).
|
||||
Width(innerWidth).
|
||||
Padding(0, 1).
|
||||
Bold(true)
|
||||
|
||||
var items []string
|
||||
|
||||
visibleItems := min(len(s.filtered), s.popupHeight)
|
||||
@@ -563,44 +607,45 @@ func (s *InputComponent) renderPopup() string {
|
||||
match := s.filtered[i]
|
||||
sc := match.Command
|
||||
|
||||
// Choose the appropriate background style
|
||||
itemStyle := normalItemBg
|
||||
if i == s.selected {
|
||||
itemStyle = selectedItemBg
|
||||
}
|
||||
|
||||
// Build indicator with proper coloring
|
||||
var indicator string
|
||||
if i == s.selected {
|
||||
indicator = lipgloss.NewStyle().Foreground(theme.Primary).Render("> ")
|
||||
indicator = "> "
|
||||
} else {
|
||||
indicator = " "
|
||||
}
|
||||
|
||||
nameStyle := lipgloss.NewStyle().Foreground(theme.Secondary).Bold(true)
|
||||
descStyle := lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
if i == s.selected {
|
||||
nameStyle = nameStyle.Foreground(theme.Primary)
|
||||
descStyle = descStyle.Foreground(theme.Text)
|
||||
}
|
||||
|
||||
// Build content with name and description
|
||||
var content string
|
||||
if s.fileMode {
|
||||
// File mode: use full width for the path, show description
|
||||
// (e.g. "directory") inline after a gap.
|
||||
// File mode: use full width for the path, show description inline
|
||||
maxNameLen := max(innerWidth-16, 8)
|
||||
displayName := sc.Name
|
||||
if len(displayName) > maxNameLen && maxNameLen > 3 {
|
||||
displayName = displayName[:maxNameLen-3] + "..."
|
||||
}
|
||||
name := nameStyle.Render(displayName)
|
||||
|
||||
if sc.Description != "" && innerWidth > 30 {
|
||||
items = append(items, indicator+name+" "+descStyle.Render(sc.Description))
|
||||
content = indicator + displayName + " " + sc.Description
|
||||
} else {
|
||||
items = append(items, indicator+name)
|
||||
content = indicator + displayName
|
||||
}
|
||||
} else {
|
||||
// Line layout: indicator(2) + name(nameWidth-2 visual) + desc.
|
||||
// Line layout: indicator(2) + name(nameWidth-2 visual) + desc
|
||||
if innerWidth < 20 {
|
||||
// Very narrow: show truncated name only, no fixed column.
|
||||
// Very narrow: show truncated name only
|
||||
displayName := sc.Name
|
||||
maxName := max(innerWidth-2, 3)
|
||||
if len(displayName) > maxName {
|
||||
displayName = displayName[:maxName-1] + "…"
|
||||
}
|
||||
items = append(items, indicator+nameStyle.Render(displayName))
|
||||
content = indicator + displayName
|
||||
} else {
|
||||
nameWidth := 15
|
||||
if innerWidth < 25 {
|
||||
@@ -611,33 +656,41 @@ func (s *InputComponent) renderPopup() string {
|
||||
if len(displayName) > maxNameChars {
|
||||
displayName = displayName[:maxNameChars-1] + "…"
|
||||
}
|
||||
name := nameStyle.Width(maxNameChars).Render(displayName)
|
||||
|
||||
// Description gets remaining space.
|
||||
// Description gets remaining space
|
||||
maxDescLen := max(innerWidth-nameWidth, 0)
|
||||
desc := sc.Description
|
||||
if maxDescLen < 4 {
|
||||
items = append(items, indicator+name)
|
||||
} else {
|
||||
if maxDescLen >= 4 && desc != "" {
|
||||
if len(desc) > maxDescLen {
|
||||
desc = desc[:maxDescLen-3] + "..."
|
||||
}
|
||||
items = append(items, indicator+name+descStyle.Render(desc))
|
||||
content = indicator + lipgloss.NewStyle().Width(maxNameChars).Render(displayName) + desc
|
||||
} else {
|
||||
content = indicator + displayName
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
items = append(items, itemStyle.Render(content))
|
||||
}
|
||||
|
||||
// Add scroll indicators with background
|
||||
scrollStyle := lipgloss.NewStyle().
|
||||
Background(popupBg).
|
||||
Foreground(theme.VeryMuted).
|
||||
Width(innerWidth).
|
||||
Padding(0, 1)
|
||||
|
||||
if startIdx > 0 {
|
||||
items = append([]string{lipgloss.NewStyle().Foreground(theme.VeryMuted).Render(" ↑ more above")}, items...)
|
||||
items = append([]string{scrollStyle.Render(" ↑ more above")}, items...)
|
||||
}
|
||||
if endIdx < len(s.filtered) {
|
||||
items = append(items, lipgloss.NewStyle().Foreground(theme.VeryMuted).Render(" ↓ more below"))
|
||||
items = append(items, scrollStyle.Render(" ↓ more below"))
|
||||
}
|
||||
|
||||
content := strings.Join(items, "\n")
|
||||
|
||||
// Adapt footer text to available width.
|
||||
// Adapt footer text to available width with background
|
||||
var footerText string
|
||||
if innerWidth >= 50 {
|
||||
footerText = "↑↓ navigate • tab complete • ↵ select • esc dismiss"
|
||||
@@ -646,7 +699,10 @@ func (s *InputComponent) renderPopup() string {
|
||||
} else {
|
||||
footerText = "↑↓ tab ↵ esc"
|
||||
}
|
||||
footer := lipgloss.NewStyle().Foreground(theme.VeryMuted).Italic(true).
|
||||
footer := lipgloss.NewStyle().
|
||||
Background(popupBg).
|
||||
Foreground(theme.VeryMuted).
|
||||
Italic(true).
|
||||
Render(footerText)
|
||||
|
||||
return popupStyle.Render(content + "\n\n" + footer)
|
||||
@@ -676,10 +732,10 @@ func (s *InputComponent) completeArgs(line string) []FuzzyMatch {
|
||||
|
||||
s.argMode = true
|
||||
s.argCommand = cmdName
|
||||
s.argSynthCmds = make([]SlashCommand, len(suggestions))
|
||||
s.argSynthCmds = make([]commands.SlashCommand, len(suggestions))
|
||||
s.filtered = make([]FuzzyMatch, len(suggestions))
|
||||
for i, sug := range suggestions {
|
||||
s.argSynthCmds[i] = SlashCommand{Name: sug}
|
||||
s.argSynthCmds[i] = commands.SlashCommand{Name: sug}
|
||||
s.filtered[i] = FuzzyMatch{Command: &s.argSynthCmds[i]}
|
||||
}
|
||||
return s.filtered
|
||||
@@ -687,7 +743,7 @@ func (s *InputComponent) completeArgs(line string) []FuzzyMatch {
|
||||
|
||||
// findCommandWithComplete looks up a command by name that has a non-nil
|
||||
// Complete function.
|
||||
func (s *InputComponent) findCommandWithComplete(name string) *SlashCommand {
|
||||
func (s *InputComponent) findCommandWithComplete(name string) *commands.SlashCommand {
|
||||
for i := range s.commands {
|
||||
if s.commands[i].Name == name && s.commands[i].Complete != nil {
|
||||
return &s.commands[i]
|
||||
@@ -705,7 +761,7 @@ func readClipboardImageCmd() tea.Cmd {
|
||||
return clipboardImageMsg{err: err}
|
||||
}
|
||||
return clipboardImageMsg{
|
||||
image: &ImageAttachment{
|
||||
image: &core.ImageAttachment{
|
||||
Data: img.Data,
|
||||
MediaType: img.MediaType,
|
||||
},
|
||||
@@ -715,7 +771,7 @@ func readClipboardImageCmd() tea.Cmd {
|
||||
|
||||
// ClearPendingImages removes all pending image attachments and returns them.
|
||||
// Used by the parent model when consuming images for submission.
|
||||
func (s *InputComponent) ClearPendingImages() []ImageAttachment {
|
||||
func (s *InputComponent) ClearPendingImages() []core.ImageAttachment {
|
||||
images := s.pendingImages
|
||||
s.pendingImages = nil
|
||||
return images
|
||||
|
||||
@@ -0,0 +1,381 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/lipgloss/v2"
|
||||
|
||||
"github.com/mark3labs/kit/internal/ui/render"
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
)
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// MessageItem implementations for ScrollList
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// TextMessageItem represents a completed text message (user or assistant)
|
||||
// in the scrollback. It uses pre-rendered styled content from MessageRenderer.
|
||||
type TextMessageItem struct {
|
||||
id string
|
||||
role string // "user" or "assistant"
|
||||
content string // Raw content (for re-rendering if needed)
|
||||
preRendered string // Pre-rendered styled content from MessageRenderer
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
// NewTextMessageItem creates a new text message for the scrollback.
|
||||
// The content should be pre-rendered using MessageRenderer for proper styling.
|
||||
func NewTextMessageItem(id string, role string, content string) *TextMessageItem {
|
||||
return &TextMessageItem{
|
||||
id: id,
|
||||
role: role,
|
||||
content: content,
|
||||
timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewStyledMessageItem creates a message item with pre-rendered styled content.
|
||||
// This is the preferred way to create messages when you have styled content from MessageRenderer.
|
||||
func NewStyledMessageItem(id string, role string, rawContent string, preRendered string) *TextMessageItem {
|
||||
return &TextMessageItem{
|
||||
id: id,
|
||||
role: role,
|
||||
content: rawContent,
|
||||
preRendered: preRendered,
|
||||
timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TextMessageItem) ID() string {
|
||||
return m.id
|
||||
}
|
||||
|
||||
func (m *TextMessageItem) Render(width int) string {
|
||||
// If we have pre-rendered styled content, return it
|
||||
if m.preRendered != "" {
|
||||
return m.preRendered
|
||||
}
|
||||
|
||||
// Fallback to simple formatting if no pre-rendered content
|
||||
return m.renderContent(width)
|
||||
}
|
||||
|
||||
func (m *TextMessageItem) Height() int {
|
||||
rendered := m.Render(0) // Width doesn't matter since we use pre-rendered
|
||||
if rendered == "" {
|
||||
return 0
|
||||
}
|
||||
return strings.Count(rendered, "\n") + 1
|
||||
}
|
||||
|
||||
func (m *TextMessageItem) renderContent(width int) string {
|
||||
var parts []string
|
||||
|
||||
// Role indicator
|
||||
if m.role == "user" {
|
||||
parts = append(parts, "│ ▸ You")
|
||||
} else {
|
||||
parts = append(parts, "") // Assistant messages start without role
|
||||
}
|
||||
|
||||
// Content with simple wrapping
|
||||
contentWidth := max(width-4, 20)
|
||||
|
||||
for line := range strings.SplitSeq(m.content, "\n") {
|
||||
if len(line) <= contentWidth {
|
||||
parts = append(parts, "│ "+line)
|
||||
} else {
|
||||
// Basic wrap
|
||||
for len(line) > contentWidth {
|
||||
parts = append(parts, "│ "+line[:contentWidth])
|
||||
line = line[contentWidth:]
|
||||
}
|
||||
if len(line) > 0 {
|
||||
parts = append(parts, "│ "+line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// StreamingMessageItem - Live streaming assistant/reasoning text
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// StreamingMessageItem represents actively streaming assistant or reasoning text.
|
||||
// It accumulates content chunks and re-renders on each update for live display.
|
||||
type StreamingMessageItem struct {
|
||||
id string
|
||||
role string // "assistant" or "reasoning"
|
||||
content string // Accumulated streaming content
|
||||
timestamp time.Time
|
||||
startTime time.Time // When streaming started (for live duration counter)
|
||||
modelName string
|
||||
streaming bool // true while actively streaming
|
||||
finalDuration time.Duration // Frozen duration when complete
|
||||
cachedRender string
|
||||
cachedWidth int
|
||||
}
|
||||
|
||||
// NewStreamingMessageItem creates a new streaming message item.
|
||||
func NewStreamingMessageItem(id, role string, modelName string) *StreamingMessageItem {
|
||||
now := time.Now()
|
||||
return &StreamingMessageItem{
|
||||
id: id,
|
||||
role: role,
|
||||
timestamp: now,
|
||||
startTime: now,
|
||||
modelName: modelName,
|
||||
streaming: true,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the unique identifier.
|
||||
func (s *StreamingMessageItem) ID() string {
|
||||
return s.id
|
||||
}
|
||||
|
||||
// Render renders the streaming message with live content.
|
||||
func (s *StreamingMessageItem) Render(width int) string {
|
||||
// For reasoning, never cache - we need live duration updates
|
||||
// For assistant, cache is OK
|
||||
if s.role != "reasoning" && s.cachedWidth == width && s.cachedRender != "" {
|
||||
return s.cachedRender
|
||||
}
|
||||
|
||||
var rendered string
|
||||
if s.role == "reasoning" {
|
||||
// Calculate duration in milliseconds for render.ReasoningBlock
|
||||
var durationMs int64
|
||||
if s.finalDuration > 0 {
|
||||
durationMs = s.finalDuration.Milliseconds()
|
||||
} else if !s.startTime.IsZero() {
|
||||
durationMs = time.Since(s.startTime).Milliseconds()
|
||||
}
|
||||
ty := createTypography(style.GetTheme())
|
||||
rendered = render.ReasoningBlock(s.content, durationMs, ty, style.GetTheme())
|
||||
} else {
|
||||
// Render as assistant message
|
||||
rendered = render.AssistantBlock(s.content, width, style.GetTheme())
|
||||
}
|
||||
|
||||
// Cache and return (but reasoning is never cached due to live duration)
|
||||
if s.role != "reasoning" {
|
||||
s.cachedRender = rendered
|
||||
s.cachedWidth = width
|
||||
}
|
||||
return rendered
|
||||
}
|
||||
|
||||
// Height returns the number of lines.
|
||||
func (s *StreamingMessageItem) Height() int {
|
||||
// For reasoning blocks, cachedRender is never populated (rendering is
|
||||
// width-independent and includes a live timer). Fall back to Render(0)
|
||||
// so callers always get the correct height.
|
||||
rendered := s.cachedRender
|
||||
if rendered == "" {
|
||||
rendered = s.Render(0)
|
||||
}
|
||||
if rendered == "" {
|
||||
return 0
|
||||
}
|
||||
return strings.Count(rendered, "\n") + 1
|
||||
}
|
||||
|
||||
// AppendChunk adds a content chunk and invalidates the render cache.
|
||||
func (s *StreamingMessageItem) AppendChunk(chunk string) {
|
||||
s.content += chunk
|
||||
s.cachedWidth = 0 // Invalidate cache
|
||||
}
|
||||
|
||||
// MarkComplete marks the streaming message as complete and freezes the duration.
|
||||
func (s *StreamingMessageItem) MarkComplete() {
|
||||
s.streaming = false
|
||||
// Freeze the duration for reasoning blocks
|
||||
if s.role == "reasoning" && !s.startTime.IsZero() {
|
||||
s.finalDuration = time.Since(s.startTime)
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// StreamingBashOutputItem - Live bash command output
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// StreamingBashOutputItem represents live bash command output.
|
||||
type StreamingBashOutputItem struct {
|
||||
id string
|
||||
command string
|
||||
stdoutLines []string
|
||||
stderrLines []string
|
||||
maxLines int
|
||||
complete bool
|
||||
cachedRender string
|
||||
cachedWidth int
|
||||
}
|
||||
|
||||
// NewStreamingBashOutputItem creates a new streaming bash output item.
|
||||
func NewStreamingBashOutputItem(id string, command string) *StreamingBashOutputItem {
|
||||
return &StreamingBashOutputItem{
|
||||
id: id,
|
||||
command: command,
|
||||
stdoutLines: make([]string, 0),
|
||||
stderrLines: make([]string, 0),
|
||||
maxLines: 100, // Cap lines to prevent memory issues
|
||||
complete: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *StreamingBashOutputItem) ID() string {
|
||||
return m.id
|
||||
}
|
||||
|
||||
func (m *StreamingBashOutputItem) Render(width int) string {
|
||||
// Return cached if width matches and complete
|
||||
if m.complete && m.cachedWidth == width && m.cachedRender != "" {
|
||||
return m.cachedRender
|
||||
}
|
||||
|
||||
theme := style.GetTheme()
|
||||
var parts []string
|
||||
|
||||
// Header with command
|
||||
if m.command != "" {
|
||||
headerStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Italic(true)
|
||||
parts = append(parts, headerStyle.Render(fmt.Sprintf("▸ %s", m.command)))
|
||||
}
|
||||
|
||||
const lineIndent = " "
|
||||
lineWidth := width - len(lineIndent)
|
||||
|
||||
// Stdout lines
|
||||
if len(m.stdoutLines) > 0 {
|
||||
outputStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Text).
|
||||
Background(theme.CodeBg).
|
||||
PaddingLeft(1).
|
||||
Width(lineWidth)
|
||||
for _, line := range m.stdoutLines {
|
||||
parts = append(parts, lineIndent+outputStyle.Render(line))
|
||||
}
|
||||
}
|
||||
|
||||
// Stderr lines
|
||||
if len(m.stderrLines) > 0 {
|
||||
stderrStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Error).
|
||||
Background(theme.CodeBg).
|
||||
PaddingLeft(1).
|
||||
Width(lineWidth)
|
||||
for _, line := range m.stderrLines {
|
||||
parts = append(parts, lineIndent+stderrStyle.Render(line))
|
||||
}
|
||||
}
|
||||
|
||||
result := strings.Join(parts, "\n")
|
||||
if m.complete {
|
||||
m.cachedRender = result
|
||||
m.cachedWidth = width
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *StreamingBashOutputItem) Height() int {
|
||||
if m.cachedRender != "" {
|
||||
return strings.Count(m.cachedRender, "\n") + 1
|
||||
}
|
||||
// Estimate: command header + stdout + stderr
|
||||
return 1 + len(m.stdoutLines) + len(m.stderrLines)
|
||||
}
|
||||
|
||||
// AppendStdout adds a stdout line to the output.
|
||||
func (m *StreamingBashOutputItem) AppendStdout(line string) {
|
||||
m.stdoutLines = append(m.stdoutLines, line)
|
||||
// Cap lines
|
||||
if len(m.stdoutLines) > m.maxLines {
|
||||
m.stdoutLines = m.stdoutLines[len(m.stdoutLines)-m.maxLines:]
|
||||
}
|
||||
m.cachedWidth = 0 // Invalidate cache
|
||||
}
|
||||
|
||||
// AppendStderr adds a stderr line to the output.
|
||||
func (m *StreamingBashOutputItem) AppendStderr(line string) {
|
||||
m.stderrLines = append(m.stderrLines, line)
|
||||
// Cap lines
|
||||
if len(m.stderrLines) > m.maxLines {
|
||||
m.stderrLines = m.stderrLines[len(m.stderrLines)-m.maxLines:]
|
||||
}
|
||||
m.cachedWidth = 0 // Invalidate cache
|
||||
}
|
||||
|
||||
// MarkComplete marks the bash output as complete.
|
||||
func (m *StreamingBashOutputItem) MarkComplete() {
|
||||
m.complete = true
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// SystemMessageItem - System messages (commands, info, errors)
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// SystemMessageItem represents a system message (commands, info, errors).
|
||||
type SystemMessageItem struct {
|
||||
id string
|
||||
content string
|
||||
timestamp time.Time
|
||||
cachedRender string
|
||||
cachedWidth int
|
||||
}
|
||||
|
||||
// NewSystemMessageItem creates a new system message for the scrollback.
|
||||
func NewSystemMessageItem(id, content string) *SystemMessageItem {
|
||||
return &SystemMessageItem{
|
||||
id: id,
|
||||
content: content,
|
||||
timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SystemMessageItem) ID() string {
|
||||
return m.id
|
||||
}
|
||||
|
||||
func (m *SystemMessageItem) Render(width int) string {
|
||||
// Return cached render if width matches
|
||||
if m.cachedWidth == width && m.cachedRender != "" {
|
||||
return m.cachedRender
|
||||
}
|
||||
|
||||
// Simple system message formatting
|
||||
rendered := "│ " + strings.ReplaceAll(m.content, "\n", "\n│ ")
|
||||
|
||||
// Cache and return
|
||||
m.cachedRender = rendered
|
||||
m.cachedWidth = width
|
||||
return rendered
|
||||
}
|
||||
|
||||
func (m *SystemMessageItem) Height() int {
|
||||
if m.cachedRender != "" {
|
||||
return strings.Count(m.cachedRender, "\n") + 1
|
||||
}
|
||||
// Estimate
|
||||
if m.cachedWidth > 0 {
|
||||
return (len(m.content) / max(m.cachedWidth-10, 40)) + 3
|
||||
}
|
||||
return 3
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Helper: generateMessageID
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
var messageCounter = 0
|
||||
|
||||
func generateMessageID() string {
|
||||
messageCounter++
|
||||
return fmt.Sprintf("msg-%d-%d", time.Now().UnixNano(), messageCounter)
|
||||
}
|
||||
+37
-46
@@ -3,17 +3,16 @@ package ui
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/lipgloss/v2"
|
||||
"github.com/indaco/herald"
|
||||
)
|
||||
|
||||
// ansiEscapeRe matches ANSI escape sequences used for terminal styling.
|
||||
var ansiEscapeRe = regexp.MustCompile(`\x1b\[[0-9;]*m`)
|
||||
"github.com/mark3labs/kit/internal/ui/render"
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
)
|
||||
|
||||
// MessageType represents different categories of messages displayed in the UI,
|
||||
// each with distinct visual styling and formatting rules.
|
||||
@@ -142,7 +141,7 @@ func newMessageRenderer(width int, debug bool) *MessageRenderer {
|
||||
return &MessageRenderer{
|
||||
width: width,
|
||||
debug: debug,
|
||||
ty: createTypography(GetTheme()),
|
||||
ty: createTypography(style.GetTheme()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -153,12 +152,7 @@ func (r *MessageRenderer) SetWidth(width int) {
|
||||
|
||||
// RenderUserMessage renders a user's input message using herald Tip alert
|
||||
func (r *MessageRenderer) RenderUserMessage(content string, timestamp time.Time) UIMessage {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
content = "(empty message)"
|
||||
}
|
||||
|
||||
rendered := r.ty.Tip(content)
|
||||
rendered = styleMarginBottom1.Render(rendered)
|
||||
rendered := render.UserBlock(content, r.width, r.ty, style.GetTheme())
|
||||
|
||||
return UIMessage{
|
||||
Type: UserMessage,
|
||||
@@ -170,18 +164,21 @@ func (r *MessageRenderer) RenderUserMessage(content string, timestamp time.Time)
|
||||
|
||||
// RenderAssistantMessage renders an AI assistant's response
|
||||
func (r *MessageRenderer) RenderAssistantMessage(content string, timestamp time.Time, modelName string) UIMessage {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return UIMessage{
|
||||
Type: AssistantMessage,
|
||||
Content: "",
|
||||
Height: 0,
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
rendered := render.AssistantBlock(content, r.width, style.GetTheme())
|
||||
|
||||
// Use markdown rendering with Chroma syntax highlighting
|
||||
rendered := toMarkdown(content, r.width-4)
|
||||
rendered = styleMarginBottom1.Render(rendered)
|
||||
return UIMessage{
|
||||
Type: AssistantMessage,
|
||||
Content: rendered,
|
||||
Height: lipgloss.Height(rendered),
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderReasoningBlock renders a reasoning/thinking block with the same styling
|
||||
// as live streaming: muted italic text with margin. This is used when resuming
|
||||
// sessions to display saved reasoning content.
|
||||
func (r *MessageRenderer) RenderReasoningBlock(content string, timestamp time.Time) UIMessage {
|
||||
rendered := render.ReasoningBlock(content, 0, r.ty, style.GetTheme())
|
||||
|
||||
return UIMessage{
|
||||
Type: AssistantMessage,
|
||||
@@ -193,12 +190,7 @@ func (r *MessageRenderer) RenderAssistantMessage(content string, timestamp time.
|
||||
|
||||
// RenderSystemMessage renders KIT system messages using herald Note alert
|
||||
func (r *MessageRenderer) RenderSystemMessage(content string, timestamp time.Time) UIMessage {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
content = "No content available"
|
||||
}
|
||||
|
||||
rendered := r.ty.Note(content)
|
||||
rendered = styleMarginBottom1.Render(rendered)
|
||||
rendered := render.SystemBlock(content, r.ty, style.GetTheme())
|
||||
|
||||
return UIMessage{
|
||||
Type: SystemMessage,
|
||||
@@ -264,8 +256,7 @@ func (r *MessageRenderer) RenderDebugConfigMessage(config map[string]any, timest
|
||||
|
||||
// RenderErrorMessage renders error notifications
|
||||
func (r *MessageRenderer) RenderErrorMessage(errorMsg string, timestamp time.Time) UIMessage {
|
||||
rendered := r.ty.Caution(errorMsg)
|
||||
rendered = styleMarginBottom1.Render(rendered)
|
||||
rendered := render.ErrorBlock(errorMsg, r.ty, style.GetTheme())
|
||||
|
||||
return UIMessage{
|
||||
Type: ErrorMessage,
|
||||
@@ -297,16 +288,16 @@ func (r *MessageRenderer) RenderToolMessage(toolName, toolArgs, toolResult strin
|
||||
}
|
||||
|
||||
var icon string
|
||||
iconColor := GetTheme().Success
|
||||
iconColor := style.GetTheme().Success
|
||||
if isError {
|
||||
icon = "×"
|
||||
iconColor = GetTheme().Error
|
||||
iconColor = style.GetTheme().Error
|
||||
} else {
|
||||
icon = "✓"
|
||||
}
|
||||
|
||||
// Style the tool name with color
|
||||
theme := GetTheme()
|
||||
theme := style.GetTheme()
|
||||
nameColor := theme.Info
|
||||
if isError {
|
||||
nameColor = theme.Error
|
||||
@@ -325,7 +316,7 @@ func (r *MessageRenderer) RenderToolMessage(toolName, toolArgs, toolResult strin
|
||||
if extRd != nil && extRd.RenderBody != nil {
|
||||
body = extRd.RenderBody(toolResult, isError, r.width-8)
|
||||
if body != "" && extRd.BodyMarkdown {
|
||||
body = strings.TrimSuffix(toMarkdown(body, r.width-8), "\n")
|
||||
body = strings.TrimSuffix(style.ToMarkdown(body, r.width-8), "\n")
|
||||
}
|
||||
}
|
||||
if body == "" {
|
||||
@@ -343,6 +334,12 @@ func (r *MessageRenderer) RenderToolMessage(toolName, toolArgs, toolResult strin
|
||||
body = r.ty.Italic("(no output)")
|
||||
}
|
||||
|
||||
// Wrap all tool errors in a herald Caution alert so the error text
|
||||
// renders inside a contained block instead of spilling into the layout.
|
||||
if isError && strings.TrimSpace(body) != "" {
|
||||
body = r.ty.Alert(herald.AlertCaution, body)
|
||||
}
|
||||
|
||||
// Compose: icon + name + params, then body
|
||||
fullContent := r.ty.Compose(
|
||||
headerLine,
|
||||
@@ -371,7 +368,7 @@ func (r *MessageRenderer) formatToolResult(toolName, result string) string {
|
||||
if strings.Contains(toolName, "bash") || strings.Contains(toolName, "command") ||
|
||||
strings.Contains(toolName, "shell") {
|
||||
if strings.Contains(result, "<stdout>") || strings.Contains(result, "<stderr>") {
|
||||
return parseBashOutput(result, GetTheme())
|
||||
return parseBashOutput(result, style.GetTheme())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -379,7 +376,7 @@ func (r *MessageRenderer) formatToolResult(toolName, result string) string {
|
||||
}
|
||||
|
||||
// createTypography creates a typography instance from theme
|
||||
func createTypography(theme Theme) *herald.Typography {
|
||||
func createTypography(theme style.Theme) *herald.Typography {
|
||||
return herald.New(
|
||||
herald.WithPalette(herald.ColorPalette{
|
||||
Primary: theme.Primary,
|
||||
@@ -408,14 +405,8 @@ func createTypography(theme Theme) *herald.Typography {
|
||||
)
|
||||
}
|
||||
|
||||
// removeBlankLines removes lines that are visually blank from rendered output.
|
||||
func removeBlankLines(s string) string {
|
||||
lines := strings.Split(s, "\n")
|
||||
filtered := lines[:0]
|
||||
for _, line := range lines {
|
||||
if strings.TrimSpace(ansiEscapeRe.ReplaceAllString(line, "")) != "" {
|
||||
filtered = append(filtered, line)
|
||||
}
|
||||
}
|
||||
return strings.Join(filtered, "\n")
|
||||
// UpdateTheme refreshes the renderer's typography instance with colors from
|
||||
// the current theme. This is called when the user changes themes via /theme.
|
||||
func (r *MessageRenderer) UpdateTheme() {
|
||||
r.ty = createTypography(style.GetTheme())
|
||||
}
|
||||
|
||||
+1093
-452
File diff suppressed because it is too large
Load Diff
+99
-285
@@ -5,9 +5,7 @@ import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"charm.land/bubbles/v2/key"
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/lipgloss/v2"
|
||||
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
)
|
||||
@@ -29,16 +27,14 @@ type ModelSelectedMsg struct {
|
||||
// ModelSelectorCancelledMsg is sent when the user cancels the selector.
|
||||
type ModelSelectorCancelledMsg struct{}
|
||||
|
||||
// ModelSelectorComponent is a full-screen Bubble Tea component that displays
|
||||
// a filterable list of available models. It follows the same pattern as
|
||||
// TreeSelectorComponent: inline text search, scrolling list, and custom
|
||||
// messages for result delivery.
|
||||
// ModelSelectorComponent is a Bubble Tea component that displays a filterable
|
||||
// list of available models as a centered overlay popup. It delegates rendering
|
||||
// and keyboard navigation to PopupList and converts results into the
|
||||
// ModelSelectedMsg / ModelSelectorCancelledMsg messages expected by AppModel.
|
||||
type ModelSelectorComponent struct {
|
||||
allModels []ModelEntry // all available models (pre-sorted)
|
||||
filtered []ModelEntry // subset matching the current search
|
||||
cursor int
|
||||
search string
|
||||
currentModel string // "provider/model" of the active model (for checkmark)
|
||||
popup *PopupList
|
||||
allModels []ModelEntry // kept for the custom filter callback
|
||||
currentModel string // "provider/model" of the active model
|
||||
width int
|
||||
height int
|
||||
active bool
|
||||
@@ -61,7 +57,22 @@ func NewModelSelector(currentModel string, width, height int) *ModelSelectorComp
|
||||
continue
|
||||
}
|
||||
|
||||
// For the custom provider, skip the built-in "custom" stub when
|
||||
// user-defined models are present — the stub is a fallback for
|
||||
// --provider-url usage and would just clutter the list.
|
||||
userDefinedCustomModels := 0
|
||||
if providerID == "custom" {
|
||||
for modelID := range modelsMap {
|
||||
if modelID != "custom" {
|
||||
userDefinedCustomModels++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for modelID, info := range modelsMap {
|
||||
if providerID == "custom" && modelID == "custom" && userDefinedCustomModels > 0 {
|
||||
continue
|
||||
}
|
||||
allModels = append(allModels, ModelEntry{
|
||||
Provider: providerID,
|
||||
ModelID: modelID,
|
||||
@@ -80,24 +91,31 @@ func NewModelSelector(currentModel string, width, height int) *ModelSelectorComp
|
||||
return allModels[i].ModelID < allModels[j].ModelID
|
||||
})
|
||||
|
||||
ms := &ModelSelectorComponent{
|
||||
// Build PopupItems from model entries.
|
||||
items := make([]PopupItem, len(allModels))
|
||||
for i, m := range allModels {
|
||||
items[i] = PopupItem{
|
||||
Label: m.ModelID,
|
||||
Description: fmt.Sprintf("[%s]", m.Provider),
|
||||
Active: m.Provider+"/"+m.ModelID == currentModel,
|
||||
Meta: m,
|
||||
}
|
||||
}
|
||||
|
||||
popup := NewPopupList("Model Selector", items, width, height)
|
||||
popup.Subtitle = "Only showing models with configured API keys"
|
||||
popup.FilterFunc = func(query string, allItems []PopupItem) []PopupItem {
|
||||
return filterModels(query, allItems)
|
||||
}
|
||||
|
||||
return &ModelSelectorComponent{
|
||||
popup: popup,
|
||||
allModels: allModels,
|
||||
filtered: allModels,
|
||||
currentModel: currentModel,
|
||||
width: width,
|
||||
height: height,
|
||||
active: true,
|
||||
}
|
||||
|
||||
// Position cursor on the current model if found.
|
||||
for i, m := range ms.filtered {
|
||||
if m.Provider+"/"+m.ModelID == currentModel {
|
||||
ms.cursor = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return ms
|
||||
}
|
||||
|
||||
// Init implements tea.Model.
|
||||
@@ -111,177 +129,43 @@ func (ms *ModelSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
case tea.WindowSizeMsg:
|
||||
ms.width = msg.Width
|
||||
ms.height = msg.Height
|
||||
ms.popup.SetSize(msg.Width, msg.Height)
|
||||
return ms, nil
|
||||
|
||||
case tea.KeyPressMsg:
|
||||
switch {
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("up", "k"))):
|
||||
if ms.cursor > 0 {
|
||||
ms.cursor--
|
||||
}
|
||||
result := ms.popup.HandleKey(msg.String(), msg.Text)
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("down", "j"))):
|
||||
if ms.cursor < len(ms.filtered)-1 {
|
||||
ms.cursor++
|
||||
if result.Selected != nil {
|
||||
ms.active = false
|
||||
entry := result.Selected.Meta.(ModelEntry)
|
||||
modelStr := entry.Provider + "/" + entry.ModelID
|
||||
return ms, func() tea.Msg {
|
||||
return ModelSelectedMsg{ModelString: modelStr}
|
||||
}
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("pgup"))):
|
||||
ms.cursor -= ms.visibleHeight()
|
||||
if ms.cursor < 0 {
|
||||
ms.cursor = 0
|
||||
}
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("pgdown"))):
|
||||
ms.cursor += ms.visibleHeight()
|
||||
if ms.cursor >= len(ms.filtered) {
|
||||
ms.cursor = len(ms.filtered) - 1
|
||||
}
|
||||
if ms.cursor < 0 {
|
||||
ms.cursor = 0
|
||||
}
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("home"))):
|
||||
ms.cursor = 0
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("end"))):
|
||||
ms.cursor = max(len(ms.filtered)-1, 0)
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("enter"))):
|
||||
if ms.cursor < len(ms.filtered) {
|
||||
entry := ms.filtered[ms.cursor]
|
||||
ms.active = false
|
||||
return ms, func() tea.Msg {
|
||||
return ModelSelectedMsg{
|
||||
ModelString: entry.Provider + "/" + entry.ModelID,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("esc"))):
|
||||
if ms.search != "" {
|
||||
ms.search = ""
|
||||
ms.rebuildFiltered()
|
||||
} else {
|
||||
ms.active = false
|
||||
return ms, func() tea.Msg {
|
||||
return ModelSelectorCancelledMsg{}
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
// Inline text search.
|
||||
if msg.Text != "" && len(msg.Text) == 1 {
|
||||
ch := msg.Text[0]
|
||||
if ch >= 32 && ch < 127 {
|
||||
ms.search += string(ch)
|
||||
ms.rebuildFiltered()
|
||||
}
|
||||
}
|
||||
if key.Matches(msg, key.NewBinding(key.WithKeys("backspace"))) && len(ms.search) > 0 {
|
||||
ms.search = ms.search[:len(ms.search)-1]
|
||||
ms.rebuildFiltered()
|
||||
}
|
||||
if result.Cancelled {
|
||||
ms.active = false
|
||||
return ms, func() tea.Msg {
|
||||
return ModelSelectorCancelledMsg{}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ms, nil
|
||||
}
|
||||
|
||||
// View implements tea.Model.
|
||||
// View implements tea.Model — not used for overlay rendering.
|
||||
// Use RenderOverlay for the centered overlay approach.
|
||||
func (ms *ModelSelectorComponent) View() tea.View {
|
||||
theme := GetTheme()
|
||||
// Fallback full-screen rendering (unused when rendered as overlay).
|
||||
v := tea.NewView(ms.popup.RenderCentered(ms.width, ms.height))
|
||||
v.AltScreen = true
|
||||
return v
|
||||
}
|
||||
|
||||
headerStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(theme.Accent).
|
||||
PaddingLeft(2)
|
||||
|
||||
helpStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
PaddingLeft(2)
|
||||
|
||||
infoStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Warning).
|
||||
PaddingLeft(2)
|
||||
|
||||
var b strings.Builder
|
||||
|
||||
// Header.
|
||||
b.WriteString(headerStyle.Render("Model Selector"))
|
||||
b.WriteString("\n")
|
||||
// Adapt help text to terminal width.
|
||||
if ms.width >= 56 {
|
||||
b.WriteString(helpStyle.Render("↑/↓: move enter: select esc: cancel type to filter"))
|
||||
} else if ms.width >= 35 {
|
||||
b.WriteString(helpStyle.Render("↑↓ move ↵ select esc type"))
|
||||
} else {
|
||||
b.WriteString(helpStyle.Render("↑↓ ↵ esc"))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
if ms.width >= 48 {
|
||||
b.WriteString(infoStyle.Render("Only showing models with configured API keys"))
|
||||
} else {
|
||||
b.WriteString(infoStyle.Render("Models with API keys"))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
|
||||
// Search input.
|
||||
searchStyle := lipgloss.NewStyle().Foreground(theme.Info).PaddingLeft(2)
|
||||
if ms.search != "" {
|
||||
b.WriteString(searchStyle.Render(fmt.Sprintf("> %s", ms.search)))
|
||||
} else {
|
||||
b.WriteString(searchStyle.Render("> "))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
|
||||
b.WriteString(lipgloss.NewStyle().Foreground(theme.Muted).Render(strings.Repeat("─", ms.width)))
|
||||
b.WriteString("\n")
|
||||
|
||||
if len(ms.filtered) == 0 {
|
||||
emptyStyle := lipgloss.NewStyle().Foreground(theme.Muted).PaddingLeft(2)
|
||||
if ms.search != "" {
|
||||
b.WriteString(emptyStyle.Render("No models matching \"" + ms.search + "\""))
|
||||
} else {
|
||||
b.WriteString(emptyStyle.Render("No models available (check API keys)"))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
} else {
|
||||
// Visible window.
|
||||
visH := ms.visibleHeight()
|
||||
startIdx := 0
|
||||
if ms.cursor >= visH {
|
||||
startIdx = ms.cursor - visH + 1
|
||||
}
|
||||
endIdx := min(startIdx+visH, len(ms.filtered))
|
||||
|
||||
for i := startIdx; i < endIdx; i++ {
|
||||
entry := ms.filtered[i]
|
||||
line := ms.renderEntry(entry, i == ms.cursor)
|
||||
b.WriteString(line)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Footer.
|
||||
b.WriteString(lipgloss.NewStyle().Foreground(theme.Muted).Render(strings.Repeat("─", ms.width)))
|
||||
b.WriteString("\n")
|
||||
|
||||
footerParts := []string{
|
||||
fmt.Sprintf("(%d/%d)", ms.cursor+1, len(ms.filtered)),
|
||||
}
|
||||
if ms.cursor < len(ms.filtered) {
|
||||
entry := ms.filtered[ms.cursor]
|
||||
if entry.Name != "" {
|
||||
footerParts = append(footerParts, fmt.Sprintf("Model Name: %s", entry.Name))
|
||||
}
|
||||
if entry.ContextLimit > 0 {
|
||||
footerParts = append(footerParts, fmt.Sprintf("Context: %dK", entry.ContextLimit/1000))
|
||||
}
|
||||
}
|
||||
|
||||
footerStyle := lipgloss.NewStyle().Foreground(theme.Muted).PaddingLeft(2)
|
||||
b.WriteString(footerStyle.Render(strings.Join(footerParts, " ")))
|
||||
|
||||
return tea.NewView(b.String())
|
||||
// RenderOverlay returns the popup as a centered overlay string, ready to be
|
||||
// composited on top of the main content via overlayContent().
|
||||
func (ms *ModelSelectorComponent) RenderOverlay(termWidth, termHeight int) string {
|
||||
return ms.popup.RenderCentered(termWidth, termHeight)
|
||||
}
|
||||
|
||||
// IsActive returns whether the selector is still accepting input.
|
||||
@@ -289,56 +173,50 @@ func (ms *ModelSelectorComponent) IsActive() bool {
|
||||
return ms.active
|
||||
}
|
||||
|
||||
// --- Internal helpers ---
|
||||
// --- Model-specific fuzzy filter ---
|
||||
|
||||
func (ms *ModelSelectorComponent) visibleHeight() int {
|
||||
// Reserve: header(1) + help(1) + info(1) + search(1) + separator(1) + footer(2) = 7.
|
||||
// Minimum 3 entries so the selector is still usable on short terminals.
|
||||
return max(ms.height-7, 3)
|
||||
}
|
||||
// filterModels scores and filters PopupItems whose Meta is a ModelEntry.
|
||||
func filterModels(query string, items []PopupItem) []PopupItem {
|
||||
if query == "" {
|
||||
return items
|
||||
}
|
||||
q := strings.ToLower(query)
|
||||
|
||||
func (ms *ModelSelectorComponent) rebuildFiltered() {
|
||||
if ms.search == "" {
|
||||
ms.filtered = ms.allModels
|
||||
} else {
|
||||
query := strings.ToLower(ms.search)
|
||||
ms.filtered = ms.filtered[:0]
|
||||
type scored struct {
|
||||
item PopupItem
|
||||
score int
|
||||
}
|
||||
var matches []scored
|
||||
|
||||
type scored struct {
|
||||
entry ModelEntry
|
||||
score int
|
||||
for _, item := range items {
|
||||
entry, ok := item.Meta.(ModelEntry)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
var matches []scored
|
||||
|
||||
for _, entry := range ms.allModels {
|
||||
s := ms.fuzzyScoreModel(query, entry)
|
||||
if s > 0 {
|
||||
matches = append(matches, scored{entry: entry, score: s})
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by score descending, then alphabetically.
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
if matches[i].score != matches[j].score {
|
||||
return matches[i].score > matches[j].score
|
||||
}
|
||||
return matches[i].entry.ModelID < matches[j].entry.ModelID
|
||||
})
|
||||
|
||||
ms.filtered = make([]ModelEntry, len(matches))
|
||||
for i, m := range matches {
|
||||
ms.filtered[i] = m.entry
|
||||
s := fuzzyScoreModelEntry(q, entry)
|
||||
if s > 0 {
|
||||
matches = append(matches, scored{item: item, score: s})
|
||||
}
|
||||
}
|
||||
|
||||
// Clamp cursor.
|
||||
if ms.cursor >= len(ms.filtered) {
|
||||
ms.cursor = max(len(ms.filtered)-1, 0)
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
if matches[i].score != matches[j].score {
|
||||
return matches[i].score > matches[j].score
|
||||
}
|
||||
a := matches[i].item.Meta.(ModelEntry)
|
||||
b := matches[j].item.Meta.(ModelEntry)
|
||||
return a.ModelID < b.ModelID
|
||||
})
|
||||
|
||||
result := make([]PopupItem, len(matches))
|
||||
for i, m := range matches {
|
||||
result[i] = m.item
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// fuzzyScoreModel scores a model entry against the search query.
|
||||
func (ms *ModelSelectorComponent) fuzzyScoreModel(query string, entry ModelEntry) int {
|
||||
// fuzzyScoreModelEntry scores a model entry against the search query.
|
||||
func fuzzyScoreModelEntry(query string, entry ModelEntry) int {
|
||||
modelID := strings.ToLower(entry.ModelID)
|
||||
provider := strings.ToLower(entry.Provider)
|
||||
name := strings.ToLower(entry.Name)
|
||||
@@ -391,67 +269,3 @@ func (ms *ModelSelectorComponent) fuzzyScoreModel(query string, entry ModelEntry
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func (ms *ModelSelectorComponent) renderEntry(entry ModelEntry, isCursor bool) string {
|
||||
theme := GetTheme()
|
||||
modelStr := entry.ModelID
|
||||
providerStr := fmt.Sprintf("[%s]", entry.Provider)
|
||||
|
||||
// Cursor indicator.
|
||||
var cursor string
|
||||
if isCursor {
|
||||
cursor = lipgloss.NewStyle().Foreground(theme.Accent).Render("-> ")
|
||||
} else {
|
||||
cursor = " "
|
||||
}
|
||||
|
||||
// Active model checkmark.
|
||||
var active string
|
||||
activeWidth := 0
|
||||
if entry.Provider+"/"+entry.ModelID == ms.currentModel {
|
||||
active = lipgloss.NewStyle().Foreground(theme.Success).Render(" \u2713")
|
||||
activeWidth = 2 // " ✓"
|
||||
}
|
||||
|
||||
// Truncate model ID and provider tag to fit terminal width.
|
||||
// Layout: cursor(3) + model + " " + provider + active.
|
||||
// Use rune length for display-width accuracy (the "…" suffix is 1 rune / 1 column).
|
||||
const cursorWidth = 3
|
||||
available := max(ms.width-cursorWidth-activeWidth-1, 10) // 1 for space between model and provider
|
||||
provDisplayLen := len([]rune(providerStr))
|
||||
modelDisplayLen := len([]rune(modelStr))
|
||||
|
||||
if modelDisplayLen+1+provDisplayLen > available {
|
||||
// Prioritize model name — truncate it, but keep provider visible.
|
||||
maxModel := max(available-provDisplayLen-1, 6)
|
||||
if maxModel < modelDisplayLen {
|
||||
if maxModel > 3 {
|
||||
runes := []rune(modelStr)
|
||||
modelStr = string(runes[:maxModel-1]) + "…"
|
||||
} else {
|
||||
runes := []rune(modelStr)
|
||||
modelStr = string(runes[:maxModel])
|
||||
}
|
||||
}
|
||||
// If provider itself is too long, drop it.
|
||||
modelDisplayLen = len([]rune(modelStr))
|
||||
if modelDisplayLen+1+provDisplayLen > available {
|
||||
providerStr = ""
|
||||
}
|
||||
}
|
||||
|
||||
// Style the model ID.
|
||||
modelStyle := lipgloss.NewStyle().Foreground(theme.Text)
|
||||
if isCursor {
|
||||
modelStyle = modelStyle.Bold(true).Foreground(theme.Accent)
|
||||
}
|
||||
|
||||
// Style the provider tag.
|
||||
providerStyle := lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
|
||||
result := cursor + modelStyle.Render(modelStr)
|
||||
if providerStr != "" {
|
||||
result += " " + providerStyle.Render(providerStr)
|
||||
}
|
||||
return result + active
|
||||
}
|
||||
|
||||
+212
-51
@@ -2,12 +2,14 @@ package ui
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/fantasy"
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
"github.com/mark3labs/kit/internal/session"
|
||||
"github.com/mark3labs/kit/internal/ui/core"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -70,7 +72,7 @@ func (s *stubAppController) AddContextMessage(_ string) {
|
||||
// no-op in tests
|
||||
}
|
||||
|
||||
func (s *stubAppController) RunWithFiles(prompt string, _ []fantasy.FilePart) int {
|
||||
func (s *stubAppController) RunWithFiles(prompt string, _ []kit.LLMFilePart) int {
|
||||
s.runCalls = append(s.runCalls, prompt)
|
||||
return s.queueLen
|
||||
}
|
||||
@@ -80,6 +82,11 @@ func (s *stubAppController) Steer(prompt string) int {
|
||||
return s.queueLen
|
||||
}
|
||||
|
||||
func (s *stubAppController) SteerWithFiles(prompt string, _ []kit.LLMFilePart) int {
|
||||
s.runCalls = append(s.runCalls, prompt)
|
||||
return s.queueLen
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Stub child components
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -87,7 +94,6 @@ func (s *stubAppController) Steer(prompt string) int {
|
||||
// stubStreamComponent satisfies streamComponentIface without rendering anything.
|
||||
type stubStreamComponent struct {
|
||||
resetCalled int
|
||||
height int
|
||||
lastMsg tea.Msg
|
||||
renderedContent string // returned by GetRenderedContent
|
||||
}
|
||||
@@ -99,11 +105,11 @@ func (s *stubStreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
func (s *stubStreamComponent) View() tea.View { return tea.NewView("") }
|
||||
func (s *stubStreamComponent) Reset() { s.resetCalled++; s.renderedContent = "" }
|
||||
func (s *stubStreamComponent) SetHeight(h int) { s.height = h }
|
||||
func (s *stubStreamComponent) GetRenderedContent() string { return s.renderedContent }
|
||||
func (s *stubStreamComponent) SpinnerView() string { return "" }
|
||||
func (s *stubStreamComponent) SetThinkingVisible(bool) {}
|
||||
func (s *stubStreamComponent) HasReasoning() bool { return false }
|
||||
func (s *stubStreamComponent) UpdateTheme() {}
|
||||
|
||||
// stubInputComponent satisfies inputComponentIface without rendering anything.
|
||||
type stubInputComponent struct {
|
||||
@@ -130,11 +136,12 @@ func newTestAppModel(ctrl AppController) (*AppModel, *stubStreamComponent, *stub
|
||||
stream: stream,
|
||||
input: input,
|
||||
renderer: newMessageRenderer(80, false),
|
||||
compactMode: false,
|
||||
modelName: "test-model",
|
||||
width: 80,
|
||||
height: 24,
|
||||
streamingBashMaxLines: 50, // Initialize buffer cap like NewAppModel does
|
||||
scrollList: NewScrollList(80, 20),
|
||||
messages: []MessageItem{},
|
||||
}
|
||||
return m, stream, input
|
||||
}
|
||||
@@ -167,7 +174,7 @@ func TestStateTransition_InputToWorking(t *testing.T) {
|
||||
t.Fatalf("expected stateInput, got %v", m.state)
|
||||
}
|
||||
|
||||
m = sendMsg(m, submitMsg{Text: "hello"})
|
||||
m = sendMsg(m, core.SubmitMsg{Text: "hello"})
|
||||
|
||||
if m.state != stateWorking {
|
||||
t.Fatalf("expected stateWorking after submitMsg, got %v", m.state)
|
||||
@@ -355,7 +362,7 @@ func TestESCCancel_timerExpiry(t *testing.T) {
|
||||
m.state = stateWorking
|
||||
m.canceling = true
|
||||
|
||||
m = sendMsg(m, cancelTimerExpiredMsg{})
|
||||
m = sendMsg(m, core.CancelTimerExpiredMsg{})
|
||||
|
||||
if m.canceling {
|
||||
t.Fatal("expected canceling=false after timer expiry")
|
||||
@@ -408,7 +415,7 @@ func TestQueuedMessages_storedOnQueuedSubmit(t *testing.T) {
|
||||
m, _, _ := newTestAppModel(ctrl)
|
||||
m.state = stateWorking
|
||||
|
||||
_, cmd := m.Update(submitMsg{Text: "queued prompt"})
|
||||
_, cmd := m.Update(core.SubmitMsg{Text: "queued prompt"})
|
||||
|
||||
if len(m.queuedMessages) != 1 {
|
||||
t.Fatalf("expected 1 queued message, got %d", len(m.queuedMessages))
|
||||
@@ -416,7 +423,7 @@ func TestQueuedMessages_storedOnQueuedSubmit(t *testing.T) {
|
||||
if m.queuedMessages[0] != "queued prompt" {
|
||||
t.Fatalf("expected queued message text 'queued prompt', got %q", m.queuedMessages[0])
|
||||
}
|
||||
// Should NOT produce a tea.Println cmd (message is anchored, not in scrollback).
|
||||
// Should NOT flush (message is anchored in ScrollList).
|
||||
if cmd != nil {
|
||||
t.Fatal("expected nil cmd for queued submit (message should not print to scrollback)")
|
||||
}
|
||||
@@ -506,19 +513,19 @@ func TestWindowResize_propagatesToStream(t *testing.T) {
|
||||
// sets the stream height after a resize.
|
||||
func TestWindowResize_distributeHeight(t *testing.T) {
|
||||
ctrl := &stubAppController{}
|
||||
m, stream, _ := newTestAppModel(ctrl)
|
||||
m, _, _ := newTestAppModel(ctrl)
|
||||
|
||||
// With height=30, stream height = 30 - 1 (separator) - 9 (input) - 1 (statusBar) = 19
|
||||
// With height=30, scroll height = 30 - 1 (separator) - 9 (input) - 1 (statusBar) = 19
|
||||
m = sendMsg(m, tea.WindowSizeMsg{Width: 80, Height: 30})
|
||||
_ = m
|
||||
|
||||
if stream.height != 19 {
|
||||
t.Fatalf("expected stream height=19, got %d", stream.height)
|
||||
if m.scrollList.height != 19 {
|
||||
t.Fatalf("expected scroll list height=19, got %d", m.scrollList.height)
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// tea.Println on step complete
|
||||
// Step complete behavior
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// TestStepComplete_preservesStreamContent verifies that StepCompleteEvent
|
||||
@@ -551,65 +558,87 @@ func TestStepComplete_noStreamContent_noCmd(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubmitMsg_printsUserMessage verifies that submitMsg produces a tea.Println
|
||||
// cmd for the user message.
|
||||
// TestSubmitMsg_printsUserMessage verifies that submitMsg adds the user message
|
||||
// to the ScrollList messages and triggers a layout update.
|
||||
func TestSubmitMsg_printsUserMessage(t *testing.T) {
|
||||
ctrl := &stubAppController{}
|
||||
m, _, _ := newTestAppModel(ctrl)
|
||||
|
||||
_, cmd := m.Update(submitMsg{Text: "user query"})
|
||||
m = sendMsg(m, core.SubmitMsg{Text: "user query"})
|
||||
|
||||
if cmd == nil {
|
||||
t.Fatal("expected non-nil cmd (tea.Println) for user message on submitMsg")
|
||||
// In alt screen mode, user messages are added to the in-memory ScrollList
|
||||
// rather than printed separately. Verify the message was added.
|
||||
found := false
|
||||
for _, msg := range m.messages {
|
||||
if tm, ok := msg.(*TextMessageItem); ok && tm.role == "user" && tm.content == "user query" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("expected user message 'user query' in ScrollList messages")
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolCallStarted_flushesOnly verifies that ToolCallStartedEvent flushes
|
||||
// accumulated stream content but does NOT print a tool call block (the unified
|
||||
// block is printed later on ToolResultEvent).
|
||||
// TestToolCallStarted_flushesOnly verifies that ToolCallStartedEvent marks
|
||||
// any active StreamingMessageItem as complete and resets the stream.
|
||||
func TestToolCallStarted_flushesOnly(t *testing.T) {
|
||||
ctrl := &stubAppController{}
|
||||
m, stream, _ := newTestAppModel(ctrl)
|
||||
m.state = stateWorking
|
||||
|
||||
// With no stream content, flush returns nil → cmd should be nil.
|
||||
_, cmd := m.Update(app.ToolCallStartedEvent{
|
||||
// With no stream content, nothing should change.
|
||||
initialCount := len(m.messages)
|
||||
m = sendMsg(m, app.ToolCallStartedEvent{
|
||||
ToolName: "bash",
|
||||
ToolArgs: `{"cmd":"ls"}`,
|
||||
})
|
||||
|
||||
if cmd != nil {
|
||||
t.Fatal("expected nil cmd on ToolCallStartedEvent with no stream content")
|
||||
if len(m.messages) != initialCount {
|
||||
t.Fatal("expected no new messages on ToolCallStartedEvent with no stream content")
|
||||
}
|
||||
|
||||
// With stream content, flush returns tea.Println → cmd should be non-nil.
|
||||
// Simulate a StreamingMessageItem already in messages (as if appendStreamingChunk was called)
|
||||
// plus the stream component having rendered content.
|
||||
streamItem := NewStreamingMessageItem("stream-1", "assistant", "test-model")
|
||||
streamItem.AppendChunk("partial text")
|
||||
m.messages = append(m.messages, streamItem)
|
||||
stream.renderedContent = "partial text"
|
||||
_, cmd = m.Update(app.ToolCallStartedEvent{
|
||||
|
||||
_ = sendMsg(m, app.ToolCallStartedEvent{
|
||||
ToolName: "bash",
|
||||
ToolArgs: `{"cmd":"ls"}`,
|
||||
})
|
||||
|
||||
if cmd == nil {
|
||||
t.Fatal("expected non-nil cmd on ToolCallStartedEvent with stream content to flush")
|
||||
// The StreamingMessageItem should have been marked complete.
|
||||
if streamItem.streaming {
|
||||
t.Fatal("expected StreamingMessageItem to be marked complete after ToolCallStartedEvent")
|
||||
}
|
||||
// Stream should have been reset.
|
||||
if stream.resetCalled == 0 {
|
||||
t.Fatal("expected stream.Reset() to be called")
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolResult_printsAndStartsSpinner verifies that ToolResultEvent produces
|
||||
// a non-nil cmd and the stream receives a SpinnerEvent.
|
||||
// TestToolResult_printsAndStartsSpinner verifies that ToolResultEvent adds
|
||||
// the tool result to the ScrollList and the stream receives a SpinnerEvent.
|
||||
func TestToolResult_printsAndStartsSpinner(t *testing.T) {
|
||||
ctrl := &stubAppController{}
|
||||
m, stream, _ := newTestAppModel(ctrl)
|
||||
m.state = stateWorking
|
||||
|
||||
_, cmd := m.Update(app.ToolResultEvent{
|
||||
initialCount := len(m.messages)
|
||||
|
||||
m = sendMsg(m, app.ToolResultEvent{
|
||||
ToolName: "bash",
|
||||
ToolArgs: "{}",
|
||||
Result: "output",
|
||||
IsError: false,
|
||||
})
|
||||
|
||||
if cmd == nil {
|
||||
t.Fatal("expected non-nil cmd on ToolResultEvent")
|
||||
// Tool result should have been added to ScrollList messages.
|
||||
if len(m.messages) <= initialCount {
|
||||
t.Fatal("expected tool result message added to ScrollList")
|
||||
}
|
||||
// Stream should have received a SpinnerEvent to start spinner for next LLM call.
|
||||
if stream.lastMsg == nil {
|
||||
@@ -621,7 +650,7 @@ func TestToolResult_printsAndStartsSpinner(t *testing.T) {
|
||||
}
|
||||
|
||||
// TestToolOutputEvent_accumulatesBashOutput verifies that ToolOutputEvent
|
||||
// accumulates stdout and stderr lines into the streaming bash output buffers.
|
||||
// accumulates stdout and stderr lines into a StreamingBashOutputItem in the ScrollList.
|
||||
func TestToolOutputEvent_accumulatesBashOutput(t *testing.T) {
|
||||
ctrl := &stubAppController{}
|
||||
m, _, _ := newTestAppModel(ctrl)
|
||||
@@ -635,11 +664,23 @@ func TestToolOutputEvent_accumulatesBashOutput(t *testing.T) {
|
||||
IsStderr: false,
|
||||
})
|
||||
|
||||
if len(m.streamingBashOutput) != 1 || m.streamingBashOutput[0] != "line one\n" {
|
||||
t.Fatalf("expected streamingBashOutput=['line one\\n'], got %v", m.streamingBashOutput)
|
||||
// Should have created a StreamingBashOutputItem in messages.
|
||||
var bashItem *StreamingBashOutputItem
|
||||
for _, msg := range m.messages {
|
||||
if item, ok := msg.(*StreamingBashOutputItem); ok {
|
||||
bashItem = item
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(m.streamingBashStderr) != 0 {
|
||||
t.Fatalf("expected empty streamingBashStderr, got %v", m.streamingBashStderr)
|
||||
if bashItem == nil {
|
||||
t.Fatal("expected StreamingBashOutputItem in messages after ToolOutputEvent")
|
||||
return
|
||||
}
|
||||
if len(bashItem.stdoutLines) != 1 || bashItem.stdoutLines[0] != "line one\n" {
|
||||
t.Fatalf("expected stdout=['line one\\n'], got %v", bashItem.stdoutLines)
|
||||
}
|
||||
if len(bashItem.stderrLines) != 0 {
|
||||
t.Fatalf("expected empty stderr, got %v", bashItem.stderrLines)
|
||||
}
|
||||
|
||||
// Send another stdout chunk.
|
||||
@@ -650,8 +691,15 @@ func TestToolOutputEvent_accumulatesBashOutput(t *testing.T) {
|
||||
IsStderr: false,
|
||||
})
|
||||
|
||||
if len(m.streamingBashOutput) != 2 {
|
||||
t.Fatalf("expected 2 stdout lines, got %d", len(m.streamingBashOutput))
|
||||
// Re-find the bash item (same item, updated)
|
||||
bashItem = nil
|
||||
for _, msg := range m.messages {
|
||||
if item, ok := msg.(*StreamingBashOutputItem); ok {
|
||||
bashItem = item
|
||||
}
|
||||
}
|
||||
if bashItem == nil || len(bashItem.stdoutLines) != 2 {
|
||||
t.Fatalf("expected 2 stdout lines, got %d", len(bashItem.stdoutLines))
|
||||
}
|
||||
|
||||
// Send stderr chunk.
|
||||
@@ -662,11 +710,17 @@ func TestToolOutputEvent_accumulatesBashOutput(t *testing.T) {
|
||||
IsStderr: true,
|
||||
})
|
||||
|
||||
if len(m.streamingBashStderr) != 1 {
|
||||
t.Fatalf("expected 1 stderr line, got %d", len(m.streamingBashStderr))
|
||||
bashItem = nil
|
||||
for _, msg := range m.messages {
|
||||
if item, ok := msg.(*StreamingBashOutputItem); ok {
|
||||
bashItem = item
|
||||
}
|
||||
}
|
||||
if m.streamingBashStderr[0] != "error: something failed\n" {
|
||||
t.Fatalf("expected stderr 'error: something failed\\n', got %q", m.streamingBashStderr[0])
|
||||
if bashItem == nil || len(bashItem.stderrLines) != 1 {
|
||||
t.Fatalf("expected 1 stderr line, got %d", len(bashItem.stderrLines))
|
||||
}
|
||||
if bashItem.stderrLines[0] != "error: something failed\n" {
|
||||
t.Fatalf("expected stderr 'error: something failed\\n', got %q", bashItem.stderrLines[0])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -748,16 +802,19 @@ func TestToolCallStarted_nonBashTool_doesNotSetCommand(t *testing.T) {
|
||||
}
|
||||
|
||||
// TestStepError_printCmd verifies that StepErrorEvent with a non-nil error
|
||||
// produces a non-nil cmd (the tea.Println call for the error message).
|
||||
// adds an error message to the ScrollList.
|
||||
func TestStepError_printCmd(t *testing.T) {
|
||||
ctrl := &stubAppController{}
|
||||
m, _, _ := newTestAppModel(ctrl)
|
||||
m.state = stateWorking
|
||||
|
||||
_, cmd := m.Update(app.StepErrorEvent{Err: errors.New("agent failed")})
|
||||
initialCount := len(m.messages)
|
||||
|
||||
if cmd == nil {
|
||||
t.Fatal("expected non-nil cmd (tea.Println) on StepErrorEvent with error")
|
||||
m = sendMsg(m, app.StepErrorEvent{Err: errors.New("agent failed")})
|
||||
|
||||
// Error should have been added to ScrollList messages.
|
||||
if len(m.messages) <= initialCount {
|
||||
t.Fatal("expected error message added to ScrollList on StepErrorEvent")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -827,7 +884,7 @@ func TestSubmit_duringWorking_stays(t *testing.T) {
|
||||
m, _, _ := newTestAppModel(ctrl)
|
||||
m.state = stateWorking
|
||||
|
||||
m = sendMsg(m, submitMsg{Text: "queued prompt"})
|
||||
m = sendMsg(m, core.SubmitMsg{Text: "queued prompt"})
|
||||
|
||||
if m.state != stateWorking {
|
||||
t.Fatalf("expected stateWorking to persist after submitMsg during working, got %v", m.state)
|
||||
@@ -836,3 +893,107 @@ func TestSubmit_duringWorking_stays(t *testing.T) {
|
||||
t.Fatalf("expected Run('queued prompt') called, got %v", ctrl.runCalls)
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// truncateMessageForBlock
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// TestTruncateMessageForBlock_shortMessage verifies that short messages are
|
||||
// returned unchanged.
|
||||
func TestTruncateMessageForBlock_shortMessage(t *testing.T) {
|
||||
msg := "hello world"
|
||||
got := truncateMessageForBlock(msg, 3, 80)
|
||||
if got != msg {
|
||||
t.Fatalf("expected unchanged message, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTruncateMessageForBlock_exactLines verifies that a message with exactly
|
||||
// maxLines hard lines is returned unchanged.
|
||||
func TestTruncateMessageForBlock_exactLines(t *testing.T) {
|
||||
msg := "line1\nline2\nline3"
|
||||
got := truncateMessageForBlock(msg, 3, 80)
|
||||
if got != msg {
|
||||
t.Fatalf("expected unchanged message, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTruncateMessageForBlock_tooManyLines verifies that messages exceeding
|
||||
// maxLines are truncated with an ellipsis.
|
||||
func TestTruncateMessageForBlock_tooManyLines(t *testing.T) {
|
||||
msg := "line1\nline2\nline3\nline4\nline5"
|
||||
got := truncateMessageForBlock(msg, 3, 80)
|
||||
want := "line1\nline2\nline3…"
|
||||
if got != want {
|
||||
t.Fatalf("expected %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTruncateMessageForBlock_longWrappingLine verifies that a single long
|
||||
// line that would wrap beyond maxLines is truncated.
|
||||
func TestTruncateMessageForBlock_longWrappingLine(t *testing.T) {
|
||||
// 100 chars at width 20 = 5 visual lines, exceeds maxLines=3
|
||||
msg := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
got := truncateMessageForBlock(msg, 3, 20)
|
||||
// Should be truncated to 3*20=60 runes + "…"
|
||||
if len([]rune(got)) != 61 { // 60 runes + "…"
|
||||
t.Fatalf("expected 61 runes (60 + ellipsis), got %d runes: %q", len([]rune(got)), got)
|
||||
}
|
||||
if got[len(got)-3:] != "…" { // "…" is 3 bytes in UTF-8
|
||||
t.Fatal("expected trailing ellipsis")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTruncateMessageForBlock_emptyMessage verifies that empty messages are
|
||||
// returned unchanged.
|
||||
func TestTruncateMessageForBlock_emptyMessage(t *testing.T) {
|
||||
got := truncateMessageForBlock("", 3, 80)
|
||||
if got != "" {
|
||||
t.Fatalf("expected empty string, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTruncateMessageForBlock_mixedWrapAndHardLines verifies truncation when
|
||||
// some hard lines wrap and the total exceeds maxLines.
|
||||
func TestTruncateMessageForBlock_mixedWrapAndHardLines(t *testing.T) {
|
||||
// First line: 40 chars at width 20 = 2 visual lines
|
||||
// Second line: "short" = 1 visual line (total: 3, exactly at limit)
|
||||
// Third line: would exceed
|
||||
msg := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\nshort\nextra"
|
||||
got := truncateMessageForBlock(msg, 3, 20)
|
||||
want := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\nshort…"
|
||||
if got != want {
|
||||
t.Fatalf("expected %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRenderQueuedMessages_truncatesLongMessages verifies that the rendered
|
||||
// queued message view truncates long messages instead of showing them in full.
|
||||
func TestRenderQueuedMessages_truncatesLongMessages(t *testing.T) {
|
||||
ctrl := &stubAppController{}
|
||||
m, _, _ := newTestAppModel(ctrl)
|
||||
m.width = 80
|
||||
|
||||
// Queue a very long message (20 lines).
|
||||
var b strings.Builder
|
||||
for i := range 20 {
|
||||
if i > 0 {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
b.WriteString("This is a long line of text for testing purposes")
|
||||
}
|
||||
m.queuedMessages = []string{b.String()}
|
||||
|
||||
rendered := m.renderQueuedMessages()
|
||||
if rendered == "" {
|
||||
t.Fatal("expected non-empty rendered output")
|
||||
}
|
||||
|
||||
// The full message would be ~20+ lines. With truncation to 3 content
|
||||
// lines + badge + padding, it should be much shorter.
|
||||
lines := len(strings.Split(rendered, "\n"))
|
||||
// 3 content lines + 1 badge + 2 padding + border overhead ≈ ~7 lines max
|
||||
if lines > 10 {
|
||||
t.Fatalf("expected truncated output to be ≤10 lines, got %d lines", lines)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/lipgloss/v2"
|
||||
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -133,7 +135,7 @@ func (o *overlayDialog) handleKey(msg tea.KeyPressMsg) (*overlayResult, tea.Cmd)
|
||||
// composition. The dialog is a bordered box centered (or anchored)
|
||||
// horizontally within the terminal width.
|
||||
func (o *overlayDialog) Render() string {
|
||||
theme := GetTheme()
|
||||
theme := style.GetTheme()
|
||||
|
||||
// Calculate dialog dimensions, clamped to terminal bounds.
|
||||
termW := max(o.width, 10)
|
||||
@@ -157,7 +159,7 @@ func (o *overlayDialog) Render() string {
|
||||
// Render body text (potentially as markdown).
|
||||
bodyText := o.content
|
||||
if o.markdown {
|
||||
bodyText = toMarkdown(bodyText, innerWidth)
|
||||
bodyText = style.ToMarkdown(bodyText, innerWidth)
|
||||
}
|
||||
bodyText = strings.TrimRight(bodyText, "\n")
|
||||
|
||||
|
||||
@@ -0,0 +1,501 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"charm.land/lipgloss/v2"
|
||||
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
)
|
||||
|
||||
// PopupItem represents a single entry in a PopupList. The component renders
|
||||
// Label as the primary text and Description as secondary text to its right.
|
||||
// The Active flag renders a checkmark to indicate the currently-active item
|
||||
// (e.g. the current model). Meta is opaque caller data returned on selection.
|
||||
type PopupItem struct {
|
||||
Label string // primary display text
|
||||
Description string // secondary text (shown right of label)
|
||||
Active bool // true → render checkmark indicator
|
||||
Meta any // opaque data returned on selection
|
||||
}
|
||||
|
||||
// PopupList is a generic, themed, scrollable fuzzy-find popup list. It is
|
||||
// rendered as a centered overlay on top of the normal TUI layout and can be
|
||||
// reused by any feature that needs a selection popup (slash commands, model
|
||||
// selector, session picker, extension-provided lists, etc.).
|
||||
//
|
||||
// The caller is responsible for:
|
||||
// - Building the initial item list
|
||||
// - Providing a fuzzy-filter callback (or nil for substring matching)
|
||||
// - Handling the result when the user selects or cancels
|
||||
//
|
||||
// Navigation: up/down to move, enter to select, esc to cancel, type to filter.
|
||||
type PopupList struct {
|
||||
// Title shown at the top of the popup.
|
||||
Title string
|
||||
// Subtitle shown below the title (dimmed).
|
||||
Subtitle string
|
||||
// FooterHint overrides the default keyboard-hint footer.
|
||||
FooterHint string
|
||||
|
||||
allItems []PopupItem // full unfiltered list
|
||||
filtered []PopupItem // subset matching the current search
|
||||
cursor int
|
||||
search string
|
||||
|
||||
// FilterFunc is called with (query, allItems) and should return the
|
||||
// filtered+scored subset. When nil, a default substring match is used.
|
||||
FilterFunc func(query string, items []PopupItem) []PopupItem
|
||||
|
||||
width int
|
||||
height int
|
||||
maxVisible int // max items visible at once (0 = auto from height)
|
||||
showSearch bool
|
||||
}
|
||||
|
||||
// PopupResult is returned by HandleKey to tell the caller what happened.
|
||||
type PopupResult struct {
|
||||
// Selected is non-nil when the user pressed Enter on an item.
|
||||
Selected *PopupItem
|
||||
// Cancelled is true when the user pressed Esc with no search text.
|
||||
Cancelled bool
|
||||
// Changed is true when the search or cursor moved (caller should re-render).
|
||||
Changed bool
|
||||
}
|
||||
|
||||
// NewPopupList creates a new popup list with the given items and dimensions.
|
||||
func NewPopupList(title string, items []PopupItem, width, height int) *PopupList {
|
||||
p := &PopupList{
|
||||
Title: title,
|
||||
allItems: items,
|
||||
filtered: items,
|
||||
width: width,
|
||||
height: height,
|
||||
showSearch: true,
|
||||
}
|
||||
// Position cursor on the active item if one exists.
|
||||
for i, item := range p.filtered {
|
||||
if item.Active {
|
||||
p.cursor = i
|
||||
break
|
||||
}
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// SetSize updates the popup dimensions (e.g. on window resize).
|
||||
func (p *PopupList) SetSize(width, height int) {
|
||||
p.width = width
|
||||
p.height = height
|
||||
}
|
||||
|
||||
// visibleCount returns the number of items visible at once.
|
||||
func (p *PopupList) visibleCount() int {
|
||||
if p.maxVisible > 0 {
|
||||
return p.maxVisible
|
||||
}
|
||||
// Reserve: title(1) + subtitle(1) + search(1) + separator(1) + footer(2) + border(2) + padding(2) = 10
|
||||
overhead := 8
|
||||
if p.Subtitle != "" {
|
||||
overhead++
|
||||
}
|
||||
if p.showSearch {
|
||||
overhead += 2 // search line + separator
|
||||
}
|
||||
return max(p.height/2-overhead, 3)
|
||||
}
|
||||
|
||||
// HandleKey processes a single key event and returns the result. The caller
|
||||
// should inspect PopupResult to decide whether to re-render, close the popup,
|
||||
// or act on a selection.
|
||||
//
|
||||
// keyName is the Bubble Tea key string (e.g. "up", "down", "enter", "esc").
|
||||
// keyText is the printable text for character keys (e.g. "a", "1").
|
||||
func (p *PopupList) HandleKey(keyName, keyText string) PopupResult {
|
||||
switch keyName {
|
||||
case "up":
|
||||
if p.cursor > 0 {
|
||||
p.cursor--
|
||||
return PopupResult{Changed: true}
|
||||
}
|
||||
return PopupResult{}
|
||||
|
||||
case "down":
|
||||
if p.cursor < len(p.filtered)-1 {
|
||||
p.cursor++
|
||||
return PopupResult{Changed: true}
|
||||
}
|
||||
return PopupResult{}
|
||||
|
||||
case "pgup":
|
||||
p.cursor -= p.visibleCount()
|
||||
if p.cursor < 0 {
|
||||
p.cursor = 0
|
||||
}
|
||||
return PopupResult{Changed: true}
|
||||
|
||||
case "pgdown":
|
||||
p.cursor += p.visibleCount()
|
||||
if p.cursor >= len(p.filtered) {
|
||||
p.cursor = max(len(p.filtered)-1, 0)
|
||||
}
|
||||
return PopupResult{Changed: true}
|
||||
|
||||
case "home":
|
||||
p.cursor = 0
|
||||
return PopupResult{Changed: true}
|
||||
|
||||
case "end":
|
||||
p.cursor = max(len(p.filtered)-1, 0)
|
||||
return PopupResult{Changed: true}
|
||||
|
||||
case "enter":
|
||||
if p.cursor < len(p.filtered) {
|
||||
item := p.filtered[p.cursor]
|
||||
return PopupResult{Selected: &item}
|
||||
}
|
||||
return PopupResult{}
|
||||
|
||||
case "esc":
|
||||
if p.search != "" {
|
||||
p.search = ""
|
||||
p.rebuildFiltered()
|
||||
return PopupResult{Changed: true}
|
||||
}
|
||||
return PopupResult{Cancelled: true}
|
||||
|
||||
case "backspace":
|
||||
if len(p.search) > 0 {
|
||||
p.search = p.search[:len(p.search)-1]
|
||||
p.rebuildFiltered()
|
||||
return PopupResult{Changed: true}
|
||||
}
|
||||
return PopupResult{}
|
||||
|
||||
default:
|
||||
// Printable character → append to search.
|
||||
if keyText != "" && len(keyText) == 1 {
|
||||
ch := keyText[0]
|
||||
if ch >= 32 && ch < 127 {
|
||||
p.search += string(ch)
|
||||
p.rebuildFiltered()
|
||||
return PopupResult{Changed: true}
|
||||
}
|
||||
}
|
||||
return PopupResult{}
|
||||
}
|
||||
}
|
||||
|
||||
// Render returns the styled popup content (bordered box) ready to be placed
|
||||
// as a centered overlay via lipgloss.Place + overlayContent.
|
||||
func (p *PopupList) Render() string {
|
||||
theme := style.GetTheme()
|
||||
popupWidth := max(min(p.width-4, 80), 20)
|
||||
popupBg := theme.Background
|
||||
|
||||
popupStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(theme.Primary).
|
||||
Background(popupBg).
|
||||
Padding(1, 2).
|
||||
Width(popupWidth).
|
||||
MarginBottom(1)
|
||||
|
||||
// Inner content width: popup minus border (2) and horizontal padding (4).
|
||||
innerWidth := max(popupWidth-6, 10)
|
||||
|
||||
var b strings.Builder
|
||||
|
||||
// Title.
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(theme.Accent).
|
||||
Background(popupBg).
|
||||
Width(innerWidth)
|
||||
b.WriteString(titleStyle.Render(p.Title))
|
||||
b.WriteString("\n")
|
||||
|
||||
// Subtitle.
|
||||
if p.Subtitle != "" {
|
||||
subtitleStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(popupBg).
|
||||
Width(innerWidth)
|
||||
b.WriteString(subtitleStyle.Render(p.Subtitle))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Search input.
|
||||
if p.showSearch {
|
||||
searchStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Info).
|
||||
Background(popupBg).
|
||||
Width(innerWidth)
|
||||
if p.search != "" {
|
||||
b.WriteString(searchStyle.Render(fmt.Sprintf("> %s", p.search)))
|
||||
} else {
|
||||
b.WriteString(searchStyle.Render("> "))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
|
||||
// Separator.
|
||||
sepStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(popupBg)
|
||||
b.WriteString(sepStyle.Render(strings.Repeat("─", innerWidth)))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Item list.
|
||||
normalItemBg := lipgloss.NewStyle().
|
||||
Background(popupBg).
|
||||
Foreground(theme.Text).
|
||||
Width(innerWidth).
|
||||
Padding(0, 1)
|
||||
|
||||
selectedItemBg := lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(theme.Background).
|
||||
Width(innerWidth).
|
||||
Padding(0, 1).
|
||||
Bold(true)
|
||||
|
||||
scrollStyle := lipgloss.NewStyle().
|
||||
Background(popupBg).
|
||||
Foreground(theme.VeryMuted).
|
||||
Width(innerWidth).
|
||||
Padding(0, 1)
|
||||
|
||||
vis := p.visibleCount()
|
||||
var items []string
|
||||
|
||||
if len(p.filtered) == 0 {
|
||||
emptyStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(popupBg).
|
||||
Width(innerWidth).
|
||||
Padding(0, 1)
|
||||
if p.search != "" {
|
||||
items = append(items, emptyStyle.Render("No matches for \""+p.search+"\""))
|
||||
} else {
|
||||
items = append(items, emptyStyle.Render("No items"))
|
||||
}
|
||||
} else {
|
||||
startIdx := 0
|
||||
if p.cursor >= vis {
|
||||
startIdx = p.cursor - vis + 1
|
||||
}
|
||||
endIdx := min(startIdx+vis, len(p.filtered))
|
||||
|
||||
if startIdx > 0 {
|
||||
items = append(items, scrollStyle.Render(" ↑ more above"))
|
||||
}
|
||||
|
||||
for i := startIdx; i < endIdx; i++ {
|
||||
entry := p.filtered[i]
|
||||
isCursor := i == p.cursor
|
||||
|
||||
itemStyle := normalItemBg
|
||||
if isCursor {
|
||||
itemStyle = selectedItemBg
|
||||
}
|
||||
|
||||
// Build indicator.
|
||||
var indicator string
|
||||
if isCursor {
|
||||
indicator = "> "
|
||||
} else {
|
||||
indicator = " "
|
||||
}
|
||||
|
||||
// Build content: indicator + label + description + active checkmark.
|
||||
content := p.renderItemContent(indicator, entry, innerWidth, isCursor)
|
||||
items = append(items, itemStyle.Render(content))
|
||||
}
|
||||
|
||||
if endIdx < len(p.filtered) {
|
||||
items = append(items, scrollStyle.Render(" ↓ more below"))
|
||||
}
|
||||
}
|
||||
|
||||
content := b.String() + strings.Join(items, "\n")
|
||||
|
||||
// Footer with count and keyboard hints.
|
||||
var footerParts []string
|
||||
footerParts = append(footerParts, fmt.Sprintf("(%d/%d)", p.cursor+1, len(p.filtered)))
|
||||
|
||||
footerHint := p.FooterHint
|
||||
if footerHint == "" {
|
||||
if innerWidth >= 50 {
|
||||
footerHint = "↑↓ navigate • enter select • esc cancel • type to filter"
|
||||
} else if innerWidth >= 30 {
|
||||
footerHint = "↑↓ nav • ↵ select • esc"
|
||||
} else {
|
||||
footerHint = "↑↓ ↵ esc"
|
||||
}
|
||||
}
|
||||
footerParts = append(footerParts, footerHint)
|
||||
|
||||
footer := lipgloss.NewStyle().
|
||||
Background(popupBg).
|
||||
Foreground(theme.VeryMuted).
|
||||
Italic(true).
|
||||
Render(strings.Join(footerParts, " "))
|
||||
|
||||
return popupStyle.Render(content + "\n\n" + footer)
|
||||
}
|
||||
|
||||
// RenderCentered returns the popup placed at the center of a termWidth×termHeight
|
||||
// canvas, ready to be composed with overlayContent().
|
||||
func (p *PopupList) RenderCentered(termWidth, termHeight int) string {
|
||||
popupContent := p.Render()
|
||||
return lipgloss.Place(
|
||||
termWidth,
|
||||
termHeight,
|
||||
lipgloss.Center,
|
||||
lipgloss.Center,
|
||||
popupContent,
|
||||
)
|
||||
}
|
||||
|
||||
// IsSearching returns true when the search input is non-empty.
|
||||
func (p *PopupList) IsSearching() bool {
|
||||
return p.search != ""
|
||||
}
|
||||
|
||||
// SelectedItem returns the item under the cursor, or nil if the list is empty.
|
||||
func (p *PopupList) SelectedItem() *PopupItem {
|
||||
if p.cursor < len(p.filtered) {
|
||||
item := p.filtered[p.cursor]
|
||||
return &item
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Internal helpers ---
|
||||
|
||||
func (p *PopupList) rebuildFiltered() {
|
||||
if p.FilterFunc != nil {
|
||||
p.filtered = p.FilterFunc(p.search, p.allItems)
|
||||
} else {
|
||||
p.filtered = defaultFilter(p.search, p.allItems)
|
||||
}
|
||||
// Clamp cursor.
|
||||
if p.cursor >= len(p.filtered) {
|
||||
p.cursor = max(len(p.filtered)-1, 0)
|
||||
}
|
||||
}
|
||||
|
||||
// defaultFilter is a simple case-insensitive substring + fuzzy character match.
|
||||
func defaultFilter(query string, items []PopupItem) []PopupItem {
|
||||
if query == "" {
|
||||
return items
|
||||
}
|
||||
q := strings.ToLower(query)
|
||||
type scored struct {
|
||||
item PopupItem
|
||||
score int
|
||||
}
|
||||
var matches []scored
|
||||
for _, item := range items {
|
||||
label := strings.ToLower(item.Label)
|
||||
desc := strings.ToLower(item.Description)
|
||||
|
||||
var s int
|
||||
switch {
|
||||
case label == q:
|
||||
s = 1000
|
||||
case strings.HasPrefix(label, q):
|
||||
s = 800 - len(label) + len(q)
|
||||
case strings.Contains(label, q):
|
||||
s = 600
|
||||
case strings.Contains(desc, q):
|
||||
s = 400
|
||||
default:
|
||||
s = fuzzyCharacterMatch(q, label)
|
||||
}
|
||||
if s > 0 {
|
||||
matches = append(matches, scored{item: item, score: s})
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by score descending, then alphabetically by label.
|
||||
for i := 0; i < len(matches)-1; i++ {
|
||||
for j := i + 1; j < len(matches); j++ {
|
||||
if matches[j].score > matches[i].score ||
|
||||
(matches[j].score == matches[i].score && matches[j].item.Label < matches[i].item.Label) {
|
||||
matches[i], matches[j] = matches[j], matches[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := make([]PopupItem, len(matches))
|
||||
for i, m := range matches {
|
||||
result[i] = m.item
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// renderItemContent builds the display string for a single item row.
|
||||
func (p *PopupList) renderItemContent(indicator string, entry PopupItem, innerWidth int, isCursor bool) string {
|
||||
theme := style.GetTheme()
|
||||
|
||||
// Reserve space: indicator(2) + potential checkmark(2)
|
||||
activeWidth := 0
|
||||
if entry.Active {
|
||||
activeWidth = 2
|
||||
}
|
||||
available := max(innerWidth-2-activeWidth, 6) // 2 for indicator, already included
|
||||
|
||||
label := entry.Label
|
||||
desc := entry.Description
|
||||
|
||||
if desc != "" {
|
||||
// Two-column layout: label + description.
|
||||
descWidth := len([]rune(desc)) + 1 // 1 space gap
|
||||
labelMax := max(available-descWidth, available*2/3)
|
||||
if len([]rune(label)) > labelMax && labelMax > 3 {
|
||||
runes := []rune(label)
|
||||
label = string(runes[:labelMax-1]) + "…"
|
||||
}
|
||||
labelDisplayLen := len([]rune(label))
|
||||
|
||||
// If label + desc don't fit, truncate or drop desc.
|
||||
if labelDisplayLen+1+len([]rune(desc)) > available {
|
||||
remaining := available - labelDisplayLen - 1
|
||||
if remaining >= 4 {
|
||||
runes := []rune(desc)
|
||||
if len(runes) > remaining {
|
||||
desc = string(runes[:remaining-1]) + "…"
|
||||
}
|
||||
} else {
|
||||
desc = ""
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Single column: just the label.
|
||||
if len([]rune(label)) > available && available > 3 {
|
||||
runes := []rune(label)
|
||||
label = string(runes[:available-1]) + "…"
|
||||
}
|
||||
}
|
||||
|
||||
result := indicator + label
|
||||
if desc != "" {
|
||||
descStyle := lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
if isCursor {
|
||||
// When selected, use a dimmer foreground that still contrasts with Primary bg.
|
||||
descStyle = lipgloss.NewStyle().Foreground(theme.Background)
|
||||
}
|
||||
result += " " + descStyle.Render(desc)
|
||||
}
|
||||
if entry.Active {
|
||||
checkStyle := lipgloss.NewStyle().Foreground(theme.Success)
|
||||
if isCursor {
|
||||
checkStyle = lipgloss.NewStyle().Foreground(theme.Background)
|
||||
}
|
||||
result += checkStyle.Render(" ✓")
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,297 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPopupList_NewPositionsCursorOnActiveItem(t *testing.T) {
|
||||
items := []PopupItem{
|
||||
{Label: "alpha"},
|
||||
{Label: "beta"},
|
||||
{Label: "gamma", Active: true},
|
||||
{Label: "delta"},
|
||||
}
|
||||
p := NewPopupList("Test", items, 80, 40)
|
||||
|
||||
if p.cursor != 2 {
|
||||
t.Errorf("expected cursor on active item (index 2), got %d", p.cursor)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPopupList_HandleKey_Navigation(t *testing.T) {
|
||||
items := []PopupItem{
|
||||
{Label: "alpha"},
|
||||
{Label: "beta"},
|
||||
{Label: "gamma"},
|
||||
}
|
||||
p := NewPopupList("Test", items, 80, 40)
|
||||
|
||||
// Initial cursor at 0.
|
||||
if p.cursor != 0 {
|
||||
t.Fatalf("expected cursor 0, got %d", p.cursor)
|
||||
}
|
||||
|
||||
// Down → 1.
|
||||
res := p.HandleKey("down", "")
|
||||
if !res.Changed || p.cursor != 1 {
|
||||
t.Errorf("down: changed=%v cursor=%d", res.Changed, p.cursor)
|
||||
}
|
||||
|
||||
// Down → 2.
|
||||
p.HandleKey("down", "")
|
||||
if p.cursor != 2 {
|
||||
t.Errorf("expected cursor 2, got %d", p.cursor)
|
||||
}
|
||||
|
||||
// Down at end → stays at 2.
|
||||
res = p.HandleKey("down", "")
|
||||
if p.cursor != 2 {
|
||||
t.Errorf("down at end: expected cursor 2, got %d", p.cursor)
|
||||
}
|
||||
|
||||
// Up → 1.
|
||||
res = p.HandleKey("up", "")
|
||||
if !res.Changed || p.cursor != 1 {
|
||||
t.Errorf("up: changed=%v cursor=%d", res.Changed, p.cursor)
|
||||
}
|
||||
|
||||
// Home → 0.
|
||||
p.HandleKey("home", "")
|
||||
if p.cursor != 0 {
|
||||
t.Errorf("home: expected cursor 0, got %d", p.cursor)
|
||||
}
|
||||
|
||||
// End → 2.
|
||||
p.HandleKey("end", "")
|
||||
if p.cursor != 2 {
|
||||
t.Errorf("end: expected cursor 2, got %d", p.cursor)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPopupList_HandleKey_Search(t *testing.T) {
|
||||
items := []PopupItem{
|
||||
{Label: "apple"},
|
||||
{Label: "banana"},
|
||||
{Label: "cherry"},
|
||||
}
|
||||
p := NewPopupList("Test", items, 80, 40)
|
||||
|
||||
// Type "an" → should filter to banana.
|
||||
p.HandleKey("a", "a")
|
||||
p.HandleKey("n", "n")
|
||||
|
||||
if !p.IsSearching() {
|
||||
t.Error("expected IsSearching() to be true")
|
||||
}
|
||||
if len(p.filtered) == 0 {
|
||||
t.Fatal("expected at least one filtered result")
|
||||
}
|
||||
// banana should match (contains "an").
|
||||
found := false
|
||||
for _, item := range p.filtered {
|
||||
if item.Label == "banana" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected 'banana' in filtered results")
|
||||
}
|
||||
|
||||
// Backspace removes last char.
|
||||
p.HandleKey("backspace", "")
|
||||
if p.search != "a" {
|
||||
t.Errorf("expected search 'a' after backspace, got %q", p.search)
|
||||
}
|
||||
|
||||
// Esc clears search.
|
||||
res := p.HandleKey("esc", "")
|
||||
if res.Cancelled {
|
||||
t.Error("esc with search should clear search, not cancel")
|
||||
}
|
||||
if p.search != "" {
|
||||
t.Errorf("expected empty search after esc, got %q", p.search)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPopupList_HandleKey_SelectAndCancel(t *testing.T) {
|
||||
items := []PopupItem{
|
||||
{Label: "alpha", Meta: "first"},
|
||||
{Label: "beta", Meta: "second"},
|
||||
}
|
||||
p := NewPopupList("Test", items, 80, 40)
|
||||
|
||||
// Select first item.
|
||||
res := p.HandleKey("enter", "")
|
||||
if res.Selected == nil {
|
||||
t.Fatal("expected a selection on enter")
|
||||
}
|
||||
if res.Selected.Label != "alpha" {
|
||||
t.Errorf("expected 'alpha', got %q", res.Selected.Label)
|
||||
}
|
||||
if res.Selected.Meta != "first" {
|
||||
t.Errorf("expected meta 'first', got %v", res.Selected.Meta)
|
||||
}
|
||||
|
||||
// Cancel with esc (no search text).
|
||||
p2 := NewPopupList("Test", items, 80, 40)
|
||||
res = p2.HandleKey("esc", "")
|
||||
if !res.Cancelled {
|
||||
t.Error("expected Cancelled on esc with no search")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPopupList_DefaultFilter(t *testing.T) {
|
||||
items := []PopupItem{
|
||||
{Label: "foo-bar"},
|
||||
{Label: "baz-qux"},
|
||||
{Label: "foobar"},
|
||||
}
|
||||
|
||||
// Exact prefix.
|
||||
result := defaultFilter("foo", items)
|
||||
if len(result) < 2 {
|
||||
t.Fatalf("expected at least 2 matches for 'foo', got %d", len(result))
|
||||
}
|
||||
// "foobar" should rank higher (shorter match) or equal to "foo-bar".
|
||||
if result[0].Label != "foobar" && result[1].Label != "foobar" {
|
||||
t.Error("expected 'foobar' in top results")
|
||||
}
|
||||
|
||||
// No match.
|
||||
result = defaultFilter("zzz", items)
|
||||
if len(result) != 0 {
|
||||
t.Errorf("expected 0 matches for 'zzz', got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPopupList_CustomFilterFunc(t *testing.T) {
|
||||
items := []PopupItem{
|
||||
{Label: "alpha"},
|
||||
{Label: "beta"},
|
||||
{Label: "gamma"},
|
||||
}
|
||||
p := NewPopupList("Test", items, 80, 40)
|
||||
p.FilterFunc = func(query string, allItems []PopupItem) []PopupItem {
|
||||
// Custom: only return items whose label starts with query.
|
||||
var result []PopupItem
|
||||
for _, item := range allItems {
|
||||
if strings.HasPrefix(item.Label, query) {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
p.HandleKey("b", "b")
|
||||
if len(p.filtered) != 1 || p.filtered[0].Label != "beta" {
|
||||
t.Errorf("expected ['beta'], got %v", p.filtered)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPopupList_Render(t *testing.T) {
|
||||
items := []PopupItem{
|
||||
{Label: "alpha", Description: "[test]"},
|
||||
{Label: "beta", Description: "[test]", Active: true},
|
||||
}
|
||||
p := NewPopupList("My List", items, 80, 40)
|
||||
p.Subtitle = "Some subtitle"
|
||||
|
||||
rendered := p.Render()
|
||||
if rendered == "" {
|
||||
t.Fatal("expected non-empty rendered output")
|
||||
}
|
||||
|
||||
// Strip ANSI escape sequences for content checking.
|
||||
plain := stripAnsi(rendered)
|
||||
if !strings.Contains(plain, "My List") {
|
||||
t.Error("expected title 'My List' in rendered output")
|
||||
}
|
||||
if !strings.Contains(plain, "alpha") {
|
||||
t.Error("expected 'alpha' in rendered output")
|
||||
}
|
||||
if !strings.Contains(plain, "beta") {
|
||||
t.Error("expected 'beta' in rendered output")
|
||||
}
|
||||
if !strings.Contains(plain, "✓") {
|
||||
t.Error("expected checkmark for active item")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPopupList_RenderCentered(t *testing.T) {
|
||||
items := []PopupItem{
|
||||
{Label: "item1"},
|
||||
}
|
||||
p := NewPopupList("Test", items, 80, 40)
|
||||
|
||||
centered := p.RenderCentered(80, 40)
|
||||
if centered == "" {
|
||||
t.Fatal("expected non-empty centered output")
|
||||
}
|
||||
// Should contain newlines for vertical centering.
|
||||
lines := strings.Split(centered, "\n")
|
||||
if len(lines) < 10 {
|
||||
t.Errorf("expected centered output to have many lines, got %d", len(lines))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPopupList_EmptyItems(t *testing.T) {
|
||||
p := NewPopupList("Empty", nil, 80, 40)
|
||||
|
||||
rendered := p.Render()
|
||||
if !strings.Contains(rendered, "No items") {
|
||||
t.Error("expected 'No items' for empty list")
|
||||
}
|
||||
|
||||
// Navigate on empty list shouldn't panic.
|
||||
p.HandleKey("down", "")
|
||||
p.HandleKey("up", "")
|
||||
res := p.HandleKey("enter", "")
|
||||
if res.Selected != nil {
|
||||
t.Error("enter on empty list should not select")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPopupList_SearchNoResults(t *testing.T) {
|
||||
items := []PopupItem{
|
||||
{Label: "alpha"},
|
||||
{Label: "beta"},
|
||||
}
|
||||
p := NewPopupList("Test", items, 80, 40)
|
||||
|
||||
// Type something that doesn't match.
|
||||
p.HandleKey("z", "z")
|
||||
p.HandleKey("z", "z")
|
||||
p.HandleKey("z", "z")
|
||||
|
||||
rendered := p.Render()
|
||||
if !strings.Contains(rendered, "No matches") {
|
||||
t.Error("expected 'No matches' message for empty search results")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPopupList_CursorClamping(t *testing.T) {
|
||||
items := []PopupItem{
|
||||
{Label: "alpha"},
|
||||
{Label: "beta"},
|
||||
{Label: "gamma"},
|
||||
}
|
||||
p := NewPopupList("Test", items, 80, 40)
|
||||
|
||||
// Move to last item.
|
||||
p.HandleKey("end", "")
|
||||
if p.cursor != 2 {
|
||||
t.Fatalf("expected cursor 2, got %d", p.cursor)
|
||||
}
|
||||
|
||||
// Search that reduces list to 1 item → cursor should clamp.
|
||||
p.HandleKey("a", "a")
|
||||
p.HandleKey("l", "l")
|
||||
// Only "alpha" should match.
|
||||
if p.cursor >= len(p.filtered) {
|
||||
t.Errorf("cursor %d should be < filtered count %d", p.cursor, len(p.filtered))
|
||||
}
|
||||
}
|
||||
|
||||
// stripAnsi is defined in usage_tracker_render_test.go
|
||||
@@ -1,4 +1,4 @@
|
||||
package ui
|
||||
package prefs
|
||||
|
||||
import (
|
||||
"os"
|
||||
@@ -118,22 +118,33 @@ func (m ProgressModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// status information and help text. Displays error messages if present or
|
||||
// a completion message when the download finishes.
|
||||
func (m ProgressModel) View() tea.View {
|
||||
var v tea.View
|
||||
v.AltScreen = true
|
||||
v.MouseMode = tea.MouseModeCellMotion
|
||||
v.ReportFocus = true
|
||||
v.KeyboardEnhancements = tea.KeyboardEnhancements{
|
||||
ReportEventTypes: true,
|
||||
}
|
||||
|
||||
if m.err != nil {
|
||||
return tea.NewView(fmt.Sprintf("Error: %s\n", m.err.Error()))
|
||||
v.Content = fmt.Sprintf("Error: %s\n", m.err.Error())
|
||||
return v
|
||||
}
|
||||
|
||||
if m.complete {
|
||||
return tea.NewView(fmt.Sprintf("\n%s%s\n\n%sComplete!\n",
|
||||
v.Content = fmt.Sprintf("\n%s%s\n\n%sComplete!\n",
|
||||
strings.Repeat(" ", padding),
|
||||
m.progress.View(),
|
||||
strings.Repeat(" ", padding)))
|
||||
strings.Repeat(" ", padding))
|
||||
return v
|
||||
}
|
||||
|
||||
pad := strings.Repeat(" ", padding)
|
||||
return tea.NewView(fmt.Sprintf("\n%s%s\n%s%s\n\n%s",
|
||||
v.Content = fmt.Sprintf("\n%s%s\n%s%s\n\n%s",
|
||||
pad, m.progress.View(),
|
||||
pad, m.status,
|
||||
pad+helpStyle("Press 'q' or Ctrl+C to cancel")))
|
||||
pad+helpStyle("Press 'q' or Ctrl+C to cancel"))
|
||||
return v
|
||||
}
|
||||
|
||||
// ProgressReader wraps an io.Reader to intercept and parse Ollama pull operation
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"charm.land/bubbles/v2/textarea"
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/lipgloss/v2"
|
||||
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -76,7 +78,7 @@ func newInputPrompt(message, placeholder, defaultValue string, width, height int
|
||||
ta.Placeholder = placeholder
|
||||
ta.ShowLineNumbers = false
|
||||
ta.Prompt = ""
|
||||
ta.CharLimit = 1000
|
||||
ta.CharLimit = 0
|
||||
ta.SetWidth(width - 12) // account for border + padding
|
||||
ta.SetHeight(1)
|
||||
ta.Focus()
|
||||
@@ -204,7 +206,7 @@ func (p *promptOverlay) updateInput(msg tea.KeyPressMsg) (*promptResult, tea.Cmd
|
||||
// AppModel layout. The prompt replaces the normal input area (below the
|
||||
// separator and above the status bar) rather than taking over the full screen.
|
||||
func (p *promptOverlay) Render() string {
|
||||
theme := GetTheme()
|
||||
theme := style.GetTheme()
|
||||
var content string
|
||||
|
||||
switch p.mode {
|
||||
@@ -224,7 +226,7 @@ func (p *promptOverlay) Render() string {
|
||||
)
|
||||
}
|
||||
|
||||
func (p *promptOverlay) viewSelect(theme Theme) string {
|
||||
func (p *promptOverlay) viewSelect(theme style.Theme) string {
|
||||
var lines []string
|
||||
lines = append(lines, lipgloss.NewStyle().Bold(true).Foreground(theme.Text).Render(p.message))
|
||||
lines = append(lines, "")
|
||||
@@ -247,7 +249,7 @@ func (p *promptOverlay) viewSelect(theme Theme) string {
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func (p *promptOverlay) viewConfirm(theme Theme) string {
|
||||
func (p *promptOverlay) viewConfirm(theme style.Theme) string {
|
||||
var lines []string
|
||||
lines = append(lines, lipgloss.NewStyle().Bold(true).Foreground(theme.Text).Render(p.message))
|
||||
lines = append(lines, "")
|
||||
@@ -272,7 +274,7 @@ func (p *promptOverlay) viewConfirm(theme Theme) string {
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func (p *promptOverlay) viewInput(theme Theme) string {
|
||||
func (p *promptOverlay) viewInput(theme style.Theme) string {
|
||||
var lines []string
|
||||
lines = append(lines, lipgloss.NewStyle().Bold(true).Foreground(theme.Text).Render(p.message))
|
||||
lines = append(lines, "")
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
// Package render provides pure rendering functions for message blocks.
|
||||
// These functions are stateless and can be used by both streaming and
|
||||
// historical message rendering paths, eliminating code duplication.
|
||||
package render
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"charm.land/lipgloss/v2"
|
||||
"github.com/indaco/herald"
|
||||
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
)
|
||||
|
||||
// UserBlock renders a user message with herald Tip styling.
|
||||
// The width parameter controls line wrapping so long messages don't overflow.
|
||||
func UserBlock(content string, width int, ty *herald.Typography, theme style.Theme) string {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
content = "(empty message)"
|
||||
}
|
||||
|
||||
// Wrap content before passing to herald Alert so long lines break
|
||||
// inside the alert box. Subtract 4 to account for the alert bar
|
||||
// prefix ("│ ") and a small margin.
|
||||
if width > 4 {
|
||||
content = lipgloss.Wrap(content, width-4, "")
|
||||
}
|
||||
|
||||
rendered := ty.Tip(content)
|
||||
return styleMarginBottom(theme, rendered)
|
||||
}
|
||||
|
||||
// AssistantBlock renders an assistant message with markdown styling.
|
||||
func AssistantBlock(content string, width int, theme style.Theme) string {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
rendered := style.ToMarkdown(content, width-4)
|
||||
return styleMarginBottom(theme, rendered)
|
||||
}
|
||||
|
||||
// ReasoningBlock renders a reasoning/thinking block with muted italic text.
|
||||
// If duration > 0, shows "Thought for Xs" label. Otherwise shows just "Thought".
|
||||
func ReasoningBlock(content string, duration int64, ty *herald.Typography, theme style.Theme) string {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Match live streaming styling: muted italic text
|
||||
lines := strings.Split(strings.TrimRight(content, "\n"), "\n")
|
||||
contentStr := strings.TrimLeft(strings.Join(lines, "\n"), " \t\n")
|
||||
mutedStyle := lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
contentRendered := mutedStyle.Render(ty.Italic(contentStr))
|
||||
|
||||
// Build label based on duration
|
||||
if duration > 0 {
|
||||
var durationStr string
|
||||
if duration < 1000 {
|
||||
durationStr = fmt.Sprintf("%dms", duration)
|
||||
} else {
|
||||
durationStr = fmt.Sprintf("%.1fs", float64(duration)/1000)
|
||||
}
|
||||
labelPart := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render("Thought for ")
|
||||
durationPart := lipgloss.NewStyle().Foreground(theme.Accent).Render(durationStr)
|
||||
label := labelPart + durationPart
|
||||
rendered := contentRendered + "\n" + label
|
||||
return styleMarginBottom(theme, rendered)
|
||||
}
|
||||
|
||||
label := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render("Thought")
|
||||
rendered := contentRendered + "\n" + label
|
||||
|
||||
return styleMarginBottom(theme, rendered)
|
||||
}
|
||||
|
||||
// SystemBlock renders a system message with herald Note styling.
|
||||
func SystemBlock(content string, ty *herald.Typography, theme style.Theme) string {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
content = "No content available"
|
||||
}
|
||||
|
||||
rendered := ty.Note(content)
|
||||
return styleMarginBottom(theme, rendered)
|
||||
}
|
||||
|
||||
// ErrorBlock renders an error message with herald Caution styling.
|
||||
func ErrorBlock(errorMsg string, ty *herald.Typography, theme style.Theme) string {
|
||||
rendered := ty.Caution(errorMsg)
|
||||
return styleMarginBottom(theme, rendered)
|
||||
}
|
||||
|
||||
// ToolBlock renders a tool execution result with header and body.
|
||||
func ToolBlock(displayName, params, body string, isError bool, width int, ty *herald.Typography, theme style.Theme) string {
|
||||
var icon string
|
||||
iconColor := theme.Success
|
||||
if isError {
|
||||
icon = "×"
|
||||
iconColor = theme.Error
|
||||
} else {
|
||||
icon = "✓"
|
||||
}
|
||||
|
||||
// Style the tool name with color
|
||||
nameColor := theme.Info
|
||||
if isError {
|
||||
nameColor = theme.Error
|
||||
}
|
||||
styledName := lipgloss.NewStyle().Foreground(nameColor).Bold(true).Render(displayName)
|
||||
styledIcon := lipgloss.NewStyle().Foreground(iconColor).Render(icon)
|
||||
|
||||
// Build the content: icon + name + params on first line, then body
|
||||
headerLine := styledIcon + " " + styledName
|
||||
if params != "" {
|
||||
headerLine += " " + lipgloss.NewStyle().Foreground(theme.Muted).Render(params)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(body) == "" {
|
||||
body = ty.Italic("(no output)")
|
||||
}
|
||||
|
||||
// Compose: icon + name + params, then body
|
||||
fullContent := ty.Compose(
|
||||
headerLine,
|
||||
"",
|
||||
body,
|
||||
)
|
||||
return styleMarginBottom(theme, fullContent)
|
||||
}
|
||||
|
||||
// styleMarginBottom applies a 1-line margin bottom using the theme.
|
||||
func styleMarginBottom(theme style.Theme, content string) string {
|
||||
return lipgloss.NewStyle().MarginBottom(1).Render(content)
|
||||
}
|
||||
@@ -0,0 +1,693 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
xansi "github.com/charmbracelet/x/ansi"
|
||||
|
||||
"github.com/mark3labs/kit/internal/ui/selection"
|
||||
)
|
||||
|
||||
// MessageItem is the interface all scrollback messages must implement.
|
||||
// This allows lazy rendering - messages are only rendered when visible.
|
||||
type MessageItem interface {
|
||||
// Render returns the styled content for this message at the given width.
|
||||
// Implementations should cache the result to avoid re-rendering.
|
||||
Render(width int) string
|
||||
|
||||
// Height returns the number of lines this message occupies when rendered.
|
||||
Height() int
|
||||
|
||||
// ID returns a unique identifier for this message (for tracking).
|
||||
ID() string
|
||||
}
|
||||
|
||||
// ScrollList manages a viewport over a list of MessageItems.
|
||||
// It handles offset-based scrolling, lazy rendering, and character-level
|
||||
// text selection (crush-style). Only visible items are rendered on each View() call.
|
||||
type ScrollList struct {
|
||||
items []MessageItem
|
||||
offsetIdx int // Index of first visible item
|
||||
offsetLine int // Lines to skip from first visible item
|
||||
width int
|
||||
height int // Viewport height in lines
|
||||
autoScroll bool // Whether to auto-scroll to bottom on new content
|
||||
itemGap int // Number of blank lines between items (0 = no gap)
|
||||
|
||||
// Character-level text selection (crush-style).
|
||||
sel selection.State
|
||||
}
|
||||
|
||||
// NewScrollList creates a new ScrollList with the given dimensions.
|
||||
func NewScrollList(width, height int) *ScrollList {
|
||||
return &ScrollList{
|
||||
items: []MessageItem{},
|
||||
offsetIdx: 0,
|
||||
offsetLine: 0,
|
||||
width: width,
|
||||
height: height,
|
||||
autoScroll: true,
|
||||
sel: selection.NewState(),
|
||||
}
|
||||
}
|
||||
|
||||
// SetItems replaces the items in the scroll list. If auto-scroll is enabled,
|
||||
// the viewport will scroll to the bottom to show the latest content.
|
||||
func (s *ScrollList) SetItems(items []MessageItem) {
|
||||
s.items = items
|
||||
if s.autoScroll {
|
||||
s.GotoBottom()
|
||||
}
|
||||
}
|
||||
|
||||
// SetHeight updates the viewport height. Called when the terminal is resized.
|
||||
func (s *ScrollList) SetHeight(height int) {
|
||||
s.height = height
|
||||
s.clampOffset()
|
||||
}
|
||||
|
||||
// SetWidth updates the viewport width. Called when the terminal is resized.
|
||||
// This may invalidate cached renders in MessageItems.
|
||||
func (s *ScrollList) SetWidth(width int) {
|
||||
s.width = width
|
||||
s.clampOffset()
|
||||
}
|
||||
|
||||
// SetItemGap sets the number of blank lines between items (0 = no gap).
|
||||
func (s *ScrollList) SetItemGap(gap int) {
|
||||
s.itemGap = gap
|
||||
}
|
||||
|
||||
// ItemGap returns the current gap between items.
|
||||
func (s *ScrollList) ItemGap() int {
|
||||
return s.itemGap
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Mouse event handling — character-level text selection (crush-style)
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// HandleMouseDown handles mouse button press. Detects single, double, and
|
||||
// triple clicks for character, word, and line selection respectively.
|
||||
// Returns true if the click was handled.
|
||||
func (s *ScrollList) HandleMouseDown(x, y int) bool {
|
||||
if len(s.items) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
itemIdx, lineIdx := s.getItemAndLineAtY(y)
|
||||
if itemIdx < 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Multi-click detection (crush-style).
|
||||
now := time.Now()
|
||||
if now.Sub(s.sel.LastClickTime) <= selection.DoubleClickThreshold &&
|
||||
abs(x-s.sel.LastClickX) <= selection.ClickTolerance &&
|
||||
abs(y-s.sel.LastClickY) <= selection.ClickTolerance {
|
||||
s.sel.ClickCount++
|
||||
} else {
|
||||
s.sel.ClickCount = 1
|
||||
}
|
||||
s.sel.LastClickTime = now
|
||||
s.sel.LastClickX = x
|
||||
s.sel.LastClickY = y
|
||||
|
||||
switch s.sel.ClickCount {
|
||||
case 1:
|
||||
// Single click: start character-level drag selection.
|
||||
s.sel.MouseDown = true
|
||||
s.sel.MouseDownItemIdx = itemIdx
|
||||
s.sel.MouseDownLineIdx = lineIdx
|
||||
s.sel.MouseDownCol = x
|
||||
s.sel.DragItemIdx = itemIdx
|
||||
s.sel.DragLineIdx = lineIdx
|
||||
s.sel.DragCol = x
|
||||
|
||||
case 2:
|
||||
// Double click: select word at position.
|
||||
s.selectWord(itemIdx, lineIdx, x)
|
||||
|
||||
case 3:
|
||||
// Triple click: select entire line.
|
||||
s.selectLine(itemIdx, lineIdx)
|
||||
s.sel.ClickCount = 0 // Reset after triple
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// HandleMouseDrag handles mouse motion while button is held.
|
||||
// Updates the selection endpoint for character-level precision.
|
||||
// Returns true if selection was updated.
|
||||
func (s *ScrollList) HandleMouseDrag(x, y int) bool {
|
||||
if !s.sel.MouseDown {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(s.items) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
itemIdx, lineIdx := s.getItemAndLineAtY(y)
|
||||
if itemIdx < 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
s.sel.DragItemIdx = itemIdx
|
||||
s.sel.DragLineIdx = lineIdx
|
||||
s.sel.DragCol = x
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// HandleMouseUp handles mouse button release.
|
||||
// Returns true if there was an active selection.
|
||||
func (s *ScrollList) HandleMouseUp() bool {
|
||||
if !s.sel.MouseDown {
|
||||
return false
|
||||
}
|
||||
s.sel.MouseDown = false
|
||||
return s.sel.HasSelection()
|
||||
}
|
||||
|
||||
// HasSelection returns true if there is a non-empty active selection.
|
||||
func (s *ScrollList) HasSelection() bool {
|
||||
return s.sel.HasSelection()
|
||||
}
|
||||
|
||||
// ClearSelection clears the current text selection.
|
||||
func (s *ScrollList) ClearSelection() {
|
||||
s.sel.Clear()
|
||||
}
|
||||
|
||||
// ExtractSelectedText returns the plain text content of the current selection
|
||||
// by walking through selected items and extracting text at the character level
|
||||
// using the ultraviolet cell buffer (ANSI-aware).
|
||||
func (s *ScrollList) ExtractSelectedText() string {
|
||||
r := s.sel.GetRange()
|
||||
if r.IsEmpty() {
|
||||
return ""
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
|
||||
for itemIdx := r.StartItemIdx; itemIdx <= r.EndItemIdx && itemIdx < len(s.items); itemIdx++ {
|
||||
item := s.items[itemIdx]
|
||||
content := item.Render(s.width)
|
||||
contentLines := strings.Split(content, "\n")
|
||||
|
||||
for lineIdx, line := range contentLines {
|
||||
inRange, startCol, endCol := selection.IsLineInRange(r, itemIdx, lineIdx)
|
||||
if !inRange {
|
||||
continue
|
||||
}
|
||||
|
||||
text := selection.ExtractText(line, startCol, endCol)
|
||||
if text != "" {
|
||||
if sb.Len() > 0 {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// selectWord selects the word at the given position using UAX#29 word
|
||||
// segmentation and display-width-aware column calculations.
|
||||
func (s *ScrollList) selectWord(itemIdx, lineIdx, x int) {
|
||||
if itemIdx < 0 || itemIdx >= len(s.items) {
|
||||
return
|
||||
}
|
||||
|
||||
item := s.items[itemIdx]
|
||||
content := item.Render(s.width)
|
||||
lines := strings.Split(content, "\n")
|
||||
if lineIdx < 0 || lineIdx >= len(lines) {
|
||||
return
|
||||
}
|
||||
|
||||
// Strip ANSI codes for word boundary detection.
|
||||
plainLine := xansi.Strip(lines[lineIdx])
|
||||
startCol, endCol := selection.FindWordBoundaries(plainLine, x)
|
||||
|
||||
if startCol == endCol {
|
||||
// No word at this position — set up single-click drag state.
|
||||
s.sel.MouseDown = true
|
||||
s.sel.MouseDownItemIdx = itemIdx
|
||||
s.sel.MouseDownLineIdx = lineIdx
|
||||
s.sel.MouseDownCol = x
|
||||
s.sel.DragItemIdx = itemIdx
|
||||
s.sel.DragLineIdx = lineIdx
|
||||
s.sel.DragCol = x
|
||||
return
|
||||
}
|
||||
|
||||
// Set selection to the word boundaries.
|
||||
s.sel.MouseDown = true
|
||||
s.sel.MouseDownItemIdx = itemIdx
|
||||
s.sel.MouseDownLineIdx = lineIdx
|
||||
s.sel.MouseDownCol = startCol
|
||||
s.sel.DragItemIdx = itemIdx
|
||||
s.sel.DragLineIdx = lineIdx
|
||||
s.sel.DragCol = endCol
|
||||
}
|
||||
|
||||
// selectLine selects the entire line at the given position.
|
||||
func (s *ScrollList) selectLine(itemIdx, lineIdx int) {
|
||||
if itemIdx < 0 || itemIdx >= len(s.items) {
|
||||
return
|
||||
}
|
||||
|
||||
item := s.items[itemIdx]
|
||||
content := item.Render(s.width)
|
||||
lines := strings.Split(content, "\n")
|
||||
if lineIdx < 0 || lineIdx >= len(lines) {
|
||||
return
|
||||
}
|
||||
|
||||
lineWidth := xansi.StringWidth(lines[lineIdx])
|
||||
|
||||
s.sel.MouseDown = true
|
||||
s.sel.MouseDownItemIdx = itemIdx
|
||||
s.sel.MouseDownLineIdx = lineIdx
|
||||
s.sel.MouseDownCol = 0
|
||||
s.sel.DragItemIdx = itemIdx
|
||||
s.sel.DragLineIdx = lineIdx
|
||||
s.sel.DragCol = lineWidth
|
||||
}
|
||||
|
||||
// getItemAndLineAtY converts a viewport-relative Y coordinate to item index
|
||||
// and line index within that item. Accounts for scroll offset and item gaps.
|
||||
// Returns (-1, -1) if Y is outside the viewport or beyond all items.
|
||||
//
|
||||
// IMPORTANT: Uses Render()+line counting (not Height()) to compute item height,
|
||||
// because Height() on some MessageItem implementations (e.g. StreamingMessageItem
|
||||
// for reasoning blocks) may return 0 when the render cache is empty.
|
||||
func (s *ScrollList) getItemAndLineAtY(y int) (itemIdx, lineIdx int) {
|
||||
if y < 0 || y >= s.height || len(s.items) == 0 {
|
||||
return -1, -1
|
||||
}
|
||||
|
||||
currentY := 0
|
||||
for idx := s.offsetIdx; idx < len(s.items); idx++ {
|
||||
item := s.items[idx]
|
||||
// Compute height the same way View() does: render, then count lines.
|
||||
itemHeight := s.renderedHeight(item)
|
||||
|
||||
// Account for partial visibility of the first item.
|
||||
startLine := 0
|
||||
if idx == s.offsetIdx {
|
||||
startLine = s.offsetLine
|
||||
itemHeight -= s.offsetLine
|
||||
}
|
||||
|
||||
if y >= currentY && y < currentY+itemHeight {
|
||||
return idx, (y - currentY) + startLine
|
||||
}
|
||||
|
||||
currentY += itemHeight
|
||||
|
||||
// Add gap after item (except last).
|
||||
if s.itemGap > 0 && idx < len(s.items)-1 {
|
||||
currentY += s.itemGap
|
||||
}
|
||||
|
||||
if currentY >= s.height {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return -1, -1
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Scrolling
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// ScrollBy scrolls the viewport by the given number of lines.
|
||||
// Positive = scroll down, negative = scroll up.
|
||||
func (s *ScrollList) ScrollBy(lines int) {
|
||||
if lines > 0 {
|
||||
// Scroll down
|
||||
for lines > 0 && s.offsetIdx < len(s.items) {
|
||||
if s.offsetIdx >= len(s.items) {
|
||||
break
|
||||
}
|
||||
currentItem := s.items[s.offsetIdx]
|
||||
itemHeight := currentItem.Height()
|
||||
remainingLines := itemHeight - s.offsetLine
|
||||
|
||||
if lines >= remainingLines {
|
||||
// Move to next item
|
||||
s.offsetIdx++
|
||||
s.offsetLine = 0
|
||||
lines -= remainingLines
|
||||
// Consume gap lines between items
|
||||
if s.itemGap > 0 && s.offsetIdx < len(s.items) {
|
||||
if lines >= s.itemGap {
|
||||
lines -= s.itemGap
|
||||
} else {
|
||||
lines = 0
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Stay on current item, skip more lines
|
||||
s.offsetLine += lines
|
||||
lines = 0
|
||||
}
|
||||
}
|
||||
} else if lines < 0 {
|
||||
// Scroll up
|
||||
lines = -lines
|
||||
for lines > 0 && (s.offsetIdx > 0 || s.offsetLine > 0) {
|
||||
if s.offsetLine > 0 {
|
||||
// Scroll within current item
|
||||
if lines >= s.offsetLine {
|
||||
lines -= s.offsetLine
|
||||
s.offsetLine = 0
|
||||
} else {
|
||||
s.offsetLine -= lines
|
||||
lines = 0
|
||||
}
|
||||
} else if s.offsetIdx > 0 {
|
||||
// Consume gap lines between items
|
||||
if s.itemGap > 0 {
|
||||
if lines > s.itemGap {
|
||||
lines -= s.itemGap
|
||||
} else {
|
||||
lines = 0
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Move to previous item
|
||||
s.offsetIdx--
|
||||
if s.offsetIdx < len(s.items) {
|
||||
currentItem := s.items[s.offsetIdx]
|
||||
itemHeight := currentItem.Height()
|
||||
|
||||
if lines >= itemHeight {
|
||||
lines -= itemHeight
|
||||
s.offsetLine = 0
|
||||
} else {
|
||||
s.offsetLine = itemHeight - lines
|
||||
lines = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
s.clampOffset()
|
||||
}
|
||||
|
||||
// GotoBottom scrolls to the end of the list.
|
||||
func (s *ScrollList) GotoBottom() {
|
||||
if len(s.items) == 0 {
|
||||
s.offsetIdx = 0
|
||||
s.offsetLine = 0
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate total height including gaps
|
||||
totalHeight := 0
|
||||
for i, item := range s.items {
|
||||
rendered := item.Render(s.width)
|
||||
itemHeight := strings.Count(rendered, "\n") + 1
|
||||
totalHeight += itemHeight
|
||||
if s.itemGap > 0 && i < len(s.items)-1 {
|
||||
totalHeight += s.itemGap
|
||||
}
|
||||
}
|
||||
|
||||
// If content fits in viewport, start at top
|
||||
if totalHeight <= s.height {
|
||||
s.offsetIdx = 0
|
||||
s.offsetLine = 0
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, position viewport at bottom
|
||||
remaining := totalHeight - s.height
|
||||
for idx := 0; idx < len(s.items); idx++ {
|
||||
rendered := s.items[idx].Render(s.width)
|
||||
itemHeight := strings.Count(rendered, "\n") + 1
|
||||
if remaining < itemHeight {
|
||||
s.offsetIdx = idx
|
||||
s.offsetLine = remaining
|
||||
return
|
||||
}
|
||||
remaining -= itemHeight
|
||||
if s.itemGap > 0 && idx < len(s.items)-1 {
|
||||
remaining -= s.itemGap
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: show last item
|
||||
s.offsetIdx = max(0, len(s.items)-1)
|
||||
s.offsetLine = 0
|
||||
}
|
||||
|
||||
// GotoTop scrolls to the beginning of the list.
|
||||
func (s *ScrollList) GotoTop() {
|
||||
s.offsetIdx = 0
|
||||
s.offsetLine = 0
|
||||
}
|
||||
|
||||
// AtBottom returns true if the viewport is at the bottom of the list.
|
||||
func (s *ScrollList) AtBottom() bool {
|
||||
if len(s.items) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
visibleHeight := 0
|
||||
for idx := s.offsetIdx; idx < len(s.items); idx++ {
|
||||
item := s.items[idx]
|
||||
rendered := item.Render(s.width)
|
||||
itemHeight := strings.Count(rendered, "\n") + 1
|
||||
|
||||
if idx == s.offsetIdx {
|
||||
visibleHeight += itemHeight - s.offsetLine
|
||||
} else {
|
||||
visibleHeight += itemHeight
|
||||
}
|
||||
|
||||
if s.itemGap > 0 && idx < len(s.items)-1 {
|
||||
visibleHeight += s.itemGap
|
||||
}
|
||||
|
||||
if visibleHeight >= s.height {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// AtTop returns true if the viewport is at the top of the list.
|
||||
func (s *ScrollList) AtTop() bool {
|
||||
return s.offsetIdx == 0 && s.offsetLine == 0
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Rendering
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// View renders the visible portion of the scrollback.
|
||||
// Only items that fit within the viewport height are rendered.
|
||||
// ALWAYS returns exactly s.height lines (padded with empty lines if needed)
|
||||
// to ensure the input/footer stay fixed at the bottom.
|
||||
//
|
||||
// When an active selection exists, character-level highlighting is applied
|
||||
// using ultraviolet ScreenBuffer for ANSI-aware cell manipulation.
|
||||
func (s *ScrollList) View() string {
|
||||
if s.height <= 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
selRange := s.sel.GetRange()
|
||||
hasSelection := !selRange.IsEmpty()
|
||||
|
||||
var lines []string
|
||||
remainingHeight := s.height
|
||||
|
||||
if len(s.items) > 0 {
|
||||
for idx := s.offsetIdx; idx < len(s.items) && remainingHeight > 0; idx++ {
|
||||
item := s.items[idx]
|
||||
content := item.Render(s.width)
|
||||
contentLines := strings.Split(content, "\n")
|
||||
|
||||
startLine := 0
|
||||
if idx == s.offsetIdx {
|
||||
startLine = s.offsetLine
|
||||
}
|
||||
|
||||
for i := startLine; i < len(contentLines) && remainingHeight > 0; i++ {
|
||||
line := contentLines[i]
|
||||
|
||||
// Apply character-level selection highlighting.
|
||||
if hasSelection {
|
||||
inRange, startCol, endCol := selection.IsLineInRange(selRange, idx, i)
|
||||
if inRange {
|
||||
line = selection.HighlightLine(line, startCol, endCol)
|
||||
}
|
||||
}
|
||||
|
||||
lines = append(lines, line)
|
||||
remainingHeight--
|
||||
}
|
||||
|
||||
// Add gap lines between items.
|
||||
if remainingHeight > 0 && idx < len(s.items)-1 && s.itemGap > 0 {
|
||||
for g := 0; g < s.itemGap && remainingHeight > 0; g++ {
|
||||
lines = append(lines, "")
|
||||
remainingHeight--
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pad with empty lines to ensure exactly s.height lines.
|
||||
for remainingHeight > 0 {
|
||||
lines = append(lines, "")
|
||||
remainingHeight--
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// ScrollPercent returns the current scroll position as a percentage (0.0-1.0).
|
||||
// 0.0 = at top, 1.0 = at bottom. Useful for scroll indicators.
|
||||
func (s *ScrollList) ScrollPercent() float64 {
|
||||
if len(s.items) == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
totalHeight := 0
|
||||
for _, item := range s.items {
|
||||
totalHeight += item.Height()
|
||||
}
|
||||
|
||||
if totalHeight <= s.height {
|
||||
return 1.0
|
||||
}
|
||||
|
||||
linesAbove := 0
|
||||
for i := 0; i < s.offsetIdx && i < len(s.items); i++ {
|
||||
linesAbove += s.items[i].Height()
|
||||
}
|
||||
linesAbove += s.offsetLine
|
||||
|
||||
scrollableHeight := totalHeight - s.height
|
||||
if scrollableHeight <= 0 {
|
||||
return 1.0
|
||||
}
|
||||
|
||||
percent := float64(linesAbove) / float64(scrollableHeight)
|
||||
if percent > 1.0 {
|
||||
percent = 1.0
|
||||
}
|
||||
if percent < 0.0 {
|
||||
percent = 0.0
|
||||
}
|
||||
return percent
|
||||
}
|
||||
|
||||
// clampOffset ensures the offset values are within valid bounds after
|
||||
// resizing or scrolling operations.
|
||||
func (s *ScrollList) clampOffset() {
|
||||
if len(s.items) == 0 {
|
||||
s.offsetIdx = 0
|
||||
s.offsetLine = 0
|
||||
return
|
||||
}
|
||||
|
||||
if s.offsetIdx >= len(s.items) {
|
||||
s.offsetIdx = len(s.items) - 1
|
||||
}
|
||||
if s.offsetIdx < 0 {
|
||||
s.offsetIdx = 0
|
||||
}
|
||||
|
||||
if s.offsetIdx < len(s.items) {
|
||||
rendered := s.items[s.offsetIdx].Render(s.width)
|
||||
itemHeight := strings.Count(rendered, "\n") + 1
|
||||
if s.offsetLine >= itemHeight {
|
||||
s.offsetLine = max(0, itemHeight-1)
|
||||
}
|
||||
}
|
||||
if s.offsetLine < 0 {
|
||||
s.offsetLine = 0
|
||||
}
|
||||
|
||||
// Prevent scrolling past the bottom
|
||||
totalHeight := 0
|
||||
for i, item := range s.items {
|
||||
rendered := item.Render(s.width)
|
||||
totalHeight += strings.Count(rendered, "\n") + 1
|
||||
if s.itemGap > 0 && i < len(s.items)-1 {
|
||||
totalHeight += s.itemGap
|
||||
}
|
||||
}
|
||||
|
||||
if totalHeight <= s.height {
|
||||
s.offsetIdx = 0
|
||||
s.offsetLine = 0
|
||||
return
|
||||
}
|
||||
|
||||
linesAbove := 0
|
||||
for i := 0; i < s.offsetIdx; i++ {
|
||||
rendered := s.items[i].Render(s.width)
|
||||
linesAbove += strings.Count(rendered, "\n") + 1
|
||||
if s.itemGap > 0 && i < len(s.items)-1 {
|
||||
linesAbove += s.itemGap
|
||||
}
|
||||
}
|
||||
linesAbove += s.offsetLine
|
||||
|
||||
linesFromCurrentToEnd := totalHeight - linesAbove
|
||||
if linesFromCurrentToEnd < s.height {
|
||||
targetLine := totalHeight - s.height
|
||||
currentLine := 0
|
||||
|
||||
for idx := 0; idx < len(s.items); idx++ {
|
||||
rendered := s.items[idx].Render(s.width)
|
||||
itemHeight := strings.Count(rendered, "\n") + 1
|
||||
|
||||
if currentLine+itemHeight > targetLine {
|
||||
s.offsetIdx = idx
|
||||
s.offsetLine = targetLine - currentLine
|
||||
return
|
||||
}
|
||||
|
||||
currentLine += itemHeight
|
||||
if s.itemGap > 0 && idx < len(s.items)-1 {
|
||||
currentLine += s.itemGap
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// renderedHeight returns the height of a message item in lines by actually
|
||||
// rendering it. This is the single source of truth for item height — it
|
||||
// matches exactly what View() produces, unlike item.Height() which may
|
||||
// return stale/zero values for uncached items (e.g. reasoning blocks).
|
||||
func (s *ScrollList) renderedHeight(item MessageItem) int {
|
||||
rendered := item.Render(s.width)
|
||||
if rendered == "" {
|
||||
return 0
|
||||
}
|
||||
return strings.Count(rendered, "\n") + 1
|
||||
}
|
||||
|
||||
// abs returns the absolute value of x.
|
||||
func abs(x int) int {
|
||||
if x < 0 {
|
||||
return -x
|
||||
}
|
||||
return x
|
||||
}
|
||||
@@ -0,0 +1,324 @@
|
||||
// Package selection provides character-level text selection for terminal UIs.
|
||||
//
|
||||
// It handles converting mouse coordinates (in terminal cells) to character
|
||||
// positions within rendered ANSI-styled text, supporting multi-byte characters,
|
||||
// wide characters (CJK, emoji), and word/line selection via double/triple click.
|
||||
//
|
||||
// The approach is modeled after Charm's crush: all coordinate calculations use
|
||||
// display columns (terminal cells), not byte offsets or rune counts. The
|
||||
// ultraviolet ScreenBuffer provides the bridge between rendered ANSI strings
|
||||
// and individual character cells.
|
||||
package selection
|
||||
|
||||
import (
|
||||
"image"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
uv "github.com/charmbracelet/ultraviolet"
|
||||
xansi "github.com/charmbracelet/x/ansi"
|
||||
"github.com/clipperhouse/displaywidth"
|
||||
"github.com/clipperhouse/uax29/v2/words"
|
||||
)
|
||||
|
||||
// DoubleClickThreshold is the maximum time between clicks for multi-click.
|
||||
const DoubleClickThreshold = 400 * time.Millisecond
|
||||
|
||||
// ClickTolerance is the pixel/cell tolerance for multi-click detection.
|
||||
const ClickTolerance = 2
|
||||
|
||||
// State tracks the full state of a mouse text selection.
|
||||
type State struct {
|
||||
// Whether a mouse button is currently held down.
|
||||
MouseDown bool
|
||||
|
||||
// Position where mouse was first pressed (viewport-relative).
|
||||
MouseDownItemIdx int
|
||||
MouseDownLineIdx int
|
||||
MouseDownCol int
|
||||
|
||||
// Current drag position (viewport-relative).
|
||||
DragItemIdx int
|
||||
DragLineIdx int
|
||||
DragCol int
|
||||
|
||||
// Multi-click detection.
|
||||
LastClickTime time.Time
|
||||
LastClickX int
|
||||
LastClickY int
|
||||
ClickCount int
|
||||
}
|
||||
|
||||
// Range represents a normalized (start <= end) selection range.
|
||||
type Range struct {
|
||||
StartItemIdx int
|
||||
StartLine int
|
||||
StartCol int
|
||||
EndItemIdx int
|
||||
EndLine int
|
||||
EndCol int
|
||||
}
|
||||
|
||||
// IsEmpty returns true if the range selects nothing.
|
||||
func (r Range) IsEmpty() bool {
|
||||
return r.StartItemIdx < 0 || r.EndItemIdx < 0 ||
|
||||
(r.StartItemIdx == r.EndItemIdx && r.StartLine == r.EndLine && r.StartCol == r.EndCol)
|
||||
}
|
||||
|
||||
// NewState creates a new empty selection state.
|
||||
func NewState() State {
|
||||
return State{
|
||||
MouseDownItemIdx: -1,
|
||||
DragItemIdx: -1,
|
||||
}
|
||||
}
|
||||
|
||||
// Clear resets all selection state.
|
||||
func (s *State) Clear() {
|
||||
s.MouseDown = false
|
||||
s.MouseDownItemIdx = -1
|
||||
s.MouseDownLineIdx = 0
|
||||
s.MouseDownCol = 0
|
||||
s.DragItemIdx = -1
|
||||
s.DragLineIdx = 0
|
||||
s.DragCol = 0
|
||||
s.LastClickTime = time.Time{}
|
||||
s.LastClickX = 0
|
||||
s.LastClickY = 0
|
||||
s.ClickCount = 0
|
||||
}
|
||||
|
||||
// HasSelection returns true if there is a non-empty active selection.
|
||||
func (s *State) HasSelection() bool {
|
||||
return s.MouseDownItemIdx >= 0 && s.DragItemIdx >= 0 && !s.GetRange().IsEmpty()
|
||||
}
|
||||
|
||||
// GetRange returns the normalized selection range (start <= end).
|
||||
func (s *State) GetRange() Range {
|
||||
if s.MouseDownItemIdx < 0 || s.DragItemIdx < 0 {
|
||||
return Range{StartItemIdx: -1, EndItemIdx: -1}
|
||||
}
|
||||
|
||||
downItem := s.MouseDownItemIdx
|
||||
downLine := s.MouseDownLineIdx
|
||||
downCol := s.MouseDownCol
|
||||
dragItem := s.DragItemIdx
|
||||
dragLine := s.DragLineIdx
|
||||
dragCol := s.DragCol
|
||||
|
||||
// Determine if dragging forward or backward.
|
||||
forward := dragItem > downItem ||
|
||||
(dragItem == downItem && dragLine > downLine) ||
|
||||
(dragItem == downItem && dragLine == downLine && dragCol >= downCol)
|
||||
|
||||
if forward {
|
||||
return Range{
|
||||
StartItemIdx: downItem,
|
||||
StartLine: downLine,
|
||||
StartCol: downCol,
|
||||
EndItemIdx: dragItem,
|
||||
EndLine: dragLine,
|
||||
EndCol: dragCol,
|
||||
}
|
||||
}
|
||||
return Range{
|
||||
StartItemIdx: dragItem,
|
||||
StartLine: dragLine,
|
||||
StartCol: dragCol,
|
||||
EndItemIdx: downItem,
|
||||
EndLine: downLine,
|
||||
EndCol: downCol,
|
||||
}
|
||||
}
|
||||
|
||||
// IsLineInRange checks if a specific line within an item falls inside the
|
||||
// selection range. Returns (inRange, startCol, endCol) where startCol == -1
|
||||
// means the entire line is selected. startCol == endCol means no selection
|
||||
// on this line.
|
||||
func IsLineInRange(r Range, itemIdx, lineIdx int) (bool, int, int) {
|
||||
if r.IsEmpty() {
|
||||
return false, 0, 0
|
||||
}
|
||||
|
||||
// Outside item range entirely.
|
||||
if itemIdx < r.StartItemIdx || itemIdx > r.EndItemIdx {
|
||||
return false, 0, 0
|
||||
}
|
||||
|
||||
// Single-item selection.
|
||||
if r.StartItemIdx == r.EndItemIdx {
|
||||
if itemIdx != r.StartItemIdx {
|
||||
return false, 0, 0
|
||||
}
|
||||
if lineIdx < r.StartLine || lineIdx > r.EndLine {
|
||||
return false, 0, 0
|
||||
}
|
||||
if r.StartLine == r.EndLine {
|
||||
// Single line: specific column range.
|
||||
return true, r.StartCol, r.EndCol
|
||||
}
|
||||
if lineIdx == r.StartLine {
|
||||
return true, r.StartCol, -1 // from startCol to end of line
|
||||
}
|
||||
if lineIdx == r.EndLine {
|
||||
return true, 0, r.EndCol // from start of line to endCol
|
||||
}
|
||||
return true, -1, -1 // full line (middle of multi-line selection)
|
||||
}
|
||||
|
||||
// Multi-item selection.
|
||||
if itemIdx == r.StartItemIdx {
|
||||
if lineIdx < r.StartLine {
|
||||
return false, 0, 0
|
||||
}
|
||||
if lineIdx == r.StartLine {
|
||||
return true, r.StartCol, -1
|
||||
}
|
||||
return true, -1, -1 // full line
|
||||
}
|
||||
if itemIdx == r.EndItemIdx {
|
||||
if lineIdx > r.EndLine {
|
||||
return false, 0, 0
|
||||
}
|
||||
if lineIdx == r.EndLine {
|
||||
return true, 0, r.EndCol
|
||||
}
|
||||
return true, -1, -1 // full line
|
||||
}
|
||||
|
||||
// Middle item: fully selected.
|
||||
return true, -1, -1
|
||||
}
|
||||
|
||||
// FindWordBoundaries finds the start and end column of the word at the given
|
||||
// column position in a plain-text line (ANSI codes already stripped).
|
||||
// Returns (startCol, endCol) where endCol is exclusive.
|
||||
// Uses UAX#29 word segmentation and display-width-aware column tracking.
|
||||
func FindWordBoundaries(line string, col int) (startCol, endCol int) {
|
||||
if line == "" || col < 0 {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
// Segment the line into words using UAX#29.
|
||||
lineCol := 0
|
||||
iter := words.FromString(line)
|
||||
for iter.Next() {
|
||||
token := iter.Value()
|
||||
tokenWidth := displaywidth.String(token)
|
||||
|
||||
graphemeStart := lineCol
|
||||
graphemeEnd := lineCol + tokenWidth
|
||||
lineCol += tokenWidth
|
||||
|
||||
// If clicked before this token, no word here.
|
||||
if col < graphemeStart {
|
||||
return col, col
|
||||
}
|
||||
|
||||
// If clicked within this token, return its boundaries.
|
||||
if col >= graphemeStart && col < graphemeEnd {
|
||||
// Whitespace tokens produce empty selection.
|
||||
if strings.TrimSpace(token) == "" {
|
||||
return col, col
|
||||
}
|
||||
return graphemeStart, graphemeEnd
|
||||
}
|
||||
}
|
||||
|
||||
return col, col
|
||||
}
|
||||
|
||||
// HighlightLine applies reverse-video highlighting to a portion of a rendered
|
||||
// line (which may contain ANSI escape codes). startCol/endCol are in display
|
||||
// columns. If startCol == -1, the entire line is highlighted. If startCol ==
|
||||
// endCol, returns the line unchanged.
|
||||
//
|
||||
// Uses ultraviolet ScreenBuffer for cell-level ANSI manipulation.
|
||||
func HighlightLine(line string, startCol, endCol int) string {
|
||||
if line == "" {
|
||||
return line
|
||||
}
|
||||
|
||||
lineWidth := xansi.StringWidth(line)
|
||||
if lineWidth == 0 {
|
||||
return line
|
||||
}
|
||||
|
||||
// Full-line highlight.
|
||||
if startCol == -1 {
|
||||
startCol = 0
|
||||
endCol = lineWidth
|
||||
}
|
||||
|
||||
if startCol >= endCol || startCol >= lineWidth {
|
||||
return line
|
||||
}
|
||||
if endCol > lineWidth {
|
||||
endCol = lineWidth
|
||||
}
|
||||
|
||||
// Parse the styled line into a cell buffer.
|
||||
area := image.Rect(0, 0, lineWidth, 1)
|
||||
buf := uv.NewScreenBuffer(lineWidth, 1)
|
||||
styled := uv.NewStyledString(line)
|
||||
styled.Draw(&buf, area)
|
||||
|
||||
// Apply reverse attribute to cells in the selection range.
|
||||
if buf.Height() > 0 {
|
||||
bufLine := buf.Line(0)
|
||||
for x := startCol; x < endCol && x < len(bufLine); x++ {
|
||||
cell := bufLine.At(x)
|
||||
if cell != nil {
|
||||
cell.Style.Attrs |= uv.AttrReverse
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return buf.Render()
|
||||
}
|
||||
|
||||
// ExtractText extracts plain text from a rendered ANSI string within the given
|
||||
// column range on a single line. Uses ultraviolet to parse ANSI and extract
|
||||
// character content.
|
||||
func ExtractText(line string, startCol, endCol int) string {
|
||||
if line == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
lineWidth := xansi.StringWidth(line)
|
||||
if lineWidth == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Full-line extraction.
|
||||
if startCol == -1 {
|
||||
startCol = 0
|
||||
endCol = lineWidth
|
||||
}
|
||||
|
||||
if startCol >= endCol || startCol >= lineWidth {
|
||||
return ""
|
||||
}
|
||||
if endCol > lineWidth {
|
||||
endCol = lineWidth
|
||||
}
|
||||
|
||||
// Parse to cell buffer.
|
||||
area := image.Rect(0, 0, lineWidth, 1)
|
||||
buf := uv.NewScreenBuffer(lineWidth, 1)
|
||||
styled := uv.NewStyledString(line)
|
||||
styled.Draw(&buf, area)
|
||||
|
||||
var sb strings.Builder
|
||||
if buf.Height() > 0 {
|
||||
bufLine := buf.Line(0)
|
||||
for x := startCol; x < endCol && x < len(bufLine); x++ {
|
||||
cell := bufLine.At(x)
|
||||
if cell != nil && cell.Content != "" {
|
||||
sb.WriteString(cell.Content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
@@ -0,0 +1,400 @@
|
||||
package selection
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewState(t *testing.T) {
|
||||
s := NewState()
|
||||
if s.MouseDownItemIdx != -1 {
|
||||
t.Errorf("expected MouseDownItemIdx -1, got %d", s.MouseDownItemIdx)
|
||||
}
|
||||
if s.DragItemIdx != -1 {
|
||||
t.Errorf("expected DragItemIdx -1, got %d", s.DragItemIdx)
|
||||
}
|
||||
if s.MouseDown {
|
||||
t.Error("expected MouseDown false")
|
||||
}
|
||||
if s.HasSelection() {
|
||||
t.Error("expected no selection on new state")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClear(t *testing.T) {
|
||||
s := NewState()
|
||||
s.MouseDown = true
|
||||
s.MouseDownItemIdx = 2
|
||||
s.DragItemIdx = 3
|
||||
s.ClickCount = 2
|
||||
s.Clear()
|
||||
|
||||
if s.MouseDown {
|
||||
t.Error("expected MouseDown false after clear")
|
||||
}
|
||||
if s.MouseDownItemIdx != -1 {
|
||||
t.Errorf("expected MouseDownItemIdx -1 after clear, got %d", s.MouseDownItemIdx)
|
||||
}
|
||||
if s.DragItemIdx != -1 {
|
||||
t.Errorf("expected DragItemIdx -1 after clear, got %d", s.DragItemIdx)
|
||||
}
|
||||
if s.ClickCount != 0 {
|
||||
t.Errorf("expected ClickCount 0 after clear, got %d", s.ClickCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRange_Forward(t *testing.T) {
|
||||
s := NewState()
|
||||
s.MouseDownItemIdx = 0
|
||||
s.MouseDownLineIdx = 1
|
||||
s.MouseDownCol = 5
|
||||
s.DragItemIdx = 0
|
||||
s.DragLineIdx = 3
|
||||
s.DragCol = 10
|
||||
|
||||
r := s.GetRange()
|
||||
if r.StartItemIdx != 0 || r.StartLine != 1 || r.StartCol != 5 {
|
||||
t.Errorf("unexpected start: item=%d line=%d col=%d", r.StartItemIdx, r.StartLine, r.StartCol)
|
||||
}
|
||||
if r.EndItemIdx != 0 || r.EndLine != 3 || r.EndCol != 10 {
|
||||
t.Errorf("unexpected end: item=%d line=%d col=%d", r.EndItemIdx, r.EndLine, r.EndCol)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRange_Backward(t *testing.T) {
|
||||
s := NewState()
|
||||
s.MouseDownItemIdx = 2
|
||||
s.MouseDownLineIdx = 5
|
||||
s.MouseDownCol = 20
|
||||
s.DragItemIdx = 0
|
||||
s.DragLineIdx = 1
|
||||
s.DragCol = 3
|
||||
|
||||
r := s.GetRange()
|
||||
// Should be normalized: drag position becomes start
|
||||
if r.StartItemIdx != 0 || r.StartLine != 1 || r.StartCol != 3 {
|
||||
t.Errorf("unexpected start: item=%d line=%d col=%d", r.StartItemIdx, r.StartLine, r.StartCol)
|
||||
}
|
||||
if r.EndItemIdx != 2 || r.EndLine != 5 || r.EndCol != 20 {
|
||||
t.Errorf("unexpected end: item=%d line=%d col=%d", r.EndItemIdx, r.EndLine, r.EndCol)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRange_SameLine(t *testing.T) {
|
||||
s := NewState()
|
||||
s.MouseDownItemIdx = 1
|
||||
s.MouseDownLineIdx = 2
|
||||
s.MouseDownCol = 10
|
||||
s.DragItemIdx = 1
|
||||
s.DragLineIdx = 2
|
||||
s.DragCol = 20
|
||||
|
||||
r := s.GetRange()
|
||||
if r.IsEmpty() {
|
||||
t.Error("expected non-empty range")
|
||||
}
|
||||
if r.StartCol != 10 || r.EndCol != 20 {
|
||||
t.Errorf("expected cols 10-20, got %d-%d", r.StartCol, r.EndCol)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRangeIsEmpty(t *testing.T) {
|
||||
// Same point
|
||||
r := Range{StartItemIdx: 0, StartLine: 0, StartCol: 5, EndItemIdx: 0, EndLine: 0, EndCol: 5}
|
||||
if !r.IsEmpty() {
|
||||
t.Error("expected same-point range to be empty")
|
||||
}
|
||||
|
||||
// Negative item idx
|
||||
r = Range{StartItemIdx: -1, EndItemIdx: -1}
|
||||
if !r.IsEmpty() {
|
||||
t.Error("expected negative item idx range to be empty")
|
||||
}
|
||||
|
||||
// Valid range
|
||||
r = Range{StartItemIdx: 0, StartLine: 0, StartCol: 0, EndItemIdx: 0, EndLine: 0, EndCol: 5}
|
||||
if r.IsEmpty() {
|
||||
t.Error("expected valid range to not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasSelection(t *testing.T) {
|
||||
s := NewState()
|
||||
if s.HasSelection() {
|
||||
t.Error("new state should have no selection")
|
||||
}
|
||||
|
||||
// Set up a valid selection
|
||||
s.MouseDownItemIdx = 0
|
||||
s.MouseDownLineIdx = 0
|
||||
s.MouseDownCol = 0
|
||||
s.DragItemIdx = 0
|
||||
s.DragLineIdx = 0
|
||||
s.DragCol = 10
|
||||
if !s.HasSelection() {
|
||||
t.Error("expected selection to exist")
|
||||
}
|
||||
|
||||
// Same point = no selection
|
||||
s.DragCol = 0
|
||||
if s.HasSelection() {
|
||||
t.Error("same point should not be a selection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsLineInRange_SingleItem_SingleLine(t *testing.T) {
|
||||
r := Range{
|
||||
StartItemIdx: 1, StartLine: 2, StartCol: 5,
|
||||
EndItemIdx: 1, EndLine: 2, EndCol: 15,
|
||||
}
|
||||
|
||||
// Exact line
|
||||
ok, sc, ec := IsLineInRange(r, 1, 2)
|
||||
if !ok || sc != 5 || ec != 15 {
|
||||
t.Errorf("expected (true, 5, 15), got (%v, %d, %d)", ok, sc, ec)
|
||||
}
|
||||
|
||||
// Wrong line
|
||||
ok, _, _ = IsLineInRange(r, 1, 0)
|
||||
if ok {
|
||||
t.Error("line 0 should not be in range")
|
||||
}
|
||||
|
||||
// Wrong item
|
||||
ok, _, _ = IsLineInRange(r, 0, 2)
|
||||
if ok {
|
||||
t.Error("item 0 should not be in range")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsLineInRange_SingleItem_MultiLine(t *testing.T) {
|
||||
r := Range{
|
||||
StartItemIdx: 0, StartLine: 1, StartCol: 5,
|
||||
EndItemIdx: 0, EndLine: 4, EndCol: 10,
|
||||
}
|
||||
|
||||
// Start line
|
||||
ok, sc, ec := IsLineInRange(r, 0, 1)
|
||||
if !ok || sc != 5 || ec != -1 {
|
||||
t.Errorf("start line: expected (true, 5, -1), got (%v, %d, %d)", ok, sc, ec)
|
||||
}
|
||||
|
||||
// Middle line
|
||||
ok, sc, ec = IsLineInRange(r, 0, 2)
|
||||
if !ok || sc != -1 || ec != -1 {
|
||||
t.Errorf("middle line: expected (true, -1, -1), got (%v, %d, %d)", ok, sc, ec)
|
||||
}
|
||||
|
||||
// End line
|
||||
ok, sc, ec = IsLineInRange(r, 0, 4)
|
||||
if !ok || sc != 0 || ec != 10 {
|
||||
t.Errorf("end line: expected (true, 0, 10), got (%v, %d, %d)", ok, sc, ec)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsLineInRange_MultiItem(t *testing.T) {
|
||||
r := Range{
|
||||
StartItemIdx: 0, StartLine: 3, StartCol: 5,
|
||||
EndItemIdx: 2, EndLine: 1, EndCol: 10,
|
||||
}
|
||||
|
||||
// First item, start line
|
||||
ok, sc, ec := IsLineInRange(r, 0, 3)
|
||||
if !ok || sc != 5 || ec != -1 {
|
||||
t.Errorf("first item start: expected (true, 5, -1), got (%v, %d, %d)", ok, sc, ec)
|
||||
}
|
||||
|
||||
// First item, line after start
|
||||
ok, sc, ec = IsLineInRange(r, 0, 5)
|
||||
if !ok || sc != -1 || ec != -1 {
|
||||
t.Errorf("first item after: expected (true, -1, -1), got (%v, %d, %d)", ok, sc, ec)
|
||||
}
|
||||
|
||||
// Middle item, any line
|
||||
ok, sc, ec = IsLineInRange(r, 1, 0)
|
||||
if !ok || sc != -1 || ec != -1 {
|
||||
t.Errorf("middle item: expected (true, -1, -1), got (%v, %d, %d)", ok, sc, ec)
|
||||
}
|
||||
|
||||
// Last item, end line
|
||||
ok, sc, ec = IsLineInRange(r, 2, 1)
|
||||
if !ok || sc != 0 || ec != 10 {
|
||||
t.Errorf("last item end: expected (true, 0, 10), got (%v, %d, %d)", ok, sc, ec)
|
||||
}
|
||||
|
||||
// Last item, line after end
|
||||
ok, _, _ = IsLineInRange(r, 2, 5)
|
||||
if ok {
|
||||
t.Error("line after end in last item should not be in range")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindWordBoundaries(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
line string
|
||||
col int
|
||||
wantStart int
|
||||
wantEnd int
|
||||
}{
|
||||
{
|
||||
name: "simple word",
|
||||
line: "hello world",
|
||||
col: 2,
|
||||
wantStart: 0,
|
||||
wantEnd: 5,
|
||||
},
|
||||
{
|
||||
name: "second word",
|
||||
line: "hello world",
|
||||
col: 7,
|
||||
wantStart: 6,
|
||||
wantEnd: 11,
|
||||
},
|
||||
{
|
||||
name: "on space",
|
||||
line: "hello world",
|
||||
col: 5,
|
||||
wantStart: 5,
|
||||
wantEnd: 5,
|
||||
},
|
||||
{
|
||||
name: "empty line",
|
||||
line: "",
|
||||
col: 0,
|
||||
wantStart: 0,
|
||||
wantEnd: 0,
|
||||
},
|
||||
{
|
||||
name: "negative col",
|
||||
line: "hello",
|
||||
col: -1,
|
||||
wantStart: 0,
|
||||
wantEnd: 0,
|
||||
},
|
||||
{
|
||||
name: "past end",
|
||||
line: "hello",
|
||||
col: 10,
|
||||
wantStart: 10,
|
||||
wantEnd: 10,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
start, end := FindWordBoundaries(tt.line, tt.col)
|
||||
if start != tt.wantStart || end != tt.wantEnd {
|
||||
t.Errorf("FindWordBoundaries(%q, %d) = (%d, %d), want (%d, %d)",
|
||||
tt.line, tt.col, start, end, tt.wantStart, tt.wantEnd)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractText_PlainText(t *testing.T) {
|
||||
line := "Hello, World!"
|
||||
text := ExtractText(line, 0, 5)
|
||||
if text != "Hello" {
|
||||
t.Errorf("expected 'Hello', got %q", text)
|
||||
}
|
||||
|
||||
text = ExtractText(line, 7, 12)
|
||||
if text != "World" {
|
||||
t.Errorf("expected 'World', got %q", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractText_FullLine(t *testing.T) {
|
||||
line := "Hello"
|
||||
text := ExtractText(line, -1, -1)
|
||||
if text != "Hello" {
|
||||
t.Errorf("expected 'Hello', got %q", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractText_Empty(t *testing.T) {
|
||||
text := ExtractText("", 0, 5)
|
||||
if text != "" {
|
||||
t.Errorf("expected empty string, got %q", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractText_OutOfBounds(t *testing.T) {
|
||||
line := "Hi"
|
||||
text := ExtractText(line, 5, 10)
|
||||
if text != "" {
|
||||
t.Errorf("expected empty string for out of bounds, got %q", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHighlightLine_PlainText(t *testing.T) {
|
||||
line := "Hello, World!"
|
||||
result := HighlightLine(line, 0, 5)
|
||||
// Should produce a non-empty result different from input (has ANSI codes)
|
||||
if result == "" {
|
||||
t.Error("expected non-empty result")
|
||||
}
|
||||
// Should still contain the text content
|
||||
if len(result) < len(line) {
|
||||
t.Error("result should be at least as long as input (ANSI codes add length)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHighlightLine_Empty(t *testing.T) {
|
||||
result := HighlightLine("", 0, 5)
|
||||
if result != "" {
|
||||
t.Errorf("expected empty for empty input, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHighlightLine_NoSelection(t *testing.T) {
|
||||
line := "Hello"
|
||||
result := HighlightLine(line, 3, 3)
|
||||
// Same startCol and endCol = no change
|
||||
if result != line {
|
||||
t.Errorf("expected no change for zero-width selection, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultiClickDetection verifies the click counting logic.
|
||||
func TestMultiClickDetection(t *testing.T) {
|
||||
s := NewState()
|
||||
now := time.Now()
|
||||
|
||||
// First click
|
||||
s.LastClickTime = now
|
||||
s.LastClickX = 10
|
||||
s.LastClickY = 5
|
||||
s.ClickCount = 1
|
||||
|
||||
// Second click within threshold
|
||||
later := now.Add(200 * time.Millisecond)
|
||||
if later.Sub(s.LastClickTime) <= DoubleClickThreshold {
|
||||
if abs(10-s.LastClickX) <= ClickTolerance && abs(5-s.LastClickY) <= ClickTolerance {
|
||||
s.ClickCount++
|
||||
}
|
||||
}
|
||||
if s.ClickCount != 2 {
|
||||
t.Errorf("expected click count 2, got %d", s.ClickCount)
|
||||
}
|
||||
|
||||
// Third click
|
||||
s.LastClickTime = later
|
||||
later2 := later.Add(200 * time.Millisecond)
|
||||
if later2.Sub(s.LastClickTime) <= DoubleClickThreshold {
|
||||
if abs(10-s.LastClickX) <= ClickTolerance && abs(5-s.LastClickY) <= ClickTolerance {
|
||||
s.ClickCount++
|
||||
}
|
||||
}
|
||||
if s.ClickCount != 3 {
|
||||
t.Errorf("expected click count 3, got %d", s.ClickCount)
|
||||
}
|
||||
}
|
||||
|
||||
func abs(x int) int {
|
||||
if x < 0 {
|
||||
return -x
|
||||
}
|
||||
return x
|
||||
}
|
||||
+136
-58
@@ -12,6 +12,7 @@ import (
|
||||
"charm.land/lipgloss/v2"
|
||||
|
||||
"github.com/mark3labs/kit/internal/session"
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
)
|
||||
|
||||
// SessionSelectedMsg is sent when the user selects a session from the picker.
|
||||
@@ -158,12 +159,12 @@ func (ss *SessionSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
|
||||
switch {
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("up", "k"))):
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("up"))):
|
||||
if ss.cursor > 0 {
|
||||
ss.cursor--
|
||||
}
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("down", "j"))):
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("down"))):
|
||||
if ss.cursor < len(ss.filtered)-1 {
|
||||
ss.cursor++
|
||||
}
|
||||
@@ -250,58 +251,108 @@ func (ss *SessionSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
// View implements tea.Model.
|
||||
func (ss *SessionSelectorComponent) View() tea.View {
|
||||
theme := GetTheme()
|
||||
w := ss.width
|
||||
var b strings.Builder
|
||||
theme := style.GetTheme()
|
||||
|
||||
// Full-screen bordered container - uses entire terminal width and height
|
||||
maxWidth := ss.width - 2 // Small margin on each side
|
||||
if maxWidth < 20 {
|
||||
maxWidth = ss.width
|
||||
}
|
||||
maxHeight := ss.height - 2 // Small margin top/bottom to prevent overflow
|
||||
if maxHeight < 10 {
|
||||
maxHeight = ss.height
|
||||
}
|
||||
horizontalPadding := 1
|
||||
innerWidth := maxWidth - 4 // Account for border (2) + padding (2)
|
||||
innerHeight := maxHeight - 4 // Account for border (2) + padding (2)
|
||||
|
||||
// Container style with border - full width/height like a framed panel
|
||||
containerStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(theme.Primary).
|
||||
Background(theme.Background).
|
||||
Padding(1, horizontalPadding).
|
||||
Width(maxWidth).
|
||||
Height(maxHeight)
|
||||
|
||||
var contentBuilder strings.Builder
|
||||
|
||||
// ── Header: title + scope badges ─────────────────────────────
|
||||
titleStyle := lipgloss.NewStyle().Bold(true).Foreground(theme.Accent).PaddingLeft(1)
|
||||
b.WriteString(titleStyle.Render(fmt.Sprintf("Resume Session (%s)", ss.scope)))
|
||||
b.WriteString("\n")
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(theme.Accent).
|
||||
Background(theme.Background)
|
||||
contentBuilder.WriteString(titleStyle.Render(fmt.Sprintf("Resume Session (%s)", ss.scope)))
|
||||
contentBuilder.WriteString("\n")
|
||||
|
||||
// ── Help / keybindings ───────────────────────────────────────
|
||||
helpStyle := lipgloss.NewStyle().Foreground(theme.Muted).PaddingLeft(1)
|
||||
if w >= 75 {
|
||||
b.WriteString(helpStyle.Render("tab: scope N: named D: delete R: rename type to search esc: cancel"))
|
||||
} else if w >= 50 {
|
||||
b.WriteString(helpStyle.Render("tab scope N named D del type to search esc"))
|
||||
helpStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background)
|
||||
if innerWidth >= 75 {
|
||||
contentBuilder.WriteString(helpStyle.Render("tab: scope N: named D: delete R: rename type to search esc: cancel"))
|
||||
} else if innerWidth >= 50 {
|
||||
contentBuilder.WriteString(helpStyle.Render("tab scope N named D del type to search esc"))
|
||||
} else {
|
||||
b.WriteString(helpStyle.Render("tab N D esc"))
|
||||
contentBuilder.WriteString(helpStyle.Render("tab N D esc"))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
contentBuilder.WriteString("\n")
|
||||
|
||||
// ── Search (only shown when active) ──────────────────────────
|
||||
if ss.search != "" {
|
||||
searchStyle := lipgloss.NewStyle().Foreground(theme.Info).PaddingLeft(1)
|
||||
b.WriteString(searchStyle.Render(fmt.Sprintf("> %s", ss.search)))
|
||||
b.WriteString("\n")
|
||||
searchStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Info).
|
||||
Background(theme.Background)
|
||||
contentBuilder.WriteString(searchStyle.Render(fmt.Sprintf("> %s", ss.search)))
|
||||
contentBuilder.WriteString("\n")
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
// Separator line
|
||||
sepWidth := innerWidth
|
||||
contentBuilder.WriteString(
|
||||
lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background).
|
||||
Render(strings.Repeat("─", sepWidth)))
|
||||
contentBuilder.WriteString("\n")
|
||||
|
||||
// ── Delete confirmation ──────────────────────────────────────
|
||||
if ss.confirmDelete >= 0 && ss.confirmDelete < len(ss.filtered) {
|
||||
warnStyle := lipgloss.NewStyle().Foreground(theme.Error).Bold(true).PaddingLeft(1)
|
||||
warnStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Error).
|
||||
Bold(true).
|
||||
Background(theme.Background)
|
||||
name := sessionDisplayName(ss.filtered[ss.confirmDelete])
|
||||
b.WriteString(warnStyle.Render(fmt.Sprintf("Delete %q? (y/N)", truncateRunes(name, 40))))
|
||||
b.WriteString("\n")
|
||||
contentBuilder.WriteString(warnStyle.Render(fmt.Sprintf("Delete %q? (y/N)", truncateRunes(name, 40))))
|
||||
contentBuilder.WriteString("\n")
|
||||
}
|
||||
|
||||
// ── Session list ─────────────────────────────────────────────
|
||||
if len(ss.filtered) == 0 {
|
||||
emptyStyle := lipgloss.NewStyle().Foreground(theme.Muted).PaddingLeft(2)
|
||||
emptyStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background)
|
||||
if ss.search != "" {
|
||||
b.WriteString(emptyStyle.Render(fmt.Sprintf("No sessions matching %q", ss.search)))
|
||||
contentBuilder.WriteString(emptyStyle.Render(fmt.Sprintf("No sessions matching %q", ss.search)))
|
||||
} else if ss.filter == SessionFilterNamed {
|
||||
b.WriteString(emptyStyle.Render("No named sessions. Press N to show all."))
|
||||
contentBuilder.WriteString(emptyStyle.Render("No named sessions. Press N to show all."))
|
||||
} else if ss.scope == SessionScopeCwd {
|
||||
b.WriteString(emptyStyle.Render("No sessions in current folder. Press tab to view all."))
|
||||
contentBuilder.WriteString(emptyStyle.Render("No sessions in current folder. Press tab to view all."))
|
||||
} else {
|
||||
b.WriteString(emptyStyle.Render("No sessions found"))
|
||||
contentBuilder.WriteString(emptyStyle.Render("No sessions found"))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
contentBuilder.WriteString("\n")
|
||||
} else {
|
||||
visH := ss.visibleHeight()
|
||||
// Compute visible window based on inner container height
|
||||
// Chrome: header(2) + separator(1) + footer separator(1) + footer(1) = 5
|
||||
chromeLines := 5
|
||||
if ss.search != "" {
|
||||
chromeLines++
|
||||
}
|
||||
if ss.confirmDelete >= 0 {
|
||||
chromeLines++
|
||||
}
|
||||
visH := max(innerHeight-chromeLines, 3)
|
||||
|
||||
// Center the cursor in the visible window.
|
||||
startIdx := max(0, min(ss.cursor-visH/2, len(ss.filtered)-visH))
|
||||
@@ -312,20 +363,42 @@ func (ss *SessionSelectorComponent) View() tea.View {
|
||||
isCursor := i == ss.cursor
|
||||
isCurrent := info.Path == ss.currentPath
|
||||
isDeleting := i == ss.confirmDelete
|
||||
line := ss.renderEntry(info, isCursor, isCurrent, isDeleting, w)
|
||||
b.WriteString(line)
|
||||
b.WriteString("\n")
|
||||
line := ss.renderEntry(info, isCursor, isCurrent, isDeleting, innerWidth)
|
||||
contentBuilder.WriteString(line)
|
||||
contentBuilder.WriteString("\n")
|
||||
}
|
||||
|
||||
// Scroll position indicator.
|
||||
if len(ss.filtered) > visH {
|
||||
posStyle := lipgloss.NewStyle().Foreground(theme.Muted).PaddingLeft(2)
|
||||
b.WriteString(posStyle.Render(fmt.Sprintf("(%d/%d)", ss.cursor+1, len(ss.filtered))))
|
||||
b.WriteString("\n")
|
||||
posStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background)
|
||||
contentBuilder.WriteString(posStyle.Render(fmt.Sprintf("(%d/%d)", ss.cursor+1, len(ss.filtered))))
|
||||
contentBuilder.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
return tea.NewView(b.String())
|
||||
// Footer separator
|
||||
contentBuilder.WriteString(
|
||||
lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background).
|
||||
Render(strings.Repeat("─", sepWidth)))
|
||||
contentBuilder.WriteString("\n")
|
||||
|
||||
// Footer with filter info
|
||||
footerStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background)
|
||||
contentBuilder.WriteString(footerStyle.Render(fmt.Sprintf("Filter: %s", ss.filter)))
|
||||
|
||||
// Apply the bordered container
|
||||
content := contentBuilder.String()
|
||||
borderedContent := containerStyle.Render(content)
|
||||
|
||||
v := tea.NewView(borderedContent)
|
||||
v.AltScreen = true
|
||||
return v
|
||||
}
|
||||
|
||||
// IsActive returns whether the selector is still accepting input.
|
||||
@@ -403,12 +476,12 @@ func removeByPath(sessions []session.SessionInfo, path string) []session.Session
|
||||
// renderEntry renders a single session line with right-aligned metadata.
|
||||
// Layout: [cursor 2] [message ...variable...] [padding] [count age] [cwd?]
|
||||
func (ss *SessionSelectorComponent) renderEntry(info session.SessionInfo, isCursor, isCurrent, isDeleting bool, width int) string {
|
||||
theme := GetTheme()
|
||||
theme := style.GetTheme()
|
||||
|
||||
// ── Cursor indicator (2 chars) ───────────────────────────────
|
||||
cursorStr := " "
|
||||
if isCursor {
|
||||
cursorStr = lipgloss.NewStyle().Foreground(theme.Accent).Render("› ")
|
||||
cursorStr = lipgloss.NewStyle().Foreground(theme.Accent).Render("> ")
|
||||
}
|
||||
const cursorW = 2
|
||||
|
||||
@@ -436,45 +509,50 @@ func (ss *SessionSelectorComponent) renderEntry(info session.SessionInfo, isCurs
|
||||
msgW := utf8.RuneCountInString(displayText)
|
||||
|
||||
// ── Style the message ────────────────────────────────────────
|
||||
msgStyle := lipgloss.NewStyle()
|
||||
var msgStyle lipgloss.Style
|
||||
switch {
|
||||
case isDeleting:
|
||||
msgStyle = msgStyle.Foreground(theme.Error)
|
||||
msgStyle = lipgloss.NewStyle().Foreground(theme.Error)
|
||||
case isCurrent:
|
||||
msgStyle = msgStyle.Foreground(theme.Accent)
|
||||
msgStyle = lipgloss.NewStyle().Foreground(theme.Accent)
|
||||
case info.Name != "":
|
||||
msgStyle = msgStyle.Foreground(theme.Warning)
|
||||
msgStyle = lipgloss.NewStyle().Foreground(theme.Warning)
|
||||
default:
|
||||
msgStyle = msgStyle.Foreground(theme.Text)
|
||||
msgStyle = lipgloss.NewStyle().Foreground(theme.Text)
|
||||
}
|
||||
if isCursor {
|
||||
msgStyle = msgStyle.Bold(true)
|
||||
}
|
||||
|
||||
styledMsg := msgStyle.Render(displayText)
|
||||
|
||||
// ── Style the right part ─────────────────────────────────────
|
||||
rightColor := theme.Muted
|
||||
if isDeleting {
|
||||
rightColor = theme.Error
|
||||
}
|
||||
styledRight := lipgloss.NewStyle().Foreground(rightColor).Render(rightPart)
|
||||
var styledRight string
|
||||
|
||||
// ── Assemble with spacing ────────────────────────────────────
|
||||
spacing := max(width-cursorW-msgW-rightW, 1)
|
||||
|
||||
line := cursorStr + styledMsg + strings.Repeat(" ", spacing) + styledRight
|
||||
|
||||
// ── Background highlight for selected row ────────────────────
|
||||
// If selected, use inverted colors like PopupList
|
||||
if isCursor {
|
||||
// Use a subtle background highlight. We apply it by wrapping the
|
||||
// full line in a style with a background color.
|
||||
bgStyle := lipgloss.NewStyle().
|
||||
Background(theme.Highlight).
|
||||
Width(width)
|
||||
line = bgStyle.Render(line)
|
||||
// Inverted colors for selected item
|
||||
msgStyle = lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(theme.Background).
|
||||
Bold(true)
|
||||
styledRight = lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(rightColor).
|
||||
Render(rightPart)
|
||||
cursorStr = lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(theme.Accent).
|
||||
Render("> ")
|
||||
} else {
|
||||
styledRight = lipgloss.NewStyle().Foreground(rightColor).Render(rightPart)
|
||||
}
|
||||
|
||||
styledMsg := msgStyle.Render(displayText)
|
||||
line := cursorStr + styledMsg + strings.Repeat(" ", spacing) + styledRight
|
||||
|
||||
return line
|
||||
}
|
||||
|
||||
|
||||
+42
-158
@@ -2,25 +2,15 @@ package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/lipgloss/v2"
|
||||
"github.com/indaco/herald"
|
||||
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
)
|
||||
|
||||
// thinkTagRegex matches ... tags that some models (Qwen, DeepSeek) wrap
|
||||
// reasoning content in. Used to strip these tags from streaming text content.
|
||||
// The (?s) flag makes . match newlines.
|
||||
var thinkTagRegex = regexp.MustCompile(`(?s)` + `` + `think` + `` + `(.*?)` + `` + `/think` + ``)
|
||||
|
||||
// thinkTagOpen and thinkTagClose are the opening and closing think tag strings.
|
||||
const (
|
||||
thinkTagOpen = "<think>"
|
||||
thinkTagClose = "</think>"
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
)
|
||||
|
||||
// knightRiderFrames generates a KITT-style scanning animation where a bright
|
||||
@@ -31,7 +21,7 @@ func knightRiderFrames() []string {
|
||||
const numDots = 8
|
||||
const dot = "▪"
|
||||
|
||||
theme := GetTheme()
|
||||
theme := style.GetTheme()
|
||||
|
||||
bright := lipgloss.NewStyle().Foreground(theme.Primary)
|
||||
med := lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
@@ -131,13 +121,13 @@ const (
|
||||
// alongside streaming text until the step completes and Reset() is called.
|
||||
//
|
||||
// Tool calls, tool results, user messages, and other non-streaming content
|
||||
// are printed immediately by the parent AppModel via tea.Println(). The
|
||||
// StreamComponent only handles the live streaming text and spinner display.
|
||||
// are added to the ScrollList by the parent AppModel. The StreamComponent
|
||||
// only handles the live streaming text and spinner display.
|
||||
//
|
||||
// Lifecycle is managed entirely by the parent AppModel:
|
||||
// - Parent calls Reset() between agent steps to clear state.
|
||||
// - Parent emits completed responses above the BT region via tea.Println()
|
||||
// then calls Reset(); StreamComponent never calls tea.Quit.
|
||||
// - Content is displayed via StreamingMessageItem in the ScrollList.
|
||||
// - StreamComponent never calls tea.Quit.
|
||||
//
|
||||
// Events handled:
|
||||
// - app.SpinnerEvent{Show:true} → start spinner tick loop
|
||||
@@ -196,15 +186,6 @@ type StreamComponent struct {
|
||||
// ticks from a previous step can be discarded.
|
||||
flushGeneration uint64
|
||||
|
||||
// renderCache holds the last rendered output string. Reused by View()
|
||||
// between flush ticks to avoid redundant markdown re-parsing.
|
||||
renderCache string
|
||||
|
||||
// renderDirty is true when committed content has changed since the
|
||||
// last render. Set on flush tick; cleared after render() rebuilds
|
||||
// the cache.
|
||||
renderDirty bool
|
||||
|
||||
// thinkingVisible controls whether reasoning blocks are expanded or collapsed.
|
||||
thinkingVisible bool
|
||||
|
||||
@@ -214,11 +195,7 @@ type StreamComponent struct {
|
||||
// reasoningDuration holds the total reasoning time, frozen when streaming text begins.
|
||||
reasoningDuration time.Duration
|
||||
|
||||
// inThinkTag tracks whether we're currently inside a section
|
||||
// from models that wrap reasoning in XML-like tags (Qwen, DeepSeek).
|
||||
inThinkTag bool
|
||||
|
||||
// renderer renders streaming assistant text in either compact or standard mode.
|
||||
// renderer renders streaming assistant text.
|
||||
renderer Renderer
|
||||
|
||||
// modelName is displayed in the streaming text header.
|
||||
@@ -239,17 +216,12 @@ type StreamComponent struct {
|
||||
}
|
||||
|
||||
// NewStreamComponent creates a new StreamComponent ready to be embedded in AppModel.
|
||||
func NewStreamComponent(compactMode bool, width int, modelName string) *StreamComponent {
|
||||
func NewStreamComponent(width int, modelName string) *StreamComponent {
|
||||
if width == 0 {
|
||||
width = 80
|
||||
}
|
||||
|
||||
var renderer Renderer
|
||||
if compactMode {
|
||||
renderer = NewCompactRenderer(width, false)
|
||||
} else {
|
||||
renderer = newMessageRenderer(width, false)
|
||||
}
|
||||
renderer := newMessageRenderer(width, false)
|
||||
|
||||
return &StreamComponent{
|
||||
spinnerFrames: knightRiderFrames(),
|
||||
@@ -269,9 +241,6 @@ func (s *StreamComponent) SetHeight(h int) {
|
||||
}
|
||||
if s.height != h {
|
||||
s.height = h
|
||||
// Invalidate cache — height clamp affects output.
|
||||
s.renderCache = ""
|
||||
s.renderDirty = true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -290,16 +259,20 @@ func (s *StreamComponent) Reset() {
|
||||
s.pendingReasoning.Reset()
|
||||
s.flushPending = false
|
||||
s.flushGeneration++
|
||||
s.renderCache = ""
|
||||
s.renderDirty = false
|
||||
s.timestamp = time.Time{}
|
||||
s.reasoningStartTime = time.Time{}
|
||||
s.reasoningDuration = 0
|
||||
}
|
||||
|
||||
// ConsumeOverflow is a no-op in alt screen mode. Overflow is handled by the
|
||||
// ScrollList viewport. Retained to satisfy streamComponentIface.
|
||||
func (s *StreamComponent) ConsumeOverflow() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetRenderedContent returns the rendered assistant message from the accumulated
|
||||
// streaming text. Returns empty string if no text has been accumulated. Used by
|
||||
// the parent AppModel to flush content via tea.Println() before resetting.
|
||||
// the parent AppModel to flush stream content before resetting.
|
||||
//
|
||||
// This commits any pending chunks first so the output includes all received
|
||||
// content, not just what has been flushed by the tick.
|
||||
@@ -327,19 +300,15 @@ func (s *StreamComponent) GetRenderedContent() string {
|
||||
}
|
||||
|
||||
// commitPending moves any pending chunks to the committed content builders.
|
||||
// Called before reading content for scrollback output or on flush tick.
|
||||
// Called before reading content for output or on flush tick.
|
||||
func (s *StreamComponent) commitPending() {
|
||||
if s.pendingStream.Len() > 0 {
|
||||
// Strip ... tags that some models wrap reasoning in
|
||||
cleanedText := thinkTagRegex.ReplaceAllString(s.pendingStream.String(), "")
|
||||
s.streamContent.WriteString(cleanedText)
|
||||
s.streamContent.WriteString(s.pendingStream.String())
|
||||
s.pendingStream.Reset()
|
||||
s.renderDirty = true
|
||||
}
|
||||
if s.pendingReasoning.Len() > 0 {
|
||||
s.reasoningContent.WriteString(s.pendingReasoning.String())
|
||||
s.pendingReasoning.Reset()
|
||||
s.renderDirty = true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -362,9 +331,6 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if s.renderer != nil {
|
||||
s.renderer.SetWidth(s.width)
|
||||
}
|
||||
// Invalidate render cache — width change affects wrapping/styling.
|
||||
s.renderCache = ""
|
||||
s.renderDirty = true
|
||||
|
||||
case streamSpinnerTickMsg:
|
||||
// Only continue the tick loop if this tick belongs to the current
|
||||
@@ -417,6 +383,17 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return s, streamFlushTickCmd(s.flushGeneration)
|
||||
}
|
||||
|
||||
case app.ReasoningCompleteEvent:
|
||||
// Freeze reasoning duration when reasoning finishes (before text streaming starts).
|
||||
if s.reasoningDuration == 0 && !s.reasoningStartTime.IsZero() {
|
||||
s.reasoningDuration = time.Since(s.reasoningStartTime)
|
||||
}
|
||||
// Flush any remaining pending reasoning content.
|
||||
if s.pendingReasoning.Len() > 0 {
|
||||
s.reasoningContent.WriteString(s.pendingReasoning.String())
|
||||
s.pendingReasoning.Reset()
|
||||
}
|
||||
|
||||
case app.StreamChunkEvent:
|
||||
s.phase = streamPhaseActive
|
||||
if s.timestamp.IsZero() {
|
||||
@@ -427,43 +404,9 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
s.reasoningDuration = time.Since(s.reasoningStartTime)
|
||||
}
|
||||
|
||||
// Handle models that wrap reasoning in tags (Qwen, DeepSeek)
|
||||
// Filter out all content between and tags
|
||||
content := msg.Content
|
||||
|
||||
// Check for opening tag
|
||||
if strings.Contains(content, thinkTagOpen) {
|
||||
parts := strings.SplitN(content, thinkTagOpen, 2)
|
||||
// Content before the tag can be written
|
||||
if !s.inThinkTag && parts[0] != "" {
|
||||
s.pendingStream.WriteString(parts[0])
|
||||
}
|
||||
s.inThinkTag = true
|
||||
// Content after the opening tag is reasoning - don't write it
|
||||
if len(parts) > 1 && parts[1] != "" {
|
||||
// Check if the same chunk contains the closing tag
|
||||
if strings.Contains(parts[1], thinkTagClose) {
|
||||
innerParts := strings.SplitN(parts[1], thinkTagClose, 2)
|
||||
s.inThinkTag = false
|
||||
// Content after closing tag can be written
|
||||
if len(innerParts) > 1 && innerParts[1] != "" {
|
||||
s.pendingStream.WriteString(innerParts[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if strings.Contains(content, thinkTagClose) {
|
||||
// Closing tag found
|
||||
parts := strings.SplitN(content, thinkTagClose, 2)
|
||||
s.inThinkTag = false
|
||||
// Content after closing tag can be written
|
||||
if len(parts) > 1 && parts[1] != "" {
|
||||
s.pendingStream.WriteString(parts[1])
|
||||
}
|
||||
} else if !s.inThinkTag {
|
||||
// Normal content, not inside think tags
|
||||
s.pendingStream.WriteString(content)
|
||||
}
|
||||
// else: inside think tag, don't write this content
|
||||
// <think> tag filtering is handled at the agent layer — chunks here
|
||||
// are already clean text.
|
||||
s.pendingStream.WriteString(msg.Content)
|
||||
|
||||
if !s.flushPending && s.pendingStream.Len() > 0 {
|
||||
s.flushPending = true
|
||||
@@ -504,72 +447,10 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// View implements tea.Model. Renders the current stream region content.
|
||||
// View implements tea.Model. Returns an empty view since rendering is handled
|
||||
// by StreamingMessageItem in the ScrollList. Retained to satisfy tea.Model.
|
||||
func (s *StreamComponent) View() tea.View {
|
||||
fullContent := s.render()
|
||||
visibleContent := s.viewContent(fullContent)
|
||||
return tea.NewView(visibleContent)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Internal rendering
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// render builds the full content string for the stream region. Uses a render
|
||||
// cache to avoid redundant markdown re-parsing between flush ticks. The cache
|
||||
// is invalidated when committed content changes (flush tick), terminal width
|
||||
// changes, or height/thinking visibility changes.
|
||||
func (s *StreamComponent) render() string {
|
||||
if s.phase == streamPhaseIdle {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Return cached render if committed content hasn't changed.
|
||||
if !s.renderDirty {
|
||||
return s.renderCache
|
||||
}
|
||||
|
||||
var sections []string
|
||||
|
||||
// Render reasoning/thinking block above the main text if present.
|
||||
if reasoning := s.reasoningContent.String(); reasoning != "" {
|
||||
sections = append(sections, s.renderReasoningBlock(reasoning))
|
||||
}
|
||||
|
||||
// Render streaming text only. The spinner is rendered in the status bar
|
||||
// by the parent so it never changes the stream region height.
|
||||
text := s.streamContent.String()
|
||||
if text != "" {
|
||||
sections = append(sections, s.renderStreamingText(text))
|
||||
}
|
||||
|
||||
if len(sections) == 0 {
|
||||
s.renderCache = ""
|
||||
s.renderDirty = false
|
||||
return ""
|
||||
}
|
||||
|
||||
content := strings.Join(sections, "\n")
|
||||
|
||||
// Cache FULL content without height clamping.
|
||||
// Height clamping is applied in View() for display only.
|
||||
s.renderCache = content
|
||||
s.renderDirty = false
|
||||
return content
|
||||
}
|
||||
|
||||
// viewContent returns the visible portion of content based on height constraint.
|
||||
// This is called by View() to get the slice that fits in the terminal.
|
||||
func (s *StreamComponent) viewContent(fullContent string) string {
|
||||
if s.height > 0 && fullContent != "" {
|
||||
lines := strings.Split(fullContent, "\n")
|
||||
if len(lines) > s.height {
|
||||
// Keep only the last h lines so the most recent output is visible.
|
||||
lines = lines[len(lines)-s.height:]
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
}
|
||||
return fullContent
|
||||
return tea.NewView("")
|
||||
}
|
||||
|
||||
// renderReasoningBlock renders the reasoning/thinking content using blockquote.
|
||||
@@ -630,9 +511,6 @@ func (s *StreamComponent) renderReasoningBlock(reasoning string) string {
|
||||
func (s *StreamComponent) SetThinkingVisible(visible bool) {
|
||||
if s.thinkingVisible != visible {
|
||||
s.thinkingVisible = visible
|
||||
// Invalidate cache — thinking visibility affects rendered output.
|
||||
s.renderCache = ""
|
||||
s.renderDirty = true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -709,3 +587,9 @@ func removeToolID(ids []string, id string) []string {
|
||||
func formatToolExecutionMessage(toolName string) string {
|
||||
return toolName
|
||||
}
|
||||
|
||||
// UpdateTheme refreshes the component's typography instance with colors from
|
||||
// the current theme. This is called when the user changes themes via /theme.
|
||||
func (s *StreamComponent) UpdateTheme() {
|
||||
s.ty = createTypography(GetTheme())
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package ui
|
||||
package style
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -294,3 +294,28 @@ func ApplyGradient(text string, colorA, colorB color.Color) string {
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// KitBanner returns the KIT ASCII art title with KITT scanner lights,
|
||||
// rendered with a KITT red gradient.
|
||||
func KitBanner() string {
|
||||
kittDark := lipgloss.Color("#8B0000")
|
||||
kittBright := lipgloss.Color("#FF2200")
|
||||
lines := []string{
|
||||
" ██╗ ██╗ ██╗ ████████╗",
|
||||
" ██║ ██╔╝ ██║ ╚══██╔══╝",
|
||||
" █████╔╝ ██║ ██║",
|
||||
" ██╔═██╗ ██║ ██║",
|
||||
" ██║ ██╗ ██║ ██║",
|
||||
" ╚═╝ ╚═╝ ╚═╝ ╚═╝",
|
||||
"░░ ░░ ░░ ▒▒ ▒▒ ▓▓ ▓▓ ████ ▓▓ ▓▓ ▒▒ ▒▒ ░░ ░░ ░░",
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
for i, line := range lines {
|
||||
if i > 0 {
|
||||
result.WriteString("\n")
|
||||
}
|
||||
result.WriteString(ApplyGradient(line, kittDark, kittBright))
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package ui
|
||||
package style
|
||||
|
||||
import (
|
||||
"charm.land/lipgloss/v2"
|
||||
@@ -85,11 +85,13 @@ func GetMarkdownTypography() *herald.Typography {
|
||||
return ty
|
||||
}
|
||||
|
||||
// toMarkdown renders markdown content using herald-md.
|
||||
// The width parameter is currently unused as herald handles wrapping
|
||||
// based on terminal width internally.
|
||||
func toMarkdown(content string, width int) string {
|
||||
// ToMarkdown renders markdown content using herald-md and wraps the result
|
||||
// to the given width so that long lines do not overflow the terminal.
|
||||
func ToMarkdown(content string, width int) string {
|
||||
ty := GetMarkdownTypography()
|
||||
rendered := heraldmd.Render(ty, []byte(content))
|
||||
if width > 0 {
|
||||
rendered = lipgloss.Wrap(rendered, width, "")
|
||||
}
|
||||
return rendered
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user