mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 11:40:13 +00:00
Compare commits
106 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9d38349091 | |||
| fec8bac800 | |||
| e76f5f3d45 | |||
| 1ad493c5c7 | |||
| ea6ddc8792 | |||
| 6d4e8bcec5 | |||
| e2ed345280 | |||
| e542eb797e | |||
| e631fc1b17 | |||
| 290c5a4774 | |||
| 287d60c31e | |||
| 3d45d98895 | |||
| db4be4f9a2 | |||
| 80093e69ed | |||
| ef519ba517 | |||
| d79eb1f0fa | |||
| ac8ee6525d | |||
| e35e8382d6 | |||
| fbb3408a25 | |||
| 44fed9a647 | |||
| e7f11487b9 | |||
| 054c417603 | |||
| 94d62a6ef0 | |||
| 91e6dfd2c8 | |||
| b6a0c4b44c | |||
| 8eb0fa855a | |||
| 3bf696c546 | |||
| 3e461a0539 | |||
| a2ece01ecf | |||
| 623c9fb5ad | |||
| 139506f336 | |||
| 6d424554ad | |||
| 5a3d3fdd7d | |||
| c91225629d | |||
| 5a71cde5ff | |||
| 044d3eb206 | |||
| 80f3a642a3 | |||
| 26f0969e3e | |||
| 4af75901b5 | |||
| 49ff4c0678 | |||
| b0802a5c32 | |||
| dfe65ca227 | |||
| d4ec756ce5 | |||
| 2971e73ee8 | |||
| 5aa6c9e116 | |||
| bca08476de | |||
| 6a599d86af | |||
| fd6f200659 | |||
| b295a25946 | |||
| f0e4e2f757 | |||
| d25249506a | |||
| 971521f534 | |||
| 8c00682367 | |||
| 58caf155c1 | |||
| 3f08bf2424 | |||
| 9fbbab05f6 | |||
| b0991c7aa6 | |||
| 9c90563765 | |||
| f36166bee5 | |||
| 879e81f9b5 | |||
| 727b42acfe | |||
| 4830981570 | |||
| dcfebafcc5 | |||
| 1f5c103667 | |||
| 4caa8ba3dc | |||
| 15ef8ad78b | |||
| 551f2710d9 | |||
| 67bda5cad5 | |||
| 01d7d754ef | |||
| c6304f1e92 | |||
| bc3c733ae3 | |||
| 428ee2b8be | |||
| eb1d7fd07e | |||
| 1e3e5cafd3 | |||
| 0b93e58fb9 | |||
| 2bb01ed72c | |||
| b6ecc36ea1 | |||
| d4f27bc912 | |||
| f12e195390 | |||
| b68b3dd0bf | |||
| 48521bf76d | |||
| 16df3a738c | |||
| 9d0b8c8cef | |||
| d9326fcf21 | |||
| 22c479277e | |||
| 8ae204f12f | |||
| 8b1665a4ce | |||
| 941f1daf0b | |||
| ab7e2bda61 | |||
| 741520927c | |||
| 4c1bda9541 | |||
| 3b69b13556 | |||
| 83a959a379 | |||
| 3491e05e9e | |||
| 0a54a8aa05 | |||
| 3cb3e5dba1 | |||
| 31966c469f | |||
| f03625d6e5 | |||
| d06641dc0a | |||
| bbf1106e27 | |||
| babed03a3d | |||
| 1cd074836f | |||
| ab3ce260c8 | |||
| 8e8cc3946d | |||
| e18e36625e | |||
| be55bc03f1 |
@@ -3,6 +3,7 @@
|
||||
.env
|
||||
.kit/*
|
||||
!.kit/extensions/
|
||||
!.kit/prompts/
|
||||
aidocs/
|
||||
*.log
|
||||
/kit
|
||||
|
||||
@@ -0,0 +1,304 @@
|
||||
//go:build ignore
|
||||
|
||||
// subagent-monitor — live horizontal widget strip for spawned subagents
|
||||
//
|
||||
// Subscribes to subagents spawned by the main Kit agent and displays a
|
||||
// single widget just above the input box. Each subagent occupies one column
|
||||
// in a side-by-side horizontal layout. Columns show scrolling real-time
|
||||
// output as the subagent works. When a subagent finishes its column is
|
||||
// removed automatically.
|
||||
//
|
||||
// Yaegi-safe design notes:
|
||||
// - No sync.Mutex (Yaegi has reflection issues with sync primitives)
|
||||
// - No channels in maps (Yaegi panics on range over map[string]chan)
|
||||
// - All ctx.* calls guarded with nil checks
|
||||
// - Simple data structures only
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Per-subagent state
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type submonEntry struct {
|
||||
id int
|
||||
callID string
|
||||
task string
|
||||
lines []string
|
||||
started time.Time
|
||||
elapsed time.Duration
|
||||
}
|
||||
|
||||
const (
|
||||
submonColWidth = 34 // visible character width per column
|
||||
submonMaxLines = 5 // scrolling output lines per column
|
||||
submonColGap = 2 // spaces between columns
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Package-level state - all simple types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
var (
|
||||
submonCtx ext.Context
|
||||
submonHasCtx bool
|
||||
submonEntries []*submonEntry
|
||||
submonNextID int
|
||||
)
|
||||
|
||||
func submonInit() {
|
||||
submonEntries = nil
|
||||
submonNextID = 1
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// String helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func submonPad(s string, w int) string {
|
||||
r := []rune(s)
|
||||
if len(r) >= w {
|
||||
return string(r[:w])
|
||||
}
|
||||
return s + strings.Repeat(" ", w-len(r))
|
||||
}
|
||||
|
||||
func submonTrunc(s string, w int) string {
|
||||
r := []rune(s)
|
||||
if len(r) <= w {
|
||||
return s
|
||||
}
|
||||
if w <= 1 {
|
||||
return "…"
|
||||
}
|
||||
return string(r[:w-1]) + "…"
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Widget rendering
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func submonRenderColumn(e *submonEntry) []string {
|
||||
var rows []string
|
||||
|
||||
// Calculate elapsed time on-demand to avoid race conditions with ticker
|
||||
elapsed := e.elapsed
|
||||
if elapsed == 0 && !e.started.IsZero() {
|
||||
elapsed = time.Since(e.started)
|
||||
}
|
||||
secs := int(elapsed.Seconds())
|
||||
timeStr := fmt.Sprintf("%ds", secs)
|
||||
taskMax := submonColWidth - len(timeStr) - 3
|
||||
taskPart := submonTrunc(e.task, taskMax)
|
||||
header := fmt.Sprintf("#%d %s %s", e.id, taskPart, timeStr)
|
||||
rows = append(rows, submonPad(header, submonColWidth))
|
||||
|
||||
display := e.lines
|
||||
if len(display) > submonMaxLines {
|
||||
display = display[len(display)-submonMaxLines:]
|
||||
}
|
||||
for _, l := range display {
|
||||
rows = append(rows, submonPad(" "+submonTrunc(l, submonColWidth-2), submonColWidth))
|
||||
}
|
||||
for len(rows) < submonMaxLines+1 {
|
||||
if len(rows) == 1 && len(e.lines) == 0 {
|
||||
rows = append(rows, submonPad(" waiting…", submonColWidth))
|
||||
} else {
|
||||
rows = append(rows, strings.Repeat(" ", submonColWidth))
|
||||
}
|
||||
}
|
||||
return rows
|
||||
}
|
||||
|
||||
func submonBuildWidget() string {
|
||||
if len(submonEntries) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
numCols := len(submonEntries)
|
||||
numRows := submonMaxLines + 1
|
||||
cols := make([][]string, numCols)
|
||||
for i, e := range submonEntries {
|
||||
rows := submonRenderColumn(e)
|
||||
col := make([]string, numRows)
|
||||
for j := 0; j < numRows; j++ {
|
||||
if j < len(rows) {
|
||||
col[j] = rows[j]
|
||||
} else {
|
||||
col[j] = strings.Repeat(" ", submonColWidth)
|
||||
}
|
||||
}
|
||||
cols[i] = col
|
||||
}
|
||||
|
||||
gap := strings.Repeat(" ", submonColGap)
|
||||
var sb strings.Builder
|
||||
for row := 0; row < numRows; row++ {
|
||||
for ci := range cols {
|
||||
if ci > 0 {
|
||||
sb.WriteString(gap)
|
||||
}
|
||||
sb.WriteString(cols[ci][row])
|
||||
}
|
||||
if row < numRows-1 {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func submonPushWidget() {
|
||||
if !submonHasCtx {
|
||||
return
|
||||
}
|
||||
if submonCtx.SetWidget == nil {
|
||||
return
|
||||
}
|
||||
|
||||
text := submonBuildWidget()
|
||||
if len(submonEntries) == 0 {
|
||||
if submonCtx.RemoveWidget != nil {
|
||||
submonCtx.RemoveWidget("submon")
|
||||
}
|
||||
return
|
||||
}
|
||||
submonCtx.SetWidget(ext.WidgetConfig{
|
||||
ID: "submon",
|
||||
Placement: ext.WidgetAbove,
|
||||
Content: ext.WidgetContent{Text: text},
|
||||
Style: ext.WidgetStyle{BorderColor: "#89b4fa"},
|
||||
Priority: 0,
|
||||
})
|
||||
}
|
||||
|
||||
func submonAppendLine(e *submonEntry, line string) {
|
||||
line = strings.TrimRight(line, "\r\n")
|
||||
if strings.TrimSpace(line) == "" {
|
||||
return
|
||||
}
|
||||
e.lines = append(e.lines, line)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Init
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func Init(api ext.API) {
|
||||
submonInit()
|
||||
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
submonCtx = ctx
|
||||
submonHasCtx = true
|
||||
submonInit()
|
||||
if ctx.RemoveWidget != nil {
|
||||
ctx.RemoveWidget("submon")
|
||||
}
|
||||
})
|
||||
|
||||
api.OnAgentEnd(func(_ ext.AgentEndEvent, ctx ext.Context) {
|
||||
submonCtx = ctx
|
||||
submonHasCtx = true
|
||||
})
|
||||
|
||||
// ── SubagentStart ────────────────────────────────────────────────────────
|
||||
api.OnSubagentStart(func(e ext.SubagentStartEvent, ctx ext.Context) {
|
||||
submonCtx = ctx
|
||||
submonHasCtx = true
|
||||
|
||||
id := submonNextID
|
||||
submonNextID++
|
||||
entry := &submonEntry{
|
||||
id: id,
|
||||
callID: e.ToolCallID,
|
||||
task: e.Task,
|
||||
started: time.Now(),
|
||||
}
|
||||
submonEntries = append(submonEntries, entry)
|
||||
|
||||
submonPushWidget()
|
||||
})
|
||||
|
||||
// ── SubagentChunk ────────────────────────────────────────────────────────
|
||||
api.OnSubagentChunk(func(e ext.SubagentChunkEvent, ctx ext.Context) {
|
||||
submonCtx = ctx
|
||||
submonHasCtx = true
|
||||
|
||||
var entry *submonEntry
|
||||
for _, en := range submonEntries {
|
||||
if en.callID == e.ToolCallID {
|
||||
entry = en
|
||||
break
|
||||
}
|
||||
}
|
||||
if entry == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch e.ChunkType {
|
||||
case "text":
|
||||
for _, line := range strings.Split(e.Content, "\n") {
|
||||
submonAppendLine(entry, line)
|
||||
}
|
||||
case "tool_call":
|
||||
submonAppendLine(entry, "→ "+e.ToolName)
|
||||
case "tool_execution_start":
|
||||
submonAppendLine(entry, "⚙ "+e.ToolName)
|
||||
case "tool_result":
|
||||
if e.IsError {
|
||||
submonAppendLine(entry, "✗ "+e.ToolName)
|
||||
} else {
|
||||
submonAppendLine(entry, "✓ "+e.ToolName)
|
||||
}
|
||||
}
|
||||
|
||||
submonPushWidget()
|
||||
})
|
||||
|
||||
// ── SubagentEnd ──────────────────────────────────────────────────────────
|
||||
api.OnSubagentEnd(func(e ext.SubagentEndEvent, ctx ext.Context) {
|
||||
submonCtx = ctx
|
||||
submonHasCtx = true
|
||||
|
||||
var entry *submonEntry
|
||||
for _, en := range submonEntries {
|
||||
if en.callID == e.ToolCallID {
|
||||
entry = en
|
||||
break
|
||||
}
|
||||
}
|
||||
if entry != nil {
|
||||
entry.elapsed = time.Since(entry.started)
|
||||
if e.ErrorMsg != "" {
|
||||
submonAppendLine(entry, "✗ "+submonTrunc(e.ErrorMsg, submonColWidth-2))
|
||||
}
|
||||
}
|
||||
|
||||
submonPushWidget()
|
||||
|
||||
// Remove the entry immediately (no goroutine to avoid races)
|
||||
newEntries := submonEntries[:0]
|
||||
for _, en := range submonEntries {
|
||||
if en.callID != e.ToolCallID {
|
||||
newEntries = append(newEntries, en)
|
||||
}
|
||||
}
|
||||
submonEntries = newEntries
|
||||
submonPushWidget()
|
||||
})
|
||||
|
||||
// ── SessionShutdown ──────────────────────────────────────────────────────
|
||||
api.OnSessionShutdown(func(_ ext.SessionShutdownEvent, ctx ext.Context) {
|
||||
submonInit()
|
||||
// Guard ctx access - may be nil during shutdown
|
||||
if ctx.RemoveWidget != nil {
|
||||
ctx.RemoveWidget("submon")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
---
|
||||
description: Run ACP smoke test against opencode/kimi-k2.5 to verify JSON-RPC stdio works
|
||||
---
|
||||
|
||||
Run the ACP smoke test to verify the Kit ACP server works correctly over JSON-RPC stdio with streaming responses.
|
||||
|
||||
## Steps
|
||||
|
||||
1. Build the kit binary:
|
||||
```bash
|
||||
go build -o output/kit ./cmd/kit
|
||||
```
|
||||
|
||||
2. Run the smoke test Python script against opencode/kimi-k2.5:
|
||||
```bash
|
||||
python3 scripts/acp_smoke_test.py
|
||||
```
|
||||
|
||||
3. Verify the output shows:
|
||||
- `session/new` returns a valid `sessionId`
|
||||
- `session/prompt` streams `agent_thought_chunk` notifications (reasoning)
|
||||
- `session/prompt` streams `agent_message_chunk` notifications (response)
|
||||
- Final result has `stopReason: "end_turn"`
|
||||
- `✓ SMOKE TEST PASSED` at the end
|
||||
|
||||
4. If the test fails, check:
|
||||
- `output/kit` binary exists and is executable
|
||||
- `OPENCODE_API_KEY` or `OPENCODE_ZEN_API_KEY` environment variable is set
|
||||
- `scripts/acp_smoke_test.py` exists
|
||||
- The model `opencode/kimi-k2.5` is available (`kit models opencode | grep kimi-k2.5`)
|
||||
|
||||
5. For testing with a different model, edit the script or set the `MODEL` variable:
|
||||
```bash
|
||||
MODEL=anthropic/claude-sonnet-4-5 python3 scripts/acp_smoke_test.py
|
||||
```
|
||||
|
||||
The smoke test exercises the full ACP protocol: session lifecycle, streaming notifications, and tool-free prompt completion.
|
||||
@@ -0,0 +1,30 @@
|
||||
---
|
||||
description: Stage, commit, and push changes with an auto-generated conventional commit message
|
||||
---
|
||||
|
||||
Review the current git status and diff, then stage all changes, write a concise conventional commit message, commit, and push to the current branch.
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Check status**: `git status` — understand what has changed
|
||||
2. **Review the diff**: `git diff` (and `git diff --cached` if anything is already staged) — read the actual changes
|
||||
3. **Stage everything**: `git add -A`
|
||||
4. **Craft the commit message** following Conventional Commits:
|
||||
- Format: `<type>(<scope>): <short summary>`
|
||||
- Types: `feat`, `fix`, `refactor`, `chore`, `docs`, `test`, `perf`, `build`
|
||||
- Scope: optional, the subsystem affected (e.g. `ui`, `cmd`, `config`)
|
||||
- Summary: imperative mood, lowercase, no trailing period, ≤72 chars
|
||||
- Body: add a blank line then bullet points for non-trivial changes
|
||||
- Do **not** include "Generated by" or similar noise
|
||||
5. **Commit**: `git commit -m "<message>"`
|
||||
6. **Push**: `git push`
|
||||
|
||||
## Guidelines
|
||||
|
||||
- Read the actual diff — do not guess from filenames alone
|
||||
- Prefer one well-scoped commit; do not split unless the changes are clearly unrelated
|
||||
- Keep the subject line under 72 characters
|
||||
- Use the body to explain *what* and *why*, not *how*
|
||||
- If there is nothing to commit, say so and stop
|
||||
|
||||
$@
|
||||
@@ -0,0 +1,47 @@
|
||||
---
|
||||
description: Scaffold a new prompt template in .kit/prompts/
|
||||
---
|
||||
|
||||
Create a new kit prompt template. The user wants a prompt that does: $@
|
||||
|
||||
## What a prompt template is
|
||||
|
||||
A prompt template is a `.md` file in `.kit/prompts/` (project-local) or `~/.kit/prompts/` (global).
|
||||
It becomes a `/slug` slash command in the kit input box — typed as `/filename` with optional arguments.
|
||||
|
||||
## File format
|
||||
|
||||
```
|
||||
---
|
||||
description: One-line description shown in autocomplete
|
||||
---
|
||||
|
||||
Body text of the prompt. Use $@ for all user-supplied arguments,
|
||||
$1 $2 etc. for positional arguments.
|
||||
```
|
||||
|
||||
- **Filename** → slug: `commit-push.md` becomes `/commit-push`
|
||||
- **Frontmatter**: only `description` is recognised; keep it under ~80 chars
|
||||
- **Body**: plain markdown; the full text is submitted as the user's message when the template fires
|
||||
- **Arguments**: `$@` expands to everything the user typed after the slash command name;
|
||||
`$1`, `$2` for individual positional args; omit entirely if no arguments are needed
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Understand the workflow** the user described in `$@` — ask a clarifying question if the intent is ambiguous
|
||||
2. **Choose a filename**: short, lowercase, hyphen-separated, descriptive (e.g. `code-review.md`)
|
||||
3. **Write the description**: one sentence, imperative, fits in autocomplete
|
||||
4. **Draft the body**:
|
||||
- Open with a single sentence stating the goal
|
||||
- Use `## Steps` for multi-step workflows; use plain prose for simple prompts
|
||||
- Be specific: name commands, flags, and file paths where relevant
|
||||
- End with `$@` on its own line if the user might want to pass context or a hint; omit if the prompt is self-contained
|
||||
5. **Write the file** to `.kit/prompts/<slug>.md`
|
||||
6. **Confirm** by showing the final file content and the slash command that activates it
|
||||
|
||||
## Guidelines
|
||||
|
||||
- Keep prompts action-oriented — they should tell kit *what to do*, not just *what to think about*
|
||||
- Prefer concrete steps over vague instructions
|
||||
- A prompt that does one thing well beats one that tries to cover every edge case
|
||||
- If the workflow already exists as a prompt, suggest extending it instead of duplicating
|
||||
@@ -0,0 +1,70 @@
|
||||
---
|
||||
description: Semantic version tagging workflow - analyzes commits and tags releases
|
||||
---
|
||||
|
||||
# Release Tagging Workflow
|
||||
|
||||
Tag a new version of this Go project following semantic versioning.
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Fetch remote tags**: `git fetch --tags origin`
|
||||
|
||||
2. **Find latest version**: `git tag -l | sort -V | tail -5` to see recent tags
|
||||
|
||||
3. **Analyze changes since last tag**:
|
||||
- `git log <latest-tag>..HEAD --oneline` - list commits
|
||||
- `git diff <latest-tag>..HEAD --stat` - see file stats
|
||||
- `git diff <latest-tag>..HEAD --name-only` - see changed files
|
||||
|
||||
4. **Determine version bump** (Semantic Versioning):
|
||||
- **MAJOR (X.0.0)**: Breaking API changes, incompatible modifications
|
||||
- **MINOR (0.X.0)**: New features, backward-compatible additions
|
||||
- **PATCH (0.0.X)**: Bug fixes, backward-compatible fixes
|
||||
|
||||
Look for indicators:
|
||||
- `feat:` or `feature:` commits → MINOR
|
||||
- `fix:` or `bugfix:` commits → PATCH
|
||||
- `breaking:` or `BREAKING CHANGE:` → MAJOR
|
||||
- Breaking API changes in `pkg/` or public interfaces → MAJOR
|
||||
- New commands, flags, or features → MINOR
|
||||
- Documentation-only changes → PATCH (or skip)
|
||||
|
||||
5. **Calculate new version**: Increment appropriate segment, reset lower segments to 0
|
||||
|
||||
6. **Draft tag message**:
|
||||
- Summarize key changes from commits
|
||||
- Group by type (Features, Fixes, Breaking Changes)
|
||||
- Keep concise but informative
|
||||
|
||||
7. **Create annotated tag**: `git tag -a vX.Y.Z -m "vX.Y.Z - <summary>\n\n<detailed list>"`
|
||||
|
||||
8. **Push tag**: `git push origin vX.Y.Z`
|
||||
|
||||
## Guidelines
|
||||
|
||||
- Always fetch remote tags first to avoid conflicts
|
||||
- Use annotated tags (`-a`) with descriptive messages
|
||||
- Follow semver strictly - when in doubt, prefer conservative bump (patch over minor)
|
||||
- For Go projects, changes to `pkg/` or exported APIs warrant careful version consideration
|
||||
- If no changes since last tag, suggest skipping the release
|
||||
- Include commit summaries in the tag message body
|
||||
|
||||
## Example Tag Message Format
|
||||
|
||||
```
|
||||
v0.30.1 - Bug fixes for model handling and UI improvements
|
||||
|
||||
Fixes:
|
||||
- Properly handle think tags from Qwen/DeepSeek models
|
||||
- Handle custom provider model persistence and bare model names
|
||||
|
||||
Improvements:
|
||||
- UI style refactoring and cleanup
|
||||
```
|
||||
|
||||
Wait for the user to confirm the version and message before executing tag commands.
|
||||
|
||||
---
|
||||
|
||||
$@
|
||||
@@ -1,22 +1,3 @@
|
||||
<!-- OPENSPEC:START -->
|
||||
# OpenSpec Instructions
|
||||
|
||||
These instructions are for AI assistants working in this project.
|
||||
|
||||
Always open `@/openspec/AGENTS.md` when the request:
|
||||
- Mentions planning or proposals (words like proposal, spec, change, plan)
|
||||
- Introduces new capabilities, breaking changes, architecture shifts, or big performance/security work
|
||||
- Sounds ambiguous and you need the authoritative spec before coding
|
||||
|
||||
Use `@/openspec/AGENTS.md` to learn:
|
||||
- How to create and apply change proposals
|
||||
- Spec format and conventions
|
||||
- Project structure and guidelines
|
||||
|
||||
Keep this managed block so 'openspec update' can refresh the instructions.
|
||||
|
||||
<!-- OPENSPEC:END -->
|
||||
|
||||
# KIT Agent Guidelines
|
||||
|
||||
## Build/Test Commands
|
||||
@@ -42,6 +23,33 @@ Keep this managed block so 'openspec update' can refresh the instructions.
|
||||
- **Extension system** (`internal/extensions/`): Yaegi-interpreted Go, 13 lifecycle events, custom tools/commands/widgets/overlays/editor interceptors
|
||||
- **TUI** (`internal/ui/`): Bubble Tea v2 parent-child model (`AppModel` → `InputComponent`, `StreamComponent`, etc.)
|
||||
- **Decoupling pattern**: `cmd/root.go` has converter functions (e.g. `widgetProviderForUI()`) that bridge `internal/extensions/` types to `internal/ui/` types — the UI never imports extensions directly
|
||||
- **Public SDK** (`pkg/kit/`): The public-facing Go SDK for embedding Kit as a library. See rules below.
|
||||
|
||||
## Public SDK (`pkg/kit/`) Rules
|
||||
|
||||
`pkg/kit/` is the **public API surface** consumed by external Go developers. All exported symbols, types, function names, and godoc comments in this package are part of the SDK contract.
|
||||
|
||||
### No Dependency Name Leakage
|
||||
Internal dependency names (e.g. `charm.land/fantasy`, library-specific jargon) **must not** appear in:
|
||||
- **Exported function/method names** — use generic terms (`LLM`, `Provider`, `Message`) instead of library names
|
||||
- **Exported type names** — type aliases should use domain names (e.g. `LLMMessage`, not `FantasyMessage`)
|
||||
- **Godoc comments** on exported symbols — these are visible in `go doc` output and pkg.go.dev
|
||||
- **Struct field names and tags** on exported types
|
||||
|
||||
Using dependency types directly in **function bodies** (private implementation) is fine — that's invisible to SDK consumers.
|
||||
|
||||
### Naming Conventions for SDK Symbols
|
||||
- Type aliases re-exporting dependency types: use `LLM*` prefix (e.g. `LLMMessage`, `LLMUsage`, `LLMResponse`)
|
||||
- Conversion helpers: use `ConvertToLLM*` / `ConvertFromLLM*` (not the dependency name)
|
||||
- Provider queries: use `GetLLMProviders` (not `GetFantasyProviders`)
|
||||
- When wrapping internal methods, the `pkg/kit/` name should be dependency-agnostic even if the `internal/` method still uses the old name
|
||||
|
||||
### Deprecation Pattern
|
||||
When renaming a public SDK symbol, keep the old name as a deprecated wrapper for one release cycle:
|
||||
```go
|
||||
// Deprecated: Use NewName instead.
|
||||
func OldName() { return NewName() }
|
||||
```
|
||||
|
||||
## Key Patterns
|
||||
|
||||
@@ -92,3 +100,21 @@ Positional args are the prompt. `@file` args attach file content. Key flags: `--
|
||||
- Never guess or manually search the filesystem for external projects
|
||||
- Example: `btca ask -r https://github.com/user/repo -q "How does X work?"`
|
||||
- See `.agents/skills/btca-cli/SKILL.md` for full btca usage
|
||||
|
||||
## BTCA Configured Resources
|
||||
The following external repositories are configured in `btca.config.jsonc` for research:
|
||||
|
||||
- bubbletea
|
||||
- lipgloss
|
||||
- bubbles
|
||||
- glamour
|
||||
- fantasy
|
||||
- catwalk
|
||||
- crush
|
||||
- pi
|
||||
- iteratr
|
||||
- yaegi
|
||||
- acp-go-sdk
|
||||
- opencode
|
||||
- herald
|
||||
- herald-md
|
||||
|
||||
@@ -18,7 +18,7 @@ A powerful, extensible AI coding agent CLI with multi-provider support, built-in
|
||||
## Features
|
||||
|
||||
- **Multi-Provider LLM Support**: Anthropic, OpenAI, Google Gemini, Ollama, Azure OpenAI, AWS Bedrock, OpenRouter, and more
|
||||
- **Built-in Core Tools**: bash, read, write, edit, grep, find, ls, spawn_subagent - no MCP overhead
|
||||
- **Built-in Core Tools**: bash, read, write, edit, grep, find, ls, subagent - no MCP overhead
|
||||
- **MCP Integration**: Connect external MCP servers for expanded capabilities
|
||||
- **Extension System**: Write custom tools, commands, widgets, and UI modifications in Go
|
||||
- **Theming**: 22 built-in color themes (KITT, Catppuccin, Dracula, Nord, etc.) with runtime switching, persistence, and custom theme files
|
||||
@@ -209,7 +209,7 @@ kit auth status # Check authentication status
|
||||
|
||||
# Model database
|
||||
kit models [provider] # List available models (optionally filter by provider)
|
||||
kit models --all # Show all providers (not just Fantasy-compatible)
|
||||
kit models --all # Show all providers (not just LLM-compatible)
|
||||
kit update-models [source] # Update model database (from models.dev, URL, file, or 'embedded')
|
||||
|
||||
# Extension management
|
||||
@@ -287,7 +287,7 @@ kit -e examples/extensions/minimal.go
|
||||
|
||||
### Extension Capabilities
|
||||
|
||||
**Lifecycle Events**: OnSessionStart, OnSessionShutdown, OnBeforeAgentStart, OnAgentStart, OnAgentEnd, OnToolCall, OnToolExecutionStart, OnToolOutput, OnToolExecutionEnd, OnToolResult, OnInput, OnMessageStart, OnMessageUpdate, OnMessageEnd, OnModelChange, OnContextPrepare, OnBeforeFork, OnBeforeSessionSwitch, OnBeforeCompact, OnCustomEvent
|
||||
**Lifecycle Events**: OnSessionStart, OnSessionShutdown, OnBeforeAgentStart, OnAgentStart, OnAgentEnd, OnToolCall, OnToolExecutionStart, OnToolOutput, OnToolExecutionEnd, OnToolResult, OnInput, OnMessageStart, OnMessageUpdate, OnMessageEnd, OnModelChange, OnContextPrepare, OnBeforeFork, OnBeforeSessionSwitch, OnBeforeCompact, OnCustomEvent, OnSubagentStart, OnSubagentChunk, OnSubagentEnd
|
||||
|
||||
**Custom Components**:
|
||||
- **Tools**: Add new tools the LLM can invoke
|
||||
@@ -307,6 +307,12 @@ kit -e examples/extensions/minimal.go
|
||||
- **Themes**: Register and switch color themes via `RegisterTheme`, `SetTheme`, `ListThemes`
|
||||
- **Custom Events**: Inter-extension communication via `EmitCustomEvent`
|
||||
|
||||
**Bridged SDK APIs** (NEW): Extensions can now access internal SDK capabilities:
|
||||
- **Tree Navigation**: Navigate conversation history (`GetTreeNode`, `GetCurrentBranch`, `NavigateTo`), summarize branches (`SummarizeBranch`), and implement fresh context loops (`CollapseBranch`)
|
||||
- **Skill Loading**: Dynamically load and inject skills at runtime (`LoadSkill`, `DiscoverSkills`, `InjectSkillAsContext`)
|
||||
- **Template Parsing**: Parse and render templates with `{{variables}}` (`ParseTemplate`, `RenderTemplate`), parse CLI-style arguments (`ParseArguments`, `SimpleParseArguments`), and evaluate model conditionals (`EvaluateModelConditional`, `RenderWithModelConditionals`)
|
||||
- **Model Resolution**: Resolve model fallback chains (`ResolveModelChain`), query model capabilities (`GetModelCapabilities`, `CheckModelAvailable`), and extract provider/model ID (`GetCurrentProvider`, `GetCurrentModelID`)
|
||||
|
||||
### Extension Examples
|
||||
|
||||
See the `examples/extensions/` directory:
|
||||
@@ -318,6 +324,7 @@ See the `examples/extensions/` directory:
|
||||
- `compact-notify.go` - Notification on compaction
|
||||
- `confirm-destructive.go` - Confirm destructive operations
|
||||
- `context-inject.go` - Inject context into conversations
|
||||
- `conversation-manager.go` - **NEW** Tree navigation, branch summarization, and fresh context loops
|
||||
- `custom-editor-demo.go` - Vim-like modal editor
|
||||
- `dev-reload.go` - Development live-reload
|
||||
- `header-footer-demo.go` - Custom headers and footers
|
||||
@@ -332,6 +339,7 @@ See the `examples/extensions/` directory:
|
||||
- `plan-mode.go` - Read-only planning mode
|
||||
- `project-rules.go` - Project-specific rules
|
||||
- `prompt-demo.go` - Interactive prompts (select/confirm/input)
|
||||
- `prompt-templates.go` - **NEW** Frontmatter-driven templates with model switching and skill injection
|
||||
- `protected-paths.go` - Path protection for sensitive files
|
||||
- `subagent-widget.go` - Multi-agent orchestration with status widget
|
||||
- `subagent-test.go` - Subagent testing utilities
|
||||
@@ -494,7 +502,7 @@ func main() {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer host.Close()
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
// Send a prompt
|
||||
response, err := host.Prompt(ctx, "What is 2+2?")
|
||||
@@ -535,23 +543,26 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
### With Callbacks
|
||||
|
||||
```go
|
||||
response, err := host.PromptWithCallbacks(
|
||||
unsub := host.OnToolCall(func(e kit.ToolCallEvent) {
|
||||
println("Calling tool:", e.ToolName)
|
||||
})
|
||||
defer unsub()
|
||||
|
||||
unsub2 := host.OnToolResult(func(e kit.ToolResultEvent) {
|
||||
if e.IsError {
|
||||
println("Tool failed:", e.ToolName)
|
||||
}
|
||||
})
|
||||
defer unsub2()
|
||||
|
||||
unsub3 := host.OnStreaming(func(e kit.MessageUpdateEvent) {
|
||||
print(e.Chunk)
|
||||
})
|
||||
defer unsub3()
|
||||
|
||||
response, err := host.Prompt(
|
||||
ctx,
|
||||
"List files in current directory",
|
||||
func(name, args string) {
|
||||
// Tool call started
|
||||
println("Calling tool:", name)
|
||||
},
|
||||
func(name, args, result string, isError bool) {
|
||||
// Tool call completed
|
||||
if isError {
|
||||
println("Tool failed:", name)
|
||||
}
|
||||
},
|
||||
func(chunk string) {
|
||||
// Streaming text chunk
|
||||
print(chunk)
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
@@ -715,7 +726,7 @@ Use `custom/custom` when pointing Kit at any OpenAI-compatible endpoint with `--
|
||||
kit --provider-url "http://localhost:8080/v1" "Hello"
|
||||
```
|
||||
|
||||
This automatically defaults to `custom/custom` without needing to specify a model. The custom provider routes through fantasy's `openaicompat` provider and supports:
|
||||
This automatically defaults to `custom/custom` without needing to specify a model. The custom provider routes through the `openaicompat` provider and supports:
|
||||
|
||||
- Zero cost tracking (input/output = 0)
|
||||
- 262K context window, 65K output limit
|
||||
|
||||
@@ -76,6 +76,18 @@
|
||||
"name": "opencode",
|
||||
"url": "https://github.com/anomalyco/opencode",
|
||||
"branch": "dev"
|
||||
},
|
||||
{
|
||||
"type": "git",
|
||||
"name": "herald",
|
||||
"url": "https://github.com/indaco/herald",
|
||||
"branch": "main"
|
||||
},
|
||||
{
|
||||
"type": "git",
|
||||
"name": "herald-md",
|
||||
"url": "https://github.com/indaco/herald-md",
|
||||
"branch": "main"
|
||||
}
|
||||
],
|
||||
"model": "claude-haiku-4-5",
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
acp "github.com/coder/acp-go-sdk"
|
||||
|
||||
"github.com/mark3labs/kit/internal/acpserver"
|
||||
@@ -54,6 +55,8 @@ func runACP(cmd *cobra.Command, _ []string) error {
|
||||
conn.SetLogger(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: slog.LevelDebug,
|
||||
})))
|
||||
// Also set charmbracelet/log level for acpserver package logging
|
||||
log.SetLevel(log.DebugLevel)
|
||||
}
|
||||
|
||||
// Wait for either the client to disconnect or a signal.
|
||||
|
||||
+300
-7
@@ -1,9 +1,13 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/huh/v2"
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
@@ -14,7 +18,7 @@ import (
|
||||
// authCmd represents the auth command for managing AI provider authentication.
|
||||
// This command provides subcommands for login, logout, and status checking
|
||||
// of authentication credentials for various AI providers, with OAuth support
|
||||
// for providers like Anthropic.
|
||||
// for providers like Anthropic and OpenAI.
|
||||
var authCmd = &cobra.Command{
|
||||
Use: "auth",
|
||||
Short: "Manage authentication credentials for AI providers",
|
||||
@@ -25,9 +29,11 @@ using OAuth flows. Stored credentials take precedence over environment variables
|
||||
|
||||
Available providers:
|
||||
- anthropic: Anthropic Claude API (OAuth)
|
||||
- openai: OpenAI API (OAuth and API key)
|
||||
|
||||
Examples:
|
||||
kit auth login anthropic
|
||||
kit auth login openai
|
||||
kit auth logout anthropic
|
||||
kit auth status`,
|
||||
}
|
||||
@@ -46,9 +52,11 @@ environment variables when making API calls.
|
||||
|
||||
Available providers:
|
||||
- anthropic: Anthropic Claude API (OAuth)
|
||||
- openai: OpenAI ChatGPT Plus/Pro (Codex OAuth)
|
||||
|
||||
Example:
|
||||
kit auth login anthropic`,
|
||||
kit auth login anthropic
|
||||
kit auth login openai`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runAuthLogin,
|
||||
}
|
||||
@@ -61,14 +69,16 @@ var authLogoutCmd = &cobra.Command{
|
||||
Short: "Remove stored authentication credentials for a provider",
|
||||
Long: `Remove stored authentication credentials for an AI provider.
|
||||
|
||||
This will delete the stored API key for the specified provider. You will need
|
||||
to use environment variables or command-line flags for authentication after logout.
|
||||
This will delete the stored API key or OAuth credentials for the specified provider.
|
||||
You will need to use environment variables or command-line flags for authentication after logout.
|
||||
|
||||
Available providers:
|
||||
- anthropic: Anthropic Claude API
|
||||
- openai: OpenAI API
|
||||
|
||||
Example:
|
||||
kit auth logout anthropic`,
|
||||
kit auth logout anthropic
|
||||
kit auth logout openai`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runAuthLogout,
|
||||
}
|
||||
@@ -101,8 +111,10 @@ func runAuthLogin(cmd *cobra.Command, args []string) error {
|
||||
switch provider {
|
||||
case "anthropic":
|
||||
return loginAnthropic()
|
||||
case "openai":
|
||||
return loginOpenAI()
|
||||
default:
|
||||
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic", provider)
|
||||
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic, openai", provider)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,8 +124,10 @@ func runAuthLogout(cmd *cobra.Command, args []string) error {
|
||||
switch provider {
|
||||
case "anthropic":
|
||||
return logoutAnthropic()
|
||||
case "openai":
|
||||
return logoutOpenAI()
|
||||
default:
|
||||
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic", provider)
|
||||
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic, openai", provider)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,8 +171,44 @@ func runAuthStatus(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Check OpenAI credentials
|
||||
fmt.Print("\nOpenAI: ")
|
||||
if hasOpenAICreds, err := cm.HasOpenAICredentials(); err != nil {
|
||||
fmt.Printf("Error checking credentials: %v\n", err)
|
||||
} else if hasOpenAICreds {
|
||||
if creds, err := cm.GetOpenAICredentials(); err != nil {
|
||||
fmt.Printf("Error reading credentials: %v\n", err)
|
||||
} else {
|
||||
authType := "API Key"
|
||||
status := "✓ Authenticated"
|
||||
|
||||
if creds.Type == "oauth" {
|
||||
authType = "OAuth (ChatGPT/Codex)"
|
||||
if creds.IsExpired() {
|
||||
status = "⚠️ Token expired (will refresh automatically)"
|
||||
} else if creds.NeedsRefresh() {
|
||||
status = "⚠️ Token expires soon (will refresh automatically)"
|
||||
}
|
||||
}
|
||||
|
||||
accountInfo := ""
|
||||
if creds.Type == "oauth" && creds.AccountID != "" {
|
||||
accountInfo = fmt.Sprintf(" [%s]", creds.AccountID)
|
||||
}
|
||||
|
||||
fmt.Printf("%s (%s%s, stored %s)\n", status, authType, accountInfo, creds.CreatedAt.Format("2006-01-02 15:04:05"))
|
||||
}
|
||||
} else {
|
||||
fmt.Println("✗ Not authenticated")
|
||||
// Check if environment variable is set
|
||||
if os.Getenv("OPENAI_API_KEY") != "" {
|
||||
fmt.Println(" (OPENAI_API_KEY environment variable is set)")
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("\nTo authenticate with a provider:")
|
||||
fmt.Println(" kit auth login anthropic")
|
||||
fmt.Println(" kit auth login openai")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -282,3 +332,246 @@ func logoutAnthropic() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func loginOpenAI() error {
|
||||
cm, err := kit.NewCredentialManager()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize credential manager: %w", err)
|
||||
}
|
||||
|
||||
// Check if already authenticated
|
||||
if hasAuth, err := cm.HasOpenAICredentials(); err == nil && hasAuth {
|
||||
var reauth bool
|
||||
err := huh.NewConfirm().
|
||||
Title("You are already authenticated with OpenAI (ChatGPT/Codex)").
|
||||
Description("Do you want to re-authenticate?").
|
||||
Affirmative("Yes").
|
||||
Negative("No").
|
||||
Value(&reauth).
|
||||
Run()
|
||||
if err != nil || !reauth {
|
||||
fmt.Println("Authentication cancelled.")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Create OAuth client
|
||||
client := auth.NewOpenAIOAuthClient()
|
||||
|
||||
// Generate authorization URL
|
||||
fmt.Println("🔐 Starting OAuth authentication with OpenAI (ChatGPT/Codex)...")
|
||||
fmt.Println("This will open your browser to authenticate with your ChatGPT account.")
|
||||
fmt.Println()
|
||||
|
||||
authData, err := client.GetAuthorizationURL()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate authorization URL: %w", err)
|
||||
}
|
||||
|
||||
// Start local callback server
|
||||
callbackServer, err := startOpenAICallbackServer(authData.State)
|
||||
if err != nil {
|
||||
fmt.Printf("⚠️ Could not start local callback server: %v\n", err)
|
||||
fmt.Println("Falling back to manual code entry.")
|
||||
}
|
||||
if callbackServer != nil {
|
||||
defer callbackServer.Close()
|
||||
}
|
||||
|
||||
// Display URL and try to open browser
|
||||
fmt.Println("📱 Opening your browser for authentication...")
|
||||
fmt.Println("If the browser doesn't open automatically, please visit this URL:")
|
||||
fmt.Printf("\n%s\n\n", authData.URL)
|
||||
|
||||
// Try to open browser
|
||||
auth.TryOpenBrowser(authData.URL)
|
||||
|
||||
// Wait for callback or manual input
|
||||
var code string
|
||||
if callbackServer != nil {
|
||||
fmt.Println("Waiting for browser authentication...")
|
||||
select {
|
||||
case callbackCode := <-callbackServer.CodeChan:
|
||||
if callbackCode != "" {
|
||||
code = callbackCode
|
||||
fmt.Println("✓ Received authorization code from browser callback.")
|
||||
}
|
||||
case <-time.After(2 * time.Minute):
|
||||
fmt.Println("\n⏱️ Timeout waiting for browser callback.")
|
||||
callbackServer.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// If no code from callback, prompt for manual entry
|
||||
if code == "" {
|
||||
fmt.Println("\nAfter authorizing, paste the callback URL or authorization code below.")
|
||||
fmt.Println("(The callback URL will look like: http://localhost:1455/auth/callback?code=...&state=...)")
|
||||
fmt.Println()
|
||||
|
||||
var input string
|
||||
err = huh.NewInput().
|
||||
Title("Callback URL or Code").
|
||||
Description("Paste the full callback URL or just the authorization code").
|
||||
Value(&input).
|
||||
Run()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read input: %w", err)
|
||||
}
|
||||
input = strings.TrimSpace(input)
|
||||
|
||||
if input == "" {
|
||||
return fmt.Errorf("authorization code cannot be empty")
|
||||
}
|
||||
|
||||
// Parse the input (could be full URL or just code)
|
||||
parsedCode, parsedState := auth.ParseOpenAIAuthorizationInput(input)
|
||||
if parsedCode == "" {
|
||||
return fmt.Errorf("could not extract authorization code from input")
|
||||
}
|
||||
|
||||
// Validate state if provided
|
||||
if parsedState != "" && parsedState != authData.State {
|
||||
return fmt.Errorf("state mismatch - possible security issue")
|
||||
}
|
||||
code = parsedCode
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
fmt.Println("\n🔄 Exchanging authorization code for access token...")
|
||||
creds, err := client.ExchangeCode(code, authData.Verifier)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to exchange authorization code: %w", err)
|
||||
}
|
||||
|
||||
// Store the credentials
|
||||
if err := cm.SetOpenAIOAuthCredentials(creds); err != nil {
|
||||
return fmt.Errorf("failed to store credentials: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("✅ Successfully authenticated with OpenAI (ChatGPT/Codex)!")
|
||||
fmt.Printf("📁 Credentials stored in: %s\n", cm.GetCredentialsPath())
|
||||
fmt.Printf("👤 Account ID: %s\n", creds.AccountID)
|
||||
fmt.Println("\n🎉 Your OAuth credentials will now be used for OpenAI API calls.")
|
||||
fmt.Println("💡 You can check your authentication status with: kit auth status")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// callbackServer holds the HTTP server and channel for receiving the OAuth callback
|
||||
type callbackServer struct {
|
||||
Server *http.Server
|
||||
CodeChan chan string
|
||||
State string
|
||||
}
|
||||
|
||||
// Close shuts down the callback server
|
||||
func (cs *callbackServer) Close() {
|
||||
if cs.Server != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = cs.Server.Shutdown(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// startOpenAICallbackServer starts a local HTTP server to receive the OAuth callback
|
||||
func startOpenAICallbackServer(expectedState string) (*callbackServer, error) {
|
||||
codeChan := make(chan string, 1)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
server := &http.Server{
|
||||
Addr: "127.0.0.1:1455",
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
mux.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check state
|
||||
state := r.URL.Query().Get("state")
|
||||
if state != expectedState {
|
||||
http.Error(w, "State mismatch", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
code := r.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
http.Error(w, "Missing authorization code", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Send code to channel
|
||||
select {
|
||||
case codeChan <- code:
|
||||
default:
|
||||
}
|
||||
|
||||
// Return success page
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = fmt.Fprintf(w, `<!DOCTYPE html>
|
||||
<html>
|
||||
<head><title>Authentication Successful</title></head>
|
||||
<body style="font-family: sans-serif; text-align: center; padding: 50px;">
|
||||
<h1>✓ Authentication Successful</h1>
|
||||
<p>You can close this window and return to the terminal.</p>
|
||||
</body>
|
||||
</html>`)
|
||||
})
|
||||
|
||||
// Try to start server
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:1455")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("port 1455 not available: %w", err)
|
||||
}
|
||||
_ = listener.Close()
|
||||
|
||||
go func() {
|
||||
_ = server.ListenAndServe()
|
||||
}()
|
||||
|
||||
return &callbackServer{
|
||||
Server: server,
|
||||
CodeChan: codeChan,
|
||||
State: expectedState,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func logoutOpenAI() error {
|
||||
cm, err := kit.NewCredentialManager()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize credential manager: %w", err)
|
||||
}
|
||||
|
||||
// Check if authenticated
|
||||
hasAuth, err := cm.HasOpenAICredentials()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check authentication status: %w", err)
|
||||
}
|
||||
|
||||
if !hasAuth {
|
||||
fmt.Println("You are not currently authenticated with OpenAI.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Confirm logout
|
||||
var confirm bool
|
||||
err = huh.NewConfirm().
|
||||
Title("Remove OpenAI credentials").
|
||||
Description("Are you sure you want to remove your stored credentials?").
|
||||
Affirmative("Yes").
|
||||
Negative("No").
|
||||
Value(&confirm).
|
||||
Run()
|
||||
if err != nil || !confirm {
|
||||
fmt.Println("Logout cancelled.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove credentials
|
||||
if err := cm.RemoveOpenAICredentials(); err != nil {
|
||||
return fmt.Errorf("failed to remove credentials: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("✓ Successfully logged out from OpenAI!")
|
||||
fmt.Println("You will need to use environment variables or command-line flags for authentication.")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
+1
-1
@@ -55,7 +55,7 @@ func printAllProviders(showAll bool) error {
|
||||
if showAll {
|
||||
providerIDs = kit.GetSupportedProviders()
|
||||
} else {
|
||||
providerIDs = kit.GetFantasyProviders()
|
||||
providerIDs = kit.GetLLMProviders()
|
||||
}
|
||||
sort.Strings(providerIDs)
|
||||
|
||||
|
||||
+538
-116
@@ -10,7 +10,6 @@ import (
|
||||
"strings"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/fantasy"
|
||||
"charm.land/lipgloss/v2"
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
@@ -38,7 +37,6 @@ var (
|
||||
noExitFlag bool
|
||||
maxSteps int
|
||||
streamFlag bool // Enable streaming output
|
||||
compactMode bool // Enable compact output mode
|
||||
autoCompactFlag bool // Enable auto-compaction near context limit
|
||||
|
||||
// Session management
|
||||
@@ -280,8 +278,6 @@ func init() {
|
||||
IntVar(&maxSteps, "max-steps", 0, "maximum number of agent steps (0 for unlimited)")
|
||||
rootCmd.PersistentFlags().
|
||||
BoolVar(&streamFlag, "stream", true, "enable streaming output for faster response display")
|
||||
rootCmd.PersistentFlags().
|
||||
BoolVar(&compactMode, "compact", false, "enable compact output mode without fancy styling")
|
||||
rootCmd.PersistentFlags().
|
||||
BoolVar(&autoCompactFlag, "auto-compact", false, "auto-compact conversation when near context limit")
|
||||
rootCmd.PersistentFlags().
|
||||
@@ -325,7 +321,6 @@ func init() {
|
||||
_ = viper.BindPFlag("debug", rootCmd.PersistentFlags().Lookup("debug"))
|
||||
_ = viper.BindPFlag("max-steps", rootCmd.PersistentFlags().Lookup("max-steps"))
|
||||
_ = viper.BindPFlag("stream", rootCmd.PersistentFlags().Lookup("stream"))
|
||||
_ = viper.BindPFlag("compact", rootCmd.PersistentFlags().Lookup("compact"))
|
||||
_ = viper.BindPFlag("auto-compact", rootCmd.PersistentFlags().Lookup("auto-compact"))
|
||||
|
||||
_ = viper.BindPFlag("provider-url", rootCmd.PersistentFlags().Lookup("provider-url"))
|
||||
@@ -415,7 +410,7 @@ func runKit(ctx context.Context) error {
|
||||
// normalised to start with "/" so they integrate with the slash-command
|
||||
// autocomplete and dispatch pipeline.
|
||||
func extensionCommandsForUI(k *kit.Kit) []ui.ExtensionCommand {
|
||||
defs := k.ExtensionCommands()
|
||||
defs := k.Extensions().Commands()
|
||||
if len(defs) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -429,12 +424,12 @@ func extensionCommandsForUI(k *kit.Kit) []ui.ExtensionCommand {
|
||||
Name: name,
|
||||
Description: d.Description,
|
||||
Execute: func(args string) (string, error) {
|
||||
return d.Execute(args, k.GetExtensionContext())
|
||||
return d.Execute(args, k.Extensions().GetContext())
|
||||
},
|
||||
}
|
||||
if d.Complete != nil {
|
||||
ec.Complete = func(prefix string) []string {
|
||||
return d.Complete(prefix, k.GetExtensionContext())
|
||||
return d.Complete(prefix, k.Extensions().GetContext())
|
||||
}
|
||||
}
|
||||
cmds = append(cmds, ec)
|
||||
@@ -446,11 +441,11 @@ func extensionCommandsForUI(k *kit.Kit) []ui.ExtensionCommand {
|
||||
// ui.WidgetData for the given placement. Returns nil if extensions are
|
||||
// disabled, which is safe — the UI treats a nil GetWidgets as "no widgets".
|
||||
func widgetProviderForUI(k *kit.Kit) func(string) []ui.WidgetData {
|
||||
if !k.HasExtensions() {
|
||||
if !k.Extensions().HasExtensions() {
|
||||
return nil
|
||||
}
|
||||
return func(placement string) []ui.WidgetData {
|
||||
configs := k.GetExtensionWidgets(extensions.WidgetPlacement(placement))
|
||||
configs := k.Extensions().GetWidgets(extensions.WidgetPlacement(placement))
|
||||
if len(configs) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -467,25 +462,34 @@ func widgetProviderForUI(k *kit.Kit) func(string) []ui.WidgetData {
|
||||
}
|
||||
}
|
||||
|
||||
// headerFooterProviderForUI returns a provider func that maps an
|
||||
// extensions.HeaderFooterConfig getter into the ui.WidgetData shape
|
||||
// expected by AppModel. The getter argument selects header vs footer.
|
||||
func headerFooterProviderForUI(k *kit.Kit, getter func() *extensions.HeaderFooterConfig) func() *ui.WidgetData {
|
||||
if !k.Extensions().HasExtensions() {
|
||||
return nil
|
||||
}
|
||||
return func() *ui.WidgetData {
|
||||
cfg := getter()
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return &ui.WidgetData{
|
||||
Text: cfg.Content.Text,
|
||||
Markdown: cfg.Content.Markdown,
|
||||
BorderColor: cfg.Style.BorderColor,
|
||||
NoBorder: cfg.Style.NoBorder,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// headerProviderForUI returns a function that converts the extension header
|
||||
// to a *ui.WidgetData for the TUI. Returns nil if extensions are disabled,
|
||||
// which is safe — the UI treats a nil GetHeader as "no header".
|
||||
func headerProviderForUI(k *kit.Kit) func() *ui.WidgetData {
|
||||
if !k.HasExtensions() {
|
||||
return nil
|
||||
}
|
||||
return func() *ui.WidgetData {
|
||||
config := k.GetExtensionHeader()
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
return &ui.WidgetData{
|
||||
Text: config.Content.Text,
|
||||
Markdown: config.Content.Markdown,
|
||||
BorderColor: config.Style.BorderColor,
|
||||
NoBorder: config.Style.NoBorder,
|
||||
}
|
||||
}
|
||||
return headerFooterProviderForUI(k, func() *extensions.HeaderFooterConfig {
|
||||
return k.Extensions().GetHeader()
|
||||
})
|
||||
}
|
||||
|
||||
// toolRendererProviderForUI returns a function that converts extension tool
|
||||
@@ -493,11 +497,11 @@ func headerProviderForUI(k *kit.Kit) func() *ui.WidgetData {
|
||||
// disabled, which is safe — the UI treats a nil GetToolRenderer as "no
|
||||
// custom renderers".
|
||||
func toolRendererProviderForUI(k *kit.Kit) func(string) *ui.ToolRendererData {
|
||||
if !k.HasExtensions() {
|
||||
if !k.Extensions().HasExtensions() {
|
||||
return nil
|
||||
}
|
||||
return func(toolName string) *ui.ToolRendererData {
|
||||
config := k.GetExtensionToolRenderer(toolName)
|
||||
config := k.Extensions().GetToolRenderer(toolName)
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -517,11 +521,11 @@ func toolRendererProviderForUI(k *kit.Kit) func(string) *ui.ToolRendererData {
|
||||
// Returns nil if extensions are disabled, which is safe — the UI treats a
|
||||
// nil GetEditorInterceptor as "no interceptor".
|
||||
func editorInterceptorProviderForUI(k *kit.Kit) func() *ui.EditorInterceptor {
|
||||
if !k.HasExtensions() {
|
||||
if !k.Extensions().HasExtensions() {
|
||||
return nil
|
||||
}
|
||||
return func() *ui.EditorInterceptor {
|
||||
config := k.GetExtensionEditor()
|
||||
config := k.Extensions().GetEditor()
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -555,11 +559,11 @@ func editorInterceptorProviderForUI(k *kit.Kit) func() *ui.EditorInterceptor {
|
||||
// visibility overrides to a *ui.UIVisibility for the TUI. Returns nil if
|
||||
// extensions are disabled — the UI treats nil as "show everything".
|
||||
func uiVisibilityProviderForUI(k *kit.Kit) func() *ui.UIVisibility {
|
||||
if !k.HasExtensions() {
|
||||
if !k.Extensions().HasExtensions() {
|
||||
return nil
|
||||
}
|
||||
return func() *ui.UIVisibility {
|
||||
v := k.GetExtensionUIVisibility()
|
||||
v := k.Extensions().GetUIVisibility()
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -576,21 +580,9 @@ func uiVisibilityProviderForUI(k *kit.Kit) func() *ui.UIVisibility {
|
||||
// to a *ui.WidgetData for the TUI. Returns nil if extensions are disabled,
|
||||
// which is safe — the UI treats a nil GetFooter as "no footer".
|
||||
func footerProviderForUI(k *kit.Kit) func() *ui.WidgetData {
|
||||
if !k.HasExtensions() {
|
||||
return nil
|
||||
}
|
||||
return func() *ui.WidgetData {
|
||||
config := k.GetExtensionFooter()
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
return &ui.WidgetData{
|
||||
Text: config.Content.Text,
|
||||
Markdown: config.Content.Markdown,
|
||||
BorderColor: config.Style.BorderColor,
|
||||
NoBorder: config.Style.NoBorder,
|
||||
}
|
||||
}
|
||||
return headerFooterProviderForUI(k, func() *extensions.HeaderFooterConfig {
|
||||
return k.Extensions().GetFooter()
|
||||
})
|
||||
}
|
||||
|
||||
// statusBarProviderForUI returns a function that fetches extension status bar
|
||||
@@ -598,11 +590,11 @@ func footerProviderForUI(k *kit.Kit) func() *ui.WidgetData {
|
||||
// if extensions are disabled, which is safe — the TUI treats a nil
|
||||
// GetStatusBarEntries as "no extension entries".
|
||||
func statusBarProviderForUI(k *kit.Kit) func() []ui.StatusBarEntryData {
|
||||
if !k.HasExtensions() {
|
||||
if !k.Extensions().HasExtensions() {
|
||||
return nil
|
||||
}
|
||||
return func() []ui.StatusBarEntryData {
|
||||
entries := k.GetExtensionStatusEntries()
|
||||
entries := k.Extensions().GetStatusEntries()
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -622,30 +614,36 @@ func statusBarProviderForUI(k *kit.Kit) func() []ui.StatusBarEntryData {
|
||||
// and returns (cancelled, reason). Returns nil if extensions are disabled —
|
||||
// the UI treats nil as "no hook".
|
||||
func beforeForkProviderForUI(k *kit.Kit) func(string, bool, string) (bool, string) {
|
||||
if !k.HasExtensions() {
|
||||
if !k.Extensions().HasExtensions() {
|
||||
return nil
|
||||
}
|
||||
return k.EmitBeforeFork
|
||||
return func(targetID string, isUserMsg bool, userText string) (bool, string) {
|
||||
return k.Extensions().EmitBeforeFork(targetID, isUserMsg, userText)
|
||||
}
|
||||
}
|
||||
|
||||
// beforeSessionSwitchProviderForUI returns a callback that emits a
|
||||
// BeforeSessionSwitch event and returns (cancelled, reason). Returns nil
|
||||
// if extensions are disabled — the UI treats nil as "no hook".
|
||||
func beforeSessionSwitchProviderForUI(k *kit.Kit) func(string) (bool, string) {
|
||||
if !k.HasExtensions() {
|
||||
if !k.Extensions().HasExtensions() {
|
||||
return nil
|
||||
}
|
||||
return k.EmitBeforeSessionSwitch
|
||||
return func(switchReason string) (bool, string) {
|
||||
return k.Extensions().EmitBeforeSessionSwitch(switchReason)
|
||||
}
|
||||
}
|
||||
|
||||
// globalShortcutsProviderForUI returns a callback that queries the extension
|
||||
// runner for registered keyboard shortcuts. Returns nil if extensions are
|
||||
// disabled — the UI treats nil as "no shortcuts".
|
||||
func globalShortcutsProviderForUI(k *kit.Kit) func() map[string]func() {
|
||||
if !k.HasExtensions() {
|
||||
if !k.Extensions().HasExtensions() {
|
||||
return nil
|
||||
}
|
||||
return k.GetExtensionShortcuts
|
||||
return func() map[string]func() {
|
||||
return k.Extensions().GetShortcuts()
|
||||
}
|
||||
}
|
||||
|
||||
func runNormalMode(ctx context.Context) error {
|
||||
@@ -677,9 +675,15 @@ func runNormalMode(ctx context.Context) error {
|
||||
// Restore persisted model preference when no explicit --model flag or
|
||||
// config file model is set. Precedence: CLI flag > config file > saved
|
||||
// preference > built-in default. This mirrors how themes are persisted.
|
||||
// Skip custom/* models unless --provider-url is also provided, since the
|
||||
// custom provider requires a URL that was only valid for the previous session.
|
||||
if !modelFlagChanged && !viper.InConfig("model") {
|
||||
if pref := ui.LoadModelPreference(); pref != "" {
|
||||
viper.Set("model", pref)
|
||||
if strings.HasPrefix(pref, "custom/") && viper.GetString("provider-url") == "" {
|
||||
// Don't restore custom models without a provider URL
|
||||
} else {
|
||||
viper.Set("model", pref)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -700,6 +704,15 @@ func runNormalMode(ctx context.Context) error {
|
||||
viper.Set("model", "custom/custom")
|
||||
}
|
||||
|
||||
// When --provider-url is set with an explicit --model that lacks a provider
|
||||
// prefix (no "/"), auto-prefix with "custom/" for OpenAI-compatible endpoints.
|
||||
if viper.GetString("provider-url") != "" && modelFlagChanged {
|
||||
model := viper.GetString("model")
|
||||
if model != "" && !strings.Contains(model, "/") {
|
||||
viper.Set("model", "custom/"+model)
|
||||
}
|
||||
}
|
||||
|
||||
// Load MCP configuration.
|
||||
mcpConfig, err := config.LoadAndValidateConfig()
|
||||
if err != nil {
|
||||
@@ -710,7 +723,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
var spinnerFunc kit.SpinnerFunc
|
||||
if !quietFlag {
|
||||
spinnerFunc = func(fn func() error) error {
|
||||
tempCli, tempErr := ui.NewCLI(viper.GetBool("debug"), viper.GetBool("compact"))
|
||||
tempCli, tempErr := ui.NewCLI(viper.GetBool("debug"))
|
||||
if tempErr == nil {
|
||||
return tempCli.ShowSpinner(fn)
|
||||
}
|
||||
@@ -774,9 +787,9 @@ func runNormalMode(ctx context.Context) error {
|
||||
|
||||
// Load existing messages from resumed/continued sessions.
|
||||
treeSession := kitInstance.GetTreeSession()
|
||||
var messages []fantasy.Message
|
||||
var messages []kit.LLMMessage
|
||||
if treeSession != nil {
|
||||
messages = treeSession.GetFantasyMessages()
|
||||
messages = treeSession.GetLLMMessages()
|
||||
}
|
||||
|
||||
// Create the app.App instance.
|
||||
@@ -799,43 +812,53 @@ func runNormalMode(ctx context.Context) error {
|
||||
appInstance := app.New(appOpts, messages)
|
||||
defer appInstance.Close()
|
||||
|
||||
// Buffer for extension messages during startup (printed after startup banner).
|
||||
var startupExtensionMessages []string
|
||||
|
||||
// Set up extension context and emit SessionStart.
|
||||
if kitInstance.HasExtensions() {
|
||||
if kitInstance.Extensions().HasExtensions() {
|
||||
cwd, _ := os.Getwd()
|
||||
kitInstance.SetExtensionContext(extensions.Context{
|
||||
CWD: cwd,
|
||||
Model: modelName,
|
||||
Interactive: positionalPrompt == "",
|
||||
Print: func(text string) { appInstance.PrintFromExtension("", text) },
|
||||
PrintInfo: func(text string) { appInstance.PrintFromExtension("info", text) },
|
||||
PrintError: func(text string) { appInstance.PrintFromExtension("error", text) },
|
||||
kitInstance.Extensions().SetContext(extensions.Context{
|
||||
CWD: cwd,
|
||||
Model: modelName,
|
||||
Interactive: positionalPrompt == "",
|
||||
Print: func(text string) {
|
||||
// Capture messages during startup, print after startup banner.
|
||||
startupExtensionMessages = append(startupExtensionMessages, text)
|
||||
},
|
||||
PrintInfo: func(text string) {
|
||||
startupExtensionMessages = append(startupExtensionMessages, text)
|
||||
},
|
||||
PrintError: func(text string) {
|
||||
startupExtensionMessages = append(startupExtensionMessages, text)
|
||||
},
|
||||
PrintBlock: appInstance.PrintBlockFromExtension,
|
||||
SendMessage: func(text string) { appInstance.Run(text) },
|
||||
CancelAndSend: func(text string) { appInstance.Steer(text) },
|
||||
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
SetWidget: func(config extensions.WidgetConfig) {
|
||||
kitInstance.SetExtensionWidget(config)
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
kitInstance.Extensions().SetWidget(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveWidget: func(id string) {
|
||||
kitInstance.RemoveExtensionWidget(id)
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
kitInstance.Extensions().RemoveWidget(id)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetHeader: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.SetExtensionHeader(config)
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
kitInstance.Extensions().SetHeader(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveHeader: func() {
|
||||
kitInstance.RemoveExtensionHeader()
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
kitInstance.Extensions().RemoveHeader()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetFooter: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.SetExtensionFooter(config)
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
kitInstance.Extensions().SetFooter(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveFooter: func() {
|
||||
kitInstance.RemoveExtensionFooter()
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
kitInstance.Extensions().RemoveFooter()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
PromptSelect: func(config extensions.PromptSelectConfig) extensions.PromptSelectResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
@@ -885,8 +908,8 @@ func runNormalMode(ctx context.Context) error {
|
||||
return extensions.PromptInputResult{Value: resp.Value}
|
||||
},
|
||||
SetUIVisibility: func(v extensions.UIVisibility) {
|
||||
kitInstance.SetExtensionUIVisibility(v)
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
kitInstance.Extensions().SetUIVisibility(v)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetContextStats: func() extensions.ContextStats {
|
||||
s := kitInstance.GetContextStats()
|
||||
@@ -898,53 +921,52 @@ func runNormalMode(ctx context.Context) error {
|
||||
}
|
||||
},
|
||||
SetEditor: func(config extensions.EditorConfig) {
|
||||
kitInstance.SetExtensionEditor(config)
|
||||
// Use a goroutine for NotifyWidgetUpdate because this may be
|
||||
// called from within an editor HandleKey callback, which runs
|
||||
// synchronously inside BubbleTea's Update(). Calling prog.Send()
|
||||
// directly from Update() deadlocks the event loop.
|
||||
kitInstance.Extensions().SetEditor(config)
|
||||
// Always use a goroutine for NotifyWidgetUpdate: prog.Send()
|
||||
// deadlocks if called synchronously from inside BubbleTea's
|
||||
// Update() handler. All call sites use go-routines uniformly.
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
ResetEditor: func() {
|
||||
kitInstance.ResetExtensionEditor()
|
||||
kitInstance.Extensions().ResetEditor()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetMessages: func() []extensions.SessionMessage {
|
||||
return kitInstance.GetSessionMessages()
|
||||
return kitInstance.Extensions().GetSessionMessages()
|
||||
},
|
||||
GetSessionPath: func() string {
|
||||
return kitInstance.GetSessionFilePath()
|
||||
return kitInstance.GetSessionPath()
|
||||
},
|
||||
AppendEntry: func(entryType string, data string) (string, error) {
|
||||
return kitInstance.AppendExtensionEntry(entryType, data)
|
||||
return kitInstance.Extensions().AppendEntry(entryType, data)
|
||||
},
|
||||
GetEntries: func(entryType string) []extensions.ExtensionEntry {
|
||||
return kitInstance.GetExtensionEntries(entryType)
|
||||
return kitInstance.Extensions().GetEntries(entryType)
|
||||
},
|
||||
SetEditorText: func(text string) {
|
||||
appInstance.SetEditorTextFromExtension(text)
|
||||
},
|
||||
SetStatus: func(key string, text string, priority int) {
|
||||
kitInstance.SetExtensionStatus(extensions.StatusBarEntry{
|
||||
kitInstance.Extensions().SetStatus(extensions.StatusBarEntry{
|
||||
Key: key,
|
||||
Text: text,
|
||||
Priority: priority,
|
||||
})
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveStatus: func(key string) {
|
||||
kitInstance.RemoveExtensionStatus(key)
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
kitInstance.Extensions().RemoveStatus(key)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetOption: func(name string) string {
|
||||
return kitInstance.GetExtensionOption(name)
|
||||
return kitInstance.Extensions().GetOption(name)
|
||||
},
|
||||
SetOption: func(name string, value string) {
|
||||
kitInstance.SetExtensionOption(name, value)
|
||||
kitInstance.Extensions().SetOption(name, value)
|
||||
},
|
||||
SetModel: func(modelString string) error {
|
||||
// Capture previous model for the ModelChange event.
|
||||
previousModel := kitInstance.GetExtensionContext().Model
|
||||
previousModel := kitInstance.Extensions().GetContext().Model
|
||||
err := kitInstance.SetModel(context.Background(), modelString)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -953,9 +975,9 @@ func runNormalMode(ctx context.Context) error {
|
||||
p, m, _ := models.ParseModelString(modelString)
|
||||
appInstance.NotifyModelChanged(p, m)
|
||||
// Update the context's Model field so handlers see it.
|
||||
kitInstance.UpdateExtensionContextModel(modelString)
|
||||
kitInstance.Extensions().UpdateContextModel(modelString)
|
||||
// Fire OnModelChange event to extensions.
|
||||
kitInstance.EmitModelChange(modelString, previousModel, "extension")
|
||||
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
|
||||
// Update usage tracker with new model info for correct token counting.
|
||||
if usageTracker != nil {
|
||||
newProvider, newModel, _ := models.ParseModelString(modelString)
|
||||
@@ -980,7 +1002,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
return kitInstance.GetAvailableModels()
|
||||
},
|
||||
EmitCustomEvent: func(name string, data string) {
|
||||
kitInstance.EmitExtensionCustomEvent(name, data)
|
||||
kitInstance.Extensions().EmitCustomEvent(name, data)
|
||||
},
|
||||
Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
|
||||
return kitInstance.ExecuteCompletion(context.Background(), req)
|
||||
@@ -989,7 +1011,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
return appInstance.SuspendTUI(callback)
|
||||
},
|
||||
RenderMessage: func(rendererName, content string) {
|
||||
renderer := kitInstance.GetExtensionMessageRenderer(rendererName)
|
||||
renderer := kitInstance.Extensions().GetMessageRenderer(rendererName)
|
||||
if renderer == nil || renderer.Render == nil {
|
||||
appInstance.PrintFromExtension("", content)
|
||||
return
|
||||
@@ -1002,19 +1024,19 @@ func runNormalMode(ctx context.Context) error {
|
||||
appInstance.PrintFromExtension("", rendered)
|
||||
},
|
||||
ReloadExtensions: func() error {
|
||||
err := kitInstance.ReloadExtensions()
|
||||
err := kitInstance.Extensions().Reload()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Notify TUI that widgets/status/commands may have changed.
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
return nil
|
||||
},
|
||||
GetAllTools: func() []extensions.ToolInfo {
|
||||
return kitInstance.GetExtensionToolInfos()
|
||||
return kitInstance.Extensions().GetToolInfos()
|
||||
},
|
||||
SetActiveTools: func(names []string) {
|
||||
kitInstance.SetExtensionActiveTools(names)
|
||||
kitInstance.Extensions().SetActiveTools(names)
|
||||
},
|
||||
RegisterTheme: func(name string, config extensions.ThemeColorConfig) {
|
||||
tc := func(c extensions.ThemeColor) [2]string { return [2]string{c.Light, c.Dark} }
|
||||
@@ -1085,7 +1107,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
}
|
||||
extResult := &extensions.SubagentResult{
|
||||
Response: result.Response,
|
||||
Error: result.Error,
|
||||
Error: err,
|
||||
SessionID: result.SessionID,
|
||||
Elapsed: result.Elapsed,
|
||||
}
|
||||
@@ -1097,8 +1119,398 @@ func runNormalMode(ctx context.Context) error {
|
||||
}
|
||||
return nil, extResult, err
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Tree Navigation API (Phase 1 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
GetTreeNode: func(entryID string) *extensions.TreeNode {
|
||||
node := kitInstance.GetTreeNode(entryID)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
return &extensions.TreeNode{
|
||||
ID: node.ID,
|
||||
ParentID: node.ParentID,
|
||||
Type: node.Type,
|
||||
Role: node.Role,
|
||||
Content: node.Content,
|
||||
Model: node.Model,
|
||||
Provider: node.Provider,
|
||||
Timestamp: node.Timestamp,
|
||||
Children: node.Children,
|
||||
}
|
||||
},
|
||||
GetCurrentBranch: func() []extensions.TreeNode {
|
||||
nodes := kitInstance.GetCurrentBranch()
|
||||
result := make([]extensions.TreeNode, len(nodes))
|
||||
for i, n := range nodes {
|
||||
result[i] = extensions.TreeNode{
|
||||
ID: n.ID,
|
||||
ParentID: n.ParentID,
|
||||
Type: n.Type,
|
||||
Role: n.Role,
|
||||
Content: n.Content,
|
||||
Model: n.Model,
|
||||
Provider: n.Provider,
|
||||
Timestamp: n.Timestamp,
|
||||
Children: n.Children,
|
||||
}
|
||||
}
|
||||
return result
|
||||
},
|
||||
GetChildren: kitInstance.GetChildren,
|
||||
NavigateTo: func(entryID string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.NavigateTo(entryID)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
SummarizeBranch: func(fromID, toID string) string {
|
||||
summary, _ := kitInstance.SummarizeBranch(fromID, toID)
|
||||
return summary
|
||||
},
|
||||
CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.CollapseBranch(fromID, toID, summary)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Skill Loading API (Phase 2 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
LoadSkill: func(path string) (*extensions.Skill, string) {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
return s, err
|
||||
},
|
||||
LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult {
|
||||
return kitInstance.LoadSkillsFromDirForExtension(dir)
|
||||
},
|
||||
DiscoverSkills: func() extensions.SkillLoadResult {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
return extensions.SkillLoadResult{Skills: skills}
|
||||
},
|
||||
InjectSkillAsContext: func(skillName string) string {
|
||||
// Find skill by name
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
for _, s := range skills {
|
||||
if s.Name == skillName {
|
||||
// Inject via SendMessage as a system context message
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("skill not found: %s", skillName)
|
||||
},
|
||||
InjectRawSkillAsContext: func(path string) string {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
if err != "" {
|
||||
return err
|
||||
}
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
},
|
||||
GetAvailableSkills: kitInstance.DiscoverSkillsForExtension,
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Template Parsing API (Phase 3 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
ParseTemplate: kit.ParseTemplate,
|
||||
RenderTemplate: kit.RenderTemplate,
|
||||
ParseArguments: kit.ParseArguments,
|
||||
SimpleParseArguments: kit.SimpleParseArguments,
|
||||
EvaluateModelConditional: func(condition string) bool {
|
||||
return kit.EvaluateModelConditional(kitInstance.Extensions().GetContext().Model, condition)
|
||||
},
|
||||
RenderWithModelConditionals: func(content string) string {
|
||||
return kit.RenderWithModelConditionals(content, kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Model Resolution API (Phase 4 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
ResolveModelChain: kit.ResolveModelChain,
|
||||
GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) {
|
||||
return kit.GetModelCapabilities(model)
|
||||
},
|
||||
CheckModelAvailable: kit.CheckModelAvailable,
|
||||
GetCurrentProvider: func() string {
|
||||
return kit.GetCurrentProvider(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
GetCurrentModelID: func() string {
|
||||
return kit.GetCurrentModelID(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
})
|
||||
kitInstance.Extensions().EmitSessionStart()
|
||||
|
||||
// Restore normal print functions for runtime use.
|
||||
kitInstance.Extensions().SetContext(extensions.Context{
|
||||
CWD: cwd,
|
||||
Model: modelName,
|
||||
Interactive: positionalPrompt == "",
|
||||
Print: func(text string) { appInstance.PrintFromExtension("", text) },
|
||||
PrintInfo: func(text string) { appInstance.PrintFromExtension("info", text) },
|
||||
PrintError: func(text string) { appInstance.PrintFromExtension("error", text) },
|
||||
PrintBlock: appInstance.PrintBlockFromExtension,
|
||||
SendMessage: func(text string) { appInstance.Run(text) },
|
||||
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
SetWidget: func(config extensions.WidgetConfig) {
|
||||
kitInstance.Extensions().SetWidget(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveWidget: func(id string) {
|
||||
kitInstance.Extensions().RemoveWidget(id)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetHeader: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetHeader(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveHeader: func() {
|
||||
kitInstance.Extensions().RemoveHeader()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetFooter: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetFooter(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveFooter: func() {
|
||||
kitInstance.Extensions().RemoveFooter()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
PromptSelect: func(config extensions.PromptSelectConfig) extensions.PromptSelectResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "select",
|
||||
Message: config.Message,
|
||||
Options: config.Options,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptSelectResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptSelectResult{Value: resp.Value, Index: resp.Index}
|
||||
},
|
||||
PromptConfirm: func(config extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
def := "false"
|
||||
if config.DefaultValue {
|
||||
def = "true"
|
||||
}
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "confirm",
|
||||
Message: config.Message,
|
||||
Default: def,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptConfirmResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptConfirmResult{Value: resp.Confirmed}
|
||||
},
|
||||
PromptInput: func(config extensions.PromptInputConfig) extensions.PromptInputResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "input",
|
||||
Message: config.Message,
|
||||
Placeholder: config.Placeholder,
|
||||
Default: config.Default,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptInputResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptInputResult{Value: resp.Value}
|
||||
},
|
||||
ShowOverlay: func(config extensions.OverlayConfig) extensions.OverlayResult {
|
||||
ch := make(chan app.OverlayResponse, 1)
|
||||
appInstance.SendOverlayRequest(app.OverlayRequestEvent{
|
||||
Title: config.Title,
|
||||
Content: config.Content.Text,
|
||||
Markdown: config.Content.Markdown,
|
||||
BorderColor: config.Style.BorderColor,
|
||||
Background: config.Style.Background,
|
||||
Width: config.Width,
|
||||
MaxHeight: config.MaxHeight,
|
||||
Anchor: string(config.Anchor),
|
||||
Actions: config.Actions,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.OverlayResult{Cancelled: true, Index: -1}
|
||||
}
|
||||
return extensions.OverlayResult{
|
||||
Action: resp.Action,
|
||||
Index: resp.Index,
|
||||
}
|
||||
},
|
||||
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
// In-process subagent via SDK.
|
||||
sdkCfg := kit.SubagentConfig{
|
||||
Prompt: config.Prompt,
|
||||
Model: config.Model,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Timeout: config.Timeout,
|
||||
NoSession: config.NoSession,
|
||||
}
|
||||
// Bridge SDK events to extension SubagentEvents.
|
||||
if config.OnEvent != nil {
|
||||
sdkCfg.OnEvent = func(e kit.Event) {
|
||||
se := sdkEventToSubagentEvent(e)
|
||||
if se.Type != "" {
|
||||
config.OnEvent(se)
|
||||
}
|
||||
}
|
||||
}
|
||||
result, err := kitInstance.Subagent(ctx, sdkCfg)
|
||||
if result == nil {
|
||||
return nil, &extensions.SubagentResult{Error: err}, err
|
||||
}
|
||||
extResult := &extensions.SubagentResult{
|
||||
Response: result.Response,
|
||||
Error: err,
|
||||
SessionID: result.SessionID,
|
||||
Elapsed: result.Elapsed,
|
||||
}
|
||||
if result.Usage != nil {
|
||||
extResult.Usage = &extensions.SubagentUsage{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
return nil, extResult, err
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Tree Navigation API (Phase 1 Bridge) - Second Context
|
||||
// -------------------------------------------------------------------------
|
||||
GetTreeNode: func(entryID string) *extensions.TreeNode {
|
||||
node := kitInstance.GetTreeNode(entryID)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
return &extensions.TreeNode{
|
||||
ID: node.ID,
|
||||
ParentID: node.ParentID,
|
||||
Type: node.Type,
|
||||
Role: node.Role,
|
||||
Content: node.Content,
|
||||
Model: node.Model,
|
||||
Provider: node.Provider,
|
||||
Timestamp: node.Timestamp,
|
||||
Children: node.Children,
|
||||
}
|
||||
},
|
||||
GetCurrentBranch: func() []extensions.TreeNode {
|
||||
nodes := kitInstance.GetCurrentBranch()
|
||||
result := make([]extensions.TreeNode, len(nodes))
|
||||
for i, n := range nodes {
|
||||
result[i] = extensions.TreeNode{
|
||||
ID: n.ID,
|
||||
ParentID: n.ParentID,
|
||||
Type: n.Type,
|
||||
Role: n.Role,
|
||||
Content: n.Content,
|
||||
Model: n.Model,
|
||||
Provider: n.Provider,
|
||||
Timestamp: n.Timestamp,
|
||||
Children: n.Children,
|
||||
}
|
||||
}
|
||||
return result
|
||||
},
|
||||
GetChildren: kitInstance.GetChildren,
|
||||
NavigateTo: func(entryID string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.NavigateTo(entryID)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
SummarizeBranch: func(fromID, toID string) string {
|
||||
summary, _ := kitInstance.SummarizeBranch(fromID, toID)
|
||||
return summary
|
||||
},
|
||||
CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.CollapseBranch(fromID, toID, summary)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Skill Loading API (Phase 2 Bridge) - Second Context
|
||||
// -------------------------------------------------------------------------
|
||||
LoadSkill: func(path string) (*extensions.Skill, string) {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
return s, err
|
||||
},
|
||||
LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult {
|
||||
return kitInstance.LoadSkillsFromDirForExtension(dir)
|
||||
},
|
||||
DiscoverSkills: func() extensions.SkillLoadResult {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
return extensions.SkillLoadResult{Skills: skills}
|
||||
},
|
||||
InjectSkillAsContext: func(skillName string) string {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
for _, s := range skills {
|
||||
if s.Name == skillName {
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("skill not found: %s", skillName)
|
||||
},
|
||||
InjectRawSkillAsContext: func(path string) string {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
if err != "" {
|
||||
return err
|
||||
}
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
},
|
||||
GetAvailableSkills: func() []extensions.Skill {
|
||||
return kitInstance.DiscoverSkillsForExtension()
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Template Parsing API (Phase 3 Bridge) - Second Context
|
||||
// -------------------------------------------------------------------------
|
||||
ParseTemplate: kit.ParseTemplate,
|
||||
RenderTemplate: kit.RenderTemplate,
|
||||
ParseArguments: kit.ParseArguments,
|
||||
SimpleParseArguments: kit.SimpleParseArguments,
|
||||
EvaluateModelConditional: func(condition string) bool {
|
||||
return kit.EvaluateModelConditional(kitInstance.Extensions().GetContext().Model, condition)
|
||||
},
|
||||
RenderWithModelConditionals: func(content string) string {
|
||||
return kit.RenderWithModelConditionals(content, kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Model Resolution API (Phase 4 Bridge) - Second Context
|
||||
// -------------------------------------------------------------------------
|
||||
ResolveModelChain: kit.ResolveModelChain,
|
||||
GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) {
|
||||
return kit.GetModelCapabilities(model)
|
||||
},
|
||||
CheckModelAvailable: kit.CheckModelAvailable,
|
||||
GetCurrentProvider: func() string {
|
||||
return kit.GetCurrentProvider(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
GetCurrentModelID: func() string {
|
||||
return kit.GetCurrentModelID(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
})
|
||||
kitInstance.EmitSessionStart()
|
||||
}
|
||||
|
||||
// Convert extension commands to UI-layer type for the interactive TUI.
|
||||
@@ -1166,7 +1578,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
// Update the extension context's Model field so handlers see it.
|
||||
kitInstance.UpdateExtensionContextModel(modelString)
|
||||
kitInstance.Extensions().UpdateContextModel(modelString)
|
||||
// NOTE: We do NOT call appInstance.NotifyModelChanged() here because
|
||||
// this callback runs synchronously inside BubbleTea's Update(), and
|
||||
// NotifyModelChanged calls prog.Send() which deadlocks. The UI layer
|
||||
@@ -1192,7 +1604,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
emitModelChangeForUI := func(newModel, previousModel, source string) {
|
||||
kitInstance.EmitModelChange(newModel, previousModel, source)
|
||||
kitInstance.Extensions().EmitModelChange(newModel, previousModel, source)
|
||||
}
|
||||
|
||||
// Build thinking level callback.
|
||||
@@ -1222,7 +1634,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
return fmt.Errorf("--quiet requires a prompt")
|
||||
}
|
||||
|
||||
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI)
|
||||
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, startupExtensionMessages)
|
||||
}
|
||||
|
||||
// runNonInteractiveModeApp executes a single prompt via the app layer and exits,
|
||||
@@ -1278,7 +1690,7 @@ func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui
|
||||
|
||||
// If --no-exit was requested, hand off to the interactive TUI.
|
||||
if noExit {
|
||||
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, providerName, loadingMessage, serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModel, emitModelChange, isReasoningModel, thinkingLevel, setThinkingLevel, switchSession)
|
||||
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, providerName, loadingMessage, serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModel, emitModelChange, isReasoningModel, thinkingLevel, setThinkingLevel, switchSession, nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1333,7 +1745,7 @@ func buildJSONOutput(result *kit.TurnResult, model string) ([]byte, error) {
|
||||
}
|
||||
|
||||
for _, fmsg := range result.Messages {
|
||||
converted := kit.ConvertFromFantasyMessage(fmsg)
|
||||
converted := kit.ConvertFromLLMMessage(fmsg)
|
||||
m := jsonMessage{Role: string(converted.Role)}
|
||||
for _, p := range converted.Parts {
|
||||
switch c := p.(type) {
|
||||
@@ -1376,7 +1788,7 @@ func writeJSONError(err error) {
|
||||
// 4. Calls program.Run() which blocks until the user quits (Ctrl+C or /quit).
|
||||
//
|
||||
// SetupCLI is not used for interactive mode; the TUI (AppModel) handles its own rendering.
|
||||
func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []ui.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []ui.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error) error {
|
||||
func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []ui.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []ui.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, startupExtensionMessages []string) error {
|
||||
// Determine terminal size; fall back gracefully.
|
||||
termWidth, termHeight, err := term.GetSize(int(os.Stdout.Fd()))
|
||||
if err != nil || termWidth == 0 {
|
||||
@@ -1387,7 +1799,6 @@ func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelN
|
||||
cwd, _ := os.Getwd()
|
||||
|
||||
appModel := ui.NewAppModel(appInstance, ui.AppModelOptions{
|
||||
CompactMode: viper.GetBool("compact"),
|
||||
ModelName: modelName,
|
||||
ProviderName: providerName,
|
||||
LoadingMessage: loadingMessage,
|
||||
@@ -1423,9 +1834,20 @@ func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelN
|
||||
ShowSessionPicker: resumeFlag,
|
||||
})
|
||||
|
||||
// Print startup info to stdout before Bubble Tea takes over the screen.
|
||||
// Print KIT banner and startup info to stdout before Bubble Tea takes over the screen.
|
||||
fmt.Println(kitBanner())
|
||||
fmt.Println()
|
||||
appModel.PrintStartupInfo()
|
||||
|
||||
// Print any extension messages that were captured during startup.
|
||||
if len(startupExtensionMessages) > 0 {
|
||||
fmt.Println()
|
||||
for _, msg := range startupExtensionMessages {
|
||||
fmt.Println(msg)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
program := tea.NewProgram(appModel)
|
||||
|
||||
// Register the program with the app layer so agent events are sent to the TUI.
|
||||
|
||||
@@ -41,7 +41,6 @@ func BuildAppOptions(mcpConfig *config.Config, modelName string, serverNames, to
|
||||
StreamingEnabled: viper.GetBool("stream"),
|
||||
Quiet: quietFlag,
|
||||
Debug: viper.GetBool("debug"),
|
||||
CompactMode: viper.GetBool("compact"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,7 +130,6 @@ func SetupCLIForNonInteractive(k *kit.Kit) (*ui.CLI, error) {
|
||||
Agent: agentAdapter,
|
||||
ModelString: viper.GetString("model"),
|
||||
Debug: viper.GetBool("debug"),
|
||||
Compact: viper.GetBool("compact"),
|
||||
Quiet: quietFlag,
|
||||
ShowDebug: false,
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/kit/pkg/extensions/test"
|
||||
)
|
||||
|
||||
// TestAllExtensions_Load is a smoke test that verifies every single-file
|
||||
// example extension in this directory can be loaded by the Yaegi interpreter
|
||||
// without errors. This catches syntax errors, missing symbols, bad imports,
|
||||
// and Init signature mismatches.
|
||||
func TestAllExtensions_Load(t *testing.T) {
|
||||
files := extensionFiles(t)
|
||||
|
||||
for _, file := range files {
|
||||
t.Run(file, func(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
ext := harness.LoadFile(file)
|
||||
if ext == nil {
|
||||
t.Fatalf("%s: extension should not be nil after loading", file)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Logf("successfully loaded %d extensions", len(files))
|
||||
}
|
||||
@@ -0,0 +1,253 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/pkg/extensions/test"
|
||||
)
|
||||
|
||||
// extensionFiles returns all single-file extensions in the current directory.
|
||||
// It skips test files, the test template, and files without an Init function.
|
||||
func extensionFiles(t *testing.T) []string {
|
||||
t.Helper()
|
||||
|
||||
skip := map[string]bool{
|
||||
"extension_test_template.go": true,
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(".")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read directory: %v", err)
|
||||
}
|
||||
|
||||
var files []string
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if entry.IsDir() || filepath.Ext(name) != ".go" {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(name, "_test.go") || skip[name] {
|
||||
continue
|
||||
}
|
||||
src, err := os.ReadFile(name)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read %s: %v", name, err)
|
||||
}
|
||||
if !strings.Contains(string(src), "func Init(") {
|
||||
continue
|
||||
}
|
||||
files = append(files, name)
|
||||
}
|
||||
|
||||
if len(files) == 0 {
|
||||
t.Fatal("no extensions found — check the directory")
|
||||
}
|
||||
return files
|
||||
}
|
||||
|
||||
// TestAllExtensions_Lifecycle verifies that every extension survives a full
|
||||
// SessionStart → SessionShutdown round-trip without errors.
|
||||
func TestAllExtensions_Lifecycle(t *testing.T) {
|
||||
for _, file := range extensionFiles(t) {
|
||||
t.Run(file, func(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile(file)
|
||||
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{
|
||||
SessionID: "smoke-test-session",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionStart error: %v", err)
|
||||
}
|
||||
|
||||
_, err = harness.Emit(extensions.SessionShutdownEvent{})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionShutdown error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAllExtensions_CommandSanity checks that every registered command has
|
||||
// a non-empty name, a non-empty description, no spaces in the name, no
|
||||
// leading slash, a non-nil Execute function, and no duplicate names.
|
||||
func TestAllExtensions_CommandSanity(t *testing.T) {
|
||||
for _, file := range extensionFiles(t) {
|
||||
t.Run(file, func(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile(file)
|
||||
|
||||
cmds := harness.RegisteredCommands()
|
||||
seen := make(map[string]bool)
|
||||
for _, cmd := range cmds {
|
||||
if cmd.Name == "" {
|
||||
t.Error("command has empty name")
|
||||
}
|
||||
if strings.Contains(cmd.Name, " ") {
|
||||
t.Errorf("command %q contains spaces", cmd.Name)
|
||||
}
|
||||
if strings.HasPrefix(cmd.Name, "/") {
|
||||
t.Errorf("command %q has leading slash (framework adds it)", cmd.Name)
|
||||
}
|
||||
if cmd.Description == "" {
|
||||
t.Errorf("command %q has empty description", cmd.Name)
|
||||
}
|
||||
if cmd.Execute == nil {
|
||||
t.Errorf("command %q has nil Execute function", cmd.Name)
|
||||
}
|
||||
if seen[cmd.Name] {
|
||||
t.Errorf("duplicate command name %q", cmd.Name)
|
||||
}
|
||||
seen[cmd.Name] = true
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAllExtensions_ToolSanity checks that every registered tool has a
|
||||
// non-empty name, a non-empty description, at least one executor, valid
|
||||
// JSON in its Parameters field, and no duplicate names.
|
||||
func TestAllExtensions_ToolSanity(t *testing.T) {
|
||||
for _, file := range extensionFiles(t) {
|
||||
t.Run(file, func(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile(file)
|
||||
|
||||
tools := harness.RegisteredTools()
|
||||
seen := make(map[string]bool)
|
||||
for _, tool := range tools {
|
||||
if tool.Name == "" {
|
||||
t.Error("tool has empty name")
|
||||
}
|
||||
if tool.Description == "" {
|
||||
t.Errorf("tool %q has empty description", tool.Name)
|
||||
}
|
||||
if tool.Execute == nil && tool.ExecuteWithContext == nil {
|
||||
t.Errorf("tool %q has no executor (both Execute and ExecuteWithContext are nil)", tool.Name)
|
||||
}
|
||||
if tool.Parameters != "" && !json.Valid([]byte(tool.Parameters)) {
|
||||
t.Errorf("tool %q has invalid JSON in Parameters: %s", tool.Name, tool.Parameters)
|
||||
}
|
||||
if seen[tool.Name] {
|
||||
t.Errorf("duplicate tool name %q", tool.Name)
|
||||
}
|
||||
seen[tool.Name] = true
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAllExtensions_ZeroValueEvents fires every event type (as zero-value
|
||||
// structs) at each extension and verifies no errors are returned. Extensions
|
||||
// should be resilient to events they don't handle and to events with empty
|
||||
// fields.
|
||||
func TestAllExtensions_ZeroValueEvents(t *testing.T) {
|
||||
// Build the set of zero-value events for every event type.
|
||||
zeroEvents := []extensions.Event{
|
||||
extensions.ToolCallEvent{},
|
||||
extensions.ToolExecutionStartEvent{},
|
||||
extensions.ToolExecutionEndEvent{},
|
||||
extensions.ToolOutputEvent{},
|
||||
extensions.ToolResultEvent{},
|
||||
extensions.InputEvent{},
|
||||
extensions.BeforeAgentStartEvent{},
|
||||
extensions.AgentStartEvent{},
|
||||
extensions.AgentEndEvent{},
|
||||
extensions.MessageStartEvent{},
|
||||
extensions.MessageUpdateEvent{},
|
||||
extensions.MessageEndEvent{},
|
||||
extensions.SessionStartEvent{},
|
||||
extensions.SessionShutdownEvent{},
|
||||
extensions.ModelChangeEvent{},
|
||||
extensions.ContextPrepareEvent{},
|
||||
extensions.BeforeForkEvent{},
|
||||
extensions.BeforeSessionSwitchEvent{},
|
||||
extensions.BeforeCompactEvent{},
|
||||
extensions.SubagentStartEvent{},
|
||||
extensions.SubagentChunkEvent{},
|
||||
extensions.SubagentEndEvent{},
|
||||
}
|
||||
|
||||
for _, file := range extensionFiles(t) {
|
||||
t.Run(file, func(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile(file)
|
||||
|
||||
for _, ev := range zeroEvents {
|
||||
_, err := harness.Emit(ev)
|
||||
if err != nil {
|
||||
t.Errorf("event %T returned error: %v", ev, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAllExtensions_WidgetSanity emits SessionStart and then checks that
|
||||
// any widgets set during initialization have non-empty IDs and valid
|
||||
// placements.
|
||||
func TestAllExtensions_WidgetSanity(t *testing.T) {
|
||||
validPlacements := map[extensions.WidgetPlacement]bool{
|
||||
"above": true,
|
||||
"below": true,
|
||||
}
|
||||
|
||||
for _, file := range extensionFiles(t) {
|
||||
t.Run(file, func(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile(file)
|
||||
|
||||
// Trigger SessionStart so extensions that set widgets on init do so.
|
||||
_, _ = harness.Emit(extensions.SessionStartEvent{
|
||||
SessionID: "widget-sanity-test",
|
||||
})
|
||||
|
||||
// Widgets is an exported field on MockContext; reads are safe
|
||||
// here because Emit returned synchronously.
|
||||
for id, w := range harness.Context().Widgets {
|
||||
if w.ID == "" {
|
||||
t.Errorf("widget stored with key %q has empty ID", id)
|
||||
}
|
||||
if w.ID != id {
|
||||
t.Errorf("widget key %q doesn't match widget ID %q", id, w.ID)
|
||||
}
|
||||
if !validPlacements[w.Placement] {
|
||||
t.Errorf("widget %q has invalid placement %q (want \"above\" or \"below\")", id, w.Placement)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAllExtensions_IdempotentLifecycle verifies that receiving SessionStart
|
||||
// twice and SessionShutdown twice doesn't cause errors — extensions should
|
||||
// be defensive about repeated lifecycle events.
|
||||
func TestAllExtensions_IdempotentLifecycle(t *testing.T) {
|
||||
for _, file := range extensionFiles(t) {
|
||||
t.Run(file, func(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile(file)
|
||||
|
||||
for i := range 2 {
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{
|
||||
SessionID: "idempotent-test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionStart #%d error: %v", i+1, err)
|
||||
}
|
||||
}
|
||||
|
||||
for i := range 2 {
|
||||
_, err := harness.Emit(extensions.SessionShutdownEvent{})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionShutdown #%d error: %v", i+1, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,170 @@
|
||||
//go:build ignore
|
||||
|
||||
// bridge_demo.go - Demonstrates the new bridged SDK APIs for extensions.
|
||||
// This extension showcases tree navigation, skill loading, template parsing,
|
||||
// and model resolution capabilities.
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
var (
|
||||
discoveredSkills []ext.Skill
|
||||
currentBranch []ext.TreeNode
|
||||
)
|
||||
|
||||
func Init(api ext.API) {
|
||||
// Register /tree-info command to demonstrate tree navigation
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "tree-info",
|
||||
Description: "Show current conversation tree information",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
branch := ctx.GetCurrentBranch()
|
||||
info := fmt.Sprintf("Current branch has %d nodes:\n", len(branch))
|
||||
for i, node := range branch {
|
||||
info += fmt.Sprintf(" [%d] %s (%s): %s...\n", i, node.Type, node.ID[:8], truncate(node.Content, 40))
|
||||
}
|
||||
ctx.PrintInfo(info)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /discover-skills command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "discover-skills",
|
||||
Description: "Discover and list available skills",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
result := ctx.DiscoverSkills()
|
||||
if result.Error != "" {
|
||||
return "", fmt.Errorf("discovery failed: %s", result.Error)
|
||||
}
|
||||
discoveredSkills = result.Skills
|
||||
|
||||
info := fmt.Sprintf("Discovered %d skills:\n", len(result.Skills))
|
||||
for _, s := range result.Skills {
|
||||
info += fmt.Sprintf(" - %s: %s\n", s.Name, s.Description)
|
||||
}
|
||||
ctx.PrintInfo(info)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /parse-template command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "parse-template",
|
||||
Description: "Parse a template and show extracted variables",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
if args == "" {
|
||||
args = "Hello {{name}}, welcome to {{place}}!"
|
||||
}
|
||||
tpl := ctx.ParseTemplate("demo", args)
|
||||
info := fmt.Sprintf("Template: %s\nVariables: %v", tpl.Content, tpl.Variables)
|
||||
ctx.PrintInfo(info)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /render-template command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "render-template",
|
||||
Description: "Render a template with variables (usage: /render-template name=John place=Kit)",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
tpl := ctx.ParseTemplate("demo", "Hello {{name}}, welcome to {{place}}!")
|
||||
vars := ctx.ParseArguments(args, ext.ArgumentPattern{
|
||||
Flags: map[string]string{"name": "name", "place": "place"},
|
||||
})
|
||||
rendered := ctx.RenderTemplate(tpl, vars.Vars)
|
||||
ctx.PrintInfo("Rendered: " + rendered)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /check-model command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "check-model",
|
||||
Description: "Check model capabilities and availability",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
model := args
|
||||
if model == "" {
|
||||
model = ctx.Model
|
||||
}
|
||||
|
||||
available := ctx.CheckModelAvailable(model)
|
||||
caps, err := ctx.GetModelCapabilities(model)
|
||||
|
||||
info := fmt.Sprintf("Model: %s\n", model)
|
||||
info += fmt.Sprintf("Available: %v\n", available)
|
||||
if err == "" {
|
||||
info += fmt.Sprintf("Provider: %s\n", caps.Provider)
|
||||
info += fmt.Sprintf("Context Limit: %d\n", caps.ContextLimit)
|
||||
info += fmt.Sprintf("Reasoning: %v\n", caps.Reasoning)
|
||||
} else {
|
||||
info += fmt.Sprintf("Error: %s\n", err)
|
||||
}
|
||||
ctx.PrintInfo(info)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /resolve-chain command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "resolve-chain",
|
||||
Description: "Resolve a model chain (usage: /resolve-chain claude-opus,gpt-4o,claude-sonnet)",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
if args == "" {
|
||||
args = "anthropic/claude-opus-4,anthropic/claude-sonnet-4,openai/gpt-4o"
|
||||
}
|
||||
prefs := ctx.SimpleParseArguments(args, 1)
|
||||
chain := []string{}
|
||||
if len(prefs) > 1 {
|
||||
// Split the first arg by comma
|
||||
for _, p := range strings.Split(prefs[1], ",") {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
chain = append(chain, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := ctx.ResolveModelChain(chain)
|
||||
info, _ := json.MarshalIndent(result, "", " ")
|
||||
ctx.PrintInfo("Resolution Result:\n" + string(info))
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /test-conditional command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "test-conditional",
|
||||
Description: "Test model conditional rendering",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
content := `<if-model is="claude-*">This is for Claude models<else>This is for other models</if-model>`
|
||||
rendered := ctx.RenderWithModelConditionals(content)
|
||||
ctx.PrintInfo("Input: " + content)
|
||||
ctx.PrintInfo("Output: " + rendered)
|
||||
ctx.PrintInfo(fmt.Sprintf("Current model matches 'claude-*': %v", ctx.EvaluateModelConditional("claude-*")))
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// OnSessionStart: discover skills automatically
|
||||
api.OnSessionStart(func(e ext.SessionStartEvent, ctx ext.Context) {
|
||||
result := ctx.DiscoverSkills()
|
||||
if result.Error == "" && len(result.Skills) > 0 {
|
||||
discoveredSkills = result.Skills
|
||||
ctx.SetStatus("bridge-demo", fmt.Sprintf("%d skills", len(result.Skills)), 50)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func truncate(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max-3] + "..."
|
||||
}
|
||||
@@ -0,0 +1,406 @@
|
||||
//go:build ignore
|
||||
|
||||
// conversation-manager.go - Advanced conversation tree navigation and management.
|
||||
// This extension demonstrates:
|
||||
// - Tree navigation (GetTreeNode, GetCurrentBranch, NavigateTo)
|
||||
// - Branch summarization and collapsing
|
||||
// - Interactive tree exploration
|
||||
//
|
||||
// Commands:
|
||||
// /tree - Show conversation tree structure
|
||||
// /branch - Show current branch path
|
||||
// /goto <entry-id> - Navigate to a specific entry
|
||||
// /summarize <n> - Summarize last N messages
|
||||
// /fresh-context - Collapse branch and start fresh
|
||||
// /loop <n> <prompt> - Execute prompt N times with fresh context each iteration
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
var (
|
||||
loopActive bool
|
||||
loopCount int
|
||||
loopCurrent int
|
||||
loopPrompt string
|
||||
loopStartNode string
|
||||
)
|
||||
|
||||
func Init(api ext.API) {
|
||||
// /tree - Show tree structure
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "tree",
|
||||
Description: "Show conversation tree structure",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
showTree(ctx)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// /branch - Show current branch
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "branch",
|
||||
Description: "Show current conversation branch",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
showBranch(ctx)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// /goto - Navigate to entry
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "goto",
|
||||
Description: "Navigate to a specific entry ID (usage: /goto <entry-id>)",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
if args == "" {
|
||||
ctx.PrintError("Usage: /goto <entry-id>")
|
||||
return "", nil
|
||||
}
|
||||
result := ctx.NavigateTo(args)
|
||||
if !result.Success {
|
||||
ctx.PrintError(fmt.Sprintf("Navigation failed: %s", result.Error))
|
||||
return "", nil
|
||||
}
|
||||
ctx.PrintInfo(fmt.Sprintf("Navigated to entry: %s", args))
|
||||
|
||||
// Show the node we navigated to
|
||||
node := ctx.GetTreeNode(args)
|
||||
if node != nil {
|
||||
ctx.PrintInfo(fmt.Sprintf("Entry type: %s, Role: %s", node.Type, node.Role))
|
||||
}
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// /summarize - Summarize recent messages
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "summarize",
|
||||
Description: "Summarize last N messages (usage: /summarize [n=5])",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
n := 5
|
||||
if args != "" {
|
||||
if parsed, err := strconv.Atoi(args); err == nil && parsed > 0 {
|
||||
n = parsed
|
||||
}
|
||||
}
|
||||
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) < 2 {
|
||||
ctx.PrintError("Not enough messages to summarize")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Find range to summarize
|
||||
startIdx := len(branch) - n - 1
|
||||
if startIdx < 0 {
|
||||
startIdx = 0
|
||||
}
|
||||
endIdx := len(branch) - 1
|
||||
|
||||
fromID := branch[startIdx].ID
|
||||
toID := branch[endIdx].ID
|
||||
|
||||
ctx.PrintInfo(fmt.Sprintf("Summarizing messages %d to %d...", startIdx, endIdx))
|
||||
summary := ctx.SummarizeBranch(fromID, toID)
|
||||
|
||||
if summary == "" {
|
||||
ctx.PrintError("Failed to generate summary")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: summary,
|
||||
BorderColor: "#89b4fa",
|
||||
Subtitle: "conversation-manager · Summary",
|
||||
})
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// /fresh-context - Collapse and restart
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "fresh-context",
|
||||
Description: "Collapse conversation to summary and start fresh",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) < 3 {
|
||||
ctx.PrintError("Not enough context to collapse")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Keep first message (system), summarize rest
|
||||
fromID := branch[1].ID
|
||||
toID := branch[len(branch)-1].ID
|
||||
|
||||
ctx.PrintInfo("Generating summary for context collapse...")
|
||||
summary := ctx.SummarizeBranch(fromID, toID)
|
||||
|
||||
if summary == "" {
|
||||
ctx.PrintError("Failed to generate summary")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Collapse the branch
|
||||
result := ctx.CollapseBranch(fromID, toID, summary)
|
||||
if !result.Success {
|
||||
ctx.PrintError(fmt.Sprintf("Collapse failed: %s", result.Error))
|
||||
return "", nil
|
||||
}
|
||||
|
||||
ctx.PrintInfo("Context collapsed. Starting fresh with summary.")
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: summary,
|
||||
BorderColor: "#a6e3a1",
|
||||
Subtitle: "conversation-manager · Collapsed Context",
|
||||
})
|
||||
|
||||
// Set a widget showing we're in fresh mode
|
||||
ctx.SetWidget(ext.WidgetConfig{
|
||||
ID: "fresh-context",
|
||||
Placement: ext.WidgetAbove,
|
||||
Content: ext.WidgetContent{Text: "🌱 Fresh Context Mode - Previous conversation collapsed"},
|
||||
Style: ext.WidgetStyle{BorderColor: "#a6e3a1"},
|
||||
})
|
||||
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// /loop - Execute with fresh context each iteration
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "loop",
|
||||
Description: "Execute prompt N times with fresh context (usage: /loop 5 analyze this code)",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
if loopActive {
|
||||
ctx.PrintError("Loop already in progress. Wait for completion.")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Parse arguments
|
||||
parts := strings.SplitN(args, " ", 2)
|
||||
if len(parts) < 2 {
|
||||
ctx.PrintError("Usage: /loop <count> <prompt>")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
count, err := strconv.Atoi(parts[0])
|
||||
if err != nil || count <= 0 || count > 10 {
|
||||
ctx.PrintError("Invalid count (must be 1-10)")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
loopCount = count
|
||||
loopCurrent = 0
|
||||
loopPrompt = parts[1]
|
||||
loopActive = true
|
||||
|
||||
// Store current branch position
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) > 0 {
|
||||
loopStartNode = branch[len(branch)-1].ID
|
||||
}
|
||||
|
||||
ctx.PrintInfo(fmt.Sprintf("Starting loop: %d iterations", loopCount))
|
||||
ctx.SetWidget(ext.WidgetConfig{
|
||||
ID: "loop-progress",
|
||||
Placement: ext.WidgetAbove,
|
||||
Content: ext.WidgetContent{Text: fmt.Sprintf("🔄 Loop: 0/%d - %s", loopCount, loopPrompt)},
|
||||
Style: ext.WidgetStyle{BorderColor: "#fab387"},
|
||||
})
|
||||
|
||||
// Start first iteration
|
||||
executeLoopIteration(ctx)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// OnAgentEnd handles loop continuation
|
||||
api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) {
|
||||
if !loopActive {
|
||||
return
|
||||
}
|
||||
|
||||
loopCurrent++
|
||||
|
||||
if loopCurrent >= loopCount {
|
||||
// Loop complete
|
||||
loopActive = false
|
||||
ctx.RemoveWidget("loop-progress")
|
||||
ctx.PrintInfo(fmt.Sprintf("✅ Loop complete: %d/%d iterations", loopCurrent, loopCount))
|
||||
|
||||
// Show final summary
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) > 0 && loopStartNode != "" {
|
||||
summary := ctx.SummarizeBranch(loopStartNode, branch[len(branch)-1].ID)
|
||||
if summary != "" {
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: summary,
|
||||
BorderColor: "#a6e3a1",
|
||||
Subtitle: "conversation-manager · Loop Summary",
|
||||
})
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Update progress
|
||||
ctx.SetWidget(ext.WidgetConfig{
|
||||
ID: "loop-progress",
|
||||
Placement: ext.WidgetAbove,
|
||||
Content: ext.WidgetContent{Text: fmt.Sprintf("🔄 Loop: %d/%d - %s", loopCurrent, loopCount, loopPrompt)},
|
||||
Style: ext.WidgetStyle{BorderColor: "#fab387"},
|
||||
})
|
||||
|
||||
// Collapse previous iteration for fresh context
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) >= 2 {
|
||||
// Find the user messages (look for the one before the last assistant message)
|
||||
// We want to collapse from the user message that started this iteration
|
||||
// to the last assistant response
|
||||
var collapseStartIdx = -1
|
||||
for i := len(branch) - 1; i >= 0; i-- {
|
||||
if branch[i].Role == "assistant" {
|
||||
// Found the last assistant message, now find the user message before it
|
||||
for j := i - 1; j >= 0; j-- {
|
||||
if branch[j].Role == "user" {
|
||||
collapseStartIdx = j
|
||||
break
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if collapseStartIdx >= 0 {
|
||||
fromID := branch[collapseStartIdx].ID
|
||||
toID := branch[len(branch)-1].ID
|
||||
|
||||
ctx.PrintInfo(fmt.Sprintf("Collapsing iteration %d for fresh context...", loopCurrent))
|
||||
summary := ctx.SummarizeBranch(fromID, toID)
|
||||
if summary != "" {
|
||||
result := ctx.CollapseBranch(fromID, toID, summary)
|
||||
if result.Success {
|
||||
ctx.PrintInfo("Context collapsed successfully")
|
||||
} else {
|
||||
ctx.PrintError(fmt.Sprintf("Collapse failed: %s", result.Error))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Small delay to let UI update
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Trigger next iteration
|
||||
executeLoopIteration(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
// showTree displays the conversation tree structure
|
||||
func showTree(ctx ext.Context) {
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) == 0 {
|
||||
ctx.PrintInfo("Tree is empty")
|
||||
return
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString(fmt.Sprintf("Conversation Tree (%d nodes):\n\n", len(branch)))
|
||||
|
||||
for i, node := range branch {
|
||||
prefix := " "
|
||||
if i == len(branch)-1 {
|
||||
prefix = "▶ " // Current node
|
||||
} else {
|
||||
prefix = " "
|
||||
}
|
||||
|
||||
roleIcon := "💬"
|
||||
switch node.Role {
|
||||
case "user":
|
||||
roleIcon = "👤"
|
||||
case "assistant":
|
||||
roleIcon = "🤖"
|
||||
case "system":
|
||||
roleIcon = "⚙️"
|
||||
}
|
||||
|
||||
content := truncate(node.Content, 50)
|
||||
if node.Type == "branch_summary" {
|
||||
roleIcon = "📋"
|
||||
content = "[Summary] " + truncate(node.Content, 40)
|
||||
}
|
||||
|
||||
output.WriteString(fmt.Sprintf("%s%s %s: %s (%s...)\n", prefix, roleIcon, node.Role, node.ID[:8], content))
|
||||
|
||||
// Show children count if any
|
||||
children := ctx.GetChildren(node.ID)
|
||||
if len(children) > 0 {
|
||||
output.WriteString(fmt.Sprintf(" └─ %d branch(es)\n", len(children)))
|
||||
}
|
||||
}
|
||||
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: output.String(),
|
||||
BorderColor: "#89b4fa",
|
||||
Subtitle: "conversation-manager · Tree View",
|
||||
})
|
||||
}
|
||||
|
||||
// showBranch displays the current branch path
|
||||
func showBranch(ctx ext.Context) {
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) == 0 {
|
||||
ctx.PrintInfo("No active branch")
|
||||
return
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString(fmt.Sprintf("Current Branch (%d nodes from root to leaf):\n\n", len(branch)))
|
||||
|
||||
for i, node := range branch {
|
||||
marker := " "
|
||||
if i == len(branch)-1 {
|
||||
marker = "▶ " // Current leaf
|
||||
}
|
||||
|
||||
output.WriteString(fmt.Sprintf("%s[%d] %s (%s): %s\n",
|
||||
marker, i, node.Type, node.ID[:8], truncate(node.Content, 40)))
|
||||
}
|
||||
|
||||
// Show current node details
|
||||
leaf := branch[len(branch)-1]
|
||||
output.WriteString(fmt.Sprintf("\nCurrent Leaf:\n"))
|
||||
output.WriteString(fmt.Sprintf(" ID: %s\n", leaf.ID))
|
||||
output.WriteString(fmt.Sprintf(" Type: %s\n", leaf.Type))
|
||||
output.WriteString(fmt.Sprintf(" Role: %s\n", leaf.Role))
|
||||
output.WriteString(fmt.Sprintf(" Model: %s\n", leaf.Model))
|
||||
output.WriteString(fmt.Sprintf(" Children: %d\n", len(leaf.Children)))
|
||||
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: output.String(),
|
||||
BorderColor: "#cba6f7",
|
||||
Subtitle: "conversation-manager · Branch View",
|
||||
})
|
||||
}
|
||||
|
||||
// executeLoopIteration triggers the next loop iteration
|
||||
func executeLoopIteration(ctx ext.Context) {
|
||||
iterationPrompt := fmt.Sprintf("[%d/%d] %s", loopCurrent+1, loopCount, loopPrompt)
|
||||
ctx.SendMessage(iterationPrompt)
|
||||
}
|
||||
|
||||
// truncate helper
|
||||
func truncate(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max-3] + "..."
|
||||
}
|
||||
@@ -908,7 +908,7 @@ func summarizeToolAction(toolName string, inputJSON string) string {
|
||||
return "searching " + getStr("pattern", "text")
|
||||
case "ls":
|
||||
return "listing " + getStr("path", "directory")
|
||||
case "spawn_subagent":
|
||||
case "subagent":
|
||||
return "spawning subagent"
|
||||
default:
|
||||
return "using " + toolName
|
||||
|
||||
@@ -2,9 +2,7 @@
|
||||
|
||||
// lsp-diagnostics.go — LSP-powered diagnostics for Kit's edit tool.
|
||||
//
|
||||
// Starts language servers on demand and surfaces diagnostics after file edits,
|
||||
// following the same pattern used by Charm's crush editor:
|
||||
//
|
||||
// Starts language servers on demand and surfaces diagnostics after file edits:
|
||||
// 1. After an edit, notify the LSP server of the file change
|
||||
// 2. Wait for the server to publish fresh diagnostics
|
||||
// 3. Append diagnostic output to the edit tool's result
|
||||
@@ -412,7 +410,7 @@ func (c *lspClient) changeFile(absPath, content string) {
|
||||
}
|
||||
|
||||
// waitForDiagnostics polls until the server publishes new diagnostics or
|
||||
// the timeout elapses. Mirrors crush's WaitForDiagnostics pattern.
|
||||
// the timeout elapses.
|
||||
func (c *lspClient) waitForDiagnostics(timeout time.Duration) {
|
||||
c.diagMu.Lock()
|
||||
startVersion := c.diagVersion
|
||||
|
||||
@@ -0,0 +1,269 @@
|
||||
//go:build ignore
|
||||
|
||||
// prompt-templates.go - Frontmatter-driven prompt templates with model switching.
|
||||
// This extension demonstrates the new bridged SDK APIs:
|
||||
// - Tree navigation for conversation management
|
||||
// - Template parsing with {{variable}} substitution
|
||||
// - Model resolution with fallback chains
|
||||
// - Skill injection
|
||||
//
|
||||
// Usage:
|
||||
// 1. Create ~/.config/kit/prompts/debug.md with frontmatter:
|
||||
// ---
|
||||
// description: Debug Python code
|
||||
// model: claude-sonnet-4-20250514
|
||||
// skill: python
|
||||
// ---
|
||||
// Help me debug this Python code: {{input}}
|
||||
//
|
||||
// 2. In Kit: /debug my_script.py
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
// PromptTemplate represents a loaded template with frontmatter
|
||||
type PromptTemplate struct {
|
||||
Name string
|
||||
Description string
|
||||
Model string
|
||||
Skill string
|
||||
Content string
|
||||
Variables []string
|
||||
Path string
|
||||
}
|
||||
|
||||
var (
|
||||
templates = make(map[string]PromptTemplate)
|
||||
templateDir string
|
||||
)
|
||||
|
||||
func Init(api ext.API) {
|
||||
// Determine template directory
|
||||
home, _ := os.UserHomeDir()
|
||||
templateDir = filepath.Join(home, ".config", "kit", "prompts")
|
||||
|
||||
// Ensure directory exists
|
||||
os.MkdirAll(templateDir, 0755)
|
||||
|
||||
// Register commands
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "reload-templates",
|
||||
Description: "Reload prompt templates from disk",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
loadTemplates(ctx)
|
||||
ctx.PrintInfo(fmt.Sprintf("Loaded %d templates from %s", len(templates), templateDir))
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Dynamic template commands are registered after loading
|
||||
api.OnSessionStart(func(e ext.SessionStartEvent, ctx ext.Context) {
|
||||
loadTemplates(ctx)
|
||||
registerTemplateCommands(api, ctx)
|
||||
})
|
||||
}
|
||||
|
||||
// loadTemplates discovers and loads all template files
|
||||
func loadTemplates(ctx ext.Context) {
|
||||
templates = make(map[string]PromptTemplate)
|
||||
|
||||
entries, err := os.ReadDir(templateDir)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".md") {
|
||||
continue
|
||||
}
|
||||
|
||||
path := filepath.Join(templateDir, entry.Name())
|
||||
tpl, err := loadTemplateFile(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
name := strings.TrimSuffix(entry.Name(), ".md")
|
||||
templates[name] = tpl
|
||||
}
|
||||
}
|
||||
|
||||
// loadTemplateFile parses a template with YAML frontmatter
|
||||
func loadTemplateFile(path string) (PromptTemplate, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return PromptTemplate{}, err
|
||||
}
|
||||
|
||||
content := string(data)
|
||||
tpl := PromptTemplate{Path: path}
|
||||
|
||||
// Parse frontmatter
|
||||
if strings.HasPrefix(content, "---") {
|
||||
parts := strings.SplitN(content[3:], "---", 2)
|
||||
if len(parts) == 2 {
|
||||
frontmatter := strings.TrimSpace(parts[0])
|
||||
body := strings.TrimSpace(parts[1])
|
||||
|
||||
// Simple line-by-line frontmatter parsing
|
||||
for _, line := range strings.Split(frontmatter, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
key, value, found := strings.Cut(line, ":")
|
||||
if found {
|
||||
key = strings.TrimSpace(key)
|
||||
value = strings.TrimSpace(value)
|
||||
switch key {
|
||||
case "description":
|
||||
tpl.Description = value
|
||||
case "model":
|
||||
tpl.Model = value
|
||||
case "skill":
|
||||
tpl.Skill = value
|
||||
}
|
||||
}
|
||||
}
|
||||
tpl.Content = body
|
||||
} else {
|
||||
tpl.Content = content
|
||||
}
|
||||
} else {
|
||||
tpl.Content = content
|
||||
}
|
||||
|
||||
// Parse {{variables}} using simple string parsing
|
||||
// (Can't use ctx.ParseTemplate here since we're in Init, not a handler)
|
||||
var vars []string
|
||||
for {
|
||||
start := strings.Index(tpl.Content, "{{")
|
||||
if start == -1 {
|
||||
break
|
||||
}
|
||||
end := strings.Index(tpl.Content[start:], "}}")
|
||||
if end == -1 {
|
||||
break
|
||||
}
|
||||
varName := strings.TrimSpace(tpl.Content[start+2 : start+end])
|
||||
vars = append(vars, varName)
|
||||
tpl.Content = tpl.Content[:start] + "{{" + varName + "}}" + tpl.Content[start+end+2:]
|
||||
}
|
||||
tpl.Variables = vars
|
||||
|
||||
return tpl, nil
|
||||
}
|
||||
|
||||
// registerTemplateCommands dynamically registers commands for each template
|
||||
func registerTemplateCommands(api ext.API, ctx ext.Context) {
|
||||
for name, tpl := range templates {
|
||||
// Skip if already registered (we'd need to track this)
|
||||
tplCopy := tpl // Capture for closure
|
||||
nameCopy := name
|
||||
|
||||
// Build description with metadata
|
||||
desc := tplCopy.Description
|
||||
if desc == "" {
|
||||
desc = fmt.Sprintf("Run %s template", nameCopy)
|
||||
}
|
||||
if tplCopy.Model != "" {
|
||||
desc += fmt.Sprintf(" [%s", tplCopy.Model)
|
||||
if tplCopy.Skill != "" {
|
||||
desc += fmt.Sprintf(" +%s", tplCopy.Skill)
|
||||
}
|
||||
desc += "]"
|
||||
}
|
||||
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: nameCopy,
|
||||
Description: desc,
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
return executeTemplate(ctx, tplCopy, args)
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// executeTemplate runs a template with the given arguments
|
||||
func executeTemplate(ctx ext.Context, tpl PromptTemplate, args string) (string, error) {
|
||||
// Store original model for restoration
|
||||
originalModel := ctx.Model
|
||||
|
||||
// 1. Resolve and switch model if specified
|
||||
if tpl.Model != "" {
|
||||
// Parse model chain (comma-separated)
|
||||
preferences := strings.Split(tpl.Model, ",")
|
||||
for i := range preferences {
|
||||
preferences[i] = strings.TrimSpace(preferences[i])
|
||||
}
|
||||
|
||||
result := ctx.ResolveModelChain(preferences)
|
||||
if result.Error != "" {
|
||||
ctx.PrintError(fmt.Sprintf("Model resolution failed: %s", result.Error))
|
||||
// Continue with current model
|
||||
} else {
|
||||
ctx.PrintInfo(fmt.Sprintf("Switching to model: %s", result.Model))
|
||||
if err := ctx.SetModel(result.Model); err != nil {
|
||||
ctx.PrintError(fmt.Sprintf("Failed to switch model: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Inject skill if specified
|
||||
if tpl.Skill != "" {
|
||||
err := ctx.InjectSkillAsContext(tpl.Skill)
|
||||
if err != "" {
|
||||
ctx.PrintError(fmt.Sprintf("Skill injection failed: %s", err))
|
||||
} else {
|
||||
ctx.PrintInfo(fmt.Sprintf("Injected skill: %s", tpl.Skill))
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Parse and render template
|
||||
parsed := ctx.ParseTemplate(tpl.Name, tpl.Content)
|
||||
|
||||
// Build variable map
|
||||
vars := make(map[string]string)
|
||||
|
||||
// Simple argument parsing: first arg is $1 (input), rest is $@
|
||||
if len(parsed.Variables) > 0 {
|
||||
argsList := ctx.SimpleParseArguments(args, len(parsed.Variables))
|
||||
for i, varName := range parsed.Variables {
|
||||
if i < len(parsed.Variables) && i+1 < len(argsList) {
|
||||
vars[varName] = argsList[i+1]
|
||||
}
|
||||
}
|
||||
// If single variable, use full args
|
||||
if len(parsed.Variables) == 1 && vars[parsed.Variables[0]] == "" {
|
||||
vars[parsed.Variables[0]] = args
|
||||
}
|
||||
}
|
||||
|
||||
// Render with model conditionals
|
||||
content := ctx.RenderWithModelConditionals(tpl.Content)
|
||||
rendered := ctx.RenderTemplate(ext.PromptTemplate{Name: tpl.Name, Content: content, Variables: parsed.Variables}, vars)
|
||||
|
||||
// 4. Send the rendered prompt
|
||||
ctx.SendMessage(rendered)
|
||||
|
||||
// 5. Schedule model restoration after turn completes
|
||||
// We use a goroutine to wait and restore
|
||||
if tpl.Model != "" && originalModel != "" {
|
||||
go func() {
|
||||
// Note: In a real implementation, we'd use OnAgentEnd event
|
||||
// For now, the user can manually switch back
|
||||
ctx.SetStatus("template-mode", fmt.Sprintf("Template: %s (model will restore)", tpl.Name), 20)
|
||||
}()
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Executing template: %s", tpl.Name), nil
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/pkg/extensions/test"
|
||||
)
|
||||
|
||||
// TestSubagentMonitor_SessionStart verifies OnSessionStart initializes state
|
||||
// without panicking and properly guards nil ctx calls.
|
||||
func TestSubagentMonitor_SessionStart(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
|
||||
// Emit SessionStart - should not panic even with nil ctx functions
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionStart should not error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubagentMonitor_SubagentLifecycle verifies the full subagent lifecycle
|
||||
// creates entries and emits widget updates.
|
||||
func TestSubagentMonitor_SubagentLifecycle(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
|
||||
// Start session
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionStart should not error: %v", err)
|
||||
}
|
||||
|
||||
// Emit SubagentStart
|
||||
_, err = harness.Emit(extensions.SubagentStartEvent{
|
||||
ToolCallID: "call-1",
|
||||
Task: "test task",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentStart should not error: %v", err)
|
||||
}
|
||||
|
||||
// Emit a few chunks
|
||||
for i := range 3 {
|
||||
_, err = harness.Emit(extensions.SubagentChunkEvent{
|
||||
ToolCallID: "call-1",
|
||||
Task: "test task",
|
||||
ChunkType: "text",
|
||||
Content: fmt.Sprintf("line %d", i),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentChunk %d should not error: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Emit tool call chunk
|
||||
_, err = harness.Emit(extensions.SubagentChunkEvent{
|
||||
ToolCallID: "call-1",
|
||||
Task: "test task",
|
||||
ChunkType: "tool_call",
|
||||
ToolName: "bash",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentChunk tool_call should not error: %v", err)
|
||||
}
|
||||
|
||||
// Emit SubagentEnd
|
||||
_, err = harness.Emit(extensions.SubagentEndEvent{
|
||||
ToolCallID: "call-1",
|
||||
Task: "test task",
|
||||
Response: "done",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentEnd should not error: %v", err)
|
||||
}
|
||||
|
||||
// Give time for cleanup goroutine
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestSubagentMonitor_MultipleSubagents verifies multiple parallel subagents.
|
||||
func TestSubagentMonitor_MultipleSubagents(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionStart should not error: %v", err)
|
||||
}
|
||||
|
||||
// Start 3 subagents
|
||||
for i := 1; i <= 3; i++ {
|
||||
_, err := harness.Emit(extensions.SubagentStartEvent{
|
||||
ToolCallID: fmt.Sprintf("call-%d", i),
|
||||
Task: fmt.Sprintf("task %d", i),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentStart %d should not error: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Emit chunks for each
|
||||
for i := 1; i <= 3; i++ {
|
||||
_, err := harness.Emit(extensions.SubagentChunkEvent{
|
||||
ToolCallID: fmt.Sprintf("call-%d", i),
|
||||
Task: fmt.Sprintf("task %d", i),
|
||||
ChunkType: "text",
|
||||
Content: fmt.Sprintf("output from agent %d", i),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentChunk %d should not error: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// End all subagents
|
||||
for i := 1; i <= 3; i++ {
|
||||
_, err := harness.Emit(extensions.SubagentEndEvent{
|
||||
ToolCallID: fmt.Sprintf("call-%d", i),
|
||||
Task: fmt.Sprintf("task %d", i),
|
||||
Response: "completed",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentEnd %d should not error: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestSubagentMonitor_SessionShutdown verifies shutdown doesn't panic
|
||||
// even with nil ctx functions.
|
||||
func TestSubagentMonitor_SessionShutdown(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
|
||||
// Start then shutdown
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionStart should not error: %v", err)
|
||||
}
|
||||
|
||||
// Start a subagent
|
||||
_, err = harness.Emit(extensions.SubagentStartEvent{
|
||||
ToolCallID: "call-1",
|
||||
Task: "test task",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentStart should not error: %v", err)
|
||||
}
|
||||
|
||||
// Shutdown - should not panic even with active subagent
|
||||
_, err = harness.Emit(extensions.SessionShutdownEvent{})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionShutdown should not error: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -37,7 +37,7 @@ func Init(api ext.API) {
|
||||
"Subagent Test Extension loaded\n\n" +
|
||||
"/subtest <task> Spawn blocking subagent\n" +
|
||||
"/subbg <task> Spawn background subagent\n\n" +
|
||||
"The LLM can also use the spawn_subagent tool.")
|
||||
"The LLM can also use the subagent tool.")
|
||||
})
|
||||
|
||||
api.OnAgentEnd(func(_ ext.AgentEndEvent, ctx ext.Context) {
|
||||
|
||||
@@ -3,9 +3,9 @@ module github.com/mark3labs/kit
|
||||
go 1.26.1
|
||||
|
||||
require (
|
||||
charm.land/bubbles/v2 v2.0.0
|
||||
charm.land/bubbles/v2 v2.1.0
|
||||
charm.land/bubbletea/v2 v2.0.2
|
||||
charm.land/fantasy v0.16.0
|
||||
charm.land/fantasy v0.17.1
|
||||
charm.land/huh/v2 v2.0.3
|
||||
charm.land/lipgloss/v2 v2.0.2
|
||||
github.com/alecthomas/chroma/v2 v2.23.1
|
||||
@@ -13,7 +13,9 @@ require (
|
||||
github.com/charmbracelet/fang v1.0.0
|
||||
github.com/charmbracelet/log v1.0.0
|
||||
github.com/coder/acp-go-sdk v0.6.3
|
||||
github.com/mark3labs/mcp-go v0.45.0
|
||||
github.com/indaco/herald v0.10.0
|
||||
github.com/indaco/herald-md v0.1.0
|
||||
github.com/mark3labs/mcp-go v0.46.0
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/traefik/yaegi v0.16.1
|
||||
@@ -23,30 +25,27 @@ require (
|
||||
|
||||
require (
|
||||
cloud.google.com/go v0.123.0 // indirect
|
||||
cloud.google.com/go/auth v0.18.2 // indirect
|
||||
cloud.google.com/go/auth v0.19.0 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.9.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect
|
||||
github.com/atotto/clipboard v0.1.4 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.4 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.12 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.12 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.5 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.13 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.13 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.8 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.13 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.14 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.18 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 // indirect
|
||||
github.com/aws/smithy-go v1.24.2 // indirect
|
||||
github.com/aymerick/douceur v0.2.0 // indirect
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/buger/jsonparser v1.1.2 // indirect
|
||||
github.com/catppuccin/go v0.3.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab // indirect
|
||||
@@ -54,11 +53,11 @@ require (
|
||||
github.com/charmbracelet/harmonica v0.2.0 // indirect
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 // indirect
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266 // indirect
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260316091819-b93f6a3b8502 // indirect
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260330092749-0f94982c930b // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260322003602-9b007323c5cd // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260330094520-2dce04b6f8a4 // indirect
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260322003602-9b007323c5cd // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260330094520-2dce04b6f8a4 // indirect
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/json v0.2.0 // indirect
|
||||
github.com/charmbracelet/x/termios v0.1.1 // indirect
|
||||
@@ -77,26 +76,22 @@ require (
|
||||
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/jsonschema-go v0.4.2 // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.19.0 // indirect
|
||||
github.com/gorilla/css v1.0.1 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.20.0 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/invopop/jsonschema v0.13.0 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.2.12 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.3.0 // indirect
|
||||
github.com/kaptinlin/jsonpointer v0.4.17 // indirect
|
||||
github.com/kaptinlin/jsonschema v0.7.6 // indirect
|
||||
github.com/kaptinlin/messageformat-go v0.4.18 // indirect
|
||||
github.com/mailru/easyjson v0.9.2 // indirect
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
|
||||
github.com/kaptinlin/jsonschema v0.7.7 // indirect
|
||||
github.com/kaptinlin/messageformat-go v0.4.19 // indirect
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect
|
||||
github.com/muesli/mango v0.2.0 // indirect
|
||||
github.com/muesli/mango-cobra v1.3.0 // indirect
|
||||
github.com/muesli/mango-pflag v0.2.0 // indirect
|
||||
github.com/muesli/reflow v0.3.0 // indirect
|
||||
github.com/muesli/roff v0.1.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.3.0 // indirect
|
||||
github.com/sagikazarmark/locafero v0.12.0 // indirect
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
@@ -105,11 +100,9 @@ require (
|
||||
github.com/tidwall/match v1.2.0 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
github.com/yuin/goldmark v1.7.17 // indirect
|
||||
github.com/yuin/goldmark-emoji v1.0.6 // indirect
|
||||
github.com/yuin/goldmark v1.8.2 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 // indirect
|
||||
@@ -122,8 +115,8 @@ require (
|
||||
golang.org/x/net v0.52.0 // indirect
|
||||
golang.org/x/oauth2 v0.36.0 // indirect
|
||||
golang.org/x/time v0.15.0 // indirect
|
||||
google.golang.org/api v0.272.0 // indirect
|
||||
google.golang.org/genai v1.51.0 // indirect
|
||||
google.golang.org/api v0.273.0 // indirect
|
||||
google.golang.org/genai v1.52.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7 // indirect
|
||||
google.golang.org/grpc v1.79.3 // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
@@ -132,11 +125,10 @@ require (
|
||||
|
||||
require (
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/charmbracelet/glamour v1.0.0
|
||||
github.com/charmbracelet/x/ansi v0.11.6
|
||||
github.com/charmbracelet/x/term v0.2.2 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.21 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
charm.land/bubbles/v2 v2.0.0 h1:tE3eK/pHjmtrDiRdoC9uGNLgpopOd8fjhEe31B/ai5s=
|
||||
charm.land/bubbles/v2 v2.0.0/go.mod h1:rCHoleP2XhU8um45NTuOWBPNVHxnkXKTiZqcclL/qOI=
|
||||
charm.land/bubbles/v2 v2.1.0 h1:YSnNh5cPYlYjPxRrzs5VEn3vwhtEn3jVGRBT3M7/I0g=
|
||||
charm.land/bubbles/v2 v2.1.0/go.mod h1:l97h4hym2hvWBVfmJDtrEHHCtkIKeTEb3TTJ4ZOB3wY=
|
||||
charm.land/bubbletea/v2 v2.0.2 h1:4CRtRnuZOdFDTWSff9r8QFt/9+z6Emubz3aDMnf/dx0=
|
||||
charm.land/bubbletea/v2 v2.0.2/go.mod h1:3LRff2U4WIYXy7MTxfbAQ+AdfM3D8Xuvz2wbsOD9OHQ=
|
||||
charm.land/fantasy v0.16.0 h1:vE/6sR9nPcSD8qXJXX6wR8NXjtWlBVAzwQmTh5pHVrs=
|
||||
charm.land/fantasy v0.16.0/go.mod h1:VZjpXVh7IgeiIzGQybEnKzd68ofDsRj94+kzH1ZCAfQ=
|
||||
charm.land/fantasy v0.17.1 h1:SQzfnyJPDuQWt6e//KKmQmEEXdqHMC0IZz10XwkLcEM=
|
||||
charm.land/fantasy v0.17.1/go.mod h1:FF5ALCCHETacHJPBqU42CtwMInYQ0ul52fdzIHQMbQk=
|
||||
charm.land/huh/v2 v2.0.3 h1:2cJsMqEPwSywGHvdlKsJyQKPtSJLVnFKyFbsYZTlLkU=
|
||||
charm.land/huh/v2 v2.0.3/go.mod h1:93eEveeeqn47MwiC3tf+2atZ2l7Is88rAtmZNZ8x9Wc=
|
||||
charm.land/lipgloss/v2 v2.0.2 h1:xFolbF8JdpNkM2cEPTfXEcW1p6NRzOWTSamRfYEw8cs=
|
||||
charm.land/lipgloss/v2 v2.0.2/go.mod h1:KjPle2Qd3YmvP1KL5OMHiHysGcNwq6u83MUjYkFvEkM=
|
||||
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
|
||||
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
|
||||
cloud.google.com/go/auth v0.18.2 h1:+Nbt5Ev0xEqxlNjd6c+yYUeosQ5TtEUaNcN/3FozlaM=
|
||||
cloud.google.com/go/auth v0.18.2/go.mod h1:xD+oY7gcahcu7G2SG2DsBerfFxgPAJz17zz2joOFF3M=
|
||||
cloud.google.com/go/auth v0.19.0 h1:DGYwtbcsGsT1ywuxsIoWi1u/vlks0moIblQHgSDgQkQ=
|
||||
cloud.google.com/go/auth v0.19.0/go.mod h1:2Aph7BT2KnaSFOM0JDPyiYgNh6PL9vGMiP8CUIXZ+IY=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
|
||||
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
|
||||
@@ -34,46 +34,40 @@ github.com/alecthomas/repr v0.5.2 h1:SU73FTI9D1P5UNtvseffFSGmdNci/O6RsqzeXJtP0Qs
|
||||
github.com/alecthomas/repr v0.5.2/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
|
||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.4 h1:10f50G7WyU02T56ox1wWXq+zTX9I1zxG46HYuG1hH/k=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.4/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 h1:3kGOqnh1pPeddVa/E37XNTaWJ8W6vrbYV9lJEkCnhuY=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.12 h1:O3csC7HUGn2895eNrLytOJQdoL2xyJy0iYXhoZ1OmP0=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.12/go.mod h1:96zTvoOFR4FURjI+/5wY1vc1ABceROO4lWgWJuxgy0g=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.12 h1:oqtA6v+y5fZg//tcTWahyN9PEn5eDU/Wpvc2+kJ4aY8=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.12/go.mod h1:U3R1RtSHx6NB0DvEQFGyf/0sbrpJrluENHdPy1j/3TE=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20 h1:zOgq3uezl5nznfoK3ODuqbhVg1JzAGDUhXOsU0IDCAo=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20/go.mod h1:z/MVwUARehy6GAg/yQ1GO2IMl0k++cu1ohP9zo887wE=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 h1:CNXO7mvgThFGqOFgbNAP2nol2qAWBOGfqR/7tQlvLmc=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20/go.mod h1:oydPDJKcfMhgfcgBUZaG+toBbwy8yPWubJXBVERtI4o=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 h1:tN6W/hg+pkM+tf9XDkWUbDEjGLb+raoBMFsTodcoYKw=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20/go.mod h1:YJ898MhD067hSHA6xYCx5ts/jEd8BSOLtQDL3iZsvbc=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV/yY=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.5/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 h1:eBMB84YGghSocM7PsjmmPffTa+1FBUeNvGvFou6V/4o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.13 h1:5KgbxMaS2coSWRrx9TX/QtWbqzgQkOdEa3sZPhBhCSg=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.13/go.mod h1:8zz7wedqtCbw5e9Mi2doEwDyEgHcEE9YOJp6a8jdSMY=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.13 h1:mA59E3fokBvyEGHKFdnpNNrvaR351cqiHgRg+JzOSRI=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.13/go.mod h1:yoTXOQKea18nrM69wGF9jBdG4WocSZA1h38A+t/MAsk=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 h1:NUS3K4BTDArQqNu2ih7yeDLaS3bmHD0YndtA6UP884g=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21/go.mod h1:YWNWJQNjKigKY1RHVJCuupeWDrrHjRqHm0N9rdrWzYI=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 h1:Rgg6wvjjtX8bNHcvi9OnXWwcE0a2vGpbwmtICOsvcf4=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21/go.mod h1:A/kJFst/nm//cyqonihbdpQZwiUhhzpqTsdbhDdRF9c=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 h1:PEgGVtPoB6NTpPrBgqSE5hE/o47Ij9qk/SEZFbUOe9A=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21/go.mod h1:p+hz+PRAYlY3zcpJhPwXlLC4C+kqn70WIHwnzAfs6ps=
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 h1:qYQ4pzQ2Oz6WpQ8T3HvGHnZydA72MnLuFK9tJwmrbHw=
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 h1:2HvVAIq+YqgGotK6EkMf+KIEqTISmTYh5zLpYyeTo1Y=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20/go.mod h1:V4X406Y666khGa8ghKmphma/7C0DAtEQYhkq9z4vpbk=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.8 h1:0GFOLzEbOyZABS3PhYfBIx2rNBACYcKty+XGkTgw1ow=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.8/go.mod h1:LXypKvk85AROkKhOG6/YEcHFPoX+prKTowKnVdcaIxE=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.13 h1:kiIDLZ005EcKomYYITtfsjn7dtOwHDOFy7IbPXKek2o=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.13/go.mod h1:2h/xGEowcW/g38g06g3KpRWDlT+OTfxxI0o1KqayAB8=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17 h1:jzKAXIlhZhJbnYwHbvUQZEB8KfgAEuG0dc08Bkda7NU=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17/go.mod h1:Al9fFsXjv4KfbzQHGe6V4NZSZQXecFcvaIF4e70FoRA=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 h1:Cng+OOwCHmFljXIxpEVXAGMnBia8MSU6Ch5i9PgBkcU=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9/go.mod h1:LrlIndBDdjA/EeXeyNBle+gyCwTlizzW5ycgWnvIxkk=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 h1:c31//R3xgIJMSC8S6hEVq+38DcvUlgFY0FM6mSI5oto=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21/go.mod h1:r6+pf23ouCB718FUxaqzZdbpYFyDtehyZcmP5KL9FkA=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 h1:QKZH0S178gCmFEgst8hN0mCX1KxLgHBKKY/CLqwP8lg=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9/go.mod h1:7yuQJoT+OoH8aqIxw9vwF+8KpvLZ8AWmvmUWHsGQZvI=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.14 h1:GcLE9ba5ehAQma6wlopUesYg/hbcOhFNWTjELkiWkh4=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.14/go.mod h1:WSvS1NLr7JaPunCXqpJnWk1Bjo7IxzZXrZi1QQCkuqM=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.18 h1:mP49nTpfKtpXLt5SLn8Uv8z6W+03jYVoOSAl/c02nog=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.18/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 h1:p8ogvvLugcR/zLBXTXrTkj0RYBUdErbMnAFFp12Lm/U=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10/go.mod h1:60dv0eZJfeVXfbT1tFJinbHrDfSJ2GZl4Q//OSSNAVw=
|
||||
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
|
||||
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/aymanbagabas/go-udiff v0.4.1 h1:OEIrQ8maEeDBXQDoGCbbTTXYJMYRCRO1fnodZ12Gv5o=
|
||||
github.com/aymanbagabas/go-udiff v0.4.1/go.mod h1:0L9PGwj20lrtmEMeyw4WKJ/TMyDtvAoK9bf2u/mNo3w=
|
||||
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
|
||||
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
|
||||
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
|
||||
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
|
||||
github.com/buger/jsonparser v1.1.2 h1:frqHqw7otoVbk5M8LlE/L7HTnIq2v9RX6EJ48i9AxJk=
|
||||
github.com/buger/jsonparser v1.1.2/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
|
||||
github.com/catppuccin/go v0.3.0 h1:d+0/YicIq+hSTo5oPuRi5kOpqkVA5tAsU6dNhvRu+aY=
|
||||
github.com/catppuccin/go v0.3.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
@@ -84,8 +78,6 @@ github.com/charmbracelet/colorprofile v0.4.3 h1:QPa1IWkYI+AOB+fE+mg/5/4HRMZcaXex
|
||||
github.com/charmbracelet/colorprofile v0.4.3/go.mod h1:/zT4BhpD5aGFpqQQqw7a+VtHCzu+zrQtt1zhMt9mR4Q=
|
||||
github.com/charmbracelet/fang v1.0.0 h1:jESBY40agJOlLYnnv9jE0mLqDGTxEk0hkOnx7YGyRlQ=
|
||||
github.com/charmbracelet/fang v1.0.0/go.mod h1:P5/DNb9DddQ0Z0dbc0P3ol4/ix5Po7Ofr2KMBfAqoCo=
|
||||
github.com/charmbracelet/glamour v1.0.0 h1:AWMLOVFHTsysl4WV8T8QgkQ0s/ZNZo7CiE4WKhk8l08=
|
||||
github.com/charmbracelet/glamour v1.0.0/go.mod h1:DSdohgOBkMr2ZQNhw4LZxSGpx3SvpeujNoXrQyH2hxo=
|
||||
github.com/charmbracelet/harmonica v0.2.0 h1:8NxJWRWg/bzKqqEaaeFNipOu77YR5t8aSwG4pgaUBiQ=
|
||||
github.com/charmbracelet/harmonica v0.2.0/go.mod h1:KSri/1RMQOZLbw7AHqgcBycp8pgJnQMYYT8QZRqZ1Ao=
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 h1:ZR7e0ro+SZZiIZD7msJyA+NjkCNNavuiPBLgerbOziE=
|
||||
@@ -94,8 +86,8 @@ github.com/charmbracelet/log v1.0.0 h1:HVVVMmfOorfj3BA9i8X8UL69Hoz9lI0PYwXfJvOdR
|
||||
github.com/charmbracelet/log v1.0.0/go.mod h1:uYgY3SmLpwJWxmlrPwXvzVYujxis1vAKRV/0VQB7yWA=
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266 h1:BW/sZtyd1JyYy0h5adMm3tzpNyL857LWjuTRET6OhpY=
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266/go.mod h1:1DahUaExbUZx/jD+FNT2PKP4L9rLE5+ZBRuI8mZjd/E=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260316091819-b93f6a3b8502 h1:hzWNs3UQRSUTS6YCbLaQnwqKBFXT5Yh1OOw6+26apqg=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260316091819-b93f6a3b8502/go.mod h1:mkUCcxn9w9j89JJp3pOza5tmDQZPgIB75UfmQlFYvas=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260330092749-0f94982c930b h1:ASDO9RT6SNKTQN87jO2bRfxHFJq8cgeYdFzivY2gCeM=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260330092749-0f94982c930b/go.mod h1:Vo8TffMf0q7Uho/n8e6XpBZvOWtd3g39yX+9P5rRutA=
|
||||
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
|
||||
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
|
||||
@@ -104,14 +96,14 @@ github.com/charmbracelet/x/conpty v0.1.1 h1:s1bUxjoi7EpqiXysVtC+a8RrvPPNcNvAjfi4
|
||||
github.com/charmbracelet/x/conpty v0.1.1/go.mod h1:OmtR77VODEFbiTzGE9G1XiRJAga6011PIm4u5fTNZpk=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 h1:JSt3B+U9iqk37QUU2Rvb6DSBYRLtWqFqfxf8l5hOZUA=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86/go.mod h1:2P0UgXMEa6TsToMSuFqKFQR+fZTO9CNGUNokkPatT/0=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260322003602-9b007323c5cd h1:eStB6uX52pgrm6TxQcEKctPrEC+a/9ubJC+P671idOc=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260322003602-9b007323c5cd/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260330094520-2dce04b6f8a4 h1:pIj18ZCZO4WOVj7jwjLoUb1lC7rS/I8oC3fZWXugNaY=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260330094520-2dce04b6f8a4/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f h1:pk6gmGpCE7F3FcjaOEKYriCvpmIN4+6OS/RD0vm4uIA=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f/go.mod h1:IfZAMTHB6XkZSeXUqriemErjAWCCzT0LwjKFYCZyw0I=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 h1:55/qLwjIh0gL0Vni+QAWk7T/qRVP6sBf+2agPBgnOFE=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0/go.mod h1:5UHwmG+is5THxMyCJHNPCn2/ecI07aKNrW+LcResjJ8=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260322003602-9b007323c5cd h1:U8xj0UXwqHzO+UYHZJopKF+gWaQEW8oj60fmiq9TFY4=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260322003602-9b007323c5cd/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260330094520-2dce04b6f8a4 h1:VSd4zShIAf/4FgEDFJpapEcAPrc7h3dyyN7V9JlJpQw=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260330094520-2dce04b6f8a4/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 h1:i69S2XI7uG1u4NLGeJPSYU++Nmjvpo9nwd6aoEm7gkA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0/go.mod h1:/ehtMPNh9K4odGFkqYJKpIYyePhdp1hLBRvyY4bWkH8=
|
||||
github.com/charmbracelet/x/json v0.2.0 h1:DqB+ZGx2h+Z+1s98HOuOyli+i97wsFQIxP2ZQANTPrQ=
|
||||
@@ -173,51 +165,48 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
|
||||
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
|
||||
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
|
||||
github.com/googleapis/gax-go/v2 v2.19.0 h1:fYQaUOiGwll0cGj7jmHT/0nPlcrZDFPrZRhTsoCr8hE=
|
||||
github.com/googleapis/gax-go/v2 v2.19.0/go.mod h1:w2ROXVdfGEVFXzmlciUU4EdjHgWvB5h2n6x/8XSTTJA=
|
||||
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
|
||||
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
|
||||
github.com/googleapis/gax-go/v2 v2.20.0 h1:NIKVuLhDlIV74muWlsMM4CcQZqN6JJ20Qcxd9YMuYcs=
|
||||
github.com/googleapis/gax-go/v2 v2.20.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=
|
||||
github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0=
|
||||
github.com/kaptinlin/go-i18n v0.2.12 h1:ywDsvb4KDFddMC2dpI/rrIzGU2mWUSvHmWUm9BMsdl4=
|
||||
github.com/kaptinlin/go-i18n v0.2.12/go.mod h1:pVcu9qsW5pOIOoZFJXesRYmLos1vMQrby70JPAoWmJU=
|
||||
github.com/indaco/herald v0.10.0 h1:XzahEKX6cr50qZQrUdA3QrQBHg8uGm5jETD0UDi21BI=
|
||||
github.com/indaco/herald v0.10.0/go.mod h1:T5g1+XLYvpjouhzAGHnAHDCKizhESkoV6+QPZ3DhgWA=
|
||||
github.com/indaco/herald-md v0.1.0 h1:zmYudYo+uamzKTBcIffJVJYrqk9xDNnVrTh+de2zciw=
|
||||
github.com/indaco/herald-md v0.1.0/go.mod h1:Z1HxPCbSn+/+TFzOM/UbsmKeEk/28NNI6JOTileKXto=
|
||||
github.com/kaptinlin/go-i18n v0.3.0 h1:wP76dvYg04bvwTb+8NB+CmdZ2kL7lSSCQ9B/kFv7QHo=
|
||||
github.com/kaptinlin/go-i18n v0.3.0/go.mod h1:pVcu9qsW5pOIOoZFJXesRYmLos1vMQrby70JPAoWmJU=
|
||||
github.com/kaptinlin/jsonpointer v0.4.17 h1:mY9k8ciWncxbsECyaxKnR0MdmxamNdp2tLQkAKVrtSk=
|
||||
github.com/kaptinlin/jsonpointer v0.4.17/go.mod h1:SsfsjqnHG5zuKo1DTBzk1VknaHlL4osHw+X9kZKukpU=
|
||||
github.com/kaptinlin/jsonschema v0.7.6 h1:UUMqZGFAk7nOzQsYAxvgygm4wpDp/nwXxA4VP9mCPCs=
|
||||
github.com/kaptinlin/jsonschema v0.7.6/go.mod h1:GGk/oE+F1lWUfYrzKaCf4QWZmMdytt0LL4XdFEFB0LE=
|
||||
github.com/kaptinlin/messageformat-go v0.4.18 h1:RBlHVWgZyoxTcUgGWBsl2AcyScq/urqbLZvzgryTmSI=
|
||||
github.com/kaptinlin/messageformat-go v0.4.18/go.mod h1:ntI3154RnqJgr7GaC+vZBnIExl2V3sv9selvRNNEM24=
|
||||
github.com/kaptinlin/jsonschema v0.7.7 h1:41BlQJ9dskH0oE5DSzBUrl/w4JQYIr6N6L0B5GNyDoM=
|
||||
github.com/kaptinlin/jsonschema v0.7.7/go.mod h1:rKjWfyySHSxAD7Li2ctYkPlOu960igoKBvZ2ADRtd5Q=
|
||||
github.com/kaptinlin/messageformat-go v0.4.19 h1:A5kuuZ1ybXDQ7kD1aoEWGAOemX7hLsMY0yolgSbgpRI=
|
||||
github.com/kaptinlin/messageformat-go v0.4.19/go.mod h1:utSDTfiXTxl66OC5RIEuObLH7Ue3YjbA2X86SYMBYWg=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mailru/easyjson v0.9.2 h1:dX8U45hQsZpxd80nLvDGihsQ/OxlvTkVUXH2r/8cb2M=
|
||||
github.com/mailru/easyjson v0.9.2/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
|
||||
github.com/mark3labs/mcp-go v0.45.0 h1:s0S8qR/9fWaQ3pHxz7pm1uQ0DrswoSnRIxKIjbiQtkc=
|
||||
github.com/mark3labs/mcp-go v0.45.0/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw=
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0 h1:UtrWVfLdarDgc44HcS7pYloGHJUjHV/4FwW4TvVgFr4=
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mark3labs/mcp-go v0.46.0 h1:8KRibF4wcKejbLsHxCA/QBVUr5fQ9nwz/n8lGqmaALo=
|
||||
github.com/mark3labs/mcp-go v0.46.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
|
||||
github.com/mattn/go-runewidth v0.0.21 h1:jJKAZiQH+2mIinzCJIaIG9Be1+0NR+5sz/lYEEjdM8w=
|
||||
github.com/mattn/go-runewidth v0.0.21/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk=
|
||||
github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA=
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
@@ -228,22 +217,18 @@ github.com/muesli/mango-cobra v1.3.0 h1:vQy5GvPg3ndOSpduxutqFoINhWk3vD5K2dXo5E8p
|
||||
github.com/muesli/mango-cobra v1.3.0/go.mod h1:Cj1ZrBu3806Qw7UjxnAUgE+7tllUBj1NCLQDwwGx19E=
|
||||
github.com/muesli/mango-pflag v0.2.0 h1:QViokgKDZQCzKhYe1zH8D+UlPJzBSGoP9yx0hBG0t5k=
|
||||
github.com/muesli/mango-pflag v0.2.0/go.mod h1:X9LT1p/pbGA1wjvEbtwnixujKErkP0jVmrxwrw3fL0Y=
|
||||
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
|
||||
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
|
||||
github.com/muesli/roff v0.1.0 h1:YD0lalCotmYuF5HhZliKWlIx7IEhiXeSfq7hNjFqGF8=
|
||||
github.com/muesli/roff v0.1.0/go.mod h1:pjAHQM9hdUUwm/krAfrLGgJkXJ+YuhtsfZ42kieB2Ig=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pelletier/go-toml/v2 v2.3.0 h1:k59bC/lIZREW0/iVaQR8nDHxVq8OVlIzYCOJf421CaM=
|
||||
github.com/pelletier/go-toml/v2 v2.3.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo=
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
@@ -279,16 +264,12 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/traefik/yaegi v0.16.1 h1:f1De3DVJqIDKmnasUF6MwmWv1dSEEat0wcpXhD2On3E=
|
||||
github.com/traefik/yaegi v0.16.1/go.mod h1:4eVhbPb3LnD2VigQjhYbEJ69vDRFdT2HQNrXx8eEwUY=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/yuin/goldmark v1.7.17 h1:p36OVWwRb246iHxA/U4p8OPEpOTESm4n+g+8t0EE5uA=
|
||||
github.com/yuin/goldmark v1.7.17/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
|
||||
github.com/yuin/goldmark-emoji v1.0.6 h1:QWfF2FYaXwL74tfGOW5izeiZepUDroDJfWubQI9HTHs=
|
||||
github.com/yuin/goldmark-emoji v1.0.6/go.mod h1:ukxJDKFpdFb5x0a5HqbdlcKtebh086iJpI31LTKmWuA=
|
||||
github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE=
|
||||
github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 h1:yI1/OhfEPy7J9eoa6Sj051C7n5dvpj0QX8g4sRchg04=
|
||||
@@ -328,10 +309,14 @@ golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/api v0.272.0 h1:eLUQZGnAS3OHn31URRf9sAmRk3w2JjMx37d2k8AjJmA=
|
||||
google.golang.org/api v0.272.0/go.mod h1:wKjowi5LNJc5qarNvDCvNQBn3rVK8nSy6jg2SwRwzIA=
|
||||
google.golang.org/genai v1.51.0 h1:IZGuUqgfx40INv3hLFGCbOSGp0qFqm7LVmDghzNIYqg=
|
||||
google.golang.org/genai v1.51.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/api v0.273.0 h1:r/Bcv36Xa/te1ugaN1kdJ5LoA5Wj/cL+a4gj6FiPBjQ=
|
||||
google.golang.org/api v0.273.0/go.mod h1:JbAt7mF+XVmWu6xNP8/+CTiGH30ofmCmk9nM8d8fHew=
|
||||
google.golang.org/genai v1.52.0 h1:ekVIxWHtLUNbt+v0WWi4j3JT4yrHDEbysMcHQcaCQoI=
|
||||
google.golang.org/genai v1.52.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 h1:XzmzkmB14QhVhgnawEVsOn6OFsnpyxNPRY9QV01dNB0=
|
||||
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:L43LFes82YgSonw6iTXTxXUX1OlULt4AQtkik4ULL/I=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7 h1:41r6JMbpzBMen0R/4TZeeAmGXSJC7DftGINUodzTkPI=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:EIQZ5bFCfRQDV4MhRle7+OgjNtZ6P1PiZBgAKuxXu/Y=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7 h1:ndE4FoJqsIceKP2oYSnUZqhTdYufCYYkqwtFzfrhI7w=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE=
|
||||
|
||||
+337
-17
@@ -7,8 +7,11 @@ package acpserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
@@ -20,6 +23,17 @@ import (
|
||||
// Version is injected at build time; fallback to "dev".
|
||||
var Version = "dev"
|
||||
|
||||
// thinkingTagOpen and thinkingTagClose are the XML-style tags that some models
|
||||
// (Qwen, DeepSeek) wrap reasoning content in. We parse these to extract
|
||||
// reasoning/thinking content and send it as ACP thought updates.
|
||||
// Also support <think> format used by some models.
|
||||
const (
|
||||
thinkingTagOpen = "<thinking>"
|
||||
thinkingTagClose = "</thinking>"
|
||||
shortThinkTagOpen = "<think>"
|
||||
shortThinkTagClose = "</think>"
|
||||
)
|
||||
|
||||
// Agent implements the acp.Agent interface, delegating to Kit for LLM
|
||||
// execution, tool calls, and session management.
|
||||
type Agent struct {
|
||||
@@ -28,6 +42,10 @@ type Agent struct {
|
||||
|
||||
// toolCallCounter provides unique IDs for tool calls within a turn.
|
||||
toolCallCounter atomic.Int64
|
||||
|
||||
// inThinkingTag tracks whether we're currently inside a <thinking> tag
|
||||
// when parsing streaming content from models that wrap reasoning in XML tags.
|
||||
inThinkingTag bool
|
||||
}
|
||||
|
||||
// NewAgent creates a new ACP agent backed by Kit.
|
||||
@@ -111,13 +129,23 @@ func (a *Agent) Prompt(ctx context.Context, params acp.PromptRequest) (acp.Promp
|
||||
)
|
||||
}
|
||||
|
||||
// Extract text from prompt content blocks.
|
||||
promptText := extractPromptText(params.Prompt)
|
||||
if promptText == "" {
|
||||
// Extract text and file attachments from prompt content blocks.
|
||||
promptText, files := extractPromptContent(params.Prompt)
|
||||
if promptText == "" && len(files) == 0 {
|
||||
return acp.PromptResponse{}, acp.NewInvalidParams("empty prompt")
|
||||
}
|
||||
|
||||
log.Debug("acp: prompt", "session", sessionID, "prompt_len", len(promptText))
|
||||
// If we have files but no text prompt, add a default prompt
|
||||
// This is required because the underlying LLM library needs a non-empty prompt
|
||||
// when there are no previous messages in the conversation.
|
||||
if promptText == "" && len(files) > 0 {
|
||||
promptText = "Please analyze the attached file."
|
||||
}
|
||||
|
||||
log.Debug("acp: prompt", "session", sessionID, "prompt_len", len(promptText), "files", len(files))
|
||||
|
||||
// Reset thinking tag state for this new prompt turn
|
||||
a.inThinkingTag = false
|
||||
|
||||
// Create a cancellable context for this prompt turn.
|
||||
promptCtx, cancel := context.WithCancel(ctx)
|
||||
@@ -129,7 +157,13 @@ func (a *Agent) Prompt(ctx context.Context, params acp.PromptRequest) (acp.Promp
|
||||
defer unsub()
|
||||
|
||||
// Run the prompt through Kit's full turn lifecycle.
|
||||
_, err := sess.kit.PromptResult(promptCtx, promptText)
|
||||
// Use PromptResultWithFiles when file attachments are present.
|
||||
var err error
|
||||
if len(files) > 0 {
|
||||
_, err = sess.kit.PromptResultWithFiles(promptCtx, promptText, files)
|
||||
} else {
|
||||
_, err = sess.kit.PromptResult(promptCtx, promptText)
|
||||
}
|
||||
if err != nil {
|
||||
if promptCtx.Err() != nil {
|
||||
return acp.PromptResponse{
|
||||
@@ -162,6 +196,24 @@ func (a *Agent) SetSessionMode(_ context.Context, _ acp.SetSessionModeRequest) (
|
||||
return acp.SetSessionModeResponse{}, nil
|
||||
}
|
||||
|
||||
// SetSessionModel changes the active model for a session.
|
||||
func (a *Agent) SetSessionModel(ctx context.Context, params acp.SetSessionModelRequest) (acp.SetSessionModelResponse, error) {
|
||||
sessionID := string(params.SessionId)
|
||||
sess, ok := a.registry.get(sessionID)
|
||||
if !ok {
|
||||
return acp.SetSessionModelResponse{}, acp.NewInvalidParams(fmt.Sprintf("session not found: %s", sessionID))
|
||||
}
|
||||
|
||||
modelID := string(params.ModelId)
|
||||
log.Debug("acp: set_session_model", "session", sessionID, "model", modelID)
|
||||
|
||||
if err := sess.kit.SetModel(ctx, modelID); err != nil {
|
||||
return acp.SetSessionModelResponse{}, fmt.Errorf("set model: %w", err)
|
||||
}
|
||||
|
||||
return acp.SetSessionModelResponse{}, nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Event streaming: Kit events → ACP SessionUpdate notifications
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -178,8 +230,24 @@ func (a *Agent) subscribeEvents(ctx context.Context, k *kit.Kit, sessionID acp.S
|
||||
var update *acp.SessionUpdate
|
||||
switch ev := e.(type) {
|
||||
case kit.MessageUpdateEvent:
|
||||
u := acp.UpdateAgentMessageText(ev.Chunk)
|
||||
update = &u
|
||||
// Handle models that wrap reasoning in <thinking> tags (Qwen, DeepSeek)
|
||||
// Parse the chunk and separate reasoning from regular text
|
||||
reasoning, text := a.parseThinkingTags(ev.Chunk)
|
||||
|
||||
// Send reasoning update if we have reasoning content
|
||||
if reasoning != "" {
|
||||
u := acp.UpdateAgentThoughtText(reasoning)
|
||||
_ = a.conn.SessionUpdate(ctx, acp.SessionNotification{
|
||||
SessionId: sessionID,
|
||||
Update: u,
|
||||
})
|
||||
}
|
||||
|
||||
// Send text update if we have text content
|
||||
if text != "" {
|
||||
u := acp.UpdateAgentMessageText(text)
|
||||
update = &u
|
||||
}
|
||||
|
||||
case kit.ReasoningDeltaEvent:
|
||||
u := acp.UpdateAgentThoughtText(ev.Delta)
|
||||
@@ -231,19 +299,271 @@ func (a *Agent) subscribeEvents(ctx context.Context, k *kit.Kit, sessionID acp.S
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// extractPromptText extracts the concatenated text content from ACP content
|
||||
// blocks. Non-text blocks are ignored for now.
|
||||
func extractPromptText(blocks []acp.ContentBlock) string {
|
||||
var text string
|
||||
for _, block := range blocks {
|
||||
if block.Text != nil {
|
||||
if text != "" {
|
||||
text += "\n"
|
||||
// extractPromptContent extracts text and file attachments from ACP content blocks.
|
||||
// It converts supported content blocks (image, audio, resource) to Kit's LLMFilePart.
|
||||
func extractPromptContent(blocks []acp.ContentBlock) (string, []kit.LLMFilePart) {
|
||||
var textParts []string
|
||||
var files []kit.LLMFilePart
|
||||
|
||||
log.Debug("acp: extracting content", "blocks", len(blocks))
|
||||
|
||||
for i, block := range blocks {
|
||||
switch {
|
||||
// Text content
|
||||
case block.Text != nil:
|
||||
log.Debug("acp: content block", "index", i, "type", "text", "len", len(block.Text.Text))
|
||||
textParts = append(textParts, block.Text.Text)
|
||||
|
||||
// Image data (base64)
|
||||
case block.Image != nil:
|
||||
mimeType := block.Image.MimeType
|
||||
if mimeType == "" {
|
||||
mimeType = "image/png" // Default fallback
|
||||
}
|
||||
text += block.Text.Text
|
||||
log.Debug("acp: content block", "index", i, "type", "image", "mime", mimeType, "data_len", len(block.Image.Data))
|
||||
if data, err := base64.StdEncoding.DecodeString(block.Image.Data); err == nil {
|
||||
files = append(files, kit.LLMFilePart{
|
||||
Filename: "image.png",
|
||||
Data: data,
|
||||
MediaType: mimeType,
|
||||
})
|
||||
} else {
|
||||
log.Debug("acp: failed to decode image", "error", err)
|
||||
}
|
||||
|
||||
// Audio data (base64)
|
||||
case block.Audio != nil:
|
||||
mimeType := block.Audio.MimeType
|
||||
if mimeType == "" {
|
||||
mimeType = "audio/wav" // Default fallback
|
||||
}
|
||||
log.Debug("acp: content block", "index", i, "type", "audio", "mime", mimeType)
|
||||
if data, err := base64.StdEncoding.DecodeString(block.Audio.Data); err == nil {
|
||||
files = append(files, kit.LLMFilePart{
|
||||
Filename: "audio.wav",
|
||||
Data: data,
|
||||
MediaType: mimeType,
|
||||
})
|
||||
} else {
|
||||
log.Debug("acp: failed to decode audio", "error", err)
|
||||
}
|
||||
|
||||
// Embedded resource (text or binary file content)
|
||||
case block.Resource != nil:
|
||||
log.Debug("acp: content block", "index", i, "type", "resource")
|
||||
res := block.Resource.Resource
|
||||
// Text resource - append as text content with file reference
|
||||
if res.TextResourceContents != nil {
|
||||
uri := res.TextResourceContents.Uri
|
||||
content := res.TextResourceContents.Text
|
||||
mimeType := "text/plain"
|
||||
if res.TextResourceContents.MimeType != nil {
|
||||
mimeType = *res.TextResourceContents.MimeType
|
||||
}
|
||||
log.Debug("acp: text resource", "uri", uri, "mime", mimeType, "len", len(content))
|
||||
// Text files are included as formatted text, NOT as FilePart
|
||||
// FilePart is for binary files (images, audio, PDFs) only
|
||||
textParts = append(textParts, fmt.Sprintf("[File: %s]\n```\n%s\n```", uri, content))
|
||||
}
|
||||
// Binary resource (base64 blob) - these become FilePart
|
||||
if res.BlobResourceContents != nil {
|
||||
uri := res.BlobResourceContents.Uri
|
||||
mimeType := "application/octet-stream"
|
||||
if res.BlobResourceContents.MimeType != nil {
|
||||
mimeType = *res.BlobResourceContents.MimeType
|
||||
}
|
||||
log.Debug("acp: binary resource", "uri", uri, "mime", mimeType, "blob_len", len(res.BlobResourceContents.Blob))
|
||||
if data, err := base64.StdEncoding.DecodeString(res.BlobResourceContents.Blob); err == nil {
|
||||
files = append(files, kit.LLMFilePart{
|
||||
Filename: extractFilenameFromURI(uri),
|
||||
Data: data,
|
||||
MediaType: mimeType,
|
||||
})
|
||||
} else {
|
||||
log.Debug("acp: failed to decode binary resource", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Resource link (file reference without embedded content)
|
||||
case block.ResourceLink != nil:
|
||||
uri := block.ResourceLink.Uri
|
||||
name := block.ResourceLink.Name
|
||||
log.Debug("acp: content block", "index", i, "type", "resource_link", "uri", uri, "name", name)
|
||||
// For resource links, we'll try to read the file from disk
|
||||
// This requires the file URI to be accessible (file:// scheme)
|
||||
if content, err := readResourceFromURI(uri); err == nil {
|
||||
// Detect if it's a text file or binary file
|
||||
mimeType := "text/plain"
|
||||
if block.ResourceLink.MimeType != nil {
|
||||
mimeType = *block.ResourceLink.MimeType
|
||||
}
|
||||
log.Debug("acp: resource link loaded", "uri", uri, "mime", mimeType, "size", len(content))
|
||||
|
||||
// Only create FilePart for binary files (images, audio, PDFs, etc.)
|
||||
// Text files are included as formatted text in the message
|
||||
if isTextMimeType(mimeType) || looksLikeText(content) {
|
||||
textParts = append(textParts, fmt.Sprintf("[File: %s]\n```\n%s\n```", uri, string(content)))
|
||||
} else {
|
||||
// Binary file - create FilePart for models that support it
|
||||
files = append(files, kit.LLMFilePart{
|
||||
Filename: extractFilenameFromURI(uri),
|
||||
Data: content,
|
||||
MediaType: mimeType,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// If we can't read it, include as a text reference
|
||||
log.Debug("acp: resource link failed to load", "uri", uri, "error", err)
|
||||
textParts = append(textParts, fmt.Sprintf("[Referenced file: %s]", uri))
|
||||
}
|
||||
|
||||
default:
|
||||
log.Debug("acp: content block", "index", i, "type", "unknown/unhandled")
|
||||
}
|
||||
}
|
||||
return text
|
||||
|
||||
// Debug log the extracted content
|
||||
for i, f := range files {
|
||||
log.Debug("acp: extracted file", "index", i, "filename", f.Filename, "mime", f.MediaType, "size", len(f.Data))
|
||||
}
|
||||
|
||||
return strings.Join(textParts, "\n"), files
|
||||
}
|
||||
|
||||
// parseThinkingTags parses a text chunk for <thinking> or tags and separates
|
||||
// reasoning content from regular text. This handles models (Qwen, DeepSeek)
|
||||
// that wrap reasoning in XML-style tags instead of using proper reasoning events.
|
||||
// Returns (reasoningContent, textContent).
|
||||
func (a *Agent) parseThinkingTags(chunk string) (reasoning string, text string) {
|
||||
// Handle empty chunk
|
||||
if chunk == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// Determine which tag format to use (long or short)
|
||||
openTag := thinkingTagOpen
|
||||
closeTag := thinkingTagClose
|
||||
|
||||
if strings.Contains(chunk, shortThinkTagOpen) || strings.Contains(chunk, shortThinkTagClose) {
|
||||
openTag = shortThinkTagOpen
|
||||
closeTag = shortThinkTagClose
|
||||
} else if !strings.Contains(chunk, thinkingTagOpen) && !strings.Contains(chunk, thinkingTagClose) && !a.inThinkingTag {
|
||||
// No tags at all and not in thinking mode - return as text
|
||||
return "", chunk
|
||||
}
|
||||
|
||||
// Check for opening tag
|
||||
if strings.Contains(chunk, openTag) {
|
||||
parts := strings.SplitN(chunk, openTag, 2)
|
||||
|
||||
// Content before the opening tag is regular text
|
||||
if !a.inThinkingTag && parts[0] != "" {
|
||||
text = parts[0]
|
||||
}
|
||||
|
||||
a.inThinkingTag = true
|
||||
|
||||
// Content after the opening tag is reasoning
|
||||
if len(parts) > 1 {
|
||||
// Check if the same chunk contains the closing tag
|
||||
if strings.Contains(parts[1], closeTag) {
|
||||
innerParts := strings.SplitN(parts[1], closeTag, 2)
|
||||
reasoning = innerParts[0]
|
||||
a.inThinkingTag = false
|
||||
|
||||
// Content after closing tag is regular text
|
||||
if len(innerParts) > 1 && innerParts[1] != "" {
|
||||
text += innerParts[1]
|
||||
}
|
||||
} else if parts[1] != "" {
|
||||
// No closing tag yet, all remaining content is reasoning
|
||||
reasoning = parts[1]
|
||||
}
|
||||
}
|
||||
return reasoning, text
|
||||
}
|
||||
|
||||
// Check for closing tag
|
||||
if strings.Contains(chunk, closeTag) {
|
||||
parts := strings.SplitN(chunk, closeTag, 2)
|
||||
a.inThinkingTag = false
|
||||
|
||||
// Content before closing tag is reasoning
|
||||
reasoning = parts[0]
|
||||
|
||||
// Content after closing tag is regular text
|
||||
if len(parts) > 1 && parts[1] != "" {
|
||||
text = parts[1]
|
||||
}
|
||||
return reasoning, text
|
||||
}
|
||||
|
||||
// No tags found - content goes to current mode
|
||||
if a.inThinkingTag {
|
||||
return chunk, ""
|
||||
}
|
||||
return "", chunk
|
||||
}
|
||||
|
||||
// isTextMimeType returns true if the MIME type indicates text content.
|
||||
func isTextMimeType(mimeType string) bool {
|
||||
return strings.HasPrefix(mimeType, "text/") ||
|
||||
mimeType == "application/json" ||
|
||||
mimeType == "application/xml" ||
|
||||
mimeType == "application/javascript" ||
|
||||
mimeType == "application/typescript" ||
|
||||
mimeType == "application/x-sh" ||
|
||||
mimeType == "application/x-python" ||
|
||||
mimeType == "application/x-yaml" ||
|
||||
mimeType == "application/x-toml"
|
||||
}
|
||||
|
||||
// looksLikeText checks if the content appears to be text (not binary).
|
||||
// It samples the first 512 bytes and checks for null bytes or high
|
||||
// concentration of non-printable characters.
|
||||
func looksLikeText(data []byte) bool {
|
||||
if len(data) == 0 {
|
||||
return true
|
||||
}
|
||||
// Check first 512 bytes (or less if file is smaller)
|
||||
sampleSize := min(len(data), 512)
|
||||
sample := data[:sampleSize]
|
||||
|
||||
// Count non-printable characters
|
||||
nonPrintable := 0
|
||||
for _, b := range sample {
|
||||
// Null byte indicates binary
|
||||
if b == 0 {
|
||||
return false
|
||||
}
|
||||
// Count control characters (except common whitespace)
|
||||
if b < 32 && b != '\n' && b != '\r' && b != '\t' {
|
||||
nonPrintable++
|
||||
}
|
||||
}
|
||||
|
||||
// If more than 30% non-printable, consider it binary
|
||||
return float64(nonPrintable)/float64(sampleSize) < 0.3
|
||||
}
|
||||
|
||||
// extractFilenameFromURI extracts a filename from a file URI or path.
|
||||
func extractFilenameFromURI(uri string) string {
|
||||
// Handle file:// URIs
|
||||
uri = strings.TrimPrefix(uri, "file://")
|
||||
// Extract basename
|
||||
if idx := strings.LastIndex(uri, "/"); idx >= 0 {
|
||||
return uri[idx+1:]
|
||||
}
|
||||
return uri
|
||||
}
|
||||
|
||||
// readResourceFromURI attempts to read file content from a file:// URI.
|
||||
func readResourceFromURI(uri string) ([]byte, error) {
|
||||
if !strings.HasPrefix(uri, "file://") {
|
||||
return nil, fmt.Errorf("unsupported URI scheme: %s", uri)
|
||||
}
|
||||
path := uri[7:] // Remove file:// prefix
|
||||
return os.ReadFile(path)
|
||||
}
|
||||
|
||||
// parseToolArgs attempts to parse a JSON tool args string into a map for
|
||||
|
||||
@@ -62,8 +62,8 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
|
||||
// work in ACP mode. TUI-dependent features (widgets, prompts, editor)
|
||||
// become no-ops or return cancelled; all data/model/tool APIs work
|
||||
// identically to interactive mode.
|
||||
if kitInstance.HasExtensions() {
|
||||
kitInstance.SetExtensionContext(extensions.Context{
|
||||
if kitInstance.Extensions().HasExtensions() {
|
||||
kitInstance.Extensions().SetContext(extensions.Context{
|
||||
SessionID: sessionID,
|
||||
CWD: cwd,
|
||||
Model: kitInstance.GetModelString(),
|
||||
@@ -121,31 +121,31 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
|
||||
MessageCount: s.MessageCount,
|
||||
}
|
||||
},
|
||||
GetMessages: func() []extensions.SessionMessage { return kitInstance.GetSessionMessages() },
|
||||
GetSessionPath: func() string { return kitInstance.GetSessionFilePath() },
|
||||
GetMessages: func() []extensions.SessionMessage { return kitInstance.Extensions().GetSessionMessages() },
|
||||
GetSessionPath: func() string { return kitInstance.GetSessionPath() },
|
||||
AppendEntry: func(entryType, data string) (string, error) {
|
||||
return kitInstance.AppendExtensionEntry(entryType, data)
|
||||
return kitInstance.Extensions().AppendEntry(entryType, data)
|
||||
},
|
||||
GetEntries: func(entryType string) []extensions.ExtensionEntry {
|
||||
return kitInstance.GetExtensionEntries(entryType)
|
||||
return kitInstance.Extensions().GetEntries(entryType)
|
||||
},
|
||||
|
||||
// Options, model, and tool management.
|
||||
GetOption: func(name string) string { return kitInstance.GetExtensionOption(name) },
|
||||
SetOption: func(name, value string) { kitInstance.SetExtensionOption(name, value) },
|
||||
GetOption: func(name string) string { return kitInstance.Extensions().GetOption(name) },
|
||||
SetOption: func(name, value string) { kitInstance.Extensions().SetOption(name, value) },
|
||||
SetModel: func(modelString string) error {
|
||||
previousModel := kitInstance.GetExtensionContext().Model
|
||||
previousModel := kitInstance.Extensions().GetContext().Model
|
||||
if err := kitInstance.SetModel(context.Background(), modelString); err != nil {
|
||||
return err
|
||||
}
|
||||
kitInstance.UpdateExtensionContextModel(modelString)
|
||||
kitInstance.EmitModelChange(modelString, previousModel, "extension")
|
||||
kitInstance.Extensions().UpdateContextModel(modelString)
|
||||
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
|
||||
return nil
|
||||
},
|
||||
GetAvailableModels: func() []extensions.ModelInfoEntry { return kitInstance.GetAvailableModels() },
|
||||
EmitCustomEvent: func(name, data string) { kitInstance.EmitExtensionCustomEvent(name, data) },
|
||||
GetAllTools: func() []extensions.ToolInfo { return kitInstance.GetExtensionToolInfos() },
|
||||
SetActiveTools: func(names []string) { kitInstance.SetExtensionActiveTools(names) },
|
||||
EmitCustomEvent: func(name, data string) { kitInstance.Extensions().EmitCustomEvent(name, data) },
|
||||
GetAllTools: func() []extensions.ToolInfo { return kitInstance.Extensions().GetToolInfos() },
|
||||
SetActiveTools: func(names []string) { kitInstance.Extensions().SetActiveTools(names) },
|
||||
|
||||
// LLM completions and subagents.
|
||||
Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
|
||||
@@ -173,7 +173,7 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
|
||||
}
|
||||
extResult := &extensions.SubagentResult{
|
||||
Response: result.Response,
|
||||
Error: result.Error,
|
||||
Error: err,
|
||||
SessionID: result.SessionID,
|
||||
Elapsed: result.Elapsed,
|
||||
}
|
||||
@@ -188,15 +188,15 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
|
||||
|
||||
// Render — fall back to logging.
|
||||
RenderMessage: func(name, content string) {
|
||||
renderer := kitInstance.GetExtensionMessageRenderer(name)
|
||||
renderer := kitInstance.Extensions().GetMessageRenderer(name)
|
||||
if renderer != nil && renderer.Render != nil {
|
||||
content = renderer.Render(content, 80)
|
||||
}
|
||||
log.Info("extension: message", "renderer", name, "content", content)
|
||||
},
|
||||
ReloadExtensions: func() error { return kitInstance.ReloadExtensions() },
|
||||
ReloadExtensions: func() error { return kitInstance.Extensions().Reload() },
|
||||
})
|
||||
kitInstance.EmitSessionStart()
|
||||
kitInstance.Extensions().EmitSessionStart()
|
||||
}
|
||||
|
||||
sess := &acpSession{
|
||||
|
||||
+83
-32
@@ -31,7 +31,7 @@ type AgentConfig struct {
|
||||
CoreTools []fantasy.AgentTool
|
||||
|
||||
// ToolWrapper is an optional function that wraps the combined tool list
|
||||
// before it is passed to the Fantasy agent. Used by the extensions system
|
||||
// before it is passed to the LLM agent. Used by the extensions system
|
||||
// to intercept tool calls/results.
|
||||
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
|
||||
|
||||
@@ -75,9 +75,9 @@ type ToolOutputHandler = core.ToolOutputCallback
|
||||
// tracking during long-running tool-calling conversations.
|
||||
type StepUsageHandler func(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64)
|
||||
|
||||
// Agent represents an AI agent with core tool integration using the fantasy library.
|
||||
// Agent represents an AI agent with core tool integration using the LLM library.
|
||||
// Core tools (bash, read, write, edit, grep, find, ls) are registered as direct
|
||||
// fantasy.AgentTool implementations — no MCP layer, no serialization overhead.
|
||||
// AgentTool implementations — no MCP layer, no serialization overhead.
|
||||
// Additional tools from external MCP servers can be loaded alongside core tools.
|
||||
type Agent struct {
|
||||
toolManager *tools.MCPToolManager
|
||||
@@ -100,7 +100,7 @@ type GenerateWithLoopResult struct {
|
||||
FinalResponse *fantasy.Response
|
||||
// ConversationMessages contains all messages in the conversation including tool calls and results
|
||||
ConversationMessages []fantasy.Message
|
||||
// Messages contains the conversation as custom content blocks (crush-style)
|
||||
// Messages contains the conversation as custom content blocks
|
||||
Messages []message.Message
|
||||
// TotalUsage contains aggregate token usage across all steps
|
||||
TotalUsage fantasy.Usage
|
||||
@@ -112,13 +112,13 @@ type GenerateWithLoopResult struct {
|
||||
// Core tools (bash, read, write, edit, grep, find, ls) are always registered.
|
||||
// External MCP tools are loaded from the config if any MCP servers are configured.
|
||||
func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
// Create the LLM provider via fantasy
|
||||
// Create the LLM provider
|
||||
providerResult, err := models.CreateProvider(ctx, agentConfig.ModelConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create model provider: %v", err)
|
||||
}
|
||||
|
||||
// Register core tools (direct fantasy implementations, no MCP overhead).
|
||||
// Register core tools (direct AgentTool implementations, no MCP overhead).
|
||||
// Use caller-provided tools if set, otherwise default to all core tools.
|
||||
coreTools := agentConfig.CoreTools
|
||||
if len(coreTools) == 0 {
|
||||
@@ -158,7 +158,7 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
allTools = agentConfig.ToolWrapper(allTools)
|
||||
}
|
||||
|
||||
// Build fantasy agent options
|
||||
// Build agent options
|
||||
var agentOpts []fantasy.AgentOption
|
||||
|
||||
if agentConfig.SystemPrompt != "" {
|
||||
@@ -183,7 +183,8 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
|
||||
// Pass generation parameters when available.
|
||||
if agentConfig.ModelConfig != nil {
|
||||
if agentConfig.ModelConfig.MaxTokens > 0 {
|
||||
// Skip max_output_tokens for providers that don't support it (e.g., Codex OAuth)
|
||||
if agentConfig.ModelConfig.MaxTokens > 0 && !providerResult.SkipMaxOutputTokens {
|
||||
agentOpts = append(agentOpts, fantasy.WithMaxOutputTokens(int64(agentConfig.ModelConfig.MaxTokens)))
|
||||
}
|
||||
if agentConfig.ModelConfig.Temperature != nil {
|
||||
@@ -197,7 +198,7 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Create the fantasy agent
|
||||
// Create the agent
|
||||
fantasyAgent := fantasy.NewAgent(providerResult.Model, agentOpts...)
|
||||
|
||||
// Determine provider type from model string
|
||||
@@ -233,8 +234,8 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []fantasy.Message
|
||||
onResponse, onToolCallContent, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
// GenerateWithLoopAndStreaming processes messages using the fantasy agent with streaming and callbacks.
|
||||
// Fantasy handles the tool call loop internally. We map fantasy's rich callback system
|
||||
// GenerateWithLoopAndStreaming processes messages using the agent with streaming and callbacks.
|
||||
// The agent handles the tool call loop internally. We map the rich callback system
|
||||
// to kit's existing callback interface for UI integration.
|
||||
func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fantasy.Message,
|
||||
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler,
|
||||
@@ -250,18 +251,21 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
ctx = core.ContextWithToolOutputCallback(ctx, onToolOutput)
|
||||
}
|
||||
|
||||
// Fantasy requires the current user input as Prompt, with prior messages as history.
|
||||
// The agent requires the current user input as Prompt, with prior messages as history.
|
||||
// Extract the last user message text and files as the prompt, and pass everything
|
||||
// before it as Messages. Files (e.g. clipboard images) are passed via the Files
|
||||
// field so Fantasy includes them in the API request.
|
||||
// field so the agent includes them in the API request.
|
||||
prompt, files, history := splitPromptAndHistory(messages)
|
||||
|
||||
// Track current tool call info for callbacks
|
||||
var currentToolName string
|
||||
// Apply message-level cache control for Anthropic models.
|
||||
// This avoids type conflicts with provider-level options.
|
||||
history = applyCacheControlToMessages(history)
|
||||
|
||||
// Track current tool call args for callbacks
|
||||
var currentToolArgs string
|
||||
|
||||
// Use the streaming path when streaming is enabled OR when any callbacks are
|
||||
// provided. Fantasy only exposes tool/step callbacks on AgentStreamCall, so
|
||||
// provided. The agent only exposes tool/step callbacks on AgentStreamCall, so
|
||||
// Stream is required to observe tool execution in real time. The non-streaming
|
||||
// Generate path is reserved for the simple case with no callbacks at all.
|
||||
hasCallbacks := onToolCall != nil || onToolExecution != nil || onToolResult != nil ||
|
||||
@@ -269,13 +273,13 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
|
||||
if a.streamingEnabled || hasCallbacks {
|
||||
// Track completed step messages so we can return partial results
|
||||
// on cancellation. Fantasy's Stream() discards accumulated steps
|
||||
// on cancellation. The agent's Stream() discards accumulated steps
|
||||
// when it returns an error, but the OnStepFinish callback fires
|
||||
// for every step that completed before the error occurred.
|
||||
var completedStepMessages []fantasy.Message
|
||||
|
||||
// Use fantasy's streaming agent
|
||||
result, err := a.fantasyAgent.Stream(ctx, fantasy.AgentStreamCall{
|
||||
// Use the streaming agent
|
||||
streamCall := fantasy.AgentStreamCall{
|
||||
Prompt: prompt,
|
||||
Files: files,
|
||||
Messages: history,
|
||||
@@ -307,7 +311,6 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
currentToolName = tc.ToolName
|
||||
currentToolArgs = tc.Input
|
||||
|
||||
// Notify about the tool call
|
||||
@@ -364,7 +367,56 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// If a steer channel is attached to the context, wire up a
|
||||
// PrepareStep function that drains the channel between steps
|
||||
// and injects pending steer messages as user messages before
|
||||
// the next LLM call. This enables graceful mid-turn steering
|
||||
// without cancelling in-progress tool execution.
|
||||
if steerCh := steerChFromContext(ctx); steerCh != nil {
|
||||
onConsumed := steerConsumedFromContext(ctx)
|
||||
streamCall.PrepareStep = func(
|
||||
stepCtx context.Context,
|
||||
opts fantasy.PrepareStepFunctionOptions,
|
||||
) (context.Context, fantasy.PrepareStepResult, error) {
|
||||
// Drain all pending steer messages (non-blocking).
|
||||
var steered []string
|
||||
for {
|
||||
select {
|
||||
case msg := <-steerCh:
|
||||
steered = append(steered, msg)
|
||||
default:
|
||||
goto done
|
||||
}
|
||||
}
|
||||
done:
|
||||
result := fantasy.PrepareStepResult{
|
||||
Model: opts.Model,
|
||||
Messages: opts.Messages,
|
||||
}
|
||||
if len(steered) > 0 {
|
||||
// Inject each steer message as a user message so the
|
||||
// LLM sees the redirection on the next step.
|
||||
for _, text := range steered {
|
||||
result.Messages = append(result.Messages,
|
||||
fantasy.NewUserMessage(text))
|
||||
}
|
||||
// Notify that steer messages were consumed.
|
||||
if onConsumed != nil {
|
||||
onConsumed(len(steered))
|
||||
}
|
||||
}
|
||||
|
||||
// Apply message-level cache control for Anthropic models.
|
||||
// This avoids type conflicts with provider-level options.
|
||||
result.Messages = applyCacheControlToMessages(result.Messages)
|
||||
|
||||
return stepCtx, result, nil
|
||||
}
|
||||
}
|
||||
|
||||
result, err := a.fantasyAgent.Stream(ctx, streamCall)
|
||||
if err != nil {
|
||||
// On cancellation (or any error), return a partial result
|
||||
// containing messages from completed steps so the caller can
|
||||
@@ -407,13 +459,11 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
onResponse(result.Response.Content.Text())
|
||||
}
|
||||
|
||||
_ = currentToolName // satisfy compiler for non-streaming path
|
||||
|
||||
return convertAgentResult(result, messages), nil
|
||||
}
|
||||
|
||||
// splitPromptAndHistory extracts the last user message as the prompt string,
|
||||
// and returns everything before it as conversation history. Fantasy's agent
|
||||
// and returns everything before it as conversation history. The agent's
|
||||
// requires the current turn's input as Prompt (string), with prior messages
|
||||
// passed separately as Messages (history).
|
||||
func splitPromptAndHistory(messages []fantasy.Message) (string, []fantasy.FilePart, []fantasy.Message) {
|
||||
@@ -456,8 +506,8 @@ func splitPromptAndHistory(messages []fantasy.Message) (string, []fantasy.FilePa
|
||||
return "", nil, messages
|
||||
}
|
||||
|
||||
// convertAgentResult converts a fantasy AgentResult to our GenerateWithLoopResult.
|
||||
// It builds both the legacy fantasy.Message slice and the new custom content blocks.
|
||||
// convertAgentResult converts an AgentResult to our GenerateWithLoopResult.
|
||||
// It builds both the message slice and the new custom content blocks.
|
||||
func convertAgentResult(result *fantasy.AgentResult, originalMessages []fantasy.Message) *GenerateWithLoopResult {
|
||||
// Collect all conversation messages: original + all step messages
|
||||
var allFantasyMessages []fantasy.Message
|
||||
@@ -470,7 +520,7 @@ func convertAgentResult(result *fantasy.AgentResult, originalMessages []fantasy.
|
||||
// Convert to custom content blocks
|
||||
var allMessages []message.Message
|
||||
for _, fm := range allFantasyMessages {
|
||||
allMessages = append(allMessages, message.FromFantasyMessage(fm))
|
||||
allMessages = append(allMessages, message.FromLLMMessage(fm))
|
||||
}
|
||||
|
||||
return &GenerateWithLoopResult{
|
||||
@@ -482,7 +532,7 @@ func convertAgentResult(result *fantasy.AgentResult, originalMessages []fantasy.
|
||||
}
|
||||
}
|
||||
|
||||
// extractToolResultText extracts the text and error status from a fantasy ToolResultContent.
|
||||
// extractToolResultText extracts the text and error status from a ToolResultContent.
|
||||
// For core tools, the result is already clean text (no MCP JSON wrapping).
|
||||
// For MCP tools, it unwraps the MCP content structure.
|
||||
func extractToolResultText(tr fantasy.ToolResultContent) (string, bool) {
|
||||
@@ -495,7 +545,7 @@ func extractToolResultText(tr fantasy.ToolResultContent) (string, bool) {
|
||||
return errResult.Error.Error(), true
|
||||
}
|
||||
|
||||
// Get text directly from the Fantasy result type.
|
||||
// Get text directly from the result type.
|
||||
if textResult, ok := tr.Result.(fantasy.ToolResultOutputContentText); ok {
|
||||
// Try to unwrap MCP JSON structure (for external MCP tools).
|
||||
// Core tools return plain text, so this is a no-op for them.
|
||||
@@ -608,7 +658,7 @@ func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) err
|
||||
allTools = a.toolWrapper(allTools)
|
||||
}
|
||||
|
||||
// Rebuild fantasy agent options.
|
||||
// Rebuild agent options.
|
||||
var agentOpts []fantasy.AgentOption
|
||||
if a.systemPrompt != "" {
|
||||
agentOpts = append(agentOpts, fantasy.WithSystemPrompt(a.systemPrompt))
|
||||
@@ -628,7 +678,8 @@ func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) err
|
||||
}
|
||||
|
||||
// Pass generation parameters when available.
|
||||
if config.MaxTokens > 0 {
|
||||
// Skip max_output_tokens for providers that don't support it (e.g., Codex OAuth)
|
||||
if config.MaxTokens > 0 && !providerResult.SkipMaxOutputTokens {
|
||||
agentOpts = append(agentOpts, fantasy.WithMaxOutputTokens(int64(config.MaxTokens)))
|
||||
}
|
||||
if config.Temperature != nil {
|
||||
@@ -668,7 +719,7 @@ func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) err
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetModel returns the underlying fantasy LanguageModel.
|
||||
// GetModel returns the underlying LanguageModel.
|
||||
func (a *Agent) GetModel() fantasy.LanguageModel {
|
||||
return a.model
|
||||
}
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"charm.land/fantasy"
|
||||
"charm.land/fantasy/providers/anthropic"
|
||||
)
|
||||
|
||||
// cacheControlOptions returns provider options for Anthropic cache control.
|
||||
// This is used at the message level to avoid type conflicts with provider-level options.
|
||||
func cacheControlOptions() fantasy.ProviderOptions {
|
||||
return anthropic.NewProviderCacheControlOptions(&anthropic.ProviderCacheControlOptions{
|
||||
CacheControl: anthropic.CacheControl{
|
||||
Type: "ephemeral",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// applyCacheControlToMessages adds cache control to specific messages.
|
||||
// Anthropic allows max 4 cache blocks per request.
|
||||
// Counts existing cache blocks and only adds new ones up to the limit.
|
||||
func applyCacheControlToMessages(messages []fantasy.Message) []fantasy.Message {
|
||||
if len(messages) == 0 {
|
||||
return messages
|
||||
}
|
||||
|
||||
// Make a copy to avoid modifying the original slice
|
||||
result := make([]fantasy.Message, len(messages))
|
||||
copy(result, messages)
|
||||
|
||||
cacheOpts := cacheControlOptions()
|
||||
maxCacheBlocks := 4
|
||||
|
||||
// Helper to check if message already has cache control
|
||||
hasCache := func(msg fantasy.Message) bool {
|
||||
if msg.ProviderOptions == nil {
|
||||
return false
|
||||
}
|
||||
if _, ok := msg.ProviderOptions["anthropic"]; ok {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Count existing cache blocks
|
||||
existingCacheCount := 0
|
||||
for _, msg := range result {
|
||||
if hasCache(msg) {
|
||||
existingCacheCount++
|
||||
}
|
||||
}
|
||||
|
||||
// If we're already at or over the limit, don't add more
|
||||
if existingCacheCount >= maxCacheBlocks {
|
||||
return result
|
||||
}
|
||||
|
||||
// How many new cache blocks can we add?
|
||||
remaining := maxCacheBlocks - existingCacheCount
|
||||
|
||||
// First: find and cache the last system message (most important)
|
||||
lastSystemIdx := -1
|
||||
for i, msg := range result {
|
||||
if msg.Role == fantasy.MessageRoleSystem {
|
||||
lastSystemIdx = i
|
||||
}
|
||||
}
|
||||
|
||||
if lastSystemIdx >= 0 && remaining > 0 && !hasCache(result[lastSystemIdx]) {
|
||||
result[lastSystemIdx].ProviderOptions = cacheOpts
|
||||
remaining--
|
||||
}
|
||||
|
||||
// Second: cache the most recent messages (up to remaining limit)
|
||||
// Work backwards from the end to prioritize recent context
|
||||
for i := len(result) - 1; i >= 0 && remaining > 0; i-- {
|
||||
if hasCache(result[i]) {
|
||||
continue
|
||||
}
|
||||
result[i].ProviderOptions = cacheOpts
|
||||
remaining--
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -39,7 +39,7 @@ type AgentCreationOptions struct {
|
||||
// CoreTools overrides the default core tool set. If empty, core.AllTools()
|
||||
// is used.
|
||||
CoreTools []fantasy.AgentTool
|
||||
// ToolWrapper wraps the combined tool list before Fantasy agent creation.
|
||||
// ToolWrapper wraps the combined tool list before agent creation.
|
||||
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
|
||||
// ExtraTools are additional tools to include (e.g. from extensions).
|
||||
ExtraTools []fantasy.AgentTool
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
package agent
|
||||
|
||||
import "context"
|
||||
|
||||
// steerChKey is the context key for the steer channel.
|
||||
type steerChKey struct{}
|
||||
|
||||
// steerConsumedKey is the context key for the steer-consumed callback.
|
||||
type steerConsumedKey struct{}
|
||||
|
||||
// ContextWithSteerCh returns a new context with the steer channel attached.
|
||||
// The agent's PrepareStep function checks this channel between steps and
|
||||
// injects any pending steer messages as user messages before the next LLM call.
|
||||
func ContextWithSteerCh(ctx context.Context, ch <-chan string) context.Context {
|
||||
return context.WithValue(ctx, steerChKey{}, ch)
|
||||
}
|
||||
|
||||
// ContextWithSteerConsumed returns a new context with a callback that fires
|
||||
// when steer messages are consumed by PrepareStep. The count argument is the
|
||||
// number of messages injected in this batch.
|
||||
func ContextWithSteerConsumed(ctx context.Context, fn func(count int)) context.Context {
|
||||
return context.WithValue(ctx, steerConsumedKey{}, fn)
|
||||
}
|
||||
|
||||
// steerChFromContext extracts the steer channel from the context, or nil.
|
||||
func steerChFromContext(ctx context.Context) <-chan string {
|
||||
ch, _ := ctx.Value(steerChKey{}).(<-chan string)
|
||||
return ch
|
||||
}
|
||||
|
||||
// steerConsumedFromContext extracts the steer-consumed callback, or nil.
|
||||
func steerConsumedFromContext(ctx context.Context) func(int) {
|
||||
fn, _ := ctx.Value(steerConsumedKey{}).(func(int))
|
||||
return fn
|
||||
}
|
||||
+266
-57
@@ -3,7 +3,11 @@ package app
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/fantasy"
|
||||
@@ -16,7 +20,7 @@ import (
|
||||
// queueItem holds a prompt and optional image attachments for the execution queue.
|
||||
type queueItem struct {
|
||||
Prompt string
|
||||
Files []fantasy.FilePart
|
||||
Files []kit.LLMFilePart
|
||||
}
|
||||
|
||||
// App is the application-layer orchestrator. It owns the agentic loop,
|
||||
@@ -65,11 +69,20 @@ type App struct {
|
||||
// rootCtx/rootCancel are used to signal shutdown to all goroutines.
|
||||
rootCtx context.Context
|
||||
rootCancel context.CancelFunc
|
||||
|
||||
// widgetUpdatePending is set to true when a WidgetUpdateEvent has been
|
||||
// sent to the TUI but not yet consumed by its event loop. While the flag
|
||||
// is set, subsequent NotifyWidgetUpdate calls are coalesced (dropped) to
|
||||
// prevent fast extension tickers from flooding the BubbleTea mailbox with
|
||||
// redundant re-render triggers. The flag is cleared after a short debounce
|
||||
// (~1 frame) so new updates are always let through once the TUI has had a
|
||||
// chance to process the pending event.
|
||||
widgetUpdatePending atomic.Bool
|
||||
}
|
||||
|
||||
// New creates a new App with the provided options and pre-loaded messages.
|
||||
// initialMessages may be nil or empty for a fresh session.
|
||||
func New(opts Options, initialMessages []fantasy.Message) *App {
|
||||
func New(opts Options, initialMessages []kit.LLMMessage) *App {
|
||||
rootCtx, rootCancel := context.WithCancel(context.Background())
|
||||
return &App{
|
||||
opts: opts,
|
||||
@@ -113,9 +126,8 @@ func (a *App) Run(prompt string) int {
|
||||
// If the app is idle the prompt executes immediately; otherwise it is queued.
|
||||
// Returns the current queue depth (0 = started immediately, >0 = queued).
|
||||
//
|
||||
// Satisfies ui.AppController (via RunWithImages which converts ImageAttachment
|
||||
// to fantasy.FilePart).
|
||||
func (a *App) RunWithFiles(prompt string, files []fantasy.FilePart) int {
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) RunWithFiles(prompt string, files []kit.LLMFilePart) int {
|
||||
a.mu.Lock()
|
||||
|
||||
if a.closed {
|
||||
@@ -159,11 +171,57 @@ func (a *App) QueueLength() int {
|
||||
return len(a.queue)
|
||||
}
|
||||
|
||||
// Steer cancels the current agent step (if running), clears the queue, and
|
||||
// sends a new message that will execute as soon as the current step finishes
|
||||
// cancelling. If the agent is idle, the message executes immediately.
|
||||
// This is the "steer" delivery mode for SendMessage.
|
||||
func (a *App) Steer(prompt string) {
|
||||
// Steer injects a steering message into the currently running agent turn.
|
||||
// If the agent is in a multi-step tool loop, the message is delivered after
|
||||
// the current tool execution finishes but before the next LLM call (graceful
|
||||
// mid-turn injection via Fantasy's PrepareStep). If the agent is streaming
|
||||
// a text-only response (no pending tool calls), the message waits until the
|
||||
// response completes and then executes as the next turn.
|
||||
//
|
||||
// If the agent is idle, the message starts executing immediately (same as Run).
|
||||
//
|
||||
// Returns the number of pending steer/queue items (0 = started immediately,
|
||||
// >0 = injected/queued). The caller must update UI state based on the return
|
||||
// value — Steer does NOT send events to the program to avoid deadlocking
|
||||
// when called from within Update().
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) Steer(prompt string) int {
|
||||
a.mu.Lock()
|
||||
|
||||
if a.closed {
|
||||
a.mu.Unlock()
|
||||
return 0
|
||||
}
|
||||
|
||||
if !a.busy {
|
||||
// Not busy — start immediately, same as Run().
|
||||
item := queueItem{Prompt: prompt}
|
||||
a.busy = true
|
||||
a.wg.Add(1)
|
||||
a.mu.Unlock()
|
||||
go a.drainQueue(item)
|
||||
return 0
|
||||
}
|
||||
|
||||
a.mu.Unlock()
|
||||
|
||||
// Agent is busy — inject via the SDK's steer channel. The message
|
||||
// will be picked up by PrepareStep between agent steps (after tool
|
||||
// execution, before next LLM call). If PrepareStep doesn't fire
|
||||
// (text-only response), drainQueue will pick it up after the turn.
|
||||
if a.opts.Kit != nil {
|
||||
a.opts.Kit.InjectSteer(prompt)
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// InterruptAndSend cancels the current agent step (if running), clears the
|
||||
// queue, and sends a new message that will execute as soon as the current
|
||||
// step finishes cancelling. If the agent is idle, the message executes
|
||||
// immediately. This is the hard-cancel delivery mode used by extensions'
|
||||
// CancelAndSend.
|
||||
func (a *App) InterruptAndSend(prompt string) {
|
||||
a.mu.Lock()
|
||||
|
||||
if a.closed {
|
||||
@@ -212,6 +270,17 @@ func (a *App) ClearMessages() {
|
||||
}
|
||||
}
|
||||
|
||||
// ReloadMessagesFromTree clears the in-memory message store and reloads it
|
||||
// from the tree session's current branch. Unlike ClearMessages, this does NOT
|
||||
// reset the tree session's leaf pointer. Used after Branch() to sync the
|
||||
// store with the new branch position.
|
||||
func (a *App) ReloadMessagesFromTree() {
|
||||
a.store.Clear()
|
||||
if a.opts.TreeSession != nil {
|
||||
a.store.Replace(a.opts.TreeSession.GetLLMMessages())
|
||||
}
|
||||
}
|
||||
|
||||
// GetTreeSession returns the tree session manager, or nil if not configured.
|
||||
func (a *App) GetTreeSession() *session.TreeManager {
|
||||
return a.opts.TreeSession
|
||||
@@ -226,10 +295,14 @@ func (a *App) SwitchTreeSession(ts *session.TreeManager) {
|
||||
_ = old.Close()
|
||||
}
|
||||
a.opts.TreeSession = ts
|
||||
// Also update the kit SDK's tree session so messages are persisted correctly.
|
||||
if a.opts.Kit != nil {
|
||||
a.opts.Kit.SetTreeSession(ts)
|
||||
}
|
||||
// Reload messages from new session.
|
||||
a.store.Clear()
|
||||
if ts != nil {
|
||||
a.store.Replace(ts.GetFantasyMessages())
|
||||
a.store.Replace(ts.GetLLMMessages())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -240,12 +313,12 @@ func (a *App) SwitchTreeSession(ts *session.TreeManager) {
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) AddContextMessage(text string) {
|
||||
msg := fantasy.NewUserMessage(text)
|
||||
a.store.Add(msg)
|
||||
kitMsg := fantasy.NewUserMessage(text)
|
||||
a.store.Add(kitMsg)
|
||||
|
||||
// Persist to tree session if active.
|
||||
if ts := a.opts.TreeSession; ts != nil {
|
||||
_, _ = ts.AppendFantasyMessage(msg)
|
||||
_, _ = ts.AppendLLMMessage(fantasy.NewUserMessage(text))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -283,6 +356,15 @@ func (a *App) CompactConversation(customInstructions string) error {
|
||||
a.mu.Unlock()
|
||||
}()
|
||||
|
||||
// Subscribe to SDK events for streaming compaction summary to the TUI.
|
||||
sendFn := func(msg tea.Msg) {
|
||||
if a.program != nil {
|
||||
a.program.Send(msg)
|
||||
}
|
||||
}
|
||||
unsub := a.subscribeSDKEvents(sendFn, nil)
|
||||
defer unsub()
|
||||
|
||||
result, err := a.opts.Kit.Compact(a.rootCtx, nil, customInstructions)
|
||||
if err != nil {
|
||||
a.sendEvent(CompactErrorEvent{Err: err})
|
||||
@@ -295,7 +377,7 @@ func (a *App) CompactConversation(customInstructions string) error {
|
||||
|
||||
// Sync in-memory store with the compacted session.
|
||||
if a.opts.TreeSession != nil {
|
||||
a.store.Replace(a.opts.TreeSession.GetFantasyMessages())
|
||||
a.store.Replace(a.opts.TreeSession.GetLLMMessages())
|
||||
}
|
||||
|
||||
a.sendEvent(CompactCompleteEvent{
|
||||
@@ -401,6 +483,13 @@ func (a *App) Close() {
|
||||
|
||||
// Wait for background goroutines.
|
||||
a.wg.Wait()
|
||||
|
||||
// Clean up empty session file on shutdown.
|
||||
if ts := a.opts.TreeSession; ts != nil && ts.IsEmpty() {
|
||||
if path := ts.GetFilePath(); path != "" {
|
||||
_ = os.Remove(path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -425,15 +514,32 @@ func (a *App) drainQueue(first queueItem) {
|
||||
a.mu.Lock()
|
||||
items = append(items, a.queue...)
|
||||
a.queue = a.queue[:0] // Clear the queue
|
||||
queueLen := len(a.queue)
|
||||
a.mu.Unlock()
|
||||
|
||||
// Send queue updated event (queue is now empty)
|
||||
a.sendEvent(QueueUpdatedEvent{Length: queueLen})
|
||||
// Notify UI: all queued messages have been consumed into this batch.
|
||||
a.sendEvent(QueueUpdatedEvent{Length: 0})
|
||||
|
||||
// Process all collected items as a single batch
|
||||
a.runQueueBatch(items)
|
||||
|
||||
// Drain any unconsumed steer messages from the SDK channel.
|
||||
// These arrive when the user steered during a text-only response
|
||||
// (no tool calls, so PrepareStep didn't fire for a second step).
|
||||
// They go to the front of the queue so they run next.
|
||||
if a.opts.Kit != nil {
|
||||
if leftover := a.opts.Kit.DrainSteer(); len(leftover) > 0 {
|
||||
a.mu.Lock()
|
||||
steerItems := make([]queueItem, len(leftover))
|
||||
for i, text := range leftover {
|
||||
steerItems[i] = queueItem{Prompt: text}
|
||||
}
|
||||
a.queue = append(steerItems, a.queue...)
|
||||
a.mu.Unlock()
|
||||
// Notify UI about the consumed steer messages.
|
||||
a.sendEvent(SteerConsumedEvent{})
|
||||
}
|
||||
}
|
||||
|
||||
// Check if more items were queued while we were processing
|
||||
a.mu.Lock()
|
||||
hasMore := len(a.queue) > 0
|
||||
@@ -444,6 +550,11 @@ func (a *App) drainQueue(first queueItem) {
|
||||
}
|
||||
a.mu.Unlock()
|
||||
|
||||
if hasMore {
|
||||
// Notify UI: these newly queued messages have been consumed into the next batch.
|
||||
a.sendEvent(QueueUpdatedEvent{Length: 0})
|
||||
}
|
||||
|
||||
if !hasMore {
|
||||
// No more items, we're done
|
||||
break
|
||||
@@ -491,7 +602,7 @@ func (a *App) runQueueBatch(items []queueItem) {
|
||||
// call/result pairs; only the in-progress message or tool
|
||||
// call is discarded. Sync the in-memory store to match.
|
||||
if ts := a.opts.TreeSession; ts != nil {
|
||||
a.store.Replace(ts.GetFantasyMessages())
|
||||
a.store.Replace(ts.GetLLMMessages())
|
||||
}
|
||||
a.sendEvent(StepCancelledEvent{})
|
||||
return
|
||||
@@ -510,7 +621,7 @@ func (a *App) runQueueBatch(items []queueItem) {
|
||||
// executeStep runs a single agentic step by delegating to the SDK's
|
||||
// PromptResult() (or PromptResultWithFiles for multimodal), which handles
|
||||
// session persistence, hooks, extension events, and the generation loop.
|
||||
func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.Msg), files []fantasy.FilePart) (*kit.TurnResult, error) {
|
||||
func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.Msg), files []kit.LLMFilePart) (*kit.TurnResult, error) {
|
||||
// Test hook: bypass SDK entirely.
|
||||
if a.opts.PromptFunc != nil {
|
||||
return a.opts.PromptFunc(ctx, prompt)
|
||||
@@ -522,9 +633,10 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe to SDK events for TUI rendering. The subscription is
|
||||
// temporary — it lives only for the duration of this step.
|
||||
unsub := a.subscribeSDKEvents(sendFn)
|
||||
// Subscribe to SDK events for TUI rendering and per-step usage updates.
|
||||
// The subscription is temporary — it lives only for the duration of this step.
|
||||
var sawStepUsage atomic.Bool
|
||||
unsub := a.subscribeSDKEvents(sendFn, &sawStepUsage)
|
||||
defer unsub()
|
||||
|
||||
// Show spinner while the agent works.
|
||||
@@ -544,8 +656,9 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M
|
||||
// Sync in-memory store with the SDK's authoritative conversation.
|
||||
a.store.Replace(result.Messages)
|
||||
|
||||
// Update usage tracker.
|
||||
a.updateUsageFromTurnResult(result, prompt)
|
||||
// Update usage tracker. If per-step usage was already recorded from
|
||||
// StepUsageEvent callbacks, avoid double-counting totals.
|
||||
a.updateUsageFromTurnResult(result, prompt, sawStepUsage.Load())
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -569,9 +682,10 @@ func (a *App) executeBatch(ctx context.Context, items []queueItem, eventFn func(
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe to SDK events for TUI rendering. The subscription is
|
||||
// temporary — it lives only for the duration of this step.
|
||||
unsub := a.subscribeSDKEvents(sendFn)
|
||||
// Subscribe to SDK events for TUI rendering and per-step usage updates.
|
||||
// The subscription is temporary — it lives only for the duration of this step.
|
||||
var sawStepUsage atomic.Bool
|
||||
unsub := a.subscribeSDKEvents(sendFn, &sawStepUsage)
|
||||
defer unsub()
|
||||
|
||||
// Show spinner while the agent works.
|
||||
@@ -604,8 +718,8 @@ func (a *App) executeBatch(ctx context.Context, items []queueItem, eventFn func(
|
||||
messages = append(messages, item.Prompt)
|
||||
}
|
||||
|
||||
// TODO: Handle file attachments in batch mode
|
||||
// For now, files are ignored in batch mode (rare edge case)
|
||||
// File attachments are not supported in batch mode; fall back to
|
||||
// processing only the first item that carries files.
|
||||
if hasFiles {
|
||||
// If files exist, fall back to processing just the first item with files
|
||||
for _, item := range items {
|
||||
@@ -626,8 +740,10 @@ func (a *App) executeBatch(ctx context.Context, items []queueItem, eventFn func(
|
||||
// Sync in-memory store with the SDK's authoritative conversation.
|
||||
a.store.Replace(result.Messages)
|
||||
|
||||
// Update usage tracker (using last item's prompt for tracking).
|
||||
a.updateUsageFromTurnResult(result, items[len(items)-1].Prompt)
|
||||
// Update usage tracker (using last item's prompt for fallback estimation).
|
||||
// If per-step usage was already recorded from StepUsageEvent callbacks,
|
||||
// avoid double-counting totals.
|
||||
a.updateUsageFromTurnResult(result, items[len(items)-1].Prompt, sawStepUsage.Load())
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -644,9 +760,10 @@ func (a *App) sendEvent(msg tea.Msg) {
|
||||
}
|
||||
|
||||
// subscribeSDKEvents registers temporary SDK event subscribers that convert
|
||||
// SDK events to tea.Msg events and dispatch them via sendFn. Returns an
|
||||
// unsubscribe function that removes all listeners.
|
||||
func (a *App) subscribeSDKEvents(sendFn func(tea.Msg)) func() {
|
||||
// SDK events to tea.Msg events and dispatch them via sendFn. When stepUsageSeen
|
||||
// is provided, it is set to true after any non-zero StepUsageEvent is observed.
|
||||
// Returns an unsubscribe function that removes all listeners.
|
||||
func (a *App) subscribeSDKEvents(sendFn func(tea.Msg), stepUsageSeen *atomic.Bool) func() {
|
||||
k := a.opts.Kit
|
||||
var unsubs []func()
|
||||
|
||||
@@ -678,15 +795,10 @@ func (a *App) subscribeSDKEvents(sendFn func(tea.Msg)) func() {
|
||||
Chunk: ev.Chunk,
|
||||
IsStderr: ev.IsStderr,
|
||||
})
|
||||
case kit.SteerConsumedEvent:
|
||||
sendFn(SteerConsumedEvent{})
|
||||
case kit.StepUsageEvent:
|
||||
if a.opts.UsageTracker != nil {
|
||||
a.opts.UsageTracker.UpdateUsage(
|
||||
int(ev.InputTokens),
|
||||
int(ev.OutputTokens),
|
||||
int(ev.CacheReadTokens),
|
||||
int(ev.CacheWriteTokens),
|
||||
)
|
||||
}
|
||||
a.recordStepUsage(ev, stepUsageSeen)
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -755,12 +867,32 @@ func (a *App) NotifyModelChanged(provider, model string) {
|
||||
// NotifyWidgetUpdate sends a WidgetUpdateEvent to the TUI so it re-renders
|
||||
// extension widgets. Called from the extension context's SetWidget/RemoveWidget
|
||||
// closures. In non-interactive mode this is a no-op (widgets are TUI-only).
|
||||
//
|
||||
// Coalescing: if a WidgetUpdateEvent is already queued and not yet consumed
|
||||
// by the TUI event loop, additional calls within the same ~16 ms window are
|
||||
// dropped. This prevents fast extension tickers from flooding BubbleTea's
|
||||
// mailbox with redundant re-render triggers.
|
||||
func (a *App) NotifyWidgetUpdate() {
|
||||
// Coalesce: only one pending update at a time.
|
||||
if !a.widgetUpdatePending.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
a.mu.Lock()
|
||||
prog := a.program
|
||||
a.mu.Unlock()
|
||||
if prog != nil {
|
||||
prog.Send(WidgetUpdateEvent{})
|
||||
// Reset the pending flag after a short debounce so subsequent calls
|
||||
// within the same render cycle are also coalesced, but new updates
|
||||
// after the cycle are allowed through.
|
||||
go func() {
|
||||
time.Sleep(16 * time.Millisecond) // ~1 frame at 60 fps
|
||||
a.widgetUpdatePending.Store(false)
|
||||
}()
|
||||
} else {
|
||||
// No program registered (non-interactive mode); clear the flag so
|
||||
// future calls are never permanently blocked.
|
||||
a.widgetUpdatePending.Store(false)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -856,29 +988,106 @@ func (a *App) PrintBlockFromExtension(opts extensions.PrintBlockOpts) {
|
||||
}
|
||||
}
|
||||
|
||||
// recordStepUsage applies token/cost usage reported for a completed step.
|
||||
// Step usage events arrive even when a turn is later cancelled, so this keeps
|
||||
// the usage widget accurate on all stop paths.
|
||||
func (a *App) recordStepUsage(ev kit.StepUsageEvent, stepUsageSeen *atomic.Bool) {
|
||||
hasUsage := ev.InputTokens > 0 || ev.OutputTokens > 0 || ev.CacheReadTokens > 0 || ev.CacheWriteTokens > 0
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] recordStepUsage: hasUsage=%v input=%d output=%d cacheRead=%d cacheWrite=%d",
|
||||
hasUsage, ev.InputTokens, ev.OutputTokens, ev.CacheReadTokens, ev.CacheWriteTokens)
|
||||
}
|
||||
if !hasUsage {
|
||||
return
|
||||
}
|
||||
if stepUsageSeen != nil {
|
||||
stepUsageSeen.Store(true)
|
||||
}
|
||||
if a.opts.UsageTracker == nil {
|
||||
return
|
||||
}
|
||||
a.opts.UsageTracker.UpdateUsage(
|
||||
int(ev.InputTokens),
|
||||
int(ev.OutputTokens),
|
||||
int(ev.CacheReadTokens),
|
||||
int(ev.CacheWriteTokens),
|
||||
)
|
||||
// NOTE: We do NOT call SetContextTokens here. Context fill is set once
|
||||
// at turn completion via updateUsageFromTurnResult using FinalUsage.InputTokens,
|
||||
// which reflects the full accumulated context. Per-step context tokens would
|
||||
// cause the display to jump around during multi-step tool calls.
|
||||
}
|
||||
|
||||
// updateUsageFromTurnResult records token usage from an SDK TurnResult into the
|
||||
// configured UsageTracker. This is the SDK-path equivalent of updateUsage.
|
||||
func (a *App) updateUsageFromTurnResult(result *kit.TurnResult, userPrompt string) {
|
||||
// configured UsageTracker. Called once per turn after the turn completes.
|
||||
//
|
||||
// When sawStepUsage is true, totals were already accumulated incrementally via
|
||||
// StepUsageEvent callbacks; in that case this method only updates context fill.
|
||||
// Otherwise it falls back to TotalUsage from the API response.
|
||||
//
|
||||
// NOTE: We only use ACTUAL token counts from API responses for cost tracking.
|
||||
// Estimation is never used for costs - only API-reported tokens are accurate.
|
||||
func (a *App) updateUsageFromTurnResult(result *kit.TurnResult, userPrompt string, sawStepUsage bool) {
|
||||
if a.opts.UsageTracker == nil || result == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if result.TotalUsage != nil {
|
||||
inputTokens := int(result.TotalUsage.InputTokens)
|
||||
outputTokens := int(result.TotalUsage.OutputTokens)
|
||||
if inputTokens > 0 && outputTokens > 0 {
|
||||
cacheReadTokens := int(result.TotalUsage.CacheReadTokens)
|
||||
cacheWriteTokens := int(result.TotalUsage.CacheCreationTokens)
|
||||
a.opts.UsageTracker.UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens)
|
||||
// Debug logging for token tracking
|
||||
if a.opts.Debug {
|
||||
if result.TotalUsage != nil {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult TotalUsage: input=%d output=%d cacheRead=%d cacheCreate=%d",
|
||||
result.TotalUsage.InputTokens, result.TotalUsage.OutputTokens,
|
||||
result.TotalUsage.CacheReadTokens, result.TotalUsage.CacheCreationTokens)
|
||||
} else {
|
||||
a.opts.UsageTracker.EstimateAndUpdateUsage(userPrompt, result.Response)
|
||||
return
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: TotalUsage=nil")
|
||||
}
|
||||
if result.FinalUsage != nil {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult FinalUsage: input=%d output=%d cacheRead=%d cacheCreate=%d",
|
||||
result.FinalUsage.InputTokens, result.FinalUsage.OutputTokens,
|
||||
result.FinalUsage.CacheReadTokens, result.FinalUsage.CacheCreationTokens)
|
||||
} else {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: FinalUsage=nil")
|
||||
}
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: sawStepUsage=%v", sawStepUsage)
|
||||
}
|
||||
|
||||
if result.FinalUsage != nil {
|
||||
if ct := int(result.FinalUsage.InputTokens) + int(result.FinalUsage.OutputTokens); ct > 0 {
|
||||
a.opts.UsageTracker.SetContextTokens(ct)
|
||||
// --- Accumulate cost/token totals for the session ---
|
||||
// Only use actual API-reported tokens for cost tracking.
|
||||
// If sawStepUsage is true, totals were already updated via StepUsageEvent.
|
||||
// Check any token field > 0 (not just InputTokens) because cached prompts
|
||||
// can result in InputTokens=0 while OutputTokens>0 (OpenAI-compatible behavior).
|
||||
hasTotalUsage := result.TotalUsage != nil &&
|
||||
(result.TotalUsage.InputTokens > 0 ||
|
||||
result.TotalUsage.OutputTokens > 0 ||
|
||||
result.TotalUsage.CacheReadTokens > 0 ||
|
||||
result.TotalUsage.CacheCreationTokens > 0)
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: hasTotalUsage=%v", hasTotalUsage)
|
||||
}
|
||||
if !sawStepUsage && hasTotalUsage {
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: calling UpdateUsage input=%d output=%d cacheRead=%d cacheCreate=%d",
|
||||
result.TotalUsage.InputTokens, result.TotalUsage.OutputTokens,
|
||||
result.TotalUsage.CacheReadTokens, result.TotalUsage.CacheCreationTokens)
|
||||
}
|
||||
a.opts.UsageTracker.UpdateUsage(
|
||||
int(result.TotalUsage.InputTokens),
|
||||
int(result.TotalUsage.OutputTokens),
|
||||
int(result.TotalUsage.CacheReadTokens),
|
||||
int(result.TotalUsage.CacheCreationTokens),
|
||||
)
|
||||
}
|
||||
|
||||
// --- Context window fill (drives the % bar) ---
|
||||
// Use FinalUsage.InputTokens as the context window fill. The API's InputTokens
|
||||
// already includes the full conversation history (system prompt + all previous
|
||||
// messages + current user message). Adding OutputTokens would double-count since
|
||||
// the output becomes part of the input for the next turn.
|
||||
if result.FinalUsage != nil && result.FinalUsage.InputTokens > 0 {
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: calling SetContextTokens=%d (FinalUsage.InputTokens)",
|
||||
result.FinalUsage.InputTokens)
|
||||
}
|
||||
a.opts.UsageTracker.SetContextTokens(int(result.FinalUsage.InputTokens))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,47 @@ import (
|
||||
// Helpers
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
type usageUpdaterStub struct {
|
||||
mu sync.Mutex
|
||||
|
||||
updateCalls int
|
||||
estimateCalls int
|
||||
contextCalls int
|
||||
|
||||
lastUpdateInput int
|
||||
lastUpdateOutput int
|
||||
lastUpdateCacheRead int
|
||||
lastUpdateCacheWrite int
|
||||
lastContextTokens int
|
||||
lastEstimateInput string
|
||||
lastEstimateOutput string
|
||||
}
|
||||
|
||||
func (s *usageUpdaterStub) UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.updateCalls++
|
||||
s.lastUpdateInput = inputTokens
|
||||
s.lastUpdateOutput = outputTokens
|
||||
s.lastUpdateCacheRead = cacheReadTokens
|
||||
s.lastUpdateCacheWrite = cacheWriteTokens
|
||||
}
|
||||
|
||||
func (s *usageUpdaterStub) EstimateAndUpdateUsage(inputText, outputText string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.estimateCalls++
|
||||
s.lastEstimateInput = inputText
|
||||
s.lastEstimateOutput = outputText
|
||||
}
|
||||
|
||||
func (s *usageUpdaterStub) SetContextTokens(tokens int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.contextCalls++
|
||||
s.lastContextTokens = tokens
|
||||
}
|
||||
|
||||
// turnResult builds a minimal TurnResult with response text t.
|
||||
func turnResult(t string) *kit.TurnResult {
|
||||
return &kit.TurnResult{Response: t}
|
||||
@@ -489,3 +530,133 @@ func TestQueueLength_reflects(t *testing.T) {
|
||||
t.Fatalf("expected 3, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecordStepUsage_updatesTracker verifies that per-step usage updates are
|
||||
// recorded immediately for cost tracking. Context tokens are NOT updated here
|
||||
// (only via updateUsageFromTurnResult) to avoid display jumps during multi-step
|
||||
// tool calls.
|
||||
func TestRecordStepUsage_updatesTracker(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
defer app.Close()
|
||||
|
||||
app.recordStepUsage(kit.StepUsageEvent{
|
||||
InputTokens: 120,
|
||||
OutputTokens: 45,
|
||||
CacheReadTokens: 5,
|
||||
CacheWriteTokens: 2,
|
||||
}, nil)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
|
||||
if usage.updateCalls != 1 {
|
||||
t.Fatalf("expected 1 update call, got %d", usage.updateCalls)
|
||||
}
|
||||
if usage.lastUpdateInput != 120 || usage.lastUpdateOutput != 45 || usage.lastUpdateCacheRead != 5 || usage.lastUpdateCacheWrite != 2 {
|
||||
t.Fatalf("unexpected usage update payload: in=%d out=%d cache_read=%d cache_write=%d",
|
||||
usage.lastUpdateInput, usage.lastUpdateOutput, usage.lastUpdateCacheRead, usage.lastUpdateCacheWrite)
|
||||
}
|
||||
// Context tokens should NOT be updated by recordStepUsage (only by updateUsageFromTurnResult)
|
||||
if usage.contextCalls != 0 {
|
||||
t.Fatalf("expected 0 context token updates from recordStepUsage, got %d", usage.contextCalls)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateUsageFromTurnResult_skipsTotalsWhenStepUsageSeen ensures we avoid
|
||||
// double-counting totals once StepUsageEvent-based updates were already applied.
|
||||
func TestUpdateUsageFromTurnResult_skipsTotalsWhenStepUsageSeen(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
defer app.Close()
|
||||
|
||||
app.updateUsageFromTurnResult(&kit.TurnResult{
|
||||
Response: "ok",
|
||||
TotalUsage: &kit.LLMUsage{
|
||||
InputTokens: 999,
|
||||
OutputTokens: 111,
|
||||
CacheReadTokens: 7,
|
||||
CacheCreationTokens: 3,
|
||||
},
|
||||
FinalUsage: &kit.LLMUsage{InputTokens: 456},
|
||||
}, "prompt", true)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
|
||||
if usage.updateCalls != 0 {
|
||||
t.Fatalf("expected no total usage update when sawStepUsage=true, got %d", usage.updateCalls)
|
||||
}
|
||||
if usage.estimateCalls != 0 {
|
||||
t.Fatalf("expected no estimate update when sawStepUsage=true, got %d", usage.estimateCalls)
|
||||
}
|
||||
// Context tokens should be InputTokens only (456)
|
||||
if usage.contextCalls != 1 || usage.lastContextTokens != 456 {
|
||||
t.Fatalf("expected final context tokens=456 (InputTokens only), got calls=%d tokens=%d", usage.contextCalls, usage.lastContextTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateUsageFromTurnResult_recordsWhenInputTokensZero verifies that usage
|
||||
// is recorded when InputTokens=0 but OutputTokens>0 (OpenAI-compatible cache behavior).
|
||||
func TestUpdateUsageFromTurnResult_recordsWhenInputTokensZero(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
defer app.Close()
|
||||
|
||||
// Simulate OpenAI-compatible behavior: all prompt tokens cached, InputTokens=0
|
||||
app.updateUsageFromTurnResult(&kit.TurnResult{
|
||||
Response: "ok",
|
||||
TotalUsage: &kit.LLMUsage{
|
||||
InputTokens: 0, // All cached - subtracted from prompt
|
||||
OutputTokens: 150, // Actual generated tokens
|
||||
CacheReadTokens: 500, // Cache hit
|
||||
CacheCreationTokens: 0,
|
||||
},
|
||||
FinalUsage: &kit.LLMUsage{InputTokens: 0, OutputTokens: 150},
|
||||
}, "prompt", false)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
|
||||
if usage.updateCalls != 1 {
|
||||
t.Fatalf("expected 1 update call when InputTokens=0 but OutputTokens>0, got %d", usage.updateCalls)
|
||||
}
|
||||
if usage.lastUpdateInput != 0 || usage.lastUpdateOutput != 150 {
|
||||
t.Fatalf("expected input=0 output=150, got input=%d output=%d",
|
||||
usage.lastUpdateInput, usage.lastUpdateOutput)
|
||||
}
|
||||
if usage.lastUpdateCacheRead != 500 {
|
||||
t.Fatalf("expected cache_read=500, got %d", usage.lastUpdateCacheRead)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateUsageFromTurnResult_contextTokensUsesInputOnly verifies that context
|
||||
// window fill uses InputTokens only (not input+output). The API's InputTokens
|
||||
// already includes the full conversation history; adding output would double-count.
|
||||
func TestUpdateUsageFromTurnResult_contextTokensUsesInputOnly(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
defer app.Close()
|
||||
|
||||
app.updateUsageFromTurnResult(&kit.TurnResult{
|
||||
Response: "ok",
|
||||
TotalUsage: &kit.LLMUsage{
|
||||
InputTokens: 1000,
|
||||
OutputTokens: 200,
|
||||
},
|
||||
FinalUsage: &kit.LLMUsage{
|
||||
InputTokens: 1000, // Full context including history
|
||||
OutputTokens: 200,
|
||||
},
|
||||
}, "prompt", false)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
|
||||
// Context tokens should be InputTokens only (1000), not input+output (1200)
|
||||
// because InputTokens already includes the full conversation history
|
||||
if usage.contextCalls != 1 || usage.lastContextTokens != 1000 {
|
||||
t.Fatalf("expected context tokens=1000 (InputTokens only), got calls=%d tokens=%d",
|
||||
usage.contextCalls, usage.lastContextTokens)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package app
|
||||
|
||||
import "charm.land/fantasy"
|
||||
import kit "github.com/mark3labs/kit/pkg/kit"
|
||||
|
||||
// StreamChunkEvent is sent by the app layer when a streaming text delta arrives
|
||||
// from the LLM. Each chunk contains an incremental portion of the response.
|
||||
@@ -118,8 +118,8 @@ type SpinnerEvent struct {
|
||||
// MessageCreatedEvent is sent when a new message is added to the message store.
|
||||
// This allows the TUI to stay in sync with the conversation history.
|
||||
type MessageCreatedEvent struct {
|
||||
// Message is the fantasy message that was added to the store.
|
||||
Message fantasy.Message
|
||||
// Message is the message that was added to the store.
|
||||
Message kit.LLMMessage
|
||||
}
|
||||
|
||||
// CompactCompleteEvent is sent when a /compact operation finishes successfully.
|
||||
@@ -141,6 +141,12 @@ type CompactErrorEvent struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// SteerConsumedEvent is sent when one or more steering messages have been
|
||||
// consumed — either injected mid-turn via PrepareStep, or drained into the
|
||||
// queue after a turn completes. The TUI uses this to clear the steering
|
||||
// badge from the display.
|
||||
type SteerConsumedEvent struct{}
|
||||
|
||||
// ModelChangedEvent is sent when an extension changes the active model via
|
||||
// ctx.SetModel. The TUI updates the model name shown in the status bar and
|
||||
// message attribution.
|
||||
|
||||
@@ -3,14 +3,14 @@ package app
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"charm.land/fantasy"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// MessageStore is a thread-safe in-memory store for the conversation history.
|
||||
// On-disk persistence is handled by the TreeManager at the app/SDK layer.
|
||||
type MessageStore struct {
|
||||
mu sync.RWMutex
|
||||
messages []fantasy.Message
|
||||
messages []kit.LLMMessage
|
||||
}
|
||||
|
||||
// NewMessageStore creates an empty MessageStore.
|
||||
@@ -20,14 +20,14 @@ func NewMessageStore() *MessageStore {
|
||||
|
||||
// NewMessageStoreWithMessages creates a MessageStore pre-populated with the
|
||||
// given messages. This is used when loading an existing session at startup.
|
||||
func NewMessageStoreWithMessages(msgs []fantasy.Message) *MessageStore {
|
||||
cp := make([]fantasy.Message, len(msgs))
|
||||
func NewMessageStoreWithMessages(msgs []kit.LLMMessage) *MessageStore {
|
||||
cp := make([]kit.LLMMessage, len(msgs))
|
||||
copy(cp, msgs)
|
||||
return &MessageStore{messages: cp}
|
||||
}
|
||||
|
||||
// Add appends a single message to the store.
|
||||
func (s *MessageStore) Add(msg fantasy.Message) {
|
||||
func (s *MessageStore) Add(msg kit.LLMMessage) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.messages = append(s.messages, msg)
|
||||
@@ -36,22 +36,22 @@ func (s *MessageStore) Add(msg fantasy.Message) {
|
||||
// Replace replaces the entire message history with the given slice. This is
|
||||
// used after an agent step returns the full updated conversation (including
|
||||
// tool calls and results).
|
||||
func (s *MessageStore) Replace(msgs []fantasy.Message) {
|
||||
func (s *MessageStore) Replace(msgs []kit.LLMMessage) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
cp := make([]fantasy.Message, len(msgs))
|
||||
cp := make([]kit.LLMMessage, len(msgs))
|
||||
copy(cp, msgs)
|
||||
s.messages = cp
|
||||
}
|
||||
|
||||
// GetAll returns a snapshot copy of the current message slice.
|
||||
// The returned slice is safe to modify without affecting the store.
|
||||
func (s *MessageStore) GetAll() []fantasy.Message {
|
||||
func (s *MessageStore) GetAll() []kit.LLMMessage {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
cp := make([]fantasy.Message, len(s.messages))
|
||||
cp := make([]kit.LLMMessage, len(s.messages))
|
||||
copy(cp, s.messages)
|
||||
return cp
|
||||
}
|
||||
|
||||
@@ -4,16 +4,29 @@ import (
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// makeTextMsg builds a minimal fantasy.Message with a single TextPart.
|
||||
func makeTextMsg(role, text string) fantasy.Message {
|
||||
return fantasy.Message{
|
||||
Role: fantasy.MessageRole(role),
|
||||
// makeTextMsg builds a minimal kit.LLMMessage using fantasy.NewUserMessage
|
||||
// or constructing with the given role.
|
||||
func makeTextMsg(role, text string) kit.LLMMessage {
|
||||
return kit.LLMMessage{
|
||||
Role: kit.LLMMessageRole(role),
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: text}},
|
||||
}
|
||||
}
|
||||
|
||||
// textOf extracts the plain text from an LLMMessage for assertions.
|
||||
func textOf(msg kit.LLMMessage) string {
|
||||
for _, part := range msg.Content {
|
||||
if tp, ok := part.(fantasy.TextPart); ok {
|
||||
return tp.Text
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// NewMessageStore / NewMessageStoreWithMessages
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -29,7 +42,7 @@ func TestNewMessageStore_empty(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewMessageStoreWithMessages_preloaded(t *testing.T) {
|
||||
msgs := []fantasy.Message{
|
||||
msgs := []kit.LLMMessage{
|
||||
makeTextMsg("user", "hello"),
|
||||
makeTextMsg("assistant", "hi"),
|
||||
}
|
||||
@@ -42,7 +55,7 @@ func TestNewMessageStoreWithMessages_preloaded(t *testing.T) {
|
||||
// NewMessageStoreWithMessages must deep-copy the slice so that external
|
||||
// modifications don't affect the store.
|
||||
func TestNewMessageStoreWithMessages_isolatesInput(t *testing.T) {
|
||||
msgs := []fantasy.Message{makeTextMsg("user", "hello")}
|
||||
msgs := []kit.LLMMessage{makeTextMsg("user", "hello")}
|
||||
s := NewMessageStoreWithMessages(msgs)
|
||||
|
||||
// Mutate the source slice.
|
||||
@@ -52,9 +65,8 @@ func TestNewMessageStoreWithMessages_isolatesInput(t *testing.T) {
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(got))
|
||||
}
|
||||
tp, ok := got[0].Content[0].(fantasy.TextPart)
|
||||
if !ok || tp.Text != "hello" {
|
||||
t.Fatalf("store was mutated by external slice change; got %q", tp.Text)
|
||||
if textOf(got[0]) != "hello" {
|
||||
t.Fatalf("store was mutated by external slice change; got %q", textOf(got[0]))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,9 +92,8 @@ func TestAdd_preservesOrder(t *testing.T) {
|
||||
}
|
||||
got := s.GetAll()
|
||||
for i, expected := range texts {
|
||||
tp, ok := got[i].Content[0].(fantasy.TextPart)
|
||||
if !ok || tp.Text != expected {
|
||||
t.Fatalf("message[%d]: expected %q, got %q", i, expected, tp.Text)
|
||||
if textOf(got[i]) != expected {
|
||||
t.Fatalf("message[%d]: expected %q, got %q", i, expected, textOf(got[i]))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -95,7 +106,7 @@ func TestReplace_swapsHistory(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s.Add(makeTextMsg("user", "old"))
|
||||
|
||||
replacement := []fantasy.Message{
|
||||
replacement := []kit.LLMMessage{
|
||||
makeTextMsg("user", "new1"),
|
||||
makeTextMsg("assistant", "new2"),
|
||||
}
|
||||
@@ -105,25 +116,22 @@ func TestReplace_swapsHistory(t *testing.T) {
|
||||
t.Fatalf("expected 2 messages after replace, got %d", s.Len())
|
||||
}
|
||||
got := s.GetAll()
|
||||
tp0, _ := got[0].Content[0].(fantasy.TextPart)
|
||||
tp1, _ := got[1].Content[0].(fantasy.TextPart)
|
||||
if tp0.Text != "new1" || tp1.Text != "new2" {
|
||||
t.Fatalf("unexpected messages after replace: %q %q", tp0.Text, tp1.Text)
|
||||
if textOf(got[0]) != "new1" || textOf(got[1]) != "new2" {
|
||||
t.Fatalf("unexpected messages after replace: %q %q", textOf(got[0]), textOf(got[1]))
|
||||
}
|
||||
}
|
||||
|
||||
// Replace must deep-copy the incoming slice.
|
||||
func TestReplace_isolatesInput(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
replacement := []fantasy.Message{makeTextMsg("user", "original")}
|
||||
replacement := []kit.LLMMessage{makeTextMsg("user", "original")}
|
||||
s.Replace(replacement)
|
||||
|
||||
replacement[0] = makeTextMsg("user", "mutated")
|
||||
|
||||
got := s.GetAll()
|
||||
tp, _ := got[0].Content[0].(fantasy.TextPart)
|
||||
if tp.Text != "original" {
|
||||
t.Fatalf("store was mutated by external slice change after Replace; got %q", tp.Text)
|
||||
if textOf(got[0]) != "original" {
|
||||
t.Fatalf("store was mutated by external slice change after Replace; got %q", textOf(got[0]))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,9 +148,8 @@ func TestGetAll_returnsCopy(t *testing.T) {
|
||||
got[0] = makeTextMsg("user", "mutated")
|
||||
|
||||
internal := s.GetAll()
|
||||
tp, _ := internal[0].Content[0].(fantasy.TextPart)
|
||||
if tp.Text != "hello" {
|
||||
t.Fatalf("GetAll returned non-copy; store was mutated to %q", tp.Text)
|
||||
if textOf(internal[0]) != "hello" {
|
||||
t.Fatalf("GetAll returned non-copy; store was mutated to %q", textOf(internal[0]))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,9 +186,8 @@ func TestClear_allowsSubsequentAdds(t *testing.T) {
|
||||
t.Fatalf("expected 1 message after Clear+Add, got %d", s.Len())
|
||||
}
|
||||
got := s.GetAll()
|
||||
tp, _ := got[0].Content[0].(fantasy.TextPart)
|
||||
if tp.Text != "after" {
|
||||
t.Fatalf("expected %q, got %q", "after", tp.Text)
|
||||
if textOf(got[0]) != "after" {
|
||||
t.Fatalf("expected %q, got %q", "after", textOf(got[0]))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -67,10 +67,6 @@ type Options struct {
|
||||
// Debug enables verbose debug logging.
|
||||
Debug bool
|
||||
|
||||
// CompactMode selects the compact renderer instead of the block renderer for
|
||||
// message formatting.
|
||||
CompactMode bool
|
||||
|
||||
// UsageTracker is an optional callback for recording token usage after each
|
||||
// agent step. When non-nil, the app layer calls UpdateUsage (or
|
||||
// EstimateAndUpdateUsage as a fallback) using the usage data returned by the
|
||||
|
||||
@@ -10,9 +10,10 @@ import (
|
||||
)
|
||||
|
||||
// CredentialStore holds all stored credentials for various providers.
|
||||
// Currently supports Anthropic credentials with both OAuth and API key authentication methods.
|
||||
// Currently supports Anthropic and OpenAI credentials with both OAuth and API key authentication methods.
|
||||
type CredentialStore struct {
|
||||
Anthropic *AnthropicCredentials `json:"anthropic,omitempty"`
|
||||
OpenAI *OpenAICredentials `json:"openai,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicCredentials holds Anthropic API credentials supporting both OAuth
|
||||
@@ -28,13 +29,44 @@ type AnthropicCredentials struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// OpenAICredentials holds OpenAI API credentials supporting both OAuth
|
||||
// and API key authentication methods. The Type field indicates which authentication
|
||||
// method is being used. For OAuth, tokens are stored with expiration timestamps
|
||||
// for automatic refresh. For API keys, only the key itself is stored.
|
||||
type OpenAICredentials struct {
|
||||
Type string `json:"type"` // "oauth" or "api_key"
|
||||
APIKey string `json:"api_key,omitempty"` // For API key auth
|
||||
AccessToken string `json:"access_token,omitempty"` // For OAuth
|
||||
RefreshToken string `json:"refresh_token,omitempty"` // For OAuth
|
||||
ExpiresAt int64 `json:"expires_at,omitempty"` // For OAuth
|
||||
AccountID string `json:"account_id,omitempty"` // For OAuth (ChatGPT account ID)
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// oauthTokenExpired reports whether an OAuth token with the given type and
|
||||
// expiry unix timestamp is past its expiry. Returns false for API key
|
||||
// credentials or when no expiry is set.
|
||||
func oauthTokenExpired(credType string, expiresAt int64) bool {
|
||||
if credType != "oauth" || expiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= expiresAt
|
||||
}
|
||||
|
||||
// oauthTokenNeedsRefresh reports whether an OAuth token will expire within the
|
||||
// next 5 minutes, allowing proactive refresh before it becomes invalid.
|
||||
// Returns false for API key credentials or when no expiry is set.
|
||||
func oauthTokenNeedsRefresh(credType string, expiresAt int64) bool {
|
||||
if credType != "oauth" || expiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= (expiresAt - 300) // 5 minutes buffer
|
||||
}
|
||||
|
||||
// IsExpired checks if the OAuth token is expired based on the ExpiresAt timestamp.
|
||||
// Returns false for API key authentication or if no expiration is set.
|
||||
func (c *AnthropicCredentials) IsExpired() bool {
|
||||
if c.Type != "oauth" || c.ExpiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= c.ExpiresAt
|
||||
return oauthTokenExpired(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// NeedsRefresh checks if the OAuth token needs refresh, returning true if the token
|
||||
@@ -42,10 +74,21 @@ func (c *AnthropicCredentials) IsExpired() bool {
|
||||
// to avoid authentication failures during operations. Returns false for API key
|
||||
// authentication or if no expiration is set.
|
||||
func (c *AnthropicCredentials) NeedsRefresh() bool {
|
||||
if c.Type != "oauth" || c.ExpiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= (c.ExpiresAt - 300) // 5 minutes buffer
|
||||
return oauthTokenNeedsRefresh(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// IsExpired checks if the OAuth token is expired based on the ExpiresAt timestamp.
|
||||
// Returns false for API key authentication or if no expiration is set.
|
||||
func (c *OpenAICredentials) IsExpired() bool {
|
||||
return oauthTokenExpired(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// NeedsRefresh checks if the OAuth token needs refresh, returning true if the token
|
||||
// will expire within the next 5 minutes. This allows for proactive token refresh
|
||||
// to avoid authentication failures during operations. Returns false for API key
|
||||
// authentication or if no expiration is set.
|
||||
func (c *OpenAICredentials) NeedsRefresh() bool {
|
||||
return oauthTokenNeedsRefresh(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// CredentialManager handles secure storage and retrieval of authentication credentials.
|
||||
@@ -212,6 +255,142 @@ func (cm *CredentialManager) HasAnthropicCredentials() (bool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// SetOpenAICredentials stores OpenAI API key credentials. It validates the
|
||||
// API key format before storing. The API key must start with "sk-" and be
|
||||
// at least 20 characters long. Returns an error if the API key is invalid or
|
||||
// if storage fails.
|
||||
func (cm *CredentialManager) SetOpenAICredentials(apiKey string) error {
|
||||
if err := validateOpenAIAPIKey(apiKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store.OpenAI = &OpenAICredentials{
|
||||
Type: "api_key",
|
||||
APIKey: apiKey,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
return cm.SaveCredentials(store)
|
||||
}
|
||||
|
||||
// GetOpenAICredentials retrieves stored OpenAI credentials. Returns nil if
|
||||
// no credentials are stored. The returned credentials may be either OAuth or API
|
||||
// key type, check the Type field to determine which.
|
||||
func (cm *CredentialManager) GetOpenAICredentials() (*OpenAICredentials, error) {
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return store.OpenAI, nil
|
||||
}
|
||||
|
||||
// RemoveOpenAICredentials removes stored OpenAI credentials from storage.
|
||||
// If this was the only credential stored, the entire credentials file is removed.
|
||||
// Returns an error if the removal fails.
|
||||
func (cm *CredentialManager) RemoveOpenAICredentials() error {
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store.OpenAI = nil
|
||||
|
||||
// If store is empty, remove the file entirely
|
||||
if store.Anthropic == nil && store.OpenAI == nil {
|
||||
if err := os.Remove(cm.credentialsPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to remove credentials file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return cm.SaveCredentials(store)
|
||||
}
|
||||
|
||||
// HasOpenAICredentials checks if valid OpenAI credentials are stored.
|
||||
// Returns true if either a non-empty OAuth access token or API key is present,
|
||||
// false otherwise. Returns an error if credentials cannot be loaded.
|
||||
func (cm *CredentialManager) HasOpenAICredentials() (bool, error) {
|
||||
creds, err := cm.GetOpenAICredentials()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if creds == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Check based on credential type
|
||||
switch creds.Type {
|
||||
case "oauth":
|
||||
return creds.AccessToken != "", nil
|
||||
case "api_key":
|
||||
return creds.APIKey != "", nil
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// SetOpenAIOAuthCredentials stores OpenAI OAuth credentials in the credential manager's secure storage.
|
||||
// The credentials should include access token, refresh token, and expiration information.
|
||||
// Returns an error if the credentials cannot be saved.
|
||||
func (cm *CredentialManager) SetOpenAIOAuthCredentials(creds *OpenAICredentials) error {
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store.OpenAI = creds
|
||||
return cm.SaveCredentials(store)
|
||||
}
|
||||
|
||||
// GetValidOpenAIAccessToken returns a valid access token for API requests. For OAuth credentials,
|
||||
// it automatically refreshes the token if it's expired or about to expire. For API key
|
||||
// credentials, it simply returns the API key. Returns an error if no credentials are found,
|
||||
// if token refresh fails, or if the credential type is unknown.
|
||||
func (cm *CredentialManager) GetValidOpenAIAccessToken() (string, error) {
|
||||
creds, err := cm.GetOpenAICredentials()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if creds == nil {
|
||||
return "", fmt.Errorf("no credentials found")
|
||||
}
|
||||
|
||||
// For API key auth, return the API key
|
||||
if creds.Type == "api_key" {
|
||||
return creds.APIKey, nil
|
||||
}
|
||||
|
||||
// For OAuth, check if token needs refresh
|
||||
if creds.Type == "oauth" {
|
||||
if creds.NeedsRefresh() {
|
||||
// Refresh the token
|
||||
client := NewOpenAIOAuthClient()
|
||||
newCreds, err := client.RefreshToken(creds.RefreshToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to refresh token: %w", err)
|
||||
}
|
||||
|
||||
// Update stored credentials
|
||||
if err := cm.SetOpenAIOAuthCredentials(newCreds); err != nil {
|
||||
return "", fmt.Errorf("failed to save refreshed token: %w", err)
|
||||
}
|
||||
|
||||
return newCreds.AccessToken, nil
|
||||
}
|
||||
|
||||
return creds.AccessToken, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("unknown credential type: %s", creds.Type)
|
||||
}
|
||||
|
||||
// GetCredentialsPath returns the absolute path to the credentials JSON file.
|
||||
// This is useful for debugging or displaying the storage location to users.
|
||||
func (cm *CredentialManager) GetCredentialsPath() string {
|
||||
@@ -238,6 +417,26 @@ func validateAnthropicAPIKey(apiKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateOpenAIAPIKey validates the format of an OpenAI API key
|
||||
func validateOpenAIAPIKey(apiKey string) error {
|
||||
apiKey = strings.TrimSpace(apiKey)
|
||||
|
||||
if apiKey == "" {
|
||||
return fmt.Errorf("API key cannot be empty")
|
||||
}
|
||||
|
||||
// OpenAI API keys typically start with "sk-" and are quite long
|
||||
if !strings.HasPrefix(apiKey, "sk-") {
|
||||
return fmt.Errorf("invalid OpenAI API key format (should start with 'sk-')")
|
||||
}
|
||||
|
||||
if len(apiKey) < 20 {
|
||||
return fmt.Errorf("API key appears to be too short")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAnthropicAPIKey retrieves an Anthropic API key from multiple sources in priority order:
|
||||
// 1. Command-line flag value (highest priority)
|
||||
// 2. Stored credentials (OAuth or API key)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -30,6 +31,7 @@ type OAuthClient struct {
|
||||
type AuthData struct {
|
||||
URL string
|
||||
Verifier string
|
||||
State string // Optional state parameter for CSRF protection
|
||||
}
|
||||
|
||||
// NewOAuthClient creates a new OAuth client configured for Anthropic's OAuth service.
|
||||
@@ -199,6 +201,270 @@ func (c *OAuthClient) parseCodeAndState(code string) (parsedCode, parsedState st
|
||||
return
|
||||
}
|
||||
|
||||
// OpenAIOAuthClient handles OAuth 2.0 authentication flow with OpenAI Codex (ChatGPT Plus/Pro).
|
||||
// This uses OpenAI's auth0-based OAuth service for ChatGPT account authentication.
|
||||
type OpenAIOAuthClient struct {
|
||||
ClientID string
|
||||
AuthorizeURL string
|
||||
TokenURL string
|
||||
RedirectURI string
|
||||
Scopes string
|
||||
}
|
||||
|
||||
// NewOpenAIOAuthClient creates a new OAuth client configured for OpenAI Codex OAuth.
|
||||
// This uses the public client ID for CLI applications with PKCE for security.
|
||||
func NewOpenAIOAuthClient() *OpenAIOAuthClient {
|
||||
return &OpenAIOAuthClient{
|
||||
// Public client ID for OpenAI Codex CLI OAuth
|
||||
ClientID: "app_EMoamEEZ73f0CkXaXp7hrann",
|
||||
AuthorizeURL: "https://auth.openai.com/oauth/authorize",
|
||||
TokenURL: "https://auth.openai.com/oauth/token",
|
||||
RedirectURI: "http://localhost:1455/auth/callback",
|
||||
Scopes: "openid profile email offline_access",
|
||||
}
|
||||
}
|
||||
|
||||
// GetAuthorizationURL generates a complete authorization URL for the OAuth flow with
|
||||
// PKCE parameters. Returns an AuthData structure containing the URL for user
|
||||
// authentication and the PKCE verifier for the subsequent code exchange.
|
||||
func (c *OpenAIOAuthClient) GetAuthorizationURL() (*AuthData, error) {
|
||||
verifier, challenge, err := generatePKCE()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate PKCE: %w", err)
|
||||
}
|
||||
|
||||
// Generate random state
|
||||
stateBytes := make([]byte, 16)
|
||||
if _, err := rand.Read(stateBytes); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
}
|
||||
state := fmt.Sprintf("%x", stateBytes)
|
||||
|
||||
params := url.Values{
|
||||
"response_type": {"code"},
|
||||
"client_id": {c.ClientID},
|
||||
"redirect_uri": {c.RedirectURI},
|
||||
"scope": {c.Scopes},
|
||||
"code_challenge": {challenge},
|
||||
"code_challenge_method": {"S256"},
|
||||
"state": {state},
|
||||
"id_token_add_organizations": {"true"},
|
||||
"codex_cli_simplified_flow": {"true"},
|
||||
"originator": {"kit"},
|
||||
}
|
||||
|
||||
authURL := fmt.Sprintf("%s?%s", c.AuthorizeURL, params.Encode())
|
||||
|
||||
return &AuthData{
|
||||
URL: authURL,
|
||||
Verifier: verifier,
|
||||
State: state,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges an authorization code for access and refresh tokens.
|
||||
// The code parameter should be the authorization code received from the OAuth callback.
|
||||
// The verifier parameter must be the same PKCE verifier generated during GetAuthorizationURL.
|
||||
// Returns OpenAICredentials containing the tokens, expiration, and account ID.
|
||||
func (c *OpenAIOAuthClient) ExchangeCode(code, verifier string) (*OpenAICredentials, error) {
|
||||
return c.exchangeAuthorizationCode(code, verifier, c.RedirectURI)
|
||||
}
|
||||
|
||||
// exchangeAuthorizationCode performs the token exchange with the OAuth server
|
||||
func (c *OpenAIOAuthClient) exchangeAuthorizationCode(code, verifier, redirectUri string) (*OpenAICredentials, error) {
|
||||
data := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"client_id": {c.ClientID},
|
||||
"code": {code},
|
||||
"code_verifier": {verifier},
|
||||
"redirect_uri": {redirectUri},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), "POST", c.TokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make token request: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("token exchange failed: %s", string(body))
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode token response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.AccessToken == "" || tokenResp.RefreshToken == "" {
|
||||
return nil, fmt.Errorf("token response missing required fields")
|
||||
}
|
||||
|
||||
// Extract account ID from JWT token
|
||||
accountID := extractOpenAIAccountID(tokenResp.AccessToken)
|
||||
if accountID == "" {
|
||||
return nil, fmt.Errorf("failed to extract account ID from token")
|
||||
}
|
||||
|
||||
return &OpenAICredentials{
|
||||
Type: "oauth",
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
||||
CreatedAt: time.Now(),
|
||||
AccountID: accountID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an expired or expiring access token using a refresh token.
|
||||
// Returns new OpenAICredentials with updated access token, refresh token (may be
|
||||
// rotated), and new expiration timestamp. Returns an error if the refresh fails or
|
||||
// the refresh token is invalid.
|
||||
func (c *OpenAIOAuthClient) RefreshToken(refreshToken string) (*OpenAICredentials, error) {
|
||||
data := url.Values{
|
||||
"grant_type": {"refresh_token"},
|
||||
"refresh_token": {refreshToken},
|
||||
"client_id": {c.ClientID},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), "POST", c.TokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make refresh request: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("token refresh failed: %s", string(body))
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode refresh response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.AccessToken == "" || tokenResp.RefreshToken == "" {
|
||||
return nil, fmt.Errorf("refresh response missing required fields")
|
||||
}
|
||||
|
||||
// Extract account ID from JWT token
|
||||
accountID := extractOpenAIAccountID(tokenResp.AccessToken)
|
||||
if accountID == "" {
|
||||
return nil, fmt.Errorf("failed to extract account ID from refreshed token")
|
||||
}
|
||||
|
||||
return &OpenAICredentials{
|
||||
Type: "oauth",
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
||||
CreatedAt: time.Now(),
|
||||
AccountID: accountID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extractOpenAIAccountID extracts the ChatGPT account ID from a JWT access token.
|
||||
// The account ID is stored in the claim path https://api.openai.com/auth.chatgpt_account_id
|
||||
func extractOpenAIAccountID(token string) string {
|
||||
// JWT tokens are base64-encoded JSON payloads
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Decode payload (second part)
|
||||
payload := parts[1]
|
||||
// Add padding if needed
|
||||
if len(payload)%4 != 0 {
|
||||
payload += strings.Repeat("=", 4-len(payload)%4)
|
||||
}
|
||||
|
||||
decoded, err := base64.URLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var claims map[string]any
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Navigate to the claim path: https://api.openai.com/auth.chatgpt_account_id
|
||||
authPath, ok := claims["https://api.openai.com/auth"].(map[string]any)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
accountID, ok := authPath["chatgpt_account_id"].(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
return accountID
|
||||
}
|
||||
|
||||
// ParseOpenAIAuthorizationInput parses various forms of authorization input:
|
||||
// - Full callback URL: http://localhost:1455/auth/callback?code=xxx&state=yyy
|
||||
// - Code#State format: abc123#state456
|
||||
// - Query string: code=abc123&state=state456
|
||||
// - Just the code: abc123
|
||||
func ParseOpenAIAuthorizationInput(input string) (code, state string) {
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// Try parsing as URL
|
||||
if strings.HasPrefix(input, "http") {
|
||||
if u, err := url.Parse(input); err == nil {
|
||||
return u.Query().Get("code"), u.Query().Get("state")
|
||||
}
|
||||
}
|
||||
|
||||
// Try code#state format
|
||||
if strings.Contains(input, "#") {
|
||||
parts := strings.SplitN(input, "#", 2)
|
||||
return parts[0], parts[1]
|
||||
}
|
||||
|
||||
// Try query string format
|
||||
if strings.Contains(input, "code=") {
|
||||
if values, err := url.ParseQuery(input); err == nil {
|
||||
return values.Get("code"), values.Get("state")
|
||||
}
|
||||
}
|
||||
|
||||
// Assume it's just the code
|
||||
return input, ""
|
||||
}
|
||||
|
||||
// SetOAuthCredentials stores OAuth credentials in the credential manager's secure storage.
|
||||
// The credentials should include access token, refresh token, and expiration information.
|
||||
// Returns an error if the credentials cannot be saved.
|
||||
|
||||
@@ -428,6 +428,10 @@ type PreviousCompaction struct {
|
||||
ModifiedFiles []string
|
||||
}
|
||||
|
||||
// StreamCallback is called for each chunk of text during streaming compaction.
|
||||
// Return a non-nil error to cancel the stream.
|
||||
type StreamCallback func(delta string) error
|
||||
|
||||
// Compact summarises older messages using the LLM, returning the compaction
|
||||
// result and a new message slice (summary message + preserved recent
|
||||
// messages).
|
||||
@@ -442,6 +446,8 @@ type PreviousCompaction struct {
|
||||
//
|
||||
// prev carries file tracking from a previous compaction for cumulative
|
||||
// tracking. Pass nil if there is no prior compaction.
|
||||
// onChunk is an optional callback for streaming summary text. Pass nil for
|
||||
// non-streaming compaction.
|
||||
func Compact(
|
||||
ctx context.Context,
|
||||
model fantasy.LanguageModel,
|
||||
@@ -449,6 +455,7 @@ func Compact(
|
||||
opts CompactionOptions,
|
||||
customInstructions string,
|
||||
prev *PreviousCompaction,
|
||||
onChunk StreamCallback,
|
||||
) (*CompactionResult, []fantasy.Message, error) {
|
||||
opts.defaults()
|
||||
|
||||
@@ -487,9 +494,9 @@ func Compact(
|
||||
var err error
|
||||
|
||||
if IsSplitTurn(messages, cutPoint) {
|
||||
summaryText, err = compactSplitTurn(ctx, model, oldMessages, messages, cutPoint, opts, customInstructions)
|
||||
summaryText, err = compactSplitTurn(ctx, model, oldMessages, messages, cutPoint, opts, customInstructions, onChunk)
|
||||
} else {
|
||||
summaryText, err = compactNormal(ctx, model, oldMessages, opts, customInstructions)
|
||||
summaryText, err = compactNormal(ctx, model, oldMessages, opts, customInstructions, onChunk)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@@ -527,15 +534,17 @@ func Compact(
|
||||
}
|
||||
|
||||
// compactNormal generates a summary for a clean turn-boundary cut.
|
||||
// If onChunk is provided, text deltas are streamed to it.
|
||||
func compactNormal(
|
||||
ctx context.Context,
|
||||
model fantasy.LanguageModel,
|
||||
oldMessages []fantasy.Message,
|
||||
opts CompactionOptions,
|
||||
customInstructions string,
|
||||
onChunk StreamCallback,
|
||||
) (string, error) {
|
||||
conversationText := serializeMessages(oldMessages)
|
||||
return generateSummary(ctx, model, conversationText, opts, customInstructions)
|
||||
return generateSummary(ctx, model, conversationText, opts, customInstructions, onChunk)
|
||||
}
|
||||
|
||||
// compactSplitTurn handles the case where the cut point lands mid-turn.
|
||||
@@ -546,6 +555,7 @@ func compactNormal(
|
||||
//
|
||||
// The merged result preserves context from both the older history and the
|
||||
// beginning of the current long turn.
|
||||
// If onChunk is provided, both summaries and the separator are streamed.
|
||||
func compactSplitTurn(
|
||||
ctx context.Context,
|
||||
model fantasy.LanguageModel,
|
||||
@@ -554,6 +564,7 @@ func compactSplitTurn(
|
||||
cutPoint int,
|
||||
opts CompactionOptions,
|
||||
customInstructions string,
|
||||
onChunk StreamCallback,
|
||||
) (string, error) {
|
||||
// Find where the split turn starts.
|
||||
turnStart := findTurnStart(allMessages, cutPoint)
|
||||
@@ -573,12 +584,19 @@ func compactSplitTurn(
|
||||
// Generate history summary if there are complete turns before the split.
|
||||
if len(historyMessages) >= 2 {
|
||||
historySummary, err = generateSummary(ctx, model,
|
||||
serializeMessages(historyMessages), opts, "")
|
||||
serializeMessages(historyMessages), opts, "", onChunk)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("split turn history summary failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Stream the separator between history and turn prefix summaries.
|
||||
if onChunk != nil && historySummary != "" {
|
||||
if err := onChunk("\n\n---\n\n## Current Turn (in progress)\n\n"); err != nil {
|
||||
return "", fmt.Errorf("streaming separator failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Generate turn prefix summary.
|
||||
turnPrefixText := serializeMessages(turnPrefixMessages)
|
||||
turnPrefixPrompt := "The messages above are the BEGINNING of a long turn that was split. " +
|
||||
@@ -588,16 +606,10 @@ func compactSplitTurn(
|
||||
turnPrefixPrompt += "\n\nAdditional instructions: " + customInstructions
|
||||
}
|
||||
|
||||
summaryAgent := fantasy.NewAgent(model,
|
||||
fantasy.WithSystemPrompt(defaultSystemPrompt),
|
||||
)
|
||||
result, err := summaryAgent.Generate(ctx, fantasy.AgentCall{
|
||||
Prompt: turnPrefixText + "\n\n" + turnPrefixPrompt,
|
||||
})
|
||||
turnPrefixSummary, err := generateSummary(ctx, model, turnPrefixText, opts, turnPrefixPrompt, onChunk)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("split turn prefix summary failed: %w", err)
|
||||
}
|
||||
turnPrefixSummary := result.Response.Content.Text()
|
||||
|
||||
// Merge the two summaries.
|
||||
if historySummary != "" && turnPrefixSummary != "" {
|
||||
@@ -610,12 +622,14 @@ func compactSplitTurn(
|
||||
}
|
||||
|
||||
// generateSummary calls the LLM to produce a structured summary.
|
||||
// If onChunk is provided, the summary is streamed using Agent.Stream().
|
||||
func generateSummary(
|
||||
ctx context.Context,
|
||||
model fantasy.LanguageModel,
|
||||
conversationText string,
|
||||
opts CompactionOptions,
|
||||
customInstructions string,
|
||||
onChunk StreamCallback,
|
||||
) (string, error) {
|
||||
userPrompt := opts.SummaryPrompt
|
||||
if userPrompt == "" {
|
||||
@@ -628,8 +642,31 @@ func generateSummary(
|
||||
summaryAgent := fantasy.NewAgent(model,
|
||||
fantasy.WithSystemPrompt(defaultSystemPrompt),
|
||||
)
|
||||
|
||||
prompt := conversationText + "\n\n" + userPrompt
|
||||
|
||||
// Use streaming if onChunk is provided.
|
||||
if onChunk != nil {
|
||||
var fullText strings.Builder
|
||||
_, err := summaryAgent.Stream(ctx, fantasy.AgentStreamCall{
|
||||
Prompt: prompt,
|
||||
OnTextDelta: func(_, delta string) error {
|
||||
if delta != "" {
|
||||
fullText.WriteString(delta)
|
||||
return onChunk(delta)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("compaction summarisation (streaming) failed: %w", err)
|
||||
}
|
||||
return fullText.String(), nil
|
||||
}
|
||||
|
||||
// Non-streaming path.
|
||||
result, err := summaryAgent.Generate(ctx, fantasy.AgentCall{
|
||||
Prompt: conversationText + "\n\n" + userPrompt,
|
||||
Prompt: prompt,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("compaction summarisation failed: %w", err)
|
||||
|
||||
@@ -243,7 +243,7 @@ func TestCompact_TooFewMessages(t *testing.T) {
|
||||
makeTextMessageN(fantasy.MessageRoleUser, 400),
|
||||
}
|
||||
|
||||
result, newMsgs, err := Compact(context.TODO(), nil, msgs, CompactionOptions{}, "", nil)
|
||||
result, newMsgs, err := Compact(context.TODO(), nil, msgs, CompactionOptions{}, "", nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -262,7 +262,7 @@ func TestCompact_WithinBudget(t *testing.T) {
|
||||
makeTextMessageN(fantasy.MessageRoleAssistant, 400),
|
||||
}
|
||||
|
||||
result, newMsgs, err := Compact(context.TODO(), nil, msgs, CompactionOptions{}, "", nil)
|
||||
result, newMsgs, err := Compact(context.TODO(), nil, msgs, CompactionOptions{}, "", nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
@@ -191,7 +191,6 @@ type Config struct {
|
||||
Model string `json:"model,omitempty" yaml:"model,omitempty"`
|
||||
MaxSteps int `json:"max-steps,omitempty" yaml:"max-steps,omitempty"`
|
||||
Debug bool `json:"debug,omitempty" yaml:"debug,omitempty"`
|
||||
Compact bool `json:"compact,omitempty" yaml:"compact,omitempty"`
|
||||
SystemPrompt string `json:"system-prompt,omitempty" yaml:"system-prompt,omitempty"`
|
||||
ProviderAPIKey string `json:"provider-api-key,omitempty" yaml:"provider-api-key,omitempty"`
|
||||
ProviderURL string `json:"provider-url,omitempty" yaml:"provider-url,omitempty"`
|
||||
@@ -403,10 +402,9 @@ func FilepathOr[T any](key string, value *T) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
filepath.Join(home, absPath[2:])
|
||||
absPath = filepath.Join(home, absPath[2:])
|
||||
}
|
||||
if !filepath.IsAbs(absPath) {
|
||||
// base := GetConfigPath()
|
||||
base := configPath
|
||||
if base == "" {
|
||||
fmt.Fprintf(os.Stderr, "unable to build relative path to config.")
|
||||
|
||||
+5
-18
@@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -39,20 +40,8 @@ func toolOutputCallbackFromContext(ctx context.Context) ToolOutputCallback {
|
||||
const defaultBashTimeout = 120 * time.Second
|
||||
const maxBashTimeout = 600 * time.Second
|
||||
|
||||
var bannedCommands = []string{
|
||||
"alias ", "bg ", "bind ", "builtin ",
|
||||
"caller ", "command ", "compgen ",
|
||||
"complete ", "compopt ", "coproc ",
|
||||
"dirs ", "disown ", "enable ",
|
||||
"fc ", "fg ", "hash ", "help ",
|
||||
"history ", "jobs ", "kill ",
|
||||
"logout ", "mapfile ", "popd ",
|
||||
"pushd ", "readonly ", "select ",
|
||||
"set ", "shopt ", "source ",
|
||||
"suspend ", "times ", "trap ",
|
||||
"type ", "typeset ", "ulimit ",
|
||||
"umask ", "unalias ", "wait ",
|
||||
}
|
||||
// bannedCmdRe matches bash builtin commands that are not allowed for security reasons.
|
||||
var bannedCmdRe = regexp.MustCompile(`^(alias|bg|bind|builtin|caller|command|compgen|complete|compopt|coproc|dirs|disown|enable|fc|fg|hash|help|history|jobs|kill|logout|mapfile|popd|pushd|readonly|select|set|shopt|source|suspend|times|trap|type|typeset|ulimit|umask|unalias|wait)\s`)
|
||||
|
||||
type bashArgs struct {
|
||||
Command string `json:"command"`
|
||||
@@ -94,10 +83,8 @@ func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
}
|
||||
|
||||
// Check for banned commands
|
||||
for _, banned := range bannedCommands {
|
||||
if strings.HasPrefix(args.Command, banned) {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", args.Command)), nil
|
||||
}
|
||||
if bannedCmdRe.MatchString(args.Command) {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", args.Command)), nil
|
||||
}
|
||||
|
||||
// Determine timeout
|
||||
|
||||
+234
-44
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
@@ -13,19 +14,45 @@ import (
|
||||
udiff "github.com/aymanbagabas/go-udiff"
|
||||
)
|
||||
|
||||
type editArgs struct {
|
||||
Path string `json:"path"`
|
||||
// Edit represents a single replacement in a multi-edit operation.
|
||||
type Edit struct {
|
||||
OldText string `json:"old_text"`
|
||||
NewText string `json:"new_text"`
|
||||
}
|
||||
|
||||
// editArgs holds the arguments for the edit tool.
|
||||
// Supports both single-edit mode (old_text/new_text) and multi-edit mode (edits array).
|
||||
type editArgs struct {
|
||||
Path string `json:"path"`
|
||||
OldText string `json:"old_text"` // Single-edit mode
|
||||
NewText string `json:"new_text"` // Single-edit mode
|
||||
Edits []Edit `json:"edits"` // Multi-edit mode
|
||||
}
|
||||
|
||||
// replacement represents a normalized edit ready for processing.
|
||||
type replacement struct {
|
||||
oldText string // normalized old text for matching
|
||||
newText string // normalized new text
|
||||
originalOld string // original old text for metadata
|
||||
originalNew string // original new text for metadata
|
||||
index int // index in the original edits array (for error messages)
|
||||
}
|
||||
|
||||
// matchedReplacement represents a replacement with its match location.
|
||||
type matchedReplacement struct {
|
||||
replacement
|
||||
start int // start index in normalized content
|
||||
end int // end index in normalized content
|
||||
usedFuzzyMatch bool // true if fuzzy matching was used
|
||||
}
|
||||
|
||||
// NewEditTool creates the edit core tool.
|
||||
func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
cfg := ApplyOptions(opts)
|
||||
return &coreTool{
|
||||
info: fantasy.ToolInfo{
|
||||
Name: "edit",
|
||||
Description: "Edit a file by replacing exact text. The old_text must match exactly (including whitespace). Use this for precise, surgical edits. Fails if old_text is not found or matches multiple locations.",
|
||||
Description: "Edit a file by replacing exact text. Supports single edit via old_text/new_text, or multiple edits via the edits array. All edits in the array are matched against the original file content (non-incremental) and must be non-overlapping.",
|
||||
Parameters: map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
@@ -33,14 +60,32 @@ func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
},
|
||||
"old_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Exact text to find and replace (must match exactly)",
|
||||
"description": "Exact text to find and replace (single-edit mode). Must not be used with 'edits' array.",
|
||||
},
|
||||
"new_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "New text to replace the old text with",
|
||||
"description": "New text to replace the old text with (single-edit mode). Must not be used with 'edits' array.",
|
||||
},
|
||||
"edits": map[string]any{
|
||||
"type": "array",
|
||||
"description": "Array of edits for multi-region replacement. Each edit must have unique, non-overlapping old_text. All matches are against the original file content.",
|
||||
"items": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"old_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Exact text to find and replace for this edit",
|
||||
},
|
||||
"new_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "New text for this edit",
|
||||
},
|
||||
},
|
||||
"required": []string{"old_text", "new_text"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Required: []string{"path", "old_text", "new_text"},
|
||||
Required: []string{"path"},
|
||||
},
|
||||
handler: func(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return executeEdit(ctx, call, cfg.WorkDir)
|
||||
@@ -51,7 +96,7 @@ func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
func executeEdit(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
|
||||
var args editArgs
|
||||
if err := parseArgs(call.Input, &args); err != nil {
|
||||
return fantasy.NewTextErrorResponse("path, old_text, and new_text parameters are required"), nil
|
||||
return fantasy.NewTextErrorResponse("failed to parse arguments: " + err.Error()), nil
|
||||
}
|
||||
if args.Path == "" {
|
||||
return fantasy.NewTextErrorResponse("path parameter is required"), nil
|
||||
@@ -69,56 +114,201 @@ func executeEdit(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
|
||||
content := string(contentBytes)
|
||||
|
||||
// Normalize line endings for matching
|
||||
normalized := strings.ReplaceAll(content, "\r\n", "\n")
|
||||
normalizedOld := strings.ReplaceAll(args.OldText, "\r\n", "\n")
|
||||
|
||||
// Try exact match first
|
||||
count := strings.Count(normalized, normalizedOld)
|
||||
|
||||
// If no exact match, try fuzzy matching
|
||||
if count == 0 {
|
||||
if idx, matchLen := fuzzyMatch(normalized, normalizedOld); idx >= 0 {
|
||||
// Apply fuzzy match — the matched text is the original content slice
|
||||
matchedText := normalized[idx : idx+matchLen]
|
||||
newContent := normalized[:idx] + args.NewText + normalized[idx+matchLen:]
|
||||
if err := os.WriteFile(absPath, []byte(newContent), 0644); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to write file: %v", err)), nil
|
||||
}
|
||||
diff := generateDiff(absPath, normalized, newContent)
|
||||
resp := fantasy.NewTextResponse(fmt.Sprintf("Applied edit (fuzzy match) to %s\n%s", args.Path, diff))
|
||||
return fantasy.WithResponseMetadata(resp, editDiffMeta(absPath, matchedText, args.NewText)), nil
|
||||
}
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("old_text not found in %s", args.Path)), nil
|
||||
// Normalize and validate input
|
||||
replacements, err := normalizeEditInput(args)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
if count > 1 {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("found %d matches for old_text in %s. Provide more context to identify the correct match.", count, args.Path)), nil
|
||||
// Apply all edits
|
||||
newContent, applied, err := applyEdits(content, replacements)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
// Apply the edit
|
||||
newContent := strings.Replace(normalized, normalizedOld, args.NewText, 1)
|
||||
|
||||
// Write the file
|
||||
if err := os.WriteFile(absPath, []byte(newContent), 0644); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to write file: %v", err)), nil
|
||||
}
|
||||
|
||||
diff := generateDiff(absPath, normalized, newContent)
|
||||
resp := fantasy.NewTextResponse(fmt.Sprintf("Applied edit to %s\n%s", args.Path, diff))
|
||||
return fantasy.WithResponseMetadata(resp, editDiffMeta(absPath, normalizedOld, args.NewText)), nil
|
||||
// Generate diff
|
||||
normalizedContent := strings.ReplaceAll(content, "\r\n", "\n")
|
||||
diff := generateDiff(absPath, normalizedContent, newContent)
|
||||
|
||||
// Build response with fuzzy match indication
|
||||
fuzzyCount := 0
|
||||
for _, m := range applied {
|
||||
if m.usedFuzzyMatch {
|
||||
fuzzyCount++
|
||||
}
|
||||
}
|
||||
|
||||
var msg string
|
||||
if len(applied) == 1 {
|
||||
if fuzzyCount > 0 {
|
||||
msg = fmt.Sprintf("Applied edit (fuzzy match) to %s\n%s", args.Path, diff)
|
||||
} else {
|
||||
msg = fmt.Sprintf("Applied edit to %s\n%s", args.Path, diff)
|
||||
}
|
||||
} else {
|
||||
if fuzzyCount > 0 {
|
||||
msg = fmt.Sprintf("Applied %d edits (%d fuzzy) to %s\n%s", len(applied), fuzzyCount, args.Path, diff)
|
||||
} else {
|
||||
msg = fmt.Sprintf("Applied %d edits to %s\n%s", len(applied), args.Path, diff)
|
||||
}
|
||||
}
|
||||
|
||||
resp := fantasy.NewTextResponse(msg)
|
||||
return fantasy.WithResponseMetadata(resp, editDiffMeta(absPath, applied)), nil
|
||||
}
|
||||
|
||||
// normalizeEditInput validates and normalizes the edit input.
|
||||
// Returns error if both single-edit and multi-edit modes are used.
|
||||
func normalizeEditInput(args editArgs) ([]replacement, error) {
|
||||
singleMode := args.OldText != "" || args.NewText != ""
|
||||
multiMode := len(args.Edits) > 0
|
||||
|
||||
if singleMode && multiMode {
|
||||
return nil, fmt.Errorf("cannot use old_text/new_text together with edits array")
|
||||
}
|
||||
|
||||
if !singleMode && !multiMode {
|
||||
return nil, fmt.Errorf("must provide either old_text/new_text or edits array")
|
||||
}
|
||||
|
||||
if singleMode {
|
||||
if args.OldText == "" {
|
||||
return nil, fmt.Errorf("old_text is required when using single-edit mode")
|
||||
}
|
||||
if args.NewText == "" {
|
||||
return nil, fmt.Errorf("new_text is required when using single-edit mode")
|
||||
}
|
||||
return []replacement{{
|
||||
oldText: strings.ReplaceAll(args.OldText, "\r\n", "\n"),
|
||||
newText: strings.ReplaceAll(args.NewText, "\r\n", "\n"),
|
||||
originalOld: args.OldText,
|
||||
originalNew: args.NewText,
|
||||
index: 0,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// Multi-edit mode
|
||||
var reps []replacement
|
||||
for i, edit := range args.Edits {
|
||||
if edit.OldText == "" {
|
||||
return nil, fmt.Errorf("edits[%d].old_text is required", i)
|
||||
}
|
||||
reps = append(reps, replacement{
|
||||
oldText: strings.ReplaceAll(edit.OldText, "\r\n", "\n"),
|
||||
newText: strings.ReplaceAll(edit.NewText, "\r\n", "\n"),
|
||||
originalOld: edit.OldText,
|
||||
originalNew: edit.NewText,
|
||||
index: i,
|
||||
})
|
||||
}
|
||||
return reps, nil
|
||||
}
|
||||
|
||||
// applyEdits applies multiple replacements to the content.
|
||||
// All matches are against the original content (non-incremental).
|
||||
// Returns the new content, the applied matches, and any error.
|
||||
func applyEdits(content string, edits []replacement) (string, []matchedReplacement, error) {
|
||||
normalizedContent := strings.ReplaceAll(content, "\r\n", "\n")
|
||||
|
||||
// Find all matches
|
||||
var matched []matchedReplacement
|
||||
for _, edit := range edits {
|
||||
m, err := findMatch(normalizedContent, edit)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
matched = append(matched, *m)
|
||||
}
|
||||
|
||||
// Sort by position
|
||||
sort.Slice(matched, func(i, j int) bool {
|
||||
return matched[i].start < matched[j].start
|
||||
})
|
||||
|
||||
// Check for overlaps
|
||||
for i := 1; i < len(matched); i++ {
|
||||
if matched[i-1].end > matched[i].start {
|
||||
return "", nil, fmt.Errorf("edits[%d] and edits[%d] overlap; merge them into a single edit",
|
||||
matched[i-1].index, matched[i].index)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply edits in reverse order (end to start) to maintain stable offsets
|
||||
result := normalizedContent
|
||||
for i := len(matched) - 1; i >= 0; i-- {
|
||||
m := matched[i]
|
||||
result = result[:m.start] + m.newText + result[m.end:]
|
||||
}
|
||||
|
||||
return result, matched, nil
|
||||
}
|
||||
|
||||
// findMatch finds a unique match for the edit in the content.
|
||||
// Returns error if not found or ambiguous.
|
||||
func findMatch(content string, edit replacement) (*matchedReplacement, error) {
|
||||
// Try exact match first
|
||||
count := strings.Count(content, edit.oldText)
|
||||
|
||||
if count == 0 {
|
||||
// Try fuzzy match
|
||||
idx, matchLen := fuzzyMatch(content, edit.oldText)
|
||||
if idx < 0 {
|
||||
return nil, fmt.Errorf("edits[%d]: could not find old_text in file. The text must match exactly (including whitespace)", edit.index)
|
||||
}
|
||||
// Use the matched text from content for the replacement
|
||||
matchedText := content[idx : idx+matchLen]
|
||||
return &matchedReplacement{
|
||||
replacement: replacement{
|
||||
oldText: matchedText,
|
||||
newText: edit.newText,
|
||||
originalOld: edit.originalOld,
|
||||
originalNew: edit.originalNew,
|
||||
index: edit.index,
|
||||
},
|
||||
start: idx,
|
||||
end: idx + matchLen,
|
||||
usedFuzzyMatch: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if count > 1 {
|
||||
return nil, fmt.Errorf("found %d matches for edits[%d].old_text; each old_text must be unique, provide more context to identify the correct match", count, edit.index)
|
||||
}
|
||||
|
||||
// Single exact match
|
||||
idx := strings.Index(content, edit.oldText)
|
||||
return &matchedReplacement{
|
||||
replacement: edit,
|
||||
start: idx,
|
||||
end: idx + len(edit.oldText),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// editDiffMeta builds the structured metadata attached to edit tool responses.
|
||||
func editDiffMeta(path, oldText, newText string) map[string]any {
|
||||
func editDiffMeta(path string, applied []matchedReplacement) map[string]any {
|
||||
var diffBlocks []map[string]any
|
||||
totalAdditions, totalDeletions := 0, 0
|
||||
|
||||
for _, m := range applied {
|
||||
diffBlocks = append(diffBlocks, map[string]any{
|
||||
"old_text": m.originalOld,
|
||||
"new_text": m.originalNew,
|
||||
})
|
||||
totalAdditions += strings.Count(m.originalNew, "\n") + 1
|
||||
totalDeletions += strings.Count(m.originalOld, "\n") + 1
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"file_diffs": []map[string]any{{
|
||||
"path": path,
|
||||
"additions": strings.Count(newText, "\n") + 1,
|
||||
"deletions": strings.Count(oldText, "\n") + 1,
|
||||
"diff_blocks": []map[string]any{{
|
||||
"old_text": oldText,
|
||||
"new_text": newText,
|
||||
}},
|
||||
"path": path,
|
||||
"additions": totalAdditions,
|
||||
"deletions": totalDeletions,
|
||||
"diff_blocks": diffBlocks,
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -715,3 +715,315 @@ func TestExecuteEdit_MetadataContainsFileDiffs(t *testing.T) {
|
||||
t.Fatal("file_diffs should be a non-empty array")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Multi-edit tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExecuteEdit_MultiEdit_Basic(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "multi.txt")
|
||||
writeFileOrFail(t, path, "line1\nline2\nline3\nline4\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "line1", NewText: "LINE1"},
|
||||
{OldText: "line3", NewText: "LINE3"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("tool returned error: %s", resp.Content)
|
||||
}
|
||||
|
||||
got, _ := os.ReadFile(path)
|
||||
gotStr := string(got)
|
||||
|
||||
if !strings.Contains(gotStr, "LINE1") {
|
||||
t.Error("first edit not applied: missing LINE1")
|
||||
}
|
||||
if !strings.Contains(gotStr, "LINE3") {
|
||||
t.Error("second edit not applied: missing LINE3")
|
||||
}
|
||||
if !strings.Contains(gotStr, "line2") {
|
||||
t.Error("line2 was modified but should be untouched")
|
||||
}
|
||||
if !strings.Contains(gotStr, "line4") {
|
||||
t.Error("line4 was modified but should be untouched")
|
||||
}
|
||||
|
||||
// Check response mentions multiple edits
|
||||
if !strings.Contains(resp.Content, "2 edits") {
|
||||
t.Errorf("response should mention '2 edits', got: %s", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_NonIncrementalMatching(t *testing.T) {
|
||||
// All edits are matched against the original content, not incrementally
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "noninc.txt")
|
||||
writeFileOrFail(t, path, "aaa\nbbb\nccc\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "aaa", NewText: "AAA"},
|
||||
{OldText: "bbb", NewText: "BBB"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("tool returned error: %s", resp.Content)
|
||||
}
|
||||
|
||||
got, _ := os.ReadFile(path)
|
||||
gotStr := string(got)
|
||||
|
||||
want := "AAA\nBBB\nccc\n"
|
||||
if gotStr != want {
|
||||
t.Errorf("got %q, want %q", gotStr, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_OverlapDetection(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "overlap.txt")
|
||||
writeFileOrFail(t, path, "hello world\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "hello", NewText: "HELLO"},
|
||||
{OldText: "hello world", NewText: "GOODBYE"}, // Overlaps with first edit
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error for overlapping edits")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "overlap") {
|
||||
t.Errorf("expected 'overlap' in error, got: %s", resp.Content)
|
||||
}
|
||||
|
||||
// File should be untouched
|
||||
got, _ := os.ReadFile(path)
|
||||
if string(got) != "hello world\n" {
|
||||
t.Error("file was modified despite error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_DuplicateDetection(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "dup.txt")
|
||||
writeFileOrFail(t, path, "hello\nworld\nhello\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "hello", NewText: "HELLO"},
|
||||
{OldText: "world", NewText: "WORLD"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error for ambiguous old_text (duplicate matches)")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "unique") {
|
||||
t.Errorf("expected 'unique' in error, got: %s", resp.Content)
|
||||
}
|
||||
|
||||
// File should be untouched
|
||||
got, _ := os.ReadFile(path)
|
||||
if string(got) != "hello\nworld\nhello\n" {
|
||||
t.Error("file was modified despite error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_NotFound(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "notfound.txt")
|
||||
writeFileOrFail(t, path, "hello world\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "nonexistent", NewText: "REPLACEMENT"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error for not found")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "edits[0]") {
|
||||
t.Errorf("expected 'edits[0]' in error, got: %s", resp.Content)
|
||||
}
|
||||
|
||||
// File should be untouched
|
||||
got, _ := os.ReadFile(path)
|
||||
if string(got) != "hello world\n" {
|
||||
t.Error("file was modified despite error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_EmptyArray(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "empty.txt")
|
||||
writeFileOrFail(t, path, "hello\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error for empty edits array")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_MixedWithSingleMode(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "mixed.txt")
|
||||
writeFileOrFail(t, path, "hello\n")
|
||||
|
||||
input, _ := json.Marshal(map[string]any{
|
||||
"path": path,
|
||||
"old_text": "hello",
|
||||
"new_text": "HELLO",
|
||||
"edits": []Edit{
|
||||
{OldText: "hello", NewText: "HI"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error when mixing single and multi-edit modes")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "cannot use") {
|
||||
t.Errorf("expected 'cannot use' in error, got: %s", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_FuzzyMatch(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "fuzzy_multi.txt")
|
||||
// File has trailing whitespace
|
||||
original := "func foo() { \n\treturn 1 \n}\nfunc bar() { \n\treturn 2 \n}\n"
|
||||
writeFileOrFail(t, path, original)
|
||||
|
||||
// Search without trailing whitespace (common LLM behavior)
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "func foo() {\n\treturn 1\n}", NewText: "func foo() {\n\treturn 10\n}"},
|
||||
{OldText: "func bar() {\n\treturn 2\n}", NewText: "func bar() {\n\treturn 20\n}"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("tool returned error: %s", resp.Content)
|
||||
}
|
||||
|
||||
got, _ := os.ReadFile(path)
|
||||
gotStr := string(got)
|
||||
|
||||
if !strings.Contains(gotStr, "return 10") {
|
||||
t.Error("first edit not applied")
|
||||
}
|
||||
if !strings.Contains(gotStr, "return 20") {
|
||||
t.Error("second edit not applied")
|
||||
}
|
||||
|
||||
// Response should mention fuzzy match
|
||||
if !strings.Contains(resp.Content, "fuzzy") {
|
||||
t.Errorf("response should mention 'fuzzy', got: %s", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_Metadata(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "meta_multi.txt")
|
||||
writeFileOrFail(t, path, "aaa\nbbb\nccc\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "aaa", NewText: "AAA"},
|
||||
{OldText: "bbb", NewText: "BBB"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("tool returned error: %s", resp.Content)
|
||||
}
|
||||
|
||||
var meta map[string]any
|
||||
if err := json.Unmarshal([]byte(resp.Metadata), &meta); err != nil {
|
||||
t.Fatalf("metadata is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
diffs, ok := meta["file_diffs"].([]any)
|
||||
if !ok || len(diffs) == 0 {
|
||||
t.Fatal("metadata missing file_diffs")
|
||||
}
|
||||
|
||||
firstDiff, ok := diffs[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("first diff is not an object")
|
||||
}
|
||||
|
||||
// Check that diff_blocks contains both edits
|
||||
diffBlocks, ok := firstDiff["diff_blocks"].([]any)
|
||||
if !ok || len(diffBlocks) != 2 {
|
||||
t.Fatalf("expected 2 diff_blocks, got %d", len(diffBlocks))
|
||||
}
|
||||
|
||||
// Verify each block has old_text and new_text
|
||||
for i, block := range diffBlocks {
|
||||
b, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("diff_block[%d] is not an object", i)
|
||||
}
|
||||
if _, ok := b["old_text"]; !ok {
|
||||
t.Fatalf("diff_block[%d] missing old_text", i)
|
||||
}
|
||||
if _, ok := b["new_text"]; !ok {
|
||||
t.Fatalf("diff_block[%d] missing new_text", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,14 +28,14 @@ type SubagentSpawnResult struct {
|
||||
// SubagentSpawnFunc is a callback that spawns an in-process subagent. The
|
||||
// parent Kit instance injects this into the context so the core tool can
|
||||
// call back without importing pkg/kit (which would create a cycle).
|
||||
// The toolCallID parameter is the LLM-assigned ID of the spawn_subagent
|
||||
// The toolCallID parameter is the LLM-assigned ID of the subagent
|
||||
// tool call, enabling the parent to correlate subagent events.
|
||||
type SubagentSpawnFunc func(ctx context.Context, toolCallID, prompt, model, systemPrompt string, timeout time.Duration) (*SubagentSpawnResult, error)
|
||||
|
||||
type subagentCtxKey struct{}
|
||||
|
||||
// WithSubagentSpawner stores a spawn function in the context so that the
|
||||
// spawn_subagent core tool can create in-process subagents.
|
||||
// subagent core tool can create in-process subagents.
|
||||
func WithSubagentSpawner(ctx context.Context, fn SubagentSpawnFunc) context.Context {
|
||||
return context.WithValue(ctx, subagentCtxKey{}, fn)
|
||||
}
|
||||
@@ -49,7 +49,7 @@ func getSubagentSpawner(ctx context.Context) SubagentSpawnFunc {
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// spawn_subagent tool
|
||||
// subagent tool
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type subagentArgs struct {
|
||||
@@ -59,11 +59,11 @@ type subagentArgs struct {
|
||||
TimeoutSeconds int `json:"timeout_seconds,omitempty"`
|
||||
}
|
||||
|
||||
// NewSubagentTool creates the spawn_subagent core tool.
|
||||
// NewSubagentTool creates the subagent core tool.
|
||||
func NewSubagentTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
return &coreTool{
|
||||
info: fantasy.ToolInfo{
|
||||
Name: "spawn_subagent",
|
||||
Name: "subagent",
|
||||
Description: `Spawn a subagent to perform a task autonomously.
|
||||
|
||||
The subagent runs as a separate in-process Kit instance with full tool access
|
||||
|
||||
@@ -86,7 +86,7 @@ func ReadOnlyTools(opts ...ToolOption) []fantasy.AgentTool {
|
||||
}
|
||||
}
|
||||
|
||||
// SubagentTools returns all core tools except spawn_subagent. This prevents
|
||||
// SubagentTools returns all core tools except subagent. This prevents
|
||||
// infinite recursion when a subagent is itself a Kit instance.
|
||||
func SubagentTools(opts ...ToolOption) []fantasy.AgentTool {
|
||||
return []fantasy.AgentTool{
|
||||
|
||||
+319
-1
@@ -572,6 +572,102 @@ type Context struct {
|
||||
// })
|
||||
// // handle.Kill() to cancel, handle.Wait() to block
|
||||
SpawnSubagent func(SubagentConfig) (*SubagentHandle, *SubagentResult, error)
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Tree Navigation API (Phase 1 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// GetTreeNode returns a node by ID with full metadata and children.
|
||||
// Returns nil if entry not found.
|
||||
GetTreeNode func(entryID string) *TreeNode
|
||||
|
||||
// GetCurrentBranch returns the path from root to current leaf.
|
||||
// Each node contains full metadata (unlike GetMessages which flattens).
|
||||
GetCurrentBranch func() []TreeNode
|
||||
|
||||
// GetChildren returns direct child IDs of an entry.
|
||||
GetChildren func(entryID string) []string
|
||||
|
||||
// NavigateTo branches/forks the session to the specified entry ID.
|
||||
// Equivalent to SDK's Branch() but for extensions.
|
||||
NavigateTo func(entryID string) TreeNavigationResult
|
||||
|
||||
// SummarizeBranch uses LLM to summarize a branch range.
|
||||
// Returns summary text or error string (empty if success).
|
||||
SummarizeBranch func(fromID, toID string) string
|
||||
|
||||
// CollapseBranch replaces a branch range with a summary entry.
|
||||
// This is the "fresh context" primitive for context window management.
|
||||
CollapseBranch func(fromID, toID, summary string) TreeNavigationResult
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Skill Loading API (Phase 2 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// LoadSkill loads a single skill file from path.
|
||||
// Parses YAML frontmatter, returns skill with content ready for injection.
|
||||
LoadSkill func(path string) (*Skill, string)
|
||||
|
||||
// LoadSkillsFromDir discovers and loads all skills from a directory.
|
||||
LoadSkillsFromDir func(dir string) SkillLoadResult
|
||||
|
||||
// DiscoverSkills finds skills in standard locations.
|
||||
// Checks ~/.config/kit/skills/, .kit/skills/, .agents/skills/
|
||||
DiscoverSkills func() SkillLoadResult
|
||||
|
||||
// InjectSkillAsContext sends a skill's content as a system message.
|
||||
// Looks up skill by name from discovered skills.
|
||||
InjectSkillAsContext func(skillName string) string
|
||||
|
||||
// InjectRawSkillAsContext loads and immediately injects a skill file.
|
||||
InjectRawSkillAsContext func(path string) string
|
||||
|
||||
// GetAvailableSkills returns all currently loaded/discovered skills.
|
||||
GetAvailableSkills func() []Skill
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Template Parsing API (Phase 3 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// ParseTemplate extracts {{variables}} from template content.
|
||||
ParseTemplate func(name, content string) PromptTemplate
|
||||
|
||||
// RenderTemplate substitutes variables into template content.
|
||||
RenderTemplate func(tpl PromptTemplate, vars map[string]string) string
|
||||
|
||||
// ParseArguments parses command-line style arguments.
|
||||
ParseArguments func(input string, pattern ArgumentPattern) ParseResult
|
||||
|
||||
// SimpleParseArguments parses $1, $2, $@ style arguments.
|
||||
// Returns slice where [0]=full input, [1]=$1, [2]=$2, ... [n]=$@
|
||||
SimpleParseArguments func(input string, count int) []string
|
||||
|
||||
// EvaluateModelConditional checks if condition matches current model.
|
||||
// Condition supports wildcards: * matches any, ? matches single char.
|
||||
EvaluateModelConditional func(condition string) bool
|
||||
|
||||
// RenderWithModelConditionals processes <if-model> blocks in content.
|
||||
RenderWithModelConditionals func(content string) string
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Model Resolution API (Phase 4 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// ResolveModelChain attempts each model in order until one is available.
|
||||
ResolveModelChain func(preferences []string) ModelResolutionResult
|
||||
|
||||
// GetModelCapabilities returns capabilities for a specific model.
|
||||
// If model is empty, uses current model.
|
||||
GetModelCapabilities func(model string) (ModelCapabilities, string)
|
||||
|
||||
// CheckModelAvailable verifies if a model string is valid.
|
||||
CheckModelAvailable func(model string) bool
|
||||
|
||||
// GetCurrentProvider returns just the provider part of current model.
|
||||
GetCurrentProvider func() string
|
||||
|
||||
// GetCurrentModelID returns just the model ID part of current model.
|
||||
GetCurrentModelID func() string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -598,6 +694,148 @@ type SessionMessage struct {
|
||||
Timestamp string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tree navigation types (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TreeNode represents a node in the session tree for navigation.
|
||||
// Extensions use this to traverse conversation history and implement
|
||||
// features like "fresh context" loops and branch summarization.
|
||||
type TreeNode struct {
|
||||
// ID is the unique entry identifier.
|
||||
ID string
|
||||
// ParentID links this entry to its parent (empty if root).
|
||||
ParentID string
|
||||
// Type is the entry type: "message", "branch_summary", "model_change", "extension_data", "tool_execution".
|
||||
Type string
|
||||
// Role is the message role for message entries: "user", "assistant", "system", "tool".
|
||||
Role string
|
||||
// Content is the text content or summary.
|
||||
Content string
|
||||
// Model is the model that generated this (for assistant messages).
|
||||
Model string
|
||||
// Provider is the provider used.
|
||||
Provider string
|
||||
// Timestamp is the RFC3339-formatted creation time.
|
||||
Timestamp string
|
||||
// Children is the list of child entry IDs for tree traversal.
|
||||
Children []string
|
||||
}
|
||||
|
||||
// TreeNavigationResult reports success or failure of tree operations.
|
||||
type TreeNavigationResult struct {
|
||||
// Success is true if the operation completed.
|
||||
Success bool
|
||||
// Error describes what went wrong (empty if success).
|
||||
Error string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Skill types (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Skill represents a loaded skill file with parsed YAML frontmatter.
|
||||
type Skill struct {
|
||||
// Name is the human-readable identifier.
|
||||
Name string
|
||||
// Description summarizes what this skill provides.
|
||||
Description string
|
||||
// Content is the markdown body (frontmatter stripped).
|
||||
Content string
|
||||
// Path is the absolute filesystem path.
|
||||
Path string
|
||||
// Tags are optional labels for categorization.
|
||||
Tags []string
|
||||
// When controls automatic inclusion: "always", "on-demand", or file-glob.
|
||||
When string
|
||||
}
|
||||
|
||||
// SkillLoadResult reports skills loaded from a directory.
|
||||
type SkillLoadResult struct {
|
||||
// Skills is the list of loaded skills.
|
||||
Skills []Skill
|
||||
// Error describes loading failures (empty if success).
|
||||
Error string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Template parsing types (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// PromptTemplate represents a parsed template with variable placeholders.
|
||||
type PromptTemplate struct {
|
||||
// Name is the template identifier.
|
||||
Name string
|
||||
// Content is the original template content.
|
||||
Content string
|
||||
// Variables are the extracted {{variable}} names.
|
||||
Variables []string
|
||||
}
|
||||
|
||||
// ArgumentPattern defines how to parse command arguments.
|
||||
type ArgumentPattern struct {
|
||||
// Positional names for $1, $2, etc.
|
||||
Positional []string
|
||||
// Rest is the variable name for $@ (all remaining).
|
||||
Rest string
|
||||
// Flags maps flag names to variable names (e.g., "--loop" -> "loop").
|
||||
Flags map[string]string
|
||||
}
|
||||
|
||||
// ParseResult reports argument parsing outcome.
|
||||
type ParseResult struct {
|
||||
// Vars maps variable names to values for positional args.
|
||||
Vars map[string]string
|
||||
// Flags maps flag names to values.
|
||||
Flags map[string]string
|
||||
// Rest is remaining unparsed text.
|
||||
Rest string
|
||||
// Error describes parsing failures (empty if success).
|
||||
Error string
|
||||
}
|
||||
|
||||
// ModelConditional represents an <if-model> block for evaluation.
|
||||
type ModelConditional struct {
|
||||
// Condition is the model pattern (e.g., "claude-*", "anthropic/*").
|
||||
Condition string
|
||||
// Content is rendered if condition matches.
|
||||
Content string
|
||||
// Else is rendered if condition doesn't match.
|
||||
Else string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Model resolution types (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ModelCapabilities describes what a model supports.
|
||||
type ModelCapabilities struct {
|
||||
// Provider is the provider ID (e.g., "anthropic").
|
||||
Provider string
|
||||
// ModelID is the model identifier (e.g., "claude-sonnet-4-20250929").
|
||||
ModelID string
|
||||
// ContextLimit is the maximum context window in tokens.
|
||||
ContextLimit int
|
||||
// OutputLimit is the maximum output tokens.
|
||||
OutputLimit int
|
||||
// Reasoning indicates if the model supports reasoning/thinking.
|
||||
Reasoning bool
|
||||
// Streaming indicates if the model supports streaming.
|
||||
Streaming bool
|
||||
}
|
||||
|
||||
// ModelResolutionResult reports model chain resolution outcome.
|
||||
type ModelResolutionResult struct {
|
||||
// Model is the selected model in "provider/model" format.
|
||||
Model string
|
||||
// Capabilities describes the selected model.
|
||||
Capabilities ModelCapabilities
|
||||
// Attempted lists models tried before success.
|
||||
Attempted []string
|
||||
// Error describes resolution failures (empty if success).
|
||||
Error string
|
||||
}
|
||||
|
||||
// ExtensionEntry represents persisted extension data stored in the session.
|
||||
// Extensions use AppendEntry to save custom state and GetEntries to retrieve
|
||||
// it on session resume.
|
||||
@@ -750,6 +988,9 @@ type API struct {
|
||||
registerOption func(OptionDef)
|
||||
registerShortcutFn func(ShortcutDef, func(Context))
|
||||
registerMessageRendererFn func(MessageRendererConfig)
|
||||
onSubagentStart func(func(SubagentStartEvent, Context))
|
||||
onSubagentChunk func(func(SubagentChunkEvent, Context))
|
||||
onSubagentEnd func(func(SubagentEndEvent, Context))
|
||||
}
|
||||
|
||||
// OnToolCall registers a handler that fires before a tool executes.
|
||||
@@ -781,6 +1022,27 @@ func (a *API) OnToolResult(handler func(ToolResultEvent, Context) *ToolResultRes
|
||||
a.onToolResult(handler)
|
||||
}
|
||||
|
||||
// OnSubagentStart registers a handler that fires when a subagent tool
|
||||
// call begins executing. Use the ToolCallID to correlate with subsequent
|
||||
// OnSubagentChunk and OnSubagentEnd events for the same subagent.
|
||||
func (a *API) OnSubagentStart(handler func(SubagentStartEvent, Context)) {
|
||||
a.onSubagentStart(handler)
|
||||
}
|
||||
|
||||
// OnSubagentChunk registers a handler for real-time events from a running
|
||||
// subagent. ChunkType identifies the kind of event ("text", "tool_call",
|
||||
// "tool_result", "tool_execution_start", "tool_execution_end", etc.).
|
||||
// Correlate with OnSubagentStart via the ToolCallID field.
|
||||
func (a *API) OnSubagentChunk(handler func(SubagentChunkEvent, Context)) {
|
||||
a.onSubagentChunk(handler)
|
||||
}
|
||||
|
||||
// OnSubagentEnd registers a handler that fires when a subagent call
|
||||
// completes. ErrorMsg is non-empty when the subagent failed.
|
||||
func (a *API) OnSubagentEnd(handler func(SubagentEndEvent, Context)) {
|
||||
a.onSubagentEnd(handler)
|
||||
}
|
||||
|
||||
// OnInput registers a handler that fires when user input is received.
|
||||
// Return a non-nil InputResult to transform or handle the input.
|
||||
func (a *API) OnInput(handler func(InputEvent, Context) *InputResult) {
|
||||
@@ -1781,9 +2043,65 @@ type BeforeCompactResult struct {
|
||||
func (BeforeCompactResult) isResult() {}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Theme types (exposed to Yaegi — concrete structs, string hex colors)
|
||||
// Subagent lifecycle events (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SubagentStartEvent fires when a subagent tool call begins executing.
|
||||
type SubagentStartEvent struct {
|
||||
// ToolCallID is the LLM-assigned ID of the subagent tool call.
|
||||
// Use this to correlate SubagentChunkEvent and SubagentEndEvent.
|
||||
ToolCallID string
|
||||
// Task is the task description passed to the subagent.
|
||||
Task string
|
||||
}
|
||||
|
||||
func (e SubagentStartEvent) Type() EventType { return SubagentStart }
|
||||
|
||||
// SubagentChunkEvent fires for each real-time event from a running subagent.
|
||||
// Type field indicates the kind of event; read the relevant fields accordingly.
|
||||
type SubagentChunkEvent struct {
|
||||
// ToolCallID matches the SubagentStartEvent.ToolCallID for this subagent.
|
||||
ToolCallID string
|
||||
// Task is the task description (repeated for convenience).
|
||||
Task string
|
||||
// ChunkType identifies the event kind:
|
||||
// "text" — LLM text chunk (read Content)
|
||||
// "reasoning" — reasoning/thinking delta (read Content)
|
||||
// "tool_call" — subagent called a tool (read ToolName, ToolArgs)
|
||||
// "tool_result" — tool returned a result (read ToolName, ToolResult, IsError)
|
||||
// "tool_execution_start" — tool began executing (read ToolName)
|
||||
// "tool_execution_end" — tool finished executing (read ToolName)
|
||||
// "turn_start" — subagent turn began
|
||||
// "turn_end" — subagent turn ended
|
||||
ChunkType string
|
||||
// Content carries text for "text" and "reasoning" chunk types.
|
||||
Content string
|
||||
// ToolName is set on tool-related chunk types.
|
||||
ToolName string
|
||||
// ToolArgs is the JSON-encoded tool arguments for "tool_call" chunks.
|
||||
ToolArgs string
|
||||
// ToolResult is the tool output for "tool_result" chunks.
|
||||
ToolResult string
|
||||
// IsError is true when a "tool_result" chunk represents an error.
|
||||
IsError bool
|
||||
}
|
||||
|
||||
func (e SubagentChunkEvent) Type() EventType { return SubagentChunk }
|
||||
|
||||
// SubagentEndEvent fires when a subagent tool call completes.
|
||||
type SubagentEndEvent struct {
|
||||
// ToolCallID matches the SubagentStartEvent.ToolCallID for this subagent.
|
||||
ToolCallID string
|
||||
// Task is the task description.
|
||||
Task string
|
||||
// Response is the subagent's final text response (empty on error).
|
||||
Response string
|
||||
// ErrorMsg is non-empty when the subagent failed.
|
||||
ErrorMsg string
|
||||
}
|
||||
|
||||
func (e SubagentEndEvent) Type() EventType { return SubagentEnd }
|
||||
|
||||
// ThemeColor is an adaptive color pair with light and dark hex values.
|
||||
// Either field may be empty to inherit from the default theme.
|
||||
type ThemeColor struct {
|
||||
|
||||
@@ -71,6 +71,18 @@ const (
|
||||
// BeforeCompact fires before context compaction runs. Handlers can
|
||||
// cancel compaction by returning Cancel=true.
|
||||
BeforeCompact EventType = "before_compact"
|
||||
|
||||
// SubagentStart fires when a subagent tool call begins executing.
|
||||
// Carries the tool call ID and the task description.
|
||||
SubagentStart EventType = "subagent_start"
|
||||
|
||||
// SubagentChunk fires for each real-time event emitted by a running
|
||||
// subagent: text chunks, tool calls, tool results, etc.
|
||||
SubagentChunk EventType = "subagent_chunk"
|
||||
|
||||
// SubagentEnd fires when a subagent tool call completes (success
|
||||
// or error). Carries the final response and any error message.
|
||||
SubagentEnd EventType = "subagent_end"
|
||||
)
|
||||
|
||||
// AllEventTypes returns every supported event type.
|
||||
@@ -82,6 +94,7 @@ func AllEventTypes() []EventType {
|
||||
SessionStart, SessionShutdown,
|
||||
ModelChange, ContextPrepare,
|
||||
BeforeFork, BeforeSessionSwitch, BeforeCompact,
|
||||
SubagentStart, SubagentChunk, SubagentEnd,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ import "testing"
|
||||
|
||||
func TestAllEventTypes_Count(t *testing.T) {
|
||||
all := AllEventTypes()
|
||||
if len(all) != 18 {
|
||||
t.Fatalf("expected 18 event types, got %d", len(all))
|
||||
if len(all) != 21 {
|
||||
t.Fatalf("expected 21 event types, got %d", len(all))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,6 +55,9 @@ func TestEventType_TypeMethod(t *testing.T) {
|
||||
{BeforeForkEvent{TargetID: "abc"}, BeforeFork},
|
||||
{BeforeSessionSwitchEvent{Reason: "new"}, BeforeSessionSwitch},
|
||||
{BeforeCompactEvent{EstimatedTokens: 1000}, BeforeCompact},
|
||||
{SubagentStartEvent{ToolCallID: "x", Task: "t"}, SubagentStart},
|
||||
{SubagentChunkEvent{ToolCallID: "x", ChunkType: "text"}, SubagentChunk},
|
||||
{SubagentEndEvent{ToolCallID: "x"}, SubagentEnd},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -47,46 +47,56 @@ func LoadExtensions(extraPaths []string) ([]LoadedExtension, error) {
|
||||
return loaded, nil
|
||||
}
|
||||
|
||||
// pathSet is a thread-safe helper for deduplicating and ordering file paths.
|
||||
type pathSet struct {
|
||||
m map[string]bool
|
||||
list []string
|
||||
}
|
||||
|
||||
func newPathSet() *pathSet {
|
||||
return &pathSet{m: make(map[string]bool)}
|
||||
}
|
||||
|
||||
func (ps *pathSet) add(p string) bool {
|
||||
abs, err := filepath.Abs(p)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if ps.m[abs] {
|
||||
return false
|
||||
}
|
||||
ps.m[abs] = true
|
||||
ps.list = append(ps.list, abs)
|
||||
return true
|
||||
}
|
||||
|
||||
// discoverExtensionPaths returns deduplicated paths to extension files in
|
||||
// load-order (global first, then project-local, then explicit).
|
||||
func discoverExtensionPaths(extraPaths []string) []string {
|
||||
seen := make(map[string]bool)
|
||||
var paths []string
|
||||
|
||||
add := func(p string) {
|
||||
abs, err := filepath.Abs(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if seen[abs] {
|
||||
return
|
||||
}
|
||||
seen[abs] = true
|
||||
paths = append(paths, abs)
|
||||
}
|
||||
ps := newPathSet()
|
||||
|
||||
// Global extensions: $XDG_CONFIG_HOME/kit/extensions/ (default ~/.config/kit/extensions/)
|
||||
globalDir := globalExtensionsDir()
|
||||
for _, p := range findExtensionsInDir(globalDir) {
|
||||
add(p)
|
||||
ps.add(p)
|
||||
}
|
||||
|
||||
// Global installed git packages: $XDG_DATA_HOME/kit/git/
|
||||
globalGitDir := globalGitInstallRoot()
|
||||
for _, p := range findExtensionsInGitPackages(globalGitDir) {
|
||||
add(p)
|
||||
ps.add(p)
|
||||
}
|
||||
|
||||
// Project-local extensions: .kit/extensions/
|
||||
localDir := filepath.Join(".kit", "extensions")
|
||||
for _, p := range findExtensionsInDir(localDir) {
|
||||
add(p)
|
||||
ps.add(p)
|
||||
}
|
||||
|
||||
// Project-local installed git packages: .kit/git/
|
||||
projectGitDir := filepath.Join(".kit", "git")
|
||||
for _, p := range findExtensionsInGitPackages(projectGitDir) {
|
||||
add(p)
|
||||
ps.add(p)
|
||||
}
|
||||
|
||||
// Explicit paths (highest precedence)
|
||||
@@ -97,14 +107,14 @@ func discoverExtensionPaths(extraPaths []string) []string {
|
||||
}
|
||||
if info.IsDir() {
|
||||
for _, found := range findExtensionsInDir(p) {
|
||||
add(found)
|
||||
ps.add(found)
|
||||
}
|
||||
} else if strings.HasSuffix(p, ".go") {
|
||||
add(p)
|
||||
ps.add(p)
|
||||
}
|
||||
}
|
||||
|
||||
return paths
|
||||
return ps.list
|
||||
}
|
||||
|
||||
// findExtensionsInDir returns .go files in dir and main.go in immediate subdirs.
|
||||
@@ -580,6 +590,24 @@ func loadSingleExtension(path string) (*LoadedExtension, error) {
|
||||
registerShortcutFn: func(def ShortcutDef, handler func(Context)) {
|
||||
ext.Shortcuts = append(ext.Shortcuts, ShortcutEntry{Def: def, Handler: handler})
|
||||
},
|
||||
onSubagentStart: func(h func(SubagentStartEvent, Context)) {
|
||||
reg(SubagentStart, func(e Event, c Context) Result {
|
||||
h(e.(SubagentStartEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onSubagentChunk: func(h func(SubagentChunkEvent, Context)) {
|
||||
reg(SubagentChunk, func(e Event, c Context) Result {
|
||||
h(e.(SubagentChunkEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onSubagentEnd: func(h func(SubagentEndEvent, Context)) {
|
||||
reg(SubagentEnd, func(e Event, c Context) Result {
|
||||
h(e.(SubagentEndEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
// Call Init — the extension registers its handlers, tools, commands.
|
||||
|
||||
@@ -56,11 +56,261 @@ func NewRunner(exts []LoadedExtension) *Runner {
|
||||
}
|
||||
|
||||
// SetContext updates the runtime context (session ID, model, etc.) that is
|
||||
// passed to every handler invocation. Thread-safe.
|
||||
// passed to every handler invocation. Nil function fields are replaced with
|
||||
// safe no-ops so extension handlers never panic on a missing callback.
|
||||
// Thread-safe.
|
||||
func (r *Runner) SetContext(ctx Context) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.ctx = ctx
|
||||
r.ctx = normalizeContext(ctx)
|
||||
}
|
||||
|
||||
// normalizeContext replaces nil function fields in ctx with no-op stubs so
|
||||
// that extension handlers can call any ctx method without a nil-function panic.
|
||||
func normalizeContext(ctx Context) Context {
|
||||
if ctx.Print == nil {
|
||||
ctx.Print = func(string) {}
|
||||
}
|
||||
if ctx.PrintInfo == nil {
|
||||
ctx.PrintInfo = func(string) {}
|
||||
}
|
||||
if ctx.PrintError == nil {
|
||||
ctx.PrintError = func(string) {}
|
||||
}
|
||||
if ctx.PrintBlock == nil {
|
||||
ctx.PrintBlock = func(PrintBlockOpts) {}
|
||||
}
|
||||
if ctx.SendMessage == nil {
|
||||
ctx.SendMessage = func(string) {}
|
||||
}
|
||||
if ctx.CancelAndSend == nil {
|
||||
ctx.CancelAndSend = func(string) {}
|
||||
}
|
||||
if ctx.SetWidget == nil {
|
||||
ctx.SetWidget = func(WidgetConfig) {}
|
||||
}
|
||||
if ctx.RemoveWidget == nil {
|
||||
ctx.RemoveWidget = func(string) {}
|
||||
}
|
||||
if ctx.SetHeader == nil {
|
||||
ctx.SetHeader = func(HeaderFooterConfig) {}
|
||||
}
|
||||
if ctx.RemoveHeader == nil {
|
||||
ctx.RemoveHeader = func() {}
|
||||
}
|
||||
if ctx.SetFooter == nil {
|
||||
ctx.SetFooter = func(HeaderFooterConfig) {}
|
||||
}
|
||||
if ctx.RemoveFooter == nil {
|
||||
ctx.RemoveFooter = func() {}
|
||||
}
|
||||
if ctx.PromptSelect == nil {
|
||||
ctx.PromptSelect = func(PromptSelectConfig) PromptSelectResult {
|
||||
return PromptSelectResult{Cancelled: true}
|
||||
}
|
||||
}
|
||||
if ctx.PromptConfirm == nil {
|
||||
ctx.PromptConfirm = func(PromptConfirmConfig) PromptConfirmResult {
|
||||
return PromptConfirmResult{Cancelled: true}
|
||||
}
|
||||
}
|
||||
if ctx.PromptInput == nil {
|
||||
ctx.PromptInput = func(PromptInputConfig) PromptInputResult {
|
||||
return PromptInputResult{Cancelled: true}
|
||||
}
|
||||
}
|
||||
if ctx.PromptMultiSelect == nil {
|
||||
ctx.PromptMultiSelect = func(PromptMultiSelectConfig) PromptMultiSelectResult {
|
||||
return PromptMultiSelectResult{Cancelled: true}
|
||||
}
|
||||
}
|
||||
if ctx.ShowOverlay == nil {
|
||||
ctx.ShowOverlay = func(OverlayConfig) OverlayResult {
|
||||
return OverlayResult{Cancelled: true, Index: -1}
|
||||
}
|
||||
}
|
||||
if ctx.SetEditor == nil {
|
||||
ctx.SetEditor = func(EditorConfig) {}
|
||||
}
|
||||
if ctx.ResetEditor == nil {
|
||||
ctx.ResetEditor = func() {}
|
||||
}
|
||||
if ctx.SetEditorText == nil {
|
||||
ctx.SetEditorText = func(string) {}
|
||||
}
|
||||
if ctx.SetUIVisibility == nil {
|
||||
ctx.SetUIVisibility = func(UIVisibility) {}
|
||||
}
|
||||
if ctx.SetStatus == nil {
|
||||
ctx.SetStatus = func(string, string, int) {}
|
||||
}
|
||||
if ctx.RemoveStatus == nil {
|
||||
ctx.RemoveStatus = func(string) {}
|
||||
}
|
||||
if ctx.GetContextStats == nil {
|
||||
ctx.GetContextStats = func() ContextStats { return ContextStats{} }
|
||||
}
|
||||
if ctx.GetMessages == nil {
|
||||
ctx.GetMessages = func() []SessionMessage { return nil }
|
||||
}
|
||||
if ctx.GetSessionPath == nil {
|
||||
ctx.GetSessionPath = func() string { return "" }
|
||||
}
|
||||
if ctx.AppendEntry == nil {
|
||||
ctx.AppendEntry = func(string, string) (string, error) { return "", nil }
|
||||
}
|
||||
if ctx.GetEntries == nil {
|
||||
ctx.GetEntries = func(string) []ExtensionEntry { return nil }
|
||||
}
|
||||
if ctx.GetOption == nil {
|
||||
ctx.GetOption = func(string) string { return "" }
|
||||
}
|
||||
if ctx.SetOption == nil {
|
||||
ctx.SetOption = func(string, string) {}
|
||||
}
|
||||
if ctx.SetModel == nil {
|
||||
ctx.SetModel = func(string) error { return nil }
|
||||
}
|
||||
if ctx.GetAvailableModels == nil {
|
||||
ctx.GetAvailableModels = func() []ModelInfoEntry { return nil }
|
||||
}
|
||||
if ctx.EmitCustomEvent == nil {
|
||||
ctx.EmitCustomEvent = func(string, string) {}
|
||||
}
|
||||
if ctx.GetAllTools == nil {
|
||||
ctx.GetAllTools = func() []ToolInfo { return nil }
|
||||
}
|
||||
if ctx.SetActiveTools == nil {
|
||||
ctx.SetActiveTools = func([]string) {}
|
||||
}
|
||||
if ctx.Exit == nil {
|
||||
ctx.Exit = func() {}
|
||||
}
|
||||
if ctx.Complete == nil {
|
||||
ctx.Complete = func(CompleteRequest) (CompleteResponse, error) {
|
||||
return CompleteResponse{}, nil
|
||||
}
|
||||
}
|
||||
if ctx.SuspendTUI == nil {
|
||||
ctx.SuspendTUI = func(callback func()) error { callback(); return nil }
|
||||
}
|
||||
if ctx.RenderMessage == nil {
|
||||
ctx.RenderMessage = func(string, string) {}
|
||||
}
|
||||
if ctx.RegisterTheme == nil {
|
||||
ctx.RegisterTheme = func(string, ThemeColorConfig) {}
|
||||
}
|
||||
if ctx.SetTheme == nil {
|
||||
ctx.SetTheme = func(string) error { return nil }
|
||||
}
|
||||
if ctx.ListThemes == nil {
|
||||
ctx.ListThemes = func() []string { return nil }
|
||||
}
|
||||
if ctx.ReloadExtensions == nil {
|
||||
ctx.ReloadExtensions = func() error { return nil }
|
||||
}
|
||||
if ctx.SpawnSubagent == nil {
|
||||
ctx.SpawnSubagent = func(SubagentConfig) (*SubagentHandle, *SubagentResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Tree Navigation API no-ops
|
||||
// -------------------------------------------------------------------------
|
||||
if ctx.GetTreeNode == nil {
|
||||
ctx.GetTreeNode = func(string) *TreeNode { return nil }
|
||||
}
|
||||
if ctx.GetCurrentBranch == nil {
|
||||
ctx.GetCurrentBranch = func() []TreeNode { return nil }
|
||||
}
|
||||
if ctx.GetChildren == nil {
|
||||
ctx.GetChildren = func(string) []string { return nil }
|
||||
}
|
||||
if ctx.NavigateTo == nil {
|
||||
ctx.NavigateTo = func(string) TreeNavigationResult {
|
||||
return TreeNavigationResult{Success: false, Error: "not implemented"}
|
||||
}
|
||||
}
|
||||
if ctx.SummarizeBranch == nil {
|
||||
ctx.SummarizeBranch = func(string, string) string {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
if ctx.CollapseBranch == nil {
|
||||
ctx.CollapseBranch = func(string, string, string) TreeNavigationResult {
|
||||
return TreeNavigationResult{Success: false, Error: "not implemented"}
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Skill Loading API no-ops
|
||||
// -------------------------------------------------------------------------
|
||||
if ctx.LoadSkill == nil {
|
||||
ctx.LoadSkill = func(string) (*Skill, string) { return nil, "" }
|
||||
}
|
||||
if ctx.LoadSkillsFromDir == nil {
|
||||
ctx.LoadSkillsFromDir = func(string) SkillLoadResult { return SkillLoadResult{} }
|
||||
}
|
||||
if ctx.DiscoverSkills == nil {
|
||||
ctx.DiscoverSkills = func() SkillLoadResult { return SkillLoadResult{} }
|
||||
}
|
||||
if ctx.InjectSkillAsContext == nil {
|
||||
ctx.InjectSkillAsContext = func(string) string { return "" }
|
||||
}
|
||||
if ctx.InjectRawSkillAsContext == nil {
|
||||
ctx.InjectRawSkillAsContext = func(string) string { return "" }
|
||||
}
|
||||
if ctx.GetAvailableSkills == nil {
|
||||
ctx.GetAvailableSkills = func() []Skill { return nil }
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Template Parsing API no-ops
|
||||
// -------------------------------------------------------------------------
|
||||
if ctx.ParseTemplate == nil {
|
||||
ctx.ParseTemplate = func(string, string) PromptTemplate { return PromptTemplate{} }
|
||||
}
|
||||
if ctx.RenderTemplate == nil {
|
||||
ctx.RenderTemplate = func(PromptTemplate, map[string]string) string { return "" }
|
||||
}
|
||||
if ctx.ParseArguments == nil {
|
||||
ctx.ParseArguments = func(string, ArgumentPattern) ParseResult { return ParseResult{} }
|
||||
}
|
||||
if ctx.SimpleParseArguments == nil {
|
||||
ctx.SimpleParseArguments = func(string, int) []string { return nil }
|
||||
}
|
||||
if ctx.EvaluateModelConditional == nil {
|
||||
ctx.EvaluateModelConditional = func(string) bool { return false }
|
||||
}
|
||||
if ctx.RenderWithModelConditionals == nil {
|
||||
ctx.RenderWithModelConditionals = func(string) string { return "" }
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Model Resolution API no-ops
|
||||
// -------------------------------------------------------------------------
|
||||
if ctx.ResolveModelChain == nil {
|
||||
ctx.ResolveModelChain = func([]string) ModelResolutionResult {
|
||||
return ModelResolutionResult{Error: "not implemented"}
|
||||
}
|
||||
}
|
||||
if ctx.GetModelCapabilities == nil {
|
||||
ctx.GetModelCapabilities = func(string) (ModelCapabilities, string) {
|
||||
return ModelCapabilities{}, "not implemented"
|
||||
}
|
||||
}
|
||||
if ctx.CheckModelAvailable == nil {
|
||||
ctx.CheckModelAvailable = func(string) bool { return false }
|
||||
}
|
||||
if ctx.GetCurrentProvider == nil {
|
||||
ctx.GetCurrentProvider = func() string { return "" }
|
||||
}
|
||||
if ctx.GetCurrentModelID == nil {
|
||||
ctx.GetCurrentModelID = func() string { return "" }
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
// GetContext returns a snapshot of the current runtime context. Thread-safe.
|
||||
|
||||
@@ -173,10 +173,10 @@ type subagentJSONOutput struct {
|
||||
} `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
var subagentCounter uint64
|
||||
var subagentCounter atomic.Uint64
|
||||
|
||||
func generateSubagentID() string {
|
||||
n := atomic.AddUint64(&subagentCounter, 1)
|
||||
n := subagentCounter.Add(1)
|
||||
return fmt.Sprintf("sub-%d-%d", time.Now().UnixNano(), n)
|
||||
}
|
||||
|
||||
|
||||
@@ -119,10 +119,33 @@ func Symbols() interp.Exports {
|
||||
"SubagentHandle": reflect.ValueOf((*SubagentHandle)(nil)),
|
||||
"SubagentEvent": reflect.ValueOf((*SubagentEvent)(nil)),
|
||||
|
||||
// Subagent lifecycle events
|
||||
"SubagentStartEvent": reflect.ValueOf((*SubagentStartEvent)(nil)),
|
||||
"SubagentChunkEvent": reflect.ValueOf((*SubagentChunkEvent)(nil)),
|
||||
"SubagentEndEvent": reflect.ValueOf((*SubagentEndEvent)(nil)),
|
||||
|
||||
// Theme types
|
||||
"ThemeColor": reflect.ValueOf((*ThemeColor)(nil)),
|
||||
"ThemeColorConfig": reflect.ValueOf((*ThemeColorConfig)(nil)),
|
||||
|
||||
// Tree navigation types
|
||||
"TreeNode": reflect.ValueOf((*TreeNode)(nil)),
|
||||
"TreeNavigationResult": reflect.ValueOf((*TreeNavigationResult)(nil)),
|
||||
|
||||
// Skill types
|
||||
"Skill": reflect.ValueOf((*Skill)(nil)),
|
||||
"SkillLoadResult": reflect.ValueOf((*SkillLoadResult)(nil)),
|
||||
|
||||
// Template parsing types
|
||||
"PromptTemplate": reflect.ValueOf((*PromptTemplate)(nil)),
|
||||
"ArgumentPattern": reflect.ValueOf((*ArgumentPattern)(nil)),
|
||||
"ParseResult": reflect.ValueOf((*ParseResult)(nil)),
|
||||
"ModelConditional": reflect.ValueOf((*ModelConditional)(nil)),
|
||||
|
||||
// Model resolution types
|
||||
"ModelCapabilities": reflect.ValueOf((*ModelCapabilities)(nil)),
|
||||
"ModelResolutionResult": reflect.ValueOf((*ModelResolutionResult)(nil)),
|
||||
|
||||
// Event structs
|
||||
"ToolCallEvent": reflect.ValueOf((*ToolCallEvent)(nil)),
|
||||
"ToolCallResult": reflect.ValueOf((*ToolCallResult)(nil)),
|
||||
|
||||
@@ -171,5 +171,23 @@ func NewTestAPI(ext *LoadedExtension) API {
|
||||
registerMessageRendererFn: func(config MessageRendererConfig) {
|
||||
ext.MessageRenderers = append(ext.MessageRenderers, config)
|
||||
},
|
||||
onSubagentStart: func(h func(SubagentStartEvent, Context)) {
|
||||
reg(SubagentStart, func(e Event, c Context) Result {
|
||||
h(e.(SubagentStartEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onSubagentChunk: func(h func(SubagentChunkEvent, Context)) {
|
||||
reg(SubagentChunk, func(e Event, c Context) Result {
|
||||
h(e.(SubagentChunkEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onSubagentEnd: func(h func(SubagentEndEvent, Context)) {
|
||||
reg(SubagentEnd, func(e Event, c Context) Result {
|
||||
h(e.(SubagentEndEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,14 +42,14 @@ func ExtensionToolsAsFantasy(defs []ToolDef, runner *Runner) []fantasy.AgentTool
|
||||
|
||||
// coreToolKinds maps built-in tool names to their kind classification.
|
||||
var coreToolKinds = map[string]string{
|
||||
"bash": "execute",
|
||||
"edit": "edit",
|
||||
"write": "edit",
|
||||
"read": "read",
|
||||
"ls": "read",
|
||||
"grep": "search",
|
||||
"find": "search",
|
||||
"spawn_subagent": "agent",
|
||||
"bash": "execute",
|
||||
"edit": "edit",
|
||||
"write": "edit",
|
||||
"read": "read",
|
||||
"ls": "read",
|
||||
"grep": "search",
|
||||
"find": "search",
|
||||
"subagent": "agent",
|
||||
}
|
||||
|
||||
// toolKindFor returns the ToolKind for a given tool name, defaulting to
|
||||
|
||||
+60
-10
@@ -4,11 +4,44 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// thinkTagRegex matches ... tags that some models (Qwen, DeepSeek) wrap
|
||||
// reasoning content in. Used to strip these tags from text content.
|
||||
// The (?s) flag makes . match newlines.
|
||||
var thinkTagRegex = regexp.MustCompile(`(?s)` + `` + `think` + `` + `(.*?)` + `` + `/think` + ``)
|
||||
|
||||
// sanitizeToolCallID ensures the ID matches Anthropic's required pattern:
|
||||
// ^[a-zA-Z0-9_-]+$ (alphanumeric, underscores, and hyphens only).
|
||||
// Invalid characters are replaced with underscores.
|
||||
func sanitizeToolCallID(id string) string {
|
||||
var sb strings.Builder
|
||||
for _, r := range id {
|
||||
switch {
|
||||
case (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z'):
|
||||
sb.WriteRune(r)
|
||||
case r >= '0' && r <= '9':
|
||||
sb.WriteRune(r)
|
||||
case r == '_' || r == '-':
|
||||
sb.WriteRune(r)
|
||||
default:
|
||||
// Replace invalid characters with underscore
|
||||
sb.WriteByte('_')
|
||||
}
|
||||
}
|
||||
result := sb.String()
|
||||
// Ensure non-empty (Anthropic requires at least one character)
|
||||
if result == "" {
|
||||
return "tool_0"
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ContentPart is the marker interface for all message content block types.
|
||||
// A message contains a heterogeneous slice of ContentPart values, enabling
|
||||
// rich structured messages that carry text, reasoning, tool calls, tool
|
||||
@@ -88,9 +121,9 @@ const (
|
||||
)
|
||||
|
||||
// Message is a single conversation message containing a heterogeneous slice
|
||||
// of ContentPart blocks. This design (borrowed from crush) enables a single
|
||||
// assistant message to carry text, reasoning, and multiple tool calls as
|
||||
// discrete, typed blocks rather than flattening everything into strings.
|
||||
// of ContentPart blocks. This design enables a single assistant message to
|
||||
// carry text, reasoning, and multiple tool calls as discrete, typed blocks
|
||||
// rather than flattening everything into strings.
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
Role MessageRole `json:"role"`
|
||||
@@ -285,12 +318,18 @@ func UnmarshalParts(data []byte) ([]ContentPart, error) {
|
||||
return parts, nil
|
||||
}
|
||||
|
||||
// --- Fantasy bridge ---
|
||||
// --- LLM bridge ---
|
||||
|
||||
// ToFantasyMessages converts a Message to one or more fantasy.Message values.
|
||||
// An assistant message with tool calls produces a single fantasy message with
|
||||
// ToLLMMessages converts a Message to one or more LLM message values.
|
||||
// An assistant message with tool calls produces a single message with
|
||||
// mixed TextPart and ToolCallPart content. Tool-role messages produce
|
||||
// ToolResultPart entries.
|
||||
func (m *Message) ToLLMMessages() []fantasy.Message {
|
||||
return m.ToFantasyMessages()
|
||||
}
|
||||
|
||||
// Deprecated: Use ToLLMMessages instead.
|
||||
// ToFantasyMessages converts a Message to one or more LLM message values.
|
||||
func (m *Message) ToFantasyMessages() []fantasy.Message {
|
||||
switch m.Role {
|
||||
case RoleAssistant:
|
||||
@@ -312,7 +351,7 @@ func (m *Message) ToFantasyMessages() []fantasy.Message {
|
||||
// Add tool calls
|
||||
for _, tc := range m.ToolCalls() {
|
||||
parts = append(parts, fantasy.ToolCallPart{
|
||||
ToolCallID: tc.ID,
|
||||
ToolCallID: sanitizeToolCallID(tc.ID),
|
||||
ToolName: tc.Name,
|
||||
Input: tc.Input,
|
||||
})
|
||||
@@ -340,7 +379,7 @@ func (m *Message) ToFantasyMessages() []fantasy.Message {
|
||||
}
|
||||
}
|
||||
parts = append(parts, fantasy.ToolResultPart{
|
||||
ToolCallID: result.ToolCallID,
|
||||
ToolCallID: sanitizeToolCallID(result.ToolCallID),
|
||||
Output: output,
|
||||
})
|
||||
}
|
||||
@@ -389,7 +428,14 @@ func (m *Message) ToFantasyMessages() []fantasy.Message {
|
||||
}
|
||||
}
|
||||
|
||||
// FromFantasyMessage converts a fantasy.Message into our Message type,
|
||||
// FromLLMMessage converts an LLM message into our Message type,
|
||||
// extracting all content parts into the appropriate block types.
|
||||
func FromLLMMessage(msg fantasy.Message) Message {
|
||||
return FromFantasyMessage(msg)
|
||||
}
|
||||
|
||||
// Deprecated: Use FromLLMMessage instead.
|
||||
// FromFantasyMessage converts an LLM message into our Message type,
|
||||
// extracting all content parts into the appropriate block types.
|
||||
func FromFantasyMessage(msg fantasy.Message) Message {
|
||||
m := Message{
|
||||
@@ -403,7 +449,11 @@ func FromFantasyMessage(msg fantasy.Message) Message {
|
||||
switch p := part.(type) {
|
||||
case fantasy.TextPart:
|
||||
if p.Text != "" {
|
||||
m.Parts = append(m.Parts, TextContent{Text: p.Text})
|
||||
// Strip ... tags that some models wrap reasoning in
|
||||
cleanedText := thinkTagRegex.ReplaceAllString(p.Text, "")
|
||||
if cleanedText != "" {
|
||||
m.Parts = append(m.Parts, TextContent{Text: cleanedText})
|
||||
}
|
||||
}
|
||||
case fantasy.ToolCallPart:
|
||||
m.Parts = append(m.Parts, ToolCall{
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSanitizeToolCallID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "valid alphanumeric ID",
|
||||
input: "call_123abc",
|
||||
expected: "call_123abc",
|
||||
},
|
||||
{
|
||||
name: "ID with dots (OpenCode/Kimi style)",
|
||||
input: "call.123.abc",
|
||||
expected: "call_123_abc",
|
||||
},
|
||||
{
|
||||
name: "ID with colons",
|
||||
input: "tool:123:abc",
|
||||
expected: "tool_123_abc",
|
||||
},
|
||||
{
|
||||
name: "ID with special characters",
|
||||
input: "tool@#$%^&*()",
|
||||
expected: "tool_________",
|
||||
},
|
||||
{
|
||||
name: "Anthropic style ID (already valid)",
|
||||
input: "toolu_0123456789ABCDEF",
|
||||
expected: "toolu_0123456789ABCDEF",
|
||||
},
|
||||
{
|
||||
name: "OpenAI style ID (already valid)",
|
||||
input: "call_O17Uplv4lJvD6DVdIvFFeRMw",
|
||||
expected: "call_O17Uplv4lJvD6DVdIvFFeRMw",
|
||||
},
|
||||
{
|
||||
name: "ID with hyphens",
|
||||
input: "my-tool-call-123",
|
||||
expected: "my-tool-call-123",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "tool_0",
|
||||
},
|
||||
{
|
||||
name: "only special characters",
|
||||
input: "@#$%",
|
||||
expected: "____",
|
||||
},
|
||||
{
|
||||
name: "mixed valid and invalid",
|
||||
input: "call_123.abc-def@ghi",
|
||||
expected: "call_123_abc-def_ghi",
|
||||
},
|
||||
{
|
||||
name: "Unicode characters",
|
||||
input: "tool_日本語",
|
||||
expected: "tool____",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := sanitizeToolCallID(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("sanitizeToolCallID(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeToolCallID_MatchesAnthropicPattern(t *testing.T) {
|
||||
// Test that sanitized IDs match Anthropic's required pattern: ^[a-zA-Z0-9_-]+$
|
||||
// This is a simplified check - in reality the pattern allows alphanumeric, underscore, hyphen
|
||||
testIDs := []string{
|
||||
"call.123.abc",
|
||||
"tool:123:def",
|
||||
"id@#$%^&*()",
|
||||
"mixed.valid-id_test",
|
||||
"",
|
||||
}
|
||||
|
||||
for _, id := range testIDs {
|
||||
sanitized := sanitizeToolCallID(id)
|
||||
|
||||
// Verify each character is valid
|
||||
for i, r := range sanitized {
|
||||
valid := (r >= 'a' && r <= 'z') ||
|
||||
(r >= 'A' && r <= 'Z') ||
|
||||
(r >= '0' && r <= '9') ||
|
||||
r == '_' ||
|
||||
r == '-'
|
||||
|
||||
if !valid {
|
||||
t.Errorf("sanitizeToolCallID(%q) = %q, contains invalid character at position %d: %q",
|
||||
id, sanitized, i, string(r))
|
||||
}
|
||||
}
|
||||
|
||||
// Verify non-empty
|
||||
if sanitized == "" {
|
||||
t.Errorf("sanitizeToolCallID(%q) returned empty string", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"maps"
|
||||
"os"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"charm.land/fantasy/providers/openai"
|
||||
)
|
||||
|
||||
// buildCacheProviderOptions returns caching options for supported models.
|
||||
// Caching is enabled by default for all supported models to reduce costs.
|
||||
// Set KIT_DISABLE_CACHE=1 or ProviderConfig.DisableCaching=true to opt out.
|
||||
func buildCacheProviderOptions(modelInfo *ModelInfo, config *ProviderConfig) fantasy.ProviderOptions {
|
||||
// Check explicit opt-out via config
|
||||
if config.DisableCaching {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check global opt-out via environment
|
||||
if os.Getenv("KIT_DISABLE_CACHE") != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if model supports caching
|
||||
if modelInfo == nil || !modelInfo.SupportsCaching() {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch modelInfo.CacheType() {
|
||||
case "anthropic-ephemeral":
|
||||
// Provider-level Anthropic caching disabled - use message-level caching instead.
|
||||
return nil
|
||||
case "openai-prompt-cache":
|
||||
return buildOpenAICacheOptions(config, modelInfo.ID)
|
||||
case "google-cached-content":
|
||||
// Google caching not yet implemented.
|
||||
return nil
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// buildOpenAICacheOptions enables prompt caching for OpenAI models.
|
||||
// Uses a deterministic cache key based on system prompt and model ID.
|
||||
func buildOpenAICacheOptions(config *ProviderConfig, modelID string) fantasy.ProviderOptions {
|
||||
cacheKey := generateCacheKey(config.SystemPrompt, modelID)
|
||||
|
||||
return fantasy.ProviderOptions{
|
||||
openai.Name: &openai.ProviderOptions{
|
||||
PromptCacheKey: &cacheKey,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// generateCacheKey creates a deterministic cache key from system prompt and model.
|
||||
// This ensures the same system prompt + model combination gets cache hits.
|
||||
func generateCacheKey(systemPrompt, modelID string) string {
|
||||
if systemPrompt == "" {
|
||||
systemPrompt = "default"
|
||||
}
|
||||
|
||||
h := sha256.New()
|
||||
h.Write([]byte(systemPrompt))
|
||||
h.Write([]byte(modelID))
|
||||
|
||||
// Prefix with "kit-" to identify KIT-generated cache keys
|
||||
return "kit-" + hex.EncodeToString(h.Sum(nil))[:24]
|
||||
}
|
||||
|
||||
// mergeProviderOptions merges multiple ProviderOptions maps.
|
||||
// Later maps take precedence over earlier ones.
|
||||
func mergeProviderOptions(opts ...fantasy.ProviderOptions) fantasy.ProviderOptions {
|
||||
result := make(fantasy.ProviderOptions)
|
||||
|
||||
for _, opt := range opts {
|
||||
maps.Copy(result, opt)
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,248 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
func TestModelInfo_SupportsCaching(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
family string
|
||||
expected bool
|
||||
}{
|
||||
{"Claude model", "claude-3-5-sonnet", true},
|
||||
{"Claude 4 model", "claude-4-opus", true},
|
||||
{"GPT model", "gpt-4", true},
|
||||
{"GPT-5 model", "gpt-5", true},
|
||||
{"O1 model", "o1", true},
|
||||
{"O3 model", "o3", true},
|
||||
{"O4 model", "o4-mini", true},
|
||||
{"Codex model", "codex", true},
|
||||
{"Gemini model", "gemini-2.5-pro", true},
|
||||
{"Gemini 1.5 model", "gemini-1.5-flash", true},
|
||||
{"Llama model", "llama-3", false},
|
||||
{"Unknown model", "unknown", false},
|
||||
{"Empty family", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := &ModelInfo{Family: tt.family}
|
||||
if got := m.SupportsCaching(); got != tt.expected {
|
||||
t.Errorf("ModelInfo.SupportsCaching() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelInfo_CacheType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
family string
|
||||
expected string
|
||||
}{
|
||||
{"Claude model", "claude-3-5-sonnet", "anthropic-ephemeral"},
|
||||
{"GPT model", "gpt-4", "openai-prompt-cache"},
|
||||
{"O1 model", "o1", "openai-prompt-cache"},
|
||||
{"Gemini model", "gemini-2.5-pro", "google-cached-content"},
|
||||
{"Unknown model", "llama-3", ""},
|
||||
{"Empty family", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := &ModelInfo{Family: tt.family}
|
||||
if got := m.CacheType(); got != tt.expected {
|
||||
t.Errorf("ModelInfo.CacheType() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCacheKey(t *testing.T) {
|
||||
key1 := generateCacheKey("system prompt", "model-id")
|
||||
key2 := generateCacheKey("system prompt", "model-id")
|
||||
if key1 != key2 {
|
||||
t.Errorf("generateCacheKey should be deterministic: got %q and %q", key1, key2)
|
||||
}
|
||||
|
||||
key3 := generateCacheKey("different prompt", "model-id")
|
||||
if key1 == key3 {
|
||||
t.Errorf("generateCacheKey should produce different keys for different inputs")
|
||||
}
|
||||
|
||||
key4 := generateCacheKey("", "model-id")
|
||||
key5 := generateCacheKey("default", "model-id")
|
||||
if key4 != key5 {
|
||||
t.Errorf("generateCacheKey should treat empty prompt as 'default'")
|
||||
}
|
||||
|
||||
if len(key1) < 4 || key1[:4] != "kit-" {
|
||||
t.Errorf("generateCacheKey should produce keys with 'kit-' prefix, got %q", key1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCacheProviderOptions_Disabled(t *testing.T) {
|
||||
config := &ProviderConfig{DisableCaching: true}
|
||||
modelInfo := &ModelInfo{Family: "claude-3", ID: "claude-3-opus"}
|
||||
|
||||
if opts := buildCacheProviderOptions(modelInfo, config); opts != nil {
|
||||
t.Errorf("buildCacheProviderOptions should return nil when DisableCaching=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCacheProviderOptions_EnvironmentVariable(t *testing.T) {
|
||||
_ = os.Setenv("KIT_DISABLE_CACHE", "1")
|
||||
defer func() { _ = os.Unsetenv("KIT_DISABLE_CACHE") }()
|
||||
|
||||
config := &ProviderConfig{DisableCaching: false}
|
||||
modelInfo := &ModelInfo{Family: "claude-3", ID: "claude-3-opus"}
|
||||
|
||||
if opts := buildCacheProviderOptions(modelInfo, config); opts != nil {
|
||||
t.Errorf("buildCacheProviderOptions should return nil when KIT_DISABLE_CACHE is set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCacheProviderOptions_UnsupportedModel(t *testing.T) {
|
||||
config := &ProviderConfig{DisableCaching: false}
|
||||
modelInfo := &ModelInfo{Family: "llama-3", ID: "llama-3-70b"}
|
||||
|
||||
if opts := buildCacheProviderOptions(modelInfo, config); opts != nil {
|
||||
t.Errorf("buildCacheProviderOptions should return nil for unsupported model families")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCacheProviderOptions_NilModelInfo(t *testing.T) {
|
||||
config := &ProviderConfig{DisableCaching: false}
|
||||
|
||||
if opts := buildCacheProviderOptions(nil, config); opts != nil {
|
||||
t.Errorf("buildCacheProviderOptions should return nil when modelInfo is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCacheProviderOptions_Anthropic(t *testing.T) {
|
||||
_ = os.Unsetenv("KIT_DISABLE_CACHE")
|
||||
|
||||
config := &ProviderConfig{DisableCaching: false}
|
||||
modelInfo := &ModelInfo{Family: "claude-3", ID: "claude-3-opus"}
|
||||
|
||||
opts := buildCacheProviderOptions(modelInfo, config)
|
||||
// Provider-level Anthropic caching is disabled; message-level caching is used instead
|
||||
if opts != nil {
|
||||
t.Logf("Provider-level Anthropic caching disabled; using message-level caching")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCacheProviderOptions_OpenAI(t *testing.T) {
|
||||
_ = os.Unsetenv("KIT_DISABLE_CACHE")
|
||||
|
||||
config := &ProviderConfig{
|
||||
DisableCaching: false,
|
||||
SystemPrompt: "test system prompt",
|
||||
}
|
||||
modelInfo := &ModelInfo{Family: "gpt-4", ID: "gpt-4o"}
|
||||
|
||||
opts := buildCacheProviderOptions(modelInfo, config)
|
||||
if opts == nil {
|
||||
t.Fatalf("buildCacheProviderOptions should return options for OpenAI models")
|
||||
}
|
||||
|
||||
if _, ok := opts["openai"]; !ok {
|
||||
t.Errorf("buildCacheProviderOptions should include 'openai' key for GPT models")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachingPriorityOverThinking(t *testing.T) {
|
||||
_ = os.Unsetenv("KIT_DISABLE_CACHE")
|
||||
|
||||
// Anthropic uses message-level caching; provider-level returns nil
|
||||
config1 := &ProviderConfig{
|
||||
DisableCaching: false,
|
||||
ThinkingLevel: ThinkingOff,
|
||||
}
|
||||
modelInfo1 := &ModelInfo{Family: "claude-3", ID: "claude-3-opus"}
|
||||
opts1 := buildCacheProviderOptions(modelInfo1, config1)
|
||||
if opts1 != nil {
|
||||
t.Logf("Provider-level Anthropic caching disabled; using message-level caching")
|
||||
}
|
||||
|
||||
// OpenAI provider-level caching works with thinking enabled
|
||||
config2 := &ProviderConfig{
|
||||
DisableCaching: false,
|
||||
SystemPrompt: "test prompt",
|
||||
ThinkingLevel: ThinkingMedium,
|
||||
}
|
||||
modelInfo2 := &ModelInfo{Family: "gpt-4", ID: "gpt-4o"}
|
||||
opts2 := buildCacheProviderOptions(modelInfo2, config2)
|
||||
if opts2 == nil {
|
||||
t.Errorf("OpenAI caching should work with thinking enabled")
|
||||
}
|
||||
|
||||
// OpenAI caching also works with thinking disabled
|
||||
config3 := &ProviderConfig{
|
||||
DisableCaching: false,
|
||||
SystemPrompt: "test prompt",
|
||||
ThinkingLevel: ThinkingOff,
|
||||
}
|
||||
opts3 := buildCacheProviderOptions(modelInfo2, config3)
|
||||
if opts3 == nil {
|
||||
t.Errorf("OpenAI caching should work when thinking is OFF")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeProviderOptions(t *testing.T) {
|
||||
opts1 := fantasy.ProviderOptions{
|
||||
"provider1": &testProviderData{value: "value1"},
|
||||
}
|
||||
opts2 := fantasy.ProviderOptions{
|
||||
"provider2": &testProviderData{value: "value2"},
|
||||
}
|
||||
|
||||
merged := mergeProviderOptions(opts1, opts2)
|
||||
|
||||
if len(merged) != 2 {
|
||||
t.Errorf("mergeProviderOptions should combine options from multiple maps, got %d items", len(merged))
|
||||
}
|
||||
|
||||
if _, ok := merged["provider1"]; !ok {
|
||||
t.Errorf("merged options should contain 'provider1' key")
|
||||
}
|
||||
|
||||
if _, ok := merged["provider2"]; !ok {
|
||||
t.Errorf("merged options should contain 'provider2' key")
|
||||
}
|
||||
|
||||
// Later options should override earlier ones
|
||||
opts3 := fantasy.ProviderOptions{
|
||||
"provider1": &testProviderData{value: "overridden"},
|
||||
}
|
||||
merged2 := mergeProviderOptions(opts1, opts3)
|
||||
|
||||
if data, ok := merged2["provider1"].(*testProviderData); ok {
|
||||
if data.value != "overridden" {
|
||||
t.Errorf("later options should override earlier ones, got %q", data.value)
|
||||
}
|
||||
}
|
||||
|
||||
if mergeProviderOptions() != nil {
|
||||
t.Errorf("mergeProviderOptions with no args should return nil")
|
||||
}
|
||||
}
|
||||
|
||||
// testProviderData is a simple implementation of ProviderOptionsData for testing
|
||||
type testProviderData struct {
|
||||
value string
|
||||
}
|
||||
|
||||
func (t *testProviderData) Options() {}
|
||||
|
||||
func (t *testProviderData) MarshalJSON() ([]byte, error) {
|
||||
return []byte(`"` + t.value + `"`), nil
|
||||
}
|
||||
|
||||
func (t *testProviderData) UnmarshalJSON(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
+17
-11
@@ -17,15 +17,21 @@ type modelsDBProvider struct {
|
||||
|
||||
// modelsDBModel represents a model entry from models.dev/api.json.
|
||||
type modelsDBModel struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Family string `json:"family,omitempty"`
|
||||
Attachment bool `json:"attachment"`
|
||||
Reasoning bool `json:"reasoning"`
|
||||
ToolCall bool `json:"tool_call"`
|
||||
Temperature bool `json:"temperature"`
|
||||
Cost modelsDBCost `json:"cost"`
|
||||
Limit modelsDBLimit `json:"limit"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Family string `json:"family,omitempty"`
|
||||
Attachment bool `json:"attachment"`
|
||||
Reasoning bool `json:"reasoning"`
|
||||
ToolCall bool `json:"tool_call"`
|
||||
Temperature bool `json:"temperature"`
|
||||
Cost modelsDBCost `json:"cost"`
|
||||
Limit modelsDBLimit `json:"limit"`
|
||||
Provider *modelsDBModelProvider `json:"provider,omitempty"` // Model-specific provider override
|
||||
}
|
||||
|
||||
// modelsDBModelProvider represents a provider reference within a model.
|
||||
type modelsDBModelProvider struct {
|
||||
NPM string `json:"npm"`
|
||||
}
|
||||
|
||||
// modelsDBCost represents model pricing from models.dev.
|
||||
@@ -42,10 +48,10 @@ type modelsDBLimit struct {
|
||||
Output int `json:"output"`
|
||||
}
|
||||
|
||||
// npmToFantasyProvider maps npm package names from models.dev to fantasy
|
||||
// npmToLLMProvider maps npm package names from models.dev to LLM
|
||||
// provider identifiers. Providers not in this map but with an api URL
|
||||
// can be auto-routed through openaicompat.
|
||||
var npmToFantasyProvider = map[string]string{
|
||||
var npmToLLMProvider = map[string]string{
|
||||
"@ai-sdk/anthropic": "anthropic",
|
||||
"@ai-sdk/openai": "openai",
|
||||
"@ai-sdk/google": "google",
|
||||
|
||||
+382
-29
@@ -10,6 +10,7 @@ import (
|
||||
"maps"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -22,6 +23,7 @@ import (
|
||||
"charm.land/fantasy/providers/openaicompat"
|
||||
"charm.land/fantasy/providers/openrouter"
|
||||
"charm.land/fantasy/providers/vercel"
|
||||
openaisdk "github.com/charmbracelet/openai-go"
|
||||
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
"github.com/mark3labs/kit/internal/ui/progress"
|
||||
@@ -155,6 +157,7 @@ type ProviderConfig struct {
|
||||
MainGPU *int32
|
||||
TLSSkipVerify bool
|
||||
ThinkingLevel ThinkingLevel
|
||||
DisableCaching bool // Opt-out: set to true to disable automatic prompt caching
|
||||
}
|
||||
|
||||
// ProviderResult contains the result of provider creation.
|
||||
@@ -169,6 +172,9 @@ type ProviderResult struct {
|
||||
// ProviderOptions contains provider-specific options to be passed to the
|
||||
// fantasy agent (e.g. OpenAI Responses API reasoning options).
|
||||
ProviderOptions fantasy.ProviderOptions
|
||||
// SkipMaxOutputTokens indicates that this provider doesn't support the
|
||||
// max_output_tokens parameter (e.g., OpenAI Codex OAuth API).
|
||||
SkipMaxOutputTokens bool
|
||||
}
|
||||
|
||||
// ParseModelString parses a model string in "provider/model" format (e.g. "anthropic/claude-sonnet-4-5").
|
||||
@@ -234,49 +240,86 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
|
||||
validateModelConfig(config, modelInfo)
|
||||
}
|
||||
|
||||
// Create the base provider
|
||||
var result *ProviderResult
|
||||
var createErr error
|
||||
|
||||
switch provider {
|
||||
case "anthropic":
|
||||
return createAnthropicProvider(ctx, config, modelName)
|
||||
result, createErr = createAnthropicProvider(ctx, config, modelName)
|
||||
case "openai":
|
||||
return createOpenAIProvider(ctx, config, modelName)
|
||||
result, createErr = createOpenAIProvider(ctx, config, modelName)
|
||||
case "google", "gemini":
|
||||
return createGoogleProvider(ctx, config, modelName)
|
||||
result, createErr = createGoogleProvider(ctx, config, modelName)
|
||||
case "ollama":
|
||||
return createOllamaProvider(ctx, config, modelName)
|
||||
result, createErr = createOllamaProvider(ctx, config, modelName)
|
||||
case "azure":
|
||||
return createAzureProvider(ctx, config, modelName)
|
||||
result, createErr = createAzureProvider(ctx, config, modelName)
|
||||
case "google-vertex-anthropic":
|
||||
return createVertexAnthropicProvider(ctx, config, modelName)
|
||||
result, createErr = createVertexAnthropicProvider(ctx, config, modelName)
|
||||
case "openrouter":
|
||||
return createOpenRouterProvider(ctx, config, modelName)
|
||||
result, createErr = createOpenRouterProvider(ctx, config, modelName)
|
||||
case "bedrock":
|
||||
return createBedrockProvider(ctx, config, modelName)
|
||||
result, createErr = createBedrockProvider(ctx, config, modelName)
|
||||
case "vercel":
|
||||
return createVercelProvider(ctx, config, modelName)
|
||||
result, createErr = createVercelProvider(ctx, config, modelName)
|
||||
case "custom":
|
||||
return createCustomProvider(ctx, config, modelName)
|
||||
result, createErr = createCustomProvider(ctx, config, modelName)
|
||||
default:
|
||||
return autoRouteProvider(ctx, config, provider, modelName, registry)
|
||||
result, createErr = autoRouteProvider(ctx, config, provider, modelName, registry)
|
||||
}
|
||||
|
||||
if createErr != nil {
|
||||
return nil, createErr
|
||||
}
|
||||
|
||||
// AUTOMATICALLY ENABLE CACHING for supported models (unless disabled).
|
||||
// This works for BOTH native and auto-routed providers by detecting
|
||||
// the model family from the model metadata.
|
||||
if cacheOpts := buildCacheProviderOptions(modelInfo, config); cacheOpts != nil {
|
||||
if result.ProviderOptions == nil {
|
||||
result.ProviderOptions = cacheOpts
|
||||
} else {
|
||||
// Merge cache options with existing provider options.
|
||||
// Only add cache options for providers that don't already have
|
||||
// options set, to avoid type conflicts (e.g., Anthropic has
|
||||
// different types for regular options vs cache control options).
|
||||
for k, v := range cacheOpts {
|
||||
if _, exists := result.ProviderOptions[k]; !exists {
|
||||
result.ProviderOptions[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// autoRouteProvider attempts to create a provider by looking up its npm package
|
||||
// in the models.dev database and routing through the appropriate fantasy provider.
|
||||
// For openai-compatible providers, it uses the api URL from models.dev.
|
||||
// Models may have a provider override that specifies a different npm package than
|
||||
// the provider's default (e.g., opencode's claude-opus-4-6 uses @ai-sdk/anthropic).
|
||||
func autoRouteProvider(ctx context.Context, config *ProviderConfig, provider, modelName string, registry *ModelsRegistry) (*ProviderResult, error) {
|
||||
providerInfo := registry.GetProviderInfo(provider)
|
||||
if providerInfo == nil {
|
||||
return nil, fmt.Errorf("unsupported provider: %s (not found in model database)", provider)
|
||||
}
|
||||
|
||||
// Determine the fantasy provider for this npm package
|
||||
fantasyProvider := npmToFantasyProvider[providerInfo.NPM]
|
||||
if fantasyProvider == "" && providerInfo.API != "" {
|
||||
// Unknown npm but has API URL → route through openaicompat
|
||||
fantasyProvider = "openaicompat"
|
||||
// Check for model-specific provider override
|
||||
npmPackage := providerInfo.NPM
|
||||
if modelInfo := registry.LookupModel(provider, modelName); modelInfo != nil && modelInfo.ProviderNPM != "" {
|
||||
npmPackage = modelInfo.ProviderNPM
|
||||
}
|
||||
|
||||
switch fantasyProvider {
|
||||
// Determine the LLM provider for this npm package
|
||||
llmProvider := npmToLLMProvider[npmPackage]
|
||||
if llmProvider == "" && providerInfo.API != "" {
|
||||
// Unknown npm but has API URL → route through openaicompat
|
||||
llmProvider = "openaicompat"
|
||||
}
|
||||
|
||||
switch llmProvider {
|
||||
case "openaicompat":
|
||||
return createAutoRoutedOpenAICompatProvider(ctx, config, modelName, providerInfo)
|
||||
case "anthropic":
|
||||
@@ -290,7 +333,7 @@ func autoRouteProvider(ctx context.Context, config *ProviderConfig, provider, mo
|
||||
}
|
||||
return createAutoRoutedOpenAIProvider(ctx, config, modelName, providerInfo)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported provider: %s (npm: %s has no fantasy mapping)", provider, providerInfo.NPM)
|
||||
return nil, fmt.Errorf("unsupported provider: %s (npm: %s has no LLM provider mapping)", provider, npmPackage)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -348,7 +391,10 @@ func createAutoRoutedAnthropicProvider(ctx context.Context, config *ProviderConf
|
||||
opts = append(opts, anthropic.WithAPIKey(apiKey))
|
||||
|
||||
if config.ProviderURL != "" {
|
||||
opts = append(opts, anthropic.WithBaseURL(config.ProviderURL))
|
||||
// The anthropic client appends "/v1/messages" to the base URL.
|
||||
// If the provider URL ends with "/v1", strip it to avoid double "/v1/v1" paths.
|
||||
baseURL := strings.TrimSuffix(config.ProviderURL, "/v1")
|
||||
opts = append(opts, anthropic.WithBaseURL(baseURL))
|
||||
}
|
||||
|
||||
if config.TLSSkipVerify {
|
||||
@@ -496,10 +542,15 @@ func thinkingLevelToReasoningEffort(level ThinkingLevel) *openai.ReasoningEffort
|
||||
// SendReasoning to true and configures the thinking budget. For thinking-off
|
||||
// or non-reasoning models the returned map is nil.
|
||||
//
|
||||
// NOTE: With message-level caching, thinking and caching can work together.
|
||||
// Message-level cache control (ProviderCacheControlOptions) doesn't conflict
|
||||
// with provider-level thinking options (ProviderOptions).
|
||||
//
|
||||
// Anthropic requires max_tokens > thinking.budget_tokens. If the configured
|
||||
// MaxTokens is too low, it is bumped to budget + 4096 to leave room for the
|
||||
// actual response.
|
||||
func buildAnthropicProviderOptions(config *ProviderConfig, modelName string) fantasy.ProviderOptions {
|
||||
// Thinking is OFF by default. If user hasn't explicitly enabled it, return nil.
|
||||
if config.ThinkingLevel == "" || config.ThinkingLevel == ThinkingOff {
|
||||
return nil
|
||||
}
|
||||
@@ -610,13 +661,52 @@ func createVertexAnthropicProvider(ctx context.Context, config *ProviderConfig,
|
||||
|
||||
func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
||||
apiKey := config.ProviderAPIKey
|
||||
source := "command-line flag"
|
||||
var accountID string
|
||||
var isCodexOAuth bool
|
||||
|
||||
if apiKey == "" {
|
||||
apiKey = os.Getenv("OPENAI_API_KEY")
|
||||
}
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("OpenAI API key not provided. Use --provider-api-key flag or OPENAI_API_KEY environment variable")
|
||||
// Check stored credentials first
|
||||
cm, err := auth.NewCredentialManager()
|
||||
if err == nil {
|
||||
if creds, err := cm.GetOpenAICredentials(); err == nil && creds != nil {
|
||||
if creds.Type == "oauth" && creds.AccessToken != "" {
|
||||
// For OAuth, get a valid access token (may refresh if needed)
|
||||
token, err := cm.GetValidOpenAIAccessToken()
|
||||
if err == nil && token != "" {
|
||||
apiKey = token
|
||||
accountID = creds.AccountID
|
||||
isCodexOAuth = true
|
||||
source = "stored Codex OAuth credentials"
|
||||
}
|
||||
} else if creds.Type == "api_key" && creds.APIKey != "" {
|
||||
apiKey = creds.APIKey
|
||||
source = "stored API key"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to environment variable
|
||||
if apiKey == "" {
|
||||
apiKey = os.Getenv("OPENAI_API_KEY")
|
||||
source = "OPENAI_API_KEY environment variable"
|
||||
}
|
||||
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("OpenAI API key not provided. Use 'kit auth login openai', --provider-api-key flag, or OPENAI_API_KEY environment variable")
|
||||
}
|
||||
|
||||
if os.Getenv("DEBUG") != "" || os.Getenv("KIT_DEBUG") != "" {
|
||||
fmt.Fprintf(os.Stderr, "Using OpenAI API key from: %s\n", source)
|
||||
}
|
||||
|
||||
// For Codex OAuth, use the ChatGPT backend API with custom headers
|
||||
if isCodexOAuth {
|
||||
return createOpenAICodexProvider(ctx, config, modelName, apiKey, accountID)
|
||||
}
|
||||
|
||||
// Regular OpenAI API key flow
|
||||
var opts []openai.Option
|
||||
opts = append(opts, openai.WithAPIKey(apiKey))
|
||||
opts = append(opts, openai.WithUseResponsesAPI())
|
||||
@@ -645,6 +735,135 @@ func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
return &ProviderResult{Model: model, ProviderOptions: providerOpts}, nil
|
||||
}
|
||||
|
||||
// createOpenAICodexProvider creates a provider for ChatGPT/Codex OAuth tokens.
|
||||
// Uses the chatgpt.com/backend-api/codex endpoint with special headers.
|
||||
func createOpenAICodexProvider(ctx context.Context, config *ProviderConfig, modelName, token, accountID string) (*ProviderResult, error) {
|
||||
// Check for spark models which are not accessible via OAuth
|
||||
if detectCodexModelFamily(modelName) == "gpt-codex-spark" {
|
||||
return nil, fmt.Errorf("gpt-codex-spark models are not accessible via ChatGPT OAuth. " +
|
||||
"These models require special access or a different authentication method. " +
|
||||
"Please use regular Codex models like 'openai/gpt-5.3-codex' instead")
|
||||
}
|
||||
|
||||
// Use the ChatGPT backend API with /codex path
|
||||
baseURL := "https://chatgpt.com/backend-api/codex"
|
||||
if config.ProviderURL != "" {
|
||||
baseURL = config.ProviderURL
|
||||
}
|
||||
|
||||
// Build custom HTTP client with required headers
|
||||
httpClient := createCodexHTTPClient(token, accountID, config.TLSSkipVerify)
|
||||
|
||||
var opts []openai.Option
|
||||
opts = append(opts, openai.WithAPIKey(token))
|
||||
opts = append(opts, openai.WithBaseURL(baseURL))
|
||||
opts = append(opts, openai.WithUseResponsesAPI())
|
||||
opts = append(opts, openai.WithHTTPClient(httpClient))
|
||||
|
||||
provider, err := openai.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OpenAI Codex provider: %w", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OpenAI Codex model: %w", err)
|
||||
}
|
||||
|
||||
providerOpts := buildCodexProviderOptions(config, modelName)
|
||||
|
||||
return &ProviderResult{
|
||||
Model: model,
|
||||
ProviderOptions: providerOpts,
|
||||
SkipMaxOutputTokens: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildCodexProviderOptions returns fantasy.ProviderOptions configured for
|
||||
// OpenAI Codex API. The Codex API requires the system prompt to be passed
|
||||
// as 'instructions' rather than as a system message.
|
||||
func buildCodexProviderOptions(config *ProviderConfig, modelName string) fantasy.ProviderOptions {
|
||||
store := false
|
||||
opts := &openai.ResponsesProviderOptions{
|
||||
Store: &store,
|
||||
}
|
||||
|
||||
if config.SystemPrompt != "" {
|
||||
opts.Instructions = &config.SystemPrompt
|
||||
}
|
||||
|
||||
if openai.IsResponsesReasoningModel(modelName) {
|
||||
opts.ReasoningEffort = thinkingLevelToReasoningEffort(config.ThinkingLevel)
|
||||
}
|
||||
|
||||
return fantasy.ProviderOptions{openai.Name: opts}
|
||||
}
|
||||
|
||||
// detectCodexModelFamily determines the model family from the model name
|
||||
func detectCodexModelFamily(modelName string) string {
|
||||
modelName = strings.ToLower(modelName)
|
||||
if strings.Contains(modelName, "spark") {
|
||||
return "gpt-codex-spark"
|
||||
}
|
||||
if strings.Contains(modelName, "codex-mini") || strings.Contains(modelName, "mini-latest") {
|
||||
return "gpt-codex-mini"
|
||||
}
|
||||
if strings.Contains(modelName, "codex") {
|
||||
return "gpt-codex"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// createCodexHTTPClient creates an HTTP client with headers required for ChatGPT/Codex API
|
||||
func createCodexHTTPClient(token, accountID string, skipVerify bool) *http.Client {
|
||||
var base http.RoundTripper
|
||||
if skipVerify {
|
||||
base = &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
base = http.DefaultTransport
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: &codexTransport{
|
||||
base: base,
|
||||
token: token,
|
||||
accountID: accountID,
|
||||
},
|
||||
Timeout: 120 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// codexTransport is a custom RoundTripper that adds ChatGPT/Codex specific headers
|
||||
type codexTransport struct {
|
||||
base http.RoundTripper
|
||||
token string
|
||||
accountID string
|
||||
}
|
||||
|
||||
func (t *codexTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
newReq := req.Clone(req.Context())
|
||||
|
||||
// Add required headers for ChatGPT/Codex API
|
||||
// These headers mimic the official pi client to avoid Cloudflare blocking
|
||||
newReq.Header.Set("Authorization", "Bearer "+t.token)
|
||||
if t.accountID != "" {
|
||||
newReq.Header.Set("chatgpt-account-id", t.accountID)
|
||||
}
|
||||
newReq.Header.Set("originator", "kit")
|
||||
newReq.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36")
|
||||
newReq.Header.Set("OpenAI-Beta", "responses=experimental")
|
||||
newReq.Header.Set("Accept", "text/event-stream")
|
||||
newReq.Header.Set("Accept-Language", "en-US,en;q=0.9")
|
||||
newReq.Header.Set("Cache-Control", "no-cache")
|
||||
newReq.Header.Set("Pragma", "no-cache")
|
||||
|
||||
return t.base.RoundTrip(newReq)
|
||||
}
|
||||
|
||||
func createGoogleProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
||||
apiKey := firstNonEmpty(
|
||||
config.ProviderAPIKey,
|
||||
@@ -781,6 +1000,133 @@ func createVercelProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
return &ProviderResult{Model: model}, nil
|
||||
}
|
||||
|
||||
// thinkTagRegex matches <think>...</think> tags for extracting reasoning content
|
||||
// from models that wrap thinking in XML-like tags (e.g., Qwen, DeepSeek).
|
||||
var thinkTagRegex = regexp.MustCompile(`(?s)<think>(.*?)</think>`)
|
||||
|
||||
// customExtraContentFunc extracts reasoning from <think> tags in the content field.
|
||||
// This handles models like Qwen and DeepSeek that return reasoning wrapped in XML tags
|
||||
// rather than using a separate reasoning_content field.
|
||||
func customExtraContentFunc(choice openaisdk.ChatCompletionChoice) []fantasy.Content {
|
||||
var content []fantasy.Content
|
||||
if choice.Message.Content == "" {
|
||||
return content
|
||||
}
|
||||
|
||||
// Check for <think> tags in the content
|
||||
matches := thinkTagRegex.FindStringSubmatch(choice.Message.Content)
|
||||
if len(matches) > 1 {
|
||||
// Found reasoning content in <think> tags
|
||||
reasoning := strings.TrimSpace(matches[1])
|
||||
if reasoning != "" {
|
||||
content = append(content, fantasy.ReasoningContent{
|
||||
Text: reasoning,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return content
|
||||
}
|
||||
|
||||
// customStreamExtraFunc handles streaming responses with <think> tags.
|
||||
// It extracts reasoning content and emits proper reasoning events.
|
||||
func customStreamExtraFunc(
|
||||
chunk openaisdk.ChatCompletionChunk,
|
||||
yield func(fantasy.StreamPart) bool,
|
||||
ctx map[string]any,
|
||||
) (map[string]any, bool) {
|
||||
if len(chunk.Choices) == 0 {
|
||||
return ctx, true
|
||||
}
|
||||
|
||||
const reasoningStartedKey = "reasoning_started"
|
||||
const reasoningBufferKey = "reasoning_buffer"
|
||||
const inThinkTagKey = "in_think_tag"
|
||||
|
||||
reasoningStarted, _ := ctx[reasoningStartedKey].(bool)
|
||||
inThinkTag, _ := ctx[inThinkTagKey].(bool)
|
||||
reasoningBuffer, _ := ctx[reasoningBufferKey].(string)
|
||||
|
||||
for i, choice := range chunk.Choices {
|
||||
content := choice.Delta.Content
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for <think> tag start
|
||||
if strings.Contains(content, "<think>") {
|
||||
inThinkTag = true
|
||||
ctx[inThinkTagKey] = true
|
||||
|
||||
// Emit reasoning start event
|
||||
if !reasoningStarted {
|
||||
reasoningStarted = true
|
||||
ctx[reasoningStartedKey] = true
|
||||
if !yield(fantasy.StreamPart{
|
||||
Type: fantasy.StreamPartTypeReasoningStart,
|
||||
ID: fmt.Sprintf("%d", i),
|
||||
}) {
|
||||
return ctx, false
|
||||
}
|
||||
}
|
||||
|
||||
// Extract content after <think>
|
||||
parts := strings.SplitN(content, "<think>", 2)
|
||||
if len(parts) > 1 && parts[1] != "" {
|
||||
reasoningBuffer += parts[1]
|
||||
ctx[reasoningBufferKey] = reasoningBuffer
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for </think> tag end
|
||||
if strings.Contains(content, "</think>") {
|
||||
inThinkTag = false
|
||||
ctx[inThinkTagKey] = false
|
||||
|
||||
// Extract content before </think>
|
||||
parts := strings.SplitN(content, "</think>", 2)
|
||||
if len(parts) > 0 {
|
||||
reasoningBuffer += parts[0]
|
||||
}
|
||||
|
||||
// Emit the accumulated reasoning
|
||||
if reasoningBuffer != "" {
|
||||
if !yield(fantasy.StreamPart{
|
||||
Type: fantasy.StreamPartTypeReasoningDelta,
|
||||
ID: fmt.Sprintf("%d", i),
|
||||
Delta: reasoningBuffer,
|
||||
}) {
|
||||
return ctx, false
|
||||
}
|
||||
ctx[reasoningBufferKey] = ""
|
||||
}
|
||||
|
||||
// Emit reasoning end
|
||||
if !yield(fantasy.StreamPart{
|
||||
Type: fantasy.StreamPartTypeReasoningEnd,
|
||||
ID: fmt.Sprintf("%d", i),
|
||||
}) {
|
||||
return ctx, false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Accumulate reasoning content while in think tag
|
||||
if inThinkTag {
|
||||
reasoningBuffer += content
|
||||
ctx[reasoningBufferKey] = reasoningBuffer
|
||||
}
|
||||
}
|
||||
|
||||
return ctx, true
|
||||
}
|
||||
|
||||
// customToPromptFunc converts prompts to OpenAI format using the default conversion.
|
||||
func customToPromptFunc(prompt fantasy.Prompt, systemPrompt, user string) ([]openaisdk.ChatCompletionMessageParamUnion, []fantasy.CallWarning) {
|
||||
return openai.DefaultToPrompt(prompt, systemPrompt, user)
|
||||
}
|
||||
|
||||
func createCustomProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
||||
if config.ProviderURL == "" {
|
||||
return nil, fmt.Errorf("custom provider requires --provider-url")
|
||||
@@ -795,16 +1141,23 @@ func createCustomProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
apiKey = "custom"
|
||||
}
|
||||
|
||||
var opts []openaicompat.Option
|
||||
opts = append(opts, openaicompat.WithBaseURL(config.ProviderURL))
|
||||
opts = append(opts, openaicompat.WithAPIKey(apiKey))
|
||||
opts = append(opts, openaicompat.WithName("custom"))
|
||||
// Use the openai provider directly with custom hooks to handle <think> tags
|
||||
// from models like Qwen and DeepSeek that wrap reasoning in XML tags.
|
||||
var opts []openai.Option
|
||||
opts = append(opts, openai.WithBaseURL(config.ProviderURL))
|
||||
opts = append(opts, openai.WithAPIKey(apiKey))
|
||||
opts = append(opts, openai.WithName("custom"))
|
||||
opts = append(opts, openai.WithLanguageModelOptions(
|
||||
openai.WithLanguageModelExtraContentFunc(customExtraContentFunc),
|
||||
openai.WithLanguageModelStreamExtraFunc(customStreamExtraFunc),
|
||||
openai.WithLanguageModelToPromptFunc(customToPromptFunc),
|
||||
))
|
||||
|
||||
if config.TLSSkipVerify {
|
||||
opts = append(opts, openaicompat.WithHTTPClient(createHTTPClientWithTLSConfig(true)))
|
||||
opts = append(opts, openai.WithHTTPClient(createHTTPClientWithTLSConfig(true)))
|
||||
}
|
||||
|
||||
p, err := openaicompat.New(opts...)
|
||||
p, err := openai.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create custom provider: %w", err)
|
||||
}
|
||||
|
||||
@@ -17,11 +17,51 @@ var embeddedModelsJSON []byte
|
||||
type ModelInfo struct {
|
||||
ID string
|
||||
Name string
|
||||
Family string // Model family (e.g., "claude", "gpt", "gemini")
|
||||
Attachment bool
|
||||
Reasoning bool
|
||||
Temperature bool
|
||||
Cost Cost
|
||||
Limit Limit
|
||||
ProviderNPM string // Model-specific provider npm override (e.g. "@ai-sdk/anthropic")
|
||||
}
|
||||
|
||||
// SupportsCaching returns true if this model family supports prompt caching.
|
||||
// This enables automatic cost savings for supported models regardless of provider.
|
||||
func (m *ModelInfo) SupportsCaching() bool {
|
||||
switch {
|
||||
case strings.HasPrefix(m.Family, "claude"):
|
||||
return true
|
||||
case strings.HasPrefix(m.Family, "gpt"),
|
||||
strings.HasPrefix(m.Family, "o1"),
|
||||
strings.HasPrefix(m.Family, "o3"),
|
||||
strings.HasPrefix(m.Family, "o4"),
|
||||
strings.HasPrefix(m.Family, "codex"):
|
||||
return true
|
||||
case strings.HasPrefix(m.Family, "gemini"):
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// CacheType returns the appropriate cache mechanism for this model family.
|
||||
// Returns empty string if caching is not supported.
|
||||
func (m *ModelInfo) CacheType() string {
|
||||
switch {
|
||||
case strings.HasPrefix(m.Family, "claude"):
|
||||
return "anthropic-ephemeral"
|
||||
case strings.HasPrefix(m.Family, "gpt"),
|
||||
strings.HasPrefix(m.Family, "o1"),
|
||||
strings.HasPrefix(m.Family, "o3"),
|
||||
strings.HasPrefix(m.Family, "o4"),
|
||||
strings.HasPrefix(m.Family, "codex"):
|
||||
return "openai-prompt-cache"
|
||||
case strings.HasPrefix(m.Family, "gemini"):
|
||||
return "google-cached-content"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// Cost represents the pricing information for a model.
|
||||
@@ -78,9 +118,14 @@ func buildFromModelsDB() map[string]ProviderInfo {
|
||||
for providerID, dp := range dbProviders {
|
||||
modelsMap := make(map[string]ModelInfo, len(dp.Models))
|
||||
for modelID, dm := range dp.Models {
|
||||
providerNPM := ""
|
||||
if dm.Provider != nil {
|
||||
providerNPM = dm.Provider.NPM
|
||||
}
|
||||
modelsMap[modelID] = ModelInfo{
|
||||
ID: dm.ID,
|
||||
Name: dm.Name,
|
||||
Family: dm.Family,
|
||||
Attachment: dm.Attachment,
|
||||
Reasoning: dm.Reasoning,
|
||||
Temperature: dm.Temperature,
|
||||
@@ -94,6 +139,7 @@ func buildFromModelsDB() map[string]ProviderInfo {
|
||||
Context: dm.Limit.Context,
|
||||
Output: dm.Limit.Output,
|
||||
},
|
||||
ProviderNPM: providerNPM,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -219,6 +265,15 @@ func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) err
|
||||
}
|
||||
}
|
||||
|
||||
// For openai, check stored credentials (OAuth / API key)
|
||||
if provider == "openai" {
|
||||
if cm, err := auth.NewCredentialManager(); err == nil {
|
||||
if has, _ := cm.HasOpenAICredentials(); has {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
envVars, err := r.getRequiredEnvVars(provider)
|
||||
if err != nil {
|
||||
// Unknown provider — nothing to validate
|
||||
@@ -293,27 +348,32 @@ func (r *ModelsRegistry) GetSupportedProviders() []string {
|
||||
return providers
|
||||
}
|
||||
|
||||
// GetFantasyProviders returns provider IDs that can be used with fantasy,
|
||||
// GetLLMProviders returns provider IDs that have LLM support,
|
||||
// either through a native provider or via openaicompat auto-routing.
|
||||
func (r *ModelsRegistry) GetFantasyProviders() []string {
|
||||
func (r *ModelsRegistry) GetLLMProviders() []string {
|
||||
var providers []string
|
||||
for providerID, info := range r.providers {
|
||||
if isProviderFantasySupported(providerID, &info) {
|
||||
if isProviderLLMSupported(providerID, &info) {
|
||||
providers = append(providers, providerID)
|
||||
}
|
||||
}
|
||||
return providers
|
||||
}
|
||||
|
||||
// isProviderFantasySupported checks if a provider can be used with fantasy.
|
||||
func isProviderFantasySupported(providerID string, info *ProviderInfo) bool {
|
||||
// Deprecated: Use GetLLMProviders instead.
|
||||
func (r *ModelsRegistry) GetFantasyProviders() []string {
|
||||
return r.GetLLMProviders()
|
||||
}
|
||||
|
||||
// isProviderLLMSupported checks if a provider can be used with the LLM layer.
|
||||
func isProviderLLMSupported(providerID string, info *ProviderInfo) bool {
|
||||
// Ollama is always supported (via openaicompat pointed at localhost)
|
||||
if providerID == "ollama" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if npm maps to a fantasy provider
|
||||
if _, ok := npmToFantasyProvider[info.NPM]; ok {
|
||||
// Check if npm maps to an LLM provider
|
||||
if _, ok := npmToLLMProvider[info.NPM]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -96,6 +96,7 @@ func ListAllSessions() ([]SessionInfo, error) {
|
||||
}
|
||||
|
||||
// listSessionsInDir reads all .jsonl files in a directory and extracts session info.
|
||||
// Empty sessions (no messages) are automatically cleaned up and not returned.
|
||||
func listSessionsInDir(dir string) ([]SessionInfo, error) {
|
||||
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
@@ -117,6 +118,11 @@ func listSessionsInDir(dir string) ([]SessionInfo, error) {
|
||||
if err != nil {
|
||||
continue // skip malformed session files
|
||||
}
|
||||
// Clean up and skip empty sessions (no messages)
|
||||
if info.MessageCount == 0 {
|
||||
_ = os.Remove(path)
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, *info)
|
||||
}
|
||||
|
||||
|
||||
@@ -181,7 +181,7 @@ func OpenTreeSession(path string) (*TreeManager, error) {
|
||||
|
||||
// Set leaf to the last entry.
|
||||
if len(tm.entries) > 0 {
|
||||
tm.leafID = tm.entryID(tm.entries[len(tm.entries)-1])
|
||||
tm.leafID = tm.EntryID(tm.entries[len(tm.entries)-1])
|
||||
}
|
||||
|
||||
// Open file for appending.
|
||||
@@ -242,9 +242,14 @@ func (tm *TreeManager) AppendMessage(msg message.Message) (string, error) {
|
||||
return entry.ID, nil
|
||||
}
|
||||
|
||||
// AppendFantasyMessage converts a fantasy.Message and appends it.
|
||||
// AppendLLMMessage converts an LLM message and appends it.
|
||||
func (tm *TreeManager) AppendLLMMessage(msg fantasy.Message) (string, error) {
|
||||
return tm.AppendMessage(message.FromLLMMessage(msg))
|
||||
}
|
||||
|
||||
// Deprecated: Use AppendLLMMessage instead.
|
||||
func (tm *TreeManager) AppendFantasyMessage(msg fantasy.Message) (string, error) {
|
||||
return tm.AppendMessage(message.FromFantasyMessage(msg))
|
||||
return tm.AppendLLMMessage(msg)
|
||||
}
|
||||
|
||||
// AppendModelChange records a model/provider change.
|
||||
@@ -521,7 +526,7 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
for _, entry := range branch {
|
||||
// Once we reach the first kept entry, stop skipping.
|
||||
if skipping {
|
||||
entryID := tm.entryID(entry)
|
||||
entryID := tm.EntryID(entry)
|
||||
if entryID == lastCompaction.FirstKeptEntryID {
|
||||
skipping = false
|
||||
} else {
|
||||
@@ -535,7 +540,7 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
if err != nil {
|
||||
continue // skip malformed entries
|
||||
}
|
||||
msgs := msg.ToFantasyMessages()
|
||||
msgs := msg.ToLLMMessages()
|
||||
messages = append(messages, msgs...)
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
@@ -628,6 +633,11 @@ func (tm *TreeManager) MessageCount() int {
|
||||
return count
|
||||
}
|
||||
|
||||
// IsEmpty returns true if the session has no messages (only header).
|
||||
func (tm *TreeManager) IsEmpty() bool {
|
||||
return tm.MessageCount() == 0
|
||||
}
|
||||
|
||||
// Close closes the underlying file handle.
|
||||
func (tm *TreeManager) Close() error {
|
||||
tm.mu.Lock()
|
||||
@@ -679,7 +689,7 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
|
||||
skipping := lastCompaction != nil
|
||||
for _, entry := range branch {
|
||||
if skipping {
|
||||
entryID := tm.entryID(entry)
|
||||
entryID := tm.EntryID(entry)
|
||||
if entryID == lastCompaction.FirstKeptEntryID {
|
||||
skipping = false
|
||||
} else {
|
||||
@@ -693,7 +703,7 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToFantasyMessages()
|
||||
msgs := msg.ToLLMMessages()
|
||||
for range msgs {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
@@ -732,31 +742,41 @@ func (tm *TreeManager) GetLastCompaction() *CompactionEntry {
|
||||
|
||||
// --- Legacy bridge ---
|
||||
|
||||
// AddFantasyMessages appends multiple fantasy messages as entries. This is
|
||||
// AddLLMMessages appends multiple LLM messages as entries. This is
|
||||
// used when syncing from the agent's ConversationMessages after a step.
|
||||
func (tm *TreeManager) AddFantasyMessages(msgs []fantasy.Message) error {
|
||||
func (tm *TreeManager) AddLLMMessages(msgs []fantasy.Message) error {
|
||||
for _, msg := range msgs {
|
||||
if _, err := tm.AppendFantasyMessage(msg); err != nil {
|
||||
if _, err := tm.AppendLLMMessage(msg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFantasyMessages builds the context and returns just the messages.
|
||||
// Deprecated: Use AddLLMMessages instead.
|
||||
func (tm *TreeManager) AddFantasyMessages(msgs []fantasy.Message) error {
|
||||
return tm.AddLLMMessages(msgs)
|
||||
}
|
||||
|
||||
// GetLLMMessages builds the context and returns just the messages.
|
||||
// This satisfies the same conceptual role as the old Manager.GetMessages().
|
||||
func (tm *TreeManager) GetFantasyMessages() []fantasy.Message {
|
||||
func (tm *TreeManager) GetLLMMessages() []fantasy.Message {
|
||||
msgs, _, _ := tm.BuildContext()
|
||||
return msgs
|
||||
}
|
||||
|
||||
// Deprecated: Use GetLLMMessages instead.
|
||||
func (tm *TreeManager) GetFantasyMessages() []fantasy.Message {
|
||||
return tm.GetLLMMessages()
|
||||
}
|
||||
|
||||
// --- Internal helpers ---
|
||||
|
||||
// addEntryToIndex adds an entry to the in-memory indices.
|
||||
func (tm *TreeManager) addEntryToIndex(entry any) {
|
||||
tm.entries = append(tm.entries, entry)
|
||||
|
||||
id := tm.entryID(entry)
|
||||
id := tm.EntryID(entry)
|
||||
parentID := tm.entryParentID(entry)
|
||||
|
||||
if id != "" {
|
||||
@@ -793,8 +813,8 @@ func (tm *TreeManager) writeEntry(entry any) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// entryID extracts the ID from any entry type.
|
||||
func (tm *TreeManager) entryID(entry any) string {
|
||||
// EntryID extracts the ID from any entry type.
|
||||
func (tm *TreeManager) EntryID(entry any) string {
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
return e.ID
|
||||
|
||||
@@ -127,9 +127,7 @@ func (p *MCPConnectionPool) GetConnection(ctx context.Context, serverName string
|
||||
return conn, nil
|
||||
} else {
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
p.debugLogger.LogDebug(fmt.Sprintf("[POOL] Connection %s unhealthy, removing", serverName))
|
||||
}
|
||||
p.debugLogger.LogDebug(fmt.Sprintf("[POOL] Connection %s unhealthy, removing", serverName))
|
||||
}
|
||||
_ = conn.client.Close()
|
||||
delete(p.connections, serverName)
|
||||
|
||||
@@ -3,6 +3,7 @@ package tools
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -70,7 +71,7 @@ func TestMCPToolManager_LoadTools_GracefulFailure(t *testing.T) {
|
||||
}
|
||||
|
||||
// The error should mention that all servers failed
|
||||
if err != nil && !contains(err.Error(), "all MCP servers failed") {
|
||||
if err != nil && !strings.Contains(err.Error(), "all MCP servers failed") {
|
||||
t.Errorf("Expected error message to mention all servers failed, got: %v", err)
|
||||
}
|
||||
|
||||
@@ -459,13 +460,3 @@ func sliceEqual(a, b []any) bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Helper function to check if a string contains a substring
|
||||
func contains(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
+269
-25
@@ -149,11 +149,13 @@ func TestInputComponent_QuitReturnsTeaQuit(t *testing.T) {
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// TestInputComponent_ClearCallsClearMessages verifies that /clear (and its
|
||||
// aliases) calls appCtrl.ClearMessages() and returns no submitMsg.
|
||||
// TestInputComponent_ClearForwardsAsSubmitMsg verifies that /clear (and its
|
||||
// aliases) are forwarded as submitMsg to the parent model so that the parent
|
||||
// can call ClearMessages(), update scrollback, and print the confirmation
|
||||
// message in one place. InputComponent must NOT call ClearMessages() directly.
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func TestInputComponent_ClearCallsClearMessages(t *testing.T) {
|
||||
func TestInputComponent_ClearForwardsAsSubmitMsg(t *testing.T) {
|
||||
aliases := []string{"/clear", "/c", "/cls"}
|
||||
for _, alias := range aliases {
|
||||
t.Run(alias, func(t *testing.T) {
|
||||
@@ -164,22 +166,29 @@ func TestInputComponent_ClearCallsClearMessages(t *testing.T) {
|
||||
|
||||
_, cmd := sendInputMsg(c, tea.KeyPressMsg{Code: tea.KeyEnter})
|
||||
|
||||
if ctrl.clearMsgCalled != 1 {
|
||||
t.Fatalf("%s: expected ClearMessages() called once, got %d", alias, ctrl.clearMsgCalled)
|
||||
// InputComponent must NOT call ClearMessages() directly.
|
||||
if ctrl.clearMsgCalled != 0 {
|
||||
t.Fatalf("%s: InputComponent must not call ClearMessages(), got %d", alias, ctrl.clearMsgCalled)
|
||||
}
|
||||
// No cmd should be returned (no submitMsg forwarded to parent).
|
||||
if cmd != nil {
|
||||
msg := runCmd(cmd)
|
||||
if _, ok := msg.(submitMsg); ok {
|
||||
t.Fatalf("%s: /clear should not emit submitMsg, got submitMsg", alias)
|
||||
}
|
||||
// A submitMsg must be emitted so the parent model handles /clear.
|
||||
if cmd == nil {
|
||||
t.Fatalf("%s: expected submitMsg cmd, got nil", alias)
|
||||
}
|
||||
msg := runCmd(cmd)
|
||||
sm, ok := msg.(submitMsg)
|
||||
if !ok {
|
||||
t.Fatalf("%s: expected submitMsg, got %T", alias, msg)
|
||||
}
|
||||
if sm.Text != alias {
|
||||
t.Fatalf("%s: expected submitMsg text %q, got %q", alias, alias, sm.Text)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestInputComponent_ClearNilCtrl_NoPanic verifies that /clear with a nil
|
||||
// appCtrl does not panic.
|
||||
// appCtrl does not panic. Since /clear is now forwarded to the parent via
|
||||
// submitMsg, no appCtrl interaction happens in InputComponent at all.
|
||||
func TestInputComponent_ClearNilCtrl_NoPanic(t *testing.T) {
|
||||
c := newTestInput(nil)
|
||||
c.textarea.SetValue("/clear")
|
||||
@@ -266,10 +275,9 @@ func TestInputComponent_UnknownSlashCommand_ForwardsAsSubmit(t *testing.T) {
|
||||
// Helpers
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// newTestStream creates a StreamComponent with a fixed width and model name,
|
||||
// in non-compact mode.
|
||||
// newTestStream creates a StreamComponent with a fixed width and model name.
|
||||
func newTestStream() *StreamComponent {
|
||||
return NewStreamComponent(false, 80, "test-model")
|
||||
return NewStreamComponent(80, "test-model")
|
||||
}
|
||||
|
||||
// sendStreamMsg calls component.Update and returns the updated component.
|
||||
@@ -349,7 +357,7 @@ func TestStreamComponent_SpinnerKeepsRunningDuringStreaming(t *testing.T) {
|
||||
c = sendStreamMsg(c, app.StreamChunkEvent{Content: "hello"})
|
||||
|
||||
// Flush pending chunks (simulates the 16ms tick firing).
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{})
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{generation: c.flushGeneration})
|
||||
|
||||
if !c.spinning {
|
||||
t.Fatal("expected spinning=true after first chunk")
|
||||
@@ -376,7 +384,7 @@ func TestStreamComponent_ChunkAccumulation(t *testing.T) {
|
||||
}
|
||||
|
||||
// Flush pending chunks (simulates the 16ms tick firing).
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{})
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{generation: c.flushGeneration})
|
||||
|
||||
got := c.streamContent.String()
|
||||
want := "Hello, world!"
|
||||
@@ -396,6 +404,7 @@ func TestStreamComponent_ToolExecution_IsStarting_ShowsSpinner(t *testing.T) {
|
||||
c := newTestStream()
|
||||
|
||||
_, cmd := c.Update(app.ToolExecutionEvent{
|
||||
ToolCallID: "call-exec-1",
|
||||
ToolName: "exec_tool",
|
||||
IsStarting: true,
|
||||
})
|
||||
@@ -403,8 +412,9 @@ func TestStreamComponent_ToolExecution_IsStarting_ShowsSpinner(t *testing.T) {
|
||||
if !c.spinning {
|
||||
t.Fatal("expected spinning=true during tool execution")
|
||||
}
|
||||
if len(c.activeTools) != 1 || !strings.Contains(c.activeTools[0], "exec_tool") {
|
||||
t.Fatalf("expected activeTools to contain tool name, got %v", c.activeTools)
|
||||
tools := c.activeToolDisplays()
|
||||
if len(tools) != 1 || !strings.Contains(tools[0], "exec_tool") {
|
||||
t.Fatalf("expected activeTools to contain tool name, got %v", tools)
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Fatal("expected tick cmd from ToolExecutionEvent{IsStarting:true}")
|
||||
@@ -418,11 +428,13 @@ func TestStreamComponent_ToolExecution_NotStarting_KeepsSpinning(t *testing.T) {
|
||||
c = sendStreamMsg(c, app.SpinnerEvent{Show: true})
|
||||
// Simulate a tool starting
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{
|
||||
ToolCallID: "call-some-1",
|
||||
ToolName: "some_tool",
|
||||
IsStarting: true,
|
||||
})
|
||||
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{
|
||||
ToolCallID: "call-some-1",
|
||||
ToolName: "some_tool",
|
||||
IsStarting: false,
|
||||
})
|
||||
@@ -440,9 +452,9 @@ func TestStreamComponent_ParallelToolExecution(t *testing.T) {
|
||||
c := newTestStream()
|
||||
|
||||
// Start three tools in parallel
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "read", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "grep", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "find", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-read", ToolName: "read", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-grep", ToolName: "grep", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-find", ToolName: "find", IsStarting: true})
|
||||
|
||||
if len(c.activeTools) != 3 {
|
||||
t.Fatalf("expected 3 active tools, got %d: %v", len(c.activeTools), c.activeTools)
|
||||
@@ -455,19 +467,44 @@ func TestStreamComponent_ParallelToolExecution(t *testing.T) {
|
||||
}
|
||||
|
||||
// Finish one tool
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "grep", IsStarting: false})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-grep", ToolName: "grep", IsStarting: false})
|
||||
if len(c.activeTools) != 2 {
|
||||
t.Fatalf("expected 2 active tools after one finished, got %d: %v", len(c.activeTools), c.activeTools)
|
||||
}
|
||||
|
||||
// Finish remaining tools
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "read", IsStarting: false})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "find", IsStarting: false})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-read", ToolName: "read", IsStarting: false})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-find", ToolName: "find", IsStarting: false})
|
||||
if len(c.activeTools) != 0 {
|
||||
t.Fatalf("expected 0 active tools after all finished, got %d: %v", len(c.activeTools), c.activeTools)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamComponent_ParallelSameToolName_UsesToolCallID verifies finishing one
|
||||
// tool call does not remove another concurrent call with the same tool name.
|
||||
func TestStreamComponent_ParallelSameToolName_UsesToolCallID(t *testing.T) {
|
||||
c := newTestStream()
|
||||
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-read-1", ToolName: "read", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-read-2", ToolName: "read", IsStarting: true})
|
||||
|
||||
tools := c.activeToolDisplays()
|
||||
if len(tools) != 2 {
|
||||
t.Fatalf("expected 2 active read calls, got %d (%v)", len(tools), tools)
|
||||
}
|
||||
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-read-1", ToolName: "read", IsStarting: false})
|
||||
tools = c.activeToolDisplays()
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("expected 1 active read call after finishing one ID, got %d (%v)", len(tools), tools)
|
||||
}
|
||||
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-read-2", ToolName: "read", IsStarting: false})
|
||||
if len(c.activeToolDisplays()) != 0 {
|
||||
t.Fatalf("expected no active tools after finishing both IDs, got %v", c.activeToolDisplays())
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// TestStreamComponent_GetRenderedContent verifies the method returns rendered
|
||||
// text when content is accumulated, and empty string when not.
|
||||
@@ -621,3 +658,210 @@ func TestStreamComponent_StaleTick_Discarded(t *testing.T) {
|
||||
t.Fatal("current-gen tick should reschedule")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamComponent_StaleFlushTick_Discarded verifies that flush ticks from a
|
||||
// previous generation (e.g. pre-Reset) are ignored.
|
||||
func TestStreamComponent_StaleFlushTick_Discarded(t *testing.T) {
|
||||
c := newTestStream()
|
||||
|
||||
// Start a pending flush and capture its generation.
|
||||
c = sendStreamMsg(c, app.StreamChunkEvent{Content: "old"})
|
||||
staleGen := c.flushGeneration
|
||||
if !c.flushPending {
|
||||
t.Fatal("precondition: expected flushPending=true after first chunk")
|
||||
}
|
||||
|
||||
// Reset should invalidate in-flight flush ticks.
|
||||
c.Reset()
|
||||
if c.flushGeneration == staleGen {
|
||||
t.Fatal("expected flushGeneration to change after Reset")
|
||||
}
|
||||
|
||||
// New content in a new generation.
|
||||
c = sendStreamMsg(c, app.StreamChunkEvent{Content: "new"})
|
||||
if got := c.pendingStream.String(); got != "new" {
|
||||
t.Fatalf("expected pendingStream='new', got %q", got)
|
||||
}
|
||||
|
||||
// Stale flush tick should be ignored.
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{generation: staleGen})
|
||||
if got := c.pendingStream.String(); got != "new" {
|
||||
t.Fatalf("stale flush tick should not commit pending stream, got %q", got)
|
||||
}
|
||||
|
||||
// Current generation flush should commit.
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{generation: c.flushGeneration})
|
||||
if got := c.pendingStream.String(); got != "" {
|
||||
t.Fatalf("expected pendingStream empty after current flush, got %q", got)
|
||||
}
|
||||
if got := c.streamContent.String(); got != "new" {
|
||||
t.Fatalf("expected streamContent='new' after current flush, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamComponent_ConsumeOverflow_NoHeight verifies that when height is
|
||||
// unconstrained (0), ConsumeOverflow always returns "".
|
||||
func TestStreamComponent_ConsumeOverflow_NoHeight(t *testing.T) {
|
||||
c := newTestStream()
|
||||
// Commit some content directly.
|
||||
c.streamContent.WriteString("line1\nline2\nline3")
|
||||
c.phase = streamPhaseActive
|
||||
c.renderDirty = true
|
||||
|
||||
if got := c.ConsumeOverflow(); got != "" {
|
||||
t.Fatalf("expected empty with height=0, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamComponent_ConsumeOverflow_NoOverflow verifies that when content fits
|
||||
// within the allocated height, ConsumeOverflow returns "".
|
||||
func TestStreamComponent_ConsumeOverflow_NoOverflow(t *testing.T) {
|
||||
c := newTestStream()
|
||||
c.streamContent.WriteString("line1\nline2")
|
||||
c.phase = streamPhaseActive
|
||||
c.renderDirty = true
|
||||
c.height = 20 // plenty of room
|
||||
|
||||
if got := c.ConsumeOverflow(); got != "" {
|
||||
t.Fatalf("expected empty when content fits, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamComponent_ConsumeOverflow_EmitsTopLines verifies that when the
|
||||
// rendered content has more lines than the allocated height, ConsumeOverflow
|
||||
// returns the top overflow lines and advances the internal pointer.
|
||||
func TestStreamComponent_ConsumeOverflow_EmitsTopLines(t *testing.T) {
|
||||
c := newTestStream()
|
||||
c.height = 2
|
||||
|
||||
// Build raw content that when "rendered" (plain text for this test)
|
||||
// is 5 lines — we bypass the markdown renderer by writing directly to
|
||||
// streamContent and using a nil renderer.
|
||||
c.renderer = nil
|
||||
c.streamContent.WriteString("a\nb\nc\nd\ne")
|
||||
c.phase = streamPhaseActive
|
||||
c.renderDirty = true
|
||||
|
||||
// First call: should return lines a, b, c (5 lines - 2 visible = 3 overflow).
|
||||
overflow1 := c.ConsumeOverflow()
|
||||
if overflow1 == "" {
|
||||
t.Fatal("expected overflow, got empty")
|
||||
}
|
||||
overflowLines := strings.Split(overflow1, "\n")
|
||||
if len(overflowLines) != 3 {
|
||||
t.Fatalf("expected 3 overflow lines, got %d: %q", len(overflowLines), overflow1)
|
||||
}
|
||||
if overflowLines[0] != "a" || overflowLines[1] != "b" || overflowLines[2] != "c" {
|
||||
t.Fatalf("unexpected overflow lines: %v", overflowLines)
|
||||
}
|
||||
|
||||
// Second call without new content should return "" (pointer already advanced).
|
||||
overflow2 := c.ConsumeOverflow()
|
||||
if overflow2 != "" {
|
||||
t.Fatalf("expected empty on second call, got %q", overflow2)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamComponent_ConsumeOverflow_IncrementalFlush verifies that as new
|
||||
// content arrives, ConsumeOverflow incrementally returns only newly overflowed
|
||||
// lines on each call.
|
||||
func TestStreamComponent_ConsumeOverflow_IncrementalFlush(t *testing.T) {
|
||||
c := newTestStream()
|
||||
c.height = 2
|
||||
c.renderer = nil
|
||||
c.phase = streamPhaseActive
|
||||
|
||||
// Start with 3 lines — 1 overflows.
|
||||
c.streamContent.WriteString("a\nb\nc")
|
||||
c.renderDirty = true
|
||||
|
||||
overflow1 := c.ConsumeOverflow()
|
||||
if overflow1 != "a" {
|
||||
t.Fatalf("expected 'a', got %q", overflow1)
|
||||
}
|
||||
|
||||
// Add 2 more lines — 2 additional overflows.
|
||||
c.streamContent.WriteString("\nd\ne")
|
||||
c.renderDirty = true
|
||||
|
||||
overflow2 := c.ConsumeOverflow()
|
||||
want := "b\nc"
|
||||
if overflow2 != want {
|
||||
t.Fatalf("expected %q, got %q", want, overflow2)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamComponent_ConsumeOverflow_ResetClearsPointer verifies that Reset()
|
||||
// resets the scrollback pointer so the next response starts fresh.
|
||||
func TestStreamComponent_ConsumeOverflow_ResetClearsPointer(t *testing.T) {
|
||||
c := newTestStream()
|
||||
c.height = 1
|
||||
c.renderer = nil
|
||||
c.phase = streamPhaseActive
|
||||
|
||||
c.streamContent.WriteString("a\nb")
|
||||
c.renderDirty = true
|
||||
overflow := c.ConsumeOverflow()
|
||||
if overflow != "a" {
|
||||
t.Fatalf("expected 'a', got %q", overflow)
|
||||
}
|
||||
|
||||
c.Reset()
|
||||
if c.scrollbackFlushedLines != 0 {
|
||||
t.Fatalf("expected scrollbackFlushedLines=0 after Reset, got %d", c.scrollbackFlushedLines)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamComponent_GetRenderedContent_SkipsFlushedLines verifies that
|
||||
// GetRenderedContent skips lines already emitted via ConsumeOverflow so the
|
||||
// caller doesn't re-print content already in the terminal scrollback.
|
||||
func TestStreamComponent_GetRenderedContent_SkipsFlushedLines(t *testing.T) {
|
||||
c := newTestStream()
|
||||
c.height = 2
|
||||
c.renderer = nil
|
||||
c.phase = streamPhaseActive
|
||||
|
||||
// 5 lines → 3 overflow, 2 visible.
|
||||
c.streamContent.WriteString("a\nb\nc\nd\ne")
|
||||
c.renderDirty = true
|
||||
|
||||
// Consume the overflow: lines a, b, c.
|
||||
overflow := c.ConsumeOverflow()
|
||||
if overflow != "a\nb\nc" {
|
||||
t.Fatalf("expected 'a\\nb\\nc', got %q", overflow)
|
||||
}
|
||||
if c.scrollbackFlushedLines != 3 {
|
||||
t.Fatalf("expected flushedLines=3, got %d", c.scrollbackFlushedLines)
|
||||
}
|
||||
|
||||
// GetRenderedContent should only return the non-flushed portion: d, e.
|
||||
got := c.GetRenderedContent()
|
||||
if got != "d\ne" {
|
||||
t.Fatalf("expected 'd\\ne', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamComponent_GetRenderedContent_AllFlushed verifies that when all
|
||||
// lines have been pushed via ConsumeOverflow, GetRenderedContent returns "".
|
||||
func TestStreamComponent_GetRenderedContent_AllFlushed(t *testing.T) {
|
||||
c := newTestStream()
|
||||
c.height = 1
|
||||
c.renderer = nil
|
||||
c.phase = streamPhaseActive
|
||||
|
||||
// 2 lines → height=1, so 1 overflow.
|
||||
c.streamContent.WriteString("a\nb")
|
||||
c.renderDirty = true
|
||||
|
||||
// Consume overflow (line a), leaving 1 visible line (b).
|
||||
_ = c.ConsumeOverflow()
|
||||
|
||||
// Now bump height so everything overflows — simulate a resize that made
|
||||
// the viewable area 0, forcing all content to be "flushed".
|
||||
c.scrollbackFlushedLines = 2 // pretend both lines were flushed
|
||||
|
||||
got := c.GetRenderedContent()
|
||||
if got != "" {
|
||||
t.Fatalf("expected empty when all lines flushed, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
+16
-24
@@ -11,33 +11,26 @@ import (
|
||||
)
|
||||
|
||||
// CLI manages the command-line interface for KIT, providing message rendering,
|
||||
// user input handling, and display management. It supports both standard and compact
|
||||
// display modes, handles streaming responses, tracks token usage, and manages the
|
||||
// overall conversation flow between the user and AI assistants.
|
||||
// user input handling, and display management. It handles streaming responses,
|
||||
// tracks token usage, and manages the overall conversation flow between the
|
||||
// user and AI assistants.
|
||||
type CLI struct {
|
||||
renderer Renderer
|
||||
usageTracker *UsageTracker
|
||||
width int
|
||||
compactMode bool
|
||||
debug bool
|
||||
modelName string
|
||||
}
|
||||
|
||||
// NewCLI creates and initializes a new CLI instance with the specified display modes.
|
||||
// The debug parameter enables debug message rendering, while compact enables a more
|
||||
// condensed display format. Returns an initialized CLI ready for interaction or an
|
||||
// NewCLI creates and initializes a new CLI instance. The debug parameter enables
|
||||
// debug message rendering. Returns an initialized CLI ready for interaction or an
|
||||
// error if initialization fails.
|
||||
func NewCLI(debug bool, compact bool) (*CLI, error) {
|
||||
func NewCLI(debug bool) (*CLI, error) {
|
||||
cli := &CLI{
|
||||
compactMode: compact,
|
||||
debug: debug,
|
||||
debug: debug,
|
||||
}
|
||||
cli.updateSize()
|
||||
if compact {
|
||||
cli.renderer = NewCompactRenderer(cli.width, debug)
|
||||
} else {
|
||||
cli.renderer = newMessageRenderer(cli.width, debug)
|
||||
}
|
||||
cli.renderer = newMessageRenderer(cli.width, debug)
|
||||
|
||||
return cli, nil
|
||||
}
|
||||
@@ -179,9 +172,8 @@ func (c *CLI) DisplayDebugConfig(config map[string]any) {
|
||||
}
|
||||
|
||||
// UpdateUsageFromResponse records token usage using metadata from the fantasy
|
||||
// response when available. Falls back to text-based estimation if the metadata is
|
||||
// missing or appears unreliable. This provides more accurate usage tracking when
|
||||
// providers supply token count information.
|
||||
// response. Only actual API-reported tokens are used for cost tracking.
|
||||
// If the provider doesn't report token counts, no usage is recorded.
|
||||
func (c *CLI) UpdateUsageFromResponse(response *fantasy.Response, inputText string) {
|
||||
if c.usageTracker == nil {
|
||||
return
|
||||
@@ -191,19 +183,19 @@ func (c *CLI) UpdateUsageFromResponse(response *fantasy.Response, inputText stri
|
||||
inputTokens := int(usage.InputTokens)
|
||||
outputTokens := int(usage.OutputTokens)
|
||||
|
||||
// Validate that the metadata seems reasonable
|
||||
if inputTokens > 0 && outputTokens > 0 {
|
||||
// Only use actual API-reported tokens for cost tracking.
|
||||
// We intentionally do NOT estimate tokens - estimation is inaccurate
|
||||
// and should never be used for cost calculations.
|
||||
if inputTokens > 0 {
|
||||
cacheReadTokens := int(usage.CacheReadTokens)
|
||||
cacheWriteTokens := int(usage.CacheCreationTokens)
|
||||
c.usageTracker.UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens)
|
||||
// Per-response usage is a single API call, so it represents the
|
||||
// actual context window fill level.
|
||||
c.usageTracker.SetContextTokens(inputTokens + outputTokens)
|
||||
} else {
|
||||
// Fallback to estimation if no metadata is available.
|
||||
// EstimateAndUpdateUsage sets context tokens internally.
|
||||
c.usageTracker.EstimateAndUpdateUsage(inputText, response.Content.Text())
|
||||
}
|
||||
// If inputTokens is 0, the provider didn't report usage - we skip recording
|
||||
// rather than estimating, to ensure cost accuracy.
|
||||
}
|
||||
|
||||
// DisplayUsageAfterResponse renders and displays token usage information immediately
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"github.com/atotto/clipboard"
|
||||
)
|
||||
|
||||
// CopyToClipboard writes text to both the system clipboard and via OSC 52.
|
||||
// Returns a tea.Cmd that can be used in Bubble Tea's Update flow.
|
||||
func CopyToClipboard(text string) tea.Cmd {
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return tea.Sequence(
|
||||
// Method 1: OSC 52 escape sequence (works in modern terminals)
|
||||
tea.SetClipboard(text),
|
||||
|
||||
// Method 2: Native system clipboard (atotto/clipboard)
|
||||
func() tea.Msg {
|
||||
// Best effort - ignore errors
|
||||
_ = clipboard.WriteAll(text)
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// CopyToClipboardWithMessage writes text to clipboard and returns a toast notification.
|
||||
func CopyToClipboardWithMessage(text string, message string) tea.Cmd {
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return tea.Sequence(
|
||||
CopyToClipboard(text),
|
||||
func() tea.Msg {
|
||||
return ToastMsg{Message: message, Type: ToastInfo}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// ToastType represents the type of toast notification.
|
||||
type ToastType int
|
||||
|
||||
const (
|
||||
ToastInfo ToastType = iota
|
||||
ToastSuccess
|
||||
ToastWarning
|
||||
ToastError
|
||||
)
|
||||
|
||||
// ToastMsg is a message to display a toast notification.
|
||||
type ToastMsg struct {
|
||||
Message string
|
||||
Type ToastType
|
||||
}
|
||||
|
||||
// IsClipboardSupported returns true if the clipboard is supported on this platform.
|
||||
func IsClipboardSupported() bool {
|
||||
// atotto/clipboard supports Linux (with xclip or xsel), macOS, Windows
|
||||
switch runtime.GOOS {
|
||||
case "darwin", "windows":
|
||||
return true
|
||||
case "linux":
|
||||
// Check if xclip or xsel is available
|
||||
// This is a best-effort check
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// CopySelection represents a text selection with start/end positions.
|
||||
type CopySelection struct {
|
||||
StartItemIdx int // Index of item where selection starts
|
||||
StartLine int // Line within item where selection starts
|
||||
StartCol int // Column where selection starts
|
||||
EndItemIdx int // Index of item where selection ends
|
||||
EndLine int // Line within item where selection ends
|
||||
EndCol int // Column where selection ends
|
||||
Active bool // Whether selection is currently active
|
||||
}
|
||||
|
||||
// IsEmpty returns true if the selection has no content.
|
||||
func (s CopySelection) IsEmpty() bool {
|
||||
return !s.Active || (s.StartItemIdx == s.EndItemIdx && s.StartLine == s.EndLine && s.StartCol == s.EndCol)
|
||||
}
|
||||
|
||||
// String returns a string representation for debugging.
|
||||
func (s CopySelection) String() string {
|
||||
return fmt.Sprintf("Selection{item:%d-%d, line:%d-%d, col:%d-%d, active:%v}",
|
||||
s.StartItemIdx, s.EndItemIdx, s.StartLine, s.EndLine, s.StartCol, s.EndCol, s.Active)
|
||||
}
|
||||
@@ -1,494 +0,0 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/lipgloss/v2"
|
||||
)
|
||||
|
||||
// CompactRenderer handles rendering messages in a space-efficient compact format,
|
||||
// optimized for terminals with limited vertical space. It displays messages with
|
||||
// minimal decorations while maintaining readability and essential information.
|
||||
type CompactRenderer struct {
|
||||
width int
|
||||
debug bool
|
||||
|
||||
// getToolRenderer returns extension-provided rendering overrides for a
|
||||
// specific tool. May be nil if no extensions are loaded. Used in
|
||||
// RenderToolMessage to check for custom header/body formatting before
|
||||
// falling back to builtin renderers.
|
||||
getToolRenderer func(toolName string) *ToolRendererData
|
||||
}
|
||||
|
||||
// NewCompactRenderer creates and initializes a new CompactRenderer with the specified
|
||||
// terminal width and debug mode setting. The width parameter determines line wrapping,
|
||||
// while debug enables additional diagnostic output in rendered messages.
|
||||
func NewCompactRenderer(width int, debug bool) *CompactRenderer {
|
||||
return &CompactRenderer{
|
||||
width: width,
|
||||
debug: debug,
|
||||
}
|
||||
}
|
||||
|
||||
// SetWidth updates the terminal width for the renderer, affecting how content
|
||||
// is wrapped and formatted in subsequent render operations.
|
||||
func (r *CompactRenderer) SetWidth(width int) {
|
||||
r.width = width
|
||||
}
|
||||
|
||||
// RenderUserMessage renders a user's input message in compact format with a
|
||||
// distinctive symbol (>) and label. The content is formatted to preserve structure
|
||||
// while minimizing vertical space usage. Returns a UIMessage with formatted content
|
||||
// and metadata.
|
||||
func (r *CompactRenderer) RenderUserMessage(content string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Info).Render(">")
|
||||
label := lipgloss.NewStyle().Foreground(theme.Info).Bold(true).Render("User")
|
||||
|
||||
// Only run markdown rendering when the message contains code spans or
|
||||
// fenced code blocks. Plain text is rendered directly so that newlines
|
||||
// are preserved without the extra paragraph spacing glamour adds.
|
||||
var compactContent string
|
||||
if strings.Contains(content, "`") {
|
||||
mdContent := strings.ReplaceAll(content, "\n", "\n\n")
|
||||
compactContent = r.formatUserAssistantContent(mdContent)
|
||||
compactContent = removeBlankLines(compactContent)
|
||||
} else {
|
||||
compactContent = content
|
||||
}
|
||||
|
||||
// Handle multi-line content
|
||||
lines := strings.Split(compactContent, "\n")
|
||||
var formattedLines []string
|
||||
|
||||
for i, line := range lines {
|
||||
if i == 0 {
|
||||
// First line includes symbol and label
|
||||
formattedLines = append(formattedLines, fmt.Sprintf("%s %s %s", symbol, label, line))
|
||||
} else {
|
||||
// Subsequent lines without indentation for compact mode
|
||||
formattedLines = append(formattedLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
return UIMessage{
|
||||
Type: UserMessage,
|
||||
Content: strings.Join(formattedLines, "\n"),
|
||||
Height: len(formattedLines),
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderAssistantMessage renders an AI assistant's response in compact format with
|
||||
// a distinctive symbol (<) and the model name as label. Empty content is ignored
|
||||
// and returns an empty message. Returns a UIMessage with formatted content and metadata.
|
||||
func (r *CompactRenderer) RenderAssistantMessage(content string, timestamp time.Time, modelName string) UIMessage {
|
||||
// Ignore empty responses - don't render anything
|
||||
compactContent := r.formatUserAssistantContent(content)
|
||||
if compactContent == "" {
|
||||
return UIMessage{
|
||||
Type: AssistantMessage,
|
||||
Content: "",
|
||||
Height: 0,
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
theme := getTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Primary).Render("<")
|
||||
|
||||
// Use the full model name, fallback to "Assistant" if empty
|
||||
if modelName == "" {
|
||||
modelName = "Assistant"
|
||||
}
|
||||
label := lipgloss.NewStyle().Foreground(theme.Primary).Bold(true).Render(modelName)
|
||||
|
||||
// Handle multi-line content
|
||||
lines := strings.Split(compactContent, "\n")
|
||||
var formattedLines []string
|
||||
|
||||
for i, line := range lines {
|
||||
if i == 0 {
|
||||
// First line includes symbol and label
|
||||
formattedLines = append(formattedLines, fmt.Sprintf("%s %s %s", symbol, label, line))
|
||||
} else {
|
||||
// Subsequent lines without indentation for compact mode
|
||||
formattedLines = append(formattedLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
return UIMessage{
|
||||
Type: AssistantMessage,
|
||||
Content: strings.Join(formattedLines, "\n"),
|
||||
Height: len(formattedLines),
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderToolCallMessage renders a tool call notification in compact format, showing
|
||||
// the tool being executed with its arguments in a single line. The tool name is
|
||||
// highlighted and arguments are displayed in a muted color for visual distinction.
|
||||
func (r *CompactRenderer) RenderToolCallMessage(toolName, toolArgs string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Tool).Render("[")
|
||||
label := lipgloss.NewStyle().Foreground(theme.Tool).Bold(true).Render(toolName)
|
||||
|
||||
// Format args for compact display
|
||||
argsDisplay := r.formatToolArgs(toolArgs)
|
||||
if argsDisplay != "" {
|
||||
argsDisplay = lipgloss.NewStyle().Foreground(theme.Muted).Render(argsDisplay)
|
||||
}
|
||||
|
||||
line := fmt.Sprintf("%s %s %s", symbol, label, argsDisplay)
|
||||
|
||||
return UIMessage{
|
||||
Type: ToolCallMessage,
|
||||
Content: line,
|
||||
Height: 1,
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderToolMessage renders a unified tool block in compact format, combining
|
||||
// the tool invocation header (icon + display name + params) with the execution
|
||||
// result body. Status is indicated by icon: checkmark for success, cross for error.
|
||||
func (r *CompactRenderer) RenderToolMessage(toolName, toolArgs, toolResult string, isError bool) UIMessage {
|
||||
theme := getTheme()
|
||||
|
||||
// Resolve extension renderer once for all overrides.
|
||||
var extRd *ToolRendererData
|
||||
if r.getToolRenderer != nil {
|
||||
extRd = r.getToolRenderer(toolName)
|
||||
}
|
||||
|
||||
// Status icon
|
||||
var icon string
|
||||
iconColor := theme.Success
|
||||
if isError {
|
||||
icon = "×"
|
||||
iconColor = theme.Error
|
||||
} else {
|
||||
icon = "✓"
|
||||
}
|
||||
|
||||
iconStr := lipgloss.NewStyle().Foreground(iconColor).Bold(true).Render(icon)
|
||||
|
||||
// Extension can override display name.
|
||||
displayName := toolDisplayName(toolName)
|
||||
if extRd != nil && extRd.DisplayName != "" {
|
||||
displayName = extRd.DisplayName
|
||||
}
|
||||
nameStr := lipgloss.NewStyle().Foreground(theme.Info).Bold(true).Render(displayName)
|
||||
|
||||
// Format params — check extension renderer first.
|
||||
paramBudget := max(r.width-10-len(displayName), 20)
|
||||
var params string
|
||||
if extRd != nil && extRd.RenderHeader != nil {
|
||||
params = extRd.RenderHeader(toolArgs, paramBudget)
|
||||
}
|
||||
if params == "" {
|
||||
params = formatToolParams(toolArgs, paramBudget)
|
||||
}
|
||||
|
||||
// Build header line
|
||||
header := iconStr + " " + nameStr
|
||||
if params != "" {
|
||||
header += " " + lipgloss.NewStyle().Foreground(theme.Muted).Render(params)
|
||||
}
|
||||
|
||||
// Format body: check extension renderer first, then compact builtin, then default.
|
||||
var body string
|
||||
if extRd != nil && extRd.RenderBody != nil {
|
||||
body = extRd.RenderBody(toolResult, isError, r.width-4)
|
||||
// Apply markdown rendering if requested and body is non-empty.
|
||||
if body != "" && extRd.BodyMarkdown {
|
||||
body = strings.TrimSuffix(toMarkdown(body, r.width-4), "\n")
|
||||
}
|
||||
}
|
||||
if body == "" {
|
||||
if isError {
|
||||
body = lipgloss.NewStyle().Foreground(theme.Error).Render(r.formatToolResult(toolResult))
|
||||
} else {
|
||||
// Use compact summary renderers instead of full tool body renderers.
|
||||
body = renderToolBodyCompact(toolName, toolArgs, toolResult, r.width-4)
|
||||
if body == "" {
|
||||
formatted := r.formatToolResult(toolResult)
|
||||
if formatted == "" {
|
||||
body = lipgloss.NewStyle().Foreground(theme.Muted).Italic(true).Render("(no output)")
|
||||
} else {
|
||||
body = lipgloss.NewStyle().Foreground(theme.Muted).Render(formatted)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Combine header + indented body
|
||||
var lines []string
|
||||
lines = append(lines, header)
|
||||
if body != "" {
|
||||
for line := range strings.SplitSeq(body, "\n") {
|
||||
lines = append(lines, " "+line)
|
||||
}
|
||||
}
|
||||
|
||||
return UIMessage{
|
||||
Type: ToolMessage,
|
||||
Content: strings.Join(lines, "\n"),
|
||||
Height: len(lines),
|
||||
}
|
||||
}
|
||||
|
||||
// RenderSystemMessage renders a system notification or informational message in
|
||||
// compact format with a distinctive symbol (*) and "System" label. Content is
|
||||
// formatted to fit on a single line for minimal space usage.
|
||||
func (r *CompactRenderer) RenderSystemMessage(content string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Muted).Render("◇")
|
||||
label := lipgloss.NewStyle().Foreground(theme.Muted).Bold(true).Render("System")
|
||||
|
||||
compactContent := r.formatCompactContent(content)
|
||||
|
||||
line := fmt.Sprintf("%s %-8s %s", symbol, label, compactContent)
|
||||
|
||||
return UIMessage{
|
||||
Type: SystemMessage,
|
||||
Content: line,
|
||||
Height: 1,
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderErrorMessage renders an error notification in compact format with a
|
||||
// distinctive error symbol (!) and styling to ensure visibility. The error
|
||||
// content is displayed in a single line with appropriate color highlighting.
|
||||
func (r *CompactRenderer) RenderErrorMessage(errorMsg string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Error).Render("!")
|
||||
label := lipgloss.NewStyle().Foreground(theme.Error).Bold(true).Render("Error")
|
||||
|
||||
compactContent := lipgloss.NewStyle().Foreground(theme.Error).Render(r.formatCompactContent(errorMsg))
|
||||
|
||||
line := fmt.Sprintf("%s %-8s %s", symbol, label, compactContent)
|
||||
|
||||
return UIMessage{
|
||||
Type: ErrorMessage,
|
||||
Content: line,
|
||||
Height: 1,
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderDebugMessage renders diagnostic information in compact format when debug
|
||||
// mode is enabled. Messages are truncated if they exceed the available width to
|
||||
// maintain single-line display.
|
||||
func (r *CompactRenderer) RenderDebugMessage(message string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Tool).Render("*")
|
||||
label := lipgloss.NewStyle().Foreground(theme.Tool).Bold(true).Render("Debug")
|
||||
|
||||
// Truncate message if too long
|
||||
content := message
|
||||
if len(content) > r.width-20 {
|
||||
content = content[:r.width-23] + "..."
|
||||
}
|
||||
|
||||
line := fmt.Sprintf("%s %-8s %s", symbol, label, content)
|
||||
|
||||
return UIMessage{
|
||||
Type: SystemMessage,
|
||||
Content: line,
|
||||
Height: 1,
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderDebugConfigMessage renders configuration settings in compact format for
|
||||
// debugging purposes. Config entries are displayed as key=value pairs separated
|
||||
// by commas, truncated if necessary to fit on a single line.
|
||||
func (r *CompactRenderer) RenderDebugConfigMessage(config map[string]any, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Tool).Render("*")
|
||||
label := lipgloss.NewStyle().Foreground(theme.Tool).Bold(true).Render("Debug")
|
||||
|
||||
// Format config as compact key=value pairs
|
||||
var configPairs []string
|
||||
for key, value := range config {
|
||||
if value != nil {
|
||||
configPairs = append(configPairs, fmt.Sprintf("%s=%v", key, value))
|
||||
}
|
||||
}
|
||||
|
||||
content := strings.Join(configPairs, ", ")
|
||||
if len(content) > r.width-20 {
|
||||
content = content[:r.width-23] + "..."
|
||||
}
|
||||
|
||||
line := fmt.Sprintf("%s %-8s %s", symbol, label, content)
|
||||
|
||||
return UIMessage{
|
||||
Type: SystemMessage,
|
||||
Content: line,
|
||||
Height: 1,
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// formatCompactContent formats content for compact single-line display
|
||||
func (r *CompactRenderer) formatCompactContent(content string) string {
|
||||
if content == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Remove markdown formatting for compact display
|
||||
content = strings.ReplaceAll(content, "\n", " ")
|
||||
content = strings.ReplaceAll(content, "\t", " ")
|
||||
|
||||
// Collapse multiple spaces
|
||||
for strings.Contains(content, " ") {
|
||||
content = strings.ReplaceAll(content, " ", " ")
|
||||
}
|
||||
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
// Truncate if too long (unless in debug mode)
|
||||
maxLen := max(
|
||||
// Reserve space for symbol and label more conservatively
|
||||
r.width-28,
|
||||
// Minimum width for readability
|
||||
40)
|
||||
if !r.debug && len(content) > maxLen {
|
||||
content = content[:maxLen-3] + "..."
|
||||
}
|
||||
|
||||
return content
|
||||
}
|
||||
|
||||
// formatUserAssistantContent formats user and assistant content using glamour markdown rendering
|
||||
func (r *CompactRenderer) formatUserAssistantContent(content string) string {
|
||||
if content == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Calculate available width more conservatively
|
||||
// Account for: symbol (1) + spaces (2) + label (up to 20 chars) + space (1) + margin (4)
|
||||
availableWidth := max(r.width-28,
|
||||
// Minimum width for readability
|
||||
40)
|
||||
|
||||
// Use glamour to render markdown content with proper width
|
||||
rendered := toMarkdown(content, availableWidth)
|
||||
return strings.TrimSuffix(rendered, "\n")
|
||||
}
|
||||
|
||||
// wrapText wraps text to the specified width, preserving existing line breaks
|
||||
func (r *CompactRenderer) wrapText(text string, width int) string {
|
||||
if width <= 0 {
|
||||
return text
|
||||
}
|
||||
|
||||
lines := strings.Split(text, "\n")
|
||||
var wrappedLines []string
|
||||
|
||||
for _, line := range lines {
|
||||
if len(line) <= width {
|
||||
wrappedLines = append(wrappedLines, line)
|
||||
continue
|
||||
}
|
||||
|
||||
// Wrap long lines
|
||||
words := strings.Fields(line)
|
||||
if len(words) == 0 {
|
||||
wrappedLines = append(wrappedLines, line)
|
||||
continue
|
||||
}
|
||||
|
||||
currentLine := ""
|
||||
for _, word := range words {
|
||||
// If adding this word would exceed the width, start a new line
|
||||
if len(currentLine)+len(word)+1 > width && currentLine != "" {
|
||||
wrappedLines = append(wrappedLines, currentLine)
|
||||
currentLine = word
|
||||
} else {
|
||||
if currentLine == "" {
|
||||
currentLine = word
|
||||
} else {
|
||||
currentLine += " " + word
|
||||
}
|
||||
}
|
||||
}
|
||||
if currentLine != "" {
|
||||
wrappedLines = append(wrappedLines, currentLine)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(wrappedLines, "\n")
|
||||
}
|
||||
|
||||
// formatToolArgs formats tool arguments for compact display
|
||||
func (r *CompactRenderer) formatToolArgs(args string) string {
|
||||
if args == "" || args == "{}" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Remove JSON braces and format compactly
|
||||
args = strings.TrimSpace(args)
|
||||
if strings.HasPrefix(args, "{") && strings.HasSuffix(args, "}") {
|
||||
args = strings.TrimPrefix(args, "{")
|
||||
args = strings.TrimSuffix(args, "}")
|
||||
args = strings.TrimSpace(args)
|
||||
}
|
||||
|
||||
// Remove quotes around simple values
|
||||
args = strings.ReplaceAll(args, `"`, "")
|
||||
|
||||
// Remove parameter names (e.g., "command: ls" -> "ls", "path: /home" -> "/home")
|
||||
// Look for pattern "key: value" and extract just the value
|
||||
if colonIndex := strings.Index(args, ":"); colonIndex != -1 {
|
||||
args = strings.TrimSpace(args[colonIndex+1:])
|
||||
}
|
||||
|
||||
return r.formatCompactContent(args)
|
||||
}
|
||||
|
||||
// formatToolResult formats tool results preserving formatting but limiting to 5 lines
|
||||
func (r *CompactRenderer) formatToolResult(result string) string {
|
||||
if result == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check if this is bash output with stdout/stderr tags
|
||||
if strings.Contains(result, "<stdout>") || strings.Contains(result, "<stderr>") {
|
||||
result = r.formatBashOutput(result)
|
||||
}
|
||||
|
||||
// Calculate available width more conservatively
|
||||
availableWidth := max(r.width-28,
|
||||
// Minimum width for readability
|
||||
40)
|
||||
|
||||
// First wrap the text to prevent long lines (tool results are usually plain text, not markdown)
|
||||
wrappedResult := r.wrapText(result, availableWidth)
|
||||
|
||||
// Then limit to 5 lines
|
||||
lines := strings.Split(wrappedResult, "\n")
|
||||
if len(lines) > 5 {
|
||||
lines = lines[:5]
|
||||
// Add truncation indicator
|
||||
if len(lines) == 5 && lines[4] != "" {
|
||||
lines[4] = lines[4] + "..."
|
||||
} else {
|
||||
lines = append(lines, "...")
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// formatBashOutput formats bash command output by removing stdout/stderr tags
|
||||
// and styling appropriately. Delegates tag parsing to the shared parseBashOutput
|
||||
// helper.
|
||||
func (r *CompactRenderer) formatBashOutput(result string) string {
|
||||
return parseBashOutput(result, getTheme())
|
||||
}
|
||||
@@ -35,8 +35,11 @@ func GetTheme() Theme {
|
||||
|
||||
// SetTheme updates the global UI theme, affecting all subsequent rendering
|
||||
// operations. This allows runtime theme switching for different visual preferences.
|
||||
// It also invalidates the markdownTypographyCache so the next call to
|
||||
// GetMarkdownTypography picks up the new theme.
|
||||
func SetTheme(theme Theme) {
|
||||
currentTheme = theme
|
||||
markdownTypographyCache = nil // invalidate cached renderer; colors may have changed
|
||||
}
|
||||
|
||||
// MarkdownThemeColors defines colors for markdown rendering and syntax highlighting.
|
||||
@@ -291,45 +294,3 @@ func ApplyGradient(text string, colorA, colorB color.Color) string {
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// CreateGradientText creates styled text with a gradient effect between two colors.
|
||||
func CreateGradientText(text string, startColor, endColor color.Color) string {
|
||||
return ApplyGradient(text, startColor, endColor)
|
||||
}
|
||||
|
||||
// Compact styling utilities
|
||||
|
||||
// StyleCompactSymbol creates a lipgloss style for message type indicators in
|
||||
// compact mode, using bold colored text to distinguish different message categories.
|
||||
func StyleCompactSymbol(symbol string, c color.Color) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(c).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleCompactLabel creates a lipgloss style for message labels in compact mode
|
||||
// with fixed width for alignment and bold colored text for readability.
|
||||
func StyleCompactLabel(c color.Color) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(c).
|
||||
Bold(true).
|
||||
Width(8)
|
||||
}
|
||||
|
||||
// StyleCompactContent creates a simple lipgloss style for message content in
|
||||
// compact mode, applying only color without additional formatting.
|
||||
func StyleCompactContent(c color.Color) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(c)
|
||||
}
|
||||
|
||||
// FormatCompactLine assembles a complete compact mode message line with consistent
|
||||
// spacing and styling. Combines a symbol, fixed-width label, and content with their
|
||||
// respective colors to create a uniform appearance across all message types.
|
||||
func FormatCompactLine(symbol, label, content string, symbolColor, labelColor, contentColor color.Color) string {
|
||||
styledSymbol := StyleCompactSymbol(symbol, symbolColor).Render(symbol)
|
||||
styledLabel := StyleCompactLabel(labelColor).Render(label)
|
||||
styledContent := StyleCompactContent(contentColor).Render(content)
|
||||
|
||||
return fmt.Sprintf("%s %-8s %s", styledSymbol, styledLabel, styledContent)
|
||||
}
|
||||
|
||||
@@ -25,7 +25,6 @@ type CLISetupOptions struct {
|
||||
Agent AgentInterface
|
||||
ModelString string
|
||||
Debug bool
|
||||
Compact bool
|
||||
Quiet bool
|
||||
ShowDebug bool // Whether to show debug config
|
||||
ProviderAPIKey string // For OAuth detection
|
||||
@@ -76,7 +75,7 @@ func SetupCLI(opts *CLISetupOptions) (*CLI, error) {
|
||||
return nil, nil // No CLI in quiet mode
|
||||
}
|
||||
|
||||
cli, err := NewCLI(opts.Debug, opts.Compact)
|
||||
cli, err := NewCLI(opts.Debug)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create CLI: %v", err)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// FileSuggestion represents a single file or directory suggestion for the @
|
||||
@@ -345,44 +344,14 @@ func scoreFilePath(query, path string) int {
|
||||
}
|
||||
|
||||
// Fuzzy character match on basename.
|
||||
if score := fuzzyCharMatch(query, baseNameLower); score > 0 {
|
||||
if score := fuzzyCharacterMatch(query, baseNameLower); score > 0 {
|
||||
return score
|
||||
}
|
||||
|
||||
// Fuzzy character match on full path.
|
||||
if score := fuzzyCharMatch(query, pathLower); score > 0 {
|
||||
if score := fuzzyCharacterMatch(query, pathLower); score > 0 {
|
||||
return score - 50
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// fuzzyCharMatch performs character-by-character fuzzy matching. Returns a
|
||||
// positive score if all query characters appear in order in the target.
|
||||
func fuzzyCharMatch(query, target string) int {
|
||||
if utf8.RuneCountInString(query) > utf8.RuneCountInString(target) {
|
||||
return 0
|
||||
}
|
||||
|
||||
qRunes := []rune(query)
|
||||
tRunes := []rune(target)
|
||||
qi := 0
|
||||
score := 100
|
||||
consecutive := 0
|
||||
|
||||
for ti := 0; ti < len(tRunes) && qi < len(qRunes); ti++ {
|
||||
if tRunes[ti] == qRunes[qi] {
|
||||
qi++
|
||||
consecutive++
|
||||
score += consecutive * 5
|
||||
} else {
|
||||
consecutive = 0
|
||||
score -= 2
|
||||
}
|
||||
}
|
||||
|
||||
if qi < len(qRunes) {
|
||||
return 0
|
||||
}
|
||||
return score
|
||||
}
|
||||
|
||||
@@ -7,29 +7,29 @@ import (
|
||||
"charm.land/lipgloss/v2"
|
||||
)
|
||||
|
||||
// Renderer is the interface satisfied by both MessageRenderer and
|
||||
// CompactRenderer. It allows model.go and cli.go to call rendering methods
|
||||
// without branching on compact mode.
|
||||
// Renderer is the interface satisfied by MessageRenderer. It allows model.go
|
||||
// and cli.go to call rendering methods uniformly.
|
||||
type Renderer interface {
|
||||
RenderUserMessage(content string, timestamp time.Time) UIMessage
|
||||
RenderAssistantMessage(content string, timestamp time.Time, modelName string) UIMessage
|
||||
RenderReasoningBlock(content string, timestamp time.Time) UIMessage
|
||||
RenderToolMessage(toolName, toolArgs, toolResult string, isError bool) UIMessage
|
||||
RenderSystemMessage(content string, timestamp time.Time) UIMessage
|
||||
RenderErrorMessage(errorMsg string, timestamp time.Time) UIMessage
|
||||
RenderDebugMessage(message string, timestamp time.Time) UIMessage
|
||||
RenderDebugConfigMessage(config map[string]any, timestamp time.Time) UIMessage
|
||||
SetWidth(width int)
|
||||
UpdateTheme()
|
||||
}
|
||||
|
||||
// Compile-time checks that both renderers satisfy the Renderer interface.
|
||||
// Compile-time check that MessageRenderer satisfies the Renderer interface.
|
||||
var _ Renderer = (*MessageRenderer)(nil)
|
||||
var _ Renderer = (*CompactRenderer)(nil)
|
||||
|
||||
// parseBashOutput parses <stdout>/<stderr> tagged output from bash tool
|
||||
// results, styling stderr with the theme's error color. Returns the
|
||||
// combined, styled output string with tags stripped.
|
||||
//
|
||||
// Shared by both MessageRenderer and CompactRenderer.
|
||||
// Shared by MessageRenderer.
|
||||
func parseBashOutput(result string, theme Theme) string {
|
||||
var formattedResult strings.Builder
|
||||
remaining := result
|
||||
|
||||
+11
-7
@@ -113,19 +113,23 @@ func fuzzyScore(query string, cmd *SlashCommand) int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// fuzzyCharacterMatch performs character-by-character fuzzy matching
|
||||
// fuzzyCharacterMatch performs character-by-character fuzzy matching using
|
||||
// rune-safe iteration so multi-byte Unicode characters are handled correctly.
|
||||
// Returns a positive score if all query runes appear in order within target.
|
||||
func fuzzyCharacterMatch(query, target string) int {
|
||||
if len(query) > len(target) {
|
||||
qRunes := []rune(query)
|
||||
tRunes := []rune(target)
|
||||
if len(qRunes) > len(tRunes) {
|
||||
return 0
|
||||
}
|
||||
|
||||
queryIdx := 0
|
||||
qi := 0
|
||||
score := 100
|
||||
consecutiveMatches := 0
|
||||
|
||||
for i := 0; i < len(target) && queryIdx < len(query); i++ {
|
||||
if target[i] == query[queryIdx] {
|
||||
queryIdx++
|
||||
for ti := 0; ti < len(tRunes) && qi < len(qRunes); ti++ {
|
||||
if tRunes[ti] == qRunes[qi] {
|
||||
qi++
|
||||
consecutiveMatches++
|
||||
score += consecutiveMatches * 10
|
||||
} else {
|
||||
@@ -135,7 +139,7 @@ func fuzzyCharacterMatch(query, target string) int {
|
||||
}
|
||||
|
||||
// Must match all characters in query
|
||||
if queryIdx < len(query) {
|
||||
if qi < len(qRunes) {
|
||||
return 0
|
||||
}
|
||||
|
||||
|
||||
+51
-18
@@ -65,6 +65,10 @@ type InputComponent struct {
|
||||
// hideHint suppresses the "enter submit · ctrl+j..." hint text.
|
||||
hideHint bool
|
||||
|
||||
// agentBusy indicates the agent is currently working. When true, the
|
||||
// hint text shows steering shortcut (Ctrl+S) instead of submit.
|
||||
agentBusy bool
|
||||
|
||||
// pendingImages holds clipboard images attached to the next submission.
|
||||
// Images are added via Ctrl+V and cleared on submit or Ctrl+U.
|
||||
pendingImages []ImageAttachment
|
||||
@@ -405,21 +409,14 @@ func (s *InputComponent) handleSubmit(value string) tea.Cmd {
|
||||
}
|
||||
|
||||
// Resolve via canonical command lookup so aliases are handled uniformly.
|
||||
// Only /quit and /clear are handled locally — /clear-queue must go
|
||||
// through the parent model so it can update queueCount directly
|
||||
// (calling ClearQueue here would skip the UI state update since we
|
||||
// can't send events from within Update without deadlocking).
|
||||
// Only /quit is handled locally — all other slash commands (including
|
||||
// /clear and /clear-queue) are forwarded to the parent model via
|
||||
// submitMsg so the parent can update its own state (scrollback, queue
|
||||
// counts, etc.) in one place.
|
||||
if sc := GetCommandByName(trimmed); sc != nil {
|
||||
switch sc.Name {
|
||||
case "/quit":
|
||||
return tea.Quit
|
||||
|
||||
case "/clear":
|
||||
if s.appCtrl != nil {
|
||||
s.appCtrl.ClearMessages()
|
||||
}
|
||||
// Don't forward to app.Run(); just clear silently.
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -489,10 +486,8 @@ func (s *InputComponent) View() tea.View {
|
||||
view.WriteString("\n")
|
||||
view.WriteString(inputBoxStyle.Render(s.textarea.View()))
|
||||
|
||||
if s.showPopup && len(s.filtered) > 0 {
|
||||
view.WriteString("\n")
|
||||
view.WriteString(s.renderPopup())
|
||||
}
|
||||
// Popup is now rendered as a centered overlay in AppModel.View()
|
||||
// instead of inline here to prevent bottom overflow
|
||||
|
||||
// Show image attachment indicator when images are pending.
|
||||
if len(s.pendingImages) > 0 {
|
||||
@@ -514,7 +509,16 @@ func (s *InputComponent) View() tea.View {
|
||||
// Adapt hint text to available width (accounting for left padding of 3).
|
||||
var hint string
|
||||
availableHintWidth := s.width - 3
|
||||
if availableHintWidth >= 67 {
|
||||
if s.agentBusy {
|
||||
// When the agent is working, show steering shortcut.
|
||||
if availableHintWidth >= 55 {
|
||||
hint = "enter queue • ctrl+s steer • esc esc cancel"
|
||||
} else if availableHintWidth >= 35 {
|
||||
hint = "↵ queue • ^S steer • esc×2 cancel"
|
||||
} else {
|
||||
hint = "^S steer"
|
||||
}
|
||||
} else if availableHintWidth >= 67 {
|
||||
hint = "enter submit • ctrl+j / shift+enter new line • ctrl+v paste image"
|
||||
} else if availableHintWidth >= 40 {
|
||||
hint = "↵ submit • ctrl+j newline • ctrl+v image"
|
||||
@@ -527,11 +531,40 @@ func (s *InputComponent) View() tea.View {
|
||||
view.WriteString(helpStyle.Render(hint))
|
||||
}
|
||||
|
||||
return tea.NewView(containerStyle.Render(view.String()))
|
||||
v := tea.NewView(containerStyle.Render(view.String()))
|
||||
v.AltScreen = true
|
||||
v.MouseMode = tea.MouseModeCellMotion
|
||||
v.ReportFocus = true
|
||||
v.KeyboardEnhancements = tea.KeyboardEnhancements{
|
||||
ReportEventTypes: true,
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// renderPopup renders the autocomplete popup for slash command suggestions.
|
||||
func (s *InputComponent) renderPopup() string {
|
||||
// When rendered inline (not centered), returns the styled popup content.
|
||||
// RenderPopupCentered renders the popup as a centered overlay.
|
||||
func (s *InputComponent) RenderPopupCentered(termWidth, termHeight int) string {
|
||||
if !s.showPopup || len(s.filtered) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
popupContent := s.renderPopupWithOptions(true)
|
||||
|
||||
// Center popup using lipgloss.Place
|
||||
positioned := lipgloss.Place(
|
||||
termWidth,
|
||||
termHeight,
|
||||
lipgloss.Center,
|
||||
lipgloss.Center,
|
||||
popupContent,
|
||||
)
|
||||
|
||||
return positioned
|
||||
}
|
||||
|
||||
// renderPopupWithOptions renders the popup content with optional center styling.
|
||||
func (s *InputComponent) renderPopupWithOptions(centered bool) string {
|
||||
theme := GetTheme()
|
||||
popupWidth := max(s.width-4, 20)
|
||||
popupStyle := lipgloss.NewStyle().
|
||||
|
||||
@@ -0,0 +1,398 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/lipgloss/v2"
|
||||
)
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// MessageItem implementations for ScrollList
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// TextMessageItem represents a completed text message (user or assistant)
|
||||
// in the scrollback. It uses pre-rendered styled content from MessageRenderer.
|
||||
type TextMessageItem struct {
|
||||
id string
|
||||
role string // "user" or "assistant"
|
||||
content string // Raw content (for re-rendering if needed)
|
||||
preRendered string // Pre-rendered styled content from MessageRenderer
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
// NewTextMessageItem creates a new text message for the scrollback.
|
||||
// The content should be pre-rendered using MessageRenderer for proper styling.
|
||||
func NewTextMessageItem(id string, role string, content string) *TextMessageItem {
|
||||
return &TextMessageItem{
|
||||
id: id,
|
||||
role: role,
|
||||
content: content,
|
||||
timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewStyledMessageItem creates a message item with pre-rendered styled content.
|
||||
// This is the preferred way to create messages when you have styled content from MessageRenderer.
|
||||
func NewStyledMessageItem(id string, role string, rawContent string, preRendered string) *TextMessageItem {
|
||||
return &TextMessageItem{
|
||||
id: id,
|
||||
role: role,
|
||||
content: rawContent,
|
||||
preRendered: preRendered,
|
||||
timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TextMessageItem) ID() string {
|
||||
return m.id
|
||||
}
|
||||
|
||||
func (m *TextMessageItem) Render(width int) string {
|
||||
// If we have pre-rendered styled content, return it
|
||||
if m.preRendered != "" {
|
||||
return m.preRendered
|
||||
}
|
||||
|
||||
// Fallback to simple formatting if no pre-rendered content
|
||||
return m.renderContent(width)
|
||||
}
|
||||
|
||||
func (m *TextMessageItem) Height() int {
|
||||
rendered := m.Render(0) // Width doesn't matter since we use pre-rendered
|
||||
if rendered == "" {
|
||||
return 0
|
||||
}
|
||||
return strings.Count(rendered, "\n") + 1
|
||||
}
|
||||
|
||||
func (m *TextMessageItem) renderContent(width int) string {
|
||||
var parts []string
|
||||
|
||||
// Role indicator
|
||||
if m.role == "user" {
|
||||
parts = append(parts, "│ ▸ You")
|
||||
} else {
|
||||
parts = append(parts, "") // Assistant messages start without role
|
||||
}
|
||||
|
||||
// Content with simple wrapping
|
||||
contentWidth := max(width-4, 20)
|
||||
|
||||
for line := range strings.SplitSeq(m.content, "\n") {
|
||||
if len(line) <= contentWidth {
|
||||
parts = append(parts, "│ "+line)
|
||||
} else {
|
||||
// Basic wrap
|
||||
for len(line) > contentWidth {
|
||||
parts = append(parts, "│ "+line[:contentWidth])
|
||||
line = line[contentWidth:]
|
||||
}
|
||||
if len(line) > 0 {
|
||||
parts = append(parts, "│ "+line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// StreamingMessageItem - Live streaming assistant/reasoning text
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// StreamingMessageItem represents actively streaming assistant or reasoning text.
|
||||
// It accumulates content chunks and re-renders on each update for live display.
|
||||
type StreamingMessageItem struct {
|
||||
id string
|
||||
role string // "assistant" or "reasoning"
|
||||
content string // Accumulated streaming content
|
||||
timestamp time.Time
|
||||
startTime time.Time // When streaming started (for live duration counter)
|
||||
modelName string
|
||||
streaming bool // true while actively streaming
|
||||
finalDuration time.Duration // Frozen duration when complete
|
||||
cachedRender string
|
||||
cachedWidth int
|
||||
}
|
||||
|
||||
// NewStreamingMessageItem creates a new streaming message item.
|
||||
func NewStreamingMessageItem(id, role string, modelName string) *StreamingMessageItem {
|
||||
now := time.Now()
|
||||
return &StreamingMessageItem{
|
||||
id: id,
|
||||
role: role,
|
||||
timestamp: now,
|
||||
startTime: now,
|
||||
modelName: modelName,
|
||||
streaming: true,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the unique identifier.
|
||||
func (s *StreamingMessageItem) ID() string {
|
||||
return s.id
|
||||
}
|
||||
|
||||
// Render renders the streaming message with live content.
|
||||
func (s *StreamingMessageItem) Render(width int) string {
|
||||
// For reasoning, never cache - we need live duration updates
|
||||
// For assistant, cache is OK
|
||||
if s.role != "reasoning" && s.cachedWidth == width && s.cachedRender != "" {
|
||||
return s.cachedRender
|
||||
}
|
||||
|
||||
// Get renderer from context
|
||||
renderer := newMessageRenderer(width, false)
|
||||
|
||||
var rendered string
|
||||
if s.role == "reasoning" {
|
||||
// Render as reasoning/thinking block with live duration counter
|
||||
theme := GetTheme()
|
||||
mutedStyle := lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
ty := createTypography(theme)
|
||||
content := strings.TrimLeft(s.content, " \t\n")
|
||||
|
||||
var parts []string
|
||||
parts = append(parts, mutedStyle.Render(ty.Italic(content)))
|
||||
|
||||
// Add live duration counter (updates on each render)
|
||||
var duration time.Duration
|
||||
if s.finalDuration > 0 {
|
||||
// Streaming complete, show frozen duration
|
||||
duration = s.finalDuration
|
||||
} else if !s.startTime.IsZero() {
|
||||
// Still streaming, show live duration
|
||||
duration = time.Since(s.startTime)
|
||||
}
|
||||
|
||||
if duration > 0 {
|
||||
var durationStr string
|
||||
if duration < time.Second {
|
||||
durationStr = fmt.Sprintf("%dms", duration.Milliseconds())
|
||||
} else {
|
||||
durationStr = fmt.Sprintf("%.1fs", duration.Seconds())
|
||||
}
|
||||
label := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render("Thought for ")
|
||||
durationStyled := lipgloss.NewStyle().Foreground(theme.Accent).Render(durationStr)
|
||||
parts = append(parts, label+durationStyled)
|
||||
}
|
||||
|
||||
rendered = styleMarginBottom1.Render(strings.Join(parts, "\n"))
|
||||
} else {
|
||||
// Render as assistant message
|
||||
msg := renderer.RenderAssistantMessage(s.content, s.timestamp, s.modelName)
|
||||
rendered = msg.Content
|
||||
}
|
||||
|
||||
// Cache and return (but reasoning is never cached due to live duration)
|
||||
if s.role != "reasoning" {
|
||||
s.cachedRender = rendered
|
||||
s.cachedWidth = width
|
||||
}
|
||||
return rendered
|
||||
}
|
||||
|
||||
// Height returns the number of lines.
|
||||
func (s *StreamingMessageItem) Height() int {
|
||||
if s.cachedRender == "" {
|
||||
return 0
|
||||
}
|
||||
return strings.Count(s.cachedRender, "\n") + 1
|
||||
}
|
||||
|
||||
// AppendChunk adds a content chunk and invalidates the render cache.
|
||||
func (s *StreamingMessageItem) AppendChunk(chunk string) {
|
||||
s.content += chunk
|
||||
s.cachedWidth = 0 // Invalidate cache
|
||||
}
|
||||
|
||||
// MarkComplete marks the streaming message as complete and freezes the duration.
|
||||
func (s *StreamingMessageItem) MarkComplete() {
|
||||
s.streaming = false
|
||||
// Freeze the duration for reasoning blocks
|
||||
if s.role == "reasoning" && !s.startTime.IsZero() {
|
||||
s.finalDuration = time.Since(s.startTime)
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// StreamingBashOutputItem - Live bash command output
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// StreamingBashOutputItem represents live bash command output.
|
||||
type StreamingBashOutputItem struct {
|
||||
id string
|
||||
command string
|
||||
stdoutLines []string
|
||||
stderrLines []string
|
||||
maxLines int
|
||||
complete bool
|
||||
cachedRender string
|
||||
cachedWidth int
|
||||
}
|
||||
|
||||
// NewStreamingBashOutputItem creates a new streaming bash output item.
|
||||
func NewStreamingBashOutputItem(id string, command string) *StreamingBashOutputItem {
|
||||
return &StreamingBashOutputItem{
|
||||
id: id,
|
||||
command: command,
|
||||
stdoutLines: make([]string, 0),
|
||||
stderrLines: make([]string, 0),
|
||||
maxLines: 100, // Cap lines to prevent memory issues
|
||||
complete: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *StreamingBashOutputItem) ID() string {
|
||||
return m.id
|
||||
}
|
||||
|
||||
func (m *StreamingBashOutputItem) Render(width int) string {
|
||||
// Return cached if width matches and complete
|
||||
if m.complete && m.cachedWidth == width && m.cachedRender != "" {
|
||||
return m.cachedRender
|
||||
}
|
||||
|
||||
theme := GetTheme()
|
||||
var parts []string
|
||||
|
||||
// Header with command
|
||||
if m.command != "" {
|
||||
headerStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Italic(true)
|
||||
parts = append(parts, headerStyle.Render(fmt.Sprintf("▸ %s", m.command)))
|
||||
}
|
||||
|
||||
const lineIndent = " "
|
||||
lineWidth := width - len(lineIndent)
|
||||
|
||||
// Stdout lines
|
||||
if len(m.stdoutLines) > 0 {
|
||||
outputStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Text).
|
||||
Background(theme.CodeBg).
|
||||
PaddingLeft(1).
|
||||
Width(lineWidth)
|
||||
for _, line := range m.stdoutLines {
|
||||
parts = append(parts, lineIndent+outputStyle.Render(line))
|
||||
}
|
||||
}
|
||||
|
||||
// Stderr lines
|
||||
if len(m.stderrLines) > 0 {
|
||||
stderrStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Error).
|
||||
Background(theme.CodeBg).
|
||||
PaddingLeft(1).
|
||||
Width(lineWidth)
|
||||
for _, line := range m.stderrLines {
|
||||
parts = append(parts, lineIndent+stderrStyle.Render(line))
|
||||
}
|
||||
}
|
||||
|
||||
result := strings.Join(parts, "\n")
|
||||
if m.complete {
|
||||
m.cachedRender = result
|
||||
m.cachedWidth = width
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *StreamingBashOutputItem) Height() int {
|
||||
if m.cachedRender != "" {
|
||||
return strings.Count(m.cachedRender, "\n") + 1
|
||||
}
|
||||
// Estimate: command header + stdout + stderr
|
||||
return 1 + len(m.stdoutLines) + len(m.stderrLines)
|
||||
}
|
||||
|
||||
// AppendStdout adds a stdout line to the output.
|
||||
func (m *StreamingBashOutputItem) AppendStdout(line string) {
|
||||
m.stdoutLines = append(m.stdoutLines, line)
|
||||
// Cap lines
|
||||
if len(m.stdoutLines) > m.maxLines {
|
||||
m.stdoutLines = m.stdoutLines[len(m.stdoutLines)-m.maxLines:]
|
||||
}
|
||||
m.cachedWidth = 0 // Invalidate cache
|
||||
}
|
||||
|
||||
// AppendStderr adds a stderr line to the output.
|
||||
func (m *StreamingBashOutputItem) AppendStderr(line string) {
|
||||
m.stderrLines = append(m.stderrLines, line)
|
||||
// Cap lines
|
||||
if len(m.stderrLines) > m.maxLines {
|
||||
m.stderrLines = m.stderrLines[len(m.stderrLines)-m.maxLines:]
|
||||
}
|
||||
m.cachedWidth = 0 // Invalidate cache
|
||||
}
|
||||
|
||||
// MarkComplete marks the bash output as complete.
|
||||
func (m *StreamingBashOutputItem) MarkComplete() {
|
||||
m.complete = true
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// SystemMessageItem - System messages (commands, info, errors)
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// SystemMessageItem represents a system message (commands, info, errors).
|
||||
type SystemMessageItem struct {
|
||||
id string
|
||||
content string
|
||||
timestamp time.Time
|
||||
cachedRender string
|
||||
cachedWidth int
|
||||
}
|
||||
|
||||
// NewSystemMessageItem creates a new system message for the scrollback.
|
||||
func NewSystemMessageItem(id, content string) *SystemMessageItem {
|
||||
return &SystemMessageItem{
|
||||
id: id,
|
||||
content: content,
|
||||
timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SystemMessageItem) ID() string {
|
||||
return m.id
|
||||
}
|
||||
|
||||
func (m *SystemMessageItem) Render(width int) string {
|
||||
// Return cached render if width matches
|
||||
if m.cachedWidth == width && m.cachedRender != "" {
|
||||
return m.cachedRender
|
||||
}
|
||||
|
||||
// Simple system message formatting
|
||||
rendered := "│ " + strings.ReplaceAll(m.content, "\n", "\n│ ")
|
||||
|
||||
// Cache and return
|
||||
m.cachedRender = rendered
|
||||
m.cachedWidth = width
|
||||
return rendered
|
||||
}
|
||||
|
||||
func (m *SystemMessageItem) Height() int {
|
||||
if m.cachedRender != "" {
|
||||
return strings.Count(m.cachedRender, "\n") + 1
|
||||
}
|
||||
// Estimate
|
||||
if m.cachedWidth > 0 {
|
||||
return (len(m.content) / max(m.cachedWidth-10, 40)) + 3
|
||||
}
|
||||
return 3
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Helper: generateMessageID
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
var messageCounter = 0
|
||||
|
||||
func generateMessageID() string {
|
||||
messageCounter++
|
||||
return fmt.Sprintf("msg-%d-%d", time.Now().UnixNano(), messageCounter)
|
||||
}
|
||||
+156
-377
@@ -3,17 +3,14 @@ package ui
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/lipgloss/v2"
|
||||
"github.com/indaco/herald"
|
||||
)
|
||||
|
||||
// ansiEscapeRe matches ANSI escape sequences used for terminal styling.
|
||||
var ansiEscapeRe = regexp.MustCompile(`\x1b\[[0-9;]*m`)
|
||||
|
||||
// MessageType represents different categories of messages displayed in the UI,
|
||||
// each with distinct visual styling and formatting rules.
|
||||
type MessageType int
|
||||
@@ -22,9 +19,9 @@ const (
|
||||
UserMessage MessageType = iota
|
||||
AssistantMessage
|
||||
ToolMessage
|
||||
ToolCallMessage // New type for showing tool calls in progress
|
||||
SystemMessage // New type for KIT system messages (help, tools, etc.)
|
||||
ErrorMessage // New type for error messages
|
||||
ToolCallMessage
|
||||
SystemMessage
|
||||
ErrorMessage
|
||||
)
|
||||
|
||||
// UIMessage encapsulates a fully rendered message ready for display in the UI,
|
||||
@@ -40,29 +37,9 @@ type UIMessage struct {
|
||||
Streaming bool
|
||||
}
|
||||
|
||||
// Helper functions to get theme colors
|
||||
func getTheme() Theme {
|
||||
return GetTheme()
|
||||
}
|
||||
|
||||
// toolDisplayNames maps raw tool names to human-friendly display names.
|
||||
var toolDisplayNames = map[string]string{
|
||||
"bash": "Bash",
|
||||
"read": "Read",
|
||||
"write": "Write",
|
||||
"edit": "Edit",
|
||||
"grep": "Grep",
|
||||
"find": "Find",
|
||||
"ls": "Ls",
|
||||
"run_shell_cmd": "Bash",
|
||||
}
|
||||
|
||||
// toolDisplayName returns a human-friendly display name for a tool.
|
||||
// Falls back to capitalizing the first letter of the raw name.
|
||||
// toolDisplayName returns a human-friendly display name for a tool,
|
||||
// title-casing the first letter of the raw name.
|
||||
func toolDisplayName(rawName string) string {
|
||||
if display, ok := toolDisplayNames[rawName]; ok {
|
||||
return display
|
||||
}
|
||||
if rawName != "" {
|
||||
return strings.ToUpper(rawName[:1]) + rawName[1:]
|
||||
}
|
||||
@@ -70,8 +47,6 @@ func toolDisplayName(rawName string) string {
|
||||
}
|
||||
|
||||
// formatToolParams formats tool input parameters for inline header display.
|
||||
// Extracts the primary parameter (command/filePath) first, then shows
|
||||
// remaining params as (key=val, ...). Truncates to maxWidth.
|
||||
func formatToolParams(toolArgs string, maxWidth int) string {
|
||||
args := strings.TrimSpace(toolArgs)
|
||||
if args == "" || args == "{}" {
|
||||
@@ -80,7 +55,6 @@ func formatToolParams(toolArgs string, maxWidth int) string {
|
||||
|
||||
var params map[string]any
|
||||
if err := json.Unmarshal([]byte(args), ¶ms); err != nil {
|
||||
// Fallback: strip braces and return raw content
|
||||
args = strings.TrimPrefix(args, "{")
|
||||
args = strings.TrimSuffix(args, "}")
|
||||
args = strings.TrimSpace(args)
|
||||
@@ -94,7 +68,6 @@ func formatToolParams(toolArgs string, maxWidth int) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Identify primary parameter by checking known keys in priority order
|
||||
primaryKeys := []string{"command", "filePath", "path", "pattern", "query", "url"}
|
||||
var primaryKey string
|
||||
var primaryVal string
|
||||
@@ -111,14 +84,13 @@ func formatToolParams(toolArgs string, maxWidth int) string {
|
||||
result.WriteString(primaryVal)
|
||||
}
|
||||
|
||||
// Collect remaining parameters, skipping body-content keys (already
|
||||
// rendered in the tool body) and any values that are too large.
|
||||
bodyKeys := map[string]bool{
|
||||
"content": true,
|
||||
"old_text": true,
|
||||
"new_text": true,
|
||||
"oldText": true,
|
||||
"newText": true,
|
||||
"edits": true,
|
||||
"todos": true,
|
||||
}
|
||||
var remaining []string
|
||||
@@ -154,65 +126,35 @@ func formatToolParams(toolArgs string, maxWidth int) string {
|
||||
}
|
||||
|
||||
// MessageRenderer handles the formatting and rendering of different message types
|
||||
// with consistent styling, markdown support, and appropriate visual hierarchies
|
||||
// for the standard (non-compact) display mode.
|
||||
type MessageRenderer struct {
|
||||
width int
|
||||
debug bool
|
||||
|
||||
// getToolRenderer returns extension-provided rendering overrides for a
|
||||
// specific tool. May be nil if no extensions are loaded. Used in
|
||||
// RenderToolMessage to check for custom header/body formatting before
|
||||
// falling back to builtin renderers.
|
||||
width int
|
||||
debug bool
|
||||
ty *herald.Typography
|
||||
getToolRenderer func(toolName string) *ToolRendererData
|
||||
}
|
||||
|
||||
// newMessageRenderer creates and initializes a new MessageRenderer with the specified
|
||||
// terminal width and debug mode setting. The width parameter determines line wrapping
|
||||
// and layout calculations.
|
||||
// newMessageRenderer creates and initializes a new MessageRenderer
|
||||
func newMessageRenderer(width int, debug bool) *MessageRenderer {
|
||||
return &MessageRenderer{
|
||||
width: width,
|
||||
debug: debug,
|
||||
ty: createTypography(GetTheme()),
|
||||
}
|
||||
}
|
||||
|
||||
// SetWidth updates the terminal width for the renderer, affecting how content
|
||||
// is wrapped and formatted in subsequent render operations.
|
||||
// SetWidth updates the terminal width for the renderer
|
||||
func (r *MessageRenderer) SetWidth(width int) {
|
||||
r.width = width
|
||||
}
|
||||
|
||||
// RenderUserMessage renders a user's input message with distinctive right-aligned
|
||||
// formatting, including the system username, timestamp, and markdown-rendered content.
|
||||
// The message is displayed with a colored right border for visual distinction.
|
||||
// RenderUserMessage renders a user's input message using herald Tip alert
|
||||
func (r *MessageRenderer) RenderUserMessage(content string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
|
||||
// Only run markdown rendering when the message contains code spans or
|
||||
// fenced code blocks. Plain text is rendered directly so that newlines
|
||||
// are preserved without the extra paragraph spacing glamour adds.
|
||||
var messageContent string
|
||||
if strings.Contains(content, "`") {
|
||||
// Glamour treats single \n as a soft break, so convert to paragraph
|
||||
// breaks and collapse the resulting blank lines after rendering.
|
||||
mdContent := strings.ReplaceAll(content, "\n", "\n\n")
|
||||
messageContent = r.renderMarkdown(mdContent, r.width-8)
|
||||
messageContent = removeBlankLines(messageContent)
|
||||
} else {
|
||||
messageContent = content
|
||||
if strings.TrimSpace(content) == "" {
|
||||
content = "(empty message)"
|
||||
}
|
||||
|
||||
fullContent := strings.TrimSuffix(messageContent, "\n")
|
||||
|
||||
// Left border with Blue color for user messages.
|
||||
rendered := renderContentBlock(
|
||||
fullContent,
|
||||
r.width,
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(theme.Info),
|
||||
WithMarginBottom(1),
|
||||
)
|
||||
rendered := r.ty.Tip(content)
|
||||
rendered = styleMarginBottom1.Render(rendered)
|
||||
|
||||
return UIMessage{
|
||||
Type: UserMessage,
|
||||
@@ -222,12 +164,8 @@ func (r *MessageRenderer) RenderUserMessage(content string, timestamp time.Time)
|
||||
}
|
||||
}
|
||||
|
||||
// RenderAssistantMessage renders an AI assistant's response with left-aligned formatting,
|
||||
// including the model name, timestamp, and markdown-rendered content. Empty responses
|
||||
// are ignored and return an empty message. The message features a colored left border
|
||||
// for visual distinction.
|
||||
// RenderAssistantMessage renders an AI assistant's response
|
||||
func (r *MessageRenderer) RenderAssistantMessage(content string, timestamp time.Time, modelName string) UIMessage {
|
||||
// Ignore empty responses - don't render anything
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return UIMessage{
|
||||
Type: AssistantMessage,
|
||||
@@ -237,17 +175,9 @@ func (r *MessageRenderer) RenderAssistantMessage(content string, timestamp time.
|
||||
}
|
||||
}
|
||||
|
||||
theme := getTheme()
|
||||
messageContent := r.renderMarkdown(content, r.width-8)
|
||||
fullContent := strings.TrimSuffix(messageContent, "\n")
|
||||
|
||||
// Left border with Primary (Mauve) color for assistant messages.
|
||||
rendered := renderContentBlock(
|
||||
fullContent,
|
||||
r.width,
|
||||
WithBorderColor(theme.Primary),
|
||||
WithMarginBottom(1),
|
||||
)
|
||||
// Use markdown rendering with Chroma syntax highlighting
|
||||
rendered := toMarkdown(content, r.width-4)
|
||||
rendered = styleMarginBottom1.Render(rendered)
|
||||
|
||||
return UIMessage{
|
||||
Type: AssistantMessage,
|
||||
@@ -257,30 +187,44 @@ func (r *MessageRenderer) RenderAssistantMessage(content string, timestamp time.
|
||||
}
|
||||
}
|
||||
|
||||
// RenderSystemMessage renders KIT system messages such as help text, command outputs,
|
||||
// and informational notifications. These messages are displayed with a distinctive system
|
||||
// color border and "KIT System" label to differentiate them from user and AI content.
|
||||
func (r *MessageRenderer) RenderSystemMessage(content string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
|
||||
var messageContent string
|
||||
// RenderReasoningBlock renders a reasoning/thinking block with the same styling
|
||||
// as live streaming: muted italic text with margin. This is used when resuming
|
||||
// sessions to display saved reasoning content.
|
||||
func (r *MessageRenderer) RenderReasoningBlock(content string, timestamp time.Time) UIMessage {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
messageContent = "No content available"
|
||||
} else if strings.Contains(content, "`") {
|
||||
messageContent = r.renderMarkdown(content, r.width-8)
|
||||
} else {
|
||||
messageContent = content
|
||||
return UIMessage{
|
||||
Type: AssistantMessage,
|
||||
Content: "",
|
||||
Height: 0,
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
fullContent := "◇ " + strings.TrimSuffix(messageContent, "\n")
|
||||
theme := GetTheme()
|
||||
// Match live streaming styling: muted italic text
|
||||
// Same as stream.go renderReasoningBlock()
|
||||
lines := strings.Split(strings.TrimRight(content, "\n"), "\n")
|
||||
contentStr := strings.TrimLeft(strings.Join(lines, "\n"), " \t\n")
|
||||
mutedStyle := lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
rendered := mutedStyle.Render(r.ty.Italic(contentStr))
|
||||
rendered = styleMarginBottom1.Render(rendered)
|
||||
|
||||
rendered := renderContentBlock(
|
||||
fullContent,
|
||||
r.width,
|
||||
WithNoBorder(),
|
||||
WithForeground(theme.Muted),
|
||||
WithMarginBottom(1),
|
||||
)
|
||||
return UIMessage{
|
||||
Type: AssistantMessage,
|
||||
Content: rendered,
|
||||
Height: lipgloss.Height(rendered),
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderSystemMessage renders KIT system messages using herald Note alert
|
||||
func (r *MessageRenderer) RenderSystemMessage(content string, timestamp time.Time) UIMessage {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
content = "No content available"
|
||||
}
|
||||
|
||||
rendered := r.ty.Note(content)
|
||||
rendered = styleMarginBottom1.Render(rendered)
|
||||
|
||||
return UIMessage{
|
||||
Type: SystemMessage,
|
||||
@@ -290,27 +234,9 @@ func (r *MessageRenderer) RenderSystemMessage(content string, timestamp time.Tim
|
||||
}
|
||||
}
|
||||
|
||||
// RenderDebugMessage renders diagnostic and debugging information with special formatting
|
||||
// including a debug icon, colored border, and structured layout. Debug messages are only
|
||||
// displayed when debug mode is enabled and help developers troubleshoot issues.
|
||||
// RenderDebugMessage renders diagnostic and debugging information
|
||||
func (r *MessageRenderer) RenderDebugMessage(message string, timestamp time.Time) UIMessage {
|
||||
baseStyle := lipgloss.NewStyle()
|
||||
|
||||
theme := getTheme()
|
||||
style := baseStyle.
|
||||
Width(r.width - 3).
|
||||
BorderLeft(true).
|
||||
Foreground(theme.Muted).
|
||||
BorderForeground(theme.Tool).
|
||||
BorderStyle(lipgloss.ThickBorder()).
|
||||
PaddingLeft(1).
|
||||
MarginLeft(2).
|
||||
MarginBottom(1)
|
||||
|
||||
header := baseStyle.
|
||||
Foreground(theme.Tool).
|
||||
Bold(true).
|
||||
Render("🔍 Debug Output")
|
||||
header := r.ty.H6("🔍 Debug Output")
|
||||
|
||||
lines := strings.Split(message, "\n")
|
||||
var formattedLines []string
|
||||
@@ -320,87 +246,52 @@ func (r *MessageRenderer) RenderDebugMessage(message string, timestamp time.Time
|
||||
}
|
||||
}
|
||||
|
||||
content := baseStyle.
|
||||
Foreground(theme.Muted).
|
||||
Render(strings.Join(formattedLines, "\n"))
|
||||
|
||||
fullContent := lipgloss.JoinVertical(lipgloss.Left,
|
||||
content := r.ty.Compose(
|
||||
header,
|
||||
content,
|
||||
r.ty.P(strings.Join(formattedLines, "\n")),
|
||||
)
|
||||
content = styleMarginBottom1.Render(content)
|
||||
|
||||
return UIMessage{
|
||||
Content: style.Render(fullContent),
|
||||
Height: lipgloss.Height(style.Render(fullContent)),
|
||||
Content: content,
|
||||
Height: lipgloss.Height(content),
|
||||
}
|
||||
}
|
||||
|
||||
// RenderDebugConfigMessage renders configuration settings in a formatted debug display
|
||||
// with key-value pairs shown in a structured layout. Used to display runtime configuration
|
||||
// for debugging purposes with a distinctive icon and border styling.
|
||||
// RenderDebugConfigMessage renders configuration settings
|
||||
func (r *MessageRenderer) RenderDebugConfigMessage(config map[string]any, timestamp time.Time) UIMessage {
|
||||
baseStyle := lipgloss.NewStyle()
|
||||
|
||||
theme := getTheme()
|
||||
style := baseStyle.
|
||||
Width(r.width - 1).
|
||||
BorderLeft(true).
|
||||
Foreground(theme.Muted).
|
||||
BorderForeground(theme.Tool).
|
||||
BorderStyle(lipgloss.ThickBorder()).
|
||||
PaddingLeft(1)
|
||||
|
||||
header := baseStyle.
|
||||
Foreground(theme.Tool).
|
||||
Bold(true).
|
||||
Render("🔧 Debug Configuration")
|
||||
header := r.ty.H6("🔧 Debug Configuration")
|
||||
|
||||
var configLines []string
|
||||
for key, value := range config {
|
||||
if value != nil {
|
||||
configLines = append(configLines, fmt.Sprintf(" %s: %v", key, value))
|
||||
configLines = append(configLines, fmt.Sprintf("%s: %v", key, value))
|
||||
}
|
||||
}
|
||||
|
||||
configContent := baseStyle.
|
||||
Foreground(theme.Muted).
|
||||
Render(strings.Join(configLines, "\n"))
|
||||
|
||||
parts := []string{header}
|
||||
var content string
|
||||
if len(configLines) > 0 {
|
||||
parts = append(parts, configContent)
|
||||
content = r.ty.Compose(
|
||||
header,
|
||||
r.ty.P(strings.Join(configLines, "\n")),
|
||||
)
|
||||
} else {
|
||||
content = header
|
||||
}
|
||||
|
||||
rendered := style.Render(
|
||||
lipgloss.JoinVertical(lipgloss.Left, parts...),
|
||||
)
|
||||
content = styleMarginBottom1.Render(content)
|
||||
|
||||
return UIMessage{
|
||||
Type: SystemMessage,
|
||||
Content: rendered,
|
||||
Height: lipgloss.Height(rendered),
|
||||
Content: content,
|
||||
Height: lipgloss.Height(content),
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderErrorMessage renders error notifications with distinctive red coloring and
|
||||
// bold text to ensure visibility. Error messages include timestamp information and
|
||||
// are displayed with an error-colored border for immediate recognition.
|
||||
// RenderErrorMessage renders error notifications
|
||||
func (r *MessageRenderer) RenderErrorMessage(errorMsg string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
|
||||
errorContent := lipgloss.NewStyle().
|
||||
Foreground(theme.Error).
|
||||
Bold(true).
|
||||
Render(errorMsg)
|
||||
|
||||
rendered := renderContentBlock(
|
||||
errorContent,
|
||||
r.width,
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(theme.Error),
|
||||
WithMarginBottom(1),
|
||||
)
|
||||
rendered := r.ty.Caution(errorMsg)
|
||||
rendered = styleMarginBottom1.Render(rendered)
|
||||
|
||||
return UIMessage{
|
||||
Type: ErrorMessage,
|
||||
@@ -410,93 +301,18 @@ func (r *MessageRenderer) RenderErrorMessage(errorMsg string, timestamp time.Tim
|
||||
}
|
||||
}
|
||||
|
||||
// RenderToolCallMessage renders a notification that a tool is being executed, showing
|
||||
// the tool name, formatted arguments (if any), and execution timestamp. The message
|
||||
// uses tool-specific coloring to distinguish it from regular conversation messages.
|
||||
func (r *MessageRenderer) RenderToolCallMessage(toolName, toolArgs string, timestamp time.Time) UIMessage {
|
||||
// Format timestamp
|
||||
timeStr := timestamp.Local().Format("15:04")
|
||||
|
||||
// Format arguments with better presentation
|
||||
theme := getTheme()
|
||||
var argsContent string
|
||||
if toolArgs != "" && toolArgs != "{}" {
|
||||
argsContent = lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Italic(true).
|
||||
Render(fmt.Sprintf("Arguments: %s", r.formatToolArgs(toolArgs)))
|
||||
}
|
||||
|
||||
// Create info line
|
||||
info := fmt.Sprintf(" Executing %s (%s)", toolName, timeStr)
|
||||
|
||||
// Combine parts
|
||||
var fullContent string
|
||||
if argsContent != "" {
|
||||
fullContent = argsContent + "\n" +
|
||||
lipgloss.NewStyle().Foreground(theme.VeryMuted).Render(info)
|
||||
} else {
|
||||
fullContent = lipgloss.NewStyle().Foreground(theme.VeryMuted).Render(info)
|
||||
}
|
||||
|
||||
// Use the new block renderer
|
||||
rendered := renderContentBlock(
|
||||
fullContent,
|
||||
r.width,
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(theme.Tool),
|
||||
WithMarginBottom(1),
|
||||
)
|
||||
|
||||
return UIMessage{
|
||||
Type: ToolCallMessage,
|
||||
Content: rendered,
|
||||
Height: lipgloss.Height(rendered),
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderToolMessage renders a unified tool block combining the tool invocation
|
||||
// header (icon + display name + params) with the execution result body. The
|
||||
// border color indicates status: green for success, red for error. This replaces
|
||||
// the previous two-block approach (separate call + result blocks).
|
||||
// RenderToolMessage renders a unified tool block
|
||||
func (r *MessageRenderer) RenderToolMessage(toolName, toolArgs, toolResult string, isError bool) UIMessage {
|
||||
theme := getTheme()
|
||||
|
||||
// Resolve extension renderer once for all overrides.
|
||||
var extRd *ToolRendererData
|
||||
if r.getToolRenderer != nil {
|
||||
extRd = r.getToolRenderer(toolName)
|
||||
}
|
||||
|
||||
// --- Header: [icon] [name] [params] ---
|
||||
var icon string
|
||||
borderColor := theme.Success
|
||||
iconColor := theme.Success
|
||||
if isError {
|
||||
icon = "×"
|
||||
borderColor = theme.Error
|
||||
iconColor = theme.Error
|
||||
} else {
|
||||
icon = "✓"
|
||||
}
|
||||
|
||||
// Extension can override border color (applies to both success and error).
|
||||
if extRd != nil && extRd.BorderColor != "" {
|
||||
borderColor = lipgloss.Color(extRd.BorderColor)
|
||||
}
|
||||
|
||||
iconStr := lipgloss.NewStyle().Foreground(iconColor).Bold(true).Render(icon)
|
||||
|
||||
// Extension can override display name.
|
||||
displayName := toolDisplayName(toolName)
|
||||
if extRd != nil && extRd.DisplayName != "" {
|
||||
displayName = extRd.DisplayName
|
||||
}
|
||||
nameStr := lipgloss.NewStyle().Foreground(theme.Info).Bold(true).Render(displayName)
|
||||
|
||||
// Format params with width budget for the header line.
|
||||
// Check extension renderer for custom header params first.
|
||||
paramBudget := max(r.width-10-len(displayName), 20)
|
||||
var params string
|
||||
if extRd != nil && extRd.RenderHeader != nil {
|
||||
@@ -506,97 +322,70 @@ func (r *MessageRenderer) RenderToolMessage(toolName, toolArgs, toolResult strin
|
||||
params = formatToolParams(toolArgs, paramBudget)
|
||||
}
|
||||
|
||||
header := iconStr + " " + nameStr
|
||||
if params != "" {
|
||||
header += " " + lipgloss.NewStyle().Foreground(theme.Muted).Render(params)
|
||||
var icon string
|
||||
iconColor := GetTheme().Success
|
||||
if isError {
|
||||
icon = "×"
|
||||
iconColor = GetTheme().Error
|
||||
} else {
|
||||
icon = "✓"
|
||||
}
|
||||
|
||||
// --- Body: check extension renderer first, then builtin, then default ---
|
||||
// Style the tool name with color
|
||||
theme := GetTheme()
|
||||
nameColor := theme.Info
|
||||
if isError {
|
||||
nameColor = theme.Error
|
||||
}
|
||||
styledName := lipgloss.NewStyle().Foreground(nameColor).Bold(true).Render(displayName)
|
||||
styledIcon := lipgloss.NewStyle().Foreground(iconColor).Render(icon)
|
||||
|
||||
// Build the content: icon + name + params on first line, then body
|
||||
headerLine := styledIcon + " " + styledName
|
||||
if params != "" {
|
||||
headerLine += " " + lipgloss.NewStyle().Foreground(theme.Muted).Render(params)
|
||||
}
|
||||
|
||||
// Get body content
|
||||
var body string
|
||||
if extRd != nil && extRd.RenderBody != nil {
|
||||
body = extRd.RenderBody(toolResult, isError, r.width-8)
|
||||
// Apply markdown rendering if requested and body is non-empty.
|
||||
if body != "" && extRd.BodyMarkdown {
|
||||
body = strings.TrimSuffix(toMarkdown(body, r.width-8), "\n")
|
||||
}
|
||||
}
|
||||
if body == "" {
|
||||
if isError {
|
||||
body = lipgloss.NewStyle().
|
||||
Foreground(theme.Error).
|
||||
Render(toolResult)
|
||||
body = r.formatToolResult(toolName, toolResult)
|
||||
} else {
|
||||
body = renderToolBody(toolName, toolArgs, toolResult, r.width-8)
|
||||
if body == "" {
|
||||
body = r.formatToolResult(toolName, toolResult, r.width-8)
|
||||
body = r.formatToolResult(toolName, toolResult)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(body) == "" {
|
||||
body = lipgloss.NewStyle().
|
||||
Italic(true).
|
||||
Foreground(theme.Muted).
|
||||
Render("(no output)")
|
||||
body = r.ty.Italic("(no output)")
|
||||
}
|
||||
|
||||
// Combine header + body into a single block.
|
||||
fullContent := header + "\n\n" + strings.TrimSuffix(body, "\n")
|
||||
|
||||
// Build rendering options; extension can override background.
|
||||
blockOpts := []renderingOption{
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(borderColor),
|
||||
WithMarginBottom(1),
|
||||
}
|
||||
if extRd != nil && extRd.Background != "" {
|
||||
blockOpts = append(blockOpts, WithBackground(lipgloss.Color(extRd.Background)))
|
||||
}
|
||||
|
||||
rendered := renderContentBlock(
|
||||
fullContent,
|
||||
r.width,
|
||||
blockOpts...,
|
||||
// Compose: icon + name + params, then body
|
||||
fullContent := r.ty.Compose(
|
||||
headerLine,
|
||||
"",
|
||||
body,
|
||||
)
|
||||
fullContent = styleMarginBottom1.Render(fullContent)
|
||||
|
||||
return UIMessage{
|
||||
Type: ToolMessage,
|
||||
Content: rendered,
|
||||
Height: lipgloss.Height(rendered),
|
||||
Content: fullContent,
|
||||
Height: lipgloss.Height(fullContent),
|
||||
}
|
||||
}
|
||||
|
||||
// formatToolArgs formats tool arguments for display
|
||||
func (r *MessageRenderer) formatToolArgs(args string) string {
|
||||
// Remove outer braces and clean up JSON formatting
|
||||
args = strings.TrimSpace(args)
|
||||
if strings.HasPrefix(args, "{") && strings.HasSuffix(args, "}") {
|
||||
args = strings.TrimPrefix(args, "{")
|
||||
args = strings.TrimSuffix(args, "}")
|
||||
args = strings.TrimSpace(args)
|
||||
}
|
||||
|
||||
// If it's empty after cleanup, return a placeholder
|
||||
if args == "" {
|
||||
return "(no arguments)"
|
||||
}
|
||||
|
||||
// Truncate if too long, but skip truncation in debug mode
|
||||
if !r.debug {
|
||||
maxLen := 100
|
||||
if len(args) > maxLen {
|
||||
return args[:maxLen] + "..."
|
||||
}
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
// formatToolResult formats tool results based on tool type
|
||||
func (r *MessageRenderer) formatToolResult(toolName, result string, width int) string {
|
||||
baseStyle := lipgloss.NewStyle()
|
||||
|
||||
// Truncate very long results only if not in debug mode
|
||||
func (r *MessageRenderer) formatToolResult(toolName, result string) string {
|
||||
if !r.debug {
|
||||
maxLines := 10
|
||||
lines := strings.Split(result, "\n")
|
||||
@@ -605,58 +394,48 @@ func (r *MessageRenderer) formatToolResult(toolName, result string, width int) s
|
||||
}
|
||||
}
|
||||
|
||||
// Format bash/command output with better formatting
|
||||
if strings.Contains(toolName, "bash") || strings.Contains(toolName, "command") || strings.Contains(toolName, "shell") || toolName == "run_shell_cmd" {
|
||||
theme := getTheme()
|
||||
|
||||
// Split result into sections if it contains both stdout and stderr
|
||||
if strings.Contains(toolName, "bash") || strings.Contains(toolName, "command") ||
|
||||
strings.Contains(toolName, "shell") {
|
||||
if strings.Contains(result, "<stdout>") || strings.Contains(result, "<stderr>") {
|
||||
return r.formatBashOutput(result, width, theme)
|
||||
}
|
||||
|
||||
// For simple output, just render as monospace text with proper line breaks
|
||||
return baseStyle.
|
||||
Width(width).
|
||||
Foreground(theme.Muted).
|
||||
Render(result)
|
||||
}
|
||||
|
||||
// For other tools, render as muted text
|
||||
theme := getTheme()
|
||||
return baseStyle.
|
||||
Width(width).
|
||||
Foreground(theme.Muted).
|
||||
Render(result)
|
||||
}
|
||||
|
||||
// formatBashOutput formats bash command output with proper section handling.
|
||||
// Delegates tag parsing to the shared parseBashOutput helper.
|
||||
func (r *MessageRenderer) formatBashOutput(result string, width int, theme Theme) string {
|
||||
parsed := parseBashOutput(result, theme)
|
||||
return lipgloss.NewStyle().
|
||||
Width(width).
|
||||
Foreground(theme.Muted).
|
||||
Render(parsed)
|
||||
}
|
||||
|
||||
// renderMarkdown renders markdown content using glamour
|
||||
func (r *MessageRenderer) renderMarkdown(content string, width int) string {
|
||||
rendered := toMarkdown(content, width)
|
||||
return strings.TrimSuffix(rendered, "\n")
|
||||
}
|
||||
|
||||
// removeBlankLines removes lines that are visually blank from rendered output.
|
||||
// Glamour wraps every character (including padding spaces) with ANSI color
|
||||
// codes, so we must strip escape sequences before checking whether a line is
|
||||
// empty. This collapses paragraph spacing so user messages render without
|
||||
// extra vertical gaps.
|
||||
func removeBlankLines(s string) string {
|
||||
lines := strings.Split(s, "\n")
|
||||
filtered := lines[:0]
|
||||
for _, line := range lines {
|
||||
if strings.TrimSpace(ansiEscapeRe.ReplaceAllString(line, "")) != "" {
|
||||
filtered = append(filtered, line)
|
||||
return parseBashOutput(result, GetTheme())
|
||||
}
|
||||
}
|
||||
return strings.Join(filtered, "\n")
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// createTypography creates a typography instance from theme
|
||||
func createTypography(theme Theme) *herald.Typography {
|
||||
return herald.New(
|
||||
herald.WithPalette(herald.ColorPalette{
|
||||
Primary: theme.Primary,
|
||||
Secondary: theme.Secondary,
|
||||
Tertiary: theme.Info,
|
||||
Accent: theme.Accent,
|
||||
Highlight: theme.Highlight,
|
||||
Muted: theme.Muted,
|
||||
Text: theme.Text,
|
||||
Surface: theme.Background,
|
||||
Base: theme.CodeBg,
|
||||
}),
|
||||
herald.WithAlertPalette(herald.AlertPalette{
|
||||
Note: theme.Info,
|
||||
Tip: theme.Success,
|
||||
Important: theme.Accent,
|
||||
Warning: theme.Warning,
|
||||
Caution: theme.Error,
|
||||
}),
|
||||
herald.WithCodeLineNumbers(true),
|
||||
// Customize alert labels
|
||||
herald.WithAlertLabel(herald.AlertNote, "Info"),
|
||||
herald.WithAlertLabel(herald.AlertTip, "You"),
|
||||
herald.WithAlertLabel(herald.AlertWarning, "Working"),
|
||||
herald.WithAlertLabel(herald.AlertCaution, "Error"),
|
||||
)
|
||||
}
|
||||
|
||||
// UpdateTheme refreshes the renderer's typography instance with colors from
|
||||
// the current theme. This is called when the user changes themes via /theme.
|
||||
func (r *MessageRenderer) UpdateTheme() {
|
||||
r.ty = createTypography(GetTheme())
|
||||
}
|
||||
|
||||
+815
-364
File diff suppressed because it is too large
Load Diff
@@ -50,7 +50,7 @@ func NewModelSelector(currentModel string, width, height int) *ModelSelectorComp
|
||||
registry := models.GetGlobalRegistry()
|
||||
var allModels []ModelEntry
|
||||
|
||||
for _, providerID := range registry.GetFantasyProviders() {
|
||||
for _, providerID := range registry.GetLLMProviders() {
|
||||
// Only include providers with valid API keys configured.
|
||||
if err := registry.ValidateEnvironment(providerID, ""); err != nil {
|
||||
continue
|
||||
@@ -281,7 +281,14 @@ func (ms *ModelSelectorComponent) View() tea.View {
|
||||
footerStyle := lipgloss.NewStyle().Foreground(theme.Muted).PaddingLeft(2)
|
||||
b.WriteString(footerStyle.Render(strings.Join(footerParts, " ")))
|
||||
|
||||
return tea.NewView(b.String())
|
||||
v := tea.NewView(b.String())
|
||||
v.AltScreen = true
|
||||
v.MouseMode = tea.MouseModeCellMotion
|
||||
v.ReportFocus = true
|
||||
v.KeyboardEnhancements = tea.KeyboardEnhancements{
|
||||
ReportEventTypes: true,
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// IsActive returns whether the selector is still accepting input.
|
||||
|
||||
@@ -5,9 +5,9 @@ import (
|
||||
"testing"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/fantasy"
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
"github.com/mark3labs/kit/internal/session"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -46,6 +46,10 @@ func (s *stubAppController) ClearMessages() {
|
||||
s.clearMsgCalled++
|
||||
}
|
||||
|
||||
func (s *stubAppController) ReloadMessagesFromTree() {
|
||||
// no-op in tests
|
||||
}
|
||||
|
||||
func (s *stubAppController) CompactConversation(_ string) error {
|
||||
return nil
|
||||
}
|
||||
@@ -54,6 +58,10 @@ func (s *stubAppController) GetTreeSession() *session.TreeManager {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAppController) SwitchTreeSession(_ *session.TreeManager) {
|
||||
// no-op in tests
|
||||
}
|
||||
|
||||
func (s *stubAppController) SendEvent(_ tea.Msg) {
|
||||
// no-op in tests
|
||||
}
|
||||
@@ -62,7 +70,12 @@ func (s *stubAppController) AddContextMessage(_ string) {
|
||||
// no-op in tests
|
||||
}
|
||||
|
||||
func (s *stubAppController) RunWithFiles(prompt string, _ []fantasy.FilePart) int {
|
||||
func (s *stubAppController) RunWithFiles(prompt string, _ []kit.LLMFilePart) int {
|
||||
s.runCalls = append(s.runCalls, prompt)
|
||||
return s.queueLen
|
||||
}
|
||||
|
||||
func (s *stubAppController) Steer(prompt string) int {
|
||||
s.runCalls = append(s.runCalls, prompt)
|
||||
return s.queueLen
|
||||
}
|
||||
@@ -88,9 +101,11 @@ func (s *stubStreamComponent) View() tea.View { return tea.NewView("
|
||||
func (s *stubStreamComponent) Reset() { s.resetCalled++; s.renderedContent = "" }
|
||||
func (s *stubStreamComponent) SetHeight(h int) { s.height = h }
|
||||
func (s *stubStreamComponent) GetRenderedContent() string { return s.renderedContent }
|
||||
func (s *stubStreamComponent) ConsumeOverflow() string { return "" }
|
||||
func (s *stubStreamComponent) SpinnerView() string { return "" }
|
||||
func (s *stubStreamComponent) SetThinkingVisible(bool) {}
|
||||
func (s *stubStreamComponent) HasReasoning() bool { return false }
|
||||
func (s *stubStreamComponent) UpdateTheme() {}
|
||||
|
||||
// stubInputComponent satisfies inputComponentIface without rendering anything.
|
||||
type stubInputComponent struct {
|
||||
@@ -117,7 +132,6 @@ func newTestAppModel(ctrl AppController) (*AppModel, *stubStreamComponent, *stub
|
||||
stream: stream,
|
||||
input: input,
|
||||
renderer: newMessageRenderer(80, false),
|
||||
compactMode: false,
|
||||
modelName: "test-model",
|
||||
width: 80,
|
||||
height: 24,
|
||||
@@ -133,7 +147,11 @@ func newTestAppModel(ctrl AppController) (*AppModel, *stubStreamComponent, *stub
|
||||
// sendMsg calls m.Update once with the given message and returns the updated model.
|
||||
func sendMsg(m *AppModel, msg tea.Msg) *AppModel {
|
||||
updated, _ := m.Update(msg)
|
||||
return updated.(*AppModel)
|
||||
result := updated.(*AppModel)
|
||||
// Simulate BubbleTea's frame cycle: View() is called after every Update().
|
||||
// This flushes any pending layoutDirty work (e.g. distributeHeight).
|
||||
_ = result.View()
|
||||
return result
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -679,6 +697,57 @@ func TestToolResult_clearsStreamingBashOutput(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolCallStarted_extractsBashCommand verifies that ToolCallStartedEvent
|
||||
// extracts the bash command from ToolArgs and stores it for the streaming output header.
|
||||
func TestToolCallStarted_extractsBashCommand(t *testing.T) {
|
||||
ctrl := &stubAppController{}
|
||||
m, _, _ := newTestAppModel(ctrl)
|
||||
m.state = stateWorking
|
||||
|
||||
// Send ToolCallStartedEvent with bash command.
|
||||
m = sendMsg(m, app.ToolCallStartedEvent{
|
||||
ToolCallID: "call-1",
|
||||
ToolName: "bash",
|
||||
ToolArgs: `{"command":"ls -la /home"}`,
|
||||
})
|
||||
|
||||
if m.streamingBashCommand != "ls -la /home" {
|
||||
t.Fatalf("expected streamingBashCommand='ls -la /home', got %q", m.streamingBashCommand)
|
||||
}
|
||||
|
||||
// ToolResultEvent should clear the command.
|
||||
m = sendMsg(m, app.ToolResultEvent{
|
||||
ToolCallID: "call-1",
|
||||
ToolName: "bash",
|
||||
ToolArgs: `{"command":"ls -la /home"}`,
|
||||
Result: "output",
|
||||
IsError: false,
|
||||
})
|
||||
|
||||
if m.streamingBashCommand != "" {
|
||||
t.Fatalf("expected streamingBashCommand cleared, got %q", m.streamingBashCommand)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolCallStarted_nonBashTool_doesNotSetCommand verifies that non-bash tools
|
||||
// do not set the streamingBashCommand field.
|
||||
func TestToolCallStarted_nonBashTool_doesNotSetCommand(t *testing.T) {
|
||||
ctrl := &stubAppController{}
|
||||
m, _, _ := newTestAppModel(ctrl)
|
||||
m.state = stateWorking
|
||||
|
||||
// Send ToolCallStartedEvent with a non-bash tool.
|
||||
m = sendMsg(m, app.ToolCallStartedEvent{
|
||||
ToolCallID: "call-1",
|
||||
ToolName: "read",
|
||||
ToolArgs: `{"file":"/etc/passwd"}`,
|
||||
})
|
||||
|
||||
if m.streamingBashCommand != "" {
|
||||
t.Fatalf("expected streamingBashCommand to remain empty for non-bash tools, got %q", m.streamingBashCommand)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStepError_printCmd verifies that StepErrorEvent with a non-nil error
|
||||
// produces a non-nil cmd (the tea.Println call for the error message).
|
||||
func TestStepError_printCmd(t *testing.T) {
|
||||
|
||||
@@ -118,22 +118,33 @@ func (m ProgressModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// status information and help text. Displays error messages if present or
|
||||
// a completion message when the download finishes.
|
||||
func (m ProgressModel) View() tea.View {
|
||||
var v tea.View
|
||||
v.AltScreen = true
|
||||
v.MouseMode = tea.MouseModeCellMotion
|
||||
v.ReportFocus = true
|
||||
v.KeyboardEnhancements = tea.KeyboardEnhancements{
|
||||
ReportEventTypes: true,
|
||||
}
|
||||
|
||||
if m.err != nil {
|
||||
return tea.NewView(fmt.Sprintf("Error: %s\n", m.err.Error()))
|
||||
v.Content = fmt.Sprintf("Error: %s\n", m.err.Error())
|
||||
return v
|
||||
}
|
||||
|
||||
if m.complete {
|
||||
return tea.NewView(fmt.Sprintf("\n%s%s\n\n%sComplete!\n",
|
||||
v.Content = fmt.Sprintf("\n%s%s\n\n%sComplete!\n",
|
||||
strings.Repeat(" ", padding),
|
||||
m.progress.View(),
|
||||
strings.Repeat(" ", padding)))
|
||||
strings.Repeat(" ", padding))
|
||||
return v
|
||||
}
|
||||
|
||||
pad := strings.Repeat(" ", padding)
|
||||
return tea.NewView(fmt.Sprintf("\n%s%s\n%s%s\n\n%s",
|
||||
v.Content = fmt.Sprintf("\n%s%s\n%s%s\n\n%s",
|
||||
pad, m.progress.View(),
|
||||
pad, m.status,
|
||||
pad+helpStyle("Press 'q' or Ctrl+C to cancel")))
|
||||
pad+helpStyle("Press 'q' or Ctrl+C to cancel"))
|
||||
return v
|
||||
}
|
||||
|
||||
// ProgressReader wraps an io.Reader to intercept and parse Ollama pull operation
|
||||
|
||||
@@ -0,0 +1,629 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"charm.land/lipgloss/v2"
|
||||
)
|
||||
|
||||
// highlightStyle is lazily initialized to avoid creating it on every render
|
||||
var highlightStyle lipgloss.Style
|
||||
|
||||
// initHighlightStyle creates the highlight style with proper colors
|
||||
func initHighlightStyle() lipgloss.Style {
|
||||
if highlightStyle.String() == "" {
|
||||
theme := GetTheme()
|
||||
highlightStyle = lipgloss.NewStyle().
|
||||
Background(theme.Secondary).
|
||||
Foreground(theme.Background).
|
||||
Bold(true)
|
||||
}
|
||||
return highlightStyle
|
||||
}
|
||||
|
||||
// MessageItem is the interface all scrollback messages must implement.
|
||||
// This allows lazy rendering - messages are only rendered when visible.
|
||||
type MessageItem interface {
|
||||
// Render returns the styled content for this message at the given width.
|
||||
// Implementations should cache the result to avoid re-rendering.
|
||||
Render(width int) string
|
||||
|
||||
// Height returns the number of lines this message occupies when rendered.
|
||||
Height() int
|
||||
|
||||
// ID returns a unique identifier for this message (for tracking).
|
||||
ID() string
|
||||
}
|
||||
|
||||
// ScrollList manages a viewport over a list of MessageItems.
|
||||
// It handles offset-based scrolling and lazy rendering. Only visible
|
||||
// items are rendered on each View() call.
|
||||
type ScrollList struct {
|
||||
items []MessageItem
|
||||
offsetIdx int // Index of first visible item
|
||||
offsetLine int // Lines to skip from first visible item
|
||||
width int
|
||||
height int // Viewport height in lines
|
||||
autoScroll bool // Whether to auto-scroll to bottom on new content
|
||||
itemGap int // Number of blank lines between items (0 = no gap)
|
||||
focusedIdx int // Index of focused/selected item (-1 = none)
|
||||
selectable bool // Whether items can be selected via mouse/keyboard
|
||||
|
||||
// Selection tracking for copy+paste (crush-style)
|
||||
selection CopySelection // Current text selection
|
||||
mouseDown bool // Whether mouse button is currently down
|
||||
mouseDownX int // X coordinate where mouse was pressed
|
||||
mouseDownY int // Y coordinate where mouse was pressed
|
||||
mouseDownItem int // Item index where mouse was pressed
|
||||
}
|
||||
|
||||
// NewScrollList creates a new ScrollList with the given dimensions.
|
||||
func NewScrollList(width, height int) *ScrollList {
|
||||
return &ScrollList{
|
||||
items: []MessageItem{},
|
||||
offsetIdx: 0,
|
||||
offsetLine: 0,
|
||||
width: width,
|
||||
height: height,
|
||||
autoScroll: true, // Start with auto-scroll enabled
|
||||
}
|
||||
}
|
||||
|
||||
// SetItems replaces the items in the scroll list. If auto-scroll is enabled,
|
||||
// the viewport will scroll to the bottom to show the latest content.
|
||||
func (s *ScrollList) SetItems(items []MessageItem) {
|
||||
s.items = items
|
||||
if s.autoScroll {
|
||||
s.GotoBottom()
|
||||
}
|
||||
}
|
||||
|
||||
// SetHeight updates the viewport height. Called when the terminal is resized.
|
||||
func (s *ScrollList) SetHeight(height int) {
|
||||
s.height = height
|
||||
s.clampOffset()
|
||||
}
|
||||
|
||||
// SetWidth updates the viewport width. Called when the terminal is resized.
|
||||
// This may invalidate cached renders in MessageItems.
|
||||
func (s *ScrollList) SetWidth(width int) {
|
||||
s.width = width
|
||||
s.clampOffset()
|
||||
}
|
||||
|
||||
// SetItemGap sets the number of blank lines between items (0 = no gap).
|
||||
func (s *ScrollList) SetItemGap(gap int) {
|
||||
s.itemGap = gap
|
||||
}
|
||||
|
||||
// ItemGap returns the current gap between items.
|
||||
func (s *ScrollList) ItemGap() int {
|
||||
return s.itemGap
|
||||
}
|
||||
|
||||
// SetSelectable enables or disables item selection.
|
||||
func (s *ScrollList) SetSelectable(selectable bool) {
|
||||
s.selectable = selectable
|
||||
}
|
||||
|
||||
// FocusedIdx returns the currently focused item index (-1 if none).
|
||||
func (s *ScrollList) FocusedIdx() int {
|
||||
return s.focusedIdx
|
||||
}
|
||||
|
||||
// SetFocused sets the focused item by index.
|
||||
func (s *ScrollList) SetFocused(idx int) {
|
||||
if idx < -1 {
|
||||
s.focusedIdx = -1
|
||||
} else if idx >= len(s.items) {
|
||||
s.focusedIdx = len(s.items) - 1
|
||||
} else {
|
||||
s.focusedIdx = idx
|
||||
}
|
||||
}
|
||||
|
||||
// SelectItemAtY selects the item at the given Y coordinate (relative to viewport).
|
||||
// Returns the selected item index or -1 if no item at that position.
|
||||
func (s *ScrollList) SelectItemAtY(y int) int {
|
||||
if !s.selectable || len(s.items) == 0 || y < 0 || y >= s.height {
|
||||
return -1
|
||||
}
|
||||
|
||||
// Calculate which item is at the given Y position
|
||||
currentY := 0
|
||||
for idx := s.offsetIdx; idx < len(s.items); idx++ {
|
||||
item := s.items[idx]
|
||||
itemHeight := item.Height()
|
||||
|
||||
// Check if y falls within this item
|
||||
if y >= currentY && y < currentY+itemHeight {
|
||||
s.focusedIdx = idx
|
||||
return idx
|
||||
}
|
||||
|
||||
currentY += itemHeight
|
||||
|
||||
// Add gap after item (except last)
|
||||
if s.itemGap > 0 && idx < len(s.items)-1 {
|
||||
currentY += s.itemGap
|
||||
}
|
||||
|
||||
// Stop if we've passed the viewport
|
||||
if currentY >= s.height {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
// HandleMouseDown handles mouse button press for selection (crush-style).
|
||||
// Returns true if the click was handled.
|
||||
func (s *ScrollList) HandleMouseDown(x, y int) bool {
|
||||
if !s.selectable || len(s.items) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
s.mouseDown = true
|
||||
s.mouseDownX = x
|
||||
s.mouseDownY = y
|
||||
|
||||
// Find which item and line was clicked
|
||||
itemIdx, lineIdx := s.getItemAndLineAtY(y)
|
||||
s.mouseDownItem = itemIdx
|
||||
|
||||
// Start a new selection at click position
|
||||
if itemIdx >= 0 {
|
||||
s.selection = CopySelection{
|
||||
StartItemIdx: itemIdx,
|
||||
StartLine: lineIdx,
|
||||
StartCol: x,
|
||||
EndItemIdx: itemIdx,
|
||||
EndLine: lineIdx,
|
||||
EndCol: x,
|
||||
Active: true,
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// HandleMouseDrag handles mouse drag for selection (crush-style).
|
||||
// Updates the selection end point. Returns true if selection changed.
|
||||
func (s *ScrollList) HandleMouseDrag(x, y int) bool {
|
||||
if !s.mouseDown || !s.selectable {
|
||||
return false
|
||||
}
|
||||
|
||||
// Find which item and line we're dragging over
|
||||
itemIdx, lineIdx := s.getItemAndLineAtY(y)
|
||||
if itemIdx < 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Update selection end point
|
||||
s.selection.EndItemIdx = itemIdx
|
||||
s.selection.EndLine = lineIdx
|
||||
s.selection.EndCol = x
|
||||
s.selection.Active = true
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// getItemAndLineAtY converts a Y coordinate to item index and line index within that item.
|
||||
// Returns (-1, -1) if Y is outside the viewport or beyond all items.
|
||||
func (s *ScrollList) getItemAndLineAtY(y int) (itemIdx, lineIdx int) {
|
||||
if y < 0 || y >= s.height || len(s.items) == 0 {
|
||||
return -1, -1
|
||||
}
|
||||
|
||||
currentY := 0
|
||||
for idx := s.offsetIdx; idx < len(s.items); idx++ {
|
||||
item := s.items[idx]
|
||||
itemHeight := item.Height()
|
||||
|
||||
// Check if y falls within this item
|
||||
if y >= currentY && y < currentY+itemHeight {
|
||||
return idx, y - currentY
|
||||
}
|
||||
|
||||
currentY += itemHeight
|
||||
|
||||
// Add gap after item (except last)
|
||||
if s.itemGap > 0 && idx < len(s.items)-1 {
|
||||
currentY += s.itemGap
|
||||
}
|
||||
|
||||
// Stop if we've passed the viewport
|
||||
if currentY >= s.height {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return -1, -1
|
||||
}
|
||||
|
||||
// HandleMouseUp handles mouse button release (crush-style).
|
||||
// Finalizes selection and returns true if there was an active selection.
|
||||
func (s *ScrollList) HandleMouseUp(x, y int) bool {
|
||||
if !s.mouseDown {
|
||||
return false
|
||||
}
|
||||
|
||||
s.mouseDown = false
|
||||
|
||||
// Check if we have a valid selection
|
||||
if s.selection.Active && !s.selection.IsEmpty() {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// GetSelection returns the current text selection.
|
||||
func (s *ScrollList) GetSelection() CopySelection {
|
||||
return s.selection
|
||||
}
|
||||
|
||||
// ClearSelection clears the current text selection.
|
||||
func (s *ScrollList) ClearSelection() {
|
||||
s.selection = CopySelection{}
|
||||
s.mouseDown = false
|
||||
}
|
||||
|
||||
// HasSelection returns true if there is an active non-empty selection.
|
||||
func (s *ScrollList) HasSelection() bool {
|
||||
return s.selection.Active && !s.selection.IsEmpty()
|
||||
}
|
||||
|
||||
// ScrollBy scrolls the viewport by the given number of lines.
|
||||
// Positive = scroll down, negative = scroll up.
|
||||
func (s *ScrollList) ScrollBy(lines int) {
|
||||
if lines > 0 {
|
||||
// Scroll down
|
||||
for lines > 0 && s.offsetIdx < len(s.items) {
|
||||
if s.offsetIdx >= len(s.items) {
|
||||
break
|
||||
}
|
||||
currentItem := s.items[s.offsetIdx]
|
||||
itemHeight := currentItem.Height()
|
||||
remainingLines := itemHeight - s.offsetLine
|
||||
|
||||
if lines >= remainingLines {
|
||||
// Move to next item
|
||||
s.offsetIdx++
|
||||
s.offsetLine = 0
|
||||
lines -= remainingLines
|
||||
// Consume gap lines between items
|
||||
if s.itemGap > 0 && s.offsetIdx < len(s.items) {
|
||||
if lines >= s.itemGap {
|
||||
lines -= s.itemGap
|
||||
} else {
|
||||
lines = 0
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Stay on current item, skip more lines
|
||||
s.offsetLine += lines
|
||||
lines = 0
|
||||
}
|
||||
}
|
||||
} else if lines < 0 {
|
||||
// Scroll up
|
||||
lines = -lines
|
||||
for lines > 0 && (s.offsetIdx > 0 || s.offsetLine > 0) {
|
||||
if s.offsetLine > 0 {
|
||||
// Scroll within current item
|
||||
if lines >= s.offsetLine {
|
||||
lines -= s.offsetLine
|
||||
s.offsetLine = 0
|
||||
} else {
|
||||
s.offsetLine -= lines
|
||||
lines = 0
|
||||
}
|
||||
} else if s.offsetIdx > 0 {
|
||||
// Consume gap lines between items
|
||||
if s.itemGap > 0 {
|
||||
if lines > s.itemGap {
|
||||
lines -= s.itemGap
|
||||
} else {
|
||||
lines = 0
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Move to previous item
|
||||
s.offsetIdx--
|
||||
if s.offsetIdx < len(s.items) {
|
||||
currentItem := s.items[s.offsetIdx]
|
||||
itemHeight := currentItem.Height()
|
||||
|
||||
if lines >= itemHeight {
|
||||
lines -= itemHeight
|
||||
s.offsetLine = 0
|
||||
} else {
|
||||
s.offsetLine = itemHeight - lines
|
||||
lines = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
s.clampOffset()
|
||||
}
|
||||
|
||||
// GotoBottom scrolls to the end of the list.
|
||||
func (s *ScrollList) GotoBottom() {
|
||||
if len(s.items) == 0 {
|
||||
s.offsetIdx = 0
|
||||
s.offsetLine = 0
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate total height including gaps
|
||||
totalHeight := 0
|
||||
for i, item := range s.items {
|
||||
totalHeight += item.Height()
|
||||
// Add gap after each item except the last
|
||||
if s.itemGap > 0 && i < len(s.items)-1 {
|
||||
totalHeight += s.itemGap
|
||||
}
|
||||
}
|
||||
|
||||
// If content fits in viewport, start at top
|
||||
if totalHeight <= s.height {
|
||||
s.offsetIdx = 0
|
||||
s.offsetLine = 0
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, position viewport at bottom
|
||||
remaining := totalHeight - s.height
|
||||
for idx := 0; idx < len(s.items); idx++ {
|
||||
itemHeight := s.items[idx].Height()
|
||||
if remaining < itemHeight {
|
||||
s.offsetIdx = idx
|
||||
s.offsetLine = remaining
|
||||
return
|
||||
}
|
||||
remaining -= itemHeight
|
||||
// Subtract gap after item (except last)
|
||||
if s.itemGap > 0 && idx < len(s.items)-1 {
|
||||
remaining -= s.itemGap
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: show last item
|
||||
s.offsetIdx = max(0, len(s.items)-1)
|
||||
s.offsetLine = 0
|
||||
}
|
||||
|
||||
// GotoTop scrolls to the beginning of the list.
|
||||
func (s *ScrollList) GotoTop() {
|
||||
s.offsetIdx = 0
|
||||
s.offsetLine = 0
|
||||
}
|
||||
|
||||
// AtBottom returns true if the viewport is at the bottom of the list.
|
||||
func (s *ScrollList) AtBottom() bool {
|
||||
if len(s.items) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Calculate visible height from current position including gaps
|
||||
visibleHeight := 0
|
||||
for idx := s.offsetIdx; idx < len(s.items); idx++ {
|
||||
item := s.items[idx]
|
||||
itemHeight := item.Height()
|
||||
|
||||
if idx == s.offsetIdx {
|
||||
visibleHeight += itemHeight - s.offsetLine
|
||||
} else {
|
||||
visibleHeight += itemHeight
|
||||
}
|
||||
|
||||
// Add gap after item (except last)
|
||||
if s.itemGap > 0 && idx < len(s.items)-1 {
|
||||
visibleHeight += s.itemGap
|
||||
}
|
||||
|
||||
if visibleHeight >= s.height {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// AtTop returns true if the viewport is at the top of the list.
|
||||
func (s *ScrollList) AtTop() bool {
|
||||
return s.offsetIdx == 0 && s.offsetLine == 0
|
||||
}
|
||||
|
||||
// View renders the visible portion of the scrollback.
|
||||
// Only items that fit within the viewport height are rendered.
|
||||
// ALWAYS returns exactly s.height lines (padded with empty lines if needed)
|
||||
// to ensure the input/footer stay fixed at the bottom.
|
||||
func (s *ScrollList) View() string {
|
||||
if s.height <= 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var lines []string
|
||||
remainingHeight := s.height
|
||||
|
||||
// Render visible items
|
||||
if len(s.items) > 0 {
|
||||
for idx := s.offsetIdx; idx < len(s.items) && remainingHeight > 0; idx++ {
|
||||
item := s.items[idx]
|
||||
content := item.Render(s.width)
|
||||
contentLines := strings.Split(content, "\n")
|
||||
|
||||
startLine := 0
|
||||
if idx == s.offsetIdx {
|
||||
startLine = s.offsetLine
|
||||
}
|
||||
|
||||
// Check if this item is focused (for visual indicator)
|
||||
isFocused := idx == s.focusedIdx
|
||||
|
||||
for i := startLine; i < len(contentLines) && remainingHeight > 0; i++ {
|
||||
line := contentLines[i]
|
||||
|
||||
// Apply selection highlighting if this line is within selection
|
||||
if s.selection.Active && s.isLineInSelection(idx, i) {
|
||||
line = s.applyHighlight(line)
|
||||
} else if isFocused && s.selectable {
|
||||
// Apply subtle focus indicator when item is focused but not in selection
|
||||
line = s.applyFocusIndicator(line)
|
||||
}
|
||||
|
||||
lines = append(lines, line)
|
||||
remainingHeight--
|
||||
}
|
||||
|
||||
// Add gap lines between items (but not after the last visible item)
|
||||
if remainingHeight > 0 && idx < len(s.items)-1 && s.itemGap > 0 {
|
||||
for g := 0; g < s.itemGap && remainingHeight > 0; g++ {
|
||||
lines = append(lines, "")
|
||||
remainingHeight--
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pad with empty lines to ensure exactly s.height lines
|
||||
// This keeps the input/footer fixed at the bottom of the screen
|
||||
for remainingHeight > 0 {
|
||||
lines = append(lines, "")
|
||||
remainingHeight--
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// isLineInSelection checks if a specific line within an item is part of the current selection.
|
||||
func (s *ScrollList) isLineInSelection(itemIdx, lineIdx int) bool {
|
||||
if !s.selection.Active {
|
||||
return false
|
||||
}
|
||||
|
||||
// Normalize selection (start <= end)
|
||||
startItem := s.selection.StartItemIdx
|
||||
startLine := s.selection.StartLine
|
||||
endItem := s.selection.EndItemIdx
|
||||
endLine := s.selection.EndLine
|
||||
|
||||
if startItem > endItem || (startItem == endItem && startLine > endLine) {
|
||||
startItem, endItem = endItem, startItem
|
||||
startLine, endLine = endLine, startLine
|
||||
}
|
||||
|
||||
// Check if item is within selection range
|
||||
if itemIdx < startItem || itemIdx > endItem {
|
||||
return false
|
||||
}
|
||||
|
||||
// For single item selection
|
||||
if startItem == endItem {
|
||||
return itemIdx == startItem && lineIdx >= startLine && lineIdx <= endLine
|
||||
}
|
||||
|
||||
// For multi-item selection
|
||||
if itemIdx == startItem {
|
||||
return lineIdx >= startLine
|
||||
}
|
||||
if itemIdx == endItem {
|
||||
return lineIdx <= endLine
|
||||
}
|
||||
// Middle items are fully selected
|
||||
return itemIdx > startItem && itemIdx < endItem
|
||||
}
|
||||
|
||||
// applyHighlight applies the highlight style to a line.
|
||||
// Uses the theme's Highlight color for the background.
|
||||
func (s *ScrollList) applyHighlight(line string) string {
|
||||
if line == "" {
|
||||
return line
|
||||
}
|
||||
// Apply background/foreground color change for selection
|
||||
style := initHighlightStyle()
|
||||
return style.Render(line)
|
||||
}
|
||||
|
||||
// applyFocusIndicator applies a subtle visual indicator for focused items.
|
||||
func (s *ScrollList) applyFocusIndicator(line string) string {
|
||||
if line == "" {
|
||||
return line
|
||||
}
|
||||
// Just return the line as-is - no visual indicator for focus
|
||||
// The selection highlighting is enough
|
||||
return line
|
||||
}
|
||||
|
||||
// ScrollPercent returns the current scroll position as a percentage (0.0-1.0).
|
||||
// 0.0 = at top, 1.0 = at bottom. Useful for scroll indicators.
|
||||
func (s *ScrollList) ScrollPercent() float64 {
|
||||
if len(s.items) == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
totalHeight := 0
|
||||
for _, item := range s.items {
|
||||
totalHeight += item.Height()
|
||||
}
|
||||
|
||||
if totalHeight <= s.height {
|
||||
return 1.0 // All content fits, consider it "at bottom"
|
||||
}
|
||||
|
||||
// Calculate how many lines are above the viewport
|
||||
linesAbove := 0
|
||||
for i := 0; i < s.offsetIdx && i < len(s.items); i++ {
|
||||
linesAbove += s.items[i].Height()
|
||||
}
|
||||
linesAbove += s.offsetLine
|
||||
|
||||
scrollableHeight := totalHeight - s.height
|
||||
if scrollableHeight <= 0 {
|
||||
return 1.0
|
||||
}
|
||||
|
||||
percent := float64(linesAbove) / float64(scrollableHeight)
|
||||
if percent > 1.0 {
|
||||
percent = 1.0
|
||||
}
|
||||
if percent < 0.0 {
|
||||
percent = 0.0
|
||||
}
|
||||
return percent
|
||||
}
|
||||
|
||||
// clampOffset ensures the offset values are within valid bounds after
|
||||
// resizing or scrolling operations.
|
||||
func (s *ScrollList) clampOffset() {
|
||||
if len(s.items) == 0 {
|
||||
s.offsetIdx = 0
|
||||
s.offsetLine = 0
|
||||
return
|
||||
}
|
||||
|
||||
// Clamp offsetIdx
|
||||
if s.offsetIdx >= len(s.items) {
|
||||
s.offsetIdx = len(s.items) - 1
|
||||
}
|
||||
if s.offsetIdx < 0 {
|
||||
s.offsetIdx = 0
|
||||
}
|
||||
|
||||
// Clamp offsetLine
|
||||
if s.offsetIdx < len(s.items) {
|
||||
itemHeight := s.items[s.offsetIdx].Height()
|
||||
if s.offsetLine >= itemHeight {
|
||||
s.offsetLine = max(0, itemHeight-1)
|
||||
}
|
||||
}
|
||||
if s.offsetLine < 0 {
|
||||
s.offsetLine = 0
|
||||
}
|
||||
}
|
||||
@@ -325,7 +325,14 @@ func (ss *SessionSelectorComponent) View() tea.View {
|
||||
}
|
||||
}
|
||||
|
||||
return tea.NewView(b.String())
|
||||
v := tea.NewView(b.String())
|
||||
v.AltScreen = true
|
||||
v.MouseMode = tea.MouseModeCellMotion
|
||||
v.ReportFocus = true
|
||||
v.KeyboardEnhancements = tea.KeyboardEnhancements{
|
||||
ReportEventTypes: true,
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// IsActive returns whether the selector is still accepting input.
|
||||
|
||||
@@ -1,352 +0,0 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"charm.land/bubbles/v2/key"
|
||||
"charm.land/bubbles/v2/textarea"
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/lipgloss/v2"
|
||||
)
|
||||
|
||||
// SlashCommandInput provides an interactive text input field with intelligent
|
||||
// slash command autocomplete functionality. It displays a popup menu of matching
|
||||
// commands as the user types, supporting fuzzy matching and keyboard navigation.
|
||||
type SlashCommandInput struct {
|
||||
textarea textarea.Model
|
||||
commands []SlashCommand
|
||||
showPopup bool
|
||||
filtered []FuzzyMatch
|
||||
selected int
|
||||
width int
|
||||
lastValue string
|
||||
popupHeight int
|
||||
title string
|
||||
quitting bool
|
||||
value string
|
||||
submitNext bool // Flag to submit on next update
|
||||
renderedLines int // Track how many lines were rendered
|
||||
hideHint bool // Suppress the "enter submit · ctrl+j..." hint
|
||||
}
|
||||
|
||||
// NewSlashCommandInput creates and initializes a new slash command input field with
|
||||
// the specified width and title. The input supports multi-line text entry, command
|
||||
// autocomplete, and is styled to match the application's theme.
|
||||
func NewSlashCommandInput(width int, title string) *SlashCommandInput {
|
||||
ta := textarea.New()
|
||||
ta.Placeholder = "Type your message..."
|
||||
ta.ShowLineNumbers = false
|
||||
ta.Prompt = ""
|
||||
ta.CharLimit = 5000
|
||||
ta.SetWidth(width - 8) // Account for container padding, border and internal padding
|
||||
ta.SetHeight(3) // Default to 3 lines like huh
|
||||
ta.Focus()
|
||||
|
||||
// Override InsertNewline so only ctrl+j and shift+enter insert newlines.
|
||||
// Enter always submits the input.
|
||||
ta.KeyMap.InsertNewline = key.NewBinding(
|
||||
key.WithKeys("ctrl+j", "shift+enter"),
|
||||
key.WithHelp("ctrl+j", "insert newline"),
|
||||
)
|
||||
|
||||
// Style the textarea using theme colors.
|
||||
theme := GetTheme()
|
||||
styles := ta.Styles()
|
||||
styles.Focused.Base = lipgloss.NewStyle()
|
||||
styles.Focused.Placeholder = lipgloss.NewStyle().Foreground(theme.VeryMuted)
|
||||
styles.Focused.Text = lipgloss.NewStyle().Foreground(theme.Text)
|
||||
styles.Focused.Prompt = lipgloss.NewStyle()
|
||||
styles.Focused.CursorLine = lipgloss.NewStyle()
|
||||
ta.SetStyles(styles)
|
||||
|
||||
return &SlashCommandInput{
|
||||
textarea: ta,
|
||||
commands: SlashCommands,
|
||||
width: width,
|
||||
popupHeight: 7,
|
||||
title: title,
|
||||
}
|
||||
}
|
||||
|
||||
// Init implements the tea.Model interface, returning the initial command to start
|
||||
// the cursor blinking animation for the text input field.
|
||||
func (s *SlashCommandInput) Init() tea.Cmd {
|
||||
return textarea.Blink
|
||||
}
|
||||
|
||||
// Update implements the tea.Model interface, handling keyboard input for text entry,
|
||||
// command selection, and navigation. Manages the autocomplete popup display and
|
||||
// processes submission or cancellation actions.
|
||||
func (s *SlashCommandInput) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
|
||||
// Check if we need to submit after updating the view
|
||||
if s.submitNext {
|
||||
s.value = s.textarea.Value()
|
||||
s.quitting = true
|
||||
return s, tea.Quit
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyPressMsg: // Check for quit keys first (when popup is not shown)
|
||||
if !s.showPopup {
|
||||
switch msg.String() {
|
||||
case "ctrl+c", "esc":
|
||||
s.quitting = true
|
||||
return s, tea.Quit
|
||||
case "ctrl+d", "enter": // Enter always submits
|
||||
s.value = s.textarea.Value()
|
||||
s.quitting = true
|
||||
return s, tea.Quit
|
||||
}
|
||||
}
|
||||
|
||||
// Handle popup navigation
|
||||
if s.showPopup {
|
||||
switch {
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("up"), key.WithHelp("↑", "up"))):
|
||||
if s.selected > 0 {
|
||||
s.selected--
|
||||
}
|
||||
return s, nil
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("down"), key.WithHelp("↓", "down"))):
|
||||
if s.selected < len(s.filtered)-1 {
|
||||
s.selected++
|
||||
}
|
||||
return s, nil
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("tab"))):
|
||||
if s.selected < len(s.filtered) {
|
||||
// Complete with selected command
|
||||
s.textarea.SetValue(s.filtered[s.selected].Command.Name)
|
||||
s.showPopup = false
|
||||
s.selected = 0
|
||||
// Move cursor to end
|
||||
s.textarea.CursorEnd()
|
||||
}
|
||||
return s, nil
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("enter"))):
|
||||
if s.selected < len(s.filtered) {
|
||||
// Populate the field with the selected command
|
||||
s.textarea.SetValue(s.filtered[s.selected].Command.Name)
|
||||
s.textarea.CursorEnd()
|
||||
// Hide the popup
|
||||
s.showPopup = false
|
||||
s.selected = 0
|
||||
// Set flag to submit on next update (after view refresh)
|
||||
s.submitNext = true
|
||||
// Force a refresh
|
||||
return s, nil
|
||||
}
|
||||
return s, nil
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("esc"))):
|
||||
s.showPopup = false
|
||||
s.selected = 0
|
||||
return s, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Update textarea
|
||||
s.textarea, cmd = s.textarea.Update(msg)
|
||||
|
||||
// Check if we should show/update popup
|
||||
value := s.textarea.Value()
|
||||
if value != s.lastValue {
|
||||
s.lastValue = value
|
||||
// Only show popup if we're on the first line and it starts with /
|
||||
lines := strings.Split(value, "\n")
|
||||
if len(lines) > 0 && strings.HasPrefix(lines[0], "/") && !strings.Contains(lines[0], " ") && len(lines) == 1 {
|
||||
// Show and update popup
|
||||
s.showPopup = true
|
||||
s.filtered = FuzzyMatchCommands(lines[0], s.commands)
|
||||
s.selected = 0
|
||||
} else {
|
||||
// Hide popup
|
||||
s.showPopup = false
|
||||
}
|
||||
}
|
||||
return s, cmd
|
||||
|
||||
default:
|
||||
// Pass through other messages
|
||||
s.textarea, cmd = s.textarea.Update(msg)
|
||||
return s, cmd
|
||||
}
|
||||
}
|
||||
|
||||
// View implements the tea.Model interface, rendering the complete input field
|
||||
// including the title, text area, autocomplete popup (when active), and help text.
|
||||
// The view adapts based on whether single or multi-line input is detected.
|
||||
func (s *SlashCommandInput) View() tea.View {
|
||||
containerStyle := lipgloss.NewStyle()
|
||||
|
||||
theme := GetTheme()
|
||||
|
||||
// PaddingLeft(3) aligns with message content: border(1) + paddingLeft(2).
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Text).
|
||||
MarginBottom(1).
|
||||
PaddingLeft(3)
|
||||
|
||||
// Input box with huh-like styling
|
||||
inputBoxStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.ThickBorder()).
|
||||
BorderLeft(true).
|
||||
BorderRight(false).
|
||||
BorderTop(false).
|
||||
BorderBottom(false).
|
||||
BorderForeground(theme.Primary).
|
||||
PaddingLeft(2). // match message block paddingLeft
|
||||
Width(s.width - 1) // full width minus left border
|
||||
|
||||
// Build the view
|
||||
var view strings.Builder
|
||||
view.WriteString(titleStyle.Render(s.title))
|
||||
view.WriteString("\n")
|
||||
view.WriteString(inputBoxStyle.Render(s.textarea.View()))
|
||||
// Count rendered lines
|
||||
s.renderedLines = 2 + s.textarea.Height() // title + newline + textarea height
|
||||
|
||||
// Add popup if visible
|
||||
if s.showPopup && len(s.filtered) > 0 {
|
||||
view.WriteString("\n")
|
||||
view.WriteString(s.renderPopup())
|
||||
// Add popup lines
|
||||
visibleItems := min(len(s.filtered), s.popupHeight)
|
||||
scrollIndicators := 0
|
||||
if s.selected >= s.popupHeight {
|
||||
scrollIndicators++ // top indicator
|
||||
}
|
||||
if len(s.filtered) > s.popupHeight {
|
||||
scrollIndicators++ // bottom indicator
|
||||
}
|
||||
popupLines := visibleItems + scrollIndicators + 5 // items + scroll + border + padding + footer
|
||||
s.renderedLines += 1 + popupLines // newline + popup
|
||||
}
|
||||
|
||||
// Add help text at bottom (unless hidden by extension).
|
||||
if !s.hideHint {
|
||||
helpStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.VeryMuted).
|
||||
MarginTop(1).
|
||||
PaddingLeft(3)
|
||||
|
||||
helpText := "enter submit • ctrl+j / shift+enter new line"
|
||||
|
||||
view.WriteString("\n")
|
||||
view.WriteString(helpStyle.Render(helpText))
|
||||
s.renderedLines += 2 // newline + help text
|
||||
}
|
||||
|
||||
// Apply container padding to entire view
|
||||
return tea.NewView(containerStyle.Render(view.String()))
|
||||
}
|
||||
|
||||
// renderPopup renders the autocomplete popup
|
||||
func (s *SlashCommandInput) renderPopup() string {
|
||||
theme := GetTheme()
|
||||
|
||||
// Popup styling
|
||||
popupStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(theme.MutedBorder).
|
||||
Padding(1, 2).
|
||||
Width(s.width - 4). // Account for container padding
|
||||
MarginLeft(0) // No extra margin needed due to container padding
|
||||
|
||||
var items []string
|
||||
|
||||
// Calculate visible window
|
||||
visibleItems := min(len(s.filtered), s.popupHeight)
|
||||
startIdx := 0
|
||||
|
||||
// Adjust window to keep selected item visible
|
||||
if s.selected >= s.popupHeight {
|
||||
startIdx = s.selected - s.popupHeight + 1
|
||||
}
|
||||
|
||||
endIdx := min(startIdx+visibleItems, len(s.filtered))
|
||||
|
||||
for i := startIdx; i < endIdx; i++ {
|
||||
match := s.filtered[i]
|
||||
cmd := match.Command
|
||||
// Create the selection indicator
|
||||
var indicator string
|
||||
if i == s.selected {
|
||||
indicator = lipgloss.NewStyle().
|
||||
Foreground(theme.Primary).
|
||||
Render("> ")
|
||||
} else {
|
||||
indicator = " "
|
||||
}
|
||||
|
||||
// Format item
|
||||
nameStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Secondary).
|
||||
Bold(true)
|
||||
|
||||
descStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted)
|
||||
|
||||
// Highlight selected item
|
||||
if i == s.selected {
|
||||
nameStyle = nameStyle.Foreground(theme.Primary)
|
||||
descStyle = descStyle.Foreground(theme.Text)
|
||||
}
|
||||
|
||||
// Format with proper spacing
|
||||
nameWidth := 15
|
||||
name := nameStyle.Width(nameWidth - 2).Render(cmd.Name)
|
||||
|
||||
// Truncate description if needed
|
||||
desc := cmd.Description
|
||||
maxDescLen := s.width - nameWidth - 14 // Account for padding and indicator
|
||||
if len(desc) > maxDescLen && maxDescLen > 3 {
|
||||
desc = desc[:maxDescLen-3] + "..."
|
||||
}
|
||||
|
||||
line := indicator + name + descStyle.Render(desc)
|
||||
items = append(items, line)
|
||||
}
|
||||
|
||||
// Add scroll indicators if needed
|
||||
if startIdx > 0 {
|
||||
scrollUpStyle := lipgloss.NewStyle().Foreground(theme.VeryMuted)
|
||||
items = append([]string{scrollUpStyle.Render(" ↑ more above")}, items...)
|
||||
}
|
||||
if endIdx < len(s.filtered) {
|
||||
scrollDownStyle := lipgloss.NewStyle().Foreground(theme.VeryMuted)
|
||||
items = append(items, scrollDownStyle.Render(" ↓ more below"))
|
||||
}
|
||||
// Join items
|
||||
content := strings.Join(items, "\n")
|
||||
|
||||
// Add footer hint
|
||||
footerStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.VeryMuted).
|
||||
Italic(true)
|
||||
footer := footerStyle.Render("↑↓ navigate • tab complete • ↵ select • esc dismiss")
|
||||
|
||||
// Combine content and footer
|
||||
popupContent := content + "\n\n" + footer
|
||||
|
||||
return popupStyle.Render(popupContent)
|
||||
}
|
||||
|
||||
// Value returns the final text value entered by the user after submission.
|
||||
// This will be empty if the input was cancelled.
|
||||
func (s *SlashCommandInput) Value() string {
|
||||
return s.value
|
||||
}
|
||||
|
||||
// Cancelled returns true if the user cancelled the input operation (e.g., by
|
||||
// pressing ESC or Ctrl+C) without submitting any text.
|
||||
func (s *SlashCommandInput) Cancelled() bool {
|
||||
return s.quitting && s.value == ""
|
||||
}
|
||||
|
||||
// RenderedLines returns the total number of terminal lines used by the last
|
||||
// rendered view, including the title, input area, popup, and help text. This
|
||||
// is used for proper screen clearing when the input is dismissed.
|
||||
func (s *SlashCommandInput) RenderedLines() int {
|
||||
return s.renderedLines
|
||||
}
|
||||
+264
-98
@@ -2,14 +2,27 @@ package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/lipgloss/v2"
|
||||
"github.com/indaco/herald"
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
)
|
||||
|
||||
// thinkTagRegex matches ... tags that some models (Qwen, DeepSeek) wrap
|
||||
// reasoning content in. Used to strip these tags from streaming text content.
|
||||
// The (?s) flag makes . match newlines.
|
||||
var thinkTagRegex = regexp.MustCompile(`(?s)` + `` + `think` + `` + `(.*?)` + `` + `/think` + ``)
|
||||
|
||||
// thinkTagOpen and thinkTagClose are the opening and closing think tag strings.
|
||||
const (
|
||||
thinkTagOpen = "<think>"
|
||||
thinkTagClose = "</think>"
|
||||
)
|
||||
|
||||
// knightRiderFrames generates a KITT-style scanning animation where a bright
|
||||
// light bounces back and forth across a row of dots with a trailing glow.
|
||||
// Colors are derived from the active theme. Used by StreamComponent (TUI
|
||||
@@ -79,7 +92,12 @@ func streamSpinnerTickCmd(generation uint64) tea.Cmd {
|
||||
// streamFlushTickMsg fires when it's time to commit pending chunks to the
|
||||
// main content builders and trigger a re-render. This coalesces rapid
|
||||
// streaming chunks into fewer expensive markdown re-renders.
|
||||
type streamFlushTickMsg struct{}
|
||||
//
|
||||
// generation ties the tick to the pending flush session that created it so
|
||||
// stale ticks from a prior Reset() are discarded.
|
||||
type streamFlushTickMsg struct {
|
||||
generation uint64
|
||||
}
|
||||
|
||||
// streamFlushInterval is the coalescing window for stream chunks. Chunks
|
||||
// arriving within this window are batched into a single render pass.
|
||||
@@ -89,9 +107,9 @@ const streamFlushInterval = 16 * time.Millisecond
|
||||
|
||||
// streamFlushTickCmd returns a tea.Cmd that fires streamFlushTickMsg after
|
||||
// the coalescing interval.
|
||||
func streamFlushTickCmd() tea.Cmd {
|
||||
func streamFlushTickCmd(generation uint64) tea.Cmd {
|
||||
return tea.Tick(streamFlushInterval, func(_ time.Time) tea.Msg {
|
||||
return streamFlushTickMsg{}
|
||||
return streamFlushTickMsg{generation: generation}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -149,9 +167,11 @@ type StreamComponent struct {
|
||||
// spinnerFrame is the current frame index.
|
||||
spinnerFrame int
|
||||
|
||||
// activeTools tracks the names of tools currently executing in parallel.
|
||||
// When multiple tools run concurrently, all are displayed in the spinner.
|
||||
activeTools []string
|
||||
// activeTools maps ToolCallID -> display label for currently running tools.
|
||||
activeTools map[string]string
|
||||
|
||||
// activeToolOrder preserves deterministic display order for active tools.
|
||||
activeToolOrder []string
|
||||
|
||||
// streamContent holds committed streaming text (flushed from pending).
|
||||
streamContent strings.Builder
|
||||
@@ -172,6 +192,10 @@ type StreamComponent struct {
|
||||
// the same coalescing window.
|
||||
flushPending bool
|
||||
|
||||
// flushGeneration is incremented when stream state resets so stale flush
|
||||
// ticks from a previous step can be discarded.
|
||||
flushGeneration uint64
|
||||
|
||||
// renderCache holds the last rendered output string. Reused by View()
|
||||
// between flush ticks to avoid redundant markdown re-parsing.
|
||||
renderCache string
|
||||
@@ -181,6 +205,14 @@ type StreamComponent struct {
|
||||
// the cache.
|
||||
renderDirty bool
|
||||
|
||||
// scrollbackFlushedLines is the number of lines from the top of the
|
||||
// rendered content that have already been emitted to the terminal
|
||||
// scrollback buffer. On each flush, lines that overflow the allocated
|
||||
// height and haven't been pushed yet are emitted via tea.Println so
|
||||
// they appear in the terminal's real scrollback (scrollable with the
|
||||
// terminal's own scroll mechanism).
|
||||
scrollbackFlushedLines int
|
||||
|
||||
// thinkingVisible controls whether reasoning blocks are expanded or collapsed.
|
||||
thinkingVisible bool
|
||||
|
||||
@@ -190,14 +222,12 @@ type StreamComponent struct {
|
||||
// reasoningDuration holds the total reasoning time, frozen when streaming text begins.
|
||||
reasoningDuration time.Duration
|
||||
|
||||
// messageRenderer renders assistant messages in standard mode.
|
||||
messageRenderer *MessageRenderer
|
||||
// inThinkTag tracks whether we're currently inside a section
|
||||
// from models that wrap reasoning in XML-like tags (Qwen, DeepSeek).
|
||||
inThinkTag bool
|
||||
|
||||
// compactRenderer renders assistant messages in compact mode.
|
||||
compactRenderer *CompactRenderer
|
||||
|
||||
// compactMode selects which renderer to use.
|
||||
compactMode bool
|
||||
// renderer renders streaming assistant text.
|
||||
renderer Renderer
|
||||
|
||||
// modelName is displayed in the streaming text header.
|
||||
modelName string
|
||||
@@ -211,20 +241,25 @@ type StreamComponent struct {
|
||||
// height constrains the render output to at most this many lines.
|
||||
// 0 means unconstrained.
|
||||
height int
|
||||
|
||||
// ty provides typography functions for rendering text.
|
||||
ty *herald.Typography
|
||||
}
|
||||
|
||||
// NewStreamComponent creates a new StreamComponent ready to be embedded in AppModel.
|
||||
func NewStreamComponent(compactMode bool, width int, modelName string) *StreamComponent {
|
||||
func NewStreamComponent(width int, modelName string) *StreamComponent {
|
||||
if width == 0 {
|
||||
width = 80
|
||||
}
|
||||
|
||||
renderer := newMessageRenderer(width, false)
|
||||
|
||||
return &StreamComponent{
|
||||
spinnerFrames: knightRiderFrames(),
|
||||
compactMode: compactMode,
|
||||
modelName: modelName,
|
||||
messageRenderer: newMessageRenderer(width, false),
|
||||
compactRenderer: NewCompactRenderer(width, false),
|
||||
width: width,
|
||||
spinnerFrames: knightRiderFrames(),
|
||||
modelName: modelName,
|
||||
renderer: renderer,
|
||||
width: width,
|
||||
ty: createTypography(GetTheme()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -251,16 +286,54 @@ func (s *StreamComponent) Reset() {
|
||||
s.spinnerGeneration++ // invalidate any in-flight tick commands
|
||||
s.spinnerFrame = 0
|
||||
s.activeTools = nil
|
||||
s.activeToolOrder = nil
|
||||
s.streamContent.Reset()
|
||||
s.reasoningContent.Reset()
|
||||
s.pendingStream.Reset()
|
||||
s.pendingReasoning.Reset()
|
||||
s.flushPending = false
|
||||
s.flushGeneration++
|
||||
s.renderCache = ""
|
||||
s.renderDirty = false
|
||||
s.timestamp = time.Time{}
|
||||
s.reasoningStartTime = time.Time{}
|
||||
s.reasoningDuration = 0
|
||||
s.scrollbackFlushedLines = 0
|
||||
}
|
||||
|
||||
// ConsumeOverflow returns any lines from the rendered stream content that have
|
||||
// overflowed the allocated height and have not yet been pushed to the terminal
|
||||
// scrollback buffer. It advances the internal flushed-line pointer so
|
||||
// subsequent calls only return newly overflowed lines.
|
||||
//
|
||||
// Returns "" when there is no overflow or height is unconstrained (0).
|
||||
// The caller should emit the returned string via tea.Println so the content
|
||||
// appears in the terminal's real scrollback (not just discarded).
|
||||
func (s *StreamComponent) ConsumeOverflow() string {
|
||||
if s.height <= 0 {
|
||||
return ""
|
||||
}
|
||||
content := s.render()
|
||||
if content == "" {
|
||||
return ""
|
||||
}
|
||||
lines := strings.Split(content, "\n")
|
||||
totalLines := len(lines)
|
||||
// Number of lines that overflow the viewable height.
|
||||
overflowLines := totalLines - s.height
|
||||
if overflowLines <= 0 {
|
||||
return ""
|
||||
}
|
||||
// How many overflow lines are new (not yet flushed to scrollback).
|
||||
newOverflow := overflowLines - s.scrollbackFlushedLines
|
||||
if newOverflow <= 0 {
|
||||
return ""
|
||||
}
|
||||
// The new overflow is lines [s.scrollbackFlushedLines .. overflowLines).
|
||||
start := s.scrollbackFlushedLines
|
||||
end := overflowLines
|
||||
s.scrollbackFlushedLines = overflowLines
|
||||
return strings.Join(lines[start:end], "\n")
|
||||
}
|
||||
|
||||
// GetRenderedContent returns the rendered assistant message from the accumulated
|
||||
@@ -269,6 +342,10 @@ func (s *StreamComponent) Reset() {
|
||||
//
|
||||
// This commits any pending chunks first so the output includes all received
|
||||
// content, not just what has been flushed by the tick.
|
||||
//
|
||||
// Lines already pushed to the terminal scrollback buffer via ConsumeOverflow
|
||||
// are skipped so that callers do not re-emit content that is already visible
|
||||
// in the terminal's real scrollback.
|
||||
func (s *StreamComponent) GetRenderedContent() string {
|
||||
// Commit any pending chunks so the final output is complete.
|
||||
s.commitPending()
|
||||
@@ -282,20 +359,35 @@ func (s *StreamComponent) GetRenderedContent() string {
|
||||
|
||||
text := s.streamContent.String()
|
||||
if text != "" {
|
||||
sections = append(sections, s.renderStreamingText(text))
|
||||
rendered := s.renderStreamingText(text)
|
||||
sections = append(sections, rendered)
|
||||
}
|
||||
|
||||
if len(sections) == 0 {
|
||||
return ""
|
||||
}
|
||||
return strings.Join(sections, "\n")
|
||||
fullContent := strings.Join(sections, "\n")
|
||||
|
||||
// Skip lines already emitted to the terminal scrollback via ConsumeOverflow
|
||||
// so the caller doesn't re-print content that is already there.
|
||||
if s.scrollbackFlushedLines > 0 {
|
||||
lines := strings.Split(fullContent, "\n")
|
||||
if s.scrollbackFlushedLines >= len(lines) {
|
||||
return "" // everything already in scrollback
|
||||
}
|
||||
return strings.Join(lines[s.scrollbackFlushedLines:], "\n")
|
||||
}
|
||||
|
||||
return fullContent
|
||||
}
|
||||
|
||||
// commitPending moves any pending chunks to the committed content builders.
|
||||
// Called before reading content for scrollback output or on flush tick.
|
||||
func (s *StreamComponent) commitPending() {
|
||||
if s.pendingStream.Len() > 0 {
|
||||
s.streamContent.WriteString(s.pendingStream.String())
|
||||
// Strip ... tags that some models wrap reasoning in
|
||||
cleanedText := thinkTagRegex.ReplaceAllString(s.pendingStream.String(), "")
|
||||
s.streamContent.WriteString(cleanedText)
|
||||
s.pendingStream.Reset()
|
||||
s.renderDirty = true
|
||||
}
|
||||
@@ -322,8 +414,9 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
case tea.WindowSizeMsg:
|
||||
s.width = msg.Width
|
||||
s.messageRenderer.SetWidth(s.width)
|
||||
s.compactRenderer.SetWidth(s.width)
|
||||
if s.renderer != nil {
|
||||
s.renderer.SetWidth(s.width)
|
||||
}
|
||||
// Invalidate render cache — width change affects wrapping/styling.
|
||||
s.renderCache = ""
|
||||
s.renderDirty = true
|
||||
@@ -359,6 +452,9 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
|
||||
case streamFlushTickMsg:
|
||||
if msg.generation != s.flushGeneration {
|
||||
break
|
||||
}
|
||||
s.flushPending = false
|
||||
s.commitPending()
|
||||
|
||||
@@ -373,7 +469,7 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
s.pendingReasoning.WriteString(msg.Delta)
|
||||
if !s.flushPending {
|
||||
s.flushPending = true
|
||||
return s, streamFlushTickCmd()
|
||||
return s, streamFlushTickCmd(s.flushGeneration)
|
||||
}
|
||||
|
||||
case app.StreamChunkEvent:
|
||||
@@ -385,17 +481,66 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if s.reasoningDuration == 0 && !s.reasoningStartTime.IsZero() {
|
||||
s.reasoningDuration = time.Since(s.reasoningStartTime)
|
||||
}
|
||||
s.pendingStream.WriteString(msg.Content)
|
||||
if !s.flushPending {
|
||||
|
||||
// Handle models that wrap reasoning in tags (Qwen, DeepSeek)
|
||||
// Filter out all content between and tags
|
||||
content := msg.Content
|
||||
|
||||
// Check for opening tag
|
||||
if strings.Contains(content, thinkTagOpen) {
|
||||
parts := strings.SplitN(content, thinkTagOpen, 2)
|
||||
// Content before the tag can be written
|
||||
if !s.inThinkTag && parts[0] != "" {
|
||||
s.pendingStream.WriteString(parts[0])
|
||||
}
|
||||
s.inThinkTag = true
|
||||
// Content after the opening tag is reasoning - don't write it
|
||||
if len(parts) > 1 && parts[1] != "" {
|
||||
// Check if the same chunk contains the closing tag
|
||||
if strings.Contains(parts[1], thinkTagClose) {
|
||||
innerParts := strings.SplitN(parts[1], thinkTagClose, 2)
|
||||
s.inThinkTag = false
|
||||
// Content after closing tag can be written
|
||||
if len(innerParts) > 1 && innerParts[1] != "" {
|
||||
s.pendingStream.WriteString(innerParts[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if strings.Contains(content, thinkTagClose) {
|
||||
// Closing tag found
|
||||
parts := strings.SplitN(content, thinkTagClose, 2)
|
||||
s.inThinkTag = false
|
||||
// Content after closing tag can be written
|
||||
if len(parts) > 1 && parts[1] != "" {
|
||||
s.pendingStream.WriteString(parts[1])
|
||||
}
|
||||
} else if !s.inThinkTag {
|
||||
// Normal content, not inside think tags
|
||||
s.pendingStream.WriteString(content)
|
||||
}
|
||||
// else: inside think tag, don't write this content
|
||||
|
||||
if !s.flushPending && s.pendingStream.Len() > 0 {
|
||||
s.flushPending = true
|
||||
return s, streamFlushTickCmd()
|
||||
return s, streamFlushTickCmd(s.flushGeneration)
|
||||
}
|
||||
|
||||
case app.ToolExecutionEvent:
|
||||
toolID := msg.ToolCallID
|
||||
if toolID == "" {
|
||||
// Defensive fallback for older/third-party emitters that may omit
|
||||
// ToolCallID. Best-effort only: same-name+args concurrent calls can
|
||||
// still collide without a stable ID.
|
||||
toolID = fmt.Sprintf("%s|%s", msg.ToolName, msg.ToolArgs)
|
||||
}
|
||||
if msg.IsStarting {
|
||||
// Add tool to active list for parallel execution display.
|
||||
toolDisplay := formatToolExecutionMessage(msg.ToolName, msg.ToolArgs)
|
||||
s.activeTools = append(s.activeTools, toolDisplay)
|
||||
if s.activeTools == nil {
|
||||
s.activeTools = make(map[string]string)
|
||||
}
|
||||
if _, exists := s.activeTools[toolID]; !exists {
|
||||
s.activeToolOrder = append(s.activeToolOrder, toolID)
|
||||
}
|
||||
s.activeTools[toolID] = formatToolExecutionMessage(msg.ToolName)
|
||||
s.spinnerFrame = 0
|
||||
if !s.spinning {
|
||||
s.phase = streamPhaseActive
|
||||
@@ -404,9 +549,10 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return s, streamSpinnerTickCmd(s.spinnerGeneration)
|
||||
}
|
||||
} else {
|
||||
// Tool finished — remove from active list but keep spinning if others remain.
|
||||
toolDisplay := formatToolExecutionMessage(msg.ToolName, msg.ToolArgs)
|
||||
s.activeTools = removeFromSlice(s.activeTools, toolDisplay)
|
||||
if s.activeTools != nil {
|
||||
delete(s.activeTools, toolID)
|
||||
}
|
||||
s.activeToolOrder = removeToolID(s.activeToolOrder, toolID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -415,7 +561,16 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
// View implements tea.Model. Renders the current stream region content.
|
||||
func (s *StreamComponent) View() tea.View {
|
||||
return tea.NewView(s.render())
|
||||
fullContent := s.render()
|
||||
visibleContent := s.viewContent(fullContent)
|
||||
v := tea.NewView(visibleContent)
|
||||
v.AltScreen = true
|
||||
v.MouseMode = tea.MouseModeCellMotion
|
||||
v.ReportFocus = true
|
||||
v.KeyboardEnhancements = tea.KeyboardEnhancements{
|
||||
ReportEventTypes: true,
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -458,54 +613,51 @@ func (s *StreamComponent) render() string {
|
||||
|
||||
content := strings.Join(sections, "\n")
|
||||
|
||||
// Clamp to height if constrained: keep the last h lines so the most
|
||||
// recent output is always visible.
|
||||
if s.height > 0 && content != "" {
|
||||
lines := strings.Split(content, "\n")
|
||||
if len(lines) > s.height {
|
||||
lines = lines[len(lines)-s.height:]
|
||||
content = strings.Join(lines, "\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Cache FULL content without height clamping.
|
||||
// Height clamping is applied in View() for display only.
|
||||
s.renderCache = content
|
||||
s.renderDirty = false
|
||||
return content
|
||||
}
|
||||
|
||||
// renderReasoningBlock renders the reasoning/thinking content in a surface-tinted
|
||||
// box. When collapsed, shows the last 10 lines with a truncation hint. When
|
||||
// viewContent returns the visible portion of content based on height constraint.
|
||||
// This is called by View() to get the slice that fits in the terminal.
|
||||
func (s *StreamComponent) viewContent(fullContent string) string {
|
||||
if s.height > 0 && fullContent != "" {
|
||||
lines := strings.Split(fullContent, "\n")
|
||||
if len(lines) > s.height {
|
||||
// Keep only the last h lines so the most recent output is visible.
|
||||
lines = lines[len(lines)-s.height:]
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
}
|
||||
return fullContent
|
||||
}
|
||||
|
||||
// renderReasoningBlock renders the reasoning/thinking content using blockquote.
|
||||
// When collapsed, shows the last 10 lines with a truncation hint. When
|
||||
// expanded, shows all lines. Includes a "Thought for Xs" duration footer.
|
||||
func (s *StreamComponent) renderReasoningBlock(reasoning string) string {
|
||||
theme := GetTheme()
|
||||
maxWidth := max(s.width-4, 20)
|
||||
|
||||
lines := strings.Split(strings.TrimRight(reasoning, "\n"), "\n")
|
||||
|
||||
contentStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.MutedBorder).
|
||||
Italic(true)
|
||||
|
||||
var parts []string
|
||||
|
||||
// When collapsed and content exceeds 10 lines, show only the last 10
|
||||
// with a truncation hint (matching iteratr's thinking block pattern).
|
||||
// with a truncation hint.
|
||||
const maxCollapsedLines = 10
|
||||
if !s.thinkingVisible && len(lines) > maxCollapsedLines {
|
||||
hidden := len(lines) - maxCollapsedLines
|
||||
hintStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.VeryMuted).
|
||||
Background(theme.MutedBorder).
|
||||
Italic(true)
|
||||
parts = append(parts, hintStyle.Render(fmt.Sprintf("... (%d lines hidden)", hidden)))
|
||||
parts = append(parts, s.ty.Italic(fmt.Sprintf("... (%d lines hidden)", hidden)))
|
||||
lines = lines[len(lines)-maxCollapsedLines:]
|
||||
}
|
||||
|
||||
// Render reasoning text.
|
||||
parts = append(parts, contentStyle.Width(maxWidth).Render(strings.Join(lines, "\n")))
|
||||
// Main content using Italic with Muted color for visual distinction.
|
||||
content := strings.TrimLeft(strings.Join(lines, "\n"), " \t\n")
|
||||
theme := GetTheme()
|
||||
mutedStyle := lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
parts = append(parts, mutedStyle.Render(s.ty.Italic(content)))
|
||||
|
||||
// Duration footer.
|
||||
// Duration footer with VeryMuted label and Accent duration.
|
||||
var duration time.Duration
|
||||
if s.reasoningDuration > 0 {
|
||||
duration = s.reasoningDuration
|
||||
@@ -519,21 +671,21 @@ func (s *StreamComponent) renderReasoningBlock(reasoning string) string {
|
||||
} else {
|
||||
durationStr = fmt.Sprintf("%.1fs", duration.Seconds())
|
||||
}
|
||||
footer := lipgloss.NewStyle().Foreground(theme.VeryMuted).Background(theme.MutedBorder).Render("Thought for ") +
|
||||
lipgloss.NewStyle().Foreground(theme.Info).Background(theme.MutedBorder).Render(durationStr)
|
||||
parts = append(parts, footer)
|
||||
label := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render("Thought for ")
|
||||
durationStyled := lipgloss.NewStyle().Foreground(theme.Accent).Render(durationStr)
|
||||
parts = append(parts, label+durationStyled)
|
||||
}
|
||||
|
||||
innerContent := strings.Join(parts, "\n")
|
||||
|
||||
// Wrap in box with surface background for visual distinction.
|
||||
boxStyle := lipgloss.NewStyle().
|
||||
Background(theme.MutedBorder). // Surface0 (#313244)
|
||||
PaddingLeft(1).
|
||||
Width(maxWidth + 2).
|
||||
MarginBottom(1)
|
||||
|
||||
return boxStyle.Render(innerContent)
|
||||
// Concatenate parts with newline between blockquote and footer
|
||||
var result string
|
||||
if len(parts) == 1 {
|
||||
result = parts[0]
|
||||
} else if len(parts) == 2 {
|
||||
result = parts[0] + "\n" + parts[1]
|
||||
} else {
|
||||
result = strings.Join(parts, "\n")
|
||||
}
|
||||
return styleMarginBottom1.Render(result)
|
||||
}
|
||||
|
||||
// SetThinkingVisible sets whether reasoning blocks are shown or collapsed.
|
||||
@@ -559,7 +711,8 @@ func (s *StreamComponent) SpinnerView() string {
|
||||
return ""
|
||||
}
|
||||
frame := s.spinnerFrames[s.spinnerFrame%len(s.spinnerFrames)]
|
||||
if len(s.activeTools) == 0 {
|
||||
tools := s.activeToolDisplays()
|
||||
if len(tools) == 0 {
|
||||
return " " + frame
|
||||
}
|
||||
theme := GetTheme()
|
||||
@@ -569,10 +722,10 @@ func (s *StreamComponent) SpinnerView() string {
|
||||
|
||||
// Format active tools list
|
||||
var toolsMsg string
|
||||
if len(s.activeTools) == 1 {
|
||||
toolsMsg = s.activeTools[0]
|
||||
if len(tools) == 1 {
|
||||
toolsMsg = tools[0]
|
||||
} else {
|
||||
toolsMsg = "Running: " + strings.Join(s.activeTools, ", ")
|
||||
toolsMsg = "Running: " + strings.Join(tools, ", ")
|
||||
}
|
||||
return " " + frame + " " + msgStyle.Render(toolsMsg)
|
||||
}
|
||||
@@ -584,30 +737,43 @@ func (s *StreamComponent) renderStreamingText(text string) string {
|
||||
if ts.IsZero() {
|
||||
ts = time.Now()
|
||||
}
|
||||
|
||||
if s.compactMode {
|
||||
msg := s.compactRenderer.RenderAssistantMessage(text, ts, s.modelName)
|
||||
return msg.Content
|
||||
if s.renderer == nil {
|
||||
return text
|
||||
}
|
||||
msg := s.messageRenderer.RenderAssistantMessage(text, ts, s.modelName)
|
||||
msg := s.renderer.RenderAssistantMessage(text, ts, s.modelName)
|
||||
return msg.Content
|
||||
}
|
||||
|
||||
// removeFromSlice removes the first occurrence of a string from a slice.
|
||||
func removeFromSlice(slice []string, s string) []string {
|
||||
for i, v := range slice {
|
||||
if v == s {
|
||||
return append(slice[:i], slice[i+1:]...)
|
||||
func (s *StreamComponent) activeToolDisplays() []string {
|
||||
if len(s.activeTools) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(s.activeToolOrder))
|
||||
for _, id := range s.activeToolOrder {
|
||||
if display, ok := s.activeTools[id]; ok {
|
||||
out = append(out, display)
|
||||
}
|
||||
}
|
||||
return slice
|
||||
return out
|
||||
}
|
||||
|
||||
// removeToolID removes the first occurrence of a tool ID from a slice.
|
||||
func removeToolID(ids []string, id string) []string {
|
||||
for i, v := range ids {
|
||||
if v == id {
|
||||
return append(ids[:i], ids[i+1:]...)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// formatToolExecutionMessage creates a descriptive spinner message for tool execution.
|
||||
// For spawn_subagent, it shows simply as "Subagent" with optional task preview.
|
||||
func formatToolExecutionMessage(toolName, toolArgs string) string {
|
||||
if toolName == "spawn_subagent" {
|
||||
return "Subagent"
|
||||
}
|
||||
func formatToolExecutionMessage(toolName string) string {
|
||||
return toolName
|
||||
}
|
||||
|
||||
// UpdateTheme refreshes the component's typography instance with colors from
|
||||
// the current theme. This is called when the user changes themes via /theme.
|
||||
func (s *StreamComponent) UpdateTheme() {
|
||||
s.ty = createTypography(GetTheme())
|
||||
}
|
||||
|
||||
+75
-248
@@ -1,19 +1,11 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image/color"
|
||||
|
||||
"charm.land/lipgloss/v2"
|
||||
"github.com/charmbracelet/glamour"
|
||||
"github.com/charmbracelet/glamour/ansi"
|
||||
"github.com/indaco/herald"
|
||||
heraldmd "github.com/indaco/herald-md"
|
||||
)
|
||||
|
||||
// uintPtr returns a pointer to u. Used by ansi.StyleConfig fields.
|
||||
//
|
||||
//go:fix inline
|
||||
func uintPtr(u uint) *uint { return new(u) }
|
||||
|
||||
// BaseStyle returns a new, empty lipgloss style that can be customized with
|
||||
// additional styling methods. This serves as the foundation for building more
|
||||
// complex styled components.
|
||||
@@ -21,248 +13,83 @@ func BaseStyle() lipgloss.Style {
|
||||
return lipgloss.NewStyle()
|
||||
}
|
||||
|
||||
// colorHex converts a color.Color to a hex string suitable for ansi.StyleConfig.
|
||||
func colorHex(c color.Color) string {
|
||||
r, g, b, _ := c.RGBA()
|
||||
return fmt.Sprintf("#%02x%02x%02x", r>>8, g>>8, b>>8)
|
||||
}
|
||||
// markdownTypographyCache holds the last-created Typography instance for
|
||||
// herald-md rendering. It is cached to avoid re-initialization on every
|
||||
// streaming flush tick. The cache is invalidated by SetTheme when the
|
||||
// active theme changes.
|
||||
// This is only accessed from BubbleTea's single-threaded Update/View cycle,
|
||||
// so no mutex is required.
|
||||
var markdownTypographyCache *herald.Typography
|
||||
|
||||
// colorHexPtr returns a pointer to the hex string of a color.Color.
|
||||
func colorHexPtr(c color.Color) *string {
|
||||
s := colorHex(c)
|
||||
return &s
|
||||
}
|
||||
|
||||
// GetMarkdownRenderer creates and returns a configured glamour.TermRenderer for
|
||||
// rendering markdown content with syntax highlighting and proper formatting. The
|
||||
// renderer is customized with our theme colors and adapted to the specified width.
|
||||
func GetMarkdownRenderer(width int) *glamour.TermRenderer {
|
||||
r, _ := glamour.NewTermRenderer(
|
||||
glamour.WithStyles(generateMarkdownStyleConfig()),
|
||||
glamour.WithWordWrap(width),
|
||||
)
|
||||
return r
|
||||
}
|
||||
|
||||
// generateMarkdownStyleConfig creates an ansi.StyleConfig from the active theme.
|
||||
func generateMarkdownStyleConfig() ansi.StyleConfig {
|
||||
md := GetTheme().Markdown
|
||||
text := colorHexPtr(md.Text)
|
||||
muted := colorHexPtr(md.Muted)
|
||||
heading := colorHexPtr(md.Heading)
|
||||
emph := colorHexPtr(md.Emph)
|
||||
strong := colorHexPtr(md.Strong)
|
||||
link := colorHexPtr(md.Link)
|
||||
code := colorHexPtr(md.Code)
|
||||
errClr := colorHexPtr(md.Error)
|
||||
keyword := colorHexPtr(md.Keyword)
|
||||
str := colorHexPtr(md.String)
|
||||
number := colorHexPtr(md.Number)
|
||||
comment := colorHexPtr(md.Comment)
|
||||
|
||||
return ansi.StyleConfig{
|
||||
Document: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
BlockPrefix: "",
|
||||
BlockSuffix: "",
|
||||
Color: text,
|
||||
},
|
||||
Margin: uintPtr(0),
|
||||
},
|
||||
BlockQuote: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
Color: muted,
|
||||
Italic: new(true),
|
||||
Prefix: "┃ ",
|
||||
},
|
||||
Indent: uintPtr(1),
|
||||
},
|
||||
List: ansi.StyleList{
|
||||
LevelIndent: 0,
|
||||
StyleBlock: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
Color: text,
|
||||
},
|
||||
},
|
||||
},
|
||||
Heading: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
BlockSuffix: "\n",
|
||||
Color: heading,
|
||||
Bold: new(true),
|
||||
},
|
||||
},
|
||||
H1: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
Prefix: "# ",
|
||||
Color: heading,
|
||||
Bold: new(true),
|
||||
},
|
||||
},
|
||||
H2: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
Prefix: "## ",
|
||||
Color: heading,
|
||||
Bold: new(true),
|
||||
},
|
||||
},
|
||||
H3: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
Prefix: "### ",
|
||||
Color: heading,
|
||||
Bold: new(true),
|
||||
},
|
||||
},
|
||||
H4: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
Prefix: "#### ",
|
||||
Color: heading,
|
||||
Bold: new(true),
|
||||
},
|
||||
},
|
||||
H5: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
Prefix: "##### ",
|
||||
Color: heading,
|
||||
Bold: new(true),
|
||||
},
|
||||
},
|
||||
H6: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
Prefix: "###### ",
|
||||
Color: heading,
|
||||
Bold: new(true),
|
||||
},
|
||||
},
|
||||
Strikethrough: ansi.StylePrimitive{
|
||||
CrossedOut: new(true),
|
||||
Color: muted,
|
||||
},
|
||||
Emph: ansi.StylePrimitive{
|
||||
Color: emph,
|
||||
Italic: new(true),
|
||||
},
|
||||
Strong: ansi.StylePrimitive{
|
||||
Bold: new(true),
|
||||
Color: strong,
|
||||
},
|
||||
HorizontalRule: ansi.StylePrimitive{
|
||||
Color: muted,
|
||||
Format: "\n─────────────────────────────────────────\n",
|
||||
},
|
||||
Item: ansi.StylePrimitive{
|
||||
BlockPrefix: "• ",
|
||||
Color: text,
|
||||
},
|
||||
Enumeration: ansi.StylePrimitive{
|
||||
BlockPrefix: ". ",
|
||||
Color: text,
|
||||
},
|
||||
Task: ansi.StyleTask{
|
||||
StylePrimitive: ansi.StylePrimitive{},
|
||||
Ticked: "[✓] ",
|
||||
Unticked: "[ ] ",
|
||||
},
|
||||
Link: ansi.StylePrimitive{
|
||||
Color: link,
|
||||
Underline: new(true),
|
||||
},
|
||||
LinkText: ansi.StylePrimitive{
|
||||
Color: link,
|
||||
Bold: new(true),
|
||||
},
|
||||
Image: ansi.StylePrimitive{
|
||||
Color: link,
|
||||
Underline: new(true),
|
||||
Format: "🖼 {{.text}}",
|
||||
},
|
||||
ImageText: ansi.StylePrimitive{
|
||||
Color: link,
|
||||
Format: "{{.text}}",
|
||||
},
|
||||
Code: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
Color: code,
|
||||
Prefix: "",
|
||||
Suffix: "",
|
||||
},
|
||||
},
|
||||
CodeBlock: ansi.StyleCodeBlock{
|
||||
StyleBlock: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
Prefix: "",
|
||||
Color: code,
|
||||
},
|
||||
Margin: uintPtr(0),
|
||||
},
|
||||
Chroma: &ansi.Chroma{
|
||||
Text: ansi.StylePrimitive{Color: text},
|
||||
Error: ansi.StylePrimitive{Color: errClr},
|
||||
Comment: ansi.StylePrimitive{Color: comment},
|
||||
CommentPreproc: ansi.StylePrimitive{Color: keyword},
|
||||
Keyword: ansi.StylePrimitive{Color: keyword},
|
||||
KeywordReserved: ansi.StylePrimitive{Color: keyword},
|
||||
KeywordNamespace: ansi.StylePrimitive{Color: keyword},
|
||||
KeywordType: ansi.StylePrimitive{Color: keyword},
|
||||
Operator: ansi.StylePrimitive{Color: text},
|
||||
Punctuation: ansi.StylePrimitive{Color: text},
|
||||
Name: ansi.StylePrimitive{Color: text},
|
||||
NameBuiltin: ansi.StylePrimitive{Color: text},
|
||||
NameTag: ansi.StylePrimitive{Color: keyword},
|
||||
NameAttribute: ansi.StylePrimitive{Color: text},
|
||||
NameClass: ansi.StylePrimitive{Color: keyword},
|
||||
NameConstant: ansi.StylePrimitive{Color: text},
|
||||
NameDecorator: ansi.StylePrimitive{Color: text},
|
||||
NameFunction: ansi.StylePrimitive{Color: text},
|
||||
LiteralNumber: ansi.StylePrimitive{Color: number},
|
||||
LiteralString: ansi.StylePrimitive{Color: str},
|
||||
LiteralStringEscape: ansi.StylePrimitive{
|
||||
Color: keyword,
|
||||
},
|
||||
GenericDeleted: ansi.StylePrimitive{Color: errClr},
|
||||
GenericEmph: ansi.StylePrimitive{
|
||||
Color: emph,
|
||||
Italic: new(true),
|
||||
},
|
||||
GenericInserted: ansi.StylePrimitive{Color: str},
|
||||
GenericStrong: ansi.StylePrimitive{
|
||||
Color: strong,
|
||||
Bold: new(true),
|
||||
},
|
||||
GenericSubheading: ansi.StylePrimitive{
|
||||
Color: heading,
|
||||
},
|
||||
},
|
||||
},
|
||||
Table: ansi.StyleTable{
|
||||
StyleBlock: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
BlockPrefix: "\n",
|
||||
BlockSuffix: "\n",
|
||||
},
|
||||
},
|
||||
CenterSeparator: new("┼"),
|
||||
ColumnSeparator: new("│"),
|
||||
RowSeparator: new("─"),
|
||||
},
|
||||
DefinitionDescription: ansi.StylePrimitive{
|
||||
BlockPrefix: "\n ❯ ",
|
||||
Color: link,
|
||||
},
|
||||
Text: ansi.StylePrimitive{
|
||||
Color: text,
|
||||
},
|
||||
Paragraph: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
Color: text,
|
||||
},
|
||||
},
|
||||
// GetMarkdownTypography returns a herald.Typography configured with our
|
||||
// active theme colors. The typography is cached and only rebuilt when
|
||||
// the theme changes via SetTheme.
|
||||
func GetMarkdownTypography() *herald.Typography {
|
||||
if markdownTypographyCache != nil {
|
||||
return markdownTypographyCache
|
||||
}
|
||||
|
||||
theme := GetTheme()
|
||||
md := theme.Markdown
|
||||
|
||||
// Build herald theme from our theme colors
|
||||
hty := herald.Theme{
|
||||
// Headings - use heading color
|
||||
H1: lipgloss.NewStyle().Foreground(md.Heading).Bold(true),
|
||||
H2: lipgloss.NewStyle().Foreground(md.Heading).Bold(true),
|
||||
H3: lipgloss.NewStyle().Foreground(md.Heading).Bold(true),
|
||||
H4: lipgloss.NewStyle().Foreground(md.Heading).Bold(true),
|
||||
H5: lipgloss.NewStyle().Foreground(md.Heading).Bold(true),
|
||||
H6: lipgloss.NewStyle().Foreground(md.Muted).Bold(true),
|
||||
|
||||
// Text blocks
|
||||
Paragraph: lipgloss.NewStyle().Foreground(md.Text),
|
||||
Blockquote: lipgloss.NewStyle().Foreground(md.Muted).Italic(true),
|
||||
CodeInline: lipgloss.NewStyle().Foreground(md.Code),
|
||||
CodeBlock: lipgloss.NewStyle().Foreground(md.Code),
|
||||
HR: lipgloss.NewStyle().Foreground(md.Muted),
|
||||
|
||||
// Lists
|
||||
ListBullet: lipgloss.NewStyle().Foreground(md.Text),
|
||||
ListItem: lipgloss.NewStyle().Foreground(md.Text),
|
||||
|
||||
// Inline styles
|
||||
Bold: lipgloss.NewStyle().Foreground(md.Strong).Bold(true),
|
||||
Italic: lipgloss.NewStyle().Foreground(md.Emph).Italic(true),
|
||||
Strikethrough: lipgloss.NewStyle().Foreground(md.Muted).Strikethrough(true),
|
||||
Link: lipgloss.NewStyle().Foreground(md.Link).Underline(true),
|
||||
|
||||
// Definition lists
|
||||
DT: lipgloss.NewStyle().Foreground(md.Text).Bold(true),
|
||||
DD: lipgloss.NewStyle().Foreground(md.Muted),
|
||||
|
||||
// Key-value
|
||||
KVKey: lipgloss.NewStyle().Foreground(md.Text).Bold(true),
|
||||
KVValue: lipgloss.NewStyle().Foreground(md.Text),
|
||||
|
||||
// Badges/Tags - use semantic colors
|
||||
Badge: lipgloss.NewStyle().Foreground(md.Text).Bold(true),
|
||||
SuccessBadge: lipgloss.NewStyle().Foreground(theme.Success).Bold(true),
|
||||
WarningBadge: lipgloss.NewStyle().Foreground(theme.Warning).Bold(true),
|
||||
ErrorBadge: lipgloss.NewStyle().Foreground(theme.Error).Bold(true),
|
||||
InfoBadge: lipgloss.NewStyle().Foreground(theme.Info).Bold(true),
|
||||
|
||||
// Heading decorations
|
||||
H1UnderlineChar: "═",
|
||||
H2UnderlineChar: "─",
|
||||
H3UnderlineChar: "·",
|
||||
}
|
||||
|
||||
ty := herald.New(herald.WithTheme(hty))
|
||||
markdownTypographyCache = ty
|
||||
return ty
|
||||
}
|
||||
|
||||
// toMarkdown renders markdown content using glamour.
|
||||
// toMarkdown renders markdown content using herald-md.
|
||||
// The width parameter is currently unused as herald handles wrapping
|
||||
// based on terminal width internally.
|
||||
func toMarkdown(content string, width int) string {
|
||||
r := GetMarkdownRenderer(width)
|
||||
rendered, _ := r.Render(content)
|
||||
ty := GetMarkdownTypography()
|
||||
rendered := heraldmd.Render(ty, []byte(content))
|
||||
return rendered
|
||||
}
|
||||
|
||||
+10
-3
@@ -129,13 +129,20 @@ type presetColors struct {
|
||||
}
|
||||
|
||||
func makeTheme(p presetColors) Theme {
|
||||
ac := func(pair [2]string) color.Color { return AdaptiveColor(pair[0], pair[1]) }
|
||||
def := DefaultTheme()
|
||||
acOr := func(pair [2]string, fb color.Color) color.Color {
|
||||
ac := func(pair [2]string) color.Color {
|
||||
c := AdaptiveColor(pair[0], pair[1])
|
||||
if pair[0] == "" && pair[1] == "" {
|
||||
return nil
|
||||
}
|
||||
return c
|
||||
}
|
||||
acOr := func(pair [2]string, fb color.Color) color.Color {
|
||||
c := ac(pair)
|
||||
if c == nil {
|
||||
return fb
|
||||
}
|
||||
return ac(pair)
|
||||
return c
|
||||
}
|
||||
t := Theme{
|
||||
Primary: ac(p.primary),
|
||||
|
||||
@@ -83,9 +83,19 @@ func (t *ToolApprovalInput) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
|
||||
func (t *ToolApprovalInput) View() tea.View {
|
||||
if t.done {
|
||||
return tea.NewView("we are done")
|
||||
v := tea.NewView("")
|
||||
v.AltScreen = true
|
||||
v.MouseMode = tea.MouseModeCellMotion
|
||||
v.ReportFocus = true
|
||||
v.KeyboardEnhancements = tea.KeyboardEnhancements{
|
||||
ReportEventTypes: true,
|
||||
}
|
||||
|
||||
if t.done {
|
||||
v.Content = "we are done"
|
||||
return v
|
||||
}
|
||||
|
||||
containerStyle := lipgloss.NewStyle()
|
||||
|
||||
theme := GetTheme()
|
||||
@@ -135,5 +145,6 @@ func (t *ToolApprovalInput) View() tea.View {
|
||||
}
|
||||
view.WriteString(yesText + "/" + noText + "\n")
|
||||
|
||||
return tea.NewView(containerStyle.Render(inputBoxStyle.Render(view.String())))
|
||||
v.Content = containerStyle.Render(inputBoxStyle.Render(view.String()))
|
||||
return v
|
||||
}
|
||||
|
||||
+157
-336
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
"github.com/alecthomas/chroma/v2/styles"
|
||||
udiff "github.com/aymanbagabas/go-udiff"
|
||||
xansi "github.com/charmbracelet/x/ansi"
|
||||
"github.com/indaco/herald"
|
||||
)
|
||||
|
||||
// Maximum visible lines per tool type before truncation.
|
||||
@@ -26,6 +28,13 @@ const (
|
||||
maxLsLines = 20 // lines for Ls directory listings
|
||||
)
|
||||
|
||||
// isShellTool reports if the tool name matches a shell-like tool (bash, grep, find, or
|
||||
// tools with "shell"/"command" in the name). Used by renderToolBody.
|
||||
func isShellTool(toolName string) bool {
|
||||
return toolName == "bash" || toolName == "grep" || toolName == "find" ||
|
||||
strings.Contains(toolName, "shell") || strings.Contains(toolName, "command")
|
||||
}
|
||||
|
||||
// renderToolBody dispatches to tool-specific body renderers based on tool name.
|
||||
// Returns the styled body string, or empty string to fall back to default rendering.
|
||||
func renderToolBody(toolName, toolArgs, toolResult string, width int) string {
|
||||
@@ -46,12 +55,11 @@ func renderToolBody(toolName, toolArgs, toolResult string, width int) string {
|
||||
if body := renderWriteBody(toolArgs, toolResult, width); body != "" {
|
||||
return body
|
||||
}
|
||||
case toolName == "bash" || toolName == "run_shell_cmd" ||
|
||||
strings.Contains(toolName, "shell") || strings.Contains(toolName, "command"):
|
||||
case isShellTool(toolName):
|
||||
if body := renderBashBody(toolResult, width); body != "" {
|
||||
return body
|
||||
}
|
||||
case toolName == "spawn_subagent":
|
||||
case toolName == "subagent":
|
||||
if body := renderSubagentBody(toolResult, width); body != "" {
|
||||
return body
|
||||
}
|
||||
@@ -64,21 +72,44 @@ func renderToolBody(toolName, toolArgs, toolResult string, width int) string {
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// renderEditBody renders a side-by-side diff from old_text/new_text in toolArgs.
|
||||
// Supports both single-edit mode and multi-edit mode (edits array).
|
||||
func renderEditBody(toolArgs, toolResult string, width int) string {
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(toolArgs), &args); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Try to extract the starting line number from the unified diff in the result
|
||||
startLine := extractDiffStartLine(toolResult)
|
||||
|
||||
// Check for multi-edit mode (edits array)
|
||||
if editsArr, ok := args["edits"].([]any); ok && len(editsArr) > 0 {
|
||||
var results []string
|
||||
for _, edit := range editsArr {
|
||||
if e, ok := edit.(map[string]any); ok {
|
||||
oldText, _ := e["old_text"].(string)
|
||||
newText, _ := e["new_text"].(string)
|
||||
if oldText != "" || newText != "" {
|
||||
diff := renderDiffBlock(oldText, newText, startLine, width)
|
||||
if diff != "" {
|
||||
results = append(results, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(results) > 0 {
|
||||
return strings.Join(results, "\n")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Single-edit mode (legacy)
|
||||
oldText, _ := args["old_text"].(string)
|
||||
newText, _ := args["new_text"].(string)
|
||||
if oldText == "" && newText == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Try to extract the starting line number from the unified diff in the result
|
||||
startLine := extractDiffStartLine(toolResult)
|
||||
|
||||
return renderDiffBlock(oldText, newText, startLine, width)
|
||||
}
|
||||
|
||||
@@ -221,7 +252,7 @@ func renderDiffBlock(before, after string, startLine int, width int) string {
|
||||
gutterWidth := max(len(fmt.Sprintf("%d", maxLineNum)), 3)
|
||||
contentWidth := max(panelWidth-gutterWidth-4, 10) // gutter + " - " or " + "
|
||||
|
||||
theme := getTheme()
|
||||
theme := GetTheme()
|
||||
|
||||
// Styles for each cell type
|
||||
gutterInsert := lipgloss.NewStyle().Foreground(theme.Muted).Background(theme.DiffInsertBg)
|
||||
@@ -326,7 +357,7 @@ func renderLsBody(toolResult string, width int) string {
|
||||
const indent = " "
|
||||
codeWidth := max(width-len(indent), 20)
|
||||
|
||||
theme := getTheme()
|
||||
theme := GetTheme()
|
||||
codeStyle := lipgloss.NewStyle().Background(theme.CodeBg).PaddingLeft(1)
|
||||
|
||||
var result []string
|
||||
@@ -351,137 +382,106 @@ func renderLsBody(toolResult string, width int) string {
|
||||
// Read tool — code block with line numbers + syntax highlighting
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// renderReadBody renders Read tool output with styled line numbers and optional
|
||||
// syntax highlighting based on file extension.
|
||||
// renderReadBody renders Read tool output using herald.CodeBlock with line numbers
|
||||
// and syntax highlighting. Uses WithCodeLineNumberOffset to show correct offsets
|
||||
// based on the Read tool's offset parameter.
|
||||
func renderReadBody(toolArgs, toolResult string, width int) string {
|
||||
if strings.TrimSpace(toolResult) == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Extract file path for syntax highlighting
|
||||
// Extract file path and offset from tool args
|
||||
var fileName string
|
||||
var offset = 1
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(toolArgs), &args); err == nil {
|
||||
if p, ok := args["path"].(string); ok {
|
||||
fileName = p
|
||||
}
|
||||
if o, ok := args["offset"].(float64); ok {
|
||||
offset = int(o)
|
||||
}
|
||||
}
|
||||
|
||||
return renderCodeBlock(toolResult, fileName, width)
|
||||
}
|
||||
|
||||
// codeLine holds a parsed line with optional line number.
|
||||
type codeLine struct {
|
||||
lineNum string
|
||||
code string
|
||||
}
|
||||
|
||||
// renderCodeBlock renders content with a styled gutter (line numbers) and
|
||||
// optional syntax highlighting.
|
||||
func renderCodeBlock(content, fileName string, width int) string {
|
||||
rawLines := strings.Split(content, "\n")
|
||||
|
||||
// Parse lines: detect "N: content" format from Read tool
|
||||
var parsed []codeLine
|
||||
maxNumWidth := 0
|
||||
var codeOnly []string
|
||||
// Parse lines to extract pure code content (removing "N: " prefixes)
|
||||
rawLines := strings.Split(toolResult, "\n")
|
||||
var codeLines []string
|
||||
var footerLines []string
|
||||
var codeHiddenCount int
|
||||
|
||||
for _, line := range rawLines {
|
||||
// Detect "N: content" format from Read tool
|
||||
if idx := strings.Index(line, ": "); idx > 0 && idx <= 7 {
|
||||
numPart := line[:idx]
|
||||
if _, err := strconv.Atoi(strings.TrimSpace(numPart)); err == nil {
|
||||
parsed = append(parsed, codeLine{lineNum: numPart, code: line[idx+2:]})
|
||||
if len(numPart) > maxNumWidth {
|
||||
maxNumWidth = len(numPart)
|
||||
}
|
||||
codeOnly = append(codeOnly, line[idx+2:])
|
||||
codeLines = append(codeLines, line[idx+2:])
|
||||
continue
|
||||
}
|
||||
}
|
||||
// No line number — treat as metadata/footer
|
||||
parsed = append(parsed, codeLine{code: line})
|
||||
codeOnly = append(codeOnly, line)
|
||||
// No line number — treat as footer/metadata (e.g., truncation notice)
|
||||
footerLines = append(footerLines, line)
|
||||
}
|
||||
|
||||
if len(parsed) == 0 {
|
||||
return ""
|
||||
// Apply maxCodeLines truncation
|
||||
totalCodeLines := len(codeLines)
|
||||
if totalCodeLines > maxCodeLines {
|
||||
codeHiddenCount = totalCodeLines - maxCodeLines
|
||||
codeLines = codeLines[:maxCodeLines]
|
||||
}
|
||||
|
||||
// Truncate to maxCodeLines visible lines (preserve footer/metadata lines)
|
||||
var codeHiddenCount int
|
||||
totalParsed := len(parsed)
|
||||
if totalParsed > maxCodeLines {
|
||||
// Check if last line is a footer (no line number) — keep it
|
||||
var footerLines []codeLine
|
||||
for totalParsed > 0 && parsed[totalParsed-1].lineNum == "" {
|
||||
footerLines = append([]codeLine{parsed[totalParsed-1]}, footerLines...)
|
||||
totalParsed--
|
||||
}
|
||||
if totalParsed > maxCodeLines {
|
||||
codeHiddenCount = totalParsed - maxCodeLines
|
||||
parsed = append(parsed[:maxCodeLines], footerLines...)
|
||||
codeOnly = codeOnly[:maxCodeLines]
|
||||
for _, fl := range footerLines {
|
||||
codeOnly = append(codeOnly, fl.code)
|
||||
}
|
||||
} else {
|
||||
// Restore — footer trimming was enough
|
||||
parsed = parsed[:totalParsed]
|
||||
parsed = append(parsed, footerLines...)
|
||||
// Build language hint from file extension
|
||||
lang := ""
|
||||
if fileName != "" {
|
||||
// Extract extension without the dot
|
||||
if ext := strings.TrimPrefix(filepath.Ext(fileName), "."); ext != "" {
|
||||
lang = ext
|
||||
}
|
||||
}
|
||||
|
||||
// Syntax highlight the code portion
|
||||
highlighted := syntaxHighlight(strings.Join(codeOnly, "\n"), fileName)
|
||||
highlightedLines := strings.Split(highlighted, "\n")
|
||||
|
||||
// Layout
|
||||
const codeIndent = " "
|
||||
gutterWidth := max(maxNumWidth+2, 5)
|
||||
codeWidth := max(width-gutterWidth-len(codeIndent), 20)
|
||||
|
||||
theme := getTheme()
|
||||
gutterStyle := lipgloss.NewStyle().Foreground(theme.Muted).Background(theme.GutterBg).PaddingRight(1)
|
||||
codeStyle := lipgloss.NewStyle().Background(theme.CodeBg).PaddingLeft(1)
|
||||
|
||||
var result []string
|
||||
for i, p := range parsed {
|
||||
// If this line has no line number, it's a metadata/footer line (e.g. truncation notice).
|
||||
if p.lineNum == "" {
|
||||
// Render footer lines with code background but no gutter
|
||||
truncatedFooter := truncateLine(p.code, codeWidth-1) // account for PaddingLeft(1)
|
||||
footer := codeStyle.Width(codeWidth).Render(truncatedFooter)
|
||||
emptyGutter := gutterStyle.Width(gutterWidth).Render("")
|
||||
result = append(result, codeIndent+lipgloss.JoinHorizontal(lipgloss.Top, emptyGutter, footer))
|
||||
continue
|
||||
}
|
||||
|
||||
gutter := gutterStyle.Width(gutterWidth).Render(p.lineNum)
|
||||
|
||||
var codePart string
|
||||
if i < len(highlightedLines) {
|
||||
codePart = highlightedLines[i]
|
||||
} else {
|
||||
codePart = p.code
|
||||
}
|
||||
// Truncate the (possibly ANSI-highlighted) line to fit within
|
||||
// the code column, preventing lipgloss from wrapping it.
|
||||
codePart = truncateLine(codePart, codeWidth-1) // account for PaddingLeft(1)
|
||||
styledCode := codeStyle.Width(codeWidth).Render(codePart)
|
||||
|
||||
result = append(result, codeIndent+lipgloss.JoinHorizontal(lipgloss.Top, gutter, styledCode))
|
||||
// Create typography with line number offset and custom formatter
|
||||
// Match Write tool: GutterBg for line numbers, CodeBg for content
|
||||
codeContent := strings.Join(codeLines, "\n")
|
||||
theme := GetTheme()
|
||||
hty := herald.Theme{
|
||||
CodeBlock: lipgloss.NewStyle().
|
||||
Background(theme.CodeBg).
|
||||
PaddingLeft(1),
|
||||
CodeLineNumber: lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.GutterBg),
|
||||
}
|
||||
ty := herald.New(
|
||||
herald.WithTheme(hty),
|
||||
herald.WithCodeLineNumbers(true),
|
||||
herald.WithCodeLineNumberOffset(offset),
|
||||
herald.WithCodeFormatter(func(code, _ string) string {
|
||||
// Use our syntax highlighter with the filename for lexer detection
|
||||
return syntaxHighlight(code, fileName)
|
||||
}),
|
||||
)
|
||||
|
||||
// Truncation hint
|
||||
// Render the code block
|
||||
result := ty.CodeBlock(codeContent, lang)
|
||||
|
||||
// Add truncation hint if needed
|
||||
if codeHiddenCount > 0 {
|
||||
hint := fmt.Sprintf("...(%d more lines)", codeHiddenCount)
|
||||
emptyGutter := gutterStyle.Width(gutterWidth).Render("")
|
||||
hintContent := codeStyle.Width(codeWidth).
|
||||
Foreground(theme.Muted).Italic(true).Render(hint)
|
||||
result = append(result, codeIndent+lipgloss.JoinHorizontal(lipgloss.Top, emptyGutter, hintContent))
|
||||
result += "\n" + lipgloss.NewStyle().Foreground(GetTheme().Muted).Italic(true).Render(hint)
|
||||
}
|
||||
|
||||
return strings.Join(result, "\n")
|
||||
// Add any footer lines
|
||||
if len(footerLines) > 0 {
|
||||
footer := strings.Join(footerLines, "\n")
|
||||
result += "\n" + lipgloss.NewStyle().Foreground(GetTheme().Muted).Render(footer)
|
||||
}
|
||||
|
||||
// Indent entire block to match Write/Edit tools (2 spaces)
|
||||
const blockIndent = " "
|
||||
lines := strings.Split(result, "\n")
|
||||
for i, line := range lines {
|
||||
lines[i] = blockIndent + line
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -535,7 +535,7 @@ func renderWriteBlock(content, fileName string, width int) string {
|
||||
gutterWidth := numDigits + 2
|
||||
codeWidth := max(width-gutterWidth-len(codeIndent), 20)
|
||||
|
||||
theme := getTheme()
|
||||
theme := GetTheme()
|
||||
gutterStyle := lipgloss.NewStyle().Foreground(theme.Muted).Background(theme.GutterBg).PaddingRight(1)
|
||||
writeStyle := lipgloss.NewStyle().Background(theme.WriteBg).PaddingLeft(1)
|
||||
|
||||
@@ -587,7 +587,7 @@ func renderBashBody(toolResult string, width int) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
theme := getTheme()
|
||||
theme := GetTheme()
|
||||
outputStyle := lipgloss.NewStyle().Background(theme.CodeBg).PaddingLeft(1)
|
||||
stderrStyle := lipgloss.NewStyle().Foreground(theme.Error).Background(theme.CodeBg).PaddingLeft(1)
|
||||
|
||||
@@ -604,7 +604,6 @@ func renderBashBody(toolResult string, width int) string {
|
||||
|
||||
const lineIndent = " "
|
||||
// Truncate individual lines to the available width so they never wrap.
|
||||
// This mirrors Crush's approach: truncate, don't wrap.
|
||||
lineWidth := max(width-len(lineIndent), 20)
|
||||
// Account for PaddingLeft(1) on the output/stderr styles
|
||||
maxLineChars := lineWidth - 1
|
||||
@@ -738,188 +737,10 @@ func truncateLine(s string, maxWidth int) string {
|
||||
return xansi.Truncate(s, maxWidth, "…")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Compact tool body renderers — one-line summaries for compact mode
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// renderToolBodyCompact returns a brief summary string for tool results in
|
||||
// compact display mode. Returns empty string to fall back to default.
|
||||
func renderToolBodyCompact(toolName, toolArgs, toolResult string, width int) string {
|
||||
switch {
|
||||
case toolName == "edit":
|
||||
return renderEditCompact(toolArgs, toolResult)
|
||||
case toolName == "ls":
|
||||
return renderLsCompact(toolResult)
|
||||
case toolName == "read":
|
||||
return renderReadCompact(toolResult)
|
||||
case toolName == "write":
|
||||
return renderWriteCompact(toolArgs)
|
||||
case toolName == "bash" || toolName == "run_shell_cmd" ||
|
||||
strings.Contains(toolName, "shell") || strings.Contains(toolName, "command"):
|
||||
return renderBashCompact(toolResult, width)
|
||||
case toolName == "spawn_subagent":
|
||||
return renderSubagentCompact(toolResult)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// renderReadCompact returns a line-count summary for Read tool output.
|
||||
func renderReadCompact(toolResult string) string {
|
||||
content := strings.TrimSpace(toolResult)
|
||||
if content == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
lines := strings.Split(content, "\n")
|
||||
|
||||
// Count actual code lines (those with "N: " line-number prefix)
|
||||
codeLines := 0
|
||||
for _, line := range lines {
|
||||
if idx := strings.Index(line, ": "); idx > 0 && idx <= 7 {
|
||||
numPart := line[:idx]
|
||||
if _, err := strconv.Atoi(strings.TrimSpace(numPart)); err == nil {
|
||||
codeLines++
|
||||
}
|
||||
}
|
||||
}
|
||||
if codeLines == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
theme := getTheme()
|
||||
summary := fmt.Sprintf("%d lines", codeLines)
|
||||
return lipgloss.NewStyle().Foreground(theme.Muted).Italic(true).Render(summary)
|
||||
}
|
||||
|
||||
// renderEditCompact returns a change-count summary for Edit tool output.
|
||||
func renderEditCompact(toolArgs, toolResult string) string {
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(toolArgs), &args); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
oldText, _ := args["old_text"].(string)
|
||||
newText, _ := args["new_text"].(string)
|
||||
if oldText == "" && newText == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
oldCount := len(strings.Split(oldText, "\n"))
|
||||
newCount := len(strings.Split(newText, "\n"))
|
||||
|
||||
theme := getTheme()
|
||||
var summary string
|
||||
if oldCount == newCount {
|
||||
summary = fmt.Sprintf("%d lines modified", oldCount)
|
||||
} else {
|
||||
summary = fmt.Sprintf("-%d/+%d lines", oldCount, newCount)
|
||||
}
|
||||
return lipgloss.NewStyle().Foreground(theme.Muted).Italic(true).Render(summary)
|
||||
}
|
||||
|
||||
// renderWriteCompact returns a line-count summary for Write tool output.
|
||||
func renderWriteCompact(toolArgs string) string {
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(toolArgs), &args); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
content, _ := args["content"].(string)
|
||||
if content == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
count := len(strings.Split(content, "\n"))
|
||||
theme := getTheme()
|
||||
summary := fmt.Sprintf("%d lines written", count)
|
||||
return lipgloss.NewStyle().Foreground(theme.Muted).Italic(true).Render(summary)
|
||||
}
|
||||
|
||||
// renderLsCompact returns an entry-count summary for Ls tool output.
|
||||
func renderLsCompact(toolResult string) string {
|
||||
content := strings.TrimSpace(toolResult)
|
||||
if content == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
entries := strings.Split(content, "\n")
|
||||
theme := getTheme()
|
||||
summary := fmt.Sprintf("%d entries", len(entries))
|
||||
return lipgloss.NewStyle().Foreground(theme.Muted).Italic(true).Render(summary)
|
||||
}
|
||||
|
||||
// renderBashCompact returns the first few lines of bash output as a compact
|
||||
// summary. Shows up to 3 meaningful output lines.
|
||||
func renderBashCompact(toolResult string, width int) string {
|
||||
result := strings.TrimSpace(toolResult)
|
||||
if result == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
lines := strings.Split(result, "\n")
|
||||
|
||||
// Filter to meaningful output lines (skip STDERR: label, keep exit codes separate)
|
||||
var outputLines []string
|
||||
var exitCode string
|
||||
inStderr := false
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "STDERR:" {
|
||||
inStderr = true
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(trimmed, "Exit code:") {
|
||||
exitCode = trimmed
|
||||
continue
|
||||
}
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
outputLines = append(outputLines, line)
|
||||
_ = inStderr // stderr lines are included in output
|
||||
}
|
||||
|
||||
if len(outputLines) == 0 {
|
||||
if exitCode != "" {
|
||||
theme := getTheme()
|
||||
return lipgloss.NewStyle().Foreground(theme.Error).Render(exitCode)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
const maxLines = 3
|
||||
theme := getTheme()
|
||||
|
||||
display := outputLines
|
||||
if len(display) > maxLines {
|
||||
display = display[:maxLines]
|
||||
}
|
||||
|
||||
// Truncate each line to available width (ANSI-aware)
|
||||
lineMax := max(width-4, 20)
|
||||
for i, line := range display {
|
||||
display[i] = truncateLine(line, lineMax)
|
||||
}
|
||||
|
||||
summary := strings.Join(display, "\n")
|
||||
if len(outputLines) > maxLines {
|
||||
summary += fmt.Sprintf("\n...(%d more lines)", len(outputLines)-maxLines)
|
||||
}
|
||||
if exitCode != "" {
|
||||
summary += "\n" + lipgloss.NewStyle().Foreground(theme.Error).Render(exitCode)
|
||||
}
|
||||
|
||||
return lipgloss.NewStyle().Foreground(theme.Muted).Render(summary)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Subagent tool renderers — show only summary, not full output
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// renderSubagentBody renders a clean summary of subagent results.
|
||||
// Extracts timing/token info and shows only a brief summary instead of raw output.
|
||||
// renderSubagentBody renders a clean summary of subagent results with bash-style
|
||||
// background styling for consistency with other tools.
|
||||
func renderSubagentBody(toolResult string, width int) string {
|
||||
theme := getTheme()
|
||||
theme := GetTheme()
|
||||
result := strings.TrimSpace(toolResult)
|
||||
if result == "" {
|
||||
return ""
|
||||
@@ -937,9 +758,19 @@ func renderSubagentBody(toolResult string, width int) string {
|
||||
// First line is always the status summary
|
||||
statusLine := lines[0]
|
||||
|
||||
// Build a clean summary
|
||||
var summary strings.Builder
|
||||
summary.WriteString(lipgloss.NewStyle().Foreground(theme.Muted).Render(statusLine))
|
||||
// Build content lines for display with bash-style background
|
||||
outputStyle := lipgloss.NewStyle().Background(theme.CodeBg).PaddingLeft(1)
|
||||
errorStyle := lipgloss.NewStyle().Foreground(theme.Error).Background(theme.CodeBg).PaddingLeft(1)
|
||||
|
||||
const lineIndent = " "
|
||||
lineWidth := max(width-len(lineIndent), 20)
|
||||
maxLineChars := lineWidth - 1 // account for PaddingLeft(1)
|
||||
|
||||
var contentLines []string
|
||||
|
||||
// Add status line
|
||||
styledStatus := outputStyle.Width(lineWidth).Render(truncateLine(statusLine, maxLineChars))
|
||||
contentLines = append(contentLines, lineIndent+styledStatus)
|
||||
|
||||
// For successful results, extract a brief preview of the actual result
|
||||
if strings.Contains(statusLine, "successfully") {
|
||||
@@ -947,25 +778,45 @@ func renderSubagentBody(toolResult string, width int) string {
|
||||
if _, resultContent, found := strings.Cut(result, "Result:\n"); found {
|
||||
resultContent = strings.TrimSpace(resultContent)
|
||||
if resultContent != "" {
|
||||
// Show first 3 meaningful lines as preview
|
||||
preview := extractSubagentPreview(resultContent, 3, width-4)
|
||||
if preview != "" {
|
||||
summary.WriteString("\n\n")
|
||||
summary.WriteString(lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Italic(true).
|
||||
Render(preview))
|
||||
// Show first few meaningful lines as preview
|
||||
previewLines := extractSubagentPreviewLines(resultContent, 5, maxLineChars)
|
||||
if len(previewLines) > 0 {
|
||||
// Add blank separator line
|
||||
blankLine := outputStyle.Width(lineWidth).Render("")
|
||||
contentLines = append(contentLines, lineIndent+blankLine)
|
||||
|
||||
for _, line := range previewLines {
|
||||
styled := outputStyle.Width(lineWidth).Render(line)
|
||||
contentLines = append(contentLines, lineIndent+styled)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// For failed results, show error info
|
||||
if _, errorContent, found := strings.Cut(result, "Error:\n"); found {
|
||||
errorContent = strings.TrimSpace(errorContent)
|
||||
if errorContent != "" {
|
||||
previewLines := extractSubagentPreviewLines(errorContent, 3, maxLineChars)
|
||||
if len(previewLines) > 0 {
|
||||
blankLine := outputStyle.Width(lineWidth).Render("")
|
||||
contentLines = append(contentLines, lineIndent+blankLine)
|
||||
|
||||
for _, line := range previewLines {
|
||||
styled := errorStyle.Width(lineWidth).Render(line)
|
||||
contentLines = append(contentLines, lineIndent+styled)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return summary.String()
|
||||
return strings.Join(contentLines, "\n")
|
||||
}
|
||||
|
||||
// extractSubagentPreview extracts the first N non-empty lines from content,
|
||||
// truncating each line to maxWidth.
|
||||
func extractSubagentPreview(content string, maxLines, maxWidth int) string {
|
||||
// extractSubagentPreviewLines extracts the first N non-empty lines from content,
|
||||
// truncating each line to maxWidth. Returns as a slice of strings.
|
||||
func extractSubagentPreviewLines(content string, maxLines, maxWidth int) []string {
|
||||
lines := strings.Split(content, "\n")
|
||||
var preview []string
|
||||
|
||||
@@ -984,12 +835,6 @@ func extractSubagentPreview(content string, maxLines, maxWidth int) string {
|
||||
}
|
||||
}
|
||||
|
||||
if len(preview) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
result := strings.Join(preview, "\n")
|
||||
|
||||
// Count remaining lines for "more" indicator
|
||||
totalLines := 0
|
||||
for _, line := range lines {
|
||||
@@ -998,32 +843,8 @@ func extractSubagentPreview(content string, maxLines, maxWidth int) string {
|
||||
}
|
||||
}
|
||||
if totalLines > maxLines {
|
||||
result += fmt.Sprintf("\n...(%d more lines)", totalLines-maxLines)
|
||||
preview = append(preview, fmt.Sprintf("...(%d more lines)", totalLines-maxLines))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// renderSubagentCompact returns a brief one-line summary for subagent results.
|
||||
func renderSubagentCompact(toolResult string) string {
|
||||
result := strings.TrimSpace(toolResult)
|
||||
if result == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
theme := getTheme()
|
||||
|
||||
// Extract just the first line which contains the status
|
||||
lines := strings.Split(result, "\n")
|
||||
if len(lines) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
statusLine := lines[0]
|
||||
|
||||
// Make it more compact by removing redundant words
|
||||
statusLine = strings.Replace(statusLine, "Subagent completed successfully in ", "Completed in ", 1)
|
||||
statusLine = strings.Replace(statusLine, "Subagent failed", "Failed", 1)
|
||||
|
||||
return lipgloss.NewStyle().Foreground(theme.Muted).Italic(true).Render(statusLine)
|
||||
return preview
|
||||
}
|
||||
|
||||
@@ -265,7 +265,14 @@ func (ts *TreeSelectorComponent) View() tea.View {
|
||||
footer := fmt.Sprintf("(%d/%d) [%s]", ts.cursor+1, len(ts.flatNodes), ts.filter)
|
||||
b.WriteString(footerStyle.Render(footer))
|
||||
|
||||
return tea.NewView(b.String())
|
||||
v := tea.NewView(b.String())
|
||||
v.AltScreen = true
|
||||
v.MouseMode = tea.MouseModeCellMotion
|
||||
v.ReportFocus = true
|
||||
v.KeyboardEnhancements = tea.KeyboardEnhancements{
|
||||
ReportEventTypes: true,
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// IsActive returns whether the tree selector is still accepting input.
|
||||
|
||||
@@ -134,13 +134,23 @@ func (ut *UsageTracker) EstimateAndUpdateUsage(inputText, outputText string) {
|
||||
}
|
||||
|
||||
// SetContextTokens records the approximate current context window utilization.
|
||||
// This should be set from the final API call's input + output tokens (i.e.
|
||||
// FinalResponse.Usage) rather than the aggregate TotalUsage, because TotalUsage
|
||||
// This should be set from FinalUsage.InputTokens, which already includes the
|
||||
// full conversation history (system prompt + all previous messages). Do NOT
|
||||
// add OutputTokens as that would double-count (output becomes input next turn).
|
||||
// Use FinalResponse.Usage rather than aggregate TotalUsage, because TotalUsage
|
||||
// sums across all tool-calling steps and overstates the actual window fill level.
|
||||
func (ut *UsageTracker) SetContextTokens(tokens int) {
|
||||
ut.mu.Lock()
|
||||
defer ut.mu.Unlock()
|
||||
ut.contextTokens = tokens
|
||||
// Track the maximum context seen so far. In multi-step tool calls,
|
||||
// FinalUsage.InputTokens may reflect only the last step's input, which
|
||||
// can be smaller than previous steps. We want to show the largest context
|
||||
// the model has processed in this session.
|
||||
if tokens > ut.contextTokens {
|
||||
ut.contextTokens = tokens
|
||||
}
|
||||
// If tokens < current, we keep the larger value (no-op)
|
||||
// This prevents the display from dropping during multi-step tool calls.
|
||||
}
|
||||
|
||||
// RenderUsageInfo generates a formatted string displaying current usage statistics
|
||||
@@ -151,10 +161,6 @@ func (ut *UsageTracker) RenderUsageInfo() string {
|
||||
ut.mu.RLock()
|
||||
defer ut.mu.RUnlock()
|
||||
|
||||
if ut.sessionStats.RequestCount == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
baseStyle := lipgloss.NewStyle()
|
||||
|
||||
// Display the current context window token count (from the last API call),
|
||||
|
||||
@@ -67,3 +67,62 @@ func TestUsageTracker_RenderUsageInfo_OAuth(t *testing.T) {
|
||||
t.Errorf("Expected regular rendered output to show actual cost, got: %s", regularRendered)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageTracker_RenderUsageInfo_StartupState(t *testing.T) {
|
||||
// Create a mock model info with costs and context limit
|
||||
modelInfo := &models.ModelInfo{
|
||||
ID: "claude-3-5-sonnet-20241022",
|
||||
Name: "Claude 3.5 Sonnet v2",
|
||||
Cost: models.Cost{
|
||||
Input: 3.0,
|
||||
Output: 15.0,
|
||||
},
|
||||
Limit: models.Limit{
|
||||
Context: 200000,
|
||||
Output: 8192,
|
||||
},
|
||||
}
|
||||
|
||||
// Test startup state (no requests made yet) - Regular API key
|
||||
regularTracker := NewUsageTracker(modelInfo, "anthropic", 80, false)
|
||||
rendered := stripAnsi(regularTracker.RenderUsageInfo())
|
||||
|
||||
// Should NOT return empty string on startup
|
||||
if rendered == "" {
|
||||
t.Errorf("Expected non-empty output on startup, got empty string")
|
||||
}
|
||||
|
||||
// Should show 0 tokens
|
||||
if !strings.Contains(rendered, "Tokens: 0") {
|
||||
t.Errorf("Expected 'Tokens: 0' on startup, got: %s", rendered)
|
||||
}
|
||||
|
||||
// Should NOT show percentage when tokens are 0
|
||||
if strings.Contains(rendered, "(%") {
|
||||
t.Errorf("Expected no percentage on startup with 0 tokens, got: %s", rendered)
|
||||
}
|
||||
|
||||
// Should show $0.0000 cost for regular API key
|
||||
if !strings.Contains(rendered, "Cost: $0.0000") {
|
||||
t.Errorf("Expected 'Cost: $0.0000' on startup, got: %s", rendered)
|
||||
}
|
||||
|
||||
// Test startup state (no requests made yet) - OAuth
|
||||
oauthTracker := NewUsageTracker(modelInfo, "anthropic", 80, true)
|
||||
oauthRendered := stripAnsi(oauthTracker.RenderUsageInfo())
|
||||
|
||||
// Should NOT return empty string on startup
|
||||
if oauthRendered == "" {
|
||||
t.Errorf("Expected non-empty output on startup for OAuth, got empty string")
|
||||
}
|
||||
|
||||
// Should show 0 tokens for OAuth
|
||||
if !strings.Contains(oauthRendered, "Tokens: 0") {
|
||||
t.Errorf("Expected 'Tokens: 0' on startup for OAuth, got: %s", oauthRendered)
|
||||
}
|
||||
|
||||
// Should show $0.00 cost for OAuth
|
||||
if !strings.Contains(oauthRendered, "Cost: $0.00") {
|
||||
t.Errorf("Expected 'Cost: $0.00' on startup for OAuth, got: %s", oauthRendered)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,7 +52,6 @@ type Harness struct {
|
||||
t *testing.T
|
||||
runner *extensions.Runner
|
||||
context *MockContext
|
||||
extPath string
|
||||
}
|
||||
|
||||
// New creates a new test harness for the given test.
|
||||
@@ -72,15 +71,9 @@ func New(t *testing.T) *Harness {
|
||||
func (h *Harness) LoadFile(path string) *extensions.LoadedExtension {
|
||||
h.t.Helper()
|
||||
|
||||
// Verify file exists
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
h.t.Fatalf("extension file not found: %s: %v", path, err)
|
||||
}
|
||||
|
||||
// Read extension source
|
||||
src, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
h.t.Fatalf("failed to read extension file: %v", err)
|
||||
h.t.Fatalf("failed to read extension file %s: %v", path, err)
|
||||
}
|
||||
|
||||
return h.loadSource(string(src), path)
|
||||
@@ -144,7 +137,6 @@ func (h *Harness) loadSource(src string, path string) *extensions.LoadedExtensio
|
||||
|
||||
// Create runner with the loaded extension
|
||||
h.runner = extensions.NewRunner([]extensions.LoadedExtension{*ext})
|
||||
h.extPath = path
|
||||
|
||||
// Wire the mock context
|
||||
h.runner.SetContext(h.context.ToContext())
|
||||
@@ -222,11 +214,3 @@ func (h *Harness) RegisteredCommands() []extensions.CommandDef {
|
||||
}
|
||||
return h.runner.RegisteredCommands()
|
||||
}
|
||||
|
||||
// MustLoad is like LoadFile but fails the test immediately on error.
|
||||
// It returns the harness for chaining.
|
||||
func (h *Harness) MustLoad(path string) *Harness {
|
||||
h.t.Helper()
|
||||
h.LoadFile(path)
|
||||
return h
|
||||
}
|
||||
|
||||
@@ -59,29 +59,12 @@ type MockContext struct {
|
||||
Overlays []extensions.OverlayConfig
|
||||
}
|
||||
|
||||
// StatusBarEntry represents a recorded status bar entry
|
||||
type StatusBarEntry struct {
|
||||
Key string
|
||||
Text string
|
||||
Priority int
|
||||
}
|
||||
|
||||
// NewMockContext creates a new mock context with default values.
|
||||
func NewMockContext() *MockContext {
|
||||
return &MockContext{
|
||||
Prints: make([]string, 0),
|
||||
PrintInfos: make([]string, 0),
|
||||
PrintErrors: make([]string, 0),
|
||||
PrintBlocks: make([]extensions.PrintBlockOpts, 0),
|
||||
Messages: make([]string, 0),
|
||||
CancelSends: make([]string, 0),
|
||||
Widgets: make(map[string]extensions.WidgetConfig),
|
||||
RemovedIDs: make([]string, 0),
|
||||
StatusEntries: make(map[string]extensions.StatusBarEntry),
|
||||
RemovedStatus: make([]string, 0),
|
||||
EditorTexts: make([]string, 0),
|
||||
Options: make(map[string]string),
|
||||
Overlays: make([]extensions.OverlayConfig, 0),
|
||||
Interactive: true,
|
||||
SessionID: "test-session",
|
||||
CWD: "/test",
|
||||
|
||||
+80
-41
@@ -1,6 +1,6 @@
|
||||
# KIT SDK
|
||||
|
||||
The KIT SDK allows you to use KIT programmatically from Go applications without spawning OS processes.
|
||||
The KIT SDK (`pkg/kit`) lets you embed Kit's full agent capabilities — LLM interactions, tool execution, session management, streaming, hooks — into any Go application.
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -17,26 +17,26 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
|
||||
|
||||
// Create Kit instance with default configuration
|
||||
host, err := kit.New(ctx, nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer host.Close()
|
||||
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
// Send a prompt
|
||||
response, err := host.Prompt(ctx, "What is 2+2?")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
fmt.Println(response)
|
||||
}
|
||||
```
|
||||
@@ -56,11 +56,23 @@ You can override specific settings:
|
||||
```go
|
||||
host, err := kit.New(ctx, &kit.Options{
|
||||
Model: "ollama/llama3", // Override model
|
||||
SystemPrompt: "You are a helpful bot", // Override system prompt
|
||||
ConfigFile: "/path/to/config.yml", // Use specific config file
|
||||
MaxSteps: 10, // Override max steps
|
||||
Streaming: true, // Enable streaming
|
||||
Quiet: true, // Suppress debug output
|
||||
SystemPrompt: "You are a helpful bot", // Override system prompt
|
||||
ConfigFile: "/path/to/config.yml", // Use specific config file
|
||||
MaxSteps: 10, // Override max steps
|
||||
Streaming: true, // Enable streaming
|
||||
Quiet: true, // Suppress debug output
|
||||
|
||||
// Session options
|
||||
SessionPath: "./session.jsonl", // Open specific session
|
||||
Continue: true, // Resume most recent session
|
||||
NoSession: true, // Ephemeral mode
|
||||
|
||||
// Tool options
|
||||
Tools: []kit.Tool{kit.NewBashTool()}, // Replace default tool set
|
||||
ExtraTools: []kit.Tool{myTool}, // Add alongside defaults
|
||||
|
||||
// Compaction
|
||||
AutoCompact: true, // Auto-compact near context limit
|
||||
})
|
||||
```
|
||||
|
||||
@@ -71,22 +83,28 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
Monitor tool execution in real-time:
|
||||
|
||||
```go
|
||||
response, err := host.PromptWithCallbacks(
|
||||
unsub := host.OnToolCall(func(e kit.ToolCallEvent) {
|
||||
fmt.Printf("Calling tool: %s\n", e.ToolName)
|
||||
})
|
||||
defer unsub()
|
||||
|
||||
unsub2 := host.OnToolResult(func(e kit.ToolResultEvent) {
|
||||
if e.IsError {
|
||||
fmt.Printf("Tool %s failed: %s\n", e.ToolName, e.Result)
|
||||
} else {
|
||||
fmt.Printf("Tool %s succeeded\n", e.ToolName)
|
||||
}
|
||||
})
|
||||
defer unsub2()
|
||||
|
||||
unsub3 := host.OnStreaming(func(e kit.MessageUpdateEvent) {
|
||||
fmt.Print(e.Chunk)
|
||||
})
|
||||
defer unsub3()
|
||||
|
||||
response, err := host.Prompt(
|
||||
ctx,
|
||||
"List files in the current directory",
|
||||
func(name, args string) {
|
||||
fmt.Printf("Calling tool: %s\n", name)
|
||||
},
|
||||
func(name, args, result string, isError bool) {
|
||||
if isError {
|
||||
fmt.Printf("Tool %s failed: %s\n", name, result)
|
||||
} else {
|
||||
fmt.Printf("Tool %s succeeded\n", name)
|
||||
}
|
||||
},
|
||||
func(chunk string) {
|
||||
fmt.Print(chunk) // Stream output
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
@@ -102,35 +120,56 @@ host.Prompt(ctx, "My name is Alice")
|
||||
response, _ := host.Prompt(ctx, "What's my name?")
|
||||
// Response: "Your name is Alice"
|
||||
|
||||
// Save session
|
||||
host.SaveSession("./session.json")
|
||||
|
||||
// Load session later
|
||||
host.LoadSession("./session.json")
|
||||
|
||||
// Clear session
|
||||
// Clear conversation history
|
||||
host.ClearSession()
|
||||
```
|
||||
|
||||
## Re-exported Types
|
||||
|
||||
The SDK re-exports types so you don't need direct internal imports:
|
||||
|
||||
```go
|
||||
// Message types
|
||||
kit.Message, kit.MessageRole, kit.ContentPart
|
||||
kit.TextContent, kit.ReasoningContent, kit.ToolCall, kit.ToolResult, kit.Finish
|
||||
kit.RoleUser, kit.RoleAssistant, kit.RoleTool, kit.RoleSystem
|
||||
|
||||
// LLM types — concrete Kit-owned structs, no external library dependency
|
||||
kit.LLMMessage // {Role LLMMessageRole, Content string}
|
||||
kit.LLMMessageRole // "user" | "assistant" | "system" | "tool"
|
||||
kit.LLMUsage // {InputTokens, OutputTokens, TotalTokens, ...}
|
||||
kit.LLMResponse // {Content, FinishReason, Usage}
|
||||
kit.LLMFilePart // {Filename, Data []byte, MediaType}
|
||||
|
||||
// Conversion helpers
|
||||
msgs := kit.ConvertToLLMMessages(&msg) // SDK Message → []LLMMessage
|
||||
msg := kit.ConvertFromLLMMessage(lMsg) // LLMMessage → SDK Message
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### Types
|
||||
|
||||
- `Kit` - Main SDK type
|
||||
- `Options` - Configuration options
|
||||
- `Message` - Conversation message
|
||||
- `ToolCall` - Tool invocation details
|
||||
- `Message` - Conversation message with typed content parts
|
||||
- `Tool` - Agent tool interface
|
||||
- `TurnResult` - Full result from a prompt including usage stats
|
||||
|
||||
### Methods
|
||||
### Key Methods
|
||||
|
||||
- `New(ctx, opts)` - Create new Kit instance
|
||||
- `Prompt(ctx, message)` - Send message and get response
|
||||
- `PromptWithCallbacks(ctx, message, ...)` - Send message with progress callbacks
|
||||
- `LoadSession(path)` - Load session from file
|
||||
- `SaveSession(path)` - Save session to file
|
||||
- `ClearSession()` - Clear conversation history
|
||||
- `GetSessionManager()` - Get session manager for advanced usage
|
||||
- `Prompt(ctx, message)` - Send message and get response string
|
||||
- `PromptResult(ctx, message)` - Send message and get full TurnResult
|
||||
- `PromptWithOptions(ctx, message, opts)` - Prompt with per-call options
|
||||
- `Steer(ctx, instruction)` - System-level steering
|
||||
- `FollowUp(ctx, text)` - Continue without new user input
|
||||
- `SetModel(ctx, model)` - Switch model at runtime
|
||||
- `GetModelString()` - Get current model string
|
||||
- `GetModelInfo()` - Get model capabilities and limits
|
||||
- `ClearSession()` - Clear conversation history
|
||||
- `GetSessionPath()` - Get session file path
|
||||
- `GetSessionID()` - Get session UUID
|
||||
- `Close()` - Clean up resources
|
||||
|
||||
## Environment Variables
|
||||
|
||||
+38
-1
@@ -1,6 +1,10 @@
|
||||
package kit
|
||||
|
||||
import "github.com/mark3labs/kit/internal/auth"
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
)
|
||||
|
||||
// CredentialManager manages API keys and OAuth credentials.
|
||||
type CredentialManager = auth.CredentialManager
|
||||
@@ -9,6 +13,10 @@ type CredentialManager = auth.CredentialManager
|
||||
// and API key authentication methods.
|
||||
type AnthropicCredentials = auth.AnthropicCredentials
|
||||
|
||||
// OpenAICredentials holds OpenAI API credentials supporting both OAuth
|
||||
// and API key authentication methods.
|
||||
type OpenAICredentials = auth.OpenAICredentials
|
||||
|
||||
// CredentialStore holds all stored credentials for various providers.
|
||||
type CredentialStore = auth.CredentialStore
|
||||
|
||||
@@ -42,3 +50,32 @@ func GetAnthropicAPIKey() string {
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
// HasOpenAICredentials checks if valid OpenAI credentials are stored
|
||||
// (either OAuth token or API key).
|
||||
func HasOpenAICredentials() bool {
|
||||
cm, err := auth.NewCredentialManager()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
has, err := cm.HasOpenAICredentials()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return has
|
||||
}
|
||||
|
||||
// GetOpenAIAPIKey resolves the OpenAI API key using the standard
|
||||
// resolution order: stored credentials -> OPENAI_API_KEY env var.
|
||||
// Returns an empty string if no key is found.
|
||||
func GetOpenAIAPIKey() string {
|
||||
cm, err := auth.NewCredentialManager()
|
||||
if err == nil {
|
||||
// Try to get valid access token (handles OAuth refresh)
|
||||
if token, err := cm.GetValidOpenAIAccessToken(); err == nil && token != "" {
|
||||
return token
|
||||
}
|
||||
}
|
||||
// Fall back to environment variable
|
||||
return os.Getenv("OPENAI_API_KEY")
|
||||
}
|
||||
|
||||
+61
-49
@@ -2,10 +2,9 @@ package kit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/mark3labs/kit/internal/compaction"
|
||||
)
|
||||
|
||||
@@ -17,10 +16,14 @@ type ContextStats struct {
|
||||
MessageCount int // Number of messages in the conversation
|
||||
}
|
||||
|
||||
// defaultReserveTokens is the number of tokens to keep free in the context
|
||||
// window as a safety margin during compaction checks.
|
||||
const defaultReserveTokens = 16384
|
||||
|
||||
// EstimateContextTokens returns the estimated token count of the current
|
||||
// conversation based on tree session messages.
|
||||
func (m *Kit) EstimateContextTokens() int {
|
||||
messages := m.treeSession.GetFantasyMessages()
|
||||
messages := m.treeSession.GetLLMMessages()
|
||||
return compaction.EstimateMessageTokens(messages)
|
||||
}
|
||||
|
||||
@@ -34,12 +37,12 @@ func (m *Kit) ShouldCompact() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
reserveTokens := 16384
|
||||
reserveTokens := defaultReserveTokens
|
||||
if m.compactionOpts != nil && m.compactionOpts.ReserveTokens > 0 {
|
||||
reserveTokens = m.compactionOpts.ReserveTokens
|
||||
}
|
||||
|
||||
messages := m.treeSession.GetFantasyMessages()
|
||||
messages := m.treeSession.GetLLMMessages()
|
||||
return compaction.ShouldCompact(messages, info.Limit.Context, reserveTokens)
|
||||
}
|
||||
|
||||
@@ -52,7 +55,7 @@ func (m *Kit) ShouldCompact() bool {
|
||||
// because it includes system prompts, tool definitions, and other overhead
|
||||
// that the heuristic cannot account for.
|
||||
func (m *Kit) GetContextStats() ContextStats {
|
||||
messages := m.treeSession.GetFantasyMessages()
|
||||
messages := m.treeSession.GetLLMMessages()
|
||||
|
||||
// Prefer the real API-reported input token count when available.
|
||||
m.lastInputTokensMu.RLock()
|
||||
@@ -111,7 +114,7 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
|
||||
}
|
||||
}
|
||||
|
||||
messages := m.treeSession.GetFantasyMessages()
|
||||
messages := m.treeSession.GetLLMMessages()
|
||||
if len(messages) < 2 {
|
||||
return nil, fmt.Errorf("cannot compact: need at least 2 messages")
|
||||
}
|
||||
@@ -131,7 +134,7 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
|
||||
if reason == "" {
|
||||
reason = "compaction cancelled by extension"
|
||||
}
|
||||
return nil, fmt.Errorf("%s", reason)
|
||||
return nil, errors.New(reason)
|
||||
}
|
||||
// Extension provided a custom summary — use it directly.
|
||||
if hookResult.Summary != "" {
|
||||
@@ -150,7 +153,15 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
|
||||
}
|
||||
|
||||
model := m.agent.GetModel()
|
||||
result, _, err := compaction.Compact(ctx, model, messages, *opts, customInstructions, prev)
|
||||
|
||||
// Create a streaming callback to emit chunks as events.
|
||||
streamCallback := func(delta string) error {
|
||||
// Emit MessageUpdateEvent to the UI for streaming display.
|
||||
m.events.emit(MessageUpdateEvent{Chunk: delta})
|
||||
return nil
|
||||
}
|
||||
|
||||
result, _, err := compaction.Compact(ctx, model, messages, *opts, customInstructions, prev, streamCallback)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -166,34 +177,17 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
|
||||
firstKeptEntryID = entryIDs[result.CutPoint]
|
||||
}
|
||||
|
||||
if _, err := m.treeSession.AppendCompaction(
|
||||
result.Summary,
|
||||
firstKeptEntryID,
|
||||
result.OriginalTokens,
|
||||
result.CompactedTokens,
|
||||
result.MessagesRemoved,
|
||||
result.ReadFiles,
|
||||
result.ModifiedFiles,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed to persist compaction entry: %w", err)
|
||||
if err := m.persistAndEmitCompaction(result.Summary, firstKeptEntryID, result.OriginalTokens, result.CompactedTokens, result.MessagesRemoved, result.ReadFiles, result.ModifiedFiles); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.events.emit(CompactionEvent{
|
||||
Summary: result.Summary,
|
||||
OriginalTokens: result.OriginalTokens,
|
||||
CompactedTokens: result.CompactedTokens,
|
||||
MessagesRemoved: result.MessagesRemoved,
|
||||
ReadFiles: result.ReadFiles,
|
||||
ModifiedFiles: result.ModifiedFiles,
|
||||
})
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// applyCustomCompaction handles compaction when an extension provides a
|
||||
// custom summary. It still determines the cut point and persists a
|
||||
// CompactionEntry.
|
||||
func (m *Kit) applyCustomCompaction(summary string, messages []fantasy.Message, opts *CompactionOptions) (*CompactionResult, error) {
|
||||
func (m *Kit) applyCustomCompaction(summary string, messages []LLMMessage, opts *CompactionOptions) (*CompactionResult, error) {
|
||||
originalTokens := compaction.EstimateMessageTokens(messages)
|
||||
|
||||
cutPoint := compaction.FindCutPoint(messages, opts.KeepRecentTokens)
|
||||
@@ -211,24 +205,13 @@ func (m *Kit) applyCustomCompaction(summary string, messages []fantasy.Message,
|
||||
}
|
||||
|
||||
// Estimate new token count.
|
||||
summaryTokens := compaction.EstimateMessageTokens([]fantasy.Message{{
|
||||
Role: "system",
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: summary}},
|
||||
summaryTokens := compaction.EstimateMessageTokens([]LLMMessage{{
|
||||
Role: LLMRoleSystem,
|
||||
Content: []LLMMessagePart{LLMTextPart{Text: summary}},
|
||||
}})
|
||||
recentTokens := compaction.EstimateMessageTokens(messages[cutPoint:])
|
||||
compactedTokens := summaryTokens + recentTokens
|
||||
|
||||
if _, err := m.treeSession.AppendCompaction(
|
||||
summary,
|
||||
firstKeptEntryID,
|
||||
originalTokens,
|
||||
compactedTokens,
|
||||
cutPoint,
|
||||
nil, nil, // no file tracking for custom summaries
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed to persist compaction entry: %w", err)
|
||||
}
|
||||
|
||||
result := &CompactionResult{
|
||||
Summary: summary,
|
||||
OriginalTokens: originalTokens,
|
||||
@@ -236,12 +219,41 @@ func (m *Kit) applyCustomCompaction(summary string, messages []fantasy.Message,
|
||||
MessagesRemoved: cutPoint,
|
||||
}
|
||||
|
||||
m.events.emit(CompactionEvent{
|
||||
Summary: result.Summary,
|
||||
OriginalTokens: result.OriginalTokens,
|
||||
CompactedTokens: result.CompactedTokens,
|
||||
MessagesRemoved: result.MessagesRemoved,
|
||||
})
|
||||
if err := m.persistAndEmitCompaction(summary, firstKeptEntryID, originalTokens, compactedTokens, cutPoint, nil, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// persistAndEmitCompaction writes a CompactionEntry to the session tree and
|
||||
// emits a CompactionEvent. It is the single implementation shared by
|
||||
// compactInternal and applyCustomCompaction.
|
||||
func (m *Kit) persistAndEmitCompaction(
|
||||
summary, firstKeptEntryID string,
|
||||
originalTokens, compactedTokens, messagesRemoved int,
|
||||
readFiles, modifiedFiles []string,
|
||||
) error {
|
||||
if _, err := m.treeSession.AppendCompaction(
|
||||
summary,
|
||||
firstKeptEntryID,
|
||||
originalTokens,
|
||||
compactedTokens,
|
||||
messagesRemoved,
|
||||
readFiles,
|
||||
modifiedFiles,
|
||||
); err != nil {
|
||||
return fmt.Errorf("failed to persist compaction entry: %w", err)
|
||||
}
|
||||
m.events.emit(CompactionEvent{
|
||||
Summary: summary,
|
||||
OriginalTokens: originalTokens,
|
||||
CompactedTokens: compactedTokens,
|
||||
MessagesRemoved: messagesRemoved,
|
||||
ReadFiles: readFiles,
|
||||
ModifiedFiles: modifiedFiles,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Conversion helpers are in llm_convert.go.
|
||||
|
||||
+11
-11
@@ -12,6 +12,10 @@ import (
|
||||
// defaultSystemPrompt is the built-in system prompt used when no custom
|
||||
// prompt is configured. It describes the available core tools and provides
|
||||
// usage guidelines.
|
||||
//
|
||||
// NOTE: Keep this in sync with the CLI default in cmd/root.go (search for
|
||||
// defaultSystemPrompt or system-prompt flag default). Changes here should
|
||||
// generally be reflected there, and vice versa.
|
||||
const defaultSystemPrompt = `You are an expert coding assistant operating inside kit, a coding agent harness. You help users by reading files, executing commands, editing code, and writing new files.
|
||||
|
||||
Available tools:
|
||||
@@ -78,20 +82,16 @@ func InitConfig(configFile string, debug bool) error {
|
||||
viper.AddConfigPath(home)
|
||||
|
||||
configLoaded := false
|
||||
configNames := []string{".kit"}
|
||||
|
||||
for _, name := range configNames {
|
||||
viper.SetConfigName(name)
|
||||
if err := viper.ReadInConfig(); err == nil {
|
||||
configPath := viper.ConfigFileUsed()
|
||||
if err := LoadConfigWithEnvSubstitution(configPath); err != nil {
|
||||
if strings.Contains(err.Error(), "environment variable substitution failed") {
|
||||
return fmt.Errorf("error reading config file '%s': %w", configPath, err)
|
||||
}
|
||||
continue
|
||||
viper.SetConfigName(".kit")
|
||||
if err := viper.ReadInConfig(); err == nil {
|
||||
configPath := viper.ConfigFileUsed()
|
||||
if err := LoadConfigWithEnvSubstitution(configPath); err != nil {
|
||||
if strings.Contains(err.Error(), "environment variable substitution failed") {
|
||||
return fmt.Errorf("error reading config file '%s': %w", configPath, err)
|
||||
}
|
||||
} else {
|
||||
configLoaded = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user