mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
Compare commits
58 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4ae03aab7c | |||
| 93895392e6 | |||
| 473070e78b | |||
| 12268a777f | |||
| 351c10d814 | |||
| 9de3843605 | |||
| 1d5473e111 | |||
| b6adcf159e | |||
| b1da4a28e6 | |||
| 95abb6fa6e | |||
| a9970cf346 | |||
| 13060a20f9 | |||
| adf603e944 | |||
| af486133a5 | |||
| a97cd47ced | |||
| 68518a2bdb | |||
| fd61db3e12 | |||
| e49066a119 | |||
| efaff7f44f | |||
| d3c970b607 | |||
| 23254fee64 | |||
| fe072ad2e1 | |||
| 8840cbfabc | |||
| a11b41cda4 | |||
| 8b7be8b735 | |||
| caa6d1c178 | |||
| 001156053d | |||
| 54717e32bc | |||
| 5b214b9fdf | |||
| c5e6ca6e4d | |||
| 419a139137 | |||
| 7b963624c1 | |||
| 66f2ba543b | |||
| 6dd052b990 | |||
| ef8628eecc | |||
| 3167222b72 | |||
| e3b37191b1 | |||
| 41d5f5e0fb | |||
| 3ad0b3616d | |||
| 8831b49b51 | |||
| c94edc929b | |||
| e49194a0d4 | |||
| 46b1acf444 | |||
| 6a6d201a50 | |||
| 930cbcb4f2 | |||
| 12e1ef2036 | |||
| a05da5f3ab | |||
| fefbf19b42 | |||
| 93905d4d77 | |||
| 7268ccdf4d | |||
| 9f59fa42dc | |||
| 8af7ca8455 | |||
| 424847f0db | |||
| 4c126ca41b | |||
| 4bdc4f75cc | |||
| bbd8975ca0 | |||
| e613a07773 | |||
| 1d3b4f8d56 |
@@ -1,64 +0,0 @@
|
||||
---
|
||||
name: btca-cli
|
||||
description: Operate the btca CLI for local resources and source-first answers. Use when setting up btca in a project, connecting a provider, adding or managing resources, and asking questions via btca commands. Invoke this skill when the user says "use btca" or needs to do more detailed research on a specific library or framework.
|
||||
---
|
||||
|
||||
# btca CLI
|
||||
|
||||
`btca` is a source-first research CLI. It hydrates resources (git, local, npm) into searchable context, then answers questions grounded in those sources. Use configured resources for ongoing work, or one-off anonymous resources directly in `btca ask`.
|
||||
|
||||
Full CLI reference: https://docs.btca.dev/guides/cli-reference
|
||||
|
||||
Add resources:
|
||||
|
||||
```bash
|
||||
# Git resource
|
||||
btca add -n svelte-dev https://github.com/sveltejs/svelte.dev
|
||||
|
||||
# Local directory
|
||||
btca add -n my-docs -t local /absolute/path/to/docs
|
||||
|
||||
# npm package
|
||||
btca add npm:@types/node@22.10.1 -n node-types -t npm
|
||||
```
|
||||
|
||||
Verify resources:
|
||||
|
||||
```bash
|
||||
btca resources
|
||||
```
|
||||
|
||||
Ask a question:
|
||||
|
||||
```bash
|
||||
btca ask -r svelte-dev -q "How do I define remote functions?"
|
||||
```
|
||||
|
||||
## Common Tasks
|
||||
|
||||
- Ask with multiple resources:
|
||||
|
||||
```bash
|
||||
btca ask -r react -r typescript -q "How do I type useState?"
|
||||
```
|
||||
|
||||
- Ask with anonymous one-off resources (not saved to config):
|
||||
|
||||
```bash
|
||||
# One-off git repo
|
||||
btca ask -r https://github.com/sveltejs/svelte -q "Where is the implementation of writable stores?"
|
||||
|
||||
# One-off npm package
|
||||
btca ask -r npm:react@19.0.0 -q "How is useTransition exported?"
|
||||
```
|
||||
|
||||
## Config Overview
|
||||
|
||||
- Config lives in `btca.config.jsonc` (project) and `~/.config/btca/btca.config.jsonc` (global).
|
||||
- Project config overrides global and controls provider/model and resources.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- "No resources configured": add resources with `btca add ...` and re-run `btca resources`.
|
||||
- "Provider not connected": run `btca connect` and follow the prompts.
|
||||
- "Unknown resource": use `btca resources` for configured names, or pass a valid HTTPS git URL / `npm:<package>` as an anonymous one-off in `btca ask`.
|
||||
@@ -1,3 +0,0 @@
|
||||
interface:
|
||||
display_name: "BTCA CLI"
|
||||
short_description: "Help with BTCA CLI setup and usage workflows"
|
||||
@@ -0,0 +1,32 @@
|
||||
name: Build and Deploy Docs to GitHub Pages
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build-and-deploy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@v1
|
||||
with:
|
||||
bun-version: latest
|
||||
|
||||
- name: Install Dependencies
|
||||
working-directory: ./www
|
||||
run: bun install
|
||||
|
||||
- name: Build
|
||||
working-directory: ./www
|
||||
run: bun run build
|
||||
|
||||
- name: Deploy to GitHub Pages
|
||||
uses: JamesIves/github-pages-deploy-action@v4
|
||||
with:
|
||||
folder: www/out
|
||||
branch: gh-pages
|
||||
+2
-1
@@ -6,9 +6,10 @@ aidocs/
|
||||
*.log
|
||||
/kit
|
||||
.idea
|
||||
test/
|
||||
build/
|
||||
dist/
|
||||
contribute/output/
|
||||
CONTEXT.md
|
||||
output/
|
||||
.agents/
|
||||
skills-lock.json
|
||||
|
||||
@@ -18,20 +18,26 @@ A powerful, extensible AI coding agent CLI with multi-provider support, built-in
|
||||
## Features
|
||||
|
||||
- **Multi-Provider LLM Support**: Anthropic, OpenAI, Google Gemini, Ollama, Azure OpenAI, AWS Bedrock, OpenRouter, and more
|
||||
- **Built-in Core Tools**: bash, read, write, edit, grep, find, ls - no MCP overhead
|
||||
- **Built-in Core Tools**: bash, read, write, edit, grep, find, ls, spawn_subagent - no MCP overhead
|
||||
- **MCP Integration**: Connect external MCP servers for expanded capabilities
|
||||
- **Extension System**: Write custom tools, commands, widgets, and UI modifications in Go
|
||||
- **Theming**: 22 built-in color themes (KITT, Catppuccin, Dracula, Nord, etc.) with runtime switching and custom theme files
|
||||
- **Interactive TUI**: Rich terminal interface powered by Bubble Tea with streaming, syntax highlighting, and custom rendering
|
||||
- **Session Management**: Tree-based conversation history with branching support
|
||||
- **Non-Interactive Mode**: Script-friendly positional args with JSON output
|
||||
- **ACP Server**: Run Kit as an [Agent Client Protocol](https://agentclientprotocol.com) agent over stdio
|
||||
- **Go SDK**: Embed Kit in your own applications
|
||||
|
||||
## Installation
|
||||
|
||||
### Using npm (recommended)
|
||||
### Using npm / bun / pnpm
|
||||
|
||||
```bash
|
||||
npm install -g @mark3labs/kit
|
||||
# or
|
||||
bun install -g @mark3labs/kit
|
||||
# or
|
||||
pnpm install -g @mark3labs/kit
|
||||
```
|
||||
|
||||
### Using Go
|
||||
@@ -82,14 +88,28 @@ kit "Run tests" --quiet
|
||||
kit "Quick question" --no-session
|
||||
```
|
||||
|
||||
### ACP Server Mode
|
||||
|
||||
Kit can run as an [ACP (Agent Client Protocol)](https://agentclientprotocol.com) agent server, enabling ACP-compatible clients (such as [OpenCode](https://github.com/sst/opencode)) to drive Kit as a remote coding agent over stdio.
|
||||
|
||||
```bash
|
||||
# Start Kit as an ACP server (communicates via JSON-RPC 2.0 on stdin/stdout)
|
||||
kit acp
|
||||
|
||||
# With debug logging to stderr
|
||||
kit acp --debug
|
||||
```
|
||||
|
||||
The ACP server exposes Kit's full capabilities — LLM execution, tool calls (bash, read, write, edit, grep, etc.), and session persistence — over the standard ACP protocol. Sessions are persisted to Kit's normal JSONL session files, so they can be resumed later.
|
||||
|
||||
## Configuration
|
||||
|
||||
Kit looks for configuration in the following locations (in order of priority):
|
||||
|
||||
1. CLI flags
|
||||
2. Environment variables (with `KIT_` prefix)
|
||||
3. `./.kit.yml` (project-local)
|
||||
4. `~/.kit.yml` (global)
|
||||
3. `./.kit.yml` / `./.kit.yaml` / `./.kit.json` (project-local)
|
||||
4. `~/.kit.yml` / `~/.kit.yaml` / `~/.kit.json` (global)
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
@@ -164,6 +184,7 @@ mcpServers:
|
||||
--top-p Nucleus sampling 0.0-1.0 (default: 0.95)
|
||||
--top-k Limit top K tokens (default: 40)
|
||||
--stop-sequences Custom stop sequences (comma-separated)
|
||||
--thinking-level Extended thinking level: off, minimal, low, medium, high (default: off)
|
||||
|
||||
# System
|
||||
--config Config file path (default: ~/.kit.yml)
|
||||
@@ -175,24 +196,61 @@ mcpServers:
|
||||
|
||||
```bash
|
||||
# Authentication (for OAuth-enabled providers)
|
||||
kit auth login # Start OAuth flow
|
||||
kit auth logout # Remove credentials
|
||||
kit auth status # Check authentication status
|
||||
kit auth login [provider] # Start OAuth flow (e.g., anthropic)
|
||||
kit auth logout [provider] # Remove credentials for provider
|
||||
kit auth status # Check authentication status
|
||||
|
||||
# Model database
|
||||
kit models # List available models
|
||||
kit models --all # Show all providers (not just Fantasy-compatible)
|
||||
kit update-models # Update local model database from models.dev
|
||||
kit models [provider] # List available models (optionally filter by provider)
|
||||
kit models --all # Show all providers (not just Fantasy-compatible)
|
||||
kit update-models [source] # Update model database (from models.dev, URL, file, or 'embedded')
|
||||
|
||||
# Extension management
|
||||
kit extensions list # List discovered extensions
|
||||
kit extensions validate # Validate extension files
|
||||
kit extensions init # Generate example extension template
|
||||
kit extensions list # List discovered extensions
|
||||
kit extensions validate # Validate extension files
|
||||
kit extensions init # Generate example extension template
|
||||
kit install <git-url> # Install extensions from git repositories
|
||||
kit install -l <git-url> # Install to project-local .kit/git/ directory
|
||||
kit install -u <git-url> # Update an already-installed package
|
||||
kit install --uninstall <pkg> # Remove an installed package
|
||||
|
||||
# Skills
|
||||
kit skill # Install the Kit extensions skill via skills.sh
|
||||
|
||||
# ACP server
|
||||
kit acp # Start as ACP agent (stdio JSON-RPC)
|
||||
kit acp --debug # With debug logging to stderr
|
||||
```
|
||||
|
||||
## Themes
|
||||
|
||||
Kit ships with 22 built-in color themes that control all UI elements. Switch at runtime:
|
||||
|
||||
```
|
||||
/theme dracula
|
||||
/theme catppuccin
|
||||
/theme tokyonight
|
||||
```
|
||||
|
||||
### Custom themes
|
||||
|
||||
Drop a `.yml` file in `~/.config/kit/themes/` (user) or `.kit/themes/` (project):
|
||||
|
||||
```yaml
|
||||
# ~/.config/kit/themes/my-theme.yml
|
||||
primary:
|
||||
light: "#8839ef"
|
||||
dark: "#cba6f7"
|
||||
success:
|
||||
light: "#40a02b"
|
||||
dark: "#a6e3a1"
|
||||
```
|
||||
|
||||
Built-in themes: `kitt`, `catppuccin`, `dracula`, `tokyonight`, `nord`, `gruvbox`, `monokai`, `solarized`, `github`, `one-dark`, `rose-pine`, `ayu`, `material`, `everforest`, `kanagawa`, `amoled`, `synthwave`, `vesper`, `flexoki`, `matrix`, `vercel`, `zenburn`
|
||||
|
||||
## Extension System
|
||||
|
||||
Extensions are Go source files that run via Yaegi interpreter. They can add custom tools, slash commands, widgets, keyboard shortcuts, and intercept lifecycle events.
|
||||
Extensions are Go source files that run via Yaegi interpreter. They can add custom tools, slash commands, widgets, keyboard shortcuts, themes, and intercept lifecycle events.
|
||||
|
||||
### Minimal Extension
|
||||
|
||||
@@ -220,37 +278,69 @@ kit -e examples/extensions/minimal.go
|
||||
|
||||
### Extension Capabilities
|
||||
|
||||
**Lifecycle Events**: OnSessionStart, OnSessionShutdown, OnAgentStart, OnAgentEnd, OnToolCall, OnToolResult, OnInput, OnMessageStart, OnMessageUpdate, OnMessageEnd, OnModelChange, OnContextPrepare, OnBeforeFork, OnBeforeSessionSwitch, OnBeforeCompact
|
||||
**Lifecycle Events**: OnSessionStart, OnSessionShutdown, OnBeforeAgentStart, OnAgentStart, OnAgentEnd, OnToolCall, OnToolExecutionStart, OnToolExecutionEnd, OnToolResult, OnInput, OnMessageStart, OnMessageUpdate, OnMessageEnd, OnModelChange, OnContextPrepare, OnBeforeFork, OnBeforeSessionSwitch, OnBeforeCompact
|
||||
|
||||
**Custom Components**:
|
||||
- **Tools**: Add new tools the LLM can invoke
|
||||
- **Commands**: Register slash commands (e.g., `/mycommand`)
|
||||
- **Options**: Register configurable extension options
|
||||
- **Widgets**: Persistent status displays above/below input
|
||||
- **Headers/Footers**: Persistent content above/below the conversation
|
||||
- **Status Bar**: Custom status bar entries
|
||||
- **Shortcuts**: Global keyboard shortcuts
|
||||
- **Overlays**: Modal dialogs with markdown content
|
||||
- **Tool Renderers**: Customize how tool calls display
|
||||
- **Message Renderers**: Custom rendering for assistant messages
|
||||
- **Editor Interceptors**: Handle key events and wrap rendering
|
||||
- **Interactive Prompts**: Select, confirm, input, and multi-select dialogs
|
||||
- **Subagents**: Spawn in-process child Kit instances
|
||||
- **LLM Completion**: Direct model calls via `Complete()`
|
||||
- **Themes**: Register and switch color themes via `RegisterTheme`, `SetTheme`, `ListThemes`
|
||||
- **Custom Events**: Inter-extension communication via `EmitCustomEvent`
|
||||
|
||||
### Extension Examples
|
||||
|
||||
See the `examples/extensions/` directory:
|
||||
|
||||
- `minimal.go` - Clean UI with custom footer
|
||||
- `notify.go` - Desktop notifications
|
||||
- `widget-status.go` - Persistent status widgets
|
||||
- `custom-editor-demo.go` - Vim-like modal editor
|
||||
- `prompt-demo.go` - Interactive prompts (select/confirm/input)
|
||||
- `tool-logger.go` - Log all tool calls
|
||||
- `overlay-demo.go` - Modal dialogs
|
||||
- `plan-mode.go` - Read-only planning mode
|
||||
- `subagent-widget.go` - Multi-agent orchestration
|
||||
- `auto-commit.go` - Auto-commit on shutdown
|
||||
- `bookmark.go` - Bookmark conversations
|
||||
- `branded-output.go` - Branded output rendering
|
||||
- `compact-notify.go` - Notification on compaction
|
||||
- `confirm-destructive.go` - Confirm destructive operations
|
||||
- `context-inject.go` - Inject context into conversations
|
||||
- `custom-editor-demo.go` - Vim-like modal editor
|
||||
- `dev-reload.go` - Development live-reload
|
||||
- `header-footer-demo.go` - Custom headers and footers
|
||||
- `inline-bash.go` - Inline bash execution
|
||||
- `interactive-shell.go` - Interactive shell integration
|
||||
- `kit-kit.go` - Kit-in-Kit (sub-agent spawning)
|
||||
- `lsp-diagnostics.go` - LSP diagnostic integration
|
||||
- `notify.go` - Desktop notifications
|
||||
- `overlay-demo.go` - Modal dialogs
|
||||
- `permission-gate.go` - Permission gating for tools
|
||||
- `pirate.go` - Pirate-themed personality
|
||||
- `plan-mode.go` - Read-only planning mode
|
||||
- `project-rules.go` - Project-specific rules
|
||||
- `prompt-demo.go` - Interactive prompts (select/confirm/input)
|
||||
- `protected-paths.go` - Path protection for sensitive files
|
||||
- `subagent-widget.go` - Multi-agent orchestration with status widget
|
||||
- `subagent-test.go` - Subagent testing utilities
|
||||
- `summarize.go` - Conversation summarization
|
||||
- `tool-logger.go` - Log all tool calls
|
||||
- `neon-theme.go` - Custom theme registration and switching
|
||||
- `tool-renderer-demo.go` - Custom tool call rendering
|
||||
- `widget-status.go` - Persistent status widgets
|
||||
|
||||
### Loading Extensions
|
||||
|
||||
**Auto-discovery** (loads automatically):
|
||||
- `./.kit/extensions/*.go` (project-local)
|
||||
- `~/.config/kit/extensions/*.go` (global)
|
||||
- `~/.config/kit/extensions/*.go` (global single files)
|
||||
- `~/.config/kit/extensions/*/main.go` (global subdirectory extensions)
|
||||
- `.kit/extensions/*.go` (project-local single files)
|
||||
- `.kit/extensions/*/main.go` (project-local subdirectory extensions)
|
||||
- `~/.local/share/kit/git/` (global git-installed packages)
|
||||
- `.kit/git/` (project-local git-installed packages)
|
||||
|
||||
**Explicit loading**:
|
||||
```bash
|
||||
@@ -263,13 +353,50 @@ kit -e ext1.go -e ext2.go # Multiple extensions
|
||||
kit --no-extensions
|
||||
```
|
||||
|
||||
### Testing Extensions
|
||||
|
||||
Kit provides a testing package to help you write unit tests for your extensions:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"github.com/mark3labs/kit/pkg/extensions/test"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
)
|
||||
|
||||
func TestMyExtension(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("my-ext.go")
|
||||
|
||||
// Emit events and verify behavior
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Verify the extension printed something
|
||||
test.AssertPrinted(t, harness, "session started")
|
||||
}
|
||||
```
|
||||
|
||||
**Available assertions:**
|
||||
- `AssertBlocked()`, `AssertNotBlocked()` — Verify tool blocking
|
||||
- `AssertWidgetSet()`, `AssertWidgetText()` — Verify widget content
|
||||
- `AssertPrinted()`, `AssertPrintedContains()` — Verify output
|
||||
- `AssertToolRegistered()`, `AssertCommandRegistered()` — Verify registration
|
||||
|
||||
See `examples/extensions/tool-logger_test.go` for a complete example with 14 test cases covering tool calls, input handling, and session lifecycle.
|
||||
|
||||
## Session Management
|
||||
|
||||
Kit uses a tree-based session model that supports branching and forking conversations.
|
||||
|
||||
### Session Locations
|
||||
|
||||
- Default: `~/.local/share/kit/sessions/<cwd-hash>/<uuid>.jsonl`
|
||||
- Default: `~/.kit/sessions/<cwd-path>/<timestamp>_<id>.jsonl`
|
||||
- Path separators in the working directory are replaced with `--` (e.g., `/home/user/project` becomes `home--user--project`)
|
||||
- Each line is a session entry (messages, tool calls, extension data)
|
||||
- Supports branching from any message to explore alternate paths
|
||||
|
||||
@@ -336,6 +463,19 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
MaxSteps: 10,
|
||||
Streaming: true,
|
||||
Quiet: true,
|
||||
|
||||
// Session options
|
||||
SessionPath: "./session.jsonl", // Open specific session
|
||||
Continue: true, // Resume most recent session
|
||||
NoSession: true, // Ephemeral mode
|
||||
|
||||
// Tool options
|
||||
ExtraTools: []kit.Tool{...}, // Additional tools alongside defaults
|
||||
|
||||
// Compaction
|
||||
AutoCompact: true, // Auto-compact near context limit
|
||||
|
||||
Debug: true, // Debug logging
|
||||
})
|
||||
```
|
||||
|
||||
@@ -365,14 +505,29 @@ response, err := host.PromptWithCallbacks(
|
||||
### Session Management
|
||||
|
||||
```go
|
||||
// Multi-turn conversations retain context automatically
|
||||
host.Prompt(ctx, "My name is Alice")
|
||||
response, _ := host.Prompt(ctx, "What's my name?")
|
||||
|
||||
host.SaveSession("./session.json")
|
||||
host.LoadSession("./session.json")
|
||||
// Sessions are persisted automatically to JSONL files.
|
||||
// Access session info:
|
||||
path := host.GetSessionPath()
|
||||
id := host.GetSessionID()
|
||||
|
||||
// Clear conversation history
|
||||
host.ClearSession()
|
||||
```
|
||||
|
||||
Session persistence is configured via `Options`:
|
||||
|
||||
```go
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
SessionPath: "./my-session.jsonl", // Open specific session
|
||||
Continue: true, // Resume most recent session
|
||||
NoSession: true, // Ephemeral mode
|
||||
})
|
||||
```
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Subagent Pattern
|
||||
@@ -394,12 +549,25 @@ Parse the JSON output:
|
||||
{
|
||||
"response": "Final assistant response text",
|
||||
"model": "anthropic/claude-haiku-3-5-20241022",
|
||||
"stop_reason": "end_turn",
|
||||
"session_id": "a1b2c3d4e5f6",
|
||||
"usage": {
|
||||
"input_tokens": 1024,
|
||||
"output_tokens": 512,
|
||||
"total_tokens": 1536
|
||||
"total_tokens": 1536,
|
||||
"cache_read_tokens": 0,
|
||||
"cache_creation_tokens": 0
|
||||
},
|
||||
"messages": [...]
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"parts": [
|
||||
{"type": "text", "data": "..."},
|
||||
{"type": "tool_call", "data": {"name": "...", "args": "..."}},
|
||||
{"type": "tool_result", "data": {"name": "...", "result": "..."}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
@@ -449,18 +617,27 @@ go fmt ./...
|
||||
### Project Structure
|
||||
|
||||
```
|
||||
cmd/kit/ - CLI entry point
|
||||
cmd/ - CLI command implementations
|
||||
pkg/kit/ - Go SDK
|
||||
internal/agent/ - Agent loop and tool execution
|
||||
internal/ui/ - Bubble Tea TUI components
|
||||
cmd/kit/ - CLI entry point (main.go)
|
||||
cmd/ - CLI command implementations (root, auth, models, etc.)
|
||||
pkg/kit/ - Go SDK for embedding Kit
|
||||
internal/app/ - Application orchestrator (agent loop, message store, queue)
|
||||
internal/agent/ - Agent execution and tool dispatch
|
||||
internal/auth/ - OAuth authentication and credential storage
|
||||
internal/acpserver/ - ACP (Agent Client Protocol) server
|
||||
internal/clipboard/ - Cross-platform clipboard operations
|
||||
internal/compaction/ - Conversation compaction and summarization
|
||||
internal/config/ - Configuration management
|
||||
internal/core/ - Built-in tools (bash, read, write, edit, grep, find, ls)
|
||||
internal/extensions/ - Yaegi extension system
|
||||
internal/core/ - Built-in tools
|
||||
internal/tools/ - MCP tool integration
|
||||
internal/config/ - Configuration management
|
||||
internal/session/ - Session persistence
|
||||
internal/models/ - Provider and model management
|
||||
internal/kitsetup/ - Initial setup wizard
|
||||
internal/message/ - Message content types and structured content blocks
|
||||
internal/models/ - Provider and model management
|
||||
internal/session/ - Session persistence (tree-based JSONL)
|
||||
internal/skills/ - Skill loading and system prompt composition
|
||||
internal/tools/ - MCP tool integration
|
||||
internal/ui/ - Bubble Tea TUI components
|
||||
examples/extensions/ - Example extension files
|
||||
npm/ - NPM package wrapper for distribution
|
||||
```
|
||||
|
||||
## Supported Providers
|
||||
@@ -489,18 +666,23 @@ google/gemini-2.0-flash-exp
|
||||
### Model Aliases
|
||||
|
||||
```bash
|
||||
claude-opus-latest → claude-opus-4-20250514
|
||||
claude-sonnet-latest → claude-sonnet-4-5-20250929
|
||||
claude-3-5-haiku-latest → claude-3-5-haiku-20241022
|
||||
claude-opus-latest → claude-opus-4-20250514
|
||||
claude-sonnet-latest → claude-sonnet-4-5-20250929
|
||||
claude-4-opus-latest → claude-opus-4-20250514
|
||||
claude-4-sonnet-latest → claude-sonnet-4-5-20250929
|
||||
claude-3-7-sonnet-latest → claude-3-7-sonnet-20250219
|
||||
claude-3-5-sonnet-latest → claude-3-5-sonnet-20241022
|
||||
claude-3-5-haiku-latest → claude-3-5-haiku-20241022
|
||||
claude-3-opus-latest → claude-3-opus-20240229
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines.
|
||||
Contributions are welcome! Please see the [contribution guide](contribute/contribute.md) for guidelines.
|
||||
|
||||
## License
|
||||
|
||||
[Apache 2.0](LICENSE)
|
||||
[MIT](LICENSE)
|
||||
|
||||
## Community
|
||||
|
||||
|
||||
+13
-1
@@ -64,8 +64,20 @@
|
||||
"name": "yaegi",
|
||||
"url": "https://github.com/traefik/yaegi",
|
||||
"branch": "master"
|
||||
},
|
||||
{
|
||||
"type": "git",
|
||||
"name": "acp-go-sdk",
|
||||
"url": "https://github.com/coder/acp-go-sdk",
|
||||
"branch": "main"
|
||||
},
|
||||
{
|
||||
"type": "git",
|
||||
"name": "opencode",
|
||||
"url": "https://github.com/anomalyco/opencode",
|
||||
"branch": "dev"
|
||||
}
|
||||
],
|
||||
"model": "claude-haiku-4-5",
|
||||
"provider": "opencode"
|
||||
}
|
||||
}
|
||||
+159
@@ -0,0 +1,159 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
acp "github.com/coder/acp-go-sdk"
|
||||
|
||||
"github.com/mark3labs/kit/internal/acpserver"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var acpCmd = &cobra.Command{
|
||||
Use: "acp",
|
||||
Short: "Start Kit as an ACP agent server",
|
||||
Long: `Start Kit as an ACP (Agent Client Protocol) agent server.
|
||||
|
||||
Communicates over stdio (stdin/stdout) using JSON-RPC 2.0 with
|
||||
newline-delimited JSON, compatible with OpenCode and other ACP clients.
|
||||
|
||||
The server exposes Kit's LLM execution, tool system, and session
|
||||
management via the Agent Client Protocol. Sessions are persisted
|
||||
to Kit's standard JSONL session files.`,
|
||||
RunE: runACP,
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(acpCmd)
|
||||
}
|
||||
|
||||
func runACP(cmd *cobra.Command, _ []string) error {
|
||||
// Create the ACP agent implementation.
|
||||
agent := acpserver.NewAgent()
|
||||
defer agent.Close()
|
||||
|
||||
// Create the stdio connection. The SDK reads JSON-RPC from stdin and
|
||||
// writes responses to stdout. We wrap stdin with a normalizer that
|
||||
// fills in optional fields the SDK's generated validation requires
|
||||
// (e.g. mcpServers) so clients that omit them still work.
|
||||
conn := acp.NewAgentSideConnection(agent, os.Stdout, newACPNormalizer(os.Stdin))
|
||||
|
||||
// Wire the connection back to the agent so it can send session updates.
|
||||
agent.SetAgentConnection(conn)
|
||||
|
||||
// Enable debug logging to stderr if requested.
|
||||
if debugMode {
|
||||
conn.SetLogger(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: slog.LevelDebug,
|
||||
})))
|
||||
}
|
||||
|
||||
// Wait for either the client to disconnect or a signal.
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
select {
|
||||
case <-conn.Done():
|
||||
fmt.Fprintln(os.Stderr, "kit: ACP client disconnected")
|
||||
case sig := <-sigCh:
|
||||
fmt.Fprintf(os.Stderr, "kit: received %s, shutting down\n", sig)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// acpNormalizer wraps an io.Reader carrying newline-delimited JSON-RPC and
|
||||
// patches incoming messages so that fields the SDK validates as required —
|
||||
// but that some clients (e.g. Zed) omit — are defaulted. This avoids
|
||||
// InvalidParams errors without forking the SDK.
|
||||
type acpNormalizer struct {
|
||||
scanner *bufio.Scanner
|
||||
buf bytes.Buffer // leftover bytes from the last normalized line
|
||||
}
|
||||
|
||||
func newACPNormalizer(r io.Reader) *acpNormalizer {
|
||||
const maxMsg = 10 * 1024 * 1024 // 10 MB, matches SDK buffer
|
||||
s := bufio.NewScanner(r)
|
||||
s.Buffer(make([]byte, 0, 1024*1024), maxMsg)
|
||||
return &acpNormalizer{scanner: s}
|
||||
}
|
||||
|
||||
// Read satisfies io.Reader. It feeds one normalized JSON line (plus newline)
|
||||
// per underlying scan, buffering across short caller reads.
|
||||
func (n *acpNormalizer) Read(p []byte) (int, error) {
|
||||
// Drain any leftover bytes from the previous line first.
|
||||
if n.buf.Len() > 0 {
|
||||
return n.buf.Read(p)
|
||||
}
|
||||
|
||||
if !n.scanner.Scan() {
|
||||
if err := n.scanner.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
line := n.scanner.Bytes()
|
||||
normalized := normalizeACPLine(line)
|
||||
n.buf.Write(normalized)
|
||||
n.buf.WriteByte('\n')
|
||||
return n.buf.Read(p)
|
||||
}
|
||||
|
||||
// normalizeACPLine ensures session/new and session/load params contain an
|
||||
// mcpServers array. Returns the original line unchanged for all other methods.
|
||||
func normalizeACPLine(line []byte) []byte {
|
||||
// Quick check: if it already contains mcpServers, nothing to do.
|
||||
if bytes.Contains(line, []byte(`"mcpServers"`)) {
|
||||
return line
|
||||
}
|
||||
|
||||
// Only bother parsing if the method could be session/new or session/load.
|
||||
if !bytes.Contains(line, []byte(`"session/new"`)) &&
|
||||
!bytes.Contains(line, []byte(`"session/load"`)) {
|
||||
return line
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID json.RawMessage `json:"id,omitempty"`
|
||||
Method string `json:"method"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
}
|
||||
if err := json.Unmarshal(line, &msg); err != nil {
|
||||
return line
|
||||
}
|
||||
if msg.Method != "session/new" && msg.Method != "session/load" {
|
||||
return line
|
||||
}
|
||||
|
||||
// Patch params to include mcpServers: [].
|
||||
var params map[string]json.RawMessage
|
||||
if err := json.Unmarshal(msg.Params, ¶ms); err != nil {
|
||||
return line
|
||||
}
|
||||
if _, ok := params["mcpServers"]; ok {
|
||||
return line
|
||||
}
|
||||
params["mcpServers"] = json.RawMessage(`[]`)
|
||||
|
||||
patched, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
msg.Params = patched
|
||||
|
||||
out, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
return out
|
||||
}
|
||||
+25
-21
@@ -1,11 +1,11 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"charm.land/huh/v2"
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -171,14 +171,15 @@ func loginAnthropic() error {
|
||||
|
||||
// Check if already authenticated
|
||||
if hasAuth, err := cm.HasAnthropicCredentials(); err == nil && hasAuth {
|
||||
fmt.Print("You are already authenticated with Anthropic. Do you want to re-authenticate? (y/N): ")
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
response, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
response = strings.TrimSpace(strings.ToLower(response))
|
||||
if response != "y" && response != "yes" {
|
||||
var reauth bool
|
||||
err := huh.NewConfirm().
|
||||
Title("You are already authenticated with Anthropic").
|
||||
Description("Do you want to re-authenticate?").
|
||||
Affirmative("Yes").
|
||||
Negative("No").
|
||||
Value(&reauth).
|
||||
Run()
|
||||
if err != nil || !reauth {
|
||||
fmt.Println("Authentication cancelled.")
|
||||
return nil
|
||||
}
|
||||
@@ -204,10 +205,13 @@ func loginAnthropic() error {
|
||||
|
||||
// Wait for user to complete OAuth flow
|
||||
fmt.Println("After authorizing the application, you'll receive an authorization code.")
|
||||
fmt.Print("Please enter the authorization code: ")
|
||||
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
code, err := reader.ReadString('\n')
|
||||
var code string
|
||||
err = huh.NewInput().
|
||||
Title("Authorization code").
|
||||
Description("Paste the code from your browser").
|
||||
Value(&code).
|
||||
Run()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read authorization code: %w", err)
|
||||
}
|
||||
@@ -255,15 +259,15 @@ func logoutAnthropic() error {
|
||||
}
|
||||
|
||||
// Confirm logout
|
||||
fmt.Print("Are you sure you want to remove your Anthropic credentials? (y/N): ")
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
response, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
response = strings.TrimSpace(strings.ToLower(response))
|
||||
if response != "y" && response != "yes" {
|
||||
var confirm bool
|
||||
err = huh.NewConfirm().
|
||||
Title("Remove Anthropic credentials").
|
||||
Description("Are you sure you want to remove your stored credentials?").
|
||||
Affirmative("Yes").
|
||||
Negative("No").
|
||||
Value(&confirm).
|
||||
Run()
|
||||
if err != nil || !confirm {
|
||||
fmt.Println("Logout cancelled.")
|
||||
return nil
|
||||
}
|
||||
|
||||
+225
@@ -0,0 +1,225 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
installLocalFlag bool
|
||||
installUpdateFlag bool
|
||||
installUninstallFlag bool
|
||||
installAllFlag bool
|
||||
)
|
||||
|
||||
var installCmd = &cobra.Command{
|
||||
Use: "install <git-url>",
|
||||
Short: "Install extensions from git repositories",
|
||||
Long: `Install extensions from git repositories.
|
||||
|
||||
The install command downloads and installs Kit extensions from git repositories.
|
||||
Extensions are stored in the global extensions directory by default, or in the
|
||||
project's .kit/git/ directory when using the --local flag.
|
||||
|
||||
When a repo contains multiple extensions, an interactive multi-select is shown
|
||||
so you can choose which to install. Use --all to skip selection and install everything.
|
||||
|
||||
Supported URL formats:
|
||||
- github.com/user/repo (shorthand, defaults to HTTPS)
|
||||
- git:github.com/user/repo
|
||||
- https://github.com/user/repo
|
||||
- ssh://git@github.com/user/repo
|
||||
- git@github.com:user/repo
|
||||
|
||||
You can pin to a specific version, tag, or commit using @:
|
||||
- github.com/user/repo@v1.0.0
|
||||
- github.com/user/repo@main
|
||||
- github.com/user/repo@abc1234
|
||||
|
||||
Examples:
|
||||
kit install github.com/user/my-extension
|
||||
kit install github.com/user/my-extension@v1.0.0
|
||||
kit install github.com/user/my-extension --local
|
||||
kit install github.com/user/collection --all`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runInstall,
|
||||
}
|
||||
|
||||
func init() {
|
||||
installCmd.Flags().BoolVarP(&installLocalFlag, "local", "l", false, "Install to project-local .kit/git/ directory")
|
||||
installCmd.Flags().BoolVarP(&installUpdateFlag, "update", "u", false, "Update an already-installed package")
|
||||
installCmd.Flags().BoolVar(&installUninstallFlag, "uninstall", false, "Remove an installed package")
|
||||
installCmd.Flags().BoolVar(&installAllFlag, "all", false, "Install all extensions without prompting")
|
||||
|
||||
rootCmd.AddCommand(installCmd)
|
||||
}
|
||||
|
||||
func runInstall(cmd *cobra.Command, args []string) error {
|
||||
sourceStr := args[0]
|
||||
|
||||
// Check that git is available
|
||||
if _, err := exec.LookPath("git"); err != nil {
|
||||
return fmt.Errorf("git is not installed or not in PATH")
|
||||
}
|
||||
|
||||
// Parse the source
|
||||
source, err := extensions.ParseGitSource(sourceStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid source: %w", err)
|
||||
}
|
||||
|
||||
// Determine scope
|
||||
scope := extensions.ScopeGlobal
|
||||
if installLocalFlag {
|
||||
scope = extensions.ScopeProject
|
||||
}
|
||||
|
||||
installer := extensions.NewInstaller(".")
|
||||
|
||||
// Handle uninstall
|
||||
if installUninstallFlag {
|
||||
return runUninstall(installer, source, scope)
|
||||
}
|
||||
|
||||
// Handle update
|
||||
if installUpdateFlag {
|
||||
return runUpdate(installer, source, scope)
|
||||
}
|
||||
|
||||
// Handle install
|
||||
return runInstallPackage(installer, source, scope)
|
||||
}
|
||||
|
||||
func runInstallPackage(installer *extensions.Installer, source *extensions.GitSource, scope extensions.InstallScope) error {
|
||||
// Check if already installed
|
||||
existingScope, installed := installer.IsInstalled(source)
|
||||
if installed {
|
||||
return fmt.Errorf("extension already installed (scope: %s). Use --update to update or --uninstall to remove", existingScope)
|
||||
}
|
||||
|
||||
// Preview extensions to decide if we need multi-select
|
||||
previews, tempDir, err := installer.PreviewExtensions(source)
|
||||
if err != nil {
|
||||
return fmt.Errorf("previewing extensions: %w", err)
|
||||
}
|
||||
defer extensions.CleanupTempDir(tempDir)
|
||||
|
||||
if len(previews) == 0 {
|
||||
return fmt.Errorf("no extensions found in %s", source.String())
|
||||
}
|
||||
|
||||
scopeStr := "globally"
|
||||
if scope == extensions.ScopeProject {
|
||||
scopeStr = "locally in .kit/git/"
|
||||
}
|
||||
|
||||
// Single extension or --all flag: install everything directly
|
||||
if len(previews) == 1 || installAllFlag {
|
||||
if err := installer.Install(source, scope); err != nil {
|
||||
return fmt.Errorf("install failed: %w", err)
|
||||
}
|
||||
|
||||
if source.Pinned {
|
||||
fmt.Printf("Installed %s at %s %s\n", source.String(), source.Ref, scopeStr)
|
||||
} else {
|
||||
fmt.Printf("Installed %d extension(s) from %s %s\n", len(previews), source.String(), scopeStr)
|
||||
}
|
||||
|
||||
log.Info("extension installed", "source", source.String(), "scope", scope)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Multiple extensions: show interactive selection
|
||||
includePaths, err := multiSelectForInstall(previews)
|
||||
if err != nil {
|
||||
if err.Error() == "selection cancelled" || err.Error() == "no extensions selected" {
|
||||
fmt.Println("Install cancelled.")
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("selection failed: %w", err)
|
||||
}
|
||||
|
||||
if err := installer.InstallWithInclude(source, scope, includePaths); err != nil {
|
||||
return fmt.Errorf("install failed: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Installed %d extension(s) from %s %s\n", len(includePaths), source.String(), scopeStr)
|
||||
for _, path := range includePaths {
|
||||
fmt.Printf(" - %s\n", path)
|
||||
}
|
||||
|
||||
log.Info("extension installed", "source", source.String(), "scope", scope, "selected", len(includePaths))
|
||||
return nil
|
||||
}
|
||||
|
||||
func runUpdate(installer *extensions.Installer, source *extensions.GitSource, scope extensions.InstallScope) error {
|
||||
// Find the installed package
|
||||
existingScope, installed := installer.IsInstalled(source)
|
||||
if !installed {
|
||||
// Try to find with wildcard (no version)
|
||||
entry, foundScope, err := extensions.FindInManifest(source.Identity())
|
||||
if err != nil || entry == nil {
|
||||
return fmt.Errorf("extension not installed: %s", source.Identity())
|
||||
}
|
||||
// Parse the found entry's source
|
||||
foundSource, err := extensions.ParseGitSource(entry.Source)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse installed source: %w", err)
|
||||
}
|
||||
existingScope = foundScope
|
||||
source = foundSource
|
||||
}
|
||||
|
||||
// Override scope if specified
|
||||
if installLocalFlag && scope != existingScope {
|
||||
return fmt.Errorf("extension installed in %s scope, cannot update with --local flag", existingScope)
|
||||
}
|
||||
scope = existingScope
|
||||
|
||||
// Check if pinned
|
||||
if source.Pinned {
|
||||
fmt.Printf("Skipping %s (pinned at %s)\n", source.Identity(), source.Ref)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update
|
||||
if err := installer.Update(source, scope); err != nil {
|
||||
return fmt.Errorf("update failed: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Updated %s\n", source.Identity())
|
||||
log.Info("extension updated", "source", source.Identity(), "scope", scope)
|
||||
return nil
|
||||
}
|
||||
|
||||
func runUninstall(installer *extensions.Installer, source *extensions.GitSource, scope extensions.InstallScope) error {
|
||||
// Find where it's installed (ignore scope flag for uninstall - remove from wherever it exists)
|
||||
existingScope, installed := installer.IsInstalled(source)
|
||||
if !installed {
|
||||
// Try to find in manifests
|
||||
entry, foundScope, err := extensions.FindInManifest(source.Identity())
|
||||
if err != nil || entry == nil {
|
||||
return fmt.Errorf("extension not installed: %s", source.Identity())
|
||||
}
|
||||
existingScope = foundScope
|
||||
// Parse the found entry's source
|
||||
foundSource, err := extensions.ParseGitSource(entry.Source)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse installed source: %w", err)
|
||||
}
|
||||
source = foundSource
|
||||
}
|
||||
|
||||
// Uninstall from the scope where it's installed
|
||||
if err := installer.Uninstall(source, existingScope); err != nil {
|
||||
return fmt.Errorf("uninstall failed: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Uninstalled %s from %s scope\n", source.Identity(), existingScope)
|
||||
log.Info("extension uninstalled", "source", source.Identity(), "scope", existingScope)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"charm.land/huh/v2"
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
)
|
||||
|
||||
// multiSelectForInstall runs a multi-select prompt for extension selection.
|
||||
// Returns the selected extension paths, or an error if cancelled.
|
||||
func multiSelectForInstall(previews []extensions.ExtensionPreview) ([]string, error) {
|
||||
if len(previews) == 0 {
|
||||
return nil, fmt.Errorf("no extensions to select")
|
||||
}
|
||||
|
||||
// Non-interactive: select all
|
||||
if !isInteractive() {
|
||||
log.Info("Non-interactive mode, selecting all extensions")
|
||||
paths := make([]string, len(previews))
|
||||
for i, p := range previews {
|
||||
paths[i] = p.Path
|
||||
}
|
||||
return paths, nil
|
||||
}
|
||||
|
||||
// Single extension: just return it
|
||||
if len(previews) == 1 {
|
||||
return []string{previews[0].Path}, nil
|
||||
}
|
||||
|
||||
// Build options for huh MultiSelect
|
||||
options := make([]huh.Option[string], len(previews))
|
||||
for i, p := range previews {
|
||||
label := fmt.Sprintf("%s %s", p.Name, p.Path)
|
||||
options[i] = huh.NewOption(label, p.Path).Selected(true)
|
||||
}
|
||||
|
||||
var selected []string
|
||||
|
||||
form := huh.NewForm(
|
||||
huh.NewGroup(
|
||||
huh.NewMultiSelect[string]().
|
||||
Title("Select extensions to install").
|
||||
Options(options...).
|
||||
Value(&selected),
|
||||
),
|
||||
)
|
||||
|
||||
if err := form.Run(); err != nil {
|
||||
return nil, fmt.Errorf("selection cancelled")
|
||||
}
|
||||
|
||||
if len(selected) == 0 {
|
||||
return nil, fmt.Errorf("no extensions selected")
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// isInteractive checks if the terminal is interactive.
|
||||
func isInteractive() bool {
|
||||
fi, err := os.Stdout.Stat()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return (fi.Mode() & os.ModeCharDevice) != 0
|
||||
}
|
||||
+159
-23
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"image/color"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -141,24 +142,58 @@ func LoadConfigWithEnvSubstitution(configPath string) error {
|
||||
return kit.LoadConfigWithEnvSubstitution(configPath)
|
||||
}
|
||||
|
||||
func configToUiTheme(theme config.Theme) ui.Theme {
|
||||
// adaptiveOrDefault converts a config.AdaptiveColor to a resolved color.Color,
|
||||
// falling back to fallback when both Light and Dark are empty.
|
||||
func adaptiveOrDefault(ac config.AdaptiveColor, fallback color.Color) color.Color {
|
||||
if ac.Light == "" && ac.Dark == "" {
|
||||
return fallback
|
||||
}
|
||||
return ui.AdaptiveColor(ac.Light, ac.Dark)
|
||||
}
|
||||
|
||||
func configToUiTheme(cfg config.Theme) ui.Theme {
|
||||
def := ui.DefaultTheme()
|
||||
return ui.Theme{
|
||||
Primary: ui.AdaptiveColor(theme.Primary.Light, theme.Primary.Dark),
|
||||
Secondary: ui.AdaptiveColor(theme.Secondary.Light, theme.Secondary.Dark),
|
||||
Success: ui.AdaptiveColor(theme.Success.Light, theme.Success.Dark),
|
||||
Warning: ui.AdaptiveColor(theme.Warning.Light, theme.Warning.Dark),
|
||||
Error: ui.AdaptiveColor(theme.Error.Light, theme.Error.Dark),
|
||||
Info: ui.AdaptiveColor(theme.Info.Light, theme.Info.Dark),
|
||||
Text: ui.AdaptiveColor(theme.Text.Light, theme.Text.Dark),
|
||||
Muted: ui.AdaptiveColor(theme.Muted.Light, theme.Muted.Dark),
|
||||
VeryMuted: ui.AdaptiveColor(theme.VeryMuted.Light, theme.VeryMuted.Dark),
|
||||
Background: ui.AdaptiveColor(theme.Background.Light, theme.Background.Dark),
|
||||
Border: ui.AdaptiveColor(theme.Border.Light, theme.Border.Dark),
|
||||
MutedBorder: ui.AdaptiveColor(theme.MutedBorder.Light, theme.MutedBorder.Dark),
|
||||
System: ui.AdaptiveColor(theme.System.Light, theme.System.Dark),
|
||||
Tool: ui.AdaptiveColor(theme.Tool.Light, theme.Tool.Dark),
|
||||
Accent: ui.AdaptiveColor(theme.Accent.Light, theme.Accent.Dark),
|
||||
Highlight: ui.AdaptiveColor(theme.Highlight.Light, theme.Highlight.Dark),
|
||||
Primary: adaptiveOrDefault(cfg.Primary, def.Primary),
|
||||
Secondary: adaptiveOrDefault(cfg.Secondary, def.Secondary),
|
||||
Success: adaptiveOrDefault(cfg.Success, def.Success),
|
||||
Warning: adaptiveOrDefault(cfg.Warning, def.Warning),
|
||||
Error: adaptiveOrDefault(cfg.Error, def.Error),
|
||||
Info: adaptiveOrDefault(cfg.Info, def.Info),
|
||||
Text: adaptiveOrDefault(cfg.Text, def.Text),
|
||||
Muted: adaptiveOrDefault(cfg.Muted, def.Muted),
|
||||
VeryMuted: adaptiveOrDefault(cfg.VeryMuted, def.VeryMuted),
|
||||
Background: adaptiveOrDefault(cfg.Background, def.Background),
|
||||
Border: adaptiveOrDefault(cfg.Border, def.Border),
|
||||
MutedBorder: adaptiveOrDefault(cfg.MutedBorder, def.MutedBorder),
|
||||
System: adaptiveOrDefault(cfg.System, def.System),
|
||||
Tool: adaptiveOrDefault(cfg.Tool, def.Tool),
|
||||
Accent: adaptiveOrDefault(cfg.Accent, def.Accent),
|
||||
Highlight: adaptiveOrDefault(cfg.Highlight, def.Highlight),
|
||||
|
||||
DiffInsertBg: adaptiveOrDefault(cfg.DiffInsertBg, def.DiffInsertBg),
|
||||
DiffDeleteBg: adaptiveOrDefault(cfg.DiffDeleteBg, def.DiffDeleteBg),
|
||||
DiffEqualBg: adaptiveOrDefault(cfg.DiffEqualBg, def.DiffEqualBg),
|
||||
DiffMissingBg: adaptiveOrDefault(cfg.DiffMissingBg, def.DiffMissingBg),
|
||||
|
||||
CodeBg: adaptiveOrDefault(cfg.CodeBg, def.CodeBg),
|
||||
GutterBg: adaptiveOrDefault(cfg.GutterBg, def.GutterBg),
|
||||
WriteBg: adaptiveOrDefault(cfg.WriteBg, def.WriteBg),
|
||||
|
||||
Markdown: ui.MarkdownThemeColors{
|
||||
Text: adaptiveOrDefault(cfg.Markdown.Text, def.Markdown.Text),
|
||||
Muted: adaptiveOrDefault(cfg.Markdown.Muted, def.Markdown.Muted),
|
||||
Heading: adaptiveOrDefault(cfg.Markdown.Heading, def.Markdown.Heading),
|
||||
Emph: adaptiveOrDefault(cfg.Markdown.Emph, def.Markdown.Emph),
|
||||
Strong: adaptiveOrDefault(cfg.Markdown.Strong, def.Markdown.Strong),
|
||||
Link: adaptiveOrDefault(cfg.Markdown.Link, def.Markdown.Link),
|
||||
Code: adaptiveOrDefault(cfg.Markdown.Code, def.Markdown.Code),
|
||||
Error: adaptiveOrDefault(cfg.Markdown.Error, def.Markdown.Error),
|
||||
Keyword: adaptiveOrDefault(cfg.Markdown.Keyword, def.Markdown.Keyword),
|
||||
String: adaptiveOrDefault(cfg.Markdown.String, def.Markdown.String),
|
||||
Number: adaptiveOrDefault(cfg.Markdown.Number, def.Markdown.Number),
|
||||
Comment: adaptiveOrDefault(cfg.Markdown.Comment, def.Markdown.Comment),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -901,6 +936,28 @@ func runNormalMode(ctx context.Context) error {
|
||||
SetActiveTools: func(names []string) {
|
||||
kitInstance.SetExtensionActiveTools(names)
|
||||
},
|
||||
RegisterTheme: func(name string, config extensions.ThemeColorConfig) {
|
||||
tc := func(c extensions.ThemeColor) [2]string { return [2]string{c.Light, c.Dark} }
|
||||
ui.RegisterThemeFromConfig(name,
|
||||
tc(config.Primary), tc(config.Secondary),
|
||||
tc(config.Success), tc(config.Warning),
|
||||
tc(config.Error), tc(config.Info),
|
||||
tc(config.Text), tc(config.Muted),
|
||||
tc(config.VeryMuted), tc(config.Background),
|
||||
tc(config.Border), tc(config.MutedBorder),
|
||||
tc(config.System), tc(config.Tool),
|
||||
tc(config.Accent), tc(config.Highlight),
|
||||
tc(config.MdHeading), tc(config.MdLink),
|
||||
tc(config.MdKeyword), tc(config.MdString),
|
||||
tc(config.MdNumber), tc(config.MdComment),
|
||||
)
|
||||
},
|
||||
SetTheme: func(name string) error {
|
||||
return ui.ApplyTheme(name)
|
||||
},
|
||||
ListThemes: func() []string {
|
||||
return ui.ListThemes()
|
||||
},
|
||||
ShowOverlay: func(config extensions.OverlayConfig) extensions.OverlayResult {
|
||||
ch := make(chan app.OverlayResponse, 1)
|
||||
appInstance.SendOverlayRequest(app.OverlayRequestEvent{
|
||||
@@ -924,6 +981,42 @@ func runNormalMode(ctx context.Context) error {
|
||||
Index: resp.Index,
|
||||
}
|
||||
},
|
||||
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
// In-process subagent via SDK.
|
||||
sdkCfg := kit.SubagentConfig{
|
||||
Prompt: config.Prompt,
|
||||
Model: config.Model,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Timeout: config.Timeout,
|
||||
NoSession: config.NoSession,
|
||||
}
|
||||
// Bridge SDK events to extension SubagentEvents.
|
||||
if config.OnEvent != nil {
|
||||
sdkCfg.OnEvent = func(e kit.Event) {
|
||||
se := sdkEventToSubagentEvent(e)
|
||||
if se.Type != "" {
|
||||
config.OnEvent(se)
|
||||
}
|
||||
}
|
||||
}
|
||||
result, err := kitInstance.Subagent(ctx, sdkCfg)
|
||||
if result == nil {
|
||||
return nil, &extensions.SubagentResult{Error: err}, err
|
||||
}
|
||||
extResult := &extensions.SubagentResult{
|
||||
Response: result.Response,
|
||||
Error: result.Error,
|
||||
SessionID: result.SessionID,
|
||||
Elapsed: result.Elapsed,
|
||||
}
|
||||
if result.Usage != nil {
|
||||
extResult.Usage = &extensions.SubagentUsage{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
return nil, extResult, err
|
||||
},
|
||||
})
|
||||
kitInstance.EmitSessionStart()
|
||||
}
|
||||
@@ -1083,15 +1176,19 @@ func buildJSONOutput(result *kit.TurnResult, model string) ([]byte, error) {
|
||||
CacheCreationTokens int64 `json:"cache_creation_tokens"`
|
||||
}
|
||||
type jsonEnvelope struct {
|
||||
Response string `json:"response"`
|
||||
Model string `json:"model"`
|
||||
Usage *jsonUsage `json:"usage,omitempty"`
|
||||
Messages []jsonMessage `json:"messages"`
|
||||
Response string `json:"response"`
|
||||
Model string `json:"model"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
SessionID string `json:"session_id,omitempty"`
|
||||
Usage *jsonUsage `json:"usage,omitempty"`
|
||||
Messages []jsonMessage `json:"messages"`
|
||||
}
|
||||
|
||||
out := jsonEnvelope{
|
||||
Response: result.Response,
|
||||
Model: model,
|
||||
Response: result.Response,
|
||||
Model: model,
|
||||
StopReason: result.StopReason,
|
||||
SessionID: result.SessionID,
|
||||
}
|
||||
|
||||
if result.TotalUsage != nil {
|
||||
@@ -1202,3 +1299,42 @@ func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelN
|
||||
_, runErr := program.Run()
|
||||
return runErr
|
||||
}
|
||||
|
||||
// sdkEventToSubagentEvent converts an SDK event to an extension-facing
|
||||
// SubagentEvent. Returns a zero-value event (Type=="") for events that
|
||||
// don't map to anything useful.
|
||||
func sdkEventToSubagentEvent(e kit.Event) extensions.SubagentEvent {
|
||||
switch ev := e.(type) {
|
||||
case kit.MessageUpdateEvent:
|
||||
return extensions.SubagentEvent{Type: "text", Content: ev.Chunk}
|
||||
case kit.ReasoningDeltaEvent:
|
||||
return extensions.SubagentEvent{Type: "reasoning", Content: ev.Delta}
|
||||
case kit.ToolCallEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_call", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind, ToolArgs: ev.ToolArgs,
|
||||
}
|
||||
case kit.ToolExecutionStartEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_start", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolExecutionEndEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_end", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolResultEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_result", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
ToolResult: ev.Result, IsError: ev.IsError,
|
||||
}
|
||||
case kit.TurnStartEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_start"}
|
||||
case kit.TurnEndEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_end"}
|
||||
default:
|
||||
return extensions.SubagentEvent{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// skillCmd installs the kit-extensions skill via the skills.sh CLI (npx skills).
|
||||
// This teaches AI agents how to create Kit extensions with full knowledge of
|
||||
// the extension API, lifecycle events, widgets, tools, commands, and Yaegi constraints.
|
||||
var skillCmd = &cobra.Command{
|
||||
Use: "skill",
|
||||
Short: "Install the Kit extensions skill via skills.sh",
|
||||
Long: `Install the kit-extensions skill that teaches AI agents how to create
|
||||
Kit extensions. Uses the skills.sh CLI (npx skills) to install the skill
|
||||
from the Kit repository.
|
||||
|
||||
The skill provides comprehensive documentation of Kit's extension API including
|
||||
lifecycle events, custom tools, slash commands, widgets, editor interceptors,
|
||||
tool renderers, and critical Yaegi interpreter constraints.
|
||||
|
||||
Example:
|
||||
kit skill`,
|
||||
RunE: runSkill,
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(skillCmd)
|
||||
}
|
||||
|
||||
func runSkill(_ *cobra.Command, _ []string) error {
|
||||
npx, err := exec.LookPath("npx")
|
||||
if err != nil {
|
||||
return fmt.Errorf("npx not found in PATH — install Node.js to use this command: %w", err)
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"skills",
|
||||
"add",
|
||||
"mark3labs/kit",
|
||||
"--skill",
|
||||
"kit-extensions",
|
||||
}
|
||||
|
||||
cmd := exec.Command(npx, args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("skills install failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,197 @@
|
||||
# Kit Extension Examples
|
||||
|
||||
A collection of example extensions demonstrating various Kit capabilities. These can be installed individually or as a complete collection.
|
||||
|
||||
## Installation
|
||||
|
||||
### Install all examples
|
||||
```bash
|
||||
kit install github.com/mark3labs/kit/examples/extensions
|
||||
```
|
||||
|
||||
### Install with interactive selection
|
||||
```bash
|
||||
kit install github.com/mark3labs/kit/examples/extensions --select
|
||||
```
|
||||
|
||||
### Install locally in your project
|
||||
```bash
|
||||
kit install github.com/mark3labs/kit/examples/extensions --local
|
||||
```
|
||||
|
||||
## Extension Index
|
||||
|
||||
### Core Concepts
|
||||
|
||||
| Extension | Description | Key API |
|
||||
|-----------|-------------|---------|
|
||||
| `minimal.go` | Minimal viable extension | Basic `Init()` function |
|
||||
| `plan-mode.go` | Restrict agent to read-only tools | `OnBeforeAgentStart`, `SetActiveTools` |
|
||||
| `tool-logger.go` | Log all tool calls to file | `OnToolCall`, `OnToolResult` |
|
||||
| `notify.go` | Display notifications | `PrintInfo`, `PrintBlock` |
|
||||
|
||||
### UI & Widgets
|
||||
|
||||
| Extension | Description | Key API |
|
||||
|-----------|-------------|---------|
|
||||
| `widget-status.go` | Persistent status widget | `SetWidget`, `RemoveWidget` |
|
||||
| `header-footer-demo.go` | Custom header/footer | `SetHeader`, `SetFooter` |
|
||||
| `overlay-demo.go` | Modal overlay dialogs | `ShowOverlay` |
|
||||
| `compact-notify.go` | Compact mode notifications | `PrintBlock` |
|
||||
| `branded-output.go` | Custom styled output | `PrintBlock` with colors |
|
||||
|
||||
### Input & Editor
|
||||
|
||||
| Extension | Description | Key API |
|
||||
|-----------|-------------|---------|
|
||||
| `custom-editor-demo.go` | Custom key handling | `SetEditor`, `EditorKeyAction` |
|
||||
| `pirate.go` | Transform user input | `OnInput`, `InputResult` |
|
||||
| `interactive-shell.go` | Custom command input | Slash commands with prompts |
|
||||
| `inline-bash.go` | Execute bash inline | Input handling, `exec` |
|
||||
|
||||
### Session & Context
|
||||
|
||||
| Extension | Description | Key API |
|
||||
|-----------|-------------|---------|
|
||||
| `context-inject.go` | Inject context into prompts | `OnContextPrepare` |
|
||||
| `bookmark.go` | Bookmark messages | `AppendEntry`, `GetEntries` |
|
||||
| `project-rules.go` | Project-specific rules | Session data, file reading |
|
||||
| `protected-paths.go` | Block dangerous operations | `OnToolCall` with blocking |
|
||||
| `permission-gate.go` | Confirm destructive actions | `OnToolCall` with confirmation |
|
||||
|
||||
### Tools & Commands
|
||||
|
||||
| Extension | Description | Key API |
|
||||
|-----------|-------------|---------|
|
||||
| `auto-commit.go` | Auto-commit changes | Custom tool, git operations |
|
||||
| `summarize.go` | Summarize conversation | Custom tool with parameters |
|
||||
| `confirm-destructive.go` | Confirm destructive commands | `OnToolCall` blocking |
|
||||
| `lsp-diagnostics.go` | LSP integration | Complex extension, external process |
|
||||
|
||||
### Subagents & Background Tasks
|
||||
|
||||
| Extension | Description | Key API |
|
||||
|-----------|-------------|---------|
|
||||
| `kit-kit.go` | Spawn Kit as subagent | Subagent spawning |
|
||||
| `subagent-test.go` | Test subagent functionality | `SpawnSubagent` |
|
||||
| `subagent-widget.go` | Widget with subagent updates | Goroutines + widgets |
|
||||
| `dev-reload.go` | Hot reload extensions | `ReloadExtensions` |
|
||||
|
||||
### Integrations
|
||||
|
||||
| Extension | Description | Key API |
|
||||
|-----------|-------------|---------|
|
||||
| `kit-telegram/` | Telegram relay for remote monitoring & control | `RegisterCommand`, `OnAgentStart/End`, `SetStatus`, `SendMessage` |
|
||||
|
||||
### Themes
|
||||
|
||||
| Extension | Description | Key API |
|
||||
|-----------|-------------|---------|
|
||||
| `neon-theme.go` | Register and switch custom themes | `RegisterTheme`, `SetTheme` |
|
||||
|
||||
### Rendering
|
||||
|
||||
| Extension | Description | Key API |
|
||||
|-----------|-------------|---------|
|
||||
| `tool-renderer-demo.go` | Custom tool output styling | `RegisterToolRenderer` |
|
||||
| `prompt-demo.go` | Interactive prompts | `PromptSelect`, `PromptConfirm` |
|
||||
|
||||
## Extension Details
|
||||
|
||||
### minimal.go
|
||||
The bare minimum extension showing the required structure:
|
||||
- Package `main`
|
||||
- Import `kit/ext`
|
||||
- Export `Init(api ext.API)` function
|
||||
|
||||
### plan-mode.go
|
||||
A complete example demonstrating:
|
||||
- Slash command (`/plan`)
|
||||
- Keyboard shortcut (`ctrl+alt+p`)
|
||||
- Option registration
|
||||
- Status bar indicators
|
||||
- System prompt injection
|
||||
- Tool filtering
|
||||
|
||||
### widget-status.go
|
||||
Shows how to create persistent UI elements:
|
||||
- Create widgets with `SetWidget`
|
||||
- Update content dynamically
|
||||
- Remove when done
|
||||
- Handle session lifecycle
|
||||
|
||||
### context-inject.go
|
||||
Advanced context manipulation:
|
||||
- Read project files
|
||||
- Inject into LLM context
|
||||
- Filter messages
|
||||
- Use negative indices for ephemeral content
|
||||
|
||||
### lsp-diagnostics.go
|
||||
Complex real-world example:
|
||||
- Multi-file extension
|
||||
- External process management (LSP server)
|
||||
- File watching
|
||||
- Diagnostics aggregation
|
||||
|
||||
### kit-telegram/
|
||||
Full-featured Telegram integration:
|
||||
- Slash command with subcommands and tab completion
|
||||
- Interactive guided setup flow with prompts
|
||||
- Background long-polling goroutine
|
||||
- Progress message rendering edited in place
|
||||
- Message queue with edit-before-dispatch
|
||||
- Remote command handling from Telegram
|
||||
- Status bar and widget updates
|
||||
- Config persistence with atomic writes
|
||||
|
||||
## Multi-File Extension Example
|
||||
|
||||
The `kit-kit-agents/` directory demonstrates the multi-file pattern:
|
||||
|
||||
```
|
||||
kit-kit-agents/
|
||||
├── main.go # Entry point with Init()
|
||||
├── agent.go # Agent configuration
|
||||
├── manager.go # Agent lifecycle management
|
||||
└── README.md # Documentation
|
||||
```
|
||||
|
||||
When the repo is installed, all files in subdirectories with `main.go` are loaded as separate extensions.
|
||||
|
||||
## Testing & Validation
|
||||
|
||||
After installing, test the extensions:
|
||||
|
||||
```bash
|
||||
# List all loaded extensions
|
||||
kit extensions list
|
||||
|
||||
# Validate all extensions
|
||||
kit extensions validate
|
||||
|
||||
# Run with a specific extension
|
||||
kit -e ~/.local/share/kit/git/github.com/mark3labs/kit/examples/extensions/plan-mode.go
|
||||
```
|
||||
|
||||
## Creating Your Own
|
||||
|
||||
1. Copy `minimal.go` as a starting point
|
||||
2. Modify the `Init()` function to register your handlers
|
||||
3. Use the other examples for reference on specific APIs
|
||||
4. Test with `kit -e your-extension.go`
|
||||
5. Share by pushing to a git repository!
|
||||
|
||||
## Update
|
||||
|
||||
To get the latest examples:
|
||||
|
||||
```bash
|
||||
kit install github.com/mark3labs/kit/examples/extensions --update
|
||||
```
|
||||
|
||||
## See Also
|
||||
|
||||
- [Kit Extensions Guide](https://github.com/mark3labs/kit/blob/main/.agents/skills/kit-extensions/SKILL.md)
|
||||
- [API Reference](https://github.com/mark3labs/kit/blob/main/internal/extensions/api.go)
|
||||
- [Example Extensions Source](https://github.com/mark3labs/kit/tree/main/examples/extensions)
|
||||
@@ -23,8 +23,7 @@ import (
|
||||
func Init(api ext.API) {
|
||||
api.OnSessionShutdown(func(_ ext.SessionShutdownEvent, ctx ext.Context) {
|
||||
// Check for staged changes.
|
||||
diff, err := exec.Command("git", "diff", "--cached", "--quiet").CombinedOutput()
|
||||
_ = diff
|
||||
err := exec.Command("git", "diff", "--cached", "--quiet").Run()
|
||||
if err == nil {
|
||||
return // exit code 0 means no staged changes
|
||||
}
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
// Extension Test Template
|
||||
//
|
||||
// This is a template for writing tests for your Kit extension.
|
||||
// Copy this file to your extension directory, rename it to something like
|
||||
// "my-ext_test.go", and customize it for your extension.
|
||||
//
|
||||
// Run tests with: go test -v
|
||||
//
|
||||
// IMPORTANT: This file should be in the same directory as your extension
|
||||
// and use package main, NOT package test.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/pkg/extensions/test"
|
||||
)
|
||||
|
||||
// Test that your extension loads without errors
|
||||
func TestExtension_Loads(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
ext := harness.LoadFile("my-ext.go") // Change to your extension filename
|
||||
|
||||
// Verify the extension was loaded
|
||||
if ext == nil {
|
||||
t.Fatal("extension should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
// Test your event handlers are registered
|
||||
func TestExtension_EventHandlers(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("my-ext.go")
|
||||
|
||||
// Uncomment the handlers your extension uses:
|
||||
// test.AssertHasHandlers(t, harness, extensions.ToolCall)
|
||||
// test.AssertHasHandlers(t, harness, extensions.Input)
|
||||
// test.AssertHasHandlers(t, harness, extensions.SessionStart)
|
||||
// test.AssertHasHandlers(t, harness, extensions.AgentEnd)
|
||||
}
|
||||
|
||||
// Test tool registration
|
||||
func TestExtension_Tools(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("my-ext.go")
|
||||
|
||||
// Test that your tools are registered
|
||||
// test.AssertToolRegistered(t, harness, "my_tool")
|
||||
|
||||
// Or test all registered tools
|
||||
tools := harness.RegisteredTools()
|
||||
t.Logf("Registered %d tools", len(tools))
|
||||
for _, tool := range tools {
|
||||
t.Logf(" - %s: %s", tool.Name, tool.Description)
|
||||
}
|
||||
}
|
||||
|
||||
// Test command registration
|
||||
func TestExtension_Commands(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("my-ext.go")
|
||||
|
||||
// Test that your commands are registered
|
||||
// test.AssertCommandRegistered(t, harness, "mycommand")
|
||||
|
||||
// Or test all registered commands
|
||||
cmds := harness.RegisteredCommands()
|
||||
t.Logf("Registered %d commands", len(cmds))
|
||||
for _, cmd := range cmds {
|
||||
t.Logf(" - %s: %s", cmd.Name, cmd.Description)
|
||||
}
|
||||
}
|
||||
|
||||
// Test session start behavior
|
||||
func TestExtension_SessionStart(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("my-ext.go")
|
||||
|
||||
// Emit session start event
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{
|
||||
SessionID: "test-session",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Verify expected behavior:
|
||||
// - Did it print something?
|
||||
// test.AssertPrinted(t, harness, "expected output")
|
||||
|
||||
// - Did it set a widget?
|
||||
// test.AssertWidgetSet(t, harness, "my-widget")
|
||||
// test.AssertWidgetText(t, harness, "my-widget", "expected text")
|
||||
|
||||
// - Did it set the header/footer?
|
||||
// test.AssertHeaderSet(t, harness)
|
||||
// test.AssertFooterSet(t, harness)
|
||||
|
||||
// - Did it set a status?
|
||||
// test.AssertStatusSet(t, harness, "myext:status")
|
||||
}
|
||||
|
||||
// Test tool call handling
|
||||
func TestExtension_ToolCall(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("my-ext.go")
|
||||
|
||||
// Test a specific tool call
|
||||
result, err := harness.Emit(extensions.ToolCallEvent{
|
||||
ToolName: "some_tool",
|
||||
Input: `{"key": "value"}`,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// If your extension blocks certain tools:
|
||||
// test.AssertNotBlocked(t, result)
|
||||
// OR
|
||||
// test.AssertBlocked(t, result, "expected reason")
|
||||
|
||||
// Suppress unused variable warning (remove this when using result)
|
||||
_ = result
|
||||
|
||||
// Check for print output
|
||||
// test.AssertPrinted(t, harness, "expected message")
|
||||
}
|
||||
|
||||
// Test input handling
|
||||
func TestExtension_InputHandling(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("my-ext.go")
|
||||
|
||||
// Test input that should be handled
|
||||
result, err := harness.Emit(extensions.InputEvent{
|
||||
Text: "test input",
|
||||
Source: "cli",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// If your extension handles/transforms input:
|
||||
// test.AssertInputHandled(t, result, "handled")
|
||||
// OR
|
||||
// test.AssertInputTransformed(t, result, "transformed text")
|
||||
|
||||
// Suppress unused variable warning (remove this when using result)
|
||||
_ = result
|
||||
}
|
||||
|
||||
// Test with configured prompt results
|
||||
func TestExtension_WithPrompts(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("my-ext.go")
|
||||
|
||||
// Configure what prompts should return
|
||||
harness.Context().SetPromptSelectResult(extensions.PromptSelectResult{
|
||||
Value: "option1",
|
||||
Index: 0,
|
||||
Cancelled: false,
|
||||
})
|
||||
|
||||
// Now when your extension calls ctx.PromptSelect(), it gets the configured result
|
||||
_, _ = harness.Emit(extensions.SessionStartEvent{SessionID: "test"})
|
||||
|
||||
// Verify behavior based on the selected options
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
# kit-telegram
|
||||
|
||||
A Kit extension that relays all Kit agent runs to Telegram and lets approved Telegram users reply back into Kit.
|
||||
|
||||
## What it does
|
||||
|
||||
- Relays **all Kit runs** to one Telegram chat while connected
|
||||
- Edits one Telegram progress message in place during a run
|
||||
- Lets approved Telegram users send normal text replies back into Kit
|
||||
- Shows `Telegram Connected` or `Telegram Disconnected` in the status bar
|
||||
- Shows a small spinner animation as `⠋ Telegram Connecting` only while the relay is still connecting
|
||||
- On startup with an already validated enabled config, sends a short Telegram connection message to confirm the relay is up
|
||||
|
||||
## Requirements
|
||||
|
||||
- `kit` installed and working
|
||||
- A Telegram bot token from `@BotFather`
|
||||
- Either:
|
||||
- A Telegram chat where you can message the bot, or
|
||||
- A numeric Telegram chat id you want to enter manually
|
||||
- For group chats, one or more allowed Telegram user ids
|
||||
|
||||
## Quickstart
|
||||
|
||||
### 1. Install the extension
|
||||
|
||||
```bash
|
||||
kit install github.com/mark3labs/kit/examples/extensions/kit-telegram
|
||||
```
|
||||
|
||||
Or run directly:
|
||||
```bash
|
||||
kit -e path/to/kit-telegram/main.go
|
||||
```
|
||||
|
||||
### 2. Start Kit and connect Telegram
|
||||
|
||||
```bash
|
||||
kit
|
||||
```
|
||||
|
||||
Inside Kit, run:
|
||||
|
||||
```
|
||||
/telegram connect
|
||||
```
|
||||
|
||||
You will be prompted for:
|
||||
|
||||
- Bot token from `@BotFather`
|
||||
- Whether to auto-detect the chat by messaging the bot or enter the chat id manually
|
||||
- Allowed user ids when needed
|
||||
|
||||
### 3. Verify the relay
|
||||
|
||||
```
|
||||
/telegram test
|
||||
```
|
||||
|
||||
Reply in Telegram with the code from the test message.
|
||||
|
||||
## Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/telegram` | Human-friendly overview and subcommand list |
|
||||
| `/telegram status` | Raw deterministic relay state |
|
||||
| `/telegram test` | Verify outbound and inbound relay |
|
||||
| `/telegram toggle` | Enable or disable relay without deleting credentials |
|
||||
| `/telegram logout` | Remove saved credentials and disconnect relay |
|
||||
| `/telegram connect` | Run the setup flow again |
|
||||
| `/telegram clear` | Clear Telegram status and working messages from the TUI |
|
||||
|
||||
## Remote commands (from Telegram)
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/telegram` | Sends the overview back to Telegram |
|
||||
| `/telegram status` | Sends the deterministic state report to Telegram |
|
||||
| `/telegram test` | Sends a reply-code test message from Telegram |
|
||||
| `/telegram toggle` | Flips the enabled flag |
|
||||
| `/telegram logout yes` | Logs out (requires `yes` confirmation) |
|
||||
| `/telegram clear` | Clears the TUI footer and working messages |
|
||||
|
||||
## Key APIs Used
|
||||
|
||||
- `RegisterCommand` — Slash command with subcommands and tab completion
|
||||
- `OnSessionStart` / `OnSessionShutdown` — Lifecycle management
|
||||
- `OnAgentStart` / `OnAgentEnd` — Run tracking and progress rendering
|
||||
- `OnToolCall` / `OnToolResult` — Action tracking
|
||||
- `OnMessageEnd` — Capture assistant responses
|
||||
- `OnInput` — Mirror local messages to Telegram
|
||||
- `SetStatus` / `RemoveStatus` — Status bar indicators
|
||||
- `SetWidget` / `RemoveWidget` — Working message display
|
||||
- `PromptInput` / `PromptSelect` / `PromptConfirm` — Interactive setup flow
|
||||
- `SendMessage` — Inject Telegram replies as Kit prompts
|
||||
|
||||
## Architecture
|
||||
|
||||
Single Go file interpreted by Yaegi at runtime. Core components:
|
||||
|
||||
- **Telegram Bot API client** — HTTP calls via `net/http` for getMe, getChat, getChatMember, getUpdates (long-polling), sendMessage, editMessageText
|
||||
- **Config persistence** — JSON file at `.kit/kit-telegram.json` with atomic writes
|
||||
- **Long-polling goroutine** — Background polling for Telegram updates with warmup poll, retry, and client-side timeouts
|
||||
- **Message queue** — In-memory FIFO queue for Telegram prompt input with edit-before-dispatch support
|
||||
- **Progress rendering** — `⏳ elapsed · step N` with action lines, edited in place
|
||||
- **Final rendering** — `✅/❌ elapsed` with response text, split into chunks for long output
|
||||
|
||||
## Debug mode
|
||||
|
||||
Set environment variable `KIT_TELEGRAM_DEBUG=1` to enable verbose debug logging.
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,42 @@
|
||||
//go:build ignore
|
||||
|
||||
package main
|
||||
|
||||
import "kit/ext"
|
||||
|
||||
// Init registers a "neon" theme and a /neon slash command to apply it.
|
||||
// Demonstrates how extensions can create and set themes programmatically.
|
||||
//
|
||||
// Usage: kit -e examples/extensions/neon-theme.go
|
||||
func Init(api ext.API) {
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
// Register a cyberpunk neon theme at startup.
|
||||
ctx.RegisterTheme("neon", ext.ThemeColorConfig{
|
||||
Primary: ext.ThemeColor{Light: "#CC00FF", Dark: "#FF00FF"},
|
||||
Secondary: ext.ThemeColor{Light: "#0088CC", Dark: "#00FFFF"},
|
||||
Success: ext.ThemeColor{Light: "#00CC44", Dark: "#00FF66"},
|
||||
Warning: ext.ThemeColor{Light: "#CCAA00", Dark: "#FFFF00"},
|
||||
Error: ext.ThemeColor{Light: "#CC0033", Dark: "#FF0055"},
|
||||
Info: ext.ThemeColor{Light: "#0088CC", Dark: "#00CCFF"},
|
||||
Text: ext.ThemeColor{Light: "#111111", Dark: "#F0F0F0"},
|
||||
Background: ext.ThemeColor{Light: "#F0F0F0", Dark: "#0A0A14"},
|
||||
MdKeyword: ext.ThemeColor{Light: "#CC00FF", Dark: "#FF00FF"},
|
||||
MdString: ext.ThemeColor{Light: "#00CC44", Dark: "#00FF66"},
|
||||
MdComment: ext.ThemeColor{Light: "#888888", Dark: "#555555"},
|
||||
})
|
||||
|
||||
ctx.PrintInfo("Neon theme registered! Use /theme neon to activate.")
|
||||
})
|
||||
|
||||
// Also register a /neon slash command as a shortcut.
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "neon",
|
||||
Description: "Switch to the neon cyberpunk theme",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
if err := ctx.SetTheme("neon"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return "Neon theme activated!", nil
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
//go:build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
// Helper functions for the status-tools extension
|
||||
// These are used by main.go but kept in a separate file
|
||||
// to demonstrate the multi-file extension pattern.
|
||||
|
||||
// formatMemory converts bytes to human-readable format
|
||||
func formatMemory(bytes int64) string {
|
||||
const (
|
||||
KB = 1024
|
||||
MB = 1024 * KB
|
||||
GB = 1024 * MB
|
||||
)
|
||||
|
||||
switch {
|
||||
case bytes >= GB:
|
||||
return fmt.Sprintf("%.2f GB", float64(bytes)/float64(GB))
|
||||
case bytes >= MB:
|
||||
return fmt.Sprintf("%.2f MB", float64(bytes)/float64(MB))
|
||||
case bytes >= KB:
|
||||
return fmt.Sprintf("%.2f KB", float64(bytes)/float64(KB))
|
||||
default:
|
||||
return fmt.Sprintf("%d B", bytes)
|
||||
}
|
||||
}
|
||||
|
||||
// showMemoryStatus displays memory usage (placeholder)
|
||||
func showMemoryStatus(ctx ext.Context) {
|
||||
// This is a placeholder that would show memory stats
|
||||
// In a real extension, you'd integrate with system metrics
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: "Memory status monitoring not yet implemented",
|
||||
BorderColor: "#f9e2af",
|
||||
Subtitle: "Memory",
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
//go:build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
// Init registers the status tools extension.
|
||||
// This extension provides multiple status-related utilities as a
|
||||
// multi-file extension example.
|
||||
func Init(api ext.API) {
|
||||
// Register a status bar widget that shows time
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
ctx.SetStatus("clock", time.Now().Format("15:04:05"), 5)
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
// Register a /status command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "status",
|
||||
Description: "Show system status information",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
stats := ctx.GetContextStats()
|
||||
info := fmt.Sprintf(
|
||||
"Model: %s\nTokens: %d/%d (%.1f%%)\nMessages: %d",
|
||||
ctx.Model,
|
||||
stats.EstimatedTokens,
|
||||
stats.ContextLimit,
|
||||
stats.UsagePercent*100,
|
||||
stats.MessageCount,
|
||||
)
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: info,
|
||||
BorderColor: "#89b4fa",
|
||||
Subtitle: "System Status",
|
||||
})
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
//go:build ignore
|
||||
|
||||
// Subagent Test Extension — Tests the new first-class subagent API
|
||||
//
|
||||
// Commands:
|
||||
//
|
||||
// /subtest <task> — spawn a blocking subagent and print result
|
||||
// /subbg <task> — spawn a background subagent with live output
|
||||
//
|
||||
// Usage: kit -e examples/extensions/subagent-test.go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
latestCtx ext.Context
|
||||
hasCtx bool
|
||||
)
|
||||
|
||||
func Init(api ext.API) {
|
||||
// Keep context fresh
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
mu.Lock()
|
||||
latestCtx = ctx
|
||||
hasCtx = true
|
||||
mu.Unlock()
|
||||
|
||||
ctx.PrintInfo(
|
||||
"Subagent Test Extension loaded\n\n" +
|
||||
"/subtest <task> Spawn blocking subagent\n" +
|
||||
"/subbg <task> Spawn background subagent\n\n" +
|
||||
"The LLM can also use the spawn_subagent tool.")
|
||||
})
|
||||
|
||||
api.OnAgentEnd(func(_ ext.AgentEndEvent, ctx ext.Context) {
|
||||
mu.Lock()
|
||||
latestCtx = ctx
|
||||
mu.Unlock()
|
||||
})
|
||||
|
||||
// Command: /subtest <task> — blocking subagent
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "subtest",
|
||||
Description: "Spawn a blocking subagent: /subtest <task>",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
mu.Lock()
|
||||
latestCtx = ctx
|
||||
hasCtx = true
|
||||
mu.Unlock()
|
||||
|
||||
task := strings.TrimSpace(args)
|
||||
if task == "" {
|
||||
return "Usage: /subtest <task>", nil
|
||||
}
|
||||
|
||||
ctx.PrintInfo(fmt.Sprintf("Spawning blocking subagent for: %s", task))
|
||||
|
||||
start := time.Now()
|
||||
_, result, err := ctx.SpawnSubagent(ext.SubagentConfig{
|
||||
Prompt: task,
|
||||
Timeout: 2 * time.Minute,
|
||||
Blocking: true,
|
||||
})
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Spawn error: %v", err), nil
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
return "No result returned", nil
|
||||
}
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Sprintf("Subagent failed (exit %d) after %ds: %v\n\nPartial output:\n%s",
|
||||
result.ExitCode, int(elapsed.Seconds()), result.Error, truncate(result.Response, 2000)), nil
|
||||
}
|
||||
|
||||
response := fmt.Sprintf("Subagent completed in %ds", int(elapsed.Seconds()))
|
||||
if result.Usage != nil {
|
||||
response += fmt.Sprintf(" (tokens: %d in / %d out)", result.Usage.InputTokens, result.Usage.OutputTokens)
|
||||
}
|
||||
response += fmt.Sprintf("\n\nResult:\n%s", truncate(result.Response, 4000))
|
||||
|
||||
return response, nil
|
||||
},
|
||||
})
|
||||
|
||||
// Command: /subbg <task> — background subagent with callbacks
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "subbg",
|
||||
Description: "Spawn a background subagent: /subbg <task>",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
mu.Lock()
|
||||
latestCtx = ctx
|
||||
hasCtx = true
|
||||
mu.Unlock()
|
||||
|
||||
task := strings.TrimSpace(args)
|
||||
if task == "" {
|
||||
return "Usage: /subbg <task>", nil
|
||||
}
|
||||
|
||||
ctx.PrintInfo(fmt.Sprintf("Spawning background subagent for: %s", task))
|
||||
|
||||
start := time.Now()
|
||||
handle, _, err := ctx.SpawnSubagent(ext.SubagentConfig{
|
||||
Prompt: task,
|
||||
Timeout: 2 * time.Minute,
|
||||
OnOutput: func(chunk string) {
|
||||
// Live output - could update a widget here
|
||||
fmt.Print(chunk)
|
||||
},
|
||||
OnComplete: func(result ext.SubagentResult) {
|
||||
elapsed := time.Since(start)
|
||||
|
||||
mu.Lock()
|
||||
c := latestCtx
|
||||
ok := hasCtx
|
||||
mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if result.Error != nil {
|
||||
c.SendMessage(fmt.Sprintf("Background subagent failed after %ds: %v",
|
||||
int(elapsed.Seconds()), result.Error))
|
||||
return
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("Background subagent completed in %ds", int(elapsed.Seconds()))
|
||||
if result.Usage != nil {
|
||||
msg += fmt.Sprintf(" (tokens: %d in / %d out)", result.Usage.InputTokens, result.Usage.OutputTokens)
|
||||
}
|
||||
msg += fmt.Sprintf("\n\nResult:\n%s", truncate(result.Response, 4000))
|
||||
|
||||
c.SendMessage(msg)
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Spawn error: %v", err), nil
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Background subagent spawned (ID: %s). Results will be delivered when complete.", handle.ID), nil
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func truncate(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max] + "\n\n... [truncated]"
|
||||
}
|
||||
@@ -0,0 +1,358 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/pkg/extensions/test"
|
||||
)
|
||||
|
||||
// Test that the tool-logger extension loads and registers handlers
|
||||
func TestToolLogger_Loads(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
ext := harness.LoadFile("tool-logger.go")
|
||||
|
||||
if ext == nil {
|
||||
t.Fatal("extension should not be nil")
|
||||
}
|
||||
|
||||
// Verify all expected handlers are registered
|
||||
test.AssertHasHandlers(t, harness, extensions.ToolCall)
|
||||
test.AssertHasHandlers(t, harness, extensions.ToolResult)
|
||||
test.AssertHasHandlers(t, harness, extensions.SessionStart)
|
||||
test.AssertHasHandlers(t, harness, extensions.SessionShutdown)
|
||||
test.AssertHasHandlers(t, harness, extensions.Input)
|
||||
}
|
||||
|
||||
// Test that tool calls are logged (handlers run without errors)
|
||||
func TestToolLogger_ToolCall(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("tool-logger.go")
|
||||
|
||||
// Emit a tool call event
|
||||
result, err := harness.Emit(extensions.ToolCallEvent{
|
||||
ToolName: "Read",
|
||||
ToolCallID: "call-123",
|
||||
Input: `{"file": "test.txt"}`,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Tool logger should not block any tools
|
||||
test.AssertNotBlocked(t, result)
|
||||
}
|
||||
|
||||
// Test that tool results are processed
|
||||
func TestToolLogger_ToolResult(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("tool-logger.go")
|
||||
|
||||
content := "Hello, World!"
|
||||
result, err := harness.Emit(extensions.ToolResultEvent{
|
||||
ToolName: "Read",
|
||||
Content: content,
|
||||
IsError: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Tool logger should not modify results
|
||||
if result != nil {
|
||||
t.Error("expected nil result (no modification)")
|
||||
}
|
||||
}
|
||||
|
||||
// Test that error tool results are handled
|
||||
func TestToolLogger_ToolResultError(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("tool-logger.go")
|
||||
|
||||
result, err := harness.Emit(extensions.ToolResultEvent{
|
||||
ToolName: "Bash",
|
||||
Content: "command not found",
|
||||
IsError: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
t.Error("expected nil result (no modification)")
|
||||
}
|
||||
}
|
||||
|
||||
// Test session start handler
|
||||
func TestToolLogger_SessionStart(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("tool-logger.go")
|
||||
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{
|
||||
SessionID: "test-session-123",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Handler should run without errors (logs to file)
|
||||
// Since file logging happens outside our mock, we just verify no errors
|
||||
}
|
||||
|
||||
// Test session shutdown handler
|
||||
func TestToolLogger_SessionShutdown(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("tool-logger.go")
|
||||
|
||||
_, err := harness.Emit(extensions.SessionShutdownEvent{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test the !time command
|
||||
func TestToolLogger_TimeCommand(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("tool-logger.go")
|
||||
|
||||
result, err := harness.Emit(extensions.InputEvent{
|
||||
Text: "!time",
|
||||
Source: "cli",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
test.AssertInputHandled(t, result, "handled")
|
||||
|
||||
// Verify PrintInfo was called with a time message
|
||||
infos := harness.Context().GetPrintInfos()
|
||||
found := false
|
||||
for _, info := range infos {
|
||||
if strings.Contains(info, "Current time:") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("expected PrintInfo with 'Current time:', got: %v", infos)
|
||||
}
|
||||
}
|
||||
|
||||
// Test the !status command
|
||||
func TestToolLogger_StatusCommand(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("tool-logger.go")
|
||||
|
||||
result, err := harness.Emit(extensions.InputEvent{
|
||||
Text: "!status",
|
||||
Source: "cli",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
test.AssertInputHandled(t, result, "handled")
|
||||
|
||||
// Verify PrintBlock was called
|
||||
blocks := harness.Context().PrintBlocks
|
||||
if len(blocks) != 1 {
|
||||
t.Fatalf("expected 1 PrintBlock call, got %d", len(blocks))
|
||||
}
|
||||
|
||||
block := blocks[0]
|
||||
if block.Subtitle != "tool-logger extension" {
|
||||
t.Errorf("expected subtitle 'tool-logger extension', got %q", block.Subtitle)
|
||||
}
|
||||
if block.BorderColor != "#a6e3a1" {
|
||||
t.Errorf("expected border color '#a6e3a1', got %q", block.BorderColor)
|
||||
}
|
||||
if !strings.Contains(block.Text, "Session active") {
|
||||
t.Errorf("expected text to contain 'Session active', got %q", block.Text)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that unknown commands are not handled
|
||||
func TestToolLogger_UnknownCommand(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("tool-logger.go")
|
||||
|
||||
result, err := harness.Emit(extensions.InputEvent{
|
||||
Text: "!unknown",
|
||||
Source: "cli",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
t.Errorf("expected nil result for unknown command, got %v", result)
|
||||
}
|
||||
|
||||
// Verify no info/block prints for unknown commands
|
||||
if len(harness.Context().GetPrintInfos()) != 0 {
|
||||
t.Error("expected no PrintInfo calls for unknown command")
|
||||
}
|
||||
if len(harness.Context().PrintBlocks) != 0 {
|
||||
t.Error("expected no PrintBlock calls for unknown command")
|
||||
}
|
||||
}
|
||||
|
||||
// Test regular text input (not a command)
|
||||
func TestToolLogger_RegularInput(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("tool-logger.go")
|
||||
|
||||
result, err := harness.Emit(extensions.InputEvent{
|
||||
Text: "This is a normal message",
|
||||
Source: "cli",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
t.Errorf("expected nil result for regular input, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
// Test complete session flow
|
||||
func TestToolLogger_FullSession(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("tool-logger.go")
|
||||
|
||||
// Simulate a full session
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Several tool calls
|
||||
tools := []string{"Read", "Glob", "Grep", "Bash"}
|
||||
for _, tool := range tools {
|
||||
_, err := harness.Emit(extensions.ToolCallEvent{
|
||||
ToolName: tool,
|
||||
Input: "{}",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("error for tool %s: %v", tool, err)
|
||||
}
|
||||
|
||||
_, err = harness.Emit(extensions.ToolResultEvent{
|
||||
ToolName: tool,
|
||||
Content: "result",
|
||||
IsError: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("error for tool result %s: %v", tool, err)
|
||||
}
|
||||
}
|
||||
|
||||
// User issues a command
|
||||
_, err = harness.Emit(extensions.InputEvent{Text: "!time", Source: "cli"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
_, err = harness.Emit(extensions.SessionShutdownEvent{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Verify the !time command was handled
|
||||
if len(harness.Context().GetPrintInfos()) != 1 {
|
||||
t.Errorf("expected 1 PrintInfo call, got %d", len(harness.Context().GetPrintInfos()))
|
||||
}
|
||||
}
|
||||
|
||||
// Test that the extension handles file write errors gracefully
|
||||
func TestToolLogger_FileError(t *testing.T) {
|
||||
// This test verifies the extension doesn't panic when file operations fail
|
||||
// Since we can't easily mock os.OpenFile, we rely on the extension code
|
||||
// properly checking for errors (which it does)
|
||||
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("tool-logger.go")
|
||||
|
||||
// Just verify the handlers run without panicking
|
||||
_, err := harness.Emit(extensions.ToolCallEvent{ToolName: "Read", Input: "{}"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
_, err = harness.Emit(extensions.SessionStartEvent{SessionID: "test"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test concurrent tool calls (race condition check)
|
||||
func TestToolLogger_ConcurrentToolCalls(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("tool-logger.go")
|
||||
|
||||
// Run multiple tool calls concurrently
|
||||
done := make(chan bool, 10)
|
||||
for i := range 10 {
|
||||
go func(index int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
toolName := "Tool" + string(rune('0'+index))
|
||||
_, err := harness.Emit(extensions.ToolCallEvent{
|
||||
ToolName: toolName,
|
||||
Input: "{}",
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("error in goroutine %d: %v", index, err)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for range 10 {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
// Test the actual log file is created and written to
|
||||
func TestToolLogger_LogFile(t *testing.T) {
|
||||
logFile := "/tmp/kit-tool-log.txt"
|
||||
|
||||
// Clean up before test
|
||||
_ = os.Remove(logFile)
|
||||
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("tool-logger.go")
|
||||
|
||||
// Emit events
|
||||
_, _ = harness.Emit(extensions.SessionStartEvent{SessionID: "test"})
|
||||
_, _ = harness.Emit(extensions.ToolCallEvent{ToolName: "Read", Input: "{}"})
|
||||
_, _ = harness.Emit(extensions.ToolResultEvent{ToolName: "Read", Content: "data", IsError: false})
|
||||
|
||||
// Note: Since the extension writes to a real file and the test harness
|
||||
// mocks the context, the file writes actually happen. Let's verify.
|
||||
|
||||
// Give it a moment for file operations
|
||||
if _, err := os.Stat(logFile); err == nil {
|
||||
// File exists - read and verify content
|
||||
content, err := os.ReadFile(logFile)
|
||||
if err != nil {
|
||||
t.Logf("Could not read log file: %v", err)
|
||||
} else {
|
||||
contentStr := string(content)
|
||||
if !strings.Contains(contentStr, "SESSION_START") {
|
||||
t.Error("log file should contain SESSION_START")
|
||||
}
|
||||
if !strings.Contains(contentStr, "CALL tool=Read") {
|
||||
t.Error("log file should contain CALL tool=Read")
|
||||
}
|
||||
if !strings.Contains(contentStr, "RESULT tool=Read") {
|
||||
t.Error("log file should contain RESULT tool=Read")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
t.Log("Note: Log file not created - this is expected since the extension writes directly to disk")
|
||||
}
|
||||
}
|
||||
@@ -28,7 +28,7 @@ func Init(api ext.API) {
|
||||
DisplayName: "File",
|
||||
BorderColor: "#89b4fa", // Catppuccin blue
|
||||
RenderHeader: func(toolArgs string, width int) string {
|
||||
var args map[string]interface{}
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(toolArgs), &args); err != nil {
|
||||
return ""
|
||||
}
|
||||
@@ -72,7 +72,7 @@ func Init(api ext.API) {
|
||||
Background: "#1e1e2e", // Dark background
|
||||
BorderColor: "#a6e3a1", // Catppuccin green
|
||||
RenderHeader: func(toolArgs string, width int) string {
|
||||
var args map[string]interface{}
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(toolArgs), &args); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -4,11 +4,11 @@ go 1.26.0
|
||||
|
||||
require (
|
||||
charm.land/bubbles/v2 v2.0.0
|
||||
charm.land/bubbletea/v2 v2.0.1
|
||||
charm.land/bubbletea/v2 v2.0.2
|
||||
charm.land/fantasy v0.11.1
|
||||
charm.land/lipgloss/v2 v2.0.0
|
||||
charm.land/lipgloss/v2 v2.0.1
|
||||
github.com/alecthomas/chroma/v2 v2.23.1
|
||||
github.com/aymanbagabas/go-udiff v0.4.0
|
||||
github.com/aymanbagabas/go-udiff v0.4.1
|
||||
github.com/charmbracelet/fang v0.4.4
|
||||
github.com/charmbracelet/log v0.4.2
|
||||
github.com/mark3labs/mcp-go v0.44.1
|
||||
@@ -20,6 +20,7 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
charm.land/huh/v2 v2.0.3 // indirect
|
||||
cloud.google.com/go v0.123.0 // indirect
|
||||
cloud.google.com/go/auth v0.18.2 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
|
||||
@@ -45,6 +46,7 @@ require (
|
||||
github.com/aymerick/douceur v0.2.0 // indirect
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/catppuccin/go v0.2.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab // indirect
|
||||
github.com/charmbracelet/colorprofile v0.4.2 // indirect
|
||||
@@ -53,13 +55,17 @@ require (
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260303162955-0b88c25f3fff // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260305213658-fe36e8c10185 // indirect
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260305213658-fe36e8c10185 // indirect
|
||||
github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 // indirect
|
||||
github.com/charmbracelet/x/json v0.2.0 // indirect
|
||||
github.com/charmbracelet/x/termios v0.1.1 // indirect
|
||||
github.com/charmbracelet/x/windows v0.2.2 // indirect
|
||||
github.com/clipperhouse/displaywidth v0.11.0 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0 // indirect
|
||||
github.com/coder/acp-go-sdk v0.6.3 // indirect
|
||||
github.com/dlclark/regexp2 v1.11.5 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 // indirect
|
||||
@@ -83,6 +89,7 @@ require (
|
||||
github.com/kaptinlin/messageformat-go v0.4.18 // indirect
|
||||
github.com/mailru/easyjson v0.9.1 // indirect
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect
|
||||
github.com/muesli/mango v0.2.0 // indirect
|
||||
github.com/muesli/mango-cobra v1.3.0 // indirect
|
||||
github.com/muesli/mango-pflag v0.2.0 // indirect
|
||||
@@ -137,6 +144,6 @@ require (
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/text v0.34.0 // indirect
|
||||
)
|
||||
|
||||
@@ -2,10 +2,16 @@ charm.land/bubbles/v2 v2.0.0 h1:tE3eK/pHjmtrDiRdoC9uGNLgpopOd8fjhEe31B/ai5s=
|
||||
charm.land/bubbles/v2 v2.0.0/go.mod h1:rCHoleP2XhU8um45NTuOWBPNVHxnkXKTiZqcclL/qOI=
|
||||
charm.land/bubbletea/v2 v2.0.1 h1:B8e9zzK7x9JJ+XvHGF4xnYu9Xa0E0y0MyggY6dbaCfQ=
|
||||
charm.land/bubbletea/v2 v2.0.1/go.mod h1:3LRff2U4WIYXy7MTxfbAQ+AdfM3D8Xuvz2wbsOD9OHQ=
|
||||
charm.land/bubbletea/v2 v2.0.2 h1:4CRtRnuZOdFDTWSff9r8QFt/9+z6Emubz3aDMnf/dx0=
|
||||
charm.land/bubbletea/v2 v2.0.2/go.mod h1:3LRff2U4WIYXy7MTxfbAQ+AdfM3D8Xuvz2wbsOD9OHQ=
|
||||
charm.land/fantasy v0.11.1 h1:G1dRqkzEQ0RJN1Ls5mte8HOi0wFKxYd5bfnRAmeYvDk=
|
||||
charm.land/fantasy v0.11.1/go.mod h1:C8wNxWlw+b2z54zsTor9r1tG2GE2C4QotvAlgXh9KF8=
|
||||
charm.land/huh/v2 v2.0.3 h1:2cJsMqEPwSywGHvdlKsJyQKPtSJLVnFKyFbsYZTlLkU=
|
||||
charm.land/huh/v2 v2.0.3/go.mod h1:93eEveeeqn47MwiC3tf+2atZ2l7Is88rAtmZNZ8x9Wc=
|
||||
charm.land/lipgloss/v2 v2.0.0 h1:sd8N/B3x892oiOjFfBQdXBQp3cAkvjGaU5TvVZC3ivo=
|
||||
charm.land/lipgloss/v2 v2.0.0/go.mod h1:w6SnmsBFBmEFBodiEDurGS/sdUY/u1+v72DqUzc6J14=
|
||||
charm.land/lipgloss/v2 v2.0.1 h1:6Xzrn49+Py1Um5q/wZG1gWgER2+7dUyZ9XMEufqPSys=
|
||||
charm.land/lipgloss/v2 v2.0.1/go.mod h1:KjPle2Qd3YmvP1KL5OMHiHysGcNwq6u83MUjYkFvEkM=
|
||||
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
|
||||
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
|
||||
cloud.google.com/go/auth v0.18.2 h1:+Nbt5Ev0xEqxlNjd6c+yYUeosQ5TtEUaNcN/3FozlaM=
|
||||
@@ -66,12 +72,16 @@ github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiE
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/aymanbagabas/go-udiff v0.4.0 h1:TKnLPh7IbnizJIBKFWa9mKayRUBQ9Kh1BPCk6w2PnYM=
|
||||
github.com/aymanbagabas/go-udiff v0.4.0/go.mod h1:0L9PGwj20lrtmEMeyw4WKJ/TMyDtvAoK9bf2u/mNo3w=
|
||||
github.com/aymanbagabas/go-udiff v0.4.1 h1:OEIrQ8maEeDBXQDoGCbbTTXYJMYRCRO1fnodZ12Gv5o=
|
||||
github.com/aymanbagabas/go-udiff v0.4.1/go.mod h1:0L9PGwj20lrtmEMeyw4WKJ/TMyDtvAoK9bf2u/mNo3w=
|
||||
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
|
||||
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
|
||||
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
|
||||
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
|
||||
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
|
||||
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
|
||||
github.com/catppuccin/go v0.2.0 h1:ktBeIrIP42b/8FGiScP9sgrWOss3lw0Z5SktRoithGA=
|
||||
github.com/catppuccin/go v0.2.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab h1:J7XQLgl9sefgTnTGrmX3xqvp5o6MCiBzEjGv5igAlc4=
|
||||
@@ -98,8 +108,12 @@ github.com/charmbracelet/x/exp/charmtone v0.0.0-20260305213658-fe36e8c10185 h1:/
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260305213658-fe36e8c10185/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f h1:pk6gmGpCE7F3FcjaOEKYriCvpmIN4+6OS/RD0vm4uIA=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f/go.mod h1:IfZAMTHB6XkZSeXUqriemErjAWCCzT0LwjKFYCZyw0I=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 h1:55/qLwjIh0gL0Vni+QAWk7T/qRVP6sBf+2agPBgnOFE=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0/go.mod h1:5UHwmG+is5THxMyCJHNPCn2/ecI07aKNrW+LcResjJ8=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260305213658-fe36e8c10185 h1:bloHJLweYZeIkBVgi8AF94DrTdx3eoEB57VOpFuFi3U=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260305213658-fe36e8c10185/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 h1:qko3AQ4gK1MTS/de7F5hPGx6/k1u0w4TeYmBFwzYVP4=
|
||||
github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0/go.mod h1:pBhA0ybfXv6hDjQUZ7hk1lVxBiUbupdw5R31yPUViVQ=
|
||||
github.com/charmbracelet/x/json v0.2.0 h1:DqB+ZGx2h+Z+1s98HOuOyli+i97wsFQIxP2ZQANTPrQ=
|
||||
github.com/charmbracelet/x/json v0.2.0/go.mod h1:opFIflx2YgXgi49xVUu8gEQ21teFAxyMwvOiZhIvWNM=
|
||||
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
|
||||
@@ -114,6 +128,8 @@ github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJ
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM=
|
||||
github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2 h1:aBangftG7EVZoUb69Os8IaYg++6uMOdKK83QtkkvJik=
|
||||
github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2/go.mod h1:qwXFYgsP6T7XnJtbKlf1HP8AjxZZyzxMmc+Lq5GjlU4=
|
||||
github.com/coder/acp-go-sdk v0.6.3 h1:LsXQytehdjKIYJnoVWON/nf7mqbiarnyuyE3rrjBsXQ=
|
||||
github.com/coder/acp-go-sdk v0.6.3/go.mod h1:yKzM/3R9uELp4+nBAwwtkS0aN1FOFjo11CNPy37yFko=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
@@ -121,6 +137,8 @@ github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZ
|
||||
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI=
|
||||
github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/envoyproxy/go-control-plane v0.14.0 h1:hbG2kr4RuFj222B6+7T83thSPqLjwBIfQawTkC++2HA=
|
||||
github.com/envoyproxy/go-control-plane/envoy v1.37.0 h1:u3riX6BoYRfF4Dr7dwSOroNfdSbEPe9Yyl09/B6wBrQ=
|
||||
github.com/envoyproxy/go-control-plane/envoy v1.37.0/go.mod h1:DReE9MMrmecPy+YvQOAOHNYMALuowAnbjjEMkkWOi6A=
|
||||
@@ -196,6 +214,8 @@ github.com/mattn/go-runewidth v0.0.20 h1:WcT52H91ZUAwy8+HUkdM3THM6gXqXuLJi9O3rjc
|
||||
github.com/mattn/go-runewidth v0.0.20/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk=
|
||||
github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA=
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||
github.com/muesli/mango v0.2.0 h1:iNNc0c5VLQ6fsMgAqGQofByNUBH2Q2nEbD6TaI+5yyQ=
|
||||
@@ -298,6 +318,8 @@ golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
|
||||
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
|
||||
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
|
||||
|
||||
@@ -0,0 +1,260 @@
|
||||
// Package acpserver implements a Kit-backed ACP (Agent Client Protocol) agent.
|
||||
//
|
||||
// It bridges Kit's LLM execution, tool system, and session management to the
|
||||
// ACP protocol over stdio, allowing ACP clients (such as OpenCode) to drive
|
||||
// Kit as a remote coding agent.
|
||||
package acpserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
acp "github.com/coder/acp-go-sdk"
|
||||
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// Version is injected at build time; fallback to "dev".
|
||||
var Version = "dev"
|
||||
|
||||
// Agent implements the acp.Agent interface, delegating to Kit for LLM
|
||||
// execution, tool calls, and session management.
|
||||
type Agent struct {
|
||||
conn *acp.AgentSideConnection
|
||||
registry *sessionRegistry
|
||||
|
||||
// toolCallCounter provides unique IDs for tool calls within a turn.
|
||||
toolCallCounter atomic.Int64
|
||||
}
|
||||
|
||||
// NewAgent creates a new ACP agent backed by Kit.
|
||||
func NewAgent() *Agent {
|
||||
return &Agent{
|
||||
registry: newSessionRegistry(),
|
||||
}
|
||||
}
|
||||
|
||||
// SetAgentConnection stores the connection so the agent can send session
|
||||
// updates (streaming, tool calls, etc.) back to the ACP client. This follows
|
||||
// the AgentConnAware duck-typing pattern from the SDK.
|
||||
func (a *Agent) SetAgentConnection(conn *acp.AgentSideConnection) {
|
||||
a.conn = conn
|
||||
}
|
||||
|
||||
// Close shuts down all active sessions.
|
||||
func (a *Agent) Close() {
|
||||
a.registry.closeAll()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// acp.Agent interface implementation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Authenticate handles authentication requests. Kit doesn't require auth for
|
||||
// local stdio usage, so this is a no-op.
|
||||
func (a *Agent) Authenticate(_ context.Context, _ acp.AuthenticateRequest) (acp.AuthenticateResponse, error) {
|
||||
return acp.AuthenticateResponse{}, nil
|
||||
}
|
||||
|
||||
// Initialize negotiates capabilities with the ACP client.
|
||||
func (a *Agent) Initialize(_ context.Context, params acp.InitializeRequest) (acp.InitializeResponse, error) {
|
||||
log.Debug("acp: initialize", "protocol_version", params.ProtocolVersion)
|
||||
|
||||
return acp.InitializeResponse{
|
||||
ProtocolVersion: acp.ProtocolVersion(1),
|
||||
AgentCapabilities: acp.AgentCapabilities{
|
||||
LoadSession: true,
|
||||
PromptCapabilities: acp.PromptCapabilities{
|
||||
EmbeddedContext: true,
|
||||
Image: true,
|
||||
},
|
||||
},
|
||||
AgentInfo: &acp.Implementation{
|
||||
Name: "Kit",
|
||||
Version: Version,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewSession creates a new Kit session for the given working directory.
|
||||
func (a *Agent) NewSession(ctx context.Context, params acp.NewSessionRequest) (acp.NewSessionResponse, error) {
|
||||
cwd := params.Cwd
|
||||
if cwd == "" {
|
||||
return acp.NewSessionResponse{}, acp.NewInvalidParams("cwd is required")
|
||||
}
|
||||
|
||||
log.Debug("acp: new_session", "cwd", cwd)
|
||||
|
||||
sess, err := a.registry.create(ctx, cwd)
|
||||
if err != nil {
|
||||
log.Error("acp: session creation failed", "cwd", cwd, "error", err)
|
||||
return acp.NewSessionResponse{}, fmt.Errorf("create session: %w", err)
|
||||
}
|
||||
|
||||
return acp.NewSessionResponse{
|
||||
SessionId: acp.SessionId(sess.sessionID),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Prompt handles the main agent execution. It subscribes to Kit's event bus,
|
||||
// converts events to ACP session updates, and runs the prompt through Kit's
|
||||
// full turn lifecycle (hooks, LLM, tool calls, persistence).
|
||||
func (a *Agent) Prompt(ctx context.Context, params acp.PromptRequest) (acp.PromptResponse, error) {
|
||||
sessionID := string(params.SessionId)
|
||||
sess, ok := a.registry.get(sessionID)
|
||||
if !ok {
|
||||
return acp.PromptResponse{}, acp.NewInvalidParams(
|
||||
fmt.Sprintf("session not found: %s", sessionID),
|
||||
)
|
||||
}
|
||||
|
||||
// Extract text from prompt content blocks.
|
||||
promptText := extractPromptText(params.Prompt)
|
||||
if promptText == "" {
|
||||
return acp.PromptResponse{}, acp.NewInvalidParams("empty prompt")
|
||||
}
|
||||
|
||||
log.Debug("acp: prompt", "session", sessionID, "prompt_len", len(promptText))
|
||||
|
||||
// Create a cancellable context for this prompt turn.
|
||||
promptCtx, cancel := context.WithCancel(ctx)
|
||||
sess.setCancel(cancel)
|
||||
defer sess.clearCancel()
|
||||
|
||||
// Subscribe to Kit events and stream them as ACP session updates.
|
||||
unsub := a.subscribeEvents(promptCtx, sess.kit, params.SessionId)
|
||||
defer unsub()
|
||||
|
||||
// Run the prompt through Kit's full turn lifecycle.
|
||||
_, err := sess.kit.PromptResult(promptCtx, promptText)
|
||||
if err != nil {
|
||||
if promptCtx.Err() != nil {
|
||||
return acp.PromptResponse{
|
||||
StopReason: acp.StopReasonCancelled,
|
||||
}, nil
|
||||
}
|
||||
return acp.PromptResponse{}, fmt.Errorf("prompt failed: %w", err)
|
||||
}
|
||||
|
||||
return acp.PromptResponse{
|
||||
StopReason: acp.StopReasonEndTurn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Cancel cancels the ongoing prompt for a session.
|
||||
func (a *Agent) Cancel(_ context.Context, params acp.CancelNotification) error {
|
||||
sessionID := string(params.SessionId)
|
||||
sess, ok := a.registry.get(sessionID)
|
||||
if !ok {
|
||||
return nil // No-op if session doesn't exist.
|
||||
}
|
||||
|
||||
log.Debug("acp: cancel", "session", sessionID)
|
||||
sess.cancelPrompt()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetSessionMode is a no-op for now — Kit doesn't have built-in session modes.
|
||||
func (a *Agent) SetSessionMode(_ context.Context, _ acp.SetSessionModeRequest) (acp.SetSessionModeResponse, error) {
|
||||
return acp.SetSessionModeResponse{}, nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Event streaming: Kit events → ACP SessionUpdate notifications
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// subscribeEvents subscribes to Kit's event bus and forwards events as ACP
|
||||
// session update notifications to the client.
|
||||
func (a *Agent) subscribeEvents(ctx context.Context, k *kit.Kit, sessionID acp.SessionId) func() {
|
||||
return k.Subscribe(func(e kit.Event) {
|
||||
// Don't send updates after the context is cancelled.
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var update *acp.SessionUpdate
|
||||
switch ev := e.(type) {
|
||||
case kit.MessageUpdateEvent:
|
||||
u := acp.UpdateAgentMessageText(ev.Chunk)
|
||||
update = &u
|
||||
|
||||
case kit.ReasoningDeltaEvent:
|
||||
u := acp.UpdateAgentThoughtText(ev.Delta)
|
||||
update = &u
|
||||
|
||||
case kit.ToolCallEvent:
|
||||
tcID := acp.ToolCallId(ev.ToolCallID)
|
||||
if tcID == "" {
|
||||
tcID = acp.ToolCallId(fmt.Sprintf("tc_%d", a.toolCallCounter.Add(1)))
|
||||
}
|
||||
u := acp.StartToolCall(tcID, ev.ToolName,
|
||||
acp.WithStartStatus(acp.ToolCallStatusInProgress),
|
||||
acp.WithStartRawInput(parseToolArgs(ev.ToolArgs)),
|
||||
)
|
||||
update = &u
|
||||
|
||||
case kit.ToolResultEvent:
|
||||
tcID := acp.ToolCallId(ev.ToolCallID)
|
||||
if tcID == "" {
|
||||
tcID = acp.ToolCallId(fmt.Sprintf("tc_%d", a.toolCallCounter.Load()))
|
||||
}
|
||||
status := acp.ToolCallStatusCompleted
|
||||
if ev.IsError {
|
||||
status = acp.ToolCallStatusFailed
|
||||
}
|
||||
u := acp.UpdateToolCall(tcID,
|
||||
acp.WithUpdateStatus(status),
|
||||
acp.WithUpdateContent([]acp.ToolCallContent{
|
||||
acp.ToolContent(acp.TextBlock(ev.Result)),
|
||||
}),
|
||||
)
|
||||
update = &u
|
||||
|
||||
case kit.ToolCallContentEvent:
|
||||
u := acp.UpdateAgentMessageText(ev.Content)
|
||||
update = &u
|
||||
}
|
||||
|
||||
if update != nil {
|
||||
_ = a.conn.SessionUpdate(ctx, acp.SessionNotification{
|
||||
SessionId: sessionID,
|
||||
Update: *update,
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// extractPromptText extracts the concatenated text content from ACP content
|
||||
// blocks. Non-text blocks are ignored for now.
|
||||
func extractPromptText(blocks []acp.ContentBlock) string {
|
||||
var text string
|
||||
for _, block := range blocks {
|
||||
if block.Text != nil {
|
||||
if text != "" {
|
||||
text += "\n"
|
||||
}
|
||||
text += block.Text.Text
|
||||
}
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
// parseToolArgs attempts to parse a JSON tool args string into a map for
|
||||
// structured display. Falls back to a simple string wrapper.
|
||||
func parseToolArgs(args string) any {
|
||||
if args == "" {
|
||||
return nil
|
||||
}
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal([]byte(args), &m); err == nil {
|
||||
return m
|
||||
}
|
||||
return map[string]any{"input": args}
|
||||
}
|
||||
@@ -0,0 +1,294 @@
|
||||
package acpserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// acpSession maps an ACP session to a Kit instance with its own tree session.
|
||||
type acpSession struct {
|
||||
kit *kit.Kit
|
||||
cancelFn context.CancelFunc // cancels the current prompt
|
||||
cancelMu sync.Mutex
|
||||
cwd string
|
||||
sessionID string // Kit-generated session ID (from JSONL header)
|
||||
}
|
||||
|
||||
// sessionRegistry is a thread-safe registry of ACP session ID → Kit sessions.
|
||||
type sessionRegistry struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*acpSession // ACP session ID → session
|
||||
}
|
||||
|
||||
func newSessionRegistry() *sessionRegistry {
|
||||
return &sessionRegistry{
|
||||
sessions: make(map[string]*acpSession),
|
||||
}
|
||||
}
|
||||
|
||||
// create creates a new Kit instance with a persisted tree session for the
|
||||
// given working directory. The Kit-generated session ID is used as the ACP
|
||||
// session ID so the mapping is 1:1.
|
||||
func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession, error) {
|
||||
kitInstance, err := kit.New(ctx, &kit.Options{
|
||||
SessionDir: cwd,
|
||||
Quiet: true,
|
||||
Streaming: true,
|
||||
})
|
||||
if err != nil {
|
||||
// Provide actionable guidance for provider auth errors, which are
|
||||
// the most common failure mode when running via ACP.
|
||||
msg := err.Error()
|
||||
if strings.Contains(msg, "API key") || strings.Contains(msg, "credentials") || strings.Contains(msg, "OAuth") {
|
||||
return nil, fmt.Errorf("provider authentication failed: %w — run 'kit auth login <provider>' or set the appropriate environment variable before starting 'kit acp'", err)
|
||||
}
|
||||
return nil, fmt.Errorf("create kit instance: %w", err)
|
||||
}
|
||||
|
||||
sessionID := kitInstance.GetSessionID()
|
||||
if sessionID == "" {
|
||||
_ = kitInstance.Close()
|
||||
return nil, fmt.Errorf("kit instance has no session ID")
|
||||
}
|
||||
|
||||
// Wire extension context with headless implementations so extensions
|
||||
// work in ACP mode. TUI-dependent features (widgets, prompts, editor)
|
||||
// become no-ops or return cancelled; all data/model/tool APIs work
|
||||
// identically to interactive mode.
|
||||
if kitInstance.HasExtensions() {
|
||||
kitInstance.SetExtensionContext(extensions.Context{
|
||||
SessionID: sessionID,
|
||||
CWD: cwd,
|
||||
Model: kitInstance.GetModelString(),
|
||||
Interactive: false,
|
||||
|
||||
// Output — route through structured logger.
|
||||
Print: func(text string) { log.Debug("extension: print", "text", text) },
|
||||
PrintInfo: func(text string) { log.Info("extension: info", "text", text) },
|
||||
PrintError: func(text string) { log.Error("extension: error", "text", text) },
|
||||
PrintBlock: func(opts extensions.PrintBlockOpts) {
|
||||
log.Info("extension: block", "subtitle", opts.Subtitle, "text", opts.Text)
|
||||
},
|
||||
|
||||
// Message injection — no-ops for now; ACP clients drive prompts.
|
||||
SendMessage: func(string) {},
|
||||
CancelAndSend: func(string) {},
|
||||
Exit: func() {},
|
||||
|
||||
// TUI widgets/chrome — silent no-ops (no TUI in ACP).
|
||||
SetWidget: func(extensions.WidgetConfig) {},
|
||||
RemoveWidget: func(string) {},
|
||||
SetHeader: func(extensions.HeaderFooterConfig) {},
|
||||
RemoveHeader: func() {},
|
||||
SetFooter: func(extensions.HeaderFooterConfig) {},
|
||||
RemoveFooter: func() {},
|
||||
SetEditor: func(extensions.EditorConfig) {},
|
||||
ResetEditor: func() {},
|
||||
SetEditorText: func(string) {},
|
||||
SetUIVisibility: func(extensions.UIVisibility) {},
|
||||
SetStatus: func(string, string, int) {},
|
||||
RemoveStatus: func(string) {},
|
||||
|
||||
// Interactive prompts — return cancelled (no user to prompt).
|
||||
PromptSelect: func(extensions.PromptSelectConfig) extensions.PromptSelectResult {
|
||||
return extensions.PromptSelectResult{Cancelled: true}
|
||||
},
|
||||
PromptConfirm: func(extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
|
||||
return extensions.PromptConfirmResult{Cancelled: true}
|
||||
},
|
||||
PromptInput: func(extensions.PromptInputConfig) extensions.PromptInputResult {
|
||||
return extensions.PromptInputResult{Cancelled: true}
|
||||
},
|
||||
ShowOverlay: func(extensions.OverlayConfig) extensions.OverlayResult {
|
||||
return extensions.OverlayResult{Cancelled: true, Index: -1}
|
||||
},
|
||||
SuspendTUI: func(callback func()) error { callback(); return nil },
|
||||
|
||||
// Data access — delegate to Kit instance.
|
||||
GetContextStats: func() extensions.ContextStats {
|
||||
s := kitInstance.GetContextStats()
|
||||
return extensions.ContextStats{
|
||||
EstimatedTokens: s.EstimatedTokens,
|
||||
ContextLimit: s.ContextLimit,
|
||||
UsagePercent: s.UsagePercent,
|
||||
MessageCount: s.MessageCount,
|
||||
}
|
||||
},
|
||||
GetMessages: func() []extensions.SessionMessage { return kitInstance.GetSessionMessages() },
|
||||
GetSessionPath: func() string { return kitInstance.GetSessionFilePath() },
|
||||
AppendEntry: func(entryType, data string) (string, error) {
|
||||
return kitInstance.AppendExtensionEntry(entryType, data)
|
||||
},
|
||||
GetEntries: func(entryType string) []extensions.ExtensionEntry {
|
||||
return kitInstance.GetExtensionEntries(entryType)
|
||||
},
|
||||
|
||||
// Options, model, and tool management.
|
||||
GetOption: func(name string) string { return kitInstance.GetExtensionOption(name) },
|
||||
SetOption: func(name, value string) { kitInstance.SetExtensionOption(name, value) },
|
||||
SetModel: func(modelString string) error {
|
||||
previousModel := kitInstance.GetExtensionContext().Model
|
||||
if err := kitInstance.SetModel(context.Background(), modelString); err != nil {
|
||||
return err
|
||||
}
|
||||
kitInstance.UpdateExtensionContextModel(modelString)
|
||||
kitInstance.EmitModelChange(modelString, previousModel, "extension")
|
||||
return nil
|
||||
},
|
||||
GetAvailableModels: func() []extensions.ModelInfoEntry { return kitInstance.GetAvailableModels() },
|
||||
EmitCustomEvent: func(name, data string) { kitInstance.EmitExtensionCustomEvent(name, data) },
|
||||
GetAllTools: func() []extensions.ToolInfo { return kitInstance.GetExtensionToolInfos() },
|
||||
SetActiveTools: func(names []string) { kitInstance.SetExtensionActiveTools(names) },
|
||||
|
||||
// LLM completions and subagents.
|
||||
Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
|
||||
return kitInstance.ExecuteCompletion(context.Background(), req)
|
||||
},
|
||||
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
sdkCfg := kit.SubagentConfig{
|
||||
Prompt: config.Prompt,
|
||||
Model: config.Model,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Timeout: config.Timeout,
|
||||
NoSession: config.NoSession,
|
||||
}
|
||||
if config.OnEvent != nil {
|
||||
sdkCfg.OnEvent = func(e kit.Event) {
|
||||
se := sdkEventToSubagentEvent(e)
|
||||
if se.Type != "" {
|
||||
config.OnEvent(se)
|
||||
}
|
||||
}
|
||||
}
|
||||
result, err := kitInstance.Subagent(context.Background(), sdkCfg)
|
||||
if result == nil {
|
||||
return nil, &extensions.SubagentResult{Error: err}, err
|
||||
}
|
||||
extResult := &extensions.SubagentResult{
|
||||
Response: result.Response,
|
||||
Error: result.Error,
|
||||
SessionID: result.SessionID,
|
||||
Elapsed: result.Elapsed,
|
||||
}
|
||||
if result.Usage != nil {
|
||||
extResult.Usage = &extensions.SubagentUsage{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
return nil, extResult, err
|
||||
},
|
||||
|
||||
// Render — fall back to logging.
|
||||
RenderMessage: func(name, content string) {
|
||||
renderer := kitInstance.GetExtensionMessageRenderer(name)
|
||||
if renderer != nil && renderer.Render != nil {
|
||||
content = renderer.Render(content, 80)
|
||||
}
|
||||
log.Info("extension: message", "renderer", name, "content", content)
|
||||
},
|
||||
ReloadExtensions: func() error { return kitInstance.ReloadExtensions() },
|
||||
})
|
||||
kitInstance.EmitSessionStart()
|
||||
}
|
||||
|
||||
sess := &acpSession{
|
||||
kit: kitInstance,
|
||||
cwd: cwd,
|
||||
sessionID: sessionID,
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
r.sessions[sessionID] = sess
|
||||
r.mu.Unlock()
|
||||
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
// get retrieves a session by ACP session ID.
|
||||
func (r *sessionRegistry) get(sessionID string) (*acpSession, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
s, ok := r.sessions[sessionID]
|
||||
return s, ok
|
||||
}
|
||||
|
||||
// closeAll closes all sessions.
|
||||
func (r *sessionRegistry) closeAll() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for id, sess := range r.sessions {
|
||||
if sess.kit != nil {
|
||||
_ = sess.kit.Close()
|
||||
}
|
||||
delete(r.sessions, id)
|
||||
}
|
||||
}
|
||||
|
||||
// cancelPrompt cancels the current prompt for a session, if any.
|
||||
func (s *acpSession) cancelPrompt() {
|
||||
s.cancelMu.Lock()
|
||||
defer s.cancelMu.Unlock()
|
||||
if s.cancelFn != nil {
|
||||
s.cancelFn()
|
||||
s.cancelFn = nil
|
||||
}
|
||||
}
|
||||
|
||||
// setCancel stores a cancel function for the current prompt.
|
||||
func (s *acpSession) setCancel(cancel context.CancelFunc) {
|
||||
s.cancelMu.Lock()
|
||||
defer s.cancelMu.Unlock()
|
||||
s.cancelFn = cancel
|
||||
}
|
||||
|
||||
// clearCancel clears the stored cancel function (called when prompt completes).
|
||||
func (s *acpSession) clearCancel() {
|
||||
s.cancelMu.Lock()
|
||||
defer s.cancelMu.Unlock()
|
||||
s.cancelFn = nil
|
||||
}
|
||||
|
||||
// sdkEventToSubagentEvent converts an SDK event to an extension SubagentEvent.
|
||||
func sdkEventToSubagentEvent(e kit.Event) extensions.SubagentEvent {
|
||||
switch ev := e.(type) {
|
||||
case kit.MessageUpdateEvent:
|
||||
return extensions.SubagentEvent{Type: "text", Content: ev.Chunk}
|
||||
case kit.ReasoningDeltaEvent:
|
||||
return extensions.SubagentEvent{Type: "reasoning", Content: ev.Delta}
|
||||
case kit.ToolCallEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_call", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind, ToolArgs: ev.ToolArgs,
|
||||
}
|
||||
case kit.ToolExecutionStartEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_start", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolExecutionEndEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_end", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolResultEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_result", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
ToolResult: ev.Result, IsError: ev.IsError,
|
||||
}
|
||||
case kit.TurnStartEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_start"}
|
||||
case kit.TurnEndEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_end"}
|
||||
default:
|
||||
return extensions.SubagentEvent{}
|
||||
}
|
||||
}
|
||||
+12
-7
@@ -41,13 +41,15 @@ type AgentConfig struct {
|
||||
}
|
||||
|
||||
// ToolCallHandler is a function type for handling tool calls as they happen.
|
||||
type ToolCallHandler func(toolName, toolArgs string)
|
||||
type ToolCallHandler func(toolCallID, toolName, toolArgs string)
|
||||
|
||||
// ToolExecutionHandler is a function type for handling tool execution start/end events.
|
||||
type ToolExecutionHandler func(toolName string, isStarting bool)
|
||||
type ToolExecutionHandler func(toolCallID, toolName, toolArgs string, isStarting bool)
|
||||
|
||||
// ToolResultHandler is a function type for handling tool results.
|
||||
type ToolResultHandler func(toolName, toolArgs, result string, isError bool)
|
||||
// The metadata parameter carries optional structured data (e.g. file diff
|
||||
// info) from the tool execution, JSON-encoded. It may be empty.
|
||||
type ToolResultHandler func(toolCallID, toolName, toolArgs, result, metadata string, isError bool)
|
||||
|
||||
// ResponseHandler is a function type for handling LLM responses.
|
||||
type ResponseHandler func(content string)
|
||||
@@ -90,6 +92,8 @@ type GenerateWithLoopResult struct {
|
||||
Messages []message.Message
|
||||
// TotalUsage contains aggregate token usage across all steps
|
||||
TotalUsage fantasy.Usage
|
||||
// StopReason is the LLM provider's finish reason for the final response.
|
||||
StopReason string
|
||||
}
|
||||
|
||||
// NewAgent creates a new Agent with core tools and optional MCP tool integration.
|
||||
@@ -283,12 +287,12 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
|
||||
// Notify about the tool call
|
||||
if onToolCall != nil {
|
||||
onToolCall(tc.ToolName, tc.Input)
|
||||
onToolCall(tc.ToolCallID, tc.ToolName, tc.Input)
|
||||
}
|
||||
|
||||
// Notify tool execution starting
|
||||
if onToolExecution != nil {
|
||||
onToolExecution(tc.ToolName, true)
|
||||
onToolExecution(tc.ToolCallID, tc.ToolName, tc.Input, true)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -301,13 +305,13 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
}
|
||||
// Notify tool execution finished
|
||||
if onToolExecution != nil {
|
||||
onToolExecution(tr.ToolName, false)
|
||||
onToolExecution(tr.ToolCallID, tr.ToolName, currentToolArgs, false)
|
||||
}
|
||||
|
||||
if onToolResult != nil {
|
||||
// Extract result text and error status
|
||||
resultText, isError := extractToolResultText(tr)
|
||||
onToolResult(tr.ToolName, currentToolArgs, resultText, isError)
|
||||
onToolResult(tr.ToolCallID, tr.ToolName, currentToolArgs, resultText, tr.ClientMetadata, isError)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -426,6 +430,7 @@ func convertAgentResult(result *fantasy.AgentResult, originalMessages []fantasy.
|
||||
ConversationMessages: allFantasyMessages,
|
||||
Messages: allMessages,
|
||||
TotalUsage: result.TotalUsage,
|
||||
StopReason: string(result.Response.FinishReason),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+4
-4
@@ -532,14 +532,14 @@ func (a *App) subscribeSDKEvents(sendFn func(tea.Msg)) func() {
|
||||
unsubs = append(unsubs, k.Subscribe(func(e kit.Event) {
|
||||
switch ev := e.(type) {
|
||||
case kit.ToolCallEvent:
|
||||
sendFn(ToolCallStartedEvent{ToolName: ev.ToolName, ToolArgs: ev.ToolArgs})
|
||||
sendFn(ToolCallStartedEvent{ToolCallID: ev.ToolCallID, ToolName: ev.ToolName, ToolArgs: ev.ToolArgs})
|
||||
case kit.ToolExecutionStartEvent:
|
||||
sendFn(ToolExecutionEvent{ToolName: ev.ToolName, IsStarting: true})
|
||||
sendFn(ToolExecutionEvent{ToolCallID: ev.ToolCallID, ToolName: ev.ToolName, ToolArgs: ev.ToolArgs, IsStarting: true})
|
||||
case kit.ToolExecutionEndEvent:
|
||||
sendFn(ToolExecutionEvent{ToolName: ev.ToolName, IsStarting: false})
|
||||
sendFn(ToolExecutionEvent{ToolCallID: ev.ToolCallID, ToolName: ev.ToolName, IsStarting: false})
|
||||
case kit.ToolResultEvent:
|
||||
sendFn(ToolResultEvent{
|
||||
ToolName: ev.ToolName, ToolArgs: ev.ToolArgs,
|
||||
ToolCallID: ev.ToolCallID, ToolName: ev.ToolName, ToolArgs: ev.ToolArgs,
|
||||
Result: ev.Result, IsError: ev.IsError,
|
||||
})
|
||||
case kit.ToolCallContentEvent:
|
||||
|
||||
@@ -19,6 +19,8 @@ type ReasoningChunkEvent struct {
|
||||
// ToolCallStartedEvent is sent when a tool call has been parsed and is about to execute.
|
||||
// It carries the tool name and its arguments for display purposes.
|
||||
type ToolCallStartedEvent struct {
|
||||
// ToolCallID is the stable identifier for correlating tool lifecycle events.
|
||||
ToolCallID string
|
||||
// ToolName is the name of the tool being called.
|
||||
ToolName string
|
||||
// ToolArgs is the JSON-encoded arguments for the tool call.
|
||||
@@ -28,14 +30,20 @@ type ToolCallStartedEvent struct {
|
||||
// ToolExecutionEvent is sent when a tool starts or finishes executing.
|
||||
// The IsStarting flag distinguishes between the start and end of execution.
|
||||
type ToolExecutionEvent struct {
|
||||
// ToolCallID is the stable identifier for correlating tool lifecycle events.
|
||||
ToolCallID string
|
||||
// ToolName is the name of the tool being executed.
|
||||
ToolName string
|
||||
// ToolArgs is the JSON-encoded arguments for the tool call (only set when IsStarting is true).
|
||||
ToolArgs string
|
||||
// IsStarting is true when execution is beginning, false when it is complete.
|
||||
IsStarting bool
|
||||
}
|
||||
|
||||
// ToolResultEvent is sent after a tool execution completes with its result.
|
||||
type ToolResultEvent struct {
|
||||
// ToolCallID is the stable identifier for correlating tool lifecycle events.
|
||||
ToolCallID string
|
||||
// ToolName is the name of the tool that was executed.
|
||||
ToolName string
|
||||
// ToolArgs is the JSON-encoded arguments that were passed to the tool.
|
||||
|
||||
@@ -51,6 +51,7 @@ func TestCredentialManager(t *testing.T) {
|
||||
}
|
||||
if creds == nil {
|
||||
t.Fatal("Expected credentials to be returned")
|
||||
return
|
||||
}
|
||||
if creds.APIKey != testAPIKey {
|
||||
t.Errorf("Expected API key %s, got %s", testAPIKey, creds.APIKey)
|
||||
@@ -236,6 +237,7 @@ func TestCredentialStorePersistence(t *testing.T) {
|
||||
}
|
||||
if creds == nil {
|
||||
t.Fatal("Expected credentials to persist")
|
||||
return
|
||||
}
|
||||
if creds.APIKey != testAPIKey {
|
||||
t.Errorf("Expected API key %s, got %s", testAPIKey, creds.APIKey)
|
||||
|
||||
@@ -49,12 +49,12 @@ func NewOAuthClient() *OAuthClient {
|
||||
}
|
||||
}
|
||||
|
||||
// GeneratePKCE generates a cryptographically secure PKCE verifier and challenge pair
|
||||
// generatePKCE generates a cryptographically secure PKCE verifier and challenge pair
|
||||
// for the OAuth 2.0 PKCE flow. The verifier is a random 32-byte string encoded as
|
||||
// base64url, and the challenge is the SHA256 hash of the verifier, also base64url encoded.
|
||||
// Returns the verifier (to be stored securely), challenge (to be sent with auth request),
|
||||
// and any error encountered during generation.
|
||||
func GeneratePKCE() (verifier, challenge string, err error) {
|
||||
func generatePKCE() (verifier, challenge string, err error) {
|
||||
// Generate 32 bytes of random data
|
||||
verifierBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(verifierBytes); err != nil {
|
||||
@@ -76,7 +76,7 @@ func GeneratePKCE() (verifier, challenge string, err error) {
|
||||
// and PKCE challenge. Returns an AuthData structure containing the URL for user
|
||||
// authentication and the PKCE verifier for the subsequent code exchange.
|
||||
func (c *OAuthClient) GetAuthorizationURL() (*AuthData, error) {
|
||||
verifier, challenge, err := GeneratePKCE()
|
||||
verifier, challenge, err := generatePKCE()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate PKCE: %w", err)
|
||||
}
|
||||
|
||||
@@ -71,5 +71,5 @@ func DetectMediaType(data []byte) string {
|
||||
// ErrNoImage is returned when the clipboard does not contain image data.
|
||||
var ErrNoImage = fmt.Errorf("no image data on clipboard")
|
||||
|
||||
// ErrNoClipboardTool is returned when no suitable clipboard tool is found.
|
||||
var ErrNoClipboardTool = fmt.Errorf("no clipboard tool available (install xclip, wl-paste, or use macOS)")
|
||||
// errNoClipboardTool is returned when no suitable clipboard tool is found.
|
||||
var errNoClipboardTool = fmt.Errorf("no clipboard tool available (install xclip, wl-paste, or use macOS)")
|
||||
|
||||
@@ -7,9 +7,8 @@ import (
|
||||
)
|
||||
|
||||
// ReadImage reads image data from the system clipboard on macOS.
|
||||
// It uses osascript to check if the clipboard contains an image and then
|
||||
// reads the data using a temporary approach. If the clipboard contains
|
||||
// an image, it writes it to stdout as PNG data.
|
||||
// It uses osascript to check if the clipboard contains an image via
|
||||
// NSPasteboard and writes it to stdout as PNG data.
|
||||
func ReadImage() (*ImageData, error) {
|
||||
// Use osascript to write clipboard image to stdout via a pipe.
|
||||
// The script checks if the clipboard has a «class PNGf» item.
|
||||
|
||||
@@ -41,7 +41,7 @@ func ReadImage() (*ImageData, error) {
|
||||
return nil, ErrNoImage
|
||||
}
|
||||
|
||||
return nil, ErrNoClipboardTool
|
||||
return nil, errNoClipboardTool
|
||||
}
|
||||
|
||||
// readWithXclip reads image data using xclip.
|
||||
|
||||
@@ -5,5 +5,5 @@ package clipboard
|
||||
// ReadImage reads image data from the system clipboard on Windows.
|
||||
// Windows clipboard image support is not yet implemented.
|
||||
func ReadImage() (*ImageData, error) {
|
||||
return nil, ErrNoClipboardTool
|
||||
return nil, errNoClipboardTool
|
||||
}
|
||||
|
||||
@@ -19,8 +19,8 @@ import (
|
||||
// Token estimation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// EstimateTokens provides a rough token count (~4 chars per token).
|
||||
func EstimateTokens(text string) int {
|
||||
// estimateTokens provides a rough token count (~4 chars per token).
|
||||
func estimateTokens(text string) int {
|
||||
return len(text) / 4
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ func estimateSingleMessageTokens(msg fantasy.Message) int {
|
||||
total := 0
|
||||
for _, part := range msg.Content {
|
||||
if tp, ok := part.(fantasy.TextPart); ok {
|
||||
total += EstimateTokens(tp.Text)
|
||||
total += estimateTokens(tp.Text)
|
||||
}
|
||||
}
|
||||
return total
|
||||
|
||||
@@ -36,9 +36,9 @@ func TestEstimateTokens(t *testing.T) {
|
||||
{"hello world", 2}, // 11 / 4 = 2
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := EstimateTokens(tt.text)
|
||||
got := estimateTokens(tt.text)
|
||||
if got != tt.want {
|
||||
t.Errorf("EstimateTokens(%q) = %d, want %d", tt.text, got, tt.want)
|
||||
t.Errorf("estimateTokens(%q) = %d, want %d", tt.text, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+49
-37
@@ -105,42 +105,56 @@ type AdaptiveColor struct {
|
||||
Dark string `json:"dark,omitempty" yaml:"dark,omitempty"`
|
||||
}
|
||||
|
||||
// MarkdownThemeConfig defines color overrides for markdown rendering and
|
||||
// syntax highlighting.
|
||||
type MarkdownThemeConfig struct {
|
||||
Text AdaptiveColor `json:"text,omitzero" yaml:"text,omitempty"`
|
||||
Muted AdaptiveColor `json:"muted,omitzero" yaml:"muted,omitempty"`
|
||||
Heading AdaptiveColor `json:"heading,omitzero" yaml:"heading,omitempty"`
|
||||
Emph AdaptiveColor `json:"emph,omitzero" yaml:"emph,omitempty"`
|
||||
Strong AdaptiveColor `json:"strong,omitzero" yaml:"strong,omitempty"`
|
||||
Link AdaptiveColor `json:"link,omitzero" yaml:"link,omitempty"`
|
||||
Code AdaptiveColor `json:"code,omitzero" yaml:"code,omitempty"`
|
||||
Error AdaptiveColor `json:"error,omitzero" yaml:"error,omitempty"`
|
||||
Keyword AdaptiveColor `json:"keyword,omitzero" yaml:"keyword,omitempty"`
|
||||
String AdaptiveColor `json:"string,omitzero" yaml:"string,omitempty"`
|
||||
Number AdaptiveColor `json:"number,omitzero" yaml:"number,omitempty"`
|
||||
Comment AdaptiveColor `json:"comment,omitzero" yaml:"comment,omitempty"`
|
||||
}
|
||||
|
||||
// Theme defines the color scheme for the application UI with adaptive colors
|
||||
// that support both light and dark modes.
|
||||
type Theme struct {
|
||||
Primary AdaptiveColor `json:"primary" yaml:"primary"`
|
||||
Secondary AdaptiveColor `json:"secondary" yaml:"secondary"`
|
||||
Success AdaptiveColor `json:"success" yaml:"success"`
|
||||
Warning AdaptiveColor `json:"warning" yaml:"warning"`
|
||||
Error AdaptiveColor `json:"error" yaml:"error"`
|
||||
Info AdaptiveColor `json:"info" yaml:"info"`
|
||||
Text AdaptiveColor `json:"text" yaml:"text"`
|
||||
Muted AdaptiveColor `json:"muted" yaml:"muted"`
|
||||
VeryMuted AdaptiveColor `json:"very-muted" yaml:"very-muted"`
|
||||
Background AdaptiveColor `json:"background" yaml:"background"`
|
||||
Border AdaptiveColor `json:"border" yaml:"border"`
|
||||
MutedBorder AdaptiveColor `json:"muted-border" yaml:"muted-border"`
|
||||
System AdaptiveColor `json:"system" yaml:"system"`
|
||||
Tool AdaptiveColor `json:"tool" yaml:"tool"`
|
||||
Accent AdaptiveColor `json:"accent" yaml:"accent"`
|
||||
Highlight AdaptiveColor `json:"highlight" yaml:"highlight"`
|
||||
}
|
||||
Primary AdaptiveColor `json:"primary,omitzero" yaml:"primary,omitempty"`
|
||||
Secondary AdaptiveColor `json:"secondary,omitzero" yaml:"secondary,omitempty"`
|
||||
Success AdaptiveColor `json:"success,omitzero" yaml:"success,omitempty"`
|
||||
Warning AdaptiveColor `json:"warning,omitzero" yaml:"warning,omitempty"`
|
||||
Error AdaptiveColor `json:"error,omitzero" yaml:"error,omitempty"`
|
||||
Info AdaptiveColor `json:"info,omitzero" yaml:"info,omitempty"`
|
||||
Text AdaptiveColor `json:"text,omitzero" yaml:"text,omitempty"`
|
||||
Muted AdaptiveColor `json:"muted,omitzero" yaml:"muted,omitempty"`
|
||||
VeryMuted AdaptiveColor `json:"very-muted,omitzero" yaml:"very-muted,omitempty"`
|
||||
Background AdaptiveColor `json:"background,omitzero" yaml:"background,omitempty"`
|
||||
Border AdaptiveColor `json:"border,omitzero" yaml:"border,omitempty"`
|
||||
MutedBorder AdaptiveColor `json:"muted-border,omitzero" yaml:"muted-border,omitempty"`
|
||||
System AdaptiveColor `json:"system,omitzero" yaml:"system,omitempty"`
|
||||
Tool AdaptiveColor `json:"tool,omitzero" yaml:"tool,omitempty"`
|
||||
Accent AdaptiveColor `json:"accent,omitzero" yaml:"accent,omitempty"`
|
||||
Highlight AdaptiveColor `json:"highlight,omitzero" yaml:"highlight,omitempty"`
|
||||
|
||||
// MarkdownTheme defines the color scheme for markdown rendering with syntax
|
||||
// highlighting support and adaptive colors for light and dark modes.
|
||||
type MarkdownTheme struct {
|
||||
Text AdaptiveColor `json:"text" yaml:"text"`
|
||||
Muted AdaptiveColor `json:"muted" yaml:"muted"`
|
||||
Heading AdaptiveColor `json:"heading" yaml:"heading"`
|
||||
Emph AdaptiveColor `json:"emph" yaml:"emph"`
|
||||
Strong AdaptiveColor `json:"strong" yaml:"strong"`
|
||||
Link AdaptiveColor `json:"link" yaml:"link"`
|
||||
Code AdaptiveColor `json:"code" yaml:"code"`
|
||||
Error AdaptiveColor `json:"error" yaml:"error"`
|
||||
Keyword AdaptiveColor `json:"keyword" yaml:"keyword"`
|
||||
String AdaptiveColor `json:"string" yaml:"string"`
|
||||
Number AdaptiveColor `json:"number" yaml:"number"`
|
||||
Comment AdaptiveColor `json:"comment" yaml:"comment"`
|
||||
// Diff block backgrounds
|
||||
DiffInsertBg AdaptiveColor `json:"diff-insert-bg,omitzero" yaml:"diff-insert-bg,omitempty"`
|
||||
DiffDeleteBg AdaptiveColor `json:"diff-delete-bg,omitzero" yaml:"diff-delete-bg,omitempty"`
|
||||
DiffEqualBg AdaptiveColor `json:"diff-equal-bg,omitzero" yaml:"diff-equal-bg,omitempty"`
|
||||
DiffMissingBg AdaptiveColor `json:"diff-missing-bg,omitzero" yaml:"diff-missing-bg,omitempty"`
|
||||
|
||||
// Code/output block backgrounds
|
||||
CodeBg AdaptiveColor `json:"code-bg,omitzero" yaml:"code-bg,omitempty"`
|
||||
GutterBg AdaptiveColor `json:"gutter-bg,omitzero" yaml:"gutter-bg,omitempty"`
|
||||
WriteBg AdaptiveColor `json:"write-bg,omitzero" yaml:"write-bg,omitempty"`
|
||||
|
||||
// Markdown rendering and syntax highlighting
|
||||
Markdown MarkdownThemeConfig `json:"markdown,omitzero" yaml:"markdown,omitempty"`
|
||||
}
|
||||
|
||||
// Config represents the complete application configuration including MCP servers,
|
||||
@@ -157,7 +171,6 @@ type Config struct {
|
||||
ProviderURL string `json:"provider-url,omitempty" yaml:"provider-url,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty" yaml:"stream,omitempty"`
|
||||
Theme any `json:"theme" yaml:"theme"`
|
||||
MarkdownTheme any `json:"markdown-theme" yaml:"markdown-theme"`
|
||||
// Model generation parameters
|
||||
MaxTokens int `json:"max-tokens,omitempty" yaml:"max-tokens,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
@@ -373,11 +386,10 @@ func FilepathOr[T any](key string, value *T) error {
|
||||
fmt.Fprintf(os.Stderr, "%q", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if filepath.Ext(absPath) == ".json" {
|
||||
switch filepath.Ext(absPath) {
|
||||
case ".json":
|
||||
return json.Unmarshal(b, value)
|
||||
}
|
||||
|
||||
if filepath.Ext(absPath) == ".yaml" {
|
||||
case ".yaml", ".yml":
|
||||
return yaml.Unmarshal(b, value)
|
||||
}
|
||||
}
|
||||
|
||||
+21
-3
@@ -76,13 +76,15 @@ func executeEdit(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
// If no exact match, try fuzzy matching
|
||||
if count == 0 {
|
||||
if idx, matchLen := fuzzyMatch(normalized, normalizedOld); idx >= 0 {
|
||||
// Apply fuzzy match
|
||||
// Apply fuzzy match — the matched text is the original content slice
|
||||
matchedText := normalized[idx : idx+matchLen]
|
||||
newContent := normalized[:idx] + args.NewText + normalized[idx+matchLen:]
|
||||
if err := os.WriteFile(absPath, []byte(newContent), 0644); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to write file: %v", err)), nil
|
||||
}
|
||||
diff := generateDiff(absPath, normalized, newContent, idx)
|
||||
return fantasy.NewTextResponse(fmt.Sprintf("Applied edit (fuzzy match) to %s\n%s", args.Path, diff)), nil
|
||||
resp := fantasy.NewTextResponse(fmt.Sprintf("Applied edit (fuzzy match) to %s\n%s", args.Path, diff))
|
||||
return fantasy.WithResponseMetadata(resp, editDiffMeta(absPath, matchedText, args.NewText)), nil
|
||||
}
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("old_text not found in %s", args.Path)), nil
|
||||
}
|
||||
@@ -100,7 +102,23 @@ func executeEdit(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
|
||||
idx := strings.Index(normalized, normalizedOld)
|
||||
diff := generateDiff(absPath, normalized, newContent, idx)
|
||||
return fantasy.NewTextResponse(fmt.Sprintf("Applied edit to %s\n%s", args.Path, diff)), nil
|
||||
resp := fantasy.NewTextResponse(fmt.Sprintf("Applied edit to %s\n%s", args.Path, diff))
|
||||
return fantasy.WithResponseMetadata(resp, editDiffMeta(absPath, normalizedOld, args.NewText)), nil
|
||||
}
|
||||
|
||||
// editDiffMeta builds the structured metadata attached to edit tool responses.
|
||||
func editDiffMeta(path, oldText, newText string) map[string]any {
|
||||
return map[string]any{
|
||||
"file_diffs": []map[string]any{{
|
||||
"path": path,
|
||||
"additions": strings.Count(newText, "\n") + 1,
|
||||
"deletions": strings.Count(oldText, "\n") + 1,
|
||||
"diff_blocks": []map[string]any{{
|
||||
"old_text": oldText,
|
||||
"new_text": newText,
|
||||
}},
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
// fuzzyMatch tries to find old_text with relaxed matching:
|
||||
|
||||
@@ -39,6 +39,7 @@ func NewFindTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
},
|
||||
},
|
||||
Required: []string{"pattern"},
|
||||
Parallel: true,
|
||||
},
|
||||
handler: func(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return executeFind(ctx, call, cfg.WorkDir)
|
||||
|
||||
@@ -59,6 +59,7 @@ func NewGrepTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
},
|
||||
},
|
||||
Required: []string{"pattern"},
|
||||
Parallel: true,
|
||||
},
|
||||
handler: func(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return executeGrep(ctx, call, cfg.WorkDir)
|
||||
|
||||
@@ -33,6 +33,7 @@ func NewLsTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
},
|
||||
},
|
||||
Required: []string{},
|
||||
Parallel: true,
|
||||
},
|
||||
handler: func(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return executeLs(ctx, call, cfg.WorkDir)
|
||||
|
||||
@@ -38,6 +38,7 @@ func NewReadTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
},
|
||||
},
|
||||
Required: []string{"path"},
|
||||
Parallel: true,
|
||||
},
|
||||
handler: func(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return executeRead(ctx, call, cfg.WorkDir)
|
||||
|
||||
@@ -0,0 +1,171 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
const defaultSubagentTimeout = 5 * time.Minute
|
||||
const maxSubagentTimeout = 30 * time.Minute
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Context-based subagent spawner
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SubagentSpawnResult carries the outcome of an in-process subagent spawn.
|
||||
type SubagentSpawnResult struct {
|
||||
Response string
|
||||
Error error
|
||||
SessionID string
|
||||
InputTokens int64
|
||||
OutputTokens int64
|
||||
Elapsed time.Duration
|
||||
}
|
||||
|
||||
// SubagentSpawnFunc is a callback that spawns an in-process subagent. The
|
||||
// parent Kit instance injects this into the context so the core tool can
|
||||
// call back without importing pkg/kit (which would create a cycle).
|
||||
type SubagentSpawnFunc func(ctx context.Context, prompt, model, systemPrompt string, timeout time.Duration) (*SubagentSpawnResult, error)
|
||||
|
||||
type subagentCtxKey struct{}
|
||||
|
||||
// WithSubagentSpawner stores a spawn function in the context so that the
|
||||
// spawn_subagent core tool can create in-process subagents.
|
||||
func WithSubagentSpawner(ctx context.Context, fn SubagentSpawnFunc) context.Context {
|
||||
return context.WithValue(ctx, subagentCtxKey{}, fn)
|
||||
}
|
||||
|
||||
// getSubagentSpawner retrieves the spawn function from the context.
|
||||
func getSubagentSpawner(ctx context.Context) SubagentSpawnFunc {
|
||||
if fn, ok := ctx.Value(subagentCtxKey{}).(SubagentSpawnFunc); ok {
|
||||
return fn
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// spawn_subagent tool
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type subagentArgs struct {
|
||||
Task string `json:"task"`
|
||||
Model string `json:"model,omitempty"`
|
||||
SystemPrompt string `json:"system_prompt,omitempty"`
|
||||
TimeoutSeconds int `json:"timeout_seconds,omitempty"`
|
||||
}
|
||||
|
||||
// NewSubagentTool creates the spawn_subagent core tool.
|
||||
func NewSubagentTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
return &coreTool{
|
||||
info: fantasy.ToolInfo{
|
||||
Name: "spawn_subagent",
|
||||
Description: `Spawn a subagent to perform a task autonomously.
|
||||
|
||||
The subagent runs as a separate in-process Kit instance with full tool access
|
||||
(except spawning further subagents). Use this to:
|
||||
- Delegate independent subtasks that can run in parallel
|
||||
- Perform research or analysis without blocking your main work
|
||||
- Execute tasks that benefit from a fresh context window
|
||||
|
||||
The subagent result is returned when it completes. For long-running tasks,
|
||||
consider breaking them into smaller focused subtasks.
|
||||
|
||||
Example use cases:
|
||||
- "Research the authentication patterns in this codebase"
|
||||
- "Write unit tests for the UserService class"
|
||||
- "Analyze the performance bottlenecks in the database queries"`,
|
||||
Parameters: map[string]any{
|
||||
"task": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The complete task description for the subagent to perform",
|
||||
},
|
||||
"model": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional model override (e.g. 'anthropic/claude-haiku-3-5-20241022' for faster/cheaper tasks)",
|
||||
},
|
||||
"system_prompt": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional system prompt for domain-specific guidance",
|
||||
},
|
||||
"timeout_seconds": map[string]any{
|
||||
"type": "number",
|
||||
"description": "Maximum execution time in seconds (default: 300, max: 1800)",
|
||||
},
|
||||
},
|
||||
Required: []string{"task"},
|
||||
Parallel: true,
|
||||
},
|
||||
handler: func(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return executeSubagent(ctx, call)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func executeSubagent(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
var args subagentArgs
|
||||
if err := parseArgs(call.Input, &args); err != nil {
|
||||
return fantasy.NewTextErrorResponse("task parameter is required"), nil
|
||||
}
|
||||
if args.Task == "" {
|
||||
return fantasy.NewTextErrorResponse("task parameter is required"), nil
|
||||
}
|
||||
|
||||
// Determine timeout.
|
||||
timeout := defaultSubagentTimeout
|
||||
if args.TimeoutSeconds > 0 {
|
||||
timeout = min(time.Duration(args.TimeoutSeconds)*time.Second, maxSubagentTimeout)
|
||||
}
|
||||
|
||||
// Retrieve in-process spawner from context.
|
||||
spawner := getSubagentSpawner(ctx)
|
||||
if spawner == nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
"Error: subagent spawner not available. " +
|
||||
"Ensure Kit is initialized with subagent support.",
|
||||
), fmt.Errorf("no subagent spawner in context")
|
||||
}
|
||||
|
||||
// Spawn in-process subagent.
|
||||
result, err := spawner(ctx, args.Task, args.Model, args.SystemPrompt, timeout)
|
||||
if err != nil || result.Error != nil {
|
||||
spawnErr := err
|
||||
if spawnErr == nil {
|
||||
spawnErr = result.Error
|
||||
}
|
||||
response := fmt.Sprintf("Subagent failed after %ds.\n\nError: %v",
|
||||
int(result.Elapsed.Seconds()), spawnErr)
|
||||
if result.Response != "" {
|
||||
response += fmt.Sprintf("\n\nPartial output:\n%s", truncateResponse(result.Response, 8000))
|
||||
}
|
||||
return fantasy.NewTextErrorResponse(response), nil
|
||||
}
|
||||
|
||||
// Build successful response.
|
||||
response := fmt.Sprintf("Subagent completed successfully in %ds.", int(result.Elapsed.Seconds()))
|
||||
if result.InputTokens > 0 || result.OutputTokens > 0 {
|
||||
response += fmt.Sprintf(" (tokens: %d in / %d out)", result.InputTokens, result.OutputTokens)
|
||||
}
|
||||
response += fmt.Sprintf("\n\nResult:\n%s", truncateResponse(result.Response, 12000))
|
||||
|
||||
resp := fantasy.NewTextResponse(response)
|
||||
|
||||
// Attach subagent session ID as metadata when available.
|
||||
if result.SessionID != "" {
|
||||
resp = fantasy.WithResponseMetadata(resp, map[string]any{
|
||||
"subagent_session_id": result.SessionID,
|
||||
})
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// truncateResponse limits the response length to avoid overwhelming context windows.
|
||||
func truncateResponse(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen] + "\n\n... [truncated — " + fmt.Sprintf("%d", len(s)-maxLen) + " bytes omitted]"
|
||||
}
|
||||
@@ -86,8 +86,9 @@ func ReadOnlyTools(opts ...ToolOption) []fantasy.AgentTool {
|
||||
}
|
||||
}
|
||||
|
||||
// AllTools returns all available core tools.
|
||||
func AllTools(opts ...ToolOption) []fantasy.AgentTool {
|
||||
// SubagentTools returns all core tools except spawn_subagent. This prevents
|
||||
// infinite recursion when a subagent is itself a Kit instance.
|
||||
func SubagentTools(opts ...ToolOption) []fantasy.AgentTool {
|
||||
return []fantasy.AgentTool{
|
||||
NewBashTool(opts...),
|
||||
NewReadTool(opts...),
|
||||
@@ -98,3 +99,8 @@ func AllTools(opts ...ToolOption) []fantasy.AgentTool {
|
||||
NewLsTool(opts...),
|
||||
}
|
||||
}
|
||||
|
||||
// AllTools returns all available core tools.
|
||||
func AllTools(opts ...ToolOption) []fantasy.AgentTool {
|
||||
return append(SubagentTools(opts...), NewSubagentTool(opts...))
|
||||
}
|
||||
|
||||
+32
-1
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
@@ -53,6 +54,14 @@ func executeWrite(ctx context.Context, call fantasy.ToolCall, workDir string) (f
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("invalid path: %v", err)), nil
|
||||
}
|
||||
|
||||
// Read existing content before writing (for diff metadata).
|
||||
var beforeContent string
|
||||
isNew := true
|
||||
if existing, readErr := os.ReadFile(absPath); readErr == nil {
|
||||
beforeContent = string(existing)
|
||||
isNew = false
|
||||
}
|
||||
|
||||
// Create parent directories
|
||||
dir := filepath.Dir(absPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
@@ -63,5 +72,27 @@ func executeWrite(ctx context.Context, call fantasy.ToolCall, workDir string) (f
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to write file: %v", err)), nil
|
||||
}
|
||||
|
||||
return fantasy.NewTextResponse(fmt.Sprintf("Wrote %d bytes to %s", len(args.Content), args.Path)), nil
|
||||
resp := fantasy.NewTextResponse(fmt.Sprintf("Wrote %d bytes to %s", len(args.Content), args.Path))
|
||||
return fantasy.WithResponseMetadata(resp, writeDiffMeta(absPath, beforeContent, args.Content, isNew)), nil
|
||||
}
|
||||
|
||||
// writeDiffMeta builds the structured metadata attached to write tool responses.
|
||||
func writeDiffMeta(path, beforeContent, afterContent string, isNew bool) map[string]any {
|
||||
additions := strings.Count(afterContent, "\n") + 1
|
||||
deletions := 0
|
||||
if !isNew {
|
||||
deletions = strings.Count(beforeContent, "\n") + 1
|
||||
}
|
||||
return map[string]any{
|
||||
"file_diffs": []map[string]any{{
|
||||
"path": path,
|
||||
"additions": additions,
|
||||
"deletions": deletions,
|
||||
"is_new": isNew,
|
||||
"diff_blocks": []map[string]any{{
|
||||
"old_text": beforeContent,
|
||||
"new_text": afterContent,
|
||||
}},
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
+161
-7
@@ -174,6 +174,22 @@ type Context struct {
|
||||
// }
|
||||
PromptInput func(PromptInputConfig) PromptInputResult
|
||||
|
||||
// PromptMultiSelect shows a multi-selection list to the user, allowing
|
||||
// them to toggle options with spacebar and confirm with enter. In
|
||||
// non-interactive mode, returns all options as selected.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := ctx.PromptMultiSelect(ext.PromptMultiSelectConfig{
|
||||
// Message: "Select extensions to install:",
|
||||
// Options: []string{"git", "todo", "weather"},
|
||||
// DefaultSelected: []int{0, 1, 2}, // All selected by default
|
||||
// })
|
||||
// if !result.Cancelled {
|
||||
// fmt.Println("Selected:", result.Values)
|
||||
// }
|
||||
PromptMultiSelect func(PromptMultiSelectConfig) PromptMultiSelectResult
|
||||
|
||||
// ShowOverlay displays a modal overlay dialog that blocks until the
|
||||
// user dismisses it or selects an action. The overlay renders as a
|
||||
// centered (or anchored) bordered box over the TUI. Returns a
|
||||
@@ -469,6 +485,36 @@ type Context struct {
|
||||
// ctx.RenderMessage("build-status", "All 42 tests passed.")
|
||||
RenderMessage func(rendererName string, content string)
|
||||
|
||||
// RegisterTheme adds a named theme to the runtime theme registry.
|
||||
// If a theme with the same name already exists it is replaced.
|
||||
// The theme becomes available via /theme and ctx.SetTheme().
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ctx.RegisterTheme("neon", ext.ThemeColorConfig{
|
||||
// Primary: ext.ThemeColor{Dark: "#FF00FF"},
|
||||
// Secondary: ext.ThemeColor{Dark: "#00FFFF"},
|
||||
// Success: ext.ThemeColor{Dark: "#00FF00"},
|
||||
// Warning: ext.ThemeColor{Dark: "#FFFF00"},
|
||||
// Error: ext.ThemeColor{Dark: "#FF0000"},
|
||||
// Info: ext.ThemeColor{Dark: "#00FFFF"},
|
||||
// Text: ext.ThemeColor{Dark: "#FFFFFF"},
|
||||
// Background: ext.ThemeColor{Dark: "#000000"},
|
||||
// })
|
||||
RegisterTheme func(name string, config ThemeColorConfig)
|
||||
|
||||
// SetTheme switches the active color theme by name. The name must
|
||||
// match a built-in theme, a user/project theme file, or a theme
|
||||
// registered via RegisterTheme. Returns an error if not found.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// err := ctx.SetTheme("neon")
|
||||
SetTheme func(name string) error
|
||||
|
||||
// ListThemes returns the names of all available themes.
|
||||
ListThemes func() []string
|
||||
|
||||
// ReloadExtensions hot-reloads all extensions from disk. Existing
|
||||
// extensions receive a SessionShutdown event, then new code is loaded
|
||||
// and receives a SessionStart event. Event handlers, commands,
|
||||
@@ -491,6 +537,41 @@ type Context struct {
|
||||
// },
|
||||
// })
|
||||
ReloadExtensions func() error
|
||||
|
||||
// SpawnSubagent spawns a child Kit instance to perform a task autonomously.
|
||||
// The subagent runs as a separate subprocess with full tool access but
|
||||
// isolated session and extensions (--no-session --no-extensions).
|
||||
//
|
||||
// When config.Blocking is true, blocks until completion and returns the
|
||||
// result directly (handle is nil). When false, returns immediately with
|
||||
// a handle for monitoring/cancellation.
|
||||
//
|
||||
// Example — blocking call:
|
||||
//
|
||||
// _, result, err := ctx.SpawnSubagent(ext.SubagentConfig{
|
||||
// Prompt: "Research authentication patterns in this codebase",
|
||||
// Blocking: true,
|
||||
// Timeout: 2 * time.Minute,
|
||||
// })
|
||||
// if err != nil {
|
||||
// ctx.PrintError("spawn failed: " + err.Error())
|
||||
// return
|
||||
// }
|
||||
// ctx.PrintInfo("Subagent result:\n" + result.Response)
|
||||
//
|
||||
// Example — background spawn with callbacks:
|
||||
//
|
||||
// handle, _, _ := ctx.SpawnSubagent(ext.SubagentConfig{
|
||||
// Prompt: "Write unit tests for UserService",
|
||||
// OnOutput: func(chunk string) {
|
||||
// // Live output streaming
|
||||
// },
|
||||
// OnComplete: func(result ext.SubagentResult) {
|
||||
// ctx.SendMessage("Subagent finished:\n" + result.Response)
|
||||
// },
|
||||
// })
|
||||
// // handle.Kill() to cancel, handle.Wait() to block
|
||||
SpawnSubagent func(SubagentConfig) (*SubagentHandle, *SubagentResult, error)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -965,6 +1046,29 @@ type PromptInputResult struct {
|
||||
Cancelled bool
|
||||
}
|
||||
|
||||
// PromptMultiSelectConfig configures a multi-selection prompt that allows
|
||||
// the user to toggle multiple options and confirm their selection.
|
||||
type PromptMultiSelectConfig struct {
|
||||
// Message is the question or instruction displayed to the user.
|
||||
Message string
|
||||
// Options is the list of choices the user can select from.
|
||||
Options []string
|
||||
// DefaultSelected contains indices of options that should be
|
||||
// pre-selected when the prompt appears. If nil, all options are selected.
|
||||
DefaultSelected []int
|
||||
}
|
||||
|
||||
// PromptMultiSelectResult is the response from a multi-selection prompt.
|
||||
type PromptMultiSelectResult struct {
|
||||
// Values contains the text of selected options.
|
||||
Values []string
|
||||
// Indices contains the zero-based indices of selected options.
|
||||
Indices []int
|
||||
// Cancelled is true if the user dismissed the prompt (ESC) or
|
||||
// the prompt was unavailable (non-interactive mode).
|
||||
Cancelled bool
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Header/Footer types (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -1397,7 +1501,9 @@ type EditorConfig struct {
|
||||
type ToolCallEvent struct {
|
||||
ToolName string
|
||||
ToolCallID string
|
||||
Input string // JSON-encoded tool parameters
|
||||
ToolKind string // Tool classification: "execute", "edit", "read", "search", "agent"
|
||||
Input string // JSON-encoded tool parameters
|
||||
ParsedArgs map[string]any // Pre-parsed arguments for convenience (nil on parse failure)
|
||||
// Source indicates who initiated the tool call.
|
||||
// Currently always "llm" (all tool calls originate from the LLM agent loop).
|
||||
// Future user-initiated tool features may set this to "user".
|
||||
@@ -1416,24 +1522,31 @@ func (ToolCallResult) isResult() {}
|
||||
|
||||
// ToolExecutionStartEvent fires when a tool begins executing.
|
||||
type ToolExecutionStartEvent struct {
|
||||
ToolName string
|
||||
ToolCallID string
|
||||
ToolName string
|
||||
ToolKind string
|
||||
}
|
||||
|
||||
func (e ToolExecutionStartEvent) Type() EventType { return ToolExecutionStart }
|
||||
|
||||
// ToolExecutionEndEvent fires when a tool finishes executing.
|
||||
type ToolExecutionEndEvent struct {
|
||||
ToolName string
|
||||
ToolCallID string
|
||||
ToolName string
|
||||
ToolKind string
|
||||
}
|
||||
|
||||
func (e ToolExecutionEndEvent) Type() EventType { return ToolExecutionEnd }
|
||||
|
||||
// ToolResultEvent fires after tool execution with the output.
|
||||
type ToolResultEvent struct {
|
||||
ToolName string
|
||||
Input string
|
||||
Content string
|
||||
IsError bool
|
||||
ToolCallID string
|
||||
ToolName string
|
||||
ToolKind string
|
||||
Input string
|
||||
Content string
|
||||
IsError bool
|
||||
Metadata string // Optional JSON-encoded structured metadata (e.g. file diffs)
|
||||
}
|
||||
|
||||
func (e ToolResultEvent) Type() EventType { return ToolResult }
|
||||
@@ -1640,3 +1753,44 @@ type BeforeCompactResult struct {
|
||||
}
|
||||
|
||||
func (BeforeCompactResult) isResult() {}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Theme types (exposed to Yaegi — concrete structs, string hex colors)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ThemeColor is an adaptive color pair with light and dark hex values.
|
||||
// Either field may be empty to inherit from the default theme.
|
||||
type ThemeColor struct {
|
||||
Light string
|
||||
Dark string
|
||||
}
|
||||
|
||||
// ThemeColorConfig defines a complete color theme that extensions can register
|
||||
// programmatically via ctx.RegisterTheme(). Uses plain hex strings (not
|
||||
// color.Color) so the type is safe to pass across the Yaegi boundary.
|
||||
type ThemeColorConfig struct {
|
||||
Primary ThemeColor
|
||||
Secondary ThemeColor
|
||||
Success ThemeColor
|
||||
Warning ThemeColor
|
||||
Error ThemeColor
|
||||
Info ThemeColor
|
||||
Text ThemeColor
|
||||
Muted ThemeColor
|
||||
VeryMuted ThemeColor
|
||||
Background ThemeColor
|
||||
Border ThemeColor
|
||||
MutedBorder ThemeColor
|
||||
System ThemeColor
|
||||
Tool ThemeColor
|
||||
Accent ThemeColor
|
||||
Highlight ThemeColor
|
||||
|
||||
// Markdown/syntax highlighting overrides.
|
||||
MdHeading ThemeColor
|
||||
MdLink ThemeColor
|
||||
MdKeyword ThemeColor
|
||||
MdString ThemeColor
|
||||
MdNumber ThemeColor
|
||||
MdComment ThemeColor
|
||||
}
|
||||
|
||||
@@ -0,0 +1,537 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// InstallScope defines where a package should be installed.
|
||||
type InstallScope string
|
||||
|
||||
const (
|
||||
ScopeGlobal InstallScope = "global"
|
||||
ScopeProject InstallScope = "project"
|
||||
)
|
||||
|
||||
// GitSource represents a parsed git repository URL.
|
||||
type GitSource struct {
|
||||
Repo string // Clone URL (e.g., https://github.com/user/repo.git)
|
||||
Host string // Host (e.g., github.com)
|
||||
Path string // Path (e.g., user/repo)
|
||||
Ref string // Optional ref (tag, branch, commit)
|
||||
Pinned bool // Whether a specific ref is pinned
|
||||
}
|
||||
|
||||
// String returns the canonical string representation.
|
||||
func (g GitSource) String() string {
|
||||
if g.Pinned {
|
||||
return fmt.Sprintf("git:%s/%s@%s", g.Host, g.Path, g.Ref)
|
||||
}
|
||||
return fmt.Sprintf("git:%s/%s", g.Host, g.Path)
|
||||
}
|
||||
|
||||
// Identity returns a normalized identity string for deduplication.
|
||||
func (g GitSource) Identity() string {
|
||||
return fmt.Sprintf("%s/%s", g.Host, g.Path)
|
||||
}
|
||||
|
||||
// ParseGitSource parses a git source string into a GitSource.
|
||||
// Supports formats like:
|
||||
// - git:github.com/user/repo
|
||||
// - git:github.com/user/repo@v1.0.0
|
||||
// - https://github.com/user/repo
|
||||
// - https://github.com/user/repo@v1.0.0
|
||||
// - ssh://git@github.com/user/repo
|
||||
// - git@github.com:user/repo
|
||||
// - github.com/user/repo (shorthand, defaults to https)
|
||||
func ParseGitSource(source string) (*GitSource, error) {
|
||||
source = strings.TrimSpace(source)
|
||||
|
||||
// Check for @ref suffix
|
||||
ref := ""
|
||||
pinned := false
|
||||
if atIdx := strings.LastIndex(source, "@"); atIdx > 0 {
|
||||
// Make sure it's not part of the protocol (e.g., @ in ssh://git@)
|
||||
after := source[atIdx+1:]
|
||||
if !strings.Contains(after, "/") && !strings.Contains(after, ":") {
|
||||
ref = after
|
||||
pinned = true
|
||||
source = source[:atIdx]
|
||||
}
|
||||
}
|
||||
|
||||
// Handle git: prefix
|
||||
source, _ = strings.CutPrefix(source, "git:")
|
||||
|
||||
var repo, host, path string
|
||||
|
||||
// Handle explicit URLs
|
||||
if strings.HasPrefix(source, "http://") || strings.HasPrefix(source, "https://") {
|
||||
u, err := url.Parse(source)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
host = u.Host
|
||||
path = strings.TrimPrefix(u.Path, "/")
|
||||
path, _ = strings.CutSuffix(path, ".git")
|
||||
repo = source
|
||||
if !strings.HasSuffix(repo, ".git") {
|
||||
repo += ".git"
|
||||
}
|
||||
} else if strings.HasPrefix(source, "ssh://") {
|
||||
u, err := url.Parse(source)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid SSH URL: %w", err)
|
||||
}
|
||||
host = u.Host
|
||||
path = strings.TrimPrefix(u.Path, "/")
|
||||
path, _ = strings.CutSuffix(path, ".git")
|
||||
repo = source
|
||||
} else if strings.HasPrefix(source, "git@") {
|
||||
// SSH shorthand: git@github.com:user/repo
|
||||
parts := strings.SplitN(source, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("invalid SSH shorthand format")
|
||||
}
|
||||
host = strings.TrimPrefix(parts[0], "git@")
|
||||
path = parts[1]
|
||||
path, _ = strings.CutSuffix(path, ".git")
|
||||
repo = source
|
||||
} else if strings.HasPrefix(source, "github.com/") || strings.HasPrefix(source, "gitlab.com/") || strings.HasPrefix(source, "bitbucket.org/") {
|
||||
// Shorthand for known hosts: host/path
|
||||
parts := strings.SplitN(source, "/", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("invalid shorthand format, expected host/path")
|
||||
}
|
||||
host = parts[0]
|
||||
path = parts[1]
|
||||
repo = fmt.Sprintf("https://%s/%s.git", host, path)
|
||||
} else if strings.HasPrefix(source, ".") || strings.HasPrefix(source, "/") || strings.HasPrefix(source, "~") {
|
||||
// Local paths are not supported
|
||||
return nil, fmt.Errorf("local paths not supported, use explicit extension path with -e flag")
|
||||
} else {
|
||||
// Generic shorthand: host/user/repo (3+ path segments)
|
||||
parts := strings.Split(source, "/")
|
||||
if len(parts) >= 3 {
|
||||
host = parts[0]
|
||||
path = strings.Join(parts[1:], "/")
|
||||
repo = fmt.Sprintf("https://%s/%s.git", host, path)
|
||||
} else {
|
||||
return nil, fmt.Errorf("unrecognized source format: %s", source)
|
||||
}
|
||||
}
|
||||
|
||||
return &GitSource{
|
||||
Repo: repo,
|
||||
Host: host,
|
||||
Path: path,
|
||||
Ref: ref,
|
||||
Pinned: pinned,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Installer handles installing, updating, and removing git-based extensions.
|
||||
type Installer struct {
|
||||
// Global packages root: $XDG_DATA_HOME/kit/git/ (default ~/.local/share/kit/git/)
|
||||
globalGitRoot string
|
||||
// Project packages root: .kit/git/
|
||||
projectGitRoot string
|
||||
}
|
||||
|
||||
// NewInstaller creates a new Installer.
|
||||
func NewInstaller(projectDir string) *Installer {
|
||||
return &Installer{
|
||||
globalGitRoot: globalGitInstallRoot(),
|
||||
projectGitRoot: filepath.Join(projectDir, ".kit", "git"),
|
||||
}
|
||||
}
|
||||
|
||||
// Install clones a git repository to the appropriate scope.
|
||||
func (i *Installer) Install(source *GitSource, scope InstallScope) error {
|
||||
targetDir := i.getInstallPath(source, scope)
|
||||
|
||||
// Check if already installed
|
||||
if _, err := os.Stat(targetDir); err == nil {
|
||||
return fmt.Errorf("extension already installed at %s", targetDir)
|
||||
}
|
||||
|
||||
// Ensure parent directory exists
|
||||
if err := os.MkdirAll(filepath.Dir(targetDir), 0755); err != nil {
|
||||
return fmt.Errorf("creating parent directory: %w", err)
|
||||
}
|
||||
|
||||
// Clone the repository
|
||||
cmd := exec.Command("git", "clone", "--depth=1", source.Repo, targetDir)
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("git clone failed: %w\n%s", err, string(output))
|
||||
}
|
||||
|
||||
// Checkout specific ref if pinned
|
||||
if source.Pinned && source.Ref != "" {
|
||||
checkoutCmd := exec.Command("git", "checkout", source.Ref)
|
||||
checkoutCmd.Dir = targetDir
|
||||
if output, err := checkoutCmd.CombinedOutput(); err != nil {
|
||||
// Clean up on failed checkout
|
||||
_ = os.RemoveAll(targetDir)
|
||||
return fmt.Errorf("git checkout failed: %w\n%s", err, string(output))
|
||||
}
|
||||
}
|
||||
|
||||
// Validate that the package contains valid extensions
|
||||
if err := i.validatePackage(targetDir); err != nil {
|
||||
_ = os.RemoveAll(targetDir)
|
||||
return fmt.Errorf("validation failed: %w", err)
|
||||
}
|
||||
|
||||
// Add to manifest
|
||||
entry := ManifestEntry{
|
||||
Source: source.String(),
|
||||
Repo: source.Repo,
|
||||
Host: source.Host,
|
||||
Path: source.Path,
|
||||
Ref: source.Ref,
|
||||
Pinned: source.Pinned,
|
||||
Scope: scope,
|
||||
Installed: time.Now(),
|
||||
}
|
||||
if err := i.addToManifest(entry, scope); err != nil {
|
||||
// Don't fail the install, just log the error
|
||||
// The package is installed, manifest update failed
|
||||
return fmt.Errorf("installed but failed to update manifest: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Uninstall removes an installed package.
|
||||
func (i *Installer) Uninstall(source *GitSource, scope InstallScope) error {
|
||||
targetDir := i.getInstallPath(source, scope)
|
||||
|
||||
if _, err := os.Stat(targetDir); err != nil {
|
||||
return fmt.Errorf("extension not found at %s", targetDir)
|
||||
}
|
||||
|
||||
// Remove the directory
|
||||
if err := os.RemoveAll(targetDir); err != nil {
|
||||
return fmt.Errorf("removing extension directory: %w", err)
|
||||
}
|
||||
|
||||
// Remove from manifest
|
||||
if err := i.removeFromManifest(source.Identity(), scope); err != nil {
|
||||
return fmt.Errorf("removed but failed to update manifest: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update fetches and resets a git package to the latest.
|
||||
// For pinned packages, this does nothing.
|
||||
func (i *Installer) Update(source *GitSource, scope InstallScope) error {
|
||||
if source.Pinned {
|
||||
return nil // Don't update pinned packages
|
||||
}
|
||||
|
||||
targetDir := i.getInstallPath(source, scope)
|
||||
|
||||
if _, err := os.Stat(targetDir); err != nil {
|
||||
return i.Install(source, scope)
|
||||
}
|
||||
|
||||
// Fetch latest
|
||||
fetchCmd := exec.Command("git", "fetch", "--prune", "origin")
|
||||
fetchCmd.Dir = targetDir
|
||||
if output, err := fetchCmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("git fetch failed: %w\n%s", err, string(output))
|
||||
}
|
||||
|
||||
// Reset to tracking branch or origin/HEAD
|
||||
resetCmd := exec.Command("git", "reset", "--hard", "@{upstream}")
|
||||
resetCmd.Dir = targetDir
|
||||
if _, err := resetCmd.CombinedOutput(); err != nil {
|
||||
// Try alternative: set HEAD and reset to origin/HEAD
|
||||
_ = exec.Command("git", "remote", "set-head", "origin", "-a").Run()
|
||||
resetCmd = exec.Command("git", "reset", "--hard", "origin/HEAD")
|
||||
resetCmd.Dir = targetDir
|
||||
if output, err := resetCmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("git reset failed: %w\n%s", err, string(output))
|
||||
}
|
||||
}
|
||||
|
||||
// Clean untracked files
|
||||
cleanCmd := exec.Command("git", "clean", "-fdx")
|
||||
cleanCmd.Dir = targetDir
|
||||
_ = cleanCmd.Run() // Ignore errors - clean is best effort
|
||||
|
||||
// Update manifest timestamp
|
||||
entry := ManifestEntry{
|
||||
Source: source.String(),
|
||||
Repo: source.Repo,
|
||||
Host: source.Host,
|
||||
Path: source.Path,
|
||||
Ref: "",
|
||||
Pinned: false,
|
||||
Scope: scope,
|
||||
Installed: time.Now(),
|
||||
Updated: time.Now(),
|
||||
}
|
||||
_ = i.addToManifest(entry, scope) // Best effort - don't fail update if manifest fails
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getInstallPath returns the target directory for a source.
|
||||
func (i *Installer) getInstallPath(source *GitSource, scope InstallScope) string {
|
||||
root := i.globalGitRoot
|
||||
if scope == ScopeProject {
|
||||
root = i.projectGitRoot
|
||||
}
|
||||
return filepath.Join(root, source.Host, source.Path)
|
||||
}
|
||||
|
||||
// validatePackage checks that the cloned repo contains valid .go extension files.
|
||||
func (i *Installer) validatePackage(dir string) error {
|
||||
// Find all .go files in the directory
|
||||
var goFiles []string
|
||||
err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !info.IsDir() && strings.HasSuffix(info.Name(), ".go") {
|
||||
goFiles = append(goFiles, path)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("walking directory: %w", err)
|
||||
}
|
||||
|
||||
if len(goFiles) == 0 {
|
||||
return fmt.Errorf("no .go files found in package")
|
||||
}
|
||||
|
||||
// Try to load the first .go file to validate it's a valid extension
|
||||
// We don't fail if validation fails - the extension might be fine but
|
||||
// have dependencies that aren't available during install time
|
||||
_, err = loadSingleExtension(goFiles[0])
|
||||
if err != nil {
|
||||
// Log but don't fail - the extension might need runtime deps
|
||||
// User can use `kit extensions validate` to check later
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addToManifest adds an entry to the manifest.
|
||||
func (i *Installer) addToManifest(entry ManifestEntry, scope InstallScope) error {
|
||||
manifest, err := i.loadManifest(scope)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove any existing entry with same identity
|
||||
identity := entry.Host + "/" + entry.Path
|
||||
filtered := make([]ManifestEntry, 0, len(manifest.Packages))
|
||||
for _, p := range manifest.Packages {
|
||||
if p.Host+"/"+p.Path != identity {
|
||||
filtered = append(filtered, p)
|
||||
}
|
||||
}
|
||||
filtered = append(filtered, entry)
|
||||
manifest.Packages = filtered
|
||||
|
||||
return i.saveManifest(manifest, scope)
|
||||
}
|
||||
|
||||
// removeFromManifest removes an entry from the manifest by identity.
|
||||
func (i *Installer) removeFromManifest(identity string, scope InstallScope) error {
|
||||
manifest, err := i.loadManifest(scope)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filtered := make([]ManifestEntry, 0, len(manifest.Packages))
|
||||
for _, p := range manifest.Packages {
|
||||
if p.Host+"/"+p.Path != identity {
|
||||
filtered = append(filtered, p)
|
||||
}
|
||||
}
|
||||
manifest.Packages = filtered
|
||||
|
||||
return i.saveManifest(manifest, scope)
|
||||
}
|
||||
|
||||
// loadManifest loads the manifest for the given scope.
|
||||
func (i *Installer) loadManifest(scope InstallScope) (*Manifest, error) {
|
||||
path := i.manifestPath(scope)
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return &Manifest{Packages: []ManifestEntry{}}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var manifest Manifest
|
||||
if err := json.Unmarshal(data, &manifest); err != nil {
|
||||
return nil, fmt.Errorf("parsing manifest: %w", err)
|
||||
}
|
||||
|
||||
return &manifest, nil
|
||||
}
|
||||
|
||||
// saveManifest saves the manifest for the given scope.
|
||||
func (i *Installer) saveManifest(manifest *Manifest, scope InstallScope) error {
|
||||
path := i.manifestPath(scope)
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
||||
return fmt.Errorf("creating manifest directory: %w", err)
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(manifest, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoding manifest: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, data, 0644); err != nil {
|
||||
return fmt.Errorf("writing manifest: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// manifestPath returns the path to the manifest file.
|
||||
func (i *Installer) manifestPath(scope InstallScope) string {
|
||||
if scope == ScopeProject {
|
||||
return filepath.Join(i.projectGitRoot, "packages.json")
|
||||
}
|
||||
return filepath.Join(i.globalGitRoot, "packages.json")
|
||||
}
|
||||
|
||||
// globalGitInstallRoot returns the global git install root.
|
||||
func globalGitInstallRoot() string {
|
||||
base := os.Getenv("XDG_DATA_HOME")
|
||||
if base == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
base = filepath.Join(home, ".local", "share")
|
||||
}
|
||||
return filepath.Join(base, "kit", "git")
|
||||
}
|
||||
|
||||
// GetInstalledPackages returns all installed packages from both scopes.
|
||||
func (i *Installer) GetInstalledPackages() ([]ManifestEntry, error) {
|
||||
var all []ManifestEntry
|
||||
|
||||
global, err := i.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading global manifest: %w", err)
|
||||
}
|
||||
all = append(all, global.Packages...)
|
||||
|
||||
project, err := i.loadManifest(ScopeProject)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading project manifest: %w", err)
|
||||
}
|
||||
all = append(all, project.Packages...)
|
||||
|
||||
return all, nil
|
||||
}
|
||||
|
||||
// IsInstalled checks if a package is installed in either scope.
|
||||
// Returns (scope, true) if installed, ("", false) otherwise.
|
||||
func (i *Installer) IsInstalled(source *GitSource) (InstallScope, bool) {
|
||||
globalPath := i.getInstallPath(source, ScopeGlobal)
|
||||
if _, err := os.Stat(globalPath); err == nil {
|
||||
return ScopeGlobal, true
|
||||
}
|
||||
|
||||
projectPath := i.getInstallPath(source, ScopeProject)
|
||||
if _, err := os.Stat(projectPath); err == nil {
|
||||
return ScopeProject, true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
// PreviewExtensions clones a repo to a temporary directory and scans for extensions.
|
||||
// Returns the preview list and the temp directory path (caller should clean up).
|
||||
func (i *Installer) PreviewExtensions(source *GitSource) ([]ExtensionPreview, string, error) {
|
||||
// Create temp directory
|
||||
tempDir, err := os.MkdirTemp("", "kit-install-preview-*")
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("creating temp directory: %w", err)
|
||||
}
|
||||
|
||||
// Clone to temp
|
||||
cloneDir := filepath.Join(tempDir, "repo")
|
||||
cmd := exec.Command("git", "clone", "--depth=1", source.Repo, cloneDir)
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return nil, "", fmt.Errorf("git clone failed: %w\n%s", err, string(output))
|
||||
}
|
||||
|
||||
// Checkout specific ref if pinned
|
||||
if source.Pinned && source.Ref != "" {
|
||||
checkoutCmd := exec.Command("git", "checkout", source.Ref)
|
||||
checkoutCmd.Dir = cloneDir
|
||||
if output, err := checkoutCmd.CombinedOutput(); err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return nil, "", fmt.Errorf("git checkout failed: %w\n%s", err, string(output))
|
||||
}
|
||||
}
|
||||
|
||||
// Scan for extensions
|
||||
previews, err := ScanForExtensions(cloneDir)
|
||||
if err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return nil, "", fmt.Errorf("scanning extensions: %w", err)
|
||||
}
|
||||
|
||||
return previews, tempDir, nil
|
||||
}
|
||||
|
||||
// InstallWithInclude clones a repo and installs only the specified extensions.
|
||||
// includePaths are relative paths like "./git/main.go" - if empty, installs all.
|
||||
func (i *Installer) InstallWithInclude(source *GitSource, scope InstallScope, includePaths []string) error {
|
||||
// First, do a regular install
|
||||
if err := i.Install(source, scope); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If specific includes were requested, update the manifest
|
||||
if len(includePaths) > 0 {
|
||||
entry := ManifestEntry{
|
||||
Source: source.String(),
|
||||
Repo: source.Repo,
|
||||
Host: source.Host,
|
||||
Path: source.Path,
|
||||
Ref: source.Ref,
|
||||
Pinned: source.Pinned,
|
||||
Scope: scope,
|
||||
Include: includePaths,
|
||||
}
|
||||
|
||||
if err := addEntryToManifest(entry, scope); err != nil {
|
||||
return fmt.Errorf("updating manifest with includes: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupTempDir removes a temporary directory used for preview.
|
||||
func CleanupTempDir(tempDir string) {
|
||||
if tempDir != "" {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,392 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseGitSource(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source string
|
||||
wantRepo string
|
||||
wantHost string
|
||||
wantPath string
|
||||
wantRef string
|
||||
wantPinned bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "github shorthand",
|
||||
source: "github.com/user/repo",
|
||||
wantRepo: "https://github.com/user/repo.git",
|
||||
wantHost: "github.com",
|
||||
wantPath: "user/repo",
|
||||
wantRef: "",
|
||||
wantPinned: false,
|
||||
},
|
||||
{
|
||||
name: "github shorthand with version",
|
||||
source: "github.com/user/repo@v1.0.0",
|
||||
wantRepo: "https://github.com/user/repo.git",
|
||||
wantHost: "github.com",
|
||||
wantPath: "user/repo",
|
||||
wantRef: "v1.0.0",
|
||||
wantPinned: true,
|
||||
},
|
||||
{
|
||||
name: "git prefix shorthand",
|
||||
source: "git:github.com/user/repo",
|
||||
wantRepo: "https://github.com/user/repo.git",
|
||||
wantHost: "github.com",
|
||||
wantPath: "user/repo",
|
||||
wantRef: "",
|
||||
wantPinned: false,
|
||||
},
|
||||
{
|
||||
name: "https URL",
|
||||
source: "https://github.com/user/repo",
|
||||
wantRepo: "https://github.com/user/repo.git",
|
||||
wantHost: "github.com",
|
||||
wantPath: "user/repo",
|
||||
wantRef: "",
|
||||
wantPinned: false,
|
||||
},
|
||||
{
|
||||
name: "https URL with .git suffix",
|
||||
source: "https://github.com/user/repo.git",
|
||||
wantRepo: "https://github.com/user/repo.git",
|
||||
wantHost: "github.com",
|
||||
wantPath: "user/repo",
|
||||
wantRef: "",
|
||||
wantPinned: false,
|
||||
},
|
||||
{
|
||||
name: "ssh shorthand",
|
||||
source: "git@github.com:user/repo",
|
||||
wantRepo: "git@github.com:user/repo",
|
||||
wantHost: "github.com",
|
||||
wantPath: "user/repo",
|
||||
wantRef: "",
|
||||
wantPinned: false,
|
||||
},
|
||||
{
|
||||
name: "ssh URL",
|
||||
source: "ssh://git@github.com/user/repo",
|
||||
wantRepo: "ssh://git@github.com/user/repo",
|
||||
wantHost: "github.com",
|
||||
wantPath: "user/repo",
|
||||
wantRef: "",
|
||||
wantPinned: false,
|
||||
},
|
||||
{
|
||||
name: "gitlab shorthand",
|
||||
source: "gitlab.com/user/repo",
|
||||
wantRepo: "https://gitlab.com/user/repo.git",
|
||||
wantHost: "gitlab.com",
|
||||
wantPath: "user/repo",
|
||||
wantRef: "",
|
||||
wantPinned: false,
|
||||
},
|
||||
{
|
||||
name: "bitbucket shorthand",
|
||||
source: "bitbucket.org/user/repo",
|
||||
wantRepo: "https://bitbucket.org/user/repo.git",
|
||||
wantHost: "bitbucket.org",
|
||||
wantPath: "user/repo",
|
||||
wantRef: "",
|
||||
wantPinned: false,
|
||||
},
|
||||
{
|
||||
name: "generic host",
|
||||
source: "gitea.example.com/user/repo",
|
||||
wantRepo: "https://gitea.example.com/user/repo.git",
|
||||
wantHost: "gitea.example.com",
|
||||
wantPath: "user/repo",
|
||||
wantRef: "",
|
||||
wantPinned: false,
|
||||
},
|
||||
{
|
||||
name: "with branch ref",
|
||||
source: "github.com/user/repo@main",
|
||||
wantRepo: "https://github.com/user/repo.git",
|
||||
wantHost: "github.com",
|
||||
wantPath: "user/repo",
|
||||
wantRef: "main",
|
||||
wantPinned: true,
|
||||
},
|
||||
{
|
||||
name: "with commit ref",
|
||||
source: "github.com/user/repo@abc1234",
|
||||
wantRepo: "https://github.com/user/repo.git",
|
||||
wantHost: "github.com",
|
||||
wantPath: "user/repo",
|
||||
wantRef: "abc1234",
|
||||
wantPinned: true,
|
||||
},
|
||||
{
|
||||
name: "local path should error",
|
||||
source: "./local/path",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "absolute path should error",
|
||||
source: "/absolute/path",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ParseGitSource(tt.source)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseGitSource() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if got.Repo != tt.wantRepo {
|
||||
t.Errorf("ParseGitSource() Repo = %v, want %v", got.Repo, tt.wantRepo)
|
||||
}
|
||||
if got.Host != tt.wantHost {
|
||||
t.Errorf("ParseGitSource() Host = %v, want %v", got.Host, tt.wantHost)
|
||||
}
|
||||
if got.Path != tt.wantPath {
|
||||
t.Errorf("ParseGitSource() Path = %v, want %v", got.Path, tt.wantPath)
|
||||
}
|
||||
if got.Ref != tt.wantRef {
|
||||
t.Errorf("ParseGitSource() Ref = %v, want %v", got.Ref, tt.wantRef)
|
||||
}
|
||||
if got.Pinned != tt.wantPinned {
|
||||
t.Errorf("ParseGitSource() Pinned = %v, want %v", got.Pinned, tt.wantPinned)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitSourceIdentity(t *testing.T) {
|
||||
source := &GitSource{
|
||||
Host: "github.com",
|
||||
Path: "user/repo",
|
||||
}
|
||||
if got := source.Identity(); got != "github.com/user/repo" {
|
||||
t.Errorf("Identity() = %v, want %v", got, "github.com/user/repo")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitSourceString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source GitSource
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "unpinned",
|
||||
source: GitSource{
|
||||
Host: "github.com",
|
||||
Path: "user/repo",
|
||||
Pinned: false,
|
||||
},
|
||||
want: "git:github.com/user/repo",
|
||||
},
|
||||
{
|
||||
name: "pinned",
|
||||
source: GitSource{
|
||||
Host: "github.com",
|
||||
Path: "user/repo",
|
||||
Ref: "v1.0.0",
|
||||
Pinned: true,
|
||||
},
|
||||
want: "git:github.com/user/repo@v1.0.0",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.source.String(); got != tt.want {
|
||||
t.Errorf("String() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallerGetInstallPath(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
installer := NewInstaller(tempDir)
|
||||
|
||||
source := &GitSource{
|
||||
Host: "github.com",
|
||||
Path: "user/repo",
|
||||
}
|
||||
|
||||
// Test global scope
|
||||
globalPath := installer.getInstallPath(source, ScopeGlobal)
|
||||
if !filepath.IsAbs(globalPath) {
|
||||
t.Error("Global install path should be absolute")
|
||||
}
|
||||
|
||||
// Test project scope
|
||||
projectPath := installer.getInstallPath(source, ScopeProject)
|
||||
expectedProjectPath := filepath.Join(tempDir, ".kit", "git", "github.com", "user", "repo")
|
||||
if projectPath != expectedProjectPath {
|
||||
t.Errorf("Project path = %v, want %v", projectPath, expectedProjectPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManifestEntryIdentity(t *testing.T) {
|
||||
entry := ManifestEntry{
|
||||
Host: "github.com",
|
||||
Path: "user/repo",
|
||||
}
|
||||
if got := entry.Identity(); got != "github.com/user/repo" {
|
||||
t.Errorf("Identity() = %v, want %v", got, "github.com/user/repo")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAndSaveManifest(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
manifestPath := filepath.Join(tempDir, "packages.json")
|
||||
|
||||
// Test loading non-existent manifest
|
||||
manifest, err := loadManifestFromPath(manifestPath)
|
||||
if err != nil {
|
||||
t.Fatalf("loadManifestFromPath() error = %v", err)
|
||||
}
|
||||
if len(manifest.Packages) != 0 {
|
||||
t.Errorf("Expected empty packages, got %d", len(manifest.Packages))
|
||||
}
|
||||
|
||||
// Create a manifest
|
||||
manifest = &Manifest{
|
||||
Packages: []ManifestEntry{
|
||||
{
|
||||
Source: "git:github.com/user/repo",
|
||||
Repo: "https://github.com/user/repo.git",
|
||||
Host: "github.com",
|
||||
Path: "user/repo",
|
||||
Pinned: false,
|
||||
Scope: ScopeGlobal,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Save it
|
||||
err = saveManifestToPath(manifest, manifestPath)
|
||||
if err != nil {
|
||||
t.Fatalf("saveManifestToPath() error = %v", err)
|
||||
}
|
||||
|
||||
// Load it back
|
||||
loaded, err := loadManifestFromPath(manifestPath)
|
||||
if err != nil {
|
||||
t.Fatalf("loadManifestFromPath() error = %v", err)
|
||||
}
|
||||
if len(loaded.Packages) != 1 {
|
||||
t.Errorf("Expected 1 package, got %d", len(loaded.Packages))
|
||||
}
|
||||
if loaded.Packages[0].Host != "github.com" {
|
||||
t.Errorf("Expected host github.com, got %s", loaded.Packages[0].Host)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddAndRemoveFromManifest(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Set up environment for manifest path
|
||||
if err := os.Setenv("XDG_DATA_HOME", tempDir); err != nil {
|
||||
t.Fatalf("Setenv() error = %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := os.Unsetenv("XDG_DATA_HOME"); err != nil {
|
||||
t.Logf("Unsetenv() error = %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// The manifest path when XDG_DATA_HOME is set
|
||||
manifestPath := filepath.Join(tempDir, "kit", "git", "packages.json")
|
||||
|
||||
// Add an entry
|
||||
entry := ManifestEntry{
|
||||
Source: "git:github.com/user/repo",
|
||||
Host: "github.com",
|
||||
Path: "user/repo",
|
||||
Scope: ScopeGlobal,
|
||||
}
|
||||
|
||||
err := addEntryToManifest(entry, ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("addEntryToManifest() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify it was added
|
||||
manifest, err := loadManifestFromPath(manifestPath)
|
||||
if err != nil {
|
||||
t.Fatalf("loadManifestFromPath() error = %v", err)
|
||||
}
|
||||
if len(manifest.Packages) != 1 {
|
||||
t.Errorf("Expected 1 package, got %d", len(manifest.Packages))
|
||||
}
|
||||
|
||||
// Remove it
|
||||
err = removeEntryFromManifest("github.com/user/repo", ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("removeEntryFromManifest() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify it was removed
|
||||
manifest, err = loadManifestFromPath(manifestPath)
|
||||
if err != nil {
|
||||
t.Fatalf("loadManifestFromPath() error = %v", err)
|
||||
}
|
||||
if len(manifest.Packages) != 0 {
|
||||
t.Errorf("Expected 0 packages, got %d", len(manifest.Packages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindInManifest(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
if err := os.Setenv("XDG_DATA_HOME", tempDir); err != nil {
|
||||
t.Fatalf("Setenv() error = %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := os.Unsetenv("XDG_DATA_HOME"); err != nil {
|
||||
t.Logf("Unsetenv() error = %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Add an entry to global manifest
|
||||
entry := ManifestEntry{
|
||||
Source: "git:github.com/user/repo",
|
||||
Host: "github.com",
|
||||
Path: "user/repo",
|
||||
Scope: ScopeGlobal,
|
||||
}
|
||||
|
||||
err := addEntryToManifest(entry, ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("addEntryToManifest() error = %v", err)
|
||||
}
|
||||
|
||||
// Find it
|
||||
found, scope, err := FindInManifest("github.com/user/repo")
|
||||
if err != nil {
|
||||
t.Fatalf("FindInManifest() error = %v", err)
|
||||
}
|
||||
if found == nil {
|
||||
t.Fatal("Expected to find entry, got nil")
|
||||
}
|
||||
if scope != ScopeGlobal {
|
||||
t.Errorf("Expected scope global, got %s", scope)
|
||||
}
|
||||
|
||||
// Try to find non-existent
|
||||
notFound, _, err := FindInManifest("github.com/other/repo")
|
||||
if err != nil {
|
||||
t.Fatalf("FindInManifest() error = %v", err)
|
||||
}
|
||||
if notFound != nil {
|
||||
t.Error("Expected nil for non-existent entry")
|
||||
}
|
||||
}
|
||||
@@ -71,12 +71,24 @@ func discoverExtensionPaths(extraPaths []string) []string {
|
||||
add(p)
|
||||
}
|
||||
|
||||
// Global installed git packages: $XDG_DATA_HOME/kit/git/
|
||||
globalGitDir := globalGitInstallRoot()
|
||||
for _, p := range findExtensionsInGitPackages(globalGitDir) {
|
||||
add(p)
|
||||
}
|
||||
|
||||
// Project-local extensions: .kit/extensions/
|
||||
localDir := filepath.Join(".kit", "extensions")
|
||||
for _, p := range findExtensionsInDir(localDir) {
|
||||
add(p)
|
||||
}
|
||||
|
||||
// Project-local installed git packages: .kit/git/
|
||||
projectGitDir := filepath.Join(".kit", "git")
|
||||
for _, p := range findExtensionsInGitPackages(projectGitDir) {
|
||||
add(p)
|
||||
}
|
||||
|
||||
// Explicit paths (highest precedence)
|
||||
for _, p := range extraPaths {
|
||||
info, err := os.Stat(p)
|
||||
@@ -123,6 +135,219 @@ func findExtensionsInDir(dir string) []string {
|
||||
return results
|
||||
}
|
||||
|
||||
// findExtensionsInRepo scans a git repository for extensions using opinionated conventions.
|
||||
// Extensions are ONLY recognized in:
|
||||
// 1. Root-level *.go files
|
||||
// 2. Files in examples/extensions/ or examples/ext/ subdirectories
|
||||
// 3. Files in any top-level ext/ directory
|
||||
// 4. Files in any subdirectory that ends in -ext/ or -extensions/
|
||||
//
|
||||
// Everything else (cmd/, internal/, pkg/, etc.) is ignored.
|
||||
func findExtensionsInRepo(repoPath string) []string {
|
||||
var results []string
|
||||
multiFileDirs := make(map[string]bool)
|
||||
|
||||
_ = filepath.Walk(repoPath, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
relPath, _ := filepath.Rel(repoPath, path)
|
||||
relPath = filepath.ToSlash(relPath)
|
||||
|
||||
// Skip directories we know don't contain extensions
|
||||
if info.IsDir() {
|
||||
switch info.Name() {
|
||||
case ".git", ".github", "node_modules", "vendor", "dist", "build":
|
||||
return filepath.SkipDir
|
||||
}
|
||||
|
||||
// Skip internal code directories
|
||||
if strings.HasPrefix(relPath, "internal/") ||
|
||||
strings.HasPrefix(relPath, "cmd/") ||
|
||||
strings.HasPrefix(relPath, "pkg/") ||
|
||||
strings.HasPrefix(relPath, "test/") ||
|
||||
strings.HasPrefix(relPath, "tests/") {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
|
||||
// Root directory - scan it
|
||||
if relPath == "." {
|
||||
return nil
|
||||
}
|
||||
|
||||
base := info.Name()
|
||||
isExtDir := base == "extensions" || base == "ext" ||
|
||||
strings.HasSuffix(base, "-extensions") || strings.HasSuffix(base, "-ext")
|
||||
|
||||
isExamplesSubdir := relPath == "examples" || strings.HasPrefix(relPath, "examples/")
|
||||
|
||||
if !isExtDir && !isExamplesSubdir {
|
||||
mainPath := filepath.Join(path, "main.go")
|
||||
if _, err := os.Stat(mainPath); err == nil {
|
||||
if relPath == base { // Top-level directory
|
||||
if !multiFileDirs[relPath] {
|
||||
multiFileDirs[relPath] = true
|
||||
results = append(results, mainPath)
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
if isExamplesSubdir || isExtDir {
|
||||
if !multiFileDirs[relPath] {
|
||||
multiFileDirs[relPath] = true
|
||||
results = append(results, mainPath)
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
|
||||
// Check for main.go
|
||||
mainPath := filepath.Join(path, "main.go")
|
||||
if _, err := os.Stat(mainPath); err == nil {
|
||||
if !multiFileDirs[relPath] {
|
||||
multiFileDirs[relPath] = true
|
||||
results = append(results, mainPath)
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// It's a file
|
||||
if !strings.HasSuffix(info.Name(), ".go") {
|
||||
return nil
|
||||
}
|
||||
|
||||
if info.Name() == "main.go" {
|
||||
return nil
|
||||
}
|
||||
|
||||
parentDir := filepath.Dir(relPath)
|
||||
if parentDir == "." {
|
||||
// Root-level .go file - valid extension
|
||||
results = append(results, path)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Must be in valid extension directory
|
||||
isValidExtDir := false
|
||||
if strings.HasPrefix(parentDir, "examples/extensions/") ||
|
||||
parentDir == "examples/extensions" {
|
||||
isValidExtDir = true
|
||||
} else if strings.HasPrefix(parentDir, "examples/ext/") ||
|
||||
parentDir == "examples/ext" {
|
||||
isValidExtDir = true
|
||||
} else if strings.HasPrefix(parentDir, "ext/") ||
|
||||
parentDir == "ext" {
|
||||
isValidExtDir = true
|
||||
} else if strings.Contains(parentDir, "-extensions/") ||
|
||||
strings.HasSuffix(parentDir, "-extensions") {
|
||||
isValidExtDir = true
|
||||
} else if strings.Contains(parentDir, "-ext/") ||
|
||||
strings.HasSuffix(parentDir, "-ext") {
|
||||
isValidExtDir = true
|
||||
}
|
||||
|
||||
if !isValidExtDir {
|
||||
return nil
|
||||
}
|
||||
|
||||
results = append(results, path)
|
||||
return nil
|
||||
})
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// Each git package is stored at <gitRoot>/<host>/<owner>/<repo>/ and can contain
|
||||
// .go files or a main.go in subdirectories.
|
||||
// If a package has a manifest with Include field, only those paths are loaded.
|
||||
func findExtensionsInGitPackages(gitRoot string) []string {
|
||||
info, err := os.Stat(gitRoot)
|
||||
if err != nil || !info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
var results []string
|
||||
|
||||
// Load the manifest if it exists
|
||||
manifestPath := filepath.Join(gitRoot, "packages.json")
|
||||
manifest, _ := loadManifestFromPath(manifestPath)
|
||||
// Build a map of package identity -> include list
|
||||
includeMap := make(map[string][]string)
|
||||
if manifest != nil {
|
||||
for _, entry := range manifest.Packages {
|
||||
if len(entry.Include) > 0 {
|
||||
identity := fmt.Sprintf("%s/%s", entry.Host, entry.Path)
|
||||
includeMap[identity] = entry.Include
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Walk through host directories (e.g., github.com/)
|
||||
hosts, err := os.ReadDir(gitRoot)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, host := range hosts {
|
||||
if !host.IsDir() {
|
||||
continue
|
||||
}
|
||||
hostPath := filepath.Join(gitRoot, host.Name())
|
||||
|
||||
// Walk through owner directories (e.g., github.com/user/)
|
||||
owners, err := os.ReadDir(hostPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, owner := range owners {
|
||||
if !owner.IsDir() {
|
||||
continue
|
||||
}
|
||||
ownerPath := filepath.Join(hostPath, owner.Name())
|
||||
|
||||
// Walk through repo directories (e.g., github.com/user/repo/)
|
||||
repos, err := os.ReadDir(ownerPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, repo := range repos {
|
||||
if !repo.IsDir() {
|
||||
continue
|
||||
}
|
||||
repoPath := filepath.Join(ownerPath, repo.Name())
|
||||
|
||||
// Check if there's an include filter for this package
|
||||
identity := fmt.Sprintf("%s/%s/%s", host.Name(), owner.Name(), repo.Name())
|
||||
includes, hasFilter := includeMap[identity]
|
||||
|
||||
if hasFilter {
|
||||
// Only include specific paths
|
||||
for _, include := range includes {
|
||||
// Convert relative path to absolute
|
||||
include = strings.TrimPrefix(include, "./")
|
||||
fullPath := filepath.Join(repoPath, filepath.FromSlash(include))
|
||||
if _, err := os.Stat(fullPath); err == nil {
|
||||
results = append(results, fullPath)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Find all extensions within this repo using convention-based scanning
|
||||
results = append(results, findExtensionsInRepo(repoPath)...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// globalExtensionsDir returns the global extensions directory, respecting
|
||||
// $XDG_CONFIG_HOME. Defaults to ~/.config/kit/extensions.
|
||||
func globalExtensionsDir() string {
|
||||
|
||||
@@ -304,6 +304,15 @@ func Init(api ext.API) {
|
||||
func TestLoadExtensions_SkipsBadFiles(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Isolate from host environment so globally-installed extensions
|
||||
// are not discovered alongside the test fixtures.
|
||||
isolated := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", filepath.Join(isolated, "config"))
|
||||
t.Setenv("XDG_DATA_HOME", filepath.Join(isolated, "data"))
|
||||
origWd, _ := os.Getwd()
|
||||
_ = os.Chdir(isolated)
|
||||
t.Cleanup(func() { _ = os.Chdir(origWd) })
|
||||
|
||||
// Good extension
|
||||
good := `package main
|
||||
import "kit/ext"
|
||||
|
||||
@@ -0,0 +1,398 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
// Manifest tracks installed git packages.
|
||||
type Manifest struct {
|
||||
Packages []ManifestEntry `json:"packages"`
|
||||
}
|
||||
|
||||
// ManifestEntry represents a single installed package.
|
||||
type ManifestEntry struct {
|
||||
// Source is the canonical string representation (e.g., "git:github.com/user/repo@v1.0.0")
|
||||
Source string `json:"source"`
|
||||
// Repo is the clone URL
|
||||
Repo string `json:"repo"`
|
||||
// Host is the git host (e.g., github.com)
|
||||
Host string `json:"host"`
|
||||
// Path is the path on the host (e.g., user/repo)
|
||||
Path string `json:"path"`
|
||||
// Ref is the optional pinned ref (tag/branch/commit)
|
||||
Ref string `json:"ref,omitempty"`
|
||||
// Pinned indicates if the ref is pinned
|
||||
Pinned bool `json:"pinned"`
|
||||
// Scope is where the package is installed (global or project)
|
||||
Scope InstallScope `json:"scope"`
|
||||
// Installed is when the package was first installed
|
||||
Installed time.Time `json:"installed"`
|
||||
// Updated is when the package was last updated (only for unpinned, zero time means never updated)
|
||||
Updated time.Time `json:"updated,omitzero"`
|
||||
// Include is a list of relative paths to extensions that should be loaded.
|
||||
// If empty, all extensions in the package are loaded.
|
||||
// Paths are relative to the package root (e.g., "./git/main.go", "./weather.go")
|
||||
Include []string `json:"include,omitempty"`
|
||||
}
|
||||
|
||||
// Identity returns the normalized identity for deduplication.
|
||||
func (e ManifestEntry) Identity() string {
|
||||
return fmt.Sprintf("%s/%s", e.Host, e.Path)
|
||||
}
|
||||
|
||||
// loadManifest loads the manifest from the given scope.
|
||||
func loadManifestFromScope(scope InstallScope) (*Manifest, error) {
|
||||
path := manifestPathForScope(scope)
|
||||
return loadManifestFromPath(path)
|
||||
}
|
||||
|
||||
// loadManifestFromPath loads a manifest from a specific file path.
|
||||
func loadManifestFromPath(path string) (*Manifest, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return &Manifest{Packages: []ManifestEntry{}}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("reading manifest: %w", err)
|
||||
}
|
||||
|
||||
var manifest Manifest
|
||||
if err := json.Unmarshal(data, &manifest); err != nil {
|
||||
return nil, fmt.Errorf("parsing manifest: %w", err)
|
||||
}
|
||||
|
||||
return &manifest, nil
|
||||
}
|
||||
|
||||
// saveManifestToScope saves the manifest to the given scope.
|
||||
func saveManifestToScope(manifest *Manifest, scope InstallScope) error {
|
||||
path := manifestPathForScope(scope)
|
||||
return saveManifestToPath(manifest, path)
|
||||
}
|
||||
|
||||
// saveManifestToPath saves a manifest to a specific file path.
|
||||
func saveManifestToPath(manifest *Manifest, path string) error {
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
||||
return fmt.Errorf("creating manifest directory: %w", err)
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(manifest, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoding manifest: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, data, 0644); err != nil {
|
||||
return fmt.Errorf("writing manifest: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// manifestPathForScope returns the manifest file path for a scope.
|
||||
func manifestPathForScope(scope InstallScope) string {
|
||||
if scope == ScopeProject {
|
||||
return filepath.Join(".kit", "git", "packages.json")
|
||||
}
|
||||
|
||||
base := os.Getenv("XDG_DATA_HOME")
|
||||
if base == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
base = filepath.Join(home, ".local", "share")
|
||||
}
|
||||
return filepath.Join(base, "kit", "git", "packages.json")
|
||||
}
|
||||
|
||||
// GetGlobalManifest returns the global manifest.
|
||||
func GetGlobalManifest() (*Manifest, error) {
|
||||
return loadManifestFromScope(ScopeGlobal)
|
||||
}
|
||||
|
||||
// GetProjectManifest returns the project manifest.
|
||||
func GetProjectManifest() (*Manifest, error) {
|
||||
return loadManifestFromScope(ScopeProject)
|
||||
}
|
||||
|
||||
// addEntryToManifest adds or replaces an entry in the manifest for a scope.
|
||||
func addEntryToManifest(entry ManifestEntry, scope InstallScope) error {
|
||||
manifest, err := loadManifestFromScope(scope)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove any existing entry with same identity
|
||||
identity := entry.Identity()
|
||||
filtered := make([]ManifestEntry, 0, len(manifest.Packages))
|
||||
for _, p := range manifest.Packages {
|
||||
if p.Identity() != identity {
|
||||
filtered = append(filtered, p)
|
||||
}
|
||||
}
|
||||
filtered = append(filtered, entry)
|
||||
manifest.Packages = filtered
|
||||
|
||||
return saveManifestToScope(manifest, scope)
|
||||
}
|
||||
|
||||
// removeEntryFromManifest removes an entry by identity from the manifest for a scope.
|
||||
func removeEntryFromManifest(identity string, scope InstallScope) error {
|
||||
manifest, err := loadManifestFromScope(scope)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filtered := make([]ManifestEntry, 0, len(manifest.Packages))
|
||||
for _, p := range manifest.Packages {
|
||||
if p.Identity() != identity {
|
||||
filtered = append(filtered, p)
|
||||
}
|
||||
}
|
||||
manifest.Packages = filtered
|
||||
|
||||
return saveManifestToScope(manifest, scope)
|
||||
}
|
||||
|
||||
// FindInManifest finds an entry by identity in either global or project manifest.
|
||||
// Returns the entry and its scope, or nil if not found.
|
||||
func FindInManifest(identity string) (*ManifestEntry, InstallScope, error) {
|
||||
global, err := loadManifestFromScope(ScopeGlobal)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("loading global manifest: %w", err)
|
||||
}
|
||||
for _, p := range global.Packages {
|
||||
if p.Identity() == identity {
|
||||
return &p, ScopeGlobal, nil
|
||||
}
|
||||
}
|
||||
|
||||
project, err := loadManifestFromScope(ScopeProject)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("loading project manifest: %w", err)
|
||||
}
|
||||
for _, p := range project.Packages {
|
||||
if p.Identity() == identity {
|
||||
return &p, ScopeProject, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, "", nil
|
||||
}
|
||||
|
||||
// ExtensionPreview represents a discovered extension in a package before installation.
|
||||
type ExtensionPreview struct {
|
||||
// Path is the relative path from the package root (e.g., "./git/main.go")
|
||||
Path string `json:"path"`
|
||||
// Name is a display name for the extension (derived from path or metadata)
|
||||
Name string `json:"name"`
|
||||
// Description is an optional description (could be extracted from comments)
|
||||
Description string `json:"description,omitempty"`
|
||||
// IsMain indicates if this is a main.go in a subdirectory
|
||||
IsMain bool `json:"is_main"`
|
||||
}
|
||||
|
||||
// ScanForExtensions discovers all extensions in a directory using opinionated conventions.
|
||||
// Extensions are ONLY recognized in these specific locations:
|
||||
// 1. Root-level *.go files
|
||||
// 2. Files in examples/extensions/ or examples/ext/ subdirectories
|
||||
// 3. Files in any top-level ext/ directory
|
||||
// 4. Files in any subdirectory that ends in -ext/ or -extensions/
|
||||
//
|
||||
// Everything else (cmd/, internal/, pkg/, etc.) is ignored.
|
||||
func ScanForExtensions(dir string) ([]ExtensionPreview, error) {
|
||||
info, err := os.Stat(dir)
|
||||
if err != nil || !info.IsDir() {
|
||||
return nil, fmt.Errorf("not a directory: %s", dir)
|
||||
}
|
||||
|
||||
var previews []ExtensionPreview
|
||||
multiFileDirs := make(map[string]bool)
|
||||
|
||||
err = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
relPath, _ := filepath.Rel(dir, path)
|
||||
relPath = filepath.ToSlash(relPath)
|
||||
|
||||
// Skip directories we know don't contain extensions
|
||||
if info.IsDir() {
|
||||
// Never scan these directories
|
||||
switch info.Name() {
|
||||
case ".git", ".github", "node_modules", "vendor", "dist", "build":
|
||||
return filepath.SkipDir
|
||||
}
|
||||
|
||||
// Skip internal code directories
|
||||
if strings.HasPrefix(relPath, "internal/") ||
|
||||
strings.HasPrefix(relPath, "cmd/") ||
|
||||
strings.HasPrefix(relPath, "pkg/") ||
|
||||
strings.HasPrefix(relPath, "test/") ||
|
||||
strings.HasPrefix(relPath, "tests/") {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
|
||||
// Root directory - scan it
|
||||
if relPath == "." {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if this directory is an extension location by name
|
||||
// Pattern: must be named "extensions", "ext", or end with those
|
||||
base := info.Name()
|
||||
isExtDir := base == "extensions" || base == "ext" ||
|
||||
strings.HasSuffix(base, "-extensions") || strings.HasSuffix(base, "-ext")
|
||||
|
||||
// Or check if it's a subdirectory of examples/ that might contain extensions
|
||||
isExamplesSubdir := relPath == "examples" || strings.HasPrefix(relPath, "examples/")
|
||||
|
||||
if !isExtDir && !isExamplesSubdir {
|
||||
// Check for main.go before skipping
|
||||
mainPath := filepath.Join(path, "main.go")
|
||||
if _, err := os.Stat(mainPath); err == nil {
|
||||
// This is a package with main.go at root level
|
||||
if relPath == base { // Top-level directory
|
||||
if !multiFileDirs[relPath] {
|
||||
multiFileDirs[relPath] = true
|
||||
previews = append(previews, ExtensionPreview{
|
||||
Path: "./" + relPath + "/main.go",
|
||||
Name: deriveExtensionName(relPath+"/main.go", true),
|
||||
IsMain: true,
|
||||
})
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
// Inside a valid extensions directory
|
||||
if isExamplesSubdir || isExtDir {
|
||||
if !multiFileDirs[relPath] {
|
||||
multiFileDirs[relPath] = true
|
||||
previews = append(previews, ExtensionPreview{
|
||||
Path: "./" + relPath + "/main.go",
|
||||
Name: deriveExtensionName(relPath+"/main.go", true),
|
||||
IsMain: true,
|
||||
})
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
}
|
||||
|
||||
// Not an extension location
|
||||
return filepath.SkipDir
|
||||
}
|
||||
|
||||
// Check for main.go in this directory
|
||||
mainPath := filepath.Join(path, "main.go")
|
||||
if _, err := os.Stat(mainPath); err == nil {
|
||||
if !multiFileDirs[relPath] {
|
||||
multiFileDirs[relPath] = true
|
||||
previews = append(previews, ExtensionPreview{
|
||||
Path: "./" + relPath + "/main.go",
|
||||
Name: deriveExtensionName(relPath+"/main.go", true),
|
||||
IsMain: true,
|
||||
})
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
|
||||
// Scan this extensions directory
|
||||
return nil
|
||||
}
|
||||
|
||||
// It's a file - check if it's a valid extension
|
||||
if !strings.HasSuffix(info.Name(), ".go") {
|
||||
return nil
|
||||
}
|
||||
|
||||
if info.Name() == "main.go" {
|
||||
return nil // Already handled above
|
||||
}
|
||||
|
||||
// Check if parent is a valid extension location
|
||||
parentDir := filepath.Dir(relPath)
|
||||
if parentDir == "." {
|
||||
// Root-level .go file - valid extension
|
||||
previews = append(previews, ExtensionPreview{
|
||||
Path: "./" + relPath,
|
||||
Name: deriveExtensionName(relPath, false),
|
||||
IsMain: false,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if we're in a valid extension directory
|
||||
// Valid locations are:
|
||||
// - examples/extensions/*
|
||||
// - examples/ext/*
|
||||
// - ext/* (top-level)
|
||||
// - Any *-extensions/* or *-ext/* directory
|
||||
isValidExtDir := false
|
||||
if strings.HasPrefix(parentDir, "examples/extensions/") ||
|
||||
parentDir == "examples/extensions" {
|
||||
isValidExtDir = true
|
||||
} else if strings.HasPrefix(parentDir, "examples/ext/") ||
|
||||
parentDir == "examples/ext" {
|
||||
isValidExtDir = true
|
||||
} else if strings.HasPrefix(parentDir, "ext/") ||
|
||||
parentDir == "ext" {
|
||||
isValidExtDir = true
|
||||
} else if strings.Contains(parentDir, "-extensions/") ||
|
||||
strings.HasSuffix(parentDir, "-extensions") {
|
||||
isValidExtDir = true
|
||||
} else if strings.Contains(parentDir, "-ext/") ||
|
||||
strings.HasSuffix(parentDir, "-ext") {
|
||||
isValidExtDir = true
|
||||
}
|
||||
|
||||
if !isValidExtDir {
|
||||
return nil
|
||||
}
|
||||
|
||||
previews = append(previews, ExtensionPreview{
|
||||
Path: "./" + relPath,
|
||||
Name: deriveExtensionName(relPath, false),
|
||||
IsMain: false,
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return previews, nil
|
||||
}
|
||||
|
||||
// deriveExtensionName creates a display name from a file path.
|
||||
func deriveExtensionName(relPath string, isMain bool) string {
|
||||
// Convert path to a readable name
|
||||
// e.g., "git/main.go" -> "Git Extension"
|
||||
// e.g., "weather.go" -> "Weather"
|
||||
|
||||
dir := filepath.Dir(relPath)
|
||||
base := filepath.Base(relPath)
|
||||
|
||||
if isMain && dir != "." {
|
||||
// Use immediate parent directory name for main.go files
|
||||
name := filepath.Base(dir)
|
||||
name = strings.ReplaceAll(name, "_", " ")
|
||||
name = strings.ReplaceAll(name, "-", " ")
|
||||
return cases.Title(language.English).String(name) + " Extension"
|
||||
}
|
||||
|
||||
// Use filename without extension
|
||||
name := strings.TrimSuffix(base, ".go")
|
||||
name = strings.ReplaceAll(name, "_", " ")
|
||||
name = strings.ReplaceAll(name, "-", " ")
|
||||
return cases.Title(language.English).String(name)
|
||||
}
|
||||
@@ -0,0 +1,377 @@
|
||||
// Package extensions provides subagent spawning capabilities for Kit extensions.
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Subagent types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SubagentConfig configures a subagent spawn.
|
||||
type SubagentConfig struct {
|
||||
// Prompt is the task/instruction for the subagent (required).
|
||||
Prompt string
|
||||
|
||||
// Model overrides the parent's model (e.g. "anthropic/claude-haiku-3-5-20241022").
|
||||
// Empty string uses the parent's current model.
|
||||
Model string
|
||||
|
||||
// SystemPrompt provides domain-specific instructions.
|
||||
// Empty string uses the default system prompt.
|
||||
SystemPrompt string
|
||||
|
||||
// Timeout limits execution time. Zero means 5 minute default.
|
||||
Timeout time.Duration
|
||||
|
||||
// OnOutput streams stderr output chunks as the subagent runs.
|
||||
// Called from a goroutine; must be safe for concurrent use.
|
||||
OnOutput func(chunk string)
|
||||
|
||||
// OnEvent receives real-time events from the subagent's execution:
|
||||
// text chunks, tool calls, tool results, reasoning deltas, etc.
|
||||
// Called synchronously from the subagent's event loop.
|
||||
OnEvent func(SubagentEvent)
|
||||
|
||||
// OnComplete is called when the subagent finishes (success or error).
|
||||
// Called from a goroutine; must be safe for concurrent use.
|
||||
OnComplete func(result SubagentResult)
|
||||
|
||||
// Blocking, when true, makes SpawnSubagent wait for completion and
|
||||
// return the result directly. When false (default), spawns in background
|
||||
// and returns immediately with a handle.
|
||||
Blocking bool
|
||||
|
||||
// NoSession, when true, runs the subagent without persisting a session
|
||||
// file. By default (false), subagent sessions are persisted so they can
|
||||
// be loaded for replay/inspection. Set to true for ephemeral tasks
|
||||
// where session history is not needed.
|
||||
NoSession bool
|
||||
|
||||
// ParentSessionID links the subagent's session to the parent (optional).
|
||||
// When set, the subagent's session header includes a parent reference
|
||||
// so viewers can navigate the session tree.
|
||||
ParentSessionID string
|
||||
}
|
||||
|
||||
// SubagentEvent carries a real-time event from a running subagent. Extensions
|
||||
// use the Type field to determine what happened and read the relevant fields.
|
||||
// This is a concrete struct (not an interface) for Yaegi compatibility.
|
||||
type SubagentEvent struct {
|
||||
// Type identifies the event: "text", "reasoning", "tool_call",
|
||||
// "tool_result", "tool_execution_start", "tool_execution_end",
|
||||
// "turn_start", "turn_end".
|
||||
Type string
|
||||
|
||||
// Content carries text for "text" and "reasoning" events.
|
||||
Content string
|
||||
|
||||
// ToolCallID is set on tool_call, tool_result, tool_execution_start,
|
||||
// and tool_execution_end events.
|
||||
ToolCallID string
|
||||
// ToolName is set on tool-related events.
|
||||
ToolName string
|
||||
// ToolKind is set on tool-related events.
|
||||
ToolKind string
|
||||
// ToolArgs is set on tool_call events (JSON-encoded).
|
||||
ToolArgs string
|
||||
// ToolResult is set on tool_result events.
|
||||
ToolResult string
|
||||
// IsError is set on tool_result events.
|
||||
IsError bool
|
||||
}
|
||||
|
||||
// SubagentResult contains the outcome of a subagent execution.
|
||||
type SubagentResult struct {
|
||||
// Response is the subagent's final text response.
|
||||
Response string
|
||||
|
||||
// Error is set if the subagent failed (nil on success).
|
||||
Error error
|
||||
|
||||
// ExitCode is the subprocess exit code (0 = success).
|
||||
ExitCode int
|
||||
|
||||
// Elapsed is the total execution time.
|
||||
Elapsed time.Duration
|
||||
|
||||
// Usage contains token usage if available.
|
||||
Usage *SubagentUsage
|
||||
|
||||
// SessionID is the subagent's session identifier, if available.
|
||||
// Populated when the subagent persists its session (requires running
|
||||
// without --no-session). Empty for ephemeral sessions.
|
||||
SessionID string
|
||||
}
|
||||
|
||||
// SubagentUsage contains token usage from the subagent's run.
|
||||
type SubagentUsage struct {
|
||||
InputTokens int64
|
||||
OutputTokens int64
|
||||
}
|
||||
|
||||
// SubagentHandle provides control over a running subagent.
|
||||
type SubagentHandle struct {
|
||||
// ID is a unique identifier for this subagent instance.
|
||||
ID string
|
||||
|
||||
proc *os.Process
|
||||
done chan struct{}
|
||||
result *SubagentResult
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Kill terminates the subagent process.
|
||||
func (h *SubagentHandle) Kill() error {
|
||||
h.mu.Lock()
|
||||
proc := h.proc
|
||||
h.mu.Unlock()
|
||||
if proc != nil {
|
||||
return proc.Kill()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Wait blocks until the subagent completes and returns the result.
|
||||
func (h *SubagentHandle) Wait() SubagentResult {
|
||||
<-h.done
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if h.result != nil {
|
||||
return *h.result
|
||||
}
|
||||
return SubagentResult{Error: fmt.Errorf("subagent completed without result")}
|
||||
}
|
||||
|
||||
// Done returns a channel that closes when the subagent completes.
|
||||
func (h *SubagentHandle) Done() <-chan struct{} {
|
||||
return h.done
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Internal helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// subagentJSONOutput matches the JSON envelope produced by `kit --json`.
|
||||
type subagentJSONOutput struct {
|
||||
Response string `json:"response"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
SessionID string `json:"session_id,omitempty"`
|
||||
Usage *struct {
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
} `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
var subagentCounter uint64
|
||||
|
||||
func generateSubagentID() string {
|
||||
n := atomic.AddUint64(&subagentCounter, 1)
|
||||
return fmt.Sprintf("sub-%d-%d", time.Now().UnixNano(), n)
|
||||
}
|
||||
|
||||
func findKitBinary() string {
|
||||
// Try the current process executable first.
|
||||
if exe, err := os.Executable(); err == nil {
|
||||
if _, err := os.Stat(exe); err == nil {
|
||||
return exe
|
||||
}
|
||||
}
|
||||
// Fall back to PATH lookup.
|
||||
if p, err := exec.LookPath("kit"); err == nil {
|
||||
return p
|
||||
}
|
||||
return "kit"
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SpawnSubagent implementation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SpawnSubagent spawns a child Kit instance to perform a task.
|
||||
//
|
||||
// When config.Blocking is true, blocks until completion and returns the result
|
||||
// directly (handle is nil). When false, returns immediately with a handle for
|
||||
// monitoring/cancellation.
|
||||
//
|
||||
// The subagent runs with --json --no-session --no-extensions flags by default,
|
||||
// ensuring isolation from the parent's extensions and session state.
|
||||
func SpawnSubagent(cfg SubagentConfig) (*SubagentHandle, *SubagentResult, error) {
|
||||
if cfg.Prompt == "" {
|
||||
return nil, nil, fmt.Errorf("prompt is required")
|
||||
}
|
||||
|
||||
timeout := cfg.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 5 * time.Minute
|
||||
}
|
||||
|
||||
kitBinary := findKitBinary()
|
||||
|
||||
// Build subprocess arguments.
|
||||
args := []string{
|
||||
"--json",
|
||||
"--no-extensions",
|
||||
}
|
||||
if cfg.NoSession {
|
||||
args = append(args, "--no-session")
|
||||
}
|
||||
if cfg.Model != "" {
|
||||
args = append(args, "--model", cfg.Model)
|
||||
}
|
||||
|
||||
// Handle system prompt - write to temp file if provided.
|
||||
var tmpFile *os.File
|
||||
if cfg.SystemPrompt != "" {
|
||||
var err error
|
||||
tmpFile, err = os.CreateTemp("", "kit-subagent-*.txt")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
if _, err := tmpFile.WriteString(cfg.SystemPrompt); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
return nil, nil, fmt.Errorf("write system prompt: %w", err)
|
||||
}
|
||||
_ = tmpFile.Close()
|
||||
args = append(args, "--system-prompt", tmpFile.Name())
|
||||
}
|
||||
|
||||
// Add the prompt as a positional argument.
|
||||
args = append(args, cfg.Prompt)
|
||||
|
||||
// Create command with timeout context.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
|
||||
cmd := exec.CommandContext(ctx, kitBinary, args...)
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
cancel()
|
||||
if tmpFile != nil {
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
}
|
||||
return nil, nil, fmt.Errorf("stdout pipe: %w", err)
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
cancel()
|
||||
if tmpFile != nil {
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
}
|
||||
return nil, nil, fmt.Errorf("stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
handle := &SubagentHandle{
|
||||
ID: generateSubagentID(),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start the subprocess.
|
||||
start := time.Now()
|
||||
if err := cmd.Start(); err != nil {
|
||||
cancel()
|
||||
if tmpFile != nil {
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
}
|
||||
return nil, nil, fmt.Errorf("start subprocess: %w", err)
|
||||
}
|
||||
|
||||
handle.mu.Lock()
|
||||
handle.proc = cmd.Process
|
||||
handle.mu.Unlock()
|
||||
|
||||
// Run the subprocess monitoring in a goroutine.
|
||||
go func() {
|
||||
defer close(handle.done)
|
||||
defer cancel()
|
||||
if tmpFile != nil {
|
||||
defer func() { _ = os.Remove(tmpFile.Name()) }()
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var stdoutBuf strings.Builder
|
||||
|
||||
// Read stderr (live output).
|
||||
wg.Go(func() {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
scanner.Buffer(make([]byte, 256*1024), 256*1024)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if cfg.OnOutput != nil && strings.TrimSpace(line) != "" {
|
||||
cfg.OnOutput(line + "\n")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Read stdout (JSON output).
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
scanner.Buffer(make([]byte, 256*1024), 256*1024)
|
||||
for scanner.Scan() {
|
||||
stdoutBuf.WriteString(scanner.Text() + "\n")
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
waitErr := cmd.Wait()
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// Build result.
|
||||
result := SubagentResult{Elapsed: elapsed}
|
||||
if waitErr != nil {
|
||||
result.Error = waitErr
|
||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||
result.ExitCode = exitErr.ExitCode()
|
||||
} else {
|
||||
result.ExitCode = 1
|
||||
}
|
||||
}
|
||||
|
||||
// Parse JSON output.
|
||||
raw := strings.TrimSpace(stdoutBuf.String())
|
||||
var parsed subagentJSONOutput
|
||||
if raw != "" && json.Unmarshal([]byte(raw), &parsed) == nil {
|
||||
result.Response = parsed.Response
|
||||
result.SessionID = parsed.SessionID
|
||||
if parsed.Usage != nil {
|
||||
result.Usage = &SubagentUsage{
|
||||
InputTokens: parsed.Usage.InputTokens,
|
||||
OutputTokens: parsed.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback: use raw stdout.
|
||||
result.Response = raw
|
||||
}
|
||||
|
||||
handle.mu.Lock()
|
||||
handle.result = &result
|
||||
handle.proc = nil
|
||||
handle.mu.Unlock()
|
||||
|
||||
if cfg.OnComplete != nil {
|
||||
cfg.OnComplete(result)
|
||||
}
|
||||
}()
|
||||
|
||||
if cfg.Blocking {
|
||||
// Wait for completion and return result directly.
|
||||
<-handle.done
|
||||
handle.mu.Lock()
|
||||
r := handle.result
|
||||
handle.mu.Unlock()
|
||||
return nil, r, nil
|
||||
}
|
||||
|
||||
return handle, nil, nil
|
||||
}
|
||||
@@ -90,12 +90,14 @@ func Symbols() interp.Exports {
|
||||
"EditorConfig": reflect.ValueOf((*EditorConfig)(nil)),
|
||||
|
||||
// Prompt types
|
||||
"PromptSelectConfig": reflect.ValueOf((*PromptSelectConfig)(nil)),
|
||||
"PromptSelectResult": reflect.ValueOf((*PromptSelectResult)(nil)),
|
||||
"PromptConfirmConfig": reflect.ValueOf((*PromptConfirmConfig)(nil)),
|
||||
"PromptConfirmResult": reflect.ValueOf((*PromptConfirmResult)(nil)),
|
||||
"PromptInputConfig": reflect.ValueOf((*PromptInputConfig)(nil)),
|
||||
"PromptInputResult": reflect.ValueOf((*PromptInputResult)(nil)),
|
||||
"PromptSelectConfig": reflect.ValueOf((*PromptSelectConfig)(nil)),
|
||||
"PromptSelectResult": reflect.ValueOf((*PromptSelectResult)(nil)),
|
||||
"PromptConfirmConfig": reflect.ValueOf((*PromptConfirmConfig)(nil)),
|
||||
"PromptConfirmResult": reflect.ValueOf((*PromptConfirmResult)(nil)),
|
||||
"PromptInputConfig": reflect.ValueOf((*PromptInputConfig)(nil)),
|
||||
"PromptInputResult": reflect.ValueOf((*PromptInputResult)(nil)),
|
||||
"PromptMultiSelectConfig": reflect.ValueOf((*PromptMultiSelectConfig)(nil)),
|
||||
"PromptMultiSelectResult": reflect.ValueOf((*PromptMultiSelectResult)(nil)),
|
||||
|
||||
// Context filtering types
|
||||
"ContextMessage": reflect.ValueOf((*ContextMessage)(nil)),
|
||||
@@ -110,6 +112,17 @@ func Symbols() interp.Exports {
|
||||
"BeforeCompactEvent": reflect.ValueOf((*BeforeCompactEvent)(nil)),
|
||||
"BeforeCompactResult": reflect.ValueOf((*BeforeCompactResult)(nil)),
|
||||
|
||||
// Subagent types
|
||||
"SubagentConfig": reflect.ValueOf((*SubagentConfig)(nil)),
|
||||
"SubagentResult": reflect.ValueOf((*SubagentResult)(nil)),
|
||||
"SubagentUsage": reflect.ValueOf((*SubagentUsage)(nil)),
|
||||
"SubagentHandle": reflect.ValueOf((*SubagentHandle)(nil)),
|
||||
"SubagentEvent": reflect.ValueOf((*SubagentEvent)(nil)),
|
||||
|
||||
// Theme types
|
||||
"ThemeColor": reflect.ValueOf((*ThemeColor)(nil)),
|
||||
"ThemeColorConfig": reflect.ValueOf((*ThemeColorConfig)(nil)),
|
||||
|
||||
// Event structs
|
||||
"ToolCallEvent": reflect.ValueOf((*ToolCallEvent)(nil)),
|
||||
"ToolCallResult": reflect.ValueOf((*ToolCallResult)(nil)),
|
||||
|
||||
@@ -0,0 +1,169 @@
|
||||
package extensions
|
||||
|
||||
// NewTestAPI creates an API object wired for testing.
|
||||
// This is used by the test harness to load extensions and verify behavior.
|
||||
// The registration functions wire handlers directly to the provided extension.
|
||||
func NewTestAPI(ext *LoadedExtension) API {
|
||||
reg := func(event EventType, fn HandlerFunc) {
|
||||
ext.Handlers[event] = append(ext.Handlers[event], fn)
|
||||
}
|
||||
|
||||
return API{
|
||||
onToolCall: func(h func(ToolCallEvent, Context) *ToolCallResult) {
|
||||
reg(ToolCall, func(e Event, c Context) Result {
|
||||
r := h(e.(ToolCallEvent), c)
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return *r
|
||||
})
|
||||
},
|
||||
onToolExecStart: func(h func(ToolExecutionStartEvent, Context)) {
|
||||
reg(ToolExecutionStart, func(e Event, c Context) Result {
|
||||
h(e.(ToolExecutionStartEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onToolExecEnd: func(h func(ToolExecutionEndEvent, Context)) {
|
||||
reg(ToolExecutionEnd, func(e Event, c Context) Result {
|
||||
h(e.(ToolExecutionEndEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onToolResult: func(h func(ToolResultEvent, Context) *ToolResultResult) {
|
||||
reg(ToolResult, func(e Event, c Context) Result {
|
||||
r := h(e.(ToolResultEvent), c)
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return *r
|
||||
})
|
||||
},
|
||||
onInput: func(h func(InputEvent, Context) *InputResult) {
|
||||
reg(Input, func(e Event, c Context) Result {
|
||||
r := h(e.(InputEvent), c)
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return *r
|
||||
})
|
||||
},
|
||||
onBeforeAgentStart: func(h func(BeforeAgentStartEvent, Context) *BeforeAgentStartResult) {
|
||||
reg(BeforeAgentStart, func(e Event, c Context) Result {
|
||||
r := h(e.(BeforeAgentStartEvent), c)
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return *r
|
||||
})
|
||||
},
|
||||
onAgentStart: func(h func(AgentStartEvent, Context)) {
|
||||
reg(AgentStart, func(e Event, c Context) Result {
|
||||
h(e.(AgentStartEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onAgentEnd: func(h func(AgentEndEvent, Context)) {
|
||||
reg(AgentEnd, func(e Event, c Context) Result {
|
||||
h(e.(AgentEndEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onMessageStart: func(h func(MessageStartEvent, Context)) {
|
||||
reg(MessageStart, func(e Event, c Context) Result {
|
||||
h(e.(MessageStartEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onMessageUpdate: func(h func(MessageUpdateEvent, Context)) {
|
||||
reg(MessageUpdate, func(e Event, c Context) Result {
|
||||
h(e.(MessageUpdateEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onMessageEnd: func(h func(MessageEndEvent, Context)) {
|
||||
reg(MessageEnd, func(e Event, c Context) Result {
|
||||
h(e.(MessageEndEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onSessionStart: func(h func(SessionStartEvent, Context)) {
|
||||
reg(SessionStart, func(e Event, c Context) Result {
|
||||
h(e.(SessionStartEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onSessionShutdown: func(h func(SessionShutdownEvent, Context)) {
|
||||
reg(SessionShutdown, func(e Event, c Context) Result {
|
||||
h(e.(SessionShutdownEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onModelChange: func(h func(ModelChangeEvent, Context)) {
|
||||
reg(ModelChange, func(e Event, c Context) Result {
|
||||
h(e.(ModelChangeEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onContextPrepare: func(h func(ContextPrepareEvent, Context) *ContextPrepareResult) {
|
||||
reg(ContextPrepare, func(e Event, c Context) Result {
|
||||
r := h(e.(ContextPrepareEvent), c)
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return *r
|
||||
})
|
||||
},
|
||||
onBeforeFork: func(h func(BeforeForkEvent, Context) *BeforeForkResult) {
|
||||
reg(BeforeFork, func(e Event, c Context) Result {
|
||||
r := h(e.(BeforeForkEvent), c)
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return *r
|
||||
})
|
||||
},
|
||||
onBeforeSessionSwitch: func(h func(BeforeSessionSwitchEvent, Context) *BeforeSessionSwitchResult) {
|
||||
reg(BeforeSessionSwitch, func(e Event, c Context) Result {
|
||||
r := h(e.(BeforeSessionSwitchEvent), c)
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return *r
|
||||
})
|
||||
},
|
||||
onBeforeCompact: func(h func(BeforeCompactEvent, Context) *BeforeCompactResult) {
|
||||
reg(BeforeCompact, func(e Event, c Context) Result {
|
||||
r := h(e.(BeforeCompactEvent), c)
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return *r
|
||||
})
|
||||
},
|
||||
registerToolFn: func(tool ToolDef) {
|
||||
ext.Tools = append(ext.Tools, tool)
|
||||
},
|
||||
registerCmdFn: func(cmd CommandDef) {
|
||||
ext.Commands = append(ext.Commands, cmd)
|
||||
},
|
||||
registerToolRendererFn: func(config ToolRenderConfig) {
|
||||
ext.ToolRenderers = append(ext.ToolRenderers, config)
|
||||
},
|
||||
onCustomEvent: func(name string, handler func(string)) {
|
||||
if ext.CustomEventHandlers == nil {
|
||||
ext.CustomEventHandlers = make(map[string][]func(string))
|
||||
}
|
||||
ext.CustomEventHandlers[name] = append(ext.CustomEventHandlers[name], handler)
|
||||
},
|
||||
registerOption: func(opt OptionDef) {
|
||||
ext.Options = append(ext.Options, opt)
|
||||
},
|
||||
registerShortcutFn: func(def ShortcutDef, handler func(Context)) {
|
||||
ext.Shortcuts = append(ext.Shortcuts, ShortcutEntry{Def: def, Handler: handler})
|
||||
},
|
||||
registerMessageRendererFn: func(config MessageRendererConfig) {
|
||||
ext.MessageRenderers = append(ext.MessageRenderers, config)
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -40,6 +40,37 @@ func ExtensionToolsAsFantasy(defs []ToolDef, runner *Runner) []fantasy.AgentTool
|
||||
return tools
|
||||
}
|
||||
|
||||
// coreToolKinds maps built-in tool names to their kind classification.
|
||||
var coreToolKinds = map[string]string{
|
||||
"bash": "execute",
|
||||
"edit": "edit",
|
||||
"write": "edit",
|
||||
"read": "read",
|
||||
"ls": "read",
|
||||
"grep": "search",
|
||||
"find": "search",
|
||||
"spawn_subagent": "agent",
|
||||
}
|
||||
|
||||
// toolKindFor returns the ToolKind for a given tool name, defaulting to
|
||||
// "execute" for unknown tools (including MCP tools).
|
||||
func toolKindFor(toolName string) string {
|
||||
if kind, ok := coreToolKinds[toolName]; ok {
|
||||
return kind
|
||||
}
|
||||
return "execute"
|
||||
}
|
||||
|
||||
// parseToolArgsJSON attempts to parse JSON-encoded tool args into a map.
|
||||
// Returns nil on failure (non-fatal convenience parsing).
|
||||
func parseToolArgsJSON(input string) map[string]any {
|
||||
var parsed map[string]any
|
||||
if json.Unmarshal([]byte(input), &parsed) == nil {
|
||||
return parsed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// wrappedTool — intercepts tool calls through the extension runner
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -63,12 +94,16 @@ func (w *wrappedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.T
|
||||
fmt.Errorf("tool %q disabled by extension", toolName)
|
||||
}
|
||||
|
||||
kind := toolKindFor(toolName)
|
||||
|
||||
// 1. Emit ToolCall — extensions can block execution.
|
||||
if w.runner.HasHandlers(ToolCall) {
|
||||
result, _ := w.runner.Emit(ToolCallEvent{
|
||||
ToolName: toolName,
|
||||
ToolCallID: call.ID,
|
||||
ToolKind: kind,
|
||||
Input: call.Input,
|
||||
ParsedArgs: parseToolArgsJSON(call.Input),
|
||||
Source: "llm",
|
||||
})
|
||||
if r, ok := result.(ToolCallResult); ok && r.Block {
|
||||
@@ -83,7 +118,7 @@ func (w *wrappedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.T
|
||||
|
||||
// 2. Emit ToolExecutionStart.
|
||||
if w.runner.HasHandlers(ToolExecutionStart) {
|
||||
_, _ = w.runner.Emit(ToolExecutionStartEvent{ToolName: toolName})
|
||||
_, _ = w.runner.Emit(ToolExecutionStartEvent{ToolCallID: call.ID, ToolName: toolName, ToolKind: kind})
|
||||
}
|
||||
|
||||
// 3. Execute the actual tool.
|
||||
@@ -91,16 +126,19 @@ func (w *wrappedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.T
|
||||
|
||||
// 4. Emit ToolExecutionEnd.
|
||||
if w.runner.HasHandlers(ToolExecutionEnd) {
|
||||
_, _ = w.runner.Emit(ToolExecutionEndEvent{ToolName: toolName})
|
||||
_, _ = w.runner.Emit(ToolExecutionEndEvent{ToolCallID: call.ID, ToolName: toolName, ToolKind: kind})
|
||||
}
|
||||
|
||||
// 5. Emit ToolResult — extensions can modify output.
|
||||
if w.runner.HasHandlers(ToolResult) {
|
||||
result, _ := w.runner.Emit(ToolResultEvent{
|
||||
ToolName: toolName,
|
||||
Input: call.Input,
|
||||
Content: resp.Content,
|
||||
IsError: err != nil || resp.IsError,
|
||||
ToolCallID: call.ID,
|
||||
ToolName: toolName,
|
||||
ToolKind: kind,
|
||||
Input: call.Input,
|
||||
Content: resp.Content,
|
||||
IsError: err != nil || resp.IsError,
|
||||
Metadata: resp.Metadata,
|
||||
})
|
||||
if r, ok := result.(ToolResultResult); ok {
|
||||
if r.Content != nil {
|
||||
|
||||
@@ -0,0 +1,168 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// ProviderPool manages reusable LLM provider instances to reduce overhead
|
||||
// when spawning multiple subagents or making repeated completion calls.
|
||||
type ProviderPool struct {
|
||||
mu sync.RWMutex
|
||||
providers map[string]*pooledProvider
|
||||
ttl time.Duration
|
||||
closed bool
|
||||
closeCh chan struct{}
|
||||
}
|
||||
|
||||
type pooledProvider struct {
|
||||
model fantasy.LanguageModel
|
||||
closer func() error
|
||||
providerOpts fantasy.ProviderOptions
|
||||
created time.Time
|
||||
lastUsed time.Time
|
||||
refs int32
|
||||
}
|
||||
|
||||
// DefaultPoolTTL is the default time-to-live for idle pooled providers.
|
||||
const DefaultPoolTTL = 5 * time.Minute
|
||||
|
||||
// globalPool is the singleton provider pool instance.
|
||||
var globalPool *ProviderPool
|
||||
var poolOnce sync.Once
|
||||
|
||||
// GetGlobalPool returns the singleton provider pool instance.
|
||||
func GetGlobalPool() *ProviderPool {
|
||||
poolOnce.Do(func() {
|
||||
globalPool = NewProviderPool(DefaultPoolTTL)
|
||||
})
|
||||
return globalPool
|
||||
}
|
||||
|
||||
// NewProviderPool creates a provider pool with the given TTL for idle providers.
|
||||
func NewProviderPool(ttl time.Duration) *ProviderPool {
|
||||
p := &ProviderPool{
|
||||
providers: make(map[string]*pooledProvider),
|
||||
ttl: ttl,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
go p.cleanupLoop()
|
||||
return p
|
||||
}
|
||||
|
||||
// Get returns a provider for the model string, creating one if needed.
|
||||
// The returned release function must be called when the provider is no longer
|
||||
// needed. The provider may be reused by subsequent Get calls.
|
||||
func (p *ProviderPool) Get(ctx context.Context, modelString string) (fantasy.LanguageModel, fantasy.ProviderOptions, func(), error) {
|
||||
p.mu.Lock()
|
||||
|
||||
// Check if we have an existing provider.
|
||||
if pp, ok := p.providers[modelString]; ok {
|
||||
pp.refs++
|
||||
pp.lastUsed = time.Now()
|
||||
p.mu.Unlock()
|
||||
return pp.model, pp.providerOpts, func() { p.release(modelString) }, nil
|
||||
}
|
||||
|
||||
p.mu.Unlock()
|
||||
|
||||
// Create a new provider outside the lock.
|
||||
config := &ProviderConfig{ModelString: modelString}
|
||||
result, err := CreateProvider(ctx, config)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// Double-check: another goroutine may have created one while we were unlocked.
|
||||
if pp, ok := p.providers[modelString]; ok {
|
||||
// Close the one we just created and use the existing one.
|
||||
if result.Closer != nil {
|
||||
_ = result.Closer.Close()
|
||||
}
|
||||
pp.refs++
|
||||
pp.lastUsed = time.Now()
|
||||
return pp.model, pp.providerOpts, func() { p.release(modelString) }, nil
|
||||
}
|
||||
|
||||
var closerFn func() error
|
||||
if result.Closer != nil {
|
||||
closerFn = result.Closer.Close
|
||||
}
|
||||
|
||||
pp := &pooledProvider{
|
||||
model: result.Model,
|
||||
closer: closerFn,
|
||||
providerOpts: result.ProviderOptions,
|
||||
created: time.Now(),
|
||||
lastUsed: time.Now(),
|
||||
refs: 1,
|
||||
}
|
||||
p.providers[modelString] = pp
|
||||
|
||||
return pp.model, pp.providerOpts, func() { p.release(modelString) }, nil
|
||||
}
|
||||
|
||||
func (p *ProviderPool) release(modelString string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if pp, ok := p.providers[modelString]; ok {
|
||||
pp.refs--
|
||||
pp.lastUsed = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProviderPool) cleanupLoop() {
|
||||
ticker := time.NewTicker(p.ttl / 2)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.closeCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
p.cleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProviderPool) cleanup() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for key, pp := range p.providers {
|
||||
// Only clean up providers with no active references and past TTL.
|
||||
if pp.refs <= 0 && now.Sub(pp.lastUsed) > p.ttl {
|
||||
if pp.closer != nil {
|
||||
_ = pp.closer()
|
||||
}
|
||||
delete(p.providers, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the pool and releases all providers.
|
||||
func (p *ProviderPool) Close() {
|
||||
p.mu.Lock()
|
||||
if p.closed {
|
||||
p.mu.Unlock()
|
||||
return
|
||||
}
|
||||
p.closed = true
|
||||
close(p.closeCh)
|
||||
|
||||
for key, pp := range p.providers {
|
||||
if pp.closer != nil {
|
||||
_ = pp.closer()
|
||||
}
|
||||
delete(p.providers, key)
|
||||
}
|
||||
p.mu.Unlock()
|
||||
}
|
||||
@@ -49,7 +49,7 @@ func resolveModelAlias(provider, modelName string) string {
|
||||
}
|
||||
|
||||
if resolved, exists := aliasMap[modelName]; exists {
|
||||
if _, err := registry.ValidateModel(provider, resolved); err == nil {
|
||||
if registry.LookupModel(provider, resolved) != nil {
|
||||
return resolved
|
||||
}
|
||||
}
|
||||
@@ -73,8 +73,8 @@ func ThinkingLevels() []ThinkingLevel {
|
||||
return []ThinkingLevel{ThinkingOff, ThinkingMinimal, ThinkingLow, ThinkingMedium, ThinkingHigh}
|
||||
}
|
||||
|
||||
// ThinkingBudgetTokens returns the token budget for a thinking level, or 0 for "off".
|
||||
func ThinkingBudgetTokens(level ThinkingLevel) int64 {
|
||||
// thinkingBudgetTokens returns the token budget for a thinking level, or 0 for "off".
|
||||
func thinkingBudgetTokens(level ThinkingLevel) int64 {
|
||||
switch level {
|
||||
case ThinkingMinimal:
|
||||
return 1024
|
||||
@@ -162,16 +162,6 @@ func ParseModelString(modelString string) (provider, model string, err error) {
|
||||
return "", "", fmt.Errorf("invalid model format %q: expected provider/model (e.g. anthropic/claude-sonnet-4-5)", modelString)
|
||||
}
|
||||
|
||||
// Legacy colon-separated format
|
||||
if strings.Contains(modelString, ":") {
|
||||
parts := strings.SplitN(modelString, ":", 2)
|
||||
if len(parts) == 2 && parts[0] != "" && parts[1] != "" {
|
||||
fmt.Fprintf(os.Stderr, "Warning: model format %q uses deprecated colon separator. Use %s/%s instead.\n",
|
||||
modelString, parts[0], parts[1])
|
||||
return parts[0], parts[1], nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", "", fmt.Errorf("invalid model format %q: expected provider/model (e.g. anthropic/claude-sonnet-4-5)", modelString)
|
||||
}
|
||||
|
||||
@@ -210,10 +200,11 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
|
||||
}
|
||||
}
|
||||
|
||||
// Validate environment variables
|
||||
if err := registry.ValidateEnvironment(provider, config.ProviderAPIKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// NOTE: We intentionally skip registry.ValidateEnvironment() here.
|
||||
// Each create*Provider function handles its own auth resolution and
|
||||
// produces provider-specific error messages. The early env-var check
|
||||
// was too narrow — it didn't account for stored credentials (e.g.
|
||||
// OAuth tokens from 'kit auth login') and blocked valid auth paths.
|
||||
|
||||
// Validate config against known model limits when metadata is available
|
||||
if modelInfo != nil {
|
||||
@@ -488,7 +479,7 @@ func buildAnthropicProviderOptions(config *ProviderConfig, modelName string) fan
|
||||
return nil
|
||||
}
|
||||
|
||||
budget := ThinkingBudgetTokens(config.ThinkingLevel)
|
||||
budget := thinkingBudgetTokens(config.ThinkingLevel)
|
||||
if budget == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -1042,9 +1033,21 @@ type oauthTransport struct {
|
||||
}
|
||||
|
||||
func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Resolve the freshest available token. The credential manager
|
||||
// automatically refreshes tokens nearing expiry (5-minute buffer).
|
||||
// This keeps long-lived sessions (e.g. ACP) working across token
|
||||
// renewals. Falls back to the originally-provided token if the
|
||||
// credential manager is unavailable.
|
||||
token := t.accessToken
|
||||
if cm, err := auth.NewCredentialManager(); err == nil {
|
||||
if fresh, err := cm.GetValidAccessToken(); err == nil && fresh != "" {
|
||||
token = fresh
|
||||
}
|
||||
}
|
||||
|
||||
newReq := req.Clone(req.Context())
|
||||
newReq.Header.Del("x-api-key")
|
||||
newReq.Header.Set("Authorization", "Bearer "+t.accessToken)
|
||||
newReq.Header.Set("Authorization", "Bearer "+token)
|
||||
newReq.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
newReq.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
|
||||
@@ -78,6 +78,7 @@ func TestCreateOAuthHTTPClient(t *testing.T) {
|
||||
|
||||
if client == nil {
|
||||
t.Fatal("expected non-nil client")
|
||||
return
|
||||
}
|
||||
|
||||
// Check that the transport is an oauthTransport
|
||||
|
||||
+21
-22
@@ -6,6 +6,8 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
)
|
||||
|
||||
//go:embed embedded_models.json
|
||||
@@ -145,24 +147,8 @@ func (r *ModelsRegistry) LookupModel(provider, modelID string) *ModelInfo {
|
||||
return &modelInfo
|
||||
}
|
||||
|
||||
// ValidateModel validates if a model exists and returns detailed information.
|
||||
// Deprecated: Use LookupModel instead — it returns nil for unknown models
|
||||
// rather than an error, letting the provider API be the authority.
|
||||
func (r *ModelsRegistry) ValidateModel(provider, modelID string) (*ModelInfo, error) {
|
||||
if info := r.LookupModel(provider, modelID); info != nil {
|
||||
return info, nil
|
||||
}
|
||||
|
||||
providerInfo, exists := r.providers[provider]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("unsupported provider: %s", provider)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("model %s not found for provider %s", modelID, providerInfo.ID)
|
||||
}
|
||||
|
||||
// GetRequiredEnvVars returns the required environment variables for a provider.
|
||||
func (r *ModelsRegistry) GetRequiredEnvVars(provider string) ([]string, error) {
|
||||
// getRequiredEnvVars returns the required environment variables for a provider.
|
||||
func (r *ModelsRegistry) getRequiredEnvVars(provider string) ([]string, error) {
|
||||
providerInfo, exists := r.providers[provider]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("unsupported provider: %s", provider)
|
||||
@@ -171,15 +157,28 @@ func (r *ModelsRegistry) GetRequiredEnvVars(provider string) ([]string, error) {
|
||||
return providerInfo.Env, nil
|
||||
}
|
||||
|
||||
// ValidateEnvironment checks if required environment variables are set.
|
||||
// Returns nil for providers not in the registry (unknown providers are
|
||||
// assumed to handle auth themselves or via --provider-api-key).
|
||||
// ValidateEnvironment checks if required credentials are available for a
|
||||
// provider. It checks the explicit API key, stored credentials (for
|
||||
// providers that support them, such as Anthropic OAuth), and environment
|
||||
// variables. Returns nil for providers not in the registry (unknown
|
||||
// providers are assumed to handle auth themselves or via --provider-api-key).
|
||||
func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) error {
|
||||
if apiKey != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
envVars, err := r.GetRequiredEnvVars(provider)
|
||||
// For anthropic, also check stored credentials (OAuth / API key)
|
||||
// since auth resolution goes through the credential manager, not
|
||||
// just environment variables.
|
||||
if provider == "anthropic" {
|
||||
if cm, err := auth.NewCredentialManager(); err == nil {
|
||||
if has, _ := cm.HasAnthropicCredentials(); has {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
envVars, err := r.getRequiredEnvVars(provider)
|
||||
if err != nil {
|
||||
// Unknown provider — nothing to validate
|
||||
return nil
|
||||
|
||||
@@ -38,6 +38,10 @@ type SessionHeader struct {
|
||||
Timestamp time.Time `json:"timestamp"` // creation time
|
||||
Cwd string `json:"cwd"` // working directory
|
||||
ParentSession string `json:"parent_session,omitempty"` // path to parent if forked
|
||||
|
||||
// Subagent fields (set when session is created by a subagent)
|
||||
ParentSessionID string `json:"parent_session_id,omitempty"` // UUID of parent session
|
||||
SubagentTask string `json:"subagent_task,omitempty"` // original task prompt
|
||||
}
|
||||
|
||||
// Entry is the common structure shared by all tree entries (everything except
|
||||
@@ -140,17 +144,6 @@ func NewMessageEntry(parentID string, msg message.Message) (*MessageEntry, error
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewMessageEntryFromRaw creates a MessageEntry with pre-marshaled parts.
|
||||
func NewMessageEntryFromRaw(parentID, role string, parts json.RawMessage, model, provider string) *MessageEntry {
|
||||
return &MessageEntry{
|
||||
Entry: NewEntry(EntryTypeMessage, parentID),
|
||||
Role: role,
|
||||
Parts: parts,
|
||||
Model: model,
|
||||
Provider: provider,
|
||||
}
|
||||
}
|
||||
|
||||
// NewModelChangeEntry creates a ModelChangeEntry.
|
||||
func NewModelChangeEntry(parentID, provider, modelID string) *ModelChangeEntry {
|
||||
return &ModelChangeEntry{
|
||||
|
||||
@@ -29,6 +29,12 @@ type SessionInfo struct {
|
||||
// ParentSessionPath is the parent session path if this session was forked.
|
||||
ParentSessionPath string
|
||||
|
||||
// ParentSessionID is the UUID of the parent session (for subagent sessions).
|
||||
ParentSessionID string
|
||||
|
||||
// SubagentTask is the original task prompt (for subagent sessions).
|
||||
SubagentTask string
|
||||
|
||||
// Created is when the session was first created.
|
||||
Created time.Time
|
||||
|
||||
@@ -162,6 +168,8 @@ func extractSessionInfo(path string) (*SessionInfo, error) {
|
||||
info.Created = h.Timestamp
|
||||
info.Modified = h.Timestamp
|
||||
info.ParentSessionPath = h.ParentSession
|
||||
info.ParentSessionID = h.ParentSessionID
|
||||
info.SubagentTask = h.SubagentTask
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ type blockRenderer struct {
|
||||
align *lipgloss.Position
|
||||
borderColor *color.Color
|
||||
background *color.Color
|
||||
foreground *color.Color
|
||||
fullWidth bool
|
||||
noBorder bool
|
||||
paddingTop int
|
||||
@@ -123,6 +124,15 @@ func WithBackground(c color.Color) renderingOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithForeground returns a renderingOption that overrides the default text
|
||||
// foreground color (theme.Text) for the block. Useful for muted or
|
||||
// de-emphasized content blocks.
|
||||
func WithForeground(c color.Color) renderingOption {
|
||||
return func(br *blockRenderer) {
|
||||
br.foreground = &c
|
||||
}
|
||||
}
|
||||
|
||||
// WithWidth returns a renderingOption that sets a specific width for the block
|
||||
// in characters. This overrides the default container width and allows precise
|
||||
// control over the block's horizontal dimensions.
|
||||
@@ -167,13 +177,19 @@ func renderContentBlock(content string, containerWidth int, options ...rendering
|
||||
|
||||
theme := GetTheme()
|
||||
|
||||
// Resolve foreground color: caller override or theme default.
|
||||
fgColor := theme.Text
|
||||
if renderer.foreground != nil {
|
||||
fgColor = *renderer.foreground
|
||||
}
|
||||
|
||||
// Single-pass render: padding, border, and foreground in one style.
|
||||
style := lipgloss.NewStyle().
|
||||
PaddingLeft(renderer.paddingLeft).
|
||||
PaddingRight(renderer.paddingRight).
|
||||
PaddingTop(renderer.paddingTop).
|
||||
PaddingBottom(renderer.paddingBottom).
|
||||
Foreground(theme.Text)
|
||||
Foreground(fgColor)
|
||||
|
||||
if hasBorder {
|
||||
style = style.BorderStyle(lipgloss.ThickBorder())
|
||||
|
||||
@@ -348,6 +348,9 @@ func TestStreamComponent_SpinnerKeepsRunningDuringStreaming(t *testing.T) {
|
||||
// Receive first chunk — spinner should keep running.
|
||||
c = sendStreamMsg(c, app.StreamChunkEvent{Content: "hello"})
|
||||
|
||||
// Flush pending chunks (simulates the 16ms tick firing).
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{})
|
||||
|
||||
if !c.spinning {
|
||||
t.Fatal("expected spinning=true after first chunk")
|
||||
}
|
||||
@@ -372,6 +375,9 @@ func TestStreamComponent_ChunkAccumulation(t *testing.T) {
|
||||
c = sendStreamMsg(c, app.StreamChunkEvent{Content: chunk})
|
||||
}
|
||||
|
||||
// Flush pending chunks (simulates the 16ms tick firing).
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{})
|
||||
|
||||
got := c.streamContent.String()
|
||||
want := "Hello, world!"
|
||||
if got != want {
|
||||
@@ -397,8 +403,8 @@ func TestStreamComponent_ToolExecution_IsStarting_ShowsSpinner(t *testing.T) {
|
||||
if !c.spinning {
|
||||
t.Fatal("expected spinning=true during tool execution")
|
||||
}
|
||||
if !strings.Contains(c.spinnerMsg, "exec_tool") {
|
||||
t.Fatalf("expected spinnerMsg to contain tool name, got %q", c.spinnerMsg)
|
||||
if len(c.activeTools) != 1 || !strings.Contains(c.activeTools[0], "exec_tool") {
|
||||
t.Fatalf("expected activeTools to contain tool name, got %v", c.activeTools)
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Fatal("expected tick cmd from ToolExecutionEvent{IsStarting:true}")
|
||||
@@ -410,7 +416,11 @@ func TestStreamComponent_ToolExecution_NotStarting_KeepsSpinning(t *testing.T) {
|
||||
c := newTestStream()
|
||||
// Start spinning first (simulating execution in progress).
|
||||
c = sendStreamMsg(c, app.SpinnerEvent{Show: true})
|
||||
c.spinnerMsg = "Executing some_tool…"
|
||||
// Simulate a tool starting
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{
|
||||
ToolName: "some_tool",
|
||||
IsStarting: true,
|
||||
})
|
||||
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{
|
||||
ToolName: "some_tool",
|
||||
@@ -420,8 +430,41 @@ func TestStreamComponent_ToolExecution_NotStarting_KeepsSpinning(t *testing.T) {
|
||||
if !c.spinning {
|
||||
t.Fatal("expected spinning=true after tool execution finished (spinner keeps running)")
|
||||
}
|
||||
if c.spinnerMsg != "" {
|
||||
t.Fatalf("expected spinnerMsg cleared after tool finished, got %q", c.spinnerMsg)
|
||||
if len(c.activeTools) != 0 {
|
||||
t.Fatalf("expected activeTools cleared after tool finished, got %v", c.activeTools)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamComponent_ParallelToolExecution verifies multiple tools can run concurrently.
|
||||
func TestStreamComponent_ParallelToolExecution(t *testing.T) {
|
||||
c := newTestStream()
|
||||
|
||||
// Start three tools in parallel
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "read", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "grep", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "find", IsStarting: true})
|
||||
|
||||
if len(c.activeTools) != 3 {
|
||||
t.Fatalf("expected 3 active tools, got %d: %v", len(c.activeTools), c.activeTools)
|
||||
}
|
||||
|
||||
// Check SpinnerView shows all tools
|
||||
view := c.SpinnerView()
|
||||
if !strings.Contains(view, "Running:") {
|
||||
t.Fatalf("expected spinner view to contain 'Running:' for multiple tools, got %q", view)
|
||||
}
|
||||
|
||||
// Finish one tool
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "grep", IsStarting: false})
|
||||
if len(c.activeTools) != 2 {
|
||||
t.Fatalf("expected 2 active tools after one finished, got %d: %v", len(c.activeTools), c.activeTools)
|
||||
}
|
||||
|
||||
// Finish remaining tools
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "read", IsStarting: false})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "find", IsStarting: false})
|
||||
if len(c.activeTools) != 0 {
|
||||
t.Fatalf("expected 0 active tools after all finished, got %d: %v", len(c.activeTools), c.activeTools)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -480,8 +523,8 @@ func TestStreamComponent_Reset(t *testing.T) {
|
||||
if !c.timestamp.IsZero() {
|
||||
t.Fatal("expected zero timestamp after Reset()")
|
||||
}
|
||||
if c.spinnerMsg != "" {
|
||||
t.Fatalf("expected spinnerMsg empty after Reset(), got %q", c.spinnerMsg)
|
||||
if len(c.activeTools) != 0 {
|
||||
t.Fatalf("expected activeTools empty after Reset(), got %v", c.activeTools)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -517,9 +560,10 @@ func TestStreamComponent_SpinnerTick_AdvancesFrame(t *testing.T) {
|
||||
// Start spinning first.
|
||||
c = sendStreamMsg(c, app.SpinnerEvent{Show: true})
|
||||
initialFrame := c.spinnerFrame
|
||||
gen := c.spinnerGeneration
|
||||
|
||||
// Send a tick.
|
||||
_, cmd := c.Update(streamSpinnerTickMsg{})
|
||||
// Send a tick with the current generation.
|
||||
_, cmd := c.Update(streamSpinnerTickMsg{generation: gen})
|
||||
|
||||
if c.spinnerFrame != initialFrame+1 {
|
||||
t.Fatalf("expected spinnerFrame=%d, got %d", initialFrame+1, c.spinnerFrame)
|
||||
@@ -540,3 +584,40 @@ func TestStreamComponent_SpinnerTick_NoReschedule_WhenNotSpinning(t *testing.T)
|
||||
t.Fatal("expected no tick reschedule when not spinning")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamComponent_StaleTick_Discarded verifies that a tick from a previous
|
||||
// spinner generation is silently discarded, preventing duplicate concurrent
|
||||
// tick loops that would double the animation speed.
|
||||
func TestStreamComponent_StaleTick_Discarded(t *testing.T) {
|
||||
c := newTestStream()
|
||||
|
||||
// Start spinner → generation 1.
|
||||
c = sendStreamMsg(c, app.SpinnerEvent{Show: true})
|
||||
staleGen := c.spinnerGeneration
|
||||
|
||||
// Stop spinner → generation bumped to 2.
|
||||
c = sendStreamMsg(c, app.SpinnerEvent{Show: false})
|
||||
|
||||
// Restart spinner → generation bumped to 3.
|
||||
c = sendStreamMsg(c, app.SpinnerEvent{Show: true})
|
||||
currentGen := c.spinnerGeneration
|
||||
frameBefore := c.spinnerFrame
|
||||
|
||||
// Simulate a stale tick from the first spinner session arriving.
|
||||
_, cmd := c.Update(streamSpinnerTickMsg{generation: staleGen})
|
||||
if c.spinnerFrame != frameBefore {
|
||||
t.Fatalf("stale tick should not advance frame: expected %d, got %d", frameBefore, c.spinnerFrame)
|
||||
}
|
||||
if cmd != nil {
|
||||
t.Fatal("stale tick should not reschedule")
|
||||
}
|
||||
|
||||
// A tick from the current generation should still work.
|
||||
_, cmd = c.Update(streamSpinnerTickMsg{generation: currentGen})
|
||||
if c.spinnerFrame != frameBefore+1 {
|
||||
t.Fatalf("current-gen tick should advance frame: expected %d, got %d", frameBefore+1, c.spinnerFrame)
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Fatal("current-gen tick should reschedule")
|
||||
}
|
||||
}
|
||||
|
||||
+2
-9
@@ -36,7 +36,7 @@ func NewCLI(debug bool, compact bool) (*CLI, error) {
|
||||
if compact {
|
||||
cli.renderer = NewCompactRenderer(cli.width, debug)
|
||||
} else {
|
||||
cli.renderer = NewMessageRenderer(cli.width, debug)
|
||||
cli.renderer = newMessageRenderer(cli.width, debug)
|
||||
}
|
||||
|
||||
return cli, nil
|
||||
@@ -108,13 +108,6 @@ func (c *CLI) DisplayAssistantMessageWithModel(message, modelName string) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// DisplayToolCallMessage is a no-op retained for backward compatibility. Tool
|
||||
// calls are now rendered as part of the unified tool block in DisplayToolMessage,
|
||||
// which combines the invocation header with the execution result.
|
||||
func (c *CLI) DisplayToolCallMessage(toolName, toolArgs string) {
|
||||
// No-op: unified tool blocks are rendered in DisplayToolMessage.
|
||||
}
|
||||
|
||||
// DisplayToolMessage renders and displays the complete result of a tool execution,
|
||||
// including the tool name, arguments, and result. The isError parameter determines
|
||||
// whether the result should be displayed as an error or success message.
|
||||
@@ -141,7 +134,7 @@ func (c *CLI) DisplayInfo(message string) {
|
||||
func (c *CLI) DisplayExtensionBlock(text, borderColor, subtitle string) {
|
||||
theme := GetTheme()
|
||||
|
||||
var borderClr = lipgloss.Color("#89b4fa")
|
||||
borderClr := theme.Info
|
||||
if borderColor != "" {
|
||||
borderClr = lipgloss.Color(borderColor)
|
||||
}
|
||||
|
||||
@@ -94,6 +94,24 @@ var SlashCommands = []SlashCommand{
|
||||
return matches
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "/theme",
|
||||
Description: "Switch color theme (e.g. /theme catppuccin)",
|
||||
Category: "System",
|
||||
Complete: func(prefix string) []string {
|
||||
names := ListThemes()
|
||||
if prefix == "" {
|
||||
return names
|
||||
}
|
||||
var matches []string
|
||||
for _, n := range names {
|
||||
if strings.HasPrefix(n, strings.ToLower(prefix)) {
|
||||
matches = append(matches, n)
|
||||
}
|
||||
}
|
||||
return matches
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "/quit",
|
||||
Description: "Exit the application",
|
||||
|
||||
@@ -44,15 +44,20 @@ func (r *CompactRenderer) SetWidth(width int) {
|
||||
// and metadata.
|
||||
func (r *CompactRenderer) RenderUserMessage(content string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Secondary).Render(">")
|
||||
label := lipgloss.NewStyle().Foreground(theme.Secondary).Bold(true).Render("User")
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Info).Render(">")
|
||||
label := lipgloss.NewStyle().Foreground(theme.Info).Bold(true).Render("User")
|
||||
|
||||
// Convert single newlines to paragraph breaks so they survive glamour's
|
||||
// markdown rendering (glamour treats single \n as a soft break).
|
||||
content = strings.ReplaceAll(content, "\n", "\n\n")
|
||||
|
||||
// Format content for user messages (preserve formatting, no truncation)
|
||||
compactContent := r.formatUserAssistantContent(content)
|
||||
// Only run markdown rendering when the message contains code spans or
|
||||
// fenced code blocks. Plain text is rendered directly so that newlines
|
||||
// are preserved without the extra paragraph spacing glamour adds.
|
||||
var compactContent string
|
||||
if strings.Contains(content, "`") {
|
||||
mdContent := strings.ReplaceAll(content, "\n", "\n\n")
|
||||
compactContent = r.formatUserAssistantContent(mdContent)
|
||||
compactContent = removeBlankLines(compactContent)
|
||||
} else {
|
||||
compactContent = content
|
||||
}
|
||||
|
||||
// Handle multi-line content
|
||||
lines := strings.Split(compactContent, "\n")
|
||||
@@ -170,7 +175,7 @@ func (r *CompactRenderer) RenderToolMessage(toolName, toolArgs, toolResult strin
|
||||
if extRd != nil && extRd.DisplayName != "" {
|
||||
displayName = extRd.DisplayName
|
||||
}
|
||||
nameStr := lipgloss.NewStyle().Foreground(theme.Tool).Bold(true).Render(displayName)
|
||||
nameStr := lipgloss.NewStyle().Foreground(theme.Info).Bold(true).Render(displayName)
|
||||
|
||||
// Format params — check extension renderer first.
|
||||
paramBudget := max(r.width-10-len(displayName), 20)
|
||||
@@ -235,8 +240,8 @@ func (r *CompactRenderer) RenderToolMessage(toolName, toolArgs, toolResult strin
|
||||
// formatted to fit on a single line for minimal space usage.
|
||||
func (r *CompactRenderer) RenderSystemMessage(content string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.System).Render("*")
|
||||
label := lipgloss.NewStyle().Foreground(theme.System).Bold(true).Render("System")
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Muted).Render("◇")
|
||||
label := lipgloss.NewStyle().Foreground(theme.Muted).Bold(true).Render("System")
|
||||
|
||||
compactContent := r.formatCompactContent(content)
|
||||
|
||||
|
||||
@@ -39,9 +39,26 @@ func SetTheme(theme Theme) {
|
||||
currentTheme = theme
|
||||
}
|
||||
|
||||
// MarkdownThemeColors defines colors for markdown rendering and syntax highlighting.
|
||||
type MarkdownThemeColors struct {
|
||||
Text color.Color
|
||||
Muted color.Color
|
||||
Heading color.Color
|
||||
Emph color.Color
|
||||
Strong color.Color
|
||||
Link color.Color
|
||||
Code color.Color
|
||||
Error color.Color
|
||||
Keyword color.Color
|
||||
String color.Color
|
||||
Number color.Color
|
||||
Comment color.Color
|
||||
}
|
||||
|
||||
// Theme defines a comprehensive color scheme for the application's UI, supporting
|
||||
// both light and dark terminal modes through adaptive colors. It includes semantic
|
||||
// colors for different message types and UI elements, based on the Catppuccin color palette.
|
||||
// both light and dark terminal modes through adaptive colors. Inspired by the
|
||||
// Knight Rider KITT aesthetic — scanner reds, amber dashboard glows, and dark
|
||||
// cockpit tones.
|
||||
type Theme struct {
|
||||
Primary color.Color
|
||||
Secondary color.Color
|
||||
@@ -70,40 +87,60 @@ type Theme struct {
|
||||
CodeBg color.Color // Background for code blocks (Read tool)
|
||||
GutterBg color.Color // Line-number gutter background
|
||||
WriteBg color.Color // Green-tinted bg for Write tool content
|
||||
|
||||
// Markdown rendering and syntax highlighting colors
|
||||
Markdown MarkdownThemeColors
|
||||
}
|
||||
|
||||
// DefaultTheme creates and returns the default KIT theme based on the Catppuccin
|
||||
// Mocha (dark) and Latte (light) color palettes. This theme provides a cohesive,
|
||||
// pleasant visual experience with carefully selected colors for different UI elements.
|
||||
// DefaultTheme creates and returns the default KIT theme inspired by the
|
||||
// Knight Rider KITT aesthetic — scanner reds, amber dashboard glows, and a
|
||||
// dark cockpit. No blues or bright greens; everything stays in the warm
|
||||
// red/amber/gray family of KITT's instrument panel.
|
||||
func DefaultTheme() Theme {
|
||||
return Theme{
|
||||
Primary: AdaptiveColor("#8839ef", "#cba6f7"), // Latte/Mocha Mauve
|
||||
Secondary: AdaptiveColor("#04a5e5", "#89dceb"), // Latte/Mocha Sky
|
||||
Success: AdaptiveColor("#40a02b", "#a6e3a1"), // Latte/Mocha Green
|
||||
Warning: AdaptiveColor("#df8e1d", "#f9e2af"), // Latte/Mocha Yellow
|
||||
Error: AdaptiveColor("#d20f39", "#f38ba8"), // Latte/Mocha Red
|
||||
Info: AdaptiveColor("#1e66f5", "#89b4fa"), // Latte/Mocha Blue
|
||||
Text: AdaptiveColor("#4c4f69", "#cdd6f4"), // Latte/Mocha Text
|
||||
Muted: AdaptiveColor("#6c6f85", "#a6adc8"), // Latte/Mocha Subtext 0
|
||||
VeryMuted: AdaptiveColor("#9ca0b0", "#6c7086"), // Latte/Mocha Overlay 0
|
||||
Background: AdaptiveColor("#eff1f5", "#1e1e2e"), // Latte/Mocha Base
|
||||
Border: AdaptiveColor("#acb0be", "#585b70"), // Latte/Mocha Surface 2
|
||||
MutedBorder: AdaptiveColor("#ccd0da", "#313244"), // Latte/Mocha Surface 0
|
||||
System: AdaptiveColor("#179299", "#94e2d5"), // Latte/Mocha Teal
|
||||
Tool: AdaptiveColor("#fe640b", "#fab387"), // Latte/Mocha Peach
|
||||
Accent: AdaptiveColor("#ea76cb", "#f5c2e7"), // Latte/Mocha Pink
|
||||
Highlight: AdaptiveColor("#e6e9ef", "#181825"), // Latte Mantle / Mocha Mantle
|
||||
Primary: AdaptiveColor("#CC1100", "#FF2200"), // KITT scanner red
|
||||
Secondary: AdaptiveColor("#CC6600", "#FF8800"), // Amber dashboard glow
|
||||
Success: AdaptiveColor("#998800", "#CCAA00"), // Warm gold — system OK
|
||||
Warning: AdaptiveColor("#CC8800", "#FFB800"), // Amber caution light
|
||||
Error: AdaptiveColor("#CC0000", "#FF3333"), // Alert red
|
||||
Info: AdaptiveColor("#BB6600", "#DD8833"), // Warm amber readout
|
||||
Text: AdaptiveColor("#1A1A1A", "#E0E0E0"), // Console text
|
||||
Muted: AdaptiveColor("#707070", "#808080"), // Dimmed readout
|
||||
VeryMuted: AdaptiveColor("#A0A0A0", "#505050"), // Inactive element
|
||||
Background: AdaptiveColor("#F0F0F0", "#0D0D0D"), // Cockpit interior
|
||||
Border: AdaptiveColor("#B0B0B0", "#3A3A3A"), // Panel edge
|
||||
MutedBorder: AdaptiveColor("#D0D0D0", "#222222"), // Subtle divider
|
||||
System: AdaptiveColor("#CC6600", "#FF8800"), // Amber system status
|
||||
Tool: AdaptiveColor("#CC6600", "#FF8800"), // Amber instrument
|
||||
Accent: AdaptiveColor("#DD2222", "#FF4444"), // Secondary scanner glow
|
||||
Highlight: AdaptiveColor("#FFF0F0", "#1A1010"), // Red-tinted mantle
|
||||
|
||||
// Diff backgrounds — subtle tinted variants of the base palette
|
||||
DiffInsertBg: AdaptiveColor("#d5f0d5", "#1a3a2a"), // Green tint
|
||||
DiffDeleteBg: AdaptiveColor("#f5d5d5", "#3a1a2a"), // Red tint
|
||||
DiffEqualBg: AdaptiveColor("#eceef3", "#232336"), // Neutral
|
||||
DiffMissingBg: AdaptiveColor("#e4e6eb", "#1a1a2e"), // Darker neutral
|
||||
// Diff backgrounds
|
||||
DiffInsertBg: AdaptiveColor("#F0E8D0", "#2A2410"), // Warm amber tint (added)
|
||||
DiffDeleteBg: AdaptiveColor("#F5D5D5", "#2E1A1A"), // Red tint (removed)
|
||||
DiffEqualBg: AdaptiveColor("#E8E8E8", "#161616"), // Neutral
|
||||
DiffMissingBg: AdaptiveColor("#E0E0E0", "#111111"), // Darker neutral
|
||||
|
||||
// Code & output backgrounds
|
||||
CodeBg: AdaptiveColor("#eceef3", "#232336"), // Matches DiffEqualBg
|
||||
GutterBg: AdaptiveColor("#e4e6eb", "#1a1a2e"), // Slightly darker
|
||||
WriteBg: AdaptiveColor("#d5f0d5", "#1a3a2a"), // Matches DiffInsertBg (green tint)
|
||||
CodeBg: AdaptiveColor("#E8E8E8", "#161616"), // Matches DiffEqualBg
|
||||
GutterBg: AdaptiveColor("#E0E0E0", "#111111"), // Slightly darker
|
||||
WriteBg: AdaptiveColor("#F0E8D0", "#2A2410"), // Warm amber tint
|
||||
|
||||
// Markdown & syntax highlighting — all warm tones
|
||||
Markdown: MarkdownThemeColors{
|
||||
Text: AdaptiveColor("#1A1A1A", "#E0E0E0"), // Console text
|
||||
Muted: AdaptiveColor("#707070", "#808080"), // Dimmed readout
|
||||
Heading: AdaptiveColor("#CC1100", "#FF4444"), // Scanner red accent
|
||||
Emph: AdaptiveColor("#CC8800", "#FFB800"), // Amber emphasis
|
||||
Strong: AdaptiveColor("#1A1A1A", "#E0E0E0"), // Bright text
|
||||
Link: AdaptiveColor("#CC4400", "#FF7744"), // Warm orange link
|
||||
Code: AdaptiveColor("#333333", "#CCCCCC"), // Inline code
|
||||
Error: AdaptiveColor("#CC0000", "#FF3333"), // Alert red
|
||||
Keyword: AdaptiveColor("#CC3300", "#FF6644"), // Orange-red keyword
|
||||
String: AdaptiveColor("#BB7700", "#DDAA33"), // Amber string
|
||||
Number: AdaptiveColor("#CC8800", "#FFB800"), // Amber number
|
||||
Comment: AdaptiveColor("#909090", "#606060"), // Dark gray comment
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -51,8 +51,8 @@ func CreateUsageTracker(modelString, providerAPIKey string) *UsageTracker {
|
||||
}
|
||||
|
||||
registry := models.GetGlobalRegistry()
|
||||
modelInfo, err := registry.ValidateModel(provider, model)
|
||||
if err != nil {
|
||||
modelInfo := registry.LookupModel(provider, model)
|
||||
if modelInfo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -94,7 +94,7 @@ func SetupCLI(opts *CLISetupOptions) (*CLI, error) {
|
||||
// Skip usage tracking for ollama as it's not in models.dev
|
||||
if provider != "ollama" {
|
||||
registry := models.GetGlobalRegistry()
|
||||
if modelInfo, err := registry.ValidateModel(provider, model); err == nil {
|
||||
if modelInfo := registry.LookupModel(provider, model); modelInfo != nil {
|
||||
// Check if OAuth credentials are being used for Anthropic models
|
||||
isOAuth := false
|
||||
if provider == "anthropic" {
|
||||
|
||||
+83
-31
@@ -89,18 +89,19 @@ func NewInputComponent(width int, title string, appCtrl AppController) *InputCom
|
||||
ta.SetHeight(3) // Default to 3 lines like huh
|
||||
ta.Focus()
|
||||
|
||||
// Override InsertNewline so only ctrl+j and alt+enter insert newlines.
|
||||
// Override InsertNewline so only ctrl+j and shift+enter insert newlines.
|
||||
// Enter always submits the input.
|
||||
ta.KeyMap.InsertNewline = key.NewBinding(
|
||||
key.WithKeys("ctrl+j", "alt+enter"),
|
||||
key.WithKeys("ctrl+j", "shift+enter"),
|
||||
key.WithHelp("ctrl+j", "insert newline"),
|
||||
)
|
||||
|
||||
// Style the textarea to match huh theme
|
||||
// Style the textarea using theme colors.
|
||||
theme := GetTheme()
|
||||
styles := ta.Styles()
|
||||
styles.Focused.Base = lipgloss.NewStyle()
|
||||
styles.Focused.Placeholder = lipgloss.NewStyle().Foreground(lipgloss.Color("240"))
|
||||
styles.Focused.Text = lipgloss.NewStyle().Foreground(lipgloss.Color("252"))
|
||||
styles.Focused.Placeholder = lipgloss.NewStyle().Foreground(theme.VeryMuted)
|
||||
styles.Focused.Text = lipgloss.NewStyle().Foreground(theme.Text)
|
||||
styles.Focused.Prompt = lipgloss.NewStyle()
|
||||
styles.Focused.CursorLine = lipgloss.NewStyle()
|
||||
ta.SetStyles(styles)
|
||||
@@ -376,9 +377,11 @@ func (s *InputComponent) handleSubmit(value string) tea.Cmd {
|
||||
func (s *InputComponent) View() tea.View {
|
||||
containerStyle := lipgloss.NewStyle()
|
||||
|
||||
theme := GetTheme()
|
||||
|
||||
// PaddingLeft(3) aligns with message content: border(1) + paddingLeft(2).
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("252")).
|
||||
Foreground(theme.Text).
|
||||
MarginBottom(1).
|
||||
PaddingLeft(3)
|
||||
|
||||
@@ -388,7 +391,7 @@ func (s *InputComponent) View() tea.View {
|
||||
BorderRight(false).
|
||||
BorderTop(false).
|
||||
BorderBottom(false).
|
||||
BorderForeground(lipgloss.Color("39")).
|
||||
BorderForeground(theme.Primary).
|
||||
PaddingLeft(2). // match message block paddingLeft
|
||||
Width(s.width - 1) // full width minus left border
|
||||
|
||||
@@ -405,7 +408,7 @@ func (s *InputComponent) View() tea.View {
|
||||
// Show image attachment indicator when images are pending.
|
||||
if len(s.pendingImages) > 0 {
|
||||
imgStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("39")).
|
||||
Foreground(theme.Secondary).
|
||||
PaddingLeft(3)
|
||||
|
||||
label := fmt.Sprintf("[%d image(s) attached] ctrl+u to clear", len(s.pendingImages))
|
||||
@@ -415,11 +418,22 @@ func (s *InputComponent) View() tea.View {
|
||||
|
||||
if !s.hideHint {
|
||||
helpStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("240")).
|
||||
Foreground(theme.VeryMuted).
|
||||
MarginTop(1).
|
||||
PaddingLeft(3)
|
||||
|
||||
hint := "enter submit • ctrl+j / alt+enter new line • ctrl+v paste image"
|
||||
// Adapt hint text to available width (accounting for left padding of 3).
|
||||
var hint string
|
||||
availableHintWidth := s.width - 3
|
||||
if availableHintWidth >= 67 {
|
||||
hint = "enter submit • ctrl+j / shift+enter new line • ctrl+v paste image"
|
||||
} else if availableHintWidth >= 40 {
|
||||
hint = "↵ submit • ctrl+j newline • ctrl+v image"
|
||||
} else if availableHintWidth >= 20 {
|
||||
hint = "↵ submit • ctrl+j"
|
||||
} else {
|
||||
hint = "↵ submit"
|
||||
}
|
||||
view.WriteString("\n")
|
||||
view.WriteString(helpStyle.Render(hint))
|
||||
}
|
||||
@@ -429,13 +443,18 @@ func (s *InputComponent) View() tea.View {
|
||||
|
||||
// renderPopup renders the autocomplete popup for slash command suggestions.
|
||||
func (s *InputComponent) renderPopup() string {
|
||||
theme := GetTheme()
|
||||
popupWidth := max(s.width-4, 20)
|
||||
popupStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color("236")).
|
||||
BorderForeground(theme.MutedBorder).
|
||||
Padding(1, 2).
|
||||
Width(s.width - 4).
|
||||
Width(popupWidth).
|
||||
MarginLeft(0)
|
||||
|
||||
// Inner content width: popup minus border (2) and horizontal padding (4).
|
||||
innerWidth := max(popupWidth-6, 10)
|
||||
|
||||
var items []string
|
||||
|
||||
visibleItems := min(len(s.filtered), s.popupHeight)
|
||||
@@ -451,56 +470,89 @@ func (s *InputComponent) renderPopup() string {
|
||||
|
||||
var indicator string
|
||||
if i == s.selected {
|
||||
indicator = lipgloss.NewStyle().Foreground(lipgloss.Color("39")).Render("> ")
|
||||
indicator = lipgloss.NewStyle().Foreground(theme.Primary).Render("> ")
|
||||
} else {
|
||||
indicator = " "
|
||||
}
|
||||
|
||||
nameStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("39")).Bold(true)
|
||||
descStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("243"))
|
||||
nameStyle := lipgloss.NewStyle().Foreground(theme.Secondary).Bold(true)
|
||||
descStyle := lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
if i == s.selected {
|
||||
nameStyle = nameStyle.Foreground(lipgloss.Color("87"))
|
||||
descStyle = descStyle.Foreground(lipgloss.Color("250"))
|
||||
nameStyle = nameStyle.Foreground(theme.Primary)
|
||||
descStyle = descStyle.Foreground(theme.Text)
|
||||
}
|
||||
|
||||
if s.fileMode {
|
||||
// File mode: use full width for the path, show description
|
||||
// (e.g. "directory") inline after a gap.
|
||||
maxNameLen := s.width - 24
|
||||
maxNameLen := max(innerWidth-16, 8)
|
||||
displayName := sc.Name
|
||||
if len(displayName) > maxNameLen && maxNameLen > 3 {
|
||||
displayName = displayName[:maxNameLen-3] + "..."
|
||||
}
|
||||
name := nameStyle.Render(displayName)
|
||||
if sc.Description != "" {
|
||||
if sc.Description != "" && innerWidth > 30 {
|
||||
items = append(items, indicator+name+" "+descStyle.Render(sc.Description))
|
||||
} else {
|
||||
items = append(items, indicator+name)
|
||||
}
|
||||
} else {
|
||||
nameWidth := 15
|
||||
name := nameStyle.Width(nameWidth - 2).Render(sc.Name)
|
||||
// Line layout: indicator(2) + name(nameWidth-2 visual) + desc.
|
||||
if innerWidth < 20 {
|
||||
// Very narrow: show truncated name only, no fixed column.
|
||||
displayName := sc.Name
|
||||
maxName := max(innerWidth-2, 3)
|
||||
if len(displayName) > maxName {
|
||||
displayName = displayName[:maxName-1] + "…"
|
||||
}
|
||||
items = append(items, indicator+nameStyle.Render(displayName))
|
||||
} else {
|
||||
nameWidth := 15
|
||||
if innerWidth < 25 {
|
||||
nameWidth = max(innerWidth*2/5+1, 8)
|
||||
}
|
||||
maxNameChars := nameWidth - 2
|
||||
displayName := sc.Name
|
||||
if len(displayName) > maxNameChars {
|
||||
displayName = displayName[:maxNameChars-1] + "…"
|
||||
}
|
||||
name := nameStyle.Width(maxNameChars).Render(displayName)
|
||||
|
||||
desc := sc.Description
|
||||
maxDescLen := s.width - nameWidth - 14
|
||||
if len(desc) > maxDescLen && maxDescLen > 3 {
|
||||
desc = desc[:maxDescLen-3] + "..."
|
||||
// Description gets remaining space.
|
||||
maxDescLen := max(innerWidth-nameWidth, 0)
|
||||
desc := sc.Description
|
||||
if maxDescLen < 4 {
|
||||
items = append(items, indicator+name)
|
||||
} else {
|
||||
if len(desc) > maxDescLen {
|
||||
desc = desc[:maxDescLen-3] + "..."
|
||||
}
|
||||
items = append(items, indicator+name+descStyle.Render(desc))
|
||||
}
|
||||
}
|
||||
|
||||
items = append(items, indicator+name+descStyle.Render(desc))
|
||||
}
|
||||
}
|
||||
|
||||
if startIdx > 0 {
|
||||
items = append([]string{lipgloss.NewStyle().Foreground(lipgloss.Color("238")).Render(" ↑ more above")}, items...)
|
||||
items = append([]string{lipgloss.NewStyle().Foreground(theme.VeryMuted).Render(" ↑ more above")}, items...)
|
||||
}
|
||||
if endIdx < len(s.filtered) {
|
||||
items = append(items, lipgloss.NewStyle().Foreground(lipgloss.Color("238")).Render(" ↓ more below"))
|
||||
items = append(items, lipgloss.NewStyle().Foreground(theme.VeryMuted).Render(" ↓ more below"))
|
||||
}
|
||||
|
||||
content := strings.Join(items, "\n")
|
||||
footer := lipgloss.NewStyle().Foreground(lipgloss.Color("238")).Italic(true).
|
||||
Render("↑↓ navigate • tab complete • ↵ select • esc dismiss")
|
||||
|
||||
// Adapt footer text to available width.
|
||||
var footerText string
|
||||
if innerWidth >= 50 {
|
||||
footerText = "↑↓ navigate • tab complete • ↵ select • esc dismiss"
|
||||
} else if innerWidth >= 30 {
|
||||
footerText = "↑↓ nav • tab • ↵ select • esc"
|
||||
} else {
|
||||
footerText = "↑↓ tab ↵ esc"
|
||||
}
|
||||
footer := lipgloss.NewStyle().Foreground(theme.VeryMuted).Italic(true).
|
||||
Render(footerText)
|
||||
|
||||
return popupStyle.Render(content + "\n\n" + footer)
|
||||
}
|
||||
|
||||
+57
-117
@@ -3,8 +3,7 @@ package ui
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -12,6 +11,9 @@ import (
|
||||
"charm.land/lipgloss/v2"
|
||||
)
|
||||
|
||||
// ansiEscapeRe matches ANSI escape sequences used for terminal styling.
|
||||
var ansiEscapeRe = regexp.MustCompile(`\x1b\[[0-9;]*m`)
|
||||
|
||||
// MessageType represents different categories of messages displayed in the UI,
|
||||
// each with distinct visual styling and formatting rules.
|
||||
type MessageType int
|
||||
@@ -154,25 +156,10 @@ type MessageRenderer struct {
|
||||
getToolRenderer func(toolName string) *ToolRendererData
|
||||
}
|
||||
|
||||
// getSystemUsername returns the current system username, fallback to "User"
|
||||
func getSystemUsername() string {
|
||||
if currentUser, err := user.Current(); err == nil && currentUser.Username != "" {
|
||||
return currentUser.Username
|
||||
}
|
||||
// Fallback to environment variable
|
||||
if username := os.Getenv("USER"); username != "" {
|
||||
return username
|
||||
}
|
||||
if username := os.Getenv("USERNAME"); username != "" {
|
||||
return username
|
||||
}
|
||||
return "User"
|
||||
}
|
||||
|
||||
// NewMessageRenderer creates and initializes a new MessageRenderer with the specified
|
||||
// newMessageRenderer creates and initializes a new MessageRenderer with the specified
|
||||
// terminal width and debug mode setting. The width parameter determines line wrapping
|
||||
// and layout calculations.
|
||||
func NewMessageRenderer(width int, debug bool) *MessageRenderer {
|
||||
func newMessageRenderer(width int, debug bool) *MessageRenderer {
|
||||
return &MessageRenderer{
|
||||
width: width,
|
||||
debug: debug,
|
||||
@@ -189,31 +176,30 @@ func (r *MessageRenderer) SetWidth(width int) {
|
||||
// formatting, including the system username, timestamp, and markdown-rendered content.
|
||||
// The message is displayed with a colored right border for visual distinction.
|
||||
func (r *MessageRenderer) RenderUserMessage(content string, timestamp time.Time) UIMessage {
|
||||
// Format timestamp and username
|
||||
timeStr := timestamp.Local().Format("15:04")
|
||||
username := getSystemUsername()
|
||||
|
||||
// Convert single newlines to paragraph breaks so they survive glamour's
|
||||
// markdown rendering (glamour treats single \n as a soft break).
|
||||
content = strings.ReplaceAll(content, "\n", "\n\n")
|
||||
|
||||
theme := getTheme()
|
||||
|
||||
messageContent := r.renderMarkdown(content, r.width-8) // Account for padding and borders
|
||||
// Only run markdown rendering when the message contains code spans or
|
||||
// fenced code blocks. Plain text is rendered directly so that newlines
|
||||
// are preserved without the extra paragraph spacing glamour adds.
|
||||
var messageContent string
|
||||
if strings.Contains(content, "`") {
|
||||
// Glamour treats single \n as a soft break, so convert to paragraph
|
||||
// breaks and collapse the resulting blank lines after rendering.
|
||||
mdContent := strings.ReplaceAll(content, "\n", "\n\n")
|
||||
messageContent = r.renderMarkdown(mdContent, r.width-8)
|
||||
messageContent = removeBlankLines(messageContent)
|
||||
} else {
|
||||
messageContent = content
|
||||
}
|
||||
|
||||
// Create info line
|
||||
info := fmt.Sprintf(" %s (%s)", username, timeStr)
|
||||
fullContent := strings.TrimSuffix(messageContent, "\n")
|
||||
|
||||
// Combine content and info
|
||||
fullContent := strings.TrimSuffix(messageContent, "\n") + "\n" +
|
||||
lipgloss.NewStyle().Foreground(theme.VeryMuted).Render(info)
|
||||
|
||||
// Use the block renderer — left border with Primary color, no background.
|
||||
// Left border with Blue color for user messages.
|
||||
rendered := renderContentBlock(
|
||||
fullContent,
|
||||
r.width,
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(theme.Primary),
|
||||
WithBorderColor(theme.Info),
|
||||
WithMarginBottom(1),
|
||||
)
|
||||
|
||||
@@ -230,14 +216,8 @@ func (r *MessageRenderer) RenderUserMessage(content string, timestamp time.Time)
|
||||
// are displayed with a special "Finished without output" message. The message features
|
||||
// a colored left border for visual distinction.
|
||||
func (r *MessageRenderer) RenderAssistantMessage(content string, timestamp time.Time, modelName string) UIMessage {
|
||||
// Format timestamp and model info with better defaults
|
||||
timeStr := timestamp.Local().Format("15:04")
|
||||
if modelName == "" {
|
||||
modelName = "Assistant"
|
||||
}
|
||||
|
||||
// Handle empty content with better styling
|
||||
theme := getTheme()
|
||||
|
||||
var messageContent string
|
||||
if strings.TrimSpace(content) == "" {
|
||||
messageContent = lipgloss.NewStyle().
|
||||
@@ -246,21 +226,16 @@ func (r *MessageRenderer) RenderAssistantMessage(content string, timestamp time.
|
||||
Align(lipgloss.Center).
|
||||
Render("Finished without output")
|
||||
} else {
|
||||
messageContent = r.renderMarkdown(content, r.width-8) // Account for padding and borders
|
||||
messageContent = r.renderMarkdown(content, r.width-8)
|
||||
}
|
||||
|
||||
// Create info line
|
||||
info := fmt.Sprintf(" %s (%s)", modelName, timeStr)
|
||||
fullContent := strings.TrimSuffix(messageContent, "\n")
|
||||
|
||||
// Combine content and info
|
||||
fullContent := strings.TrimSuffix(messageContent, "\n") + "\n" +
|
||||
lipgloss.NewStyle().Foreground(theme.VeryMuted).Render(info)
|
||||
|
||||
// Use the new block renderer — no borders for agent messages.
|
||||
// Left border with Primary (Mauve) color for assistant messages.
|
||||
rendered := renderContentBlock(
|
||||
fullContent,
|
||||
r.width,
|
||||
WithNoBorder(),
|
||||
WithBorderColor(theme.Primary),
|
||||
WithMarginBottom(1),
|
||||
)
|
||||
|
||||
@@ -276,35 +251,24 @@ func (r *MessageRenderer) RenderAssistantMessage(content string, timestamp time.
|
||||
// and informational notifications. These messages are displayed with a distinctive system
|
||||
// color border and "KIT System" label to differentiate them from user and AI content.
|
||||
func (r *MessageRenderer) RenderSystemMessage(content string, timestamp time.Time) UIMessage {
|
||||
// Format timestamp
|
||||
timeStr := timestamp.Local().Format("15:04")
|
||||
|
||||
// Handle empty content with better styling
|
||||
theme := getTheme()
|
||||
|
||||
var messageContent string
|
||||
if strings.TrimSpace(content) == "" {
|
||||
messageContent = lipgloss.NewStyle().
|
||||
Italic(true).
|
||||
Foreground(theme.Muted).
|
||||
Align(lipgloss.Center).
|
||||
Render("No content available")
|
||||
messageContent = "No content available"
|
||||
} else if strings.Contains(content, "`") {
|
||||
messageContent = r.renderMarkdown(content, r.width-8)
|
||||
} else {
|
||||
messageContent = r.renderMarkdown(content, r.width-8) // Account for padding and borders
|
||||
messageContent = content
|
||||
}
|
||||
|
||||
// Create info line
|
||||
info := fmt.Sprintf(" KIT System (%s)", timeStr)
|
||||
fullContent := "◇ " + strings.TrimSuffix(messageContent, "\n")
|
||||
|
||||
// Combine content and info
|
||||
fullContent := strings.TrimSuffix(messageContent, "\n") + "\n" +
|
||||
lipgloss.NewStyle().Foreground(theme.VeryMuted).Render(info)
|
||||
|
||||
// Use the new block renderer
|
||||
rendered := renderContentBlock(
|
||||
fullContent,
|
||||
r.width,
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(theme.System),
|
||||
WithNoBorder(),
|
||||
WithForeground(theme.Muted),
|
||||
WithMarginBottom(1),
|
||||
)
|
||||
|
||||
@@ -322,29 +286,22 @@ func (r *MessageRenderer) RenderSystemMessage(content string, timestamp time.Tim
|
||||
func (r *MessageRenderer) RenderDebugMessage(message string, timestamp time.Time) UIMessage {
|
||||
baseStyle := lipgloss.NewStyle()
|
||||
|
||||
// Create the main message style with border using tool color
|
||||
theme := getTheme()
|
||||
style := baseStyle.
|
||||
Width(r.width - 3). // Account for left margin
|
||||
Width(r.width - 3).
|
||||
BorderLeft(true).
|
||||
Foreground(theme.Muted).
|
||||
BorderForeground(theme.Tool).
|
||||
BorderStyle(lipgloss.ThickBorder()).
|
||||
PaddingLeft(1).
|
||||
MarginLeft(2). // Add left margin like other messages
|
||||
MarginBottom(1) // Add bottom margin
|
||||
MarginLeft(2).
|
||||
MarginBottom(1)
|
||||
|
||||
// Format timestamp
|
||||
timeStr := timestamp.Local().Format("02 Jan 2006 03:04 PM")
|
||||
|
||||
// Create header with debug icon
|
||||
header := baseStyle.
|
||||
Foreground(theme.Tool).
|
||||
Bold(true).
|
||||
Render("🔍 Debug Output")
|
||||
|
||||
// Process and format the message content
|
||||
// Split into lines and format each one
|
||||
lines := strings.Split(message, "\n")
|
||||
var formattedLines []string
|
||||
for _, line := range lines {
|
||||
@@ -357,17 +314,9 @@ func (r *MessageRenderer) RenderDebugMessage(message string, timestamp time.Time
|
||||
Foreground(theme.Muted).
|
||||
Render(strings.Join(formattedLines, "\n"))
|
||||
|
||||
// Create info line
|
||||
info := baseStyle.
|
||||
Width(r.width - 5). // Account for margins and padding
|
||||
Foreground(theme.Muted).
|
||||
Render(fmt.Sprintf(" KIT (%s)", timeStr))
|
||||
|
||||
// Combine all parts
|
||||
fullContent := lipgloss.JoinVertical(lipgloss.Left,
|
||||
header,
|
||||
content,
|
||||
info,
|
||||
)
|
||||
|
||||
return UIMessage{
|
||||
@@ -382,7 +331,6 @@ func (r *MessageRenderer) RenderDebugMessage(message string, timestamp time.Time
|
||||
func (r *MessageRenderer) RenderDebugConfigMessage(config map[string]any, timestamp time.Time) UIMessage {
|
||||
baseStyle := lipgloss.NewStyle()
|
||||
|
||||
// Create the main message style with border using tool color
|
||||
theme := getTheme()
|
||||
style := baseStyle.
|
||||
Width(r.width - 1).
|
||||
@@ -392,16 +340,11 @@ func (r *MessageRenderer) RenderDebugConfigMessage(config map[string]any, timest
|
||||
BorderStyle(lipgloss.ThickBorder()).
|
||||
PaddingLeft(1)
|
||||
|
||||
// Format timestamp
|
||||
timeStr := timestamp.Local().Format("02 Jan 2006 03:04 PM")
|
||||
|
||||
// Create header with debug icon
|
||||
header := baseStyle.
|
||||
Foreground(theme.Tool).
|
||||
Bold(true).
|
||||
Render("🔧 Debug Configuration")
|
||||
|
||||
// Format configuration settings
|
||||
var configLines []string
|
||||
for key, value := range config {
|
||||
if value != nil {
|
||||
@@ -413,18 +356,10 @@ func (r *MessageRenderer) RenderDebugConfigMessage(config map[string]any, timest
|
||||
Foreground(theme.Muted).
|
||||
Render(strings.Join(configLines, "\n"))
|
||||
|
||||
// Create info line
|
||||
info := baseStyle.
|
||||
Width(r.width - 1).
|
||||
Foreground(theme.Muted).
|
||||
Render(fmt.Sprintf(" KIT (%s)", timeStr))
|
||||
|
||||
// Combine parts
|
||||
parts := []string{header}
|
||||
if len(configLines) > 0 {
|
||||
parts = append(parts, configContent)
|
||||
}
|
||||
parts = append(parts, info)
|
||||
|
||||
rendered := style.Render(
|
||||
lipgloss.JoinVertical(lipgloss.Left, parts...),
|
||||
@@ -442,26 +377,15 @@ func (r *MessageRenderer) RenderDebugConfigMessage(config map[string]any, timest
|
||||
// bold text to ensure visibility. Error messages include timestamp information and
|
||||
// are displayed with an error-colored border for immediate recognition.
|
||||
func (r *MessageRenderer) RenderErrorMessage(errorMsg string, timestamp time.Time) UIMessage {
|
||||
// Format timestamp
|
||||
timeStr := timestamp.Local().Format("15:04")
|
||||
|
||||
// Format error content
|
||||
theme := getTheme()
|
||||
|
||||
errorContent := lipgloss.NewStyle().
|
||||
Foreground(theme.Error).
|
||||
Bold(true).
|
||||
Render(errorMsg)
|
||||
|
||||
// Create info line
|
||||
info := fmt.Sprintf(" Error (%s)", timeStr)
|
||||
|
||||
// Combine content and info
|
||||
fullContent := errorContent + "\n" +
|
||||
lipgloss.NewStyle().Foreground(theme.VeryMuted).Render(info)
|
||||
|
||||
// Use the new block renderer
|
||||
rendered := renderContentBlock(
|
||||
fullContent,
|
||||
errorContent,
|
||||
r.width,
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(theme.Error),
|
||||
@@ -559,7 +483,7 @@ func (r *MessageRenderer) RenderToolMessage(toolName, toolArgs, toolResult strin
|
||||
if extRd != nil && extRd.DisplayName != "" {
|
||||
displayName = extRd.DisplayName
|
||||
}
|
||||
nameStr := lipgloss.NewStyle().Foreground(theme.Tool).Bold(true).Render(displayName)
|
||||
nameStr := lipgloss.NewStyle().Foreground(theme.Info).Bold(true).Render(displayName)
|
||||
|
||||
// Format params with width budget for the header line.
|
||||
// Check extension renderer for custom header params first.
|
||||
@@ -710,3 +634,19 @@ func (r *MessageRenderer) renderMarkdown(content string, width int) string {
|
||||
rendered := toMarkdown(content, width)
|
||||
return strings.TrimSuffix(rendered, "\n")
|
||||
}
|
||||
|
||||
// removeBlankLines removes lines that are visually blank from rendered output.
|
||||
// Glamour wraps every character (including padding spaces) with ANSI color
|
||||
// codes, so we must strip escape sequences before checking whether a line is
|
||||
// empty. This collapses paragraph spacing so user messages render without
|
||||
// extra vertical gaps.
|
||||
func removeBlankLines(s string) string {
|
||||
lines := strings.Split(s, "\n")
|
||||
filtered := lines[:0]
|
||||
for _, line := range lines {
|
||||
if strings.TrimSpace(ansiEscapeRe.ReplaceAllString(line, "")) != "" {
|
||||
filtered = append(filtered, line)
|
||||
}
|
||||
}
|
||||
return strings.Join(filtered, "\n")
|
||||
}
|
||||
|
||||
+331
-130
@@ -372,11 +372,9 @@ type AppModel struct {
|
||||
appCtrl AppController
|
||||
|
||||
// input is the child input component (slash commands + autocomplete).
|
||||
// Placeholder until InputComponent is implemented in TAS-15.
|
||||
input inputComponentIface
|
||||
|
||||
// stream is the child streaming display component (spinner + streaming text).
|
||||
// Placeholder until StreamComponent is implemented in TAS-16.
|
||||
stream streamComponentIface
|
||||
|
||||
// renderer renders completed messages for tea.Println output. It is either
|
||||
@@ -396,6 +394,20 @@ type AppModel struct {
|
||||
// the input and move to scrollback when the agent picks them up.
|
||||
queuedMessages []string
|
||||
|
||||
// pendingUserPrints holds user messages that have been consumed from the
|
||||
// queue but not yet printed to scrollback. They are deferred until
|
||||
// SpinnerEvent{Show: true} so the previous assistant response can be
|
||||
// flushed first, preserving chronological order.
|
||||
pendingUserPrints []string
|
||||
|
||||
// scrollbackBuf collects rendered content during a single Update() call.
|
||||
// All print helpers append here instead of returning tea.Println directly.
|
||||
// The buffer is drained into a single atomic tea.Println at the end of
|
||||
// each Update call via drainScrollback(). If the stream component has
|
||||
// unflushed content, it is automatically prepended so that new messages
|
||||
// always appear below the previous assistant response.
|
||||
scrollbackBuf []string
|
||||
|
||||
// canceling tracks whether the user has pressed ESC once during stateWorking.
|
||||
// A second ESC within 2 seconds will cancel the current step.
|
||||
canceling bool
|
||||
@@ -579,7 +591,7 @@ func NewAppModel(appCtrl AppController, opts AppModelOptions) *AppModel {
|
||||
cr.getToolRenderer = opts.GetToolRenderer
|
||||
rdr = cr
|
||||
} else {
|
||||
mr := NewMessageRenderer(width, false)
|
||||
mr := newMessageRenderer(width, false)
|
||||
mr.getToolRenderer = opts.GetToolRenderer
|
||||
rdr = mr
|
||||
}
|
||||
@@ -829,7 +841,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if m.setModel != nil {
|
||||
previousModel := m.providerName + "/" + m.modelName
|
||||
if err := m.setModel(msg.ModelString); err != nil {
|
||||
cmds = append(cmds, m.printSystemMessage(fmt.Sprintf("Failed to switch model: %v", err)))
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to switch model: %v", err))
|
||||
} else {
|
||||
// Update display state directly — we cannot use
|
||||
// NotifyModelChanged (prog.Send) from inside Update()
|
||||
@@ -839,7 +851,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.providerName = parts[0]
|
||||
m.modelName = parts[1]
|
||||
}
|
||||
cmds = append(cmds, m.printSystemMessage(fmt.Sprintf("Switched to %s", msg.ModelString)))
|
||||
m.printSystemMessage(fmt.Sprintf("Switched to %s", msg.ModelString))
|
||||
if m.emitModelChange != nil {
|
||||
emit := m.emitModelChange
|
||||
newModel := msg.ModelString
|
||||
@@ -848,6 +860,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
}
|
||||
}
|
||||
cmds = append(cmds, m.drainScrollback())
|
||||
return m, tea.Batch(cmds...)
|
||||
|
||||
case ModelSelectorCancelledMsg:
|
||||
@@ -1018,6 +1031,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if cmd := m.handleSlashCommand(sc); cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
cmds = append(cmds, m.drainScrollback())
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
@@ -1031,16 +1045,25 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if cmd := m.handleCompactCommand(strings.TrimSpace(args)); cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
cmds = append(cmds, m.drainScrollback())
|
||||
return m, tea.Batch(cmds...)
|
||||
case "/model":
|
||||
if cmd := m.handleModelCommand(strings.TrimSpace(args)); cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
cmds = append(cmds, m.drainScrollback())
|
||||
return m, tea.Batch(cmds...)
|
||||
case "/thinking":
|
||||
if cmd := m.handleThinkingCommand(strings.TrimSpace(args)); cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
cmds = append(cmds, m.drainScrollback())
|
||||
return m, tea.Batch(cmds...)
|
||||
case "/theme":
|
||||
if cmd := m.handleThemeCommand(strings.TrimSpace(args)); cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
cmds = append(cmds, m.drainScrollback())
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
}
|
||||
@@ -1091,15 +1114,19 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if qLen > 0 {
|
||||
// Queued: anchor the message text above the input with a
|
||||
// "queued" badge. It will be printed to scrollback when
|
||||
// the agent picks it up (on QueueUpdatedEvent).
|
||||
// the agent picks it up (via SpinnerEvent).
|
||||
m.queuedMessages = append(m.queuedMessages, displayText)
|
||||
m.distributeHeight()
|
||||
} else {
|
||||
// Started immediately: print to scrollback now.
|
||||
cmds = append(cmds, m.printUserMessage(displayText))
|
||||
// Started immediately. Flush any leftover stream content
|
||||
// from the previous step first, then print the user
|
||||
// message — combined via the scrollback buffer so
|
||||
// scrollback stays in chronological order.
|
||||
m.pendingUserPrints = append(m.pendingUserPrints, displayText)
|
||||
m.flushStreamAndPendingUserMessages()
|
||||
}
|
||||
} else {
|
||||
cmds = append(cmds, m.printUserMessage(displayText))
|
||||
m.printUserMessage(displayText)
|
||||
}
|
||||
if m.state != stateWorking {
|
||||
m.state = stateWorking
|
||||
@@ -1107,10 +1134,22 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
// ── Shell command (! / !!) ───────────────────────────────────────────────
|
||||
case shellCommandMsg:
|
||||
// Show spinner while the shell command runs.
|
||||
m.state = stateWorking
|
||||
if m.stream != nil {
|
||||
_, cmd := m.stream.Update(app.SpinnerEvent{Show: true})
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
// Execute the shell command asynchronously so the TUI stays responsive.
|
||||
cmds = append(cmds, m.executeShellCommand(msg))
|
||||
|
||||
case shellCommandResultMsg:
|
||||
// Stop spinner now that the command has finished.
|
||||
if m.stream != nil {
|
||||
_, cmd := m.stream.Update(app.SpinnerEvent{Show: false})
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
m.state = stateInput
|
||||
cmds = append(cmds, m.handleShellCommandResult(msg))
|
||||
|
||||
// ── App layer events ─────────────────────────────────────────────────────
|
||||
@@ -1119,10 +1158,11 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// SpinnerEvent{Show: true} means a new agent step has started (either
|
||||
// freshly or from the queue after a previous step completed). Flush
|
||||
// any leftover stream content from the previous step to scrollback
|
||||
// before starting the new one. This deferred flush avoids shrinking
|
||||
// the view at step-completion time (which leaves blank lines).
|
||||
// before starting the new one, followed by any pending user messages
|
||||
// from the queue. Everything goes through the scrollback buffer to
|
||||
// guarantee chronological ordering.
|
||||
if msg.Show {
|
||||
cmds = append(cmds, m.flushStreamContent())
|
||||
m.flushStreamAndPendingUserMessages()
|
||||
m.state = stateWorking
|
||||
m.distributeHeight()
|
||||
}
|
||||
@@ -1148,7 +1188,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// always completes before tool calls fire). The tool call itself is
|
||||
// NOT printed here — a unified block (header + result) will be
|
||||
// rendered when the ToolResultEvent arrives.
|
||||
cmds = append(cmds, m.flushStreamContent())
|
||||
m.flushStreamContent()
|
||||
|
||||
case app.ToolExecutionEvent:
|
||||
// Pass to stream component for execution spinner display.
|
||||
@@ -1158,8 +1198,8 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
|
||||
case app.ToolResultEvent:
|
||||
// Print tool result immediately to scrollback.
|
||||
cmds = append(cmds, m.printToolResult(msg))
|
||||
// Buffer tool result for scrollback.
|
||||
m.printToolResult(msg)
|
||||
// Start spinner again while waiting for the next LLM response.
|
||||
if m.stream != nil {
|
||||
_, cmd := m.stream.Update(app.SpinnerEvent{Show: true})
|
||||
@@ -1179,7 +1219,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// In non-streaming mode (no stream content accumulated), print the text.
|
||||
hasStreamContent := m.stream != nil && m.stream.GetRenderedContent() != ""
|
||||
if !hasStreamContent && msg.Content != "" {
|
||||
cmds = append(cmds, m.printAssistantMessage(msg.Content))
|
||||
m.printAssistantMessage(msg.Content)
|
||||
if m.stream != nil {
|
||||
m.stream.Reset()
|
||||
}
|
||||
@@ -1189,13 +1229,14 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// Informational — no action needed by parent.
|
||||
|
||||
case app.QueueUpdatedEvent:
|
||||
// drainQueue popped item(s) from the queue. Move consumed messages
|
||||
// from the anchored display to scrollback (they are now being processed
|
||||
// or about to be).
|
||||
// drainQueue popped item(s) from the queue. Move consumed
|
||||
// messages to pendingUserPrints — they will be printed to
|
||||
// scrollback in the next SpinnerEvent{Show: true} after the
|
||||
// previous assistant response is flushed.
|
||||
for len(m.queuedMessages) > msg.Length {
|
||||
text := m.queuedMessages[0]
|
||||
m.queuedMessages = m.queuedMessages[1:]
|
||||
cmds = append(cmds, m.printUserMessage(text))
|
||||
m.pendingUserPrints = append(m.pendingUserPrints, text)
|
||||
}
|
||||
m.distributeHeight()
|
||||
|
||||
@@ -1232,7 +1273,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
if msg.Err != nil {
|
||||
cmds = append(cmds, m.printErrorResponse(msg))
|
||||
m.printErrorResponse(msg)
|
||||
}
|
||||
m.state = stateInput
|
||||
m.canceling = false
|
||||
@@ -1242,14 +1283,14 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.stream.Reset()
|
||||
}
|
||||
m.state = stateInput
|
||||
cmds = append(cmds, m.printCompactResult(msg))
|
||||
m.printCompactResult(msg)
|
||||
|
||||
case app.CompactErrorEvent:
|
||||
if m.stream != nil {
|
||||
m.stream.Reset()
|
||||
}
|
||||
m.state = stateInput
|
||||
cmds = append(cmds, m.printSystemMessage(fmt.Sprintf("Compaction failed: %v", msg.Err)))
|
||||
m.printSystemMessage(fmt.Sprintf("Compaction failed: %v", msg.Err))
|
||||
|
||||
case app.ModelChangedEvent:
|
||||
// Extension changed the model — update display name in status bar
|
||||
@@ -1357,17 +1398,16 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
case extensionCmdResultMsg:
|
||||
// Async extension slash command completed. Render output/error.
|
||||
if msg.err != nil {
|
||||
cmds = append(cmds, m.printSystemMessage(
|
||||
fmt.Sprintf("Command %s error: %v", msg.name, msg.err)))
|
||||
m.printSystemMessage(fmt.Sprintf("Command %s error: %v", msg.name, msg.err))
|
||||
} else if msg.output != "" {
|
||||
cmds = append(cmds, m.printSystemMessage(msg.output))
|
||||
m.printSystemMessage(msg.output)
|
||||
}
|
||||
|
||||
case beforeSessionSwitchResultMsg:
|
||||
// Async before-session-switch hook completed. Proceed with the
|
||||
// session reset if the hook did not cancel.
|
||||
if msg.cancelled {
|
||||
cmds = append(cmds, m.printSystemMessage(msg.reason))
|
||||
m.printSystemMessage(msg.reason)
|
||||
} else {
|
||||
cmds = append(cmds, m.performNewSession())
|
||||
}
|
||||
@@ -1376,7 +1416,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// Async before-fork hook completed. Proceed with the fork if the
|
||||
// hook did not cancel.
|
||||
if msg.cancelled {
|
||||
cmds = append(cmds, m.printSystemMessage(msg.reason))
|
||||
m.printSystemMessage(msg.reason)
|
||||
} else {
|
||||
cmds = append(cmds, m.performFork(msg.targetID, msg.isUser, msg.userText))
|
||||
}
|
||||
@@ -1385,15 +1425,15 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// Extension output — route through styled renderers when a level is set.
|
||||
switch msg.Level {
|
||||
case "info":
|
||||
cmds = append(cmds, m.printSystemMessage(msg.Text))
|
||||
m.printSystemMessage(msg.Text)
|
||||
case "error":
|
||||
cmds = append(cmds, m.printErrorResponse(app.StepErrorEvent{
|
||||
m.printErrorResponse(app.StepErrorEvent{
|
||||
Err: fmt.Errorf("%s", msg.Text),
|
||||
}))
|
||||
})
|
||||
case "block":
|
||||
cmds = append(cmds, m.printExtensionBlock(msg))
|
||||
m.printExtensionBlock(msg)
|
||||
default:
|
||||
cmds = append(cmds, tea.Println(msg.Text))
|
||||
m.appendScrollback(msg.Text)
|
||||
}
|
||||
|
||||
default:
|
||||
@@ -1408,6 +1448,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
}
|
||||
|
||||
cmds = append(cmds, m.drainScrollback())
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
@@ -1510,8 +1551,9 @@ func (m *AppModel) renderStream() string {
|
||||
|
||||
// Show canceling warning if set.
|
||||
if m.canceling {
|
||||
theme := GetTheme()
|
||||
warning := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("214")).
|
||||
Foreground(theme.Warning).
|
||||
Bold(true).
|
||||
Render(" ⚠ Press ESC again to cancel")
|
||||
return lipgloss.JoinVertical(lipgloss.Left,
|
||||
@@ -1581,10 +1623,31 @@ func (m *AppModel) renderStatusBar() string {
|
||||
|
||||
rightSide := strings.Join(rightParts, " ")
|
||||
|
||||
// Fill the gap between left+middle and right with spaces.
|
||||
usedWidth := lipgloss.Width(leftSide) + lipgloss.Width(middleSide) + lipgloss.Width(rightSide)
|
||||
gap := max(m.width-usedWidth, 1)
|
||||
// Progressive truncation to keep the status bar on one line.
|
||||
// When content exceeds terminal width, drop sections in order:
|
||||
// middle (extensions/thinking) → usage stats → model label → right side.
|
||||
leftW := lipgloss.Width(leftSide)
|
||||
middleW := lipgloss.Width(middleSide)
|
||||
rightW := lipgloss.Width(rightSide)
|
||||
|
||||
// Need at least 1 space gap between left+middle and right.
|
||||
if leftW+middleW+rightW+1 > m.width {
|
||||
// Drop middle section first (extensions/thinking status).
|
||||
middleSide = ""
|
||||
middleW = 0
|
||||
}
|
||||
if leftW+rightW+1 > m.width && len(rightParts) > 1 {
|
||||
// Drop usage stats, keep model label.
|
||||
rightSide = rightParts[0]
|
||||
rightW = lipgloss.Width(rightSide)
|
||||
}
|
||||
if leftW+rightW+1 > m.width {
|
||||
// Drop right side entirely.
|
||||
rightSide = ""
|
||||
rightW = 0
|
||||
}
|
||||
|
||||
gap := max(m.width-leftW-middleW-rightW, 1)
|
||||
return leftSide + middleSide + strings.Repeat(" ", gap) + rightSide
|
||||
}
|
||||
|
||||
@@ -1753,30 +1816,28 @@ func (m *AppModel) renderQueuedMessages() string {
|
||||
// Print helpers — emit content to scrollback via tea.Println
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// printUserMessage renders a user message and emits it above the BT region.
|
||||
func (m *AppModel) printUserMessage(text string) tea.Cmd {
|
||||
return tea.Println(m.renderer.RenderUserMessage(text, time.Now()).Content)
|
||||
// printUserMessage renders a user message into the scrollback buffer.
|
||||
func (m *AppModel) printUserMessage(text string) {
|
||||
m.appendScrollback(m.renderer.RenderUserMessage(text, time.Now()).Content)
|
||||
}
|
||||
|
||||
// printAssistantMessage renders an assistant message and emits it above the BT region.
|
||||
func (m *AppModel) printAssistantMessage(text string) tea.Cmd {
|
||||
if text == "" {
|
||||
return nil
|
||||
// printAssistantMessage renders an assistant message into the scrollback buffer.
|
||||
func (m *AppModel) printAssistantMessage(text string) {
|
||||
if text != "" {
|
||||
m.appendScrollback(m.renderer.RenderAssistantMessage(text, time.Now(), m.modelName).Content)
|
||||
}
|
||||
return tea.Println(m.renderer.RenderAssistantMessage(text, time.Now(), m.modelName).Content)
|
||||
}
|
||||
|
||||
// printToolResult renders a tool result message and emits it above the BT region.
|
||||
func (m *AppModel) printToolResult(evt app.ToolResultEvent) tea.Cmd {
|
||||
return tea.Println(m.renderer.RenderToolMessage(evt.ToolName, evt.ToolArgs, evt.Result, evt.IsError).Content)
|
||||
// printToolResult renders a tool result message into the scrollback buffer.
|
||||
func (m *AppModel) printToolResult(evt app.ToolResultEvent) {
|
||||
m.appendScrollback(m.renderer.RenderToolMessage(evt.ToolName, evt.ToolArgs, evt.Result, evt.IsError).Content)
|
||||
}
|
||||
|
||||
// printErrorResponse renders an error message and emits it above the BT region.
|
||||
func (m *AppModel) printErrorResponse(evt app.StepErrorEvent) tea.Cmd {
|
||||
if evt.Err == nil {
|
||||
return nil
|
||||
// printErrorResponse renders an error message into the scrollback buffer.
|
||||
func (m *AppModel) printErrorResponse(evt app.StepErrorEvent) {
|
||||
if evt.Err != nil {
|
||||
m.appendScrollback(m.renderer.RenderErrorMessage(evt.Err.Error(), time.Now()).Content)
|
||||
}
|
||||
return tea.Println(m.renderer.RenderErrorMessage(evt.Err.Error(), time.Now()).Content)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -1791,17 +1852,19 @@ func (m *AppModel) handleSlashCommand(sc *SlashCommand) tea.Cmd {
|
||||
case "/quit":
|
||||
return tea.Quit
|
||||
case "/help":
|
||||
return m.printHelpMessage()
|
||||
m.printHelpMessage()
|
||||
case "/tools":
|
||||
return m.printToolsMessage()
|
||||
m.printToolsMessage()
|
||||
case "/servers":
|
||||
return m.printServersMessage()
|
||||
m.printServersMessage()
|
||||
case "/usage":
|
||||
return m.printUsageMessage()
|
||||
m.printUsageMessage()
|
||||
case "/reset-usage":
|
||||
return m.printResetUsage()
|
||||
m.printResetUsage()
|
||||
case "/model":
|
||||
return m.handleModelCommand("")
|
||||
case "/theme":
|
||||
return m.handleThemeCommand("")
|
||||
case "/thinking":
|
||||
return m.handleThinkingCommand("")
|
||||
case "/compact":
|
||||
@@ -1810,14 +1873,13 @@ func (m *AppModel) handleSlashCommand(sc *SlashCommand) tea.Cmd {
|
||||
if m.appCtrl != nil {
|
||||
m.appCtrl.ClearMessages()
|
||||
}
|
||||
return m.printSystemMessage("Conversation cleared. Starting fresh.")
|
||||
m.printSystemMessage("Conversation cleared. Starting fresh.")
|
||||
case "/clear-queue":
|
||||
if m.appCtrl != nil {
|
||||
m.appCtrl.ClearQueue()
|
||||
}
|
||||
m.queuedMessages = m.queuedMessages[:0]
|
||||
m.distributeHeight()
|
||||
return nil
|
||||
|
||||
case "/tree":
|
||||
return m.handleTreeCommand()
|
||||
@@ -1831,22 +1893,23 @@ func (m *AppModel) handleSlashCommand(sc *SlashCommand) tea.Cmd {
|
||||
return m.handleSessionInfoCommand()
|
||||
|
||||
default:
|
||||
return m.printSystemMessage(fmt.Sprintf("Unknown command: %s", sc.Name))
|
||||
m.printSystemMessage(fmt.Sprintf("Unknown command: %s", sc.Name))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// printSystemMessage renders a system-level message and emits it above the BT region.
|
||||
func (m *AppModel) printSystemMessage(text string) tea.Cmd {
|
||||
return tea.Println(m.renderer.RenderSystemMessage(text, time.Now()).Content)
|
||||
// printSystemMessage renders a system-level message into the scrollback buffer.
|
||||
func (m *AppModel) printSystemMessage(text string) {
|
||||
m.appendScrollback(m.renderer.RenderSystemMessage(text, time.Now()).Content)
|
||||
}
|
||||
|
||||
// printExtensionBlock renders a custom styled block from an extension with
|
||||
// caller-chosen border color and optional subtitle, then emits it to scrollback.
|
||||
func (m *AppModel) printExtensionBlock(evt app.ExtensionPrintEvent) tea.Cmd {
|
||||
// caller-chosen border color and optional subtitle into the scrollback buffer.
|
||||
func (m *AppModel) printExtensionBlock(evt app.ExtensionPrintEvent) {
|
||||
theme := GetTheme()
|
||||
|
||||
// Resolve border color: use the extension's hex value, fall back to theme accent.
|
||||
var borderClr = lipgloss.Color("#89b4fa") // default blue
|
||||
// Resolve border color: use the extension's hex value, fall back to theme info.
|
||||
borderClr := theme.Info
|
||||
if evt.BorderColor != "" {
|
||||
borderClr = lipgloss.Color(evt.BorderColor)
|
||||
}
|
||||
@@ -1865,7 +1928,7 @@ func (m *AppModel) printExtensionBlock(evt app.ExtensionPrintEvent) tea.Cmd {
|
||||
WithBorderColor(borderClr),
|
||||
WithMarginBottom(1),
|
||||
)
|
||||
return tea.Println(rendered)
|
||||
m.appendScrollback(rendered)
|
||||
}
|
||||
|
||||
// handleExtensionCommand checks if the submitted text matches an extension-
|
||||
@@ -1916,7 +1979,7 @@ func (m *AppModel) handleExtensionCommand(text string) tea.Cmd {
|
||||
}
|
||||
|
||||
// printHelpMessage renders the help text listing all available slash commands.
|
||||
func (m *AppModel) printHelpMessage() tea.Cmd {
|
||||
func (m *AppModel) printHelpMessage() {
|
||||
help := "## Available Commands\n\n" +
|
||||
"**Info:**\n" +
|
||||
"- `/help`: Show this help message\n" +
|
||||
@@ -1966,11 +2029,11 @@ func (m *AppModel) printHelpMessage() tea.Cmd {
|
||||
"- `Ctrl+C`: Exit at any time\n" +
|
||||
"- `ESC` (x2): Cancel ongoing LLM generation\n\n" +
|
||||
"You can also just type your message to chat with the AI assistant."
|
||||
return m.printSystemMessage(help)
|
||||
m.printSystemMessage(help)
|
||||
}
|
||||
|
||||
// printToolsMessage renders the list of available tools.
|
||||
func (m *AppModel) printToolsMessage() tea.Cmd {
|
||||
func (m *AppModel) printToolsMessage() {
|
||||
var content string
|
||||
content = "## Available Tools\n\n"
|
||||
if len(m.toolNames) == 0 {
|
||||
@@ -1980,11 +2043,11 @@ func (m *AppModel) printToolsMessage() tea.Cmd {
|
||||
content += fmt.Sprintf("%d. `%s`\n", i+1, tool)
|
||||
}
|
||||
}
|
||||
return m.printSystemMessage(content)
|
||||
m.printSystemMessage(content)
|
||||
}
|
||||
|
||||
// printServersMessage renders the list of configured MCP servers.
|
||||
func (m *AppModel) printServersMessage() tea.Cmd {
|
||||
func (m *AppModel) printServersMessage() {
|
||||
var content string
|
||||
content = "## Configured MCP Servers\n\n"
|
||||
if len(m.serverNames) == 0 {
|
||||
@@ -1994,13 +2057,14 @@ func (m *AppModel) printServersMessage() tea.Cmd {
|
||||
content += fmt.Sprintf("%d. `%s`\n", i+1, server)
|
||||
}
|
||||
}
|
||||
return m.printSystemMessage(content)
|
||||
m.printSystemMessage(content)
|
||||
}
|
||||
|
||||
// printUsageMessage renders token usage statistics.
|
||||
func (m *AppModel) printUsageMessage() tea.Cmd {
|
||||
func (m *AppModel) printUsageMessage() {
|
||||
if m.usageTracker == nil {
|
||||
return m.printSystemMessage("Usage tracking is not available for this model.")
|
||||
m.printSystemMessage("Usage tracking is not available for this model.")
|
||||
return
|
||||
}
|
||||
|
||||
sessionStats := m.usageTracker.GetSessionStats()
|
||||
@@ -2014,16 +2078,17 @@ func (m *AppModel) printUsageMessage() tea.Cmd {
|
||||
content += fmt.Sprintf("**Session Total:** %d input + %d output tokens = $%.6f (%d requests)\n",
|
||||
sessionStats.TotalInputTokens, sessionStats.TotalOutputTokens, sessionStats.TotalCost, sessionStats.RequestCount)
|
||||
|
||||
return m.printSystemMessage(content)
|
||||
m.printSystemMessage(content)
|
||||
}
|
||||
|
||||
// printResetUsage resets usage statistics and prints a confirmation.
|
||||
func (m *AppModel) printResetUsage() tea.Cmd {
|
||||
func (m *AppModel) printResetUsage() {
|
||||
if m.usageTracker == nil {
|
||||
return m.printSystemMessage("Usage tracking is not available for this model.")
|
||||
m.printSystemMessage("Usage tracking is not available for this model.")
|
||||
return
|
||||
}
|
||||
m.usageTracker.Reset()
|
||||
return m.printSystemMessage("Usage statistics have been reset.")
|
||||
m.printSystemMessage("Usage statistics have been reset.")
|
||||
}
|
||||
|
||||
// handleCompactCommand starts an async compaction. It returns a tea.Cmd that
|
||||
@@ -2033,23 +2098,26 @@ func (m *AppModel) printResetUsage() tea.Cmd {
|
||||
// prompt (e.g. "Focus on the API design decisions").
|
||||
func (m *AppModel) handleCompactCommand(customInstructions string) tea.Cmd {
|
||||
if m.appCtrl == nil {
|
||||
return m.printSystemMessage("Compaction is not available.")
|
||||
m.printSystemMessage("Compaction is not available.")
|
||||
return nil
|
||||
}
|
||||
if err := m.appCtrl.CompactConversation(customInstructions); err != nil {
|
||||
return m.printSystemMessage(fmt.Sprintf("Cannot compact: %v", err))
|
||||
m.printSystemMessage(fmt.Sprintf("Cannot compact: %v", err))
|
||||
return nil
|
||||
}
|
||||
// Transition to working state so the spinner shows while compaction runs.
|
||||
m.state = stateWorking
|
||||
m.printSystemMessage("Compacting conversation...")
|
||||
var spinnerCmd tea.Cmd
|
||||
if m.stream != nil {
|
||||
_, spinnerCmd = m.stream.Update(app.SpinnerEvent{Show: true})
|
||||
}
|
||||
return tea.Batch(m.printSystemMessage("Compacting conversation..."), spinnerCmd)
|
||||
return spinnerCmd
|
||||
}
|
||||
|
||||
// printCompactResult renders the compaction summary in a styled block with
|
||||
// a distinct border color and a stats subtitle.
|
||||
func (m *AppModel) printCompactResult(evt app.CompactCompleteEvent) tea.Cmd {
|
||||
// a distinct border color and a stats subtitle into the scrollback buffer.
|
||||
func (m *AppModel) printCompactResult(evt app.CompactCompleteEvent) {
|
||||
theme := GetTheme()
|
||||
|
||||
saved := evt.OriginalTokens - evt.CompactedTokens
|
||||
@@ -2071,32 +2139,89 @@ func (m *AppModel) printCompactResult(evt app.CompactCompleteEvent) tea.Cmd {
|
||||
WithBorderColor(theme.Secondary),
|
||||
WithMarginBottom(1),
|
||||
)
|
||||
return tea.Println(rendered)
|
||||
m.appendScrollback(rendered)
|
||||
}
|
||||
|
||||
// flushStreamContent gets the rendered content from the stream component,
|
||||
// emits it above the BT region via tea.Println, and resets the stream. This
|
||||
// is called before printing tool calls (streaming completes before tools fire)
|
||||
// and on step completion.
|
||||
//
|
||||
// After flushing, a ClearScreen is issued to force a full terminal redraw.
|
||||
// When
|
||||
// the stream content is moved to scrollback the view height shrinks, and
|
||||
// bubbletea's inline renderer doesn't clear the orphaned terminal rows
|
||||
// below the managed region. ClearScreen ensures a clean redraw.
|
||||
func (m *AppModel) flushStreamContent() tea.Cmd {
|
||||
// flushStreamContent moves rendered content from the stream component into the
|
||||
// scrollback buffer and resets the stream. Called before tool calls (streaming
|
||||
// completes before tools fire). The actual tea.Println is deferred to
|
||||
// drainScrollback() at the end of the Update cycle.
|
||||
func (m *AppModel) flushStreamContent() {
|
||||
if m.stream == nil {
|
||||
return nil
|
||||
return
|
||||
}
|
||||
content := m.stream.GetRenderedContent()
|
||||
if content == "" {
|
||||
return nil
|
||||
return
|
||||
}
|
||||
m.stream.Reset()
|
||||
return tea.Sequence(
|
||||
tea.Println(content),
|
||||
func() tea.Msg { return tea.ClearScreen() },
|
||||
)
|
||||
m.appendScrollback(content)
|
||||
}
|
||||
|
||||
// flushStreamAndPendingUserMessages moves the previous assistant response and
|
||||
// any pending queued user messages into the scrollback buffer. Called from
|
||||
// SpinnerEvent{Show: true} where all previous stream chunks are guaranteed to
|
||||
// have been processed. The actual tea.Println is deferred to drainScrollback().
|
||||
func (m *AppModel) flushStreamAndPendingUserMessages() {
|
||||
// 1. Flush previous stream content (assistant response).
|
||||
if m.stream != nil {
|
||||
if content := m.stream.GetRenderedContent(); content != "" {
|
||||
m.stream.Reset()
|
||||
m.appendScrollback(content)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Render pending user messages from the queue.
|
||||
for _, text := range m.pendingUserPrints {
|
||||
rendered := m.renderer.RenderUserMessage(text, time.Now()).Content
|
||||
m.appendScrollback(rendered)
|
||||
}
|
||||
m.pendingUserPrints = nil
|
||||
}
|
||||
|
||||
// appendScrollback adds rendered content to the scrollback buffer. The content
|
||||
// will be emitted via tea.Println when drainScrollback is called at the end of
|
||||
// the current Update cycle.
|
||||
func (m *AppModel) appendScrollback(content string) {
|
||||
if content != "" {
|
||||
m.scrollbackBuf = append(m.scrollbackBuf, content)
|
||||
}
|
||||
}
|
||||
|
||||
// drainScrollback flushes the scrollback buffer into a single tea.Println. If
|
||||
// the stream component has unflushed content, it is automatically prepended so
|
||||
// that new messages always appear below the previous assistant response. When
|
||||
// stream content is flushed a ClearScreen follows to clean up orphaned terminal
|
||||
// rows left after the view height shrinks. Returns nil if there is nothing to
|
||||
// print.
|
||||
func (m *AppModel) drainScrollback() tea.Cmd {
|
||||
if len(m.scrollbackBuf) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var parts []string
|
||||
needsClear := false
|
||||
|
||||
// Auto-flush any stream content so it appears before new messages.
|
||||
if m.stream != nil {
|
||||
if content := m.stream.GetRenderedContent(); content != "" {
|
||||
m.stream.Reset()
|
||||
parts = append(parts, content)
|
||||
needsClear = true
|
||||
}
|
||||
}
|
||||
|
||||
parts = append(parts, m.scrollbackBuf...)
|
||||
m.scrollbackBuf = m.scrollbackBuf[:0]
|
||||
|
||||
printCmd := tea.Println(strings.Join(parts, "\n"))
|
||||
if needsClear {
|
||||
return tea.Sequence(
|
||||
printCmd,
|
||||
func() tea.Msg { return tea.ClearScreen() },
|
||||
)
|
||||
}
|
||||
return printCmd
|
||||
}
|
||||
|
||||
// distributeHeight recalculates child component heights after a window resize,
|
||||
@@ -2109,7 +2234,7 @@ func (m *AppModel) flushStreamContent() tea.Cmd {
|
||||
// stream region = total - header - separator(1) - widgets - queued(N*5) - input(measured) - widgets - statusBar(1) - footer
|
||||
// separator = 1 line
|
||||
// above widgets = measured dynamically
|
||||
// queued msgs = ~5 lines per message (padding + text + badge + padding)
|
||||
// queued msgs = measured dynamically via lipgloss.Height()
|
||||
// input region = measured dynamically via lipgloss.Height()
|
||||
// below widgets = measured dynamically
|
||||
// status bar = 1 line (always present)
|
||||
@@ -2125,8 +2250,12 @@ func (m *AppModel) distributeHeight() {
|
||||
if vis.HideStatusBar {
|
||||
statusBarLines = 0
|
||||
}
|
||||
const linesPerQueuedMsg = 5
|
||||
queuedLines := len(m.queuedMessages) * linesPerQueuedMsg
|
||||
// Measure actual queued message height instead of using a fixed estimate,
|
||||
// since text wrapping at different widths changes the rendered line count.
|
||||
var queuedLines int
|
||||
if queuedView := m.renderQueuedMessages(); queuedView != "" {
|
||||
queuedLines = lipgloss.Height(queuedView)
|
||||
}
|
||||
|
||||
// Propagate hint visibility before measuring input height.
|
||||
if ic, ok := m.input.(*InputComponent); ok {
|
||||
@@ -2173,6 +2302,17 @@ func (m *AppModel) distributeHeight() {
|
||||
}
|
||||
}
|
||||
|
||||
// clamp constrains v to the range [lo, hi].
|
||||
func clamp(v, lo, hi int) int {
|
||||
if v < lo {
|
||||
return lo
|
||||
}
|
||||
if v > hi {
|
||||
return hi
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// repeatRune returns a string consisting of n repetitions of r.
|
||||
func repeatRune(r rune, n int) string {
|
||||
if n <= 0 {
|
||||
@@ -2242,7 +2382,8 @@ func remapKey(name string) (tea.KeyPressMsg, bool) {
|
||||
// to that model directly.
|
||||
func (m *AppModel) handleModelCommand(args string) tea.Cmd {
|
||||
if m.setModel == nil {
|
||||
return m.printSystemMessage("Model switching is not available.")
|
||||
m.printSystemMessage("Model switching is not available.")
|
||||
return nil
|
||||
}
|
||||
|
||||
if args == "" {
|
||||
@@ -2256,7 +2397,8 @@ func (m *AppModel) handleModelCommand(args string) tea.Cmd {
|
||||
// Direct model switch with the provided model string.
|
||||
previousModel := m.providerName + "/" + m.modelName
|
||||
if err := m.setModel(args); err != nil {
|
||||
return m.printSystemMessage(fmt.Sprintf("Failed to switch model: %v", err))
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to switch model: %v", err))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update display state directly (cannot use prog.Send from Update).
|
||||
@@ -2273,7 +2415,50 @@ func (m *AppModel) handleModelCommand(args string) tea.Cmd {
|
||||
go emit(newModel, prev, "user")
|
||||
}
|
||||
|
||||
return m.printSystemMessage(fmt.Sprintf("Switched to %s", args))
|
||||
m.printSystemMessage(fmt.Sprintf("Switched to %s", args))
|
||||
return nil
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Theme command handler
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// handleThemeCommand switches the active color theme. With no arguments it
|
||||
// lists available themes and highlights the active one. With a name argument
|
||||
// (e.g. "/theme catppuccin") it switches immediately.
|
||||
func (m *AppModel) handleThemeCommand(args string) tea.Cmd {
|
||||
if args == "" {
|
||||
// List available themes.
|
||||
names := ListThemes()
|
||||
active := ActiveThemeName()
|
||||
|
||||
var lines []string
|
||||
lines = append(lines, "Available themes:")
|
||||
for _, name := range names {
|
||||
if name == active {
|
||||
lines = append(lines, fmt.Sprintf(" * %s (active)", name))
|
||||
} else {
|
||||
lines = append(lines, fmt.Sprintf(" %s", name))
|
||||
}
|
||||
}
|
||||
lines = append(lines, "")
|
||||
lines = append(lines, fmt.Sprintf("User themes: %s", userThemesDir()))
|
||||
if pdir := projectThemesDir(); pdir != "" {
|
||||
lines = append(lines, fmt.Sprintf("Project themes: %s", pdir))
|
||||
} else {
|
||||
lines = append(lines, "Project themes: .kit/themes/ (not found)")
|
||||
}
|
||||
m.printSystemMessage(strings.Join(lines, "\n"))
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := ApplyTheme(args); err != nil {
|
||||
m.printSystemMessage(fmt.Sprintf("Theme error: %v", err))
|
||||
return nil
|
||||
}
|
||||
|
||||
m.printSystemMessage(fmt.Sprintf("Switched to theme: %s", args))
|
||||
return nil
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -2285,7 +2470,8 @@ func (m *AppModel) handleModelCommand(args string) tea.Cmd {
|
||||
// minimal, low, medium, high) it switches to that level.
|
||||
func (m *AppModel) handleThinkingCommand(args string) tea.Cmd {
|
||||
if !m.isReasoningModel {
|
||||
return m.printSystemMessage("Current model does not support thinking/reasoning.")
|
||||
m.printSystemMessage("Current model does not support thinking/reasoning.")
|
||||
return nil
|
||||
}
|
||||
|
||||
if args == "" {
|
||||
@@ -2300,13 +2486,15 @@ func (m *AppModel) handleThinkingCommand(args string) tea.Cmd {
|
||||
lines = append(lines, fmt.Sprintf("%s%s — %s", marker, l, models.ThinkingLevelDescription(l)))
|
||||
}
|
||||
header := fmt.Sprintf("Current thinking level: %s\n\nAvailable levels:", m.thinkingLevel)
|
||||
return m.printSystemMessage(header + "\n" + strings.Join(lines, "\n"))
|
||||
m.printSystemMessage(header + "\n" + strings.Join(lines, "\n"))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse and validate the level.
|
||||
level := models.ParseThinkingLevel(args)
|
||||
if string(level) != strings.ToLower(args) {
|
||||
return m.printSystemMessage(fmt.Sprintf("Unknown thinking level: %q. Use: off, minimal, low, medium, high", args))
|
||||
m.printSystemMessage(fmt.Sprintf("Unknown thinking level: %q. Use: off, minimal, low, medium, high", args))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Apply the change.
|
||||
@@ -2316,7 +2504,8 @@ func (m *AppModel) handleThinkingCommand(args string) tea.Cmd {
|
||||
_ = m.setThinkingLevel(string(level))
|
||||
}()
|
||||
}
|
||||
return m.printSystemMessage(fmt.Sprintf("Thinking level set to: %s — %s", level, models.ThinkingLevelDescription(level)))
|
||||
m.printSystemMessage(fmt.Sprintf("Thinking level set to: %s — %s", level, models.ThinkingLevelDescription(level)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -2327,10 +2516,12 @@ func (m *AppModel) handleThinkingCommand(args string) tea.Cmd {
|
||||
func (m *AppModel) handleTreeCommand() tea.Cmd {
|
||||
ts := m.appCtrl.GetTreeSession()
|
||||
if ts == nil {
|
||||
return m.printSystemMessage("No tree session active. Start with `--continue` or `--resume` to enable tree sessions.")
|
||||
m.printSystemMessage("No tree session active. Start with `--continue` or `--resume` to enable tree sessions.")
|
||||
return nil
|
||||
}
|
||||
if ts.EntryCount() == 0 {
|
||||
return m.printSystemMessage("No entries in session yet.")
|
||||
m.printSystemMessage("No entries in session yet.")
|
||||
return nil
|
||||
}
|
||||
|
||||
m.treeSelector = NewTreeSelector(ts, m.width, m.height)
|
||||
@@ -2343,10 +2534,12 @@ func (m *AppModel) handleTreeCommand() tea.Cmd {
|
||||
func (m *AppModel) handleForkCommand() tea.Cmd {
|
||||
ts := m.appCtrl.GetTreeSession()
|
||||
if ts == nil {
|
||||
return m.printSystemMessage("No tree session active. Start with `--continue` or `--resume` to enable tree sessions.")
|
||||
m.printSystemMessage("No tree session active. Start with `--continue` or `--resume` to enable tree sessions.")
|
||||
return nil
|
||||
}
|
||||
if ts.EntryCount() == 0 {
|
||||
return m.printSystemMessage("No entries to fork from.")
|
||||
m.printSystemMessage("No entries to fork from.")
|
||||
return nil
|
||||
}
|
||||
|
||||
m.treeSelector = NewTreeSelector(ts, m.width, m.height)
|
||||
@@ -2384,14 +2577,16 @@ func (m *AppModel) performNewSession() tea.Cmd {
|
||||
if m.appCtrl != nil {
|
||||
m.appCtrl.ClearMessages()
|
||||
}
|
||||
return m.printSystemMessage("Conversation cleared. Starting fresh.")
|
||||
m.printSystemMessage("Conversation cleared. Starting fresh.")
|
||||
return nil
|
||||
}
|
||||
|
||||
ts.ResetLeaf()
|
||||
if m.appCtrl != nil {
|
||||
m.appCtrl.ClearMessages()
|
||||
}
|
||||
return m.printSystemMessage("New branch started. Previous conversation is preserved in the tree.")
|
||||
m.printSystemMessage("New branch started. Previous conversation is preserved in the tree.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// performFork performs the actual tree branch. Called either directly (when no
|
||||
@@ -2399,7 +2594,8 @@ func (m *AppModel) performNewSession() tea.Cmd {
|
||||
func (m *AppModel) performFork(targetID string, isUser bool, userText string) tea.Cmd {
|
||||
ts := m.appCtrl.GetTreeSession()
|
||||
if ts == nil {
|
||||
return m.printSystemMessage("No tree session active.")
|
||||
m.printSystemMessage("No tree session active.")
|
||||
return nil
|
||||
}
|
||||
|
||||
_ = ts.Branch(targetID)
|
||||
@@ -2413,7 +2609,7 @@ func (m *AppModel) performFork(targetID string, isUser bool, userText string) te
|
||||
}
|
||||
}
|
||||
|
||||
return m.printSystemMessage(
|
||||
m.printSystemMessage(
|
||||
fmt.Sprintf("Navigated to branch point. %s",
|
||||
func() string {
|
||||
if isUser {
|
||||
@@ -2421,29 +2617,34 @@ func (m *AppModel) performFork(targetID string, isUser bool, userText string) te
|
||||
}
|
||||
return "Continue from this point."
|
||||
}()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleNameCommand sets a display name for the current session.
|
||||
func (m *AppModel) handleNameCommand() tea.Cmd {
|
||||
ts := m.appCtrl.GetTreeSession()
|
||||
if ts == nil {
|
||||
return m.printSystemMessage("No tree session active.")
|
||||
m.printSystemMessage("No tree session active.")
|
||||
return nil
|
||||
}
|
||||
// For now, prompt user to provide name via input. We print instructions
|
||||
// and the next non-command input starting with "name:" will be captured.
|
||||
// TODO: inline input dialog.
|
||||
currentName := ts.GetSessionName()
|
||||
if currentName != "" {
|
||||
return m.printSystemMessage(fmt.Sprintf("Current session name: %q\nTo rename, type: `/name <new name>` (not yet implemented — use the session file directly).", currentName))
|
||||
m.printSystemMessage(fmt.Sprintf("Current session name: %q\nTo rename, type: `/name <new name>` (not yet implemented — use the session file directly).", currentName))
|
||||
return nil
|
||||
}
|
||||
return m.printSystemMessage("To name this session, use: `/name <new name>` (not yet implemented — use the session file directly).")
|
||||
m.printSystemMessage("To name this session, use: `/name <new name>` (not yet implemented — use the session file directly).")
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleSessionInfoCommand shows session statistics.
|
||||
func (m *AppModel) handleSessionInfoCommand() tea.Cmd {
|
||||
ts := m.appCtrl.GetTreeSession()
|
||||
if ts == nil {
|
||||
return m.printSystemMessage("No tree session active.")
|
||||
m.printSystemMessage("No tree session active.")
|
||||
return nil
|
||||
}
|
||||
|
||||
header := ts.GetHeader()
|
||||
@@ -2468,7 +2669,8 @@ func (m *AppModel) handleSessionInfoCommand() tea.Cmd {
|
||||
info += fmt.Sprintf("- **Name:** %s\n", name)
|
||||
}
|
||||
|
||||
return m.printSystemMessage(info)
|
||||
m.printSystemMessage(info)
|
||||
return nil
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -2779,8 +2981,7 @@ func (m *AppModel) handleShellCommandResult(msg shellCommandResultMsg) tea.Cmd {
|
||||
WithMarginBottom(1),
|
||||
)
|
||||
|
||||
var cmds []tea.Cmd
|
||||
cmds = append(cmds, tea.Println(rendered))
|
||||
m.appendScrollback(rendered)
|
||||
|
||||
// For ! (included in context): inject the command output into the
|
||||
// conversation as a user message so the LLM can reference it on the
|
||||
@@ -2800,5 +3001,5 @@ func (m *AppModel) handleShellCommandResult(msg shellCommandResultMsg) tea.Cmd {
|
||||
m.appCtrl.AddContextMessage(contextMsg)
|
||||
}
|
||||
|
||||
return tea.Batch(cmds...)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -208,9 +208,20 @@ func (ms *ModelSelectorComponent) View() tea.View {
|
||||
// Header.
|
||||
b.WriteString(headerStyle.Render("Model Selector"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(helpStyle.Render("↑/↓: move enter: select esc: cancel type to filter"))
|
||||
// Adapt help text to terminal width.
|
||||
if ms.width >= 56 {
|
||||
b.WriteString(helpStyle.Render("↑/↓: move enter: select esc: cancel type to filter"))
|
||||
} else if ms.width >= 35 {
|
||||
b.WriteString(helpStyle.Render("↑↓ move ↵ select esc type"))
|
||||
} else {
|
||||
b.WriteString(helpStyle.Render("↑↓ ↵ esc"))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
b.WriteString(infoStyle.Render("Only showing models with configured API keys"))
|
||||
if ms.width >= 48 {
|
||||
b.WriteString(infoStyle.Render("Only showing models with configured API keys"))
|
||||
} else {
|
||||
b.WriteString(infoStyle.Render("Models with API keys"))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
|
||||
// Search input.
|
||||
@@ -281,9 +292,9 @@ func (ms *ModelSelectorComponent) IsActive() bool {
|
||||
// --- Internal helpers ---
|
||||
|
||||
func (ms *ModelSelectorComponent) visibleHeight() int {
|
||||
// Reserve: header(1) + help(1) + info(1) + search(1) + separator(1) + footer(2) = 7
|
||||
h := max(ms.height-7, 5)
|
||||
return h
|
||||
// Reserve: header(1) + help(1) + info(1) + search(1) + separator(1) + footer(2) = 7.
|
||||
// Minimum 3 entries so the selector is still usable on short terminals.
|
||||
return max(ms.height-7, 3)
|
||||
}
|
||||
|
||||
func (ms *ModelSelectorComponent) rebuildFiltered() {
|
||||
@@ -396,8 +407,37 @@ func (ms *ModelSelectorComponent) renderEntry(entry ModelEntry, isCursor bool) s
|
||||
|
||||
// Active model checkmark.
|
||||
var active string
|
||||
activeWidth := 0
|
||||
if entry.Provider+"/"+entry.ModelID == ms.currentModel {
|
||||
active = lipgloss.NewStyle().Foreground(theme.Success).Render(" \u2713")
|
||||
activeWidth = 2 // " ✓"
|
||||
}
|
||||
|
||||
// Truncate model ID and provider tag to fit terminal width.
|
||||
// Layout: cursor(3) + model + " " + provider + active.
|
||||
// Use rune length for display-width accuracy (the "…" suffix is 1 rune / 1 column).
|
||||
const cursorWidth = 3
|
||||
available := max(ms.width-cursorWidth-activeWidth-1, 10) // 1 for space between model and provider
|
||||
provDisplayLen := len([]rune(providerStr))
|
||||
modelDisplayLen := len([]rune(modelStr))
|
||||
|
||||
if modelDisplayLen+1+provDisplayLen > available {
|
||||
// Prioritize model name — truncate it, but keep provider visible.
|
||||
maxModel := max(available-provDisplayLen-1, 6)
|
||||
if maxModel < modelDisplayLen {
|
||||
if maxModel > 3 {
|
||||
runes := []rune(modelStr)
|
||||
modelStr = string(runes[:maxModel-1]) + "…"
|
||||
} else {
|
||||
runes := []rune(modelStr)
|
||||
modelStr = string(runes[:maxModel])
|
||||
}
|
||||
}
|
||||
// If provider itself is too long, drop it.
|
||||
modelDisplayLen = len([]rune(modelStr))
|
||||
if modelDisplayLen+1+provDisplayLen > available {
|
||||
providerStr = ""
|
||||
}
|
||||
}
|
||||
|
||||
// Style the model ID.
|
||||
@@ -409,5 +449,9 @@ func (ms *ModelSelectorComponent) renderEntry(entry ModelEntry, isCursor bool) s
|
||||
// Style the provider tag.
|
||||
providerStyle := lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
|
||||
return cursor + modelStyle.Render(modelStr) + " " + providerStyle.Render(providerStr) + active
|
||||
result := cursor + modelStyle.Render(modelStr)
|
||||
if providerStr != "" {
|
||||
result += " " + providerStyle.Render(providerStr)
|
||||
}
|
||||
return result + active
|
||||
}
|
||||
|
||||
@@ -116,7 +116,7 @@ func newTestAppModel(ctrl AppController) (*AppModel, *stubStreamComponent, *stub
|
||||
appCtrl: ctrl,
|
||||
stream: stream,
|
||||
input: input,
|
||||
renderer: NewMessageRenderer(80, false),
|
||||
renderer: newMessageRenderer(80, false),
|
||||
compactMode: false,
|
||||
modelName: "test-model",
|
||||
width: 80,
|
||||
@@ -405,14 +405,16 @@ func TestQueuedMessages_storedOnQueuedSubmit(t *testing.T) {
|
||||
}
|
||||
|
||||
// TestQueuedMessages_poppedOnQueueUpdated verifies that QueueUpdatedEvent pops
|
||||
// consumed messages from queuedMessages and prints them to scrollback.
|
||||
// consumed messages from queuedMessages and moves them to pendingUserPrints.
|
||||
// The actual printing is deferred to SpinnerEvent{Show: true} to preserve
|
||||
// chronological order with the preceding assistant response.
|
||||
func TestQueuedMessages_poppedOnQueueUpdated(t *testing.T) {
|
||||
ctrl := &stubAppController{}
|
||||
m, _, _ := newTestAppModel(ctrl)
|
||||
m.queuedMessages = []string{"first", "second", "third"}
|
||||
|
||||
// Simulate drainQueue popping one item (length goes from 3 to 2).
|
||||
_, cmd := m.Update(app.QueueUpdatedEvent{Length: 2})
|
||||
m = sendMsg(m, app.QueueUpdatedEvent{Length: 2})
|
||||
|
||||
if len(m.queuedMessages) != 2 {
|
||||
t.Fatalf("expected 2 queued messages after pop, got %d", len(m.queuedMessages))
|
||||
@@ -420,14 +422,17 @@ func TestQueuedMessages_poppedOnQueueUpdated(t *testing.T) {
|
||||
if m.queuedMessages[0] != "second" {
|
||||
t.Fatalf("expected first remaining message 'second', got %q", m.queuedMessages[0])
|
||||
}
|
||||
// Should produce a cmd (tea.Println for the popped user message).
|
||||
if cmd == nil {
|
||||
t.Fatal("expected non-nil cmd (tea.Println) for popped message")
|
||||
// Popped message should be deferred to pendingUserPrints.
|
||||
if len(m.pendingUserPrints) != 1 {
|
||||
t.Fatalf("expected 1 pending user print, got %d", len(m.pendingUserPrints))
|
||||
}
|
||||
if m.pendingUserPrints[0] != "first" {
|
||||
t.Fatalf("expected pending message 'first', got %q", m.pendingUserPrints[0])
|
||||
}
|
||||
}
|
||||
|
||||
// TestQueuedMessages_allPoppedOnDrain verifies that QueueUpdatedEvent with
|
||||
// Length=0 pops all remaining queued messages.
|
||||
// Length=0 pops all remaining queued messages into pendingUserPrints.
|
||||
func TestQueuedMessages_allPoppedOnDrain(t *testing.T) {
|
||||
ctrl := &stubAppController{}
|
||||
m, _, _ := newTestAppModel(ctrl)
|
||||
@@ -438,6 +443,9 @@ func TestQueuedMessages_allPoppedOnDrain(t *testing.T) {
|
||||
if len(m.queuedMessages) != 0 {
|
||||
t.Fatalf("expected 0 queued messages after drain, got %d", len(m.queuedMessages))
|
||||
}
|
||||
if len(m.pendingUserPrints) != 2 {
|
||||
t.Fatalf("expected 2 pending user prints, got %d", len(m.pendingUserPrints))
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
+28
-26
@@ -135,31 +135,24 @@ func (o *overlayDialog) handleKey(msg tea.KeyPressMsg) (*overlayResult, tea.Cmd)
|
||||
func (o *overlayDialog) Render() string {
|
||||
theme := GetTheme()
|
||||
|
||||
// Calculate dialog dimensions.
|
||||
// Calculate dialog dimensions, clamped to terminal bounds.
|
||||
termW := max(o.width, 10)
|
||||
termH := max(o.height, 5)
|
||||
|
||||
dw := o.dialogWidth
|
||||
if dw == 0 {
|
||||
dw = o.width * 60 / 100
|
||||
}
|
||||
if dw < 30 {
|
||||
dw = 30
|
||||
}
|
||||
if dw > o.width-4 {
|
||||
dw = o.width - 4
|
||||
dw = termW * 60 / 100
|
||||
}
|
||||
dw = clamp(dw, min(24, termW), termW-2)
|
||||
|
||||
mh := o.maxHeight
|
||||
if mh == 0 {
|
||||
mh = o.height * 80 / 100
|
||||
}
|
||||
if mh < 8 {
|
||||
mh = 8
|
||||
}
|
||||
if mh > o.height-2 {
|
||||
mh = o.height - 2
|
||||
mh = termH * 80 / 100
|
||||
}
|
||||
mh = clamp(mh, min(6, termH), termH)
|
||||
|
||||
// Inner width accounts for border (2) + horizontal padding (2 left + 1 right).
|
||||
innerWidth := max(dw-5, 10)
|
||||
innerWidth := max(dw-5, 6)
|
||||
|
||||
// Render body text (potentially as markdown).
|
||||
bodyText := o.content
|
||||
@@ -249,7 +242,7 @@ func (o *overlayDialog) Render() string {
|
||||
innerContent := strings.Join(parts, "\n")
|
||||
|
||||
// Resolve border color.
|
||||
borderClr := lipgloss.Color("#89b4fa") // default blue
|
||||
borderClr := theme.Info
|
||||
if o.borderColor != "" {
|
||||
borderClr = lipgloss.Color(o.borderColor)
|
||||
}
|
||||
@@ -268,18 +261,27 @@ func (o *overlayDialog) Render() string {
|
||||
|
||||
dialog := dialogStyle.Render(innerContent)
|
||||
|
||||
// Key hints below the dialog.
|
||||
// Key hints below the dialog, adapted to width.
|
||||
var hints []string
|
||||
if scrollable {
|
||||
hints = append(hints, "↑/↓ scroll")
|
||||
}
|
||||
if len(o.actions) > 0 {
|
||||
hints = append(hints, "←/→ switch")
|
||||
hints = append(hints, "Enter select")
|
||||
if termW >= 50 {
|
||||
if scrollable {
|
||||
hints = append(hints, "↑/↓ scroll")
|
||||
}
|
||||
if len(o.actions) > 0 {
|
||||
hints = append(hints, "←/→ switch")
|
||||
hints = append(hints, "Enter select")
|
||||
} else {
|
||||
hints = append(hints, "Enter dismiss")
|
||||
}
|
||||
hints = append(hints, "Esc cancel")
|
||||
} else {
|
||||
hints = append(hints, "Enter dismiss")
|
||||
if len(o.actions) > 0 {
|
||||
hints = append(hints, "↵ select")
|
||||
} else {
|
||||
hints = append(hints, "↵ ok")
|
||||
}
|
||||
hints = append(hints, "esc")
|
||||
}
|
||||
hints = append(hints, "Esc cancel")
|
||||
hintText := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Render(" " + strings.Join(hints, " "))
|
||||
|
||||
@@ -83,7 +83,7 @@ func newInputPrompt(message, placeholder, defaultValue string, width, height int
|
||||
|
||||
// Prevent Enter from inserting a newline — we intercept it for submit.
|
||||
ta.KeyMap.InsertNewline = key.NewBinding(
|
||||
key.WithKeys("ctrl+j", "alt+enter"),
|
||||
key.WithKeys("ctrl+j", "shift+enter"),
|
||||
)
|
||||
|
||||
if defaultValue != "" {
|
||||
|
||||
@@ -42,18 +42,19 @@ func NewSlashCommandInput(width int, title string) *SlashCommandInput {
|
||||
ta.SetHeight(3) // Default to 3 lines like huh
|
||||
ta.Focus()
|
||||
|
||||
// Override InsertNewline so only ctrl+j and alt+enter insert newlines.
|
||||
// Override InsertNewline so only ctrl+j and shift+enter insert newlines.
|
||||
// Enter always submits the input.
|
||||
ta.KeyMap.InsertNewline = key.NewBinding(
|
||||
key.WithKeys("ctrl+j", "alt+enter"),
|
||||
key.WithKeys("ctrl+j", "shift+enter"),
|
||||
key.WithHelp("ctrl+j", "insert newline"),
|
||||
)
|
||||
|
||||
// Style the textarea to match huh theme
|
||||
// Style the textarea using theme colors.
|
||||
theme := GetTheme()
|
||||
styles := ta.Styles()
|
||||
styles.Focused.Base = lipgloss.NewStyle()
|
||||
styles.Focused.Placeholder = lipgloss.NewStyle().Foreground(lipgloss.Color("240"))
|
||||
styles.Focused.Text = lipgloss.NewStyle().Foreground(lipgloss.Color("252"))
|
||||
styles.Focused.Placeholder = lipgloss.NewStyle().Foreground(theme.VeryMuted)
|
||||
styles.Focused.Text = lipgloss.NewStyle().Foreground(theme.Text)
|
||||
styles.Focused.Prompt = lipgloss.NewStyle()
|
||||
styles.Focused.CursorLine = lipgloss.NewStyle()
|
||||
ta.SetStyles(styles)
|
||||
@@ -178,9 +179,11 @@ func (s *SlashCommandInput) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
func (s *SlashCommandInput) View() tea.View {
|
||||
containerStyle := lipgloss.NewStyle()
|
||||
|
||||
theme := GetTheme()
|
||||
|
||||
// PaddingLeft(3) aligns with message content: border(1) + paddingLeft(2).
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("252")).
|
||||
Foreground(theme.Text).
|
||||
MarginBottom(1).
|
||||
PaddingLeft(3)
|
||||
|
||||
@@ -191,7 +194,7 @@ func (s *SlashCommandInput) View() tea.View {
|
||||
BorderRight(false).
|
||||
BorderTop(false).
|
||||
BorderBottom(false).
|
||||
BorderForeground(lipgloss.Color("39")).
|
||||
BorderForeground(theme.Primary).
|
||||
PaddingLeft(2). // match message block paddingLeft
|
||||
Width(s.width - 1) // full width minus left border
|
||||
|
||||
@@ -223,11 +226,11 @@ func (s *SlashCommandInput) View() tea.View {
|
||||
// Add help text at bottom (unless hidden by extension).
|
||||
if !s.hideHint {
|
||||
helpStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("240")).
|
||||
Foreground(theme.VeryMuted).
|
||||
MarginTop(1).
|
||||
PaddingLeft(3)
|
||||
|
||||
helpText := "enter submit • ctrl+j / alt+enter new line"
|
||||
helpText := "enter submit • ctrl+j / shift+enter new line"
|
||||
|
||||
view.WriteString("\n")
|
||||
view.WriteString(helpStyle.Render(helpText))
|
||||
@@ -240,10 +243,12 @@ func (s *SlashCommandInput) View() tea.View {
|
||||
|
||||
// renderPopup renders the autocomplete popup
|
||||
func (s *SlashCommandInput) renderPopup() string {
|
||||
theme := GetTheme()
|
||||
|
||||
// Popup styling
|
||||
popupStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color("236")).
|
||||
BorderForeground(theme.MutedBorder).
|
||||
Padding(1, 2).
|
||||
Width(s.width - 4). // Account for container padding
|
||||
MarginLeft(0) // No extra margin needed due to container padding
|
||||
@@ -268,7 +273,7 @@ func (s *SlashCommandInput) renderPopup() string {
|
||||
var indicator string
|
||||
if i == s.selected {
|
||||
indicator = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("39")).
|
||||
Foreground(theme.Primary).
|
||||
Render("> ")
|
||||
} else {
|
||||
indicator = " "
|
||||
@@ -276,16 +281,16 @@ func (s *SlashCommandInput) renderPopup() string {
|
||||
|
||||
// Format item
|
||||
nameStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("39")).
|
||||
Foreground(theme.Secondary).
|
||||
Bold(true)
|
||||
|
||||
descStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("243"))
|
||||
Foreground(theme.Muted)
|
||||
|
||||
// Highlight selected item
|
||||
if i == s.selected {
|
||||
nameStyle = nameStyle.Foreground(lipgloss.Color("87"))
|
||||
descStyle = descStyle.Foreground(lipgloss.Color("250"))
|
||||
nameStyle = nameStyle.Foreground(theme.Primary)
|
||||
descStyle = descStyle.Foreground(theme.Text)
|
||||
}
|
||||
|
||||
// Format with proper spacing
|
||||
@@ -305,11 +310,11 @@ func (s *SlashCommandInput) renderPopup() string {
|
||||
|
||||
// Add scroll indicators if needed
|
||||
if startIdx > 0 {
|
||||
scrollUpStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("238"))
|
||||
scrollUpStyle := lipgloss.NewStyle().Foreground(theme.VeryMuted)
|
||||
items = append([]string{scrollUpStyle.Render(" ↑ more above")}, items...)
|
||||
}
|
||||
if endIdx < len(s.filtered) {
|
||||
scrollDownStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("238"))
|
||||
scrollDownStyle := lipgloss.NewStyle().Foreground(theme.VeryMuted)
|
||||
items = append(items, scrollDownStyle.Render(" ↓ more below"))
|
||||
}
|
||||
// Join items
|
||||
@@ -317,7 +322,7 @@ func (s *SlashCommandInput) renderPopup() string {
|
||||
|
||||
// Add footer hint
|
||||
footerStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("238")).
|
||||
Foreground(theme.VeryMuted).
|
||||
Italic(true)
|
||||
footer := footerStyle.Render("↑↓ navigate • tab complete • ↵ select • esc dismiss")
|
||||
|
||||
|
||||
+258
-47
@@ -1,6 +1,7 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -58,14 +59,39 @@ func knightRiderFrames() []string {
|
||||
}
|
||||
|
||||
// streamSpinnerTickMsg is the internal tick message that drives the KITT-style
|
||||
// spinner animation inside StreamComponent.
|
||||
type streamSpinnerTickMsg struct{}
|
||||
// spinner animation inside StreamComponent. The generation field ties each tick
|
||||
// to the spinner session that created it so that stale ticks from a previous
|
||||
// start/stop cycle are silently discarded instead of creating a second
|
||||
// concurrent tick loop (which doubles the animation speed).
|
||||
type streamSpinnerTickMsg struct {
|
||||
generation uint64
|
||||
}
|
||||
|
||||
// streamSpinnerTickCmd returns a tea.Cmd that fires streamSpinnerTickMsg at the
|
||||
// KITT animation frame rate (14 fps).
|
||||
func streamSpinnerTickCmd() tea.Cmd {
|
||||
// KITT animation frame rate (14 fps). The generation parameter is embedded in
|
||||
// the message so the receiver can verify it matches the current spinner session.
|
||||
func streamSpinnerTickCmd(generation uint64) tea.Cmd {
|
||||
return tea.Tick(time.Second/14, func(_ time.Time) tea.Msg {
|
||||
return streamSpinnerTickMsg{}
|
||||
return streamSpinnerTickMsg{generation: generation}
|
||||
})
|
||||
}
|
||||
|
||||
// streamFlushTickMsg fires when it's time to commit pending chunks to the
|
||||
// main content builders and trigger a re-render. This coalesces rapid
|
||||
// streaming chunks into fewer expensive markdown re-renders.
|
||||
type streamFlushTickMsg struct{}
|
||||
|
||||
// streamFlushInterval is the coalescing window for stream chunks. Chunks
|
||||
// arriving within this window are batched into a single render pass.
|
||||
// 16ms ≈ 60 fps — fast enough to appear smooth, slow enough to coalesce
|
||||
// bursts from the LLM provider.
|
||||
const streamFlushInterval = 16 * time.Millisecond
|
||||
|
||||
// streamFlushTickCmd returns a tea.Cmd that fires streamFlushTickMsg after
|
||||
// the coalescing interval.
|
||||
func streamFlushTickCmd() tea.Cmd {
|
||||
return tea.Tick(streamFlushInterval, func(_ time.Time) tea.Msg {
|
||||
return streamFlushTickMsg{}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -108,25 +134,62 @@ type StreamComponent struct {
|
||||
// remains visible alongside streaming text until Reset().
|
||||
spinning bool
|
||||
|
||||
// spinnerGeneration is incremented each time a new spinner tick loop
|
||||
// is started. Tick messages carry the generation they were created for;
|
||||
// if a tick's generation doesn't match the current one, it is a stale
|
||||
// tick from a previous start/stop cycle and is silently discarded.
|
||||
// This prevents multiple concurrent tick loops from accumulating when
|
||||
// the spinner is rapidly stopped and restarted (e.g. SpinnerEvent
|
||||
// hide → ToolExecutionEvent start before the old tick fires).
|
||||
spinnerGeneration uint64
|
||||
|
||||
// spinnerFrames are the pre-rendered KITT animation frames.
|
||||
spinnerFrames []string
|
||||
|
||||
// spinnerFrame is the current frame index.
|
||||
spinnerFrame int
|
||||
|
||||
// spinnerMsg is the label shown next to the KITT animation (e.g.
|
||||
// "Executing tool_name…"). Empty string means no label.
|
||||
spinnerMsg string
|
||||
// activeTools tracks the names of tools currently executing in parallel.
|
||||
// When multiple tools run concurrently, all are displayed in the spinner.
|
||||
activeTools []string
|
||||
|
||||
// streamContent accumulates all streaming text chunks.
|
||||
// streamContent holds committed streaming text (flushed from pending).
|
||||
streamContent strings.Builder
|
||||
|
||||
// reasoningContent accumulates reasoning/thinking text chunks.
|
||||
// reasoningContent holds committed reasoning text (flushed from pending).
|
||||
reasoningContent strings.Builder
|
||||
|
||||
// thinkingVisible controls whether reasoning blocks are shown or collapsed.
|
||||
// pendingStream accumulates streaming text chunks between flush ticks.
|
||||
// Chunks are written here immediately on arrival, then moved to
|
||||
// streamContent when the flush tick fires.
|
||||
pendingStream strings.Builder
|
||||
|
||||
// pendingReasoning accumulates reasoning chunks between flush ticks.
|
||||
pendingReasoning strings.Builder
|
||||
|
||||
// flushPending is true while a flush tick is in-flight. Prevents
|
||||
// scheduling duplicate ticks when multiple chunks arrive within
|
||||
// the same coalescing window.
|
||||
flushPending bool
|
||||
|
||||
// renderCache holds the last rendered output string. Reused by View()
|
||||
// between flush ticks to avoid redundant markdown re-parsing.
|
||||
renderCache string
|
||||
|
||||
// renderDirty is true when committed content has changed since the
|
||||
// last render. Set on flush tick; cleared after render() rebuilds
|
||||
// the cache.
|
||||
renderDirty bool
|
||||
|
||||
// thinkingVisible controls whether reasoning blocks are expanded or collapsed.
|
||||
thinkingVisible bool
|
||||
|
||||
// reasoningStartTime records when the first reasoning chunk was received.
|
||||
reasoningStartTime time.Time
|
||||
|
||||
// reasoningDuration holds the total reasoning time, frozen when streaming text begins.
|
||||
reasoningDuration time.Duration
|
||||
|
||||
// messageRenderer renders assistant messages in standard mode.
|
||||
messageRenderer *MessageRenderer
|
||||
|
||||
@@ -159,7 +222,7 @@ func NewStreamComponent(compactMode bool, width int, modelName string) *StreamCo
|
||||
spinnerFrames: knightRiderFrames(),
|
||||
compactMode: compactMode,
|
||||
modelName: modelName,
|
||||
messageRenderer: NewMessageRenderer(width, false),
|
||||
messageRenderer: newMessageRenderer(width, false),
|
||||
compactRenderer: NewCompactRenderer(width, false),
|
||||
width: width,
|
||||
}
|
||||
@@ -172,7 +235,12 @@ func (s *StreamComponent) SetHeight(h int) {
|
||||
if h < 0 {
|
||||
h = 0
|
||||
}
|
||||
s.height = h
|
||||
if s.height != h {
|
||||
s.height = h
|
||||
// Invalidate cache — height clamp affects output.
|
||||
s.renderCache = ""
|
||||
s.renderDirty = true
|
||||
}
|
||||
}
|
||||
|
||||
// Reset clears all accumulated state so the component is ready for the next
|
||||
@@ -180,17 +248,31 @@ func (s *StreamComponent) SetHeight(h int) {
|
||||
func (s *StreamComponent) Reset() {
|
||||
s.phase = streamPhaseIdle
|
||||
s.spinning = false
|
||||
s.spinnerGeneration++ // invalidate any in-flight tick commands
|
||||
s.spinnerFrame = 0
|
||||
s.spinnerMsg = ""
|
||||
s.activeTools = nil
|
||||
s.streamContent.Reset()
|
||||
s.reasoningContent.Reset()
|
||||
s.pendingStream.Reset()
|
||||
s.pendingReasoning.Reset()
|
||||
s.flushPending = false
|
||||
s.renderCache = ""
|
||||
s.renderDirty = false
|
||||
s.timestamp = time.Time{}
|
||||
s.reasoningStartTime = time.Time{}
|
||||
s.reasoningDuration = 0
|
||||
}
|
||||
|
||||
// GetRenderedContent returns the rendered assistant message from the accumulated
|
||||
// streaming text. Returns empty string if no text has been accumulated. Used by
|
||||
// the parent AppModel to flush content via tea.Println() before resetting.
|
||||
//
|
||||
// This commits any pending chunks first so the output includes all received
|
||||
// content, not just what has been flushed by the tick.
|
||||
func (s *StreamComponent) GetRenderedContent() string {
|
||||
// Commit any pending chunks so the final output is complete.
|
||||
s.commitPending()
|
||||
|
||||
var sections []string
|
||||
|
||||
// Include rendered reasoning block if present.
|
||||
@@ -209,6 +291,21 @@ func (s *StreamComponent) GetRenderedContent() string {
|
||||
return strings.Join(sections, "\n")
|
||||
}
|
||||
|
||||
// commitPending moves any pending chunks to the committed content builders.
|
||||
// Called before reading content for scrollback output or on flush tick.
|
||||
func (s *StreamComponent) commitPending() {
|
||||
if s.pendingStream.Len() > 0 {
|
||||
s.streamContent.WriteString(s.pendingStream.String())
|
||||
s.pendingStream.Reset()
|
||||
s.renderDirty = true
|
||||
}
|
||||
if s.pendingReasoning.Len() > 0 {
|
||||
s.reasoningContent.WriteString(s.pendingReasoning.String())
|
||||
s.pendingReasoning.Reset()
|
||||
s.renderDirty = true
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// tea.Model interface
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -227,13 +324,20 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
s.width = msg.Width
|
||||
s.messageRenderer.SetWidth(s.width)
|
||||
s.compactRenderer.SetWidth(s.width)
|
||||
// Invalidate render cache — width change affects wrapping/styling.
|
||||
s.renderCache = ""
|
||||
s.renderDirty = true
|
||||
|
||||
case streamSpinnerTickMsg:
|
||||
if s.spinning {
|
||||
// Only continue the tick loop if this tick belongs to the current
|
||||
// spinner session. Stale ticks from a previous start/stop cycle
|
||||
// are silently dropped, preventing duplicate concurrent tick loops
|
||||
// that would double (or worse) the animation speed.
|
||||
if s.spinning && msg.generation == s.spinnerGeneration {
|
||||
s.spinnerFrame++
|
||||
return s, streamSpinnerTickCmd()
|
||||
return s, streamSpinnerTickCmd(s.spinnerGeneration)
|
||||
}
|
||||
// Spinning stopped; let the tick loop die naturally.
|
||||
// Spinning stopped or generation mismatch; let the tick loop die.
|
||||
|
||||
// ── App-layer events ──────────────────────────────────────────────────
|
||||
|
||||
@@ -241,42 +345,68 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if msg.Show && !s.spinning {
|
||||
s.phase = streamPhaseActive
|
||||
s.spinning = true
|
||||
s.spinnerGeneration++ // new session; invalidate any stale ticks
|
||||
s.spinnerFrame = 0
|
||||
if s.timestamp.IsZero() {
|
||||
s.timestamp = time.Now()
|
||||
}
|
||||
return s, streamSpinnerTickCmd()
|
||||
return s, streamSpinnerTickCmd(s.spinnerGeneration)
|
||||
} else if !msg.Show && s.spinning {
|
||||
s.spinning = false
|
||||
// Bump generation so any in-flight tick from this session is
|
||||
// discarded if spinning is restarted before it fires.
|
||||
s.spinnerGeneration++
|
||||
}
|
||||
|
||||
case streamFlushTickMsg:
|
||||
s.flushPending = false
|
||||
s.commitPending()
|
||||
|
||||
case app.ReasoningChunkEvent:
|
||||
s.phase = streamPhaseActive
|
||||
if s.timestamp.IsZero() {
|
||||
s.timestamp = time.Now()
|
||||
}
|
||||
s.reasoningContent.WriteString(msg.Delta)
|
||||
if s.reasoningStartTime.IsZero() {
|
||||
s.reasoningStartTime = time.Now()
|
||||
}
|
||||
s.pendingReasoning.WriteString(msg.Delta)
|
||||
if !s.flushPending {
|
||||
s.flushPending = true
|
||||
return s, streamFlushTickCmd()
|
||||
}
|
||||
|
||||
case app.StreamChunkEvent:
|
||||
s.phase = streamPhaseActive
|
||||
if s.timestamp.IsZero() {
|
||||
s.timestamp = time.Now()
|
||||
}
|
||||
s.streamContent.WriteString(msg.Content)
|
||||
// Freeze reasoning duration on transition from reasoning to streaming.
|
||||
if s.reasoningDuration == 0 && !s.reasoningStartTime.IsZero() {
|
||||
s.reasoningDuration = time.Since(s.reasoningStartTime)
|
||||
}
|
||||
s.pendingStream.WriteString(msg.Content)
|
||||
if !s.flushPending {
|
||||
s.flushPending = true
|
||||
return s, streamFlushTickCmd()
|
||||
}
|
||||
|
||||
case app.ToolExecutionEvent:
|
||||
if msg.IsStarting {
|
||||
// Show the tool name on the spinner while the tool executes.
|
||||
s.spinnerMsg = "Executing " + msg.ToolName + "…"
|
||||
// Add tool to active list for parallel execution display.
|
||||
toolDisplay := formatToolExecutionMessage(msg.ToolName, msg.ToolArgs)
|
||||
s.activeTools = append(s.activeTools, toolDisplay)
|
||||
s.spinnerFrame = 0
|
||||
if !s.spinning {
|
||||
s.phase = streamPhaseActive
|
||||
s.spinning = true
|
||||
return s, streamSpinnerTickCmd()
|
||||
s.spinnerGeneration++ // new session; invalidate stale ticks
|
||||
return s, streamSpinnerTickCmd(s.spinnerGeneration)
|
||||
}
|
||||
} else {
|
||||
// Tool finished — clear execution label but keep spinning.
|
||||
s.spinnerMsg = ""
|
||||
// Tool finished — remove from active list but keep spinning if others remain.
|
||||
toolDisplay := formatToolExecutionMessage(msg.ToolName, msg.ToolArgs)
|
||||
s.activeTools = removeFromSlice(s.activeTools, toolDisplay)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -292,12 +422,20 @@ func (s *StreamComponent) View() tea.View {
|
||||
// Internal rendering
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// render builds the full content string for the stream region.
|
||||
// render builds the full content string for the stream region. Uses a render
|
||||
// cache to avoid redundant markdown re-parsing between flush ticks. The cache
|
||||
// is invalidated when committed content changes (flush tick), terminal width
|
||||
// changes, or height/thinking visibility changes.
|
||||
func (s *StreamComponent) render() string {
|
||||
if s.phase == streamPhaseIdle {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Return cached render if committed content hasn't changed.
|
||||
if !s.renderDirty {
|
||||
return s.renderCache
|
||||
}
|
||||
|
||||
var sections []string
|
||||
|
||||
// Render reasoning/thinking block above the main text if present.
|
||||
@@ -313,6 +451,8 @@ func (s *StreamComponent) render() string {
|
||||
}
|
||||
|
||||
if len(sections) == 0 {
|
||||
s.renderCache = ""
|
||||
s.renderDirty = false
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -328,42 +468,86 @@ func (s *StreamComponent) render() string {
|
||||
}
|
||||
}
|
||||
|
||||
s.renderCache = content
|
||||
s.renderDirty = false
|
||||
return content
|
||||
}
|
||||
|
||||
// renderReasoningBlock renders the reasoning/thinking content. When thinking
|
||||
// is visible, the full reasoning text is shown in muted italic style. When
|
||||
// collapsed, a "Thinking..." label is shown instead.
|
||||
// renderReasoningBlock renders the reasoning/thinking content in a surface-tinted
|
||||
// box. When collapsed, shows the last 10 lines with a truncation hint. When
|
||||
// expanded, shows all lines. Includes a "Thought for Xs" duration footer.
|
||||
func (s *StreamComponent) renderReasoningBlock(reasoning string) string {
|
||||
theme := GetTheme()
|
||||
maxWidth := max(s.width-4, 20)
|
||||
|
||||
if !s.thinkingVisible {
|
||||
// Show collapsed "Thinking..." label.
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Italic(true).
|
||||
Render("Thinking...")
|
||||
}
|
||||
lines := strings.Split(strings.TrimRight(reasoning, "\n"), "\n")
|
||||
|
||||
// Render full reasoning text in muted italic style.
|
||||
style := lipgloss.NewStyle().
|
||||
contentStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Italic(true)
|
||||
|
||||
// Wrap to terminal width.
|
||||
maxWidth := max(s.width-4, 20) // leave some margin
|
||||
styled := style.Width(maxWidth).Render(reasoning)
|
||||
return styled
|
||||
var parts []string
|
||||
|
||||
// When collapsed and content exceeds 10 lines, show only the last 10
|
||||
// with a truncation hint (matching iteratr's thinking block pattern).
|
||||
const maxCollapsedLines = 10
|
||||
if !s.thinkingVisible && len(lines) > maxCollapsedLines {
|
||||
hidden := len(lines) - maxCollapsedLines
|
||||
hintStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.VeryMuted).
|
||||
Italic(true)
|
||||
parts = append(parts, hintStyle.Render(fmt.Sprintf("... (%d lines hidden)", hidden)))
|
||||
lines = lines[len(lines)-maxCollapsedLines:]
|
||||
}
|
||||
|
||||
// Render reasoning text.
|
||||
parts = append(parts, contentStyle.Width(maxWidth).Render(strings.Join(lines, "\n")))
|
||||
|
||||
// Duration footer.
|
||||
var duration time.Duration
|
||||
if s.reasoningDuration > 0 {
|
||||
duration = s.reasoningDuration
|
||||
} else if !s.reasoningStartTime.IsZero() {
|
||||
duration = time.Since(s.reasoningStartTime)
|
||||
}
|
||||
if duration > 0 {
|
||||
var durationStr string
|
||||
if duration < time.Second {
|
||||
durationStr = fmt.Sprintf("%dms", duration.Milliseconds())
|
||||
} else {
|
||||
durationStr = fmt.Sprintf("%.1fs", duration.Seconds())
|
||||
}
|
||||
footer := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render("Thought for ") +
|
||||
lipgloss.NewStyle().Foreground(theme.Info).Render(durationStr)
|
||||
parts = append(parts, footer)
|
||||
}
|
||||
|
||||
innerContent := strings.Join(parts, "\n")
|
||||
|
||||
// Wrap in box with surface background for visual distinction.
|
||||
boxStyle := lipgloss.NewStyle().
|
||||
Background(theme.MutedBorder). // Surface0 (#313244)
|
||||
PaddingLeft(1).
|
||||
Width(maxWidth + 2).
|
||||
MarginBottom(1)
|
||||
|
||||
return boxStyle.Render(innerContent)
|
||||
}
|
||||
|
||||
// SetThinkingVisible sets whether reasoning blocks are shown or collapsed.
|
||||
func (s *StreamComponent) SetThinkingVisible(visible bool) {
|
||||
s.thinkingVisible = visible
|
||||
if s.thinkingVisible != visible {
|
||||
s.thinkingVisible = visible
|
||||
// Invalidate cache — thinking visibility affects rendered output.
|
||||
s.renderCache = ""
|
||||
s.renderDirty = true
|
||||
}
|
||||
}
|
||||
|
||||
// HasReasoning returns true if any reasoning content has been accumulated.
|
||||
// HasReasoning returns true if any reasoning content has been accumulated
|
||||
// (committed or pending).
|
||||
func (s *StreamComponent) HasReasoning() bool {
|
||||
return s.reasoningContent.Len() > 0
|
||||
return s.reasoningContent.Len() > 0 || s.pendingReasoning.Len() > 0
|
||||
}
|
||||
|
||||
// SpinnerView returns the rendered spinner line for the parent to embed in the
|
||||
@@ -373,14 +557,22 @@ func (s *StreamComponent) SpinnerView() string {
|
||||
return ""
|
||||
}
|
||||
frame := s.spinnerFrames[s.spinnerFrame%len(s.spinnerFrames)]
|
||||
if s.spinnerMsg == "" {
|
||||
if len(s.activeTools) == 0 {
|
||||
return " " + frame
|
||||
}
|
||||
theme := GetTheme()
|
||||
msgStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Text).
|
||||
Italic(true)
|
||||
return " " + frame + " " + msgStyle.Render(s.spinnerMsg)
|
||||
|
||||
// Format active tools list
|
||||
var toolsMsg string
|
||||
if len(s.activeTools) == 1 {
|
||||
toolsMsg = s.activeTools[0]
|
||||
} else {
|
||||
toolsMsg = "Running: " + strings.Join(s.activeTools, ", ")
|
||||
}
|
||||
return " " + frame + " " + msgStyle.Render(toolsMsg)
|
||||
}
|
||||
|
||||
// renderStreamingText renders the accumulated streaming text as a live assistant
|
||||
@@ -398,3 +590,22 @@ func (s *StreamComponent) renderStreamingText(text string) string {
|
||||
msg := s.messageRenderer.RenderAssistantMessage(text, ts, s.modelName)
|
||||
return msg.Content
|
||||
}
|
||||
|
||||
// removeFromSlice removes the first occurrence of a string from a slice.
|
||||
func removeFromSlice(slice []string, s string) []string {
|
||||
for i, v := range slice {
|
||||
if v == s {
|
||||
return append(slice[:i], slice[i+1:]...)
|
||||
}
|
||||
}
|
||||
return slice
|
||||
}
|
||||
|
||||
// formatToolExecutionMessage creates a descriptive spinner message for tool execution.
|
||||
// For spawn_subagent, it shows simply as "Subagent" with optional task preview.
|
||||
func formatToolExecutionMessage(toolName, toolArgs string) string {
|
||||
if toolName == "spawn_subagent" {
|
||||
return "Subagent"
|
||||
}
|
||||
return toolName
|
||||
}
|
||||
|
||||
+83
-124
@@ -1,11 +1,12 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image/color"
|
||||
|
||||
"charm.land/lipgloss/v2"
|
||||
"github.com/charmbracelet/glamour"
|
||||
"github.com/charmbracelet/glamour/ansi"
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// uintPtr returns a pointer to u. Used by ansi.StyleConfig fields.
|
||||
@@ -20,6 +21,18 @@ func BaseStyle() lipgloss.Style {
|
||||
return lipgloss.NewStyle()
|
||||
}
|
||||
|
||||
// colorHex converts a color.Color to a hex string suitable for ansi.StyleConfig.
|
||||
func colorHex(c color.Color) string {
|
||||
r, g, b, _ := c.RGBA()
|
||||
return fmt.Sprintf("#%02x%02x%02x", r>>8, g>>8, b>>8)
|
||||
}
|
||||
|
||||
// colorHexPtr returns a pointer to the hex string of a color.Color.
|
||||
func colorHexPtr(c color.Color) *string {
|
||||
s := colorHex(c)
|
||||
return &s
|
||||
}
|
||||
|
||||
// GetMarkdownRenderer creates and returns a configured glamour.TermRenderer for
|
||||
// rendering markdown content with syntax highlighting and proper formatting. The
|
||||
// renderer is customized with our theme colors and adapted to the specified width.
|
||||
@@ -31,169 +44,119 @@ func GetMarkdownRenderer(width int) *glamour.TermRenderer {
|
||||
return r
|
||||
}
|
||||
|
||||
// colorScheme holds resolved color values for markdown rendering.
|
||||
type colorScheme struct {
|
||||
text string
|
||||
muted string
|
||||
heading string
|
||||
emph string
|
||||
strong string
|
||||
link string
|
||||
code string
|
||||
err string
|
||||
keyword string
|
||||
str string
|
||||
number string
|
||||
comment string
|
||||
}
|
||||
|
||||
// resolveColorScheme determines the color palette based on user config and background.
|
||||
func resolveColorScheme() colorScheme {
|
||||
var mdTheme config.MarkdownTheme
|
||||
err := config.FilepathOr("markdown-theme", &mdTheme)
|
||||
fromConfig := err == nil && viper.InConfig("markdown-theme")
|
||||
|
||||
if fromConfig && IsDarkBackground() {
|
||||
return colorScheme{
|
||||
text: mdTheme.Text.Light, muted: mdTheme.Muted.Light,
|
||||
heading: mdTheme.Heading.Light, emph: mdTheme.Emph.Light,
|
||||
strong: mdTheme.Strong.Light, link: mdTheme.Link.Light,
|
||||
code: mdTheme.Code.Light, err: mdTheme.Error.Light,
|
||||
keyword: mdTheme.Keyword.Light, str: mdTheme.String.Light,
|
||||
number: mdTheme.Number.Light, comment: mdTheme.Comment.Light,
|
||||
}
|
||||
}
|
||||
if fromConfig {
|
||||
return colorScheme{
|
||||
text: mdTheme.Text.Dark, muted: mdTheme.Muted.Dark,
|
||||
heading: mdTheme.Heading.Dark, emph: mdTheme.Emph.Dark,
|
||||
strong: mdTheme.Strong.Dark, link: mdTheme.Link.Dark,
|
||||
code: mdTheme.Code.Dark, err: mdTheme.Error.Dark,
|
||||
keyword: mdTheme.Keyword.Dark, str: mdTheme.String.Dark,
|
||||
number: mdTheme.Number.Dark, comment: mdTheme.Comment.Dark,
|
||||
}
|
||||
}
|
||||
if IsDarkBackground() {
|
||||
return colorScheme{
|
||||
text: "#F9FAFB", muted: "#9CA3AF",
|
||||
heading: "#22D3EE", emph: "#FDE047",
|
||||
strong: "#F9FAFB", link: "#60A5FA",
|
||||
code: "#D1D5DB", err: "#F87171",
|
||||
keyword: "#C084FC", str: "#34D399",
|
||||
number: "#FBBF24", comment: "#9CA3AF",
|
||||
}
|
||||
}
|
||||
return colorScheme{
|
||||
text: "#1F2937", muted: "#6B7280",
|
||||
heading: "#0891B2", emph: "#D97706",
|
||||
strong: "#1F2937", link: "#2563EB",
|
||||
code: "#374151", err: "#DC2626",
|
||||
keyword: "#7C3AED", str: "#059669",
|
||||
number: "#D97706", comment: "#6B7280",
|
||||
}
|
||||
}
|
||||
|
||||
// generateMarkdownStyleConfig creates an ansi.StyleConfig for markdown rendering.
|
||||
// generateMarkdownStyleConfig creates an ansi.StyleConfig from the active theme.
|
||||
func generateMarkdownStyleConfig() ansi.StyleConfig {
|
||||
cs := resolveColorScheme()
|
||||
md := GetTheme().Markdown
|
||||
text := colorHexPtr(md.Text)
|
||||
muted := colorHexPtr(md.Muted)
|
||||
heading := colorHexPtr(md.Heading)
|
||||
emph := colorHexPtr(md.Emph)
|
||||
strong := colorHexPtr(md.Strong)
|
||||
link := colorHexPtr(md.Link)
|
||||
code := colorHexPtr(md.Code)
|
||||
errClr := colorHexPtr(md.Error)
|
||||
keyword := colorHexPtr(md.Keyword)
|
||||
str := colorHexPtr(md.String)
|
||||
number := colorHexPtr(md.Number)
|
||||
comment := colorHexPtr(md.Comment)
|
||||
|
||||
return ansi.StyleConfig{
|
||||
Document: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
BlockPrefix: "",
|
||||
BlockSuffix: "",
|
||||
Color: &cs.text,
|
||||
Color: text,
|
||||
},
|
||||
Margin: uintPtr(0), // Remove margin to prevent spacing
|
||||
Margin: uintPtr(0),
|
||||
},
|
||||
BlockQuote: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
Color: &cs.muted,
|
||||
Color: muted,
|
||||
Italic: new(true),
|
||||
Prefix: "┃ ",
|
||||
},
|
||||
Indent: uintPtr(1),
|
||||
},
|
||||
List: ansi.StyleList{
|
||||
LevelIndent: 0, // Remove list indentation
|
||||
LevelIndent: 0,
|
||||
StyleBlock: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
Color: &cs.text,
|
||||
Color: text,
|
||||
},
|
||||
},
|
||||
},
|
||||
Heading: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
BlockSuffix: "\n",
|
||||
Color: &cs.heading,
|
||||
Color: heading,
|
||||
Bold: new(true),
|
||||
},
|
||||
},
|
||||
H1: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
Prefix: "# ",
|
||||
Color: &cs.heading,
|
||||
Color: heading,
|
||||
Bold: new(true),
|
||||
},
|
||||
},
|
||||
H2: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
Prefix: "## ",
|
||||
Color: &cs.heading,
|
||||
Color: heading,
|
||||
Bold: new(true),
|
||||
},
|
||||
},
|
||||
H3: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
Prefix: "### ",
|
||||
Color: &cs.heading,
|
||||
Color: heading,
|
||||
Bold: new(true),
|
||||
},
|
||||
},
|
||||
H4: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
Prefix: "#### ",
|
||||
Color: &cs.heading,
|
||||
Color: heading,
|
||||
Bold: new(true),
|
||||
},
|
||||
},
|
||||
H5: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
Prefix: "##### ",
|
||||
Color: &cs.heading,
|
||||
Color: heading,
|
||||
Bold: new(true),
|
||||
},
|
||||
},
|
||||
H6: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
Prefix: "###### ",
|
||||
Color: &cs.heading,
|
||||
Color: heading,
|
||||
Bold: new(true),
|
||||
},
|
||||
},
|
||||
Strikethrough: ansi.StylePrimitive{
|
||||
CrossedOut: new(true),
|
||||
Color: &cs.muted,
|
||||
Color: muted,
|
||||
},
|
||||
Emph: ansi.StylePrimitive{
|
||||
Color: &cs.emph,
|
||||
Color: emph,
|
||||
Italic: new(true),
|
||||
},
|
||||
Strong: ansi.StylePrimitive{
|
||||
Bold: new(true),
|
||||
Color: &cs.strong,
|
||||
Color: strong,
|
||||
},
|
||||
HorizontalRule: ansi.StylePrimitive{
|
||||
Color: &cs.muted,
|
||||
Color: muted,
|
||||
Format: "\n─────────────────────────────────────────\n",
|
||||
},
|
||||
Item: ansi.StylePrimitive{
|
||||
BlockPrefix: "• ",
|
||||
Color: &cs.text,
|
||||
Color: text,
|
||||
},
|
||||
Enumeration: ansi.StylePrimitive{
|
||||
BlockPrefix: ". ",
|
||||
Color: &cs.text,
|
||||
Color: text,
|
||||
},
|
||||
Task: ansi.StyleTask{
|
||||
StylePrimitive: ansi.StylePrimitive{},
|
||||
@@ -201,25 +164,25 @@ func generateMarkdownStyleConfig() ansi.StyleConfig {
|
||||
Unticked: "[ ] ",
|
||||
},
|
||||
Link: ansi.StylePrimitive{
|
||||
Color: &cs.link,
|
||||
Color: link,
|
||||
Underline: new(true),
|
||||
},
|
||||
LinkText: ansi.StylePrimitive{
|
||||
Color: &cs.link,
|
||||
Color: link,
|
||||
Bold: new(true),
|
||||
},
|
||||
Image: ansi.StylePrimitive{
|
||||
Color: &cs.link,
|
||||
Color: link,
|
||||
Underline: new(true),
|
||||
Format: "🖼 {{.text}}",
|
||||
},
|
||||
ImageText: ansi.StylePrimitive{
|
||||
Color: &cs.link,
|
||||
Color: link,
|
||||
Format: "{{.text}}",
|
||||
},
|
||||
Code: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
Color: &cs.code,
|
||||
Color: code,
|
||||
Prefix: "",
|
||||
Suffix: "",
|
||||
},
|
||||
@@ -228,50 +191,46 @@ func generateMarkdownStyleConfig() ansi.StyleConfig {
|
||||
StyleBlock: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
Prefix: "",
|
||||
Color: &cs.code,
|
||||
Color: code,
|
||||
},
|
||||
Margin: uintPtr(0), // Remove margin
|
||||
Margin: uintPtr(0),
|
||||
},
|
||||
Chroma: &ansi.Chroma{
|
||||
Text: ansi.StylePrimitive{Color: &cs.text},
|
||||
Error: ansi.StylePrimitive{Color: &cs.err},
|
||||
Comment: ansi.StylePrimitive{Color: &cs.comment},
|
||||
CommentPreproc: ansi.StylePrimitive{Color: &cs.keyword},
|
||||
Keyword: ansi.StylePrimitive{Color: &cs.keyword},
|
||||
KeywordReserved: ansi.StylePrimitive{
|
||||
Color: &cs.keyword,
|
||||
},
|
||||
KeywordNamespace: ansi.StylePrimitive{
|
||||
Color: &cs.keyword,
|
||||
},
|
||||
KeywordType: ansi.StylePrimitive{Color: &cs.keyword},
|
||||
Operator: ansi.StylePrimitive{Color: &cs.text},
|
||||
Punctuation: ansi.StylePrimitive{Color: &cs.text},
|
||||
Name: ansi.StylePrimitive{Color: &cs.text},
|
||||
NameBuiltin: ansi.StylePrimitive{Color: &cs.text},
|
||||
NameTag: ansi.StylePrimitive{Color: &cs.keyword},
|
||||
NameAttribute: ansi.StylePrimitive{Color: &cs.text},
|
||||
NameClass: ansi.StylePrimitive{Color: &cs.keyword},
|
||||
NameConstant: ansi.StylePrimitive{Color: &cs.text},
|
||||
NameDecorator: ansi.StylePrimitive{Color: &cs.text},
|
||||
NameFunction: ansi.StylePrimitive{Color: &cs.text},
|
||||
LiteralNumber: ansi.StylePrimitive{Color: &cs.number},
|
||||
LiteralString: ansi.StylePrimitive{Color: &cs.str},
|
||||
Text: ansi.StylePrimitive{Color: text},
|
||||
Error: ansi.StylePrimitive{Color: errClr},
|
||||
Comment: ansi.StylePrimitive{Color: comment},
|
||||
CommentPreproc: ansi.StylePrimitive{Color: keyword},
|
||||
Keyword: ansi.StylePrimitive{Color: keyword},
|
||||
KeywordReserved: ansi.StylePrimitive{Color: keyword},
|
||||
KeywordNamespace: ansi.StylePrimitive{Color: keyword},
|
||||
KeywordType: ansi.StylePrimitive{Color: keyword},
|
||||
Operator: ansi.StylePrimitive{Color: text},
|
||||
Punctuation: ansi.StylePrimitive{Color: text},
|
||||
Name: ansi.StylePrimitive{Color: text},
|
||||
NameBuiltin: ansi.StylePrimitive{Color: text},
|
||||
NameTag: ansi.StylePrimitive{Color: keyword},
|
||||
NameAttribute: ansi.StylePrimitive{Color: text},
|
||||
NameClass: ansi.StylePrimitive{Color: keyword},
|
||||
NameConstant: ansi.StylePrimitive{Color: text},
|
||||
NameDecorator: ansi.StylePrimitive{Color: text},
|
||||
NameFunction: ansi.StylePrimitive{Color: text},
|
||||
LiteralNumber: ansi.StylePrimitive{Color: number},
|
||||
LiteralString: ansi.StylePrimitive{Color: str},
|
||||
LiteralStringEscape: ansi.StylePrimitive{
|
||||
Color: &cs.keyword,
|
||||
Color: keyword,
|
||||
},
|
||||
GenericDeleted: ansi.StylePrimitive{Color: &cs.err},
|
||||
GenericDeleted: ansi.StylePrimitive{Color: errClr},
|
||||
GenericEmph: ansi.StylePrimitive{
|
||||
Color: &cs.emph,
|
||||
Color: emph,
|
||||
Italic: new(true),
|
||||
},
|
||||
GenericInserted: ansi.StylePrimitive{Color: &cs.str},
|
||||
GenericInserted: ansi.StylePrimitive{Color: str},
|
||||
GenericStrong: ansi.StylePrimitive{
|
||||
Color: &cs.strong,
|
||||
Color: strong,
|
||||
Bold: new(true),
|
||||
},
|
||||
GenericSubheading: ansi.StylePrimitive{
|
||||
Color: &cs.heading,
|
||||
Color: heading,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -288,14 +247,14 @@ func generateMarkdownStyleConfig() ansi.StyleConfig {
|
||||
},
|
||||
DefinitionDescription: ansi.StylePrimitive{
|
||||
BlockPrefix: "\n ❯ ",
|
||||
Color: &cs.link,
|
||||
Color: link,
|
||||
},
|
||||
Text: ansi.StylePrimitive{
|
||||
Color: &cs.text,
|
||||
Color: text,
|
||||
},
|
||||
Paragraph: ansi.StyleBlock{
|
||||
StylePrimitive: ansi.StylePrimitive{
|
||||
Color: &cs.text,
|
||||
Color: text,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -0,0 +1,637 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"image/color"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ThemeEntry is a named, loadable theme — either built-in or discovered from disk.
|
||||
type ThemeEntry struct {
|
||||
Name string // Display name (filename stem or preset name)
|
||||
Source string // "builtin" or absolute file path
|
||||
theme Theme // Resolved theme (lazy-loaded for file-based)
|
||||
loaded bool
|
||||
}
|
||||
|
||||
// Theme returns the resolved ui.Theme, loading from disk on first access.
|
||||
func (e *ThemeEntry) Theme() (Theme, error) {
|
||||
if e.loaded {
|
||||
return e.theme, nil
|
||||
}
|
||||
if e.Source == "builtin" {
|
||||
// Already set at registration time.
|
||||
return e.theme, nil
|
||||
}
|
||||
t, err := loadThemeFile(e.Source)
|
||||
if err != nil {
|
||||
return Theme{}, fmt.Errorf("loading theme %q: %w", e.Name, err)
|
||||
}
|
||||
e.theme = t
|
||||
e.loaded = true
|
||||
return e.theme, nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Built-in presets
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// builtinThemes returns the set of themes shipped with Kit.
|
||||
// makeTheme builds a full Theme from a compact palette spec. Fields left as
|
||||
// zero color.Color inherit from the KITT default theme, keeping the preset
|
||||
// definitions focused on what differs.
|
||||
type presetColors struct {
|
||||
primary, secondary, success, warning, error_, info [2]string // [light, dark]
|
||||
text, muted, veryMuted, background, border, mutedBorder [2]string
|
||||
system, tool, accent, highlight [2]string
|
||||
mdKeyword, mdString, mdNumber, mdComment, mdHeading, mdLink [2]string
|
||||
}
|
||||
|
||||
func makeTheme(p presetColors) Theme {
|
||||
ac := func(pair [2]string) color.Color { return AdaptiveColor(pair[0], pair[1]) }
|
||||
def := DefaultTheme()
|
||||
acOr := func(pair [2]string, fb color.Color) color.Color {
|
||||
if pair[0] == "" && pair[1] == "" {
|
||||
return fb
|
||||
}
|
||||
return ac(pair)
|
||||
}
|
||||
t := Theme{
|
||||
Primary: ac(p.primary),
|
||||
Secondary: acOr(p.secondary, ac(p.primary)),
|
||||
Success: ac(p.success),
|
||||
Warning: ac(p.warning),
|
||||
Error: ac(p.error_),
|
||||
Info: ac(p.info),
|
||||
Text: ac(p.text),
|
||||
Muted: acOr(p.muted, def.Muted),
|
||||
VeryMuted: acOr(p.veryMuted, def.VeryMuted),
|
||||
Background: ac(p.background),
|
||||
Border: acOr(p.border, def.Border),
|
||||
MutedBorder: acOr(p.mutedBorder, def.MutedBorder),
|
||||
System: acOr(p.system, ac(p.info)),
|
||||
Tool: acOr(p.tool, ac(p.warning)),
|
||||
Accent: acOr(p.accent, ac(p.primary)),
|
||||
Highlight: acOr(p.highlight, def.Highlight),
|
||||
}
|
||||
// Derive diff/code backgrounds from the base background.
|
||||
t.DiffInsertBg = def.DiffInsertBg
|
||||
t.DiffDeleteBg = def.DiffDeleteBg
|
||||
t.DiffEqualBg = def.DiffEqualBg
|
||||
t.DiffMissingBg = def.DiffMissingBg
|
||||
t.CodeBg = def.CodeBg
|
||||
t.GutterBg = def.GutterBg
|
||||
t.WriteBg = def.WriteBg
|
||||
// Markdown colors.
|
||||
t.Markdown = MarkdownThemeColors{
|
||||
Text: t.Text,
|
||||
Muted: t.Muted,
|
||||
Heading: acOr(p.mdHeading, t.Primary),
|
||||
Emph: t.Warning,
|
||||
Strong: t.Text,
|
||||
Link: acOr(p.mdLink, t.Info),
|
||||
Code: t.Muted,
|
||||
Error: t.Error,
|
||||
Keyword: acOr(p.mdKeyword, t.Primary),
|
||||
String: acOr(p.mdString, t.Success),
|
||||
Number: acOr(p.mdNumber, t.Warning),
|
||||
Comment: acOr(p.mdComment, t.VeryMuted),
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// builtinThemes returns the set of themes shipped with Kit.
|
||||
// Inspired by the OpenCode theme collection.
|
||||
func builtinThemes() map[string]Theme {
|
||||
return map[string]Theme{
|
||||
"kitt": DefaultTheme(),
|
||||
|
||||
"catppuccin": makeTheme(presetColors{
|
||||
primary: [2]string{"#8839ef", "#cba6f7"}, secondary: [2]string{"#04a5e5", "#89dceb"},
|
||||
success: [2]string{"#40a02b", "#a6e3a1"}, warning: [2]string{"#df8e1d", "#f9e2af"},
|
||||
error_: [2]string{"#d20f39", "#f38ba8"}, info: [2]string{"#1e66f5", "#89b4fa"},
|
||||
text: [2]string{"#4c4f69", "#cdd6f4"}, muted: [2]string{"#6c6f85", "#a6adc8"},
|
||||
veryMuted: [2]string{"#9ca0b0", "#6c7086"}, background: [2]string{"#eff1f5", "#1e1e2e"},
|
||||
border: [2]string{"#acb0be", "#585b70"}, mutedBorder: [2]string{"#ccd0da", "#313244"},
|
||||
system: [2]string{"#179299", "#94e2d5"}, tool: [2]string{"#fe640b", "#fab387"},
|
||||
accent: [2]string{"#ea76cb", "#f5c2e7"}, highlight: [2]string{"#e6e9ef", "#181825"},
|
||||
mdKeyword: [2]string{"#8839ef", "#cba6f7"}, mdString: [2]string{"#40a02b", "#a6e3a1"},
|
||||
mdNumber: [2]string{"#fe640b", "#fab387"}, mdComment: [2]string{"#9ca0b0", "#6c7086"},
|
||||
}),
|
||||
|
||||
"dracula": makeTheme(presetColors{
|
||||
primary: [2]string{"#7c6bf5", "#bd93f9"}, secondary: [2]string{"#d16090", "#ff79c6"},
|
||||
success: [2]string{"#2fbf71", "#50fa7b"}, warning: [2]string{"#f7a14d", "#ffb86c"},
|
||||
error_: [2]string{"#d9536f", "#ff5555"}, info: [2]string{"#1d7fc5", "#8be9fd"},
|
||||
text: [2]string{"#1f1f2f", "#f8f8f2"}, background: [2]string{"#f8f8f2", "#1d1e28"},
|
||||
accent: [2]string{"#d16090", "#ff79c6"},
|
||||
mdKeyword: [2]string{"#7c6bf5", "#bd93f9"}, mdString: [2]string{"#2fbf71", "#50fa7b"},
|
||||
mdComment: [2]string{"#6272a4", "#6272a4"},
|
||||
}),
|
||||
|
||||
"tokyonight": makeTheme(presetColors{
|
||||
primary: [2]string{"#2e7de9", "#7aa2f7"}, secondary: [2]string{"#b15c00", "#ff9e64"},
|
||||
success: [2]string{"#587539", "#9ece6a"}, warning: [2]string{"#8c6c3e", "#e0af68"},
|
||||
error_: [2]string{"#c94060", "#f7768e"}, info: [2]string{"#007197", "#7dcfff"},
|
||||
text: [2]string{"#273153", "#c0caf5"}, background: [2]string{"#e1e2e7", "#1a1b26"},
|
||||
mdKeyword: [2]string{"#2e7de9", "#7aa2f7"}, mdString: [2]string{"#587539", "#9ece6a"},
|
||||
mdComment: [2]string{"#848cb5", "#565f89"},
|
||||
}),
|
||||
|
||||
"nord": makeTheme(presetColors{
|
||||
primary: [2]string{"#5e81ac", "#88c0d0"}, secondary: [2]string{"#bf616a", "#d57780"},
|
||||
success: [2]string{"#8fbcbb", "#a3be8c"}, warning: [2]string{"#d08770", "#d08770"},
|
||||
error_: [2]string{"#bf616a", "#bf616a"}, info: [2]string{"#81a1c1", "#81a1c1"},
|
||||
text: [2]string{"#2e3440", "#e5e9f0"}, background: [2]string{"#eceff4", "#2e3440"},
|
||||
mdKeyword: [2]string{"#5e81ac", "#81a1c1"}, mdString: [2]string{"#8fbcbb", "#a3be8c"},
|
||||
mdComment: [2]string{"#616e88", "#616e88"},
|
||||
}),
|
||||
|
||||
"gruvbox": makeTheme(presetColors{
|
||||
primary: [2]string{"#076678", "#83a598"}, secondary: [2]string{"#9d0006", "#fb4934"},
|
||||
success: [2]string{"#79740e", "#b8bb26"}, warning: [2]string{"#b57614", "#fabd2f"},
|
||||
error_: [2]string{"#9d0006", "#fb4934"}, info: [2]string{"#8f3f71", "#d3869b"},
|
||||
text: [2]string{"#3c3836", "#ebdbb2"}, background: [2]string{"#fbf1c7", "#282828"},
|
||||
mdKeyword: [2]string{"#9d0006", "#fb4934"}, mdString: [2]string{"#79740e", "#b8bb26"},
|
||||
mdComment: [2]string{"#928374", "#928374"},
|
||||
}),
|
||||
|
||||
"monokai": makeTheme(presetColors{
|
||||
primary: [2]string{"#bf7bff", "#ae81ff"}, secondary: [2]string{"#d9487c", "#f92672"},
|
||||
success: [2]string{"#4fb54b", "#a6e22e"}, warning: [2]string{"#f1a948", "#fd971f"},
|
||||
error_: [2]string{"#e54b4b", "#f92672"}, info: [2]string{"#2d9ad7", "#66d9ef"},
|
||||
text: [2]string{"#292318", "#f8f8f2"}, background: [2]string{"#fdf8ec", "#272822"},
|
||||
mdKeyword: [2]string{"#d9487c", "#f92672"}, mdString: [2]string{"#4fb54b", "#a6e22e"},
|
||||
mdComment: [2]string{"#888888", "#75715e"},
|
||||
}),
|
||||
|
||||
"solarized": makeTheme(presetColors{
|
||||
primary: [2]string{"#268bd2", "#6c71c4"}, secondary: [2]string{"#d33682", "#d33682"},
|
||||
success: [2]string{"#859900", "#859900"}, warning: [2]string{"#b58900", "#b58900"},
|
||||
error_: [2]string{"#dc322f", "#dc322f"}, info: [2]string{"#2aa198", "#2aa198"},
|
||||
text: [2]string{"#586e75", "#93a1a1"}, background: [2]string{"#fdf6e3", "#002b36"},
|
||||
mdKeyword: [2]string{"#268bd2", "#6c71c4"}, mdString: [2]string{"#859900", "#859900"},
|
||||
mdComment: [2]string{"#93a1a1", "#586e75"},
|
||||
}),
|
||||
|
||||
"github": makeTheme(presetColors{
|
||||
primary: [2]string{"#0969da", "#58a6ff"}, secondary: [2]string{"#1b7c83", "#39c5cf"},
|
||||
success: [2]string{"#1a7f37", "#3fb950"}, warning: [2]string{"#9a6700", "#e3b341"},
|
||||
error_: [2]string{"#cf222e", "#f85149"}, info: [2]string{"#bc4c00", "#d29922"},
|
||||
text: [2]string{"#24292f", "#c9d1d9"}, background: [2]string{"#ffffff", "#0d1117"},
|
||||
mdKeyword: [2]string{"#0969da", "#58a6ff"}, mdString: [2]string{"#1a7f37", "#3fb950"},
|
||||
mdComment: [2]string{"#6e7781", "#8b949e"},
|
||||
}),
|
||||
|
||||
"one-dark": makeTheme(presetColors{
|
||||
primary: [2]string{"#4078f2", "#61afef"}, secondary: [2]string{"#0184bc", "#56b6c2"},
|
||||
success: [2]string{"#50a14f", "#98c379"}, warning: [2]string{"#c18401", "#e5c07b"},
|
||||
error_: [2]string{"#e45649", "#e06c75"}, info: [2]string{"#986801", "#d19a66"},
|
||||
text: [2]string{"#383a42", "#abb2bf"}, background: [2]string{"#fafafa", "#282c34"},
|
||||
mdKeyword: [2]string{"#a626a4", "#c678dd"}, mdString: [2]string{"#50a14f", "#98c379"},
|
||||
mdComment: [2]string{"#a0a1a7", "#5c6370"},
|
||||
}),
|
||||
|
||||
"rose-pine": makeTheme(presetColors{
|
||||
primary: [2]string{"#31748f", "#9ccfd8"}, secondary: [2]string{"#d7827e", "#ebbcba"},
|
||||
success: [2]string{"#286983", "#31748f"}, warning: [2]string{"#ea9d34", "#f6c177"},
|
||||
error_: [2]string{"#b4637a", "#eb6f92"}, info: [2]string{"#56949f", "#9ccfd8"},
|
||||
text: [2]string{"#575279", "#e0def4"}, background: [2]string{"#faf4ed", "#191724"},
|
||||
mdKeyword: [2]string{"#31748f", "#9ccfd8"}, mdString: [2]string{"#ea9d34", "#f6c177"},
|
||||
mdComment: [2]string{"#9893a5", "#6e6a86"},
|
||||
}),
|
||||
|
||||
"ayu": makeTheme(presetColors{
|
||||
primary: [2]string{"#4aa8c8", "#3fb7e3"}, secondary: [2]string{"#ef7d71", "#f2856f"},
|
||||
success: [2]string{"#5fb978", "#78d05c"}, warning: [2]string{"#ea9f41", "#e4a75c"},
|
||||
error_: [2]string{"#e6656a", "#f58572"}, info: [2]string{"#2f9bce", "#66c6f1"},
|
||||
text: [2]string{"#4f5964", "#d6dae0"}, background: [2]string{"#fdfaf4", "#0f1419"},
|
||||
mdKeyword: [2]string{"#4aa8c8", "#3fb7e3"}, mdString: [2]string{"#5fb978", "#78d05c"},
|
||||
mdComment: [2]string{"#abb0b6", "#5c6773"},
|
||||
}),
|
||||
|
||||
"material": makeTheme(presetColors{
|
||||
primary: [2]string{"#6182b8", "#82aaff"}, secondary: [2]string{"#39adb5", "#89ddff"},
|
||||
success: [2]string{"#91b859", "#c3e88d"}, warning: [2]string{"#ffb300", "#ffcb6b"},
|
||||
error_: [2]string{"#e53935", "#f07178"}, info: [2]string{"#f4511e", "#ffcb6b"},
|
||||
text: [2]string{"#263238", "#eeffff"}, background: [2]string{"#fafafa", "#263238"},
|
||||
mdKeyword: [2]string{"#6182b8", "#82aaff"}, mdString: [2]string{"#91b859", "#c3e88d"},
|
||||
mdComment: [2]string{"#aabfc5", "#546e7a"},
|
||||
}),
|
||||
|
||||
"everforest": makeTheme(presetColors{
|
||||
primary: [2]string{"#8da101", "#a7c080"}, secondary: [2]string{"#df69ba", "#d699b6"},
|
||||
success: [2]string{"#8da101", "#a7c080"}, warning: [2]string{"#f57d26", "#e69875"},
|
||||
error_: [2]string{"#f85552", "#e67e80"}, info: [2]string{"#35a77c", "#83c092"},
|
||||
text: [2]string{"#5c6a72", "#d3c6aa"}, background: [2]string{"#fdf6e3", "#2d353b"},
|
||||
mdKeyword: [2]string{"#8da101", "#a7c080"}, mdString: [2]string{"#35a77c", "#83c092"},
|
||||
mdComment: [2]string{"#939b84", "#859289"},
|
||||
}),
|
||||
|
||||
"kanagawa": makeTheme(presetColors{
|
||||
primary: [2]string{"#2D4F67", "#7E9CD8"}, secondary: [2]string{"#D27E99", "#D27E99"},
|
||||
success: [2]string{"#98BB6C", "#98BB6C"}, warning: [2]string{"#D7A657", "#D7A657"},
|
||||
error_: [2]string{"#E82424", "#E82424"}, info: [2]string{"#76946A", "#76946A"},
|
||||
text: [2]string{"#54433A", "#DCD7BA"}, background: [2]string{"#F2E9DE", "#1F1F28"},
|
||||
mdKeyword: [2]string{"#2D4F67", "#7E9CD8"}, mdString: [2]string{"#98BB6C", "#98BB6C"},
|
||||
mdComment: [2]string{"#A09D98", "#727169"},
|
||||
}),
|
||||
|
||||
"amoled": makeTheme(presetColors{
|
||||
primary: [2]string{"#6200ff", "#b388ff"}, secondary: [2]string{"#ff0080", "#ff4081"},
|
||||
success: [2]string{"#00e676", "#00ff88"}, warning: [2]string{"#ffab00", "#ffea00"},
|
||||
error_: [2]string{"#ff1744", "#ff1744"}, info: [2]string{"#00b0ff", "#18ffff"},
|
||||
text: [2]string{"#0a0a0a", "#ffffff"}, background: [2]string{"#f0f0f0", "#000000"},
|
||||
mdKeyword: [2]string{"#6200ff", "#b388ff"}, mdString: [2]string{"#00e676", "#00ff88"},
|
||||
mdComment: [2]string{"#757575", "#424242"},
|
||||
}),
|
||||
|
||||
"synthwave": makeTheme(presetColors{
|
||||
primary: [2]string{"#00bcd4", "#36f9f6"}, secondary: [2]string{"#9c27b0", "#b084eb"},
|
||||
success: [2]string{"#4caf50", "#72f1b8"}, warning: [2]string{"#ff9800", "#fede5d"},
|
||||
error_: [2]string{"#f44336", "#fe4450"}, info: [2]string{"#ff5722", "#ff8b39"},
|
||||
text: [2]string{"#262335", "#ffffff"}, background: [2]string{"#fafafa", "#262335"},
|
||||
mdKeyword: [2]string{"#9c27b0", "#b084eb"}, mdString: [2]string{"#4caf50", "#72f1b8"},
|
||||
mdComment: [2]string{"#848bbd", "#848bbd"},
|
||||
}),
|
||||
|
||||
"vesper": makeTheme(presetColors{
|
||||
primary: [2]string{"#FFC799", "#FFC799"}, secondary: [2]string{"#B30000", "#FF8080"},
|
||||
success: [2]string{"#99FFE4", "#99FFE4"}, warning: [2]string{"#FFC799", "#FFC799"},
|
||||
error_: [2]string{"#FF8080", "#FF8080"}, info: [2]string{"#FFC799", "#FFC799"},
|
||||
text: [2]string{"#1a1a1a", "#FFF"}, background: [2]string{"#F0F0F0", "#101010"},
|
||||
mdKeyword: [2]string{"#FFC799", "#FFC799"}, mdString: [2]string{"#99FFE4", "#99FFE4"},
|
||||
mdComment: [2]string{"#7a7a7a", "#505050"},
|
||||
}),
|
||||
|
||||
"flexoki": makeTheme(presetColors{
|
||||
primary: [2]string{"#205EA6", "#DA702C"}, secondary: [2]string{"#BC5215", "#8B7EC8"},
|
||||
success: [2]string{"#66800B", "#879A39"}, warning: [2]string{"#BC5215", "#DA702C"},
|
||||
error_: [2]string{"#AF3029", "#D14D41"}, info: [2]string{"#24837B", "#3AA99F"},
|
||||
text: [2]string{"#100F0F", "#CECDC3"}, background: [2]string{"#FFFCF0", "#100F0F"},
|
||||
mdKeyword: [2]string{"#205EA6", "#DA702C"}, mdString: [2]string{"#66800B", "#879A39"},
|
||||
mdComment: [2]string{"#878580", "#878580"},
|
||||
}),
|
||||
|
||||
"matrix": makeTheme(presetColors{
|
||||
primary: [2]string{"#1cc24b", "#2eff6a"}, secondary: [2]string{"#c770ff", "#c770ff"},
|
||||
success: [2]string{"#1cc24b", "#62ff94"}, warning: [2]string{"#e6ff57", "#e6ff57"},
|
||||
error_: [2]string{"#ff4b4b", "#ff4b4b"}, info: [2]string{"#30b3ff", "#30b3ff"},
|
||||
text: [2]string{"#203022", "#62ff94"}, background: [2]string{"#eef3ea", "#0a0e0a"},
|
||||
mdKeyword: [2]string{"#1cc24b", "#2eff6a"}, mdString: [2]string{"#1cc24b", "#62ff94"},
|
||||
mdComment: [2]string{"#5a7a5e", "#3a5a3e"},
|
||||
}),
|
||||
|
||||
"vercel": makeTheme(presetColors{
|
||||
primary: [2]string{"#0070F3", "#0070F3"}, secondary: [2]string{"#8E4EC6", "#8E4EC6"},
|
||||
success: [2]string{"#388E3C", "#46A758"}, warning: [2]string{"#FF9500", "#FFB224"},
|
||||
error_: [2]string{"#DC3545", "#E5484D"}, info: [2]string{"#0070F3", "#52A8FF"},
|
||||
text: [2]string{"#171717", "#EDEDED"}, background: [2]string{"#FFFFFF", "#000000"},
|
||||
mdKeyword: [2]string{"#0070F3", "#0070F3"}, mdString: [2]string{"#388E3C", "#46A758"},
|
||||
mdComment: [2]string{"#6B6B6B", "#666666"},
|
||||
}),
|
||||
|
||||
"zenburn": makeTheme(presetColors{
|
||||
primary: [2]string{"#5f7f8f", "#8cd0d3"}, secondary: [2]string{"#5f8f8f", "#93e0e3"},
|
||||
success: [2]string{"#5f8f5f", "#7f9f7f"}, warning: [2]string{"#8f8f5f", "#f0dfaf"},
|
||||
error_: [2]string{"#8f5f5f", "#cc9393"}, info: [2]string{"#8f7f5f", "#dfaf8f"},
|
||||
text: [2]string{"#3f3f3f", "#dcdccc"}, background: [2]string{"#ffffef", "#3f3f3f"},
|
||||
mdKeyword: [2]string{"#5f7f8f", "#8cd0d3"}, mdString: [2]string{"#5f8f5f", "#cc9393"},
|
||||
mdComment: [2]string{"#7f7f7f", "#7f9f7f"},
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Theme registry (global)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
var themeRegistry []ThemeEntry
|
||||
|
||||
// initThemeRegistry populates the registry from built-ins, user themes, and
|
||||
// project-local themes. Later sources override earlier ones with the same name:
|
||||
// 1. Built-in presets
|
||||
// 2. User themes (~/.config/kit/themes/)
|
||||
// 3. Project-local (.kit/themes/ in the working directory)
|
||||
func initThemeRegistry() {
|
||||
themeRegistry = nil
|
||||
|
||||
// 1. Built-in presets.
|
||||
for name, t := range builtinThemes() {
|
||||
themeRegistry = append(themeRegistry, ThemeEntry{
|
||||
Name: name,
|
||||
Source: "builtin",
|
||||
theme: t,
|
||||
loaded: true,
|
||||
})
|
||||
}
|
||||
|
||||
// 2. User themes from ~/.config/kit/themes/
|
||||
scanThemesDir(userThemesDir())
|
||||
|
||||
// 3. Project-local themes from .kit/themes/
|
||||
scanThemesDir(projectThemesDir())
|
||||
|
||||
sortRegistry()
|
||||
}
|
||||
|
||||
// scanThemesDir adds all .yml/.yaml/.json theme files from dir to the registry.
|
||||
// Files override any existing entry with the same stem name.
|
||||
func scanThemesDir(dir string) {
|
||||
if dir == "" {
|
||||
return
|
||||
}
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
ext := strings.ToLower(filepath.Ext(entry.Name()))
|
||||
if ext != ".yml" && ext != ".yaml" && ext != ".json" {
|
||||
continue
|
||||
}
|
||||
name := strings.TrimSuffix(entry.Name(), filepath.Ext(entry.Name()))
|
||||
removeFromRegistry(name)
|
||||
themeRegistry = append(themeRegistry, ThemeEntry{
|
||||
Name: name,
|
||||
Source: filepath.Join(dir, entry.Name()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func sortRegistry() {
|
||||
sort.Slice(themeRegistry, func(i, j int) bool {
|
||||
return themeRegistry[i].Name < themeRegistry[j].Name
|
||||
})
|
||||
}
|
||||
|
||||
func removeFromRegistry(name string) {
|
||||
for i := range themeRegistry {
|
||||
if themeRegistry[i].Name == name {
|
||||
themeRegistry = append(themeRegistry[:i], themeRegistry[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// userThemesDir returns ~/.config/kit/themes, creating it if needed.
|
||||
func userThemesDir() string {
|
||||
cfgDir, err := os.UserConfigDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
dir := filepath.Join(cfgDir, "kit", "themes")
|
||||
_ = os.MkdirAll(dir, 0o755)
|
||||
return dir
|
||||
}
|
||||
|
||||
// projectThemesDir returns .kit/themes/ relative to the working directory.
|
||||
// Returns "" if the directory doesn't exist (does NOT create it).
|
||||
func projectThemesDir() string {
|
||||
dir := filepath.Join(".kit", "themes")
|
||||
info, err := os.Stat(dir)
|
||||
if err != nil || !info.IsDir() {
|
||||
return ""
|
||||
}
|
||||
abs, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
return dir
|
||||
}
|
||||
return abs
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Public API
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ListThemes returns the names of all available themes (built-in + user).
|
||||
func ListThemes() []string {
|
||||
if themeRegistry == nil {
|
||||
initThemeRegistry()
|
||||
}
|
||||
names := make([]string, len(themeRegistry))
|
||||
for i := range themeRegistry {
|
||||
names[i] = themeRegistry[i].Name
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// LoadThemeByName looks up a theme by name, loads it if needed, and returns it.
|
||||
func LoadThemeByName(name string) (Theme, error) {
|
||||
if themeRegistry == nil {
|
||||
initThemeRegistry()
|
||||
}
|
||||
for i := range themeRegistry {
|
||||
if themeRegistry[i].Name == name {
|
||||
return themeRegistry[i].Theme()
|
||||
}
|
||||
}
|
||||
return Theme{}, fmt.Errorf("theme %q not found", name)
|
||||
}
|
||||
|
||||
// ApplyTheme loads a theme by name and sets it as the active global theme.
|
||||
func ApplyTheme(name string) error {
|
||||
t, err := LoadThemeByName(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
SetTheme(t)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RefreshThemeRegistry re-scans the themes directory. Call after the user
|
||||
// drops a new file into ~/.config/kit/themes/.
|
||||
func RefreshThemeRegistry() {
|
||||
initThemeRegistry()
|
||||
}
|
||||
|
||||
// RegisterThemeFromConfig adds a theme to the runtime registry from an
|
||||
// extension's ThemeColorConfig (string hex pairs). Replaces any existing
|
||||
// entry with the same name. The theme is immediately available via
|
||||
// ListThemes, LoadThemeByName, and ApplyTheme.
|
||||
func RegisterThemeFromConfig(name string, primary, secondary, success, warning, error_, info, text, muted, veryMuted, background, border, mutedBorder, system, tool, accent, highlight, mdHeading, mdLink, mdKeyword, mdString, mdNumber, mdComment [2]string) {
|
||||
if themeRegistry == nil {
|
||||
initThemeRegistry()
|
||||
}
|
||||
t := makeTheme(presetColors{
|
||||
primary: primary, secondary: secondary,
|
||||
success: success, warning: warning,
|
||||
error_: error_, info: info,
|
||||
text: text, muted: muted,
|
||||
veryMuted: veryMuted, background: background,
|
||||
border: border, mutedBorder: mutedBorder,
|
||||
system: system, tool: tool,
|
||||
accent: accent, highlight: highlight,
|
||||
mdHeading: mdHeading, mdLink: mdLink,
|
||||
mdKeyword: mdKeyword, mdString: mdString,
|
||||
mdNumber: mdNumber, mdComment: mdComment,
|
||||
})
|
||||
removeFromRegistry(name)
|
||||
themeRegistry = append(themeRegistry, ThemeEntry{
|
||||
Name: name,
|
||||
Source: "extension",
|
||||
theme: t,
|
||||
loaded: true,
|
||||
})
|
||||
sortRegistry()
|
||||
}
|
||||
|
||||
// ActiveThemeName returns the name of the currently active theme by comparing
|
||||
// against known entries. Returns "custom" if no match is found.
|
||||
func ActiveThemeName() string {
|
||||
if themeRegistry == nil {
|
||||
initThemeRegistry()
|
||||
}
|
||||
current := GetTheme()
|
||||
for _, e := range themeRegistry {
|
||||
if !e.loaded {
|
||||
continue
|
||||
}
|
||||
if e.theme.Primary == current.Primary &&
|
||||
e.theme.Secondary == current.Secondary &&
|
||||
e.theme.Error == current.Error &&
|
||||
e.theme.Text == current.Text {
|
||||
return e.Name
|
||||
}
|
||||
}
|
||||
return "custom"
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// File loading
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// themeFileConfig mirrors config.Theme for unmarshaling theme files.
|
||||
// Uses the same adaptive color structure.
|
||||
type themeFileConfig struct {
|
||||
Primary adaptiveColorPair `json:"primary,omitzero" yaml:"primary,omitempty"`
|
||||
Secondary adaptiveColorPair `json:"secondary,omitzero" yaml:"secondary,omitempty"`
|
||||
Success adaptiveColorPair `json:"success,omitzero" yaml:"success,omitempty"`
|
||||
Warning adaptiveColorPair `json:"warning,omitzero" yaml:"warning,omitempty"`
|
||||
Error adaptiveColorPair `json:"error,omitzero" yaml:"error,omitempty"`
|
||||
Info adaptiveColorPair `json:"info,omitzero" yaml:"info,omitempty"`
|
||||
Text adaptiveColorPair `json:"text,omitzero" yaml:"text,omitempty"`
|
||||
Muted adaptiveColorPair `json:"muted,omitzero" yaml:"muted,omitempty"`
|
||||
VeryMuted adaptiveColorPair `json:"very-muted,omitzero" yaml:"very-muted,omitempty"`
|
||||
Background adaptiveColorPair `json:"background,omitzero" yaml:"background,omitempty"`
|
||||
Border adaptiveColorPair `json:"border,omitzero" yaml:"border,omitempty"`
|
||||
MutedBorder adaptiveColorPair `json:"muted-border,omitzero" yaml:"muted-border,omitempty"`
|
||||
System adaptiveColorPair `json:"system,omitzero" yaml:"system,omitempty"`
|
||||
Tool adaptiveColorPair `json:"tool,omitzero" yaml:"tool,omitempty"`
|
||||
Accent adaptiveColorPair `json:"accent,omitzero" yaml:"accent,omitempty"`
|
||||
Highlight adaptiveColorPair `json:"highlight,omitzero" yaml:"highlight,omitempty"`
|
||||
|
||||
DiffInsertBg adaptiveColorPair `json:"diff-insert-bg,omitzero" yaml:"diff-insert-bg,omitempty"`
|
||||
DiffDeleteBg adaptiveColorPair `json:"diff-delete-bg,omitzero" yaml:"diff-delete-bg,omitempty"`
|
||||
DiffEqualBg adaptiveColorPair `json:"diff-equal-bg,omitzero" yaml:"diff-equal-bg,omitempty"`
|
||||
DiffMissingBg adaptiveColorPair `json:"diff-missing-bg,omitzero" yaml:"diff-missing-bg,omitempty"`
|
||||
CodeBg adaptiveColorPair `json:"code-bg,omitzero" yaml:"code-bg,omitempty"`
|
||||
GutterBg adaptiveColorPair `json:"gutter-bg,omitzero" yaml:"gutter-bg,omitempty"`
|
||||
WriteBg adaptiveColorPair `json:"write-bg,omitzero" yaml:"write-bg,omitempty"`
|
||||
|
||||
Markdown struct {
|
||||
Text adaptiveColorPair `json:"text,omitzero" yaml:"text,omitempty"`
|
||||
Muted adaptiveColorPair `json:"muted,omitzero" yaml:"muted,omitempty"`
|
||||
Heading adaptiveColorPair `json:"heading,omitzero" yaml:"heading,omitempty"`
|
||||
Emph adaptiveColorPair `json:"emph,omitzero" yaml:"emph,omitempty"`
|
||||
Strong adaptiveColorPair `json:"strong,omitzero" yaml:"strong,omitempty"`
|
||||
Link adaptiveColorPair `json:"link,omitzero" yaml:"link,omitempty"`
|
||||
Code adaptiveColorPair `json:"code,omitzero" yaml:"code,omitempty"`
|
||||
Error adaptiveColorPair `json:"error,omitzero" yaml:"error,omitempty"`
|
||||
Keyword adaptiveColorPair `json:"keyword,omitzero" yaml:"keyword,omitempty"`
|
||||
String adaptiveColorPair `json:"string,omitzero" yaml:"string,omitempty"`
|
||||
Number adaptiveColorPair `json:"number,omitzero" yaml:"number,omitempty"`
|
||||
Comment adaptiveColorPair `json:"comment,omitzero" yaml:"comment,omitempty"`
|
||||
} `json:"markdown,omitzero" yaml:"markdown,omitempty"`
|
||||
}
|
||||
|
||||
type adaptiveColorPair struct {
|
||||
Light string `json:"light,omitempty" yaml:"light,omitempty"`
|
||||
Dark string `json:"dark,omitempty" yaml:"dark,omitempty"`
|
||||
}
|
||||
|
||||
// resolve converts an adaptiveColorPair to a resolved color.Color,
|
||||
// falling back to fallback when both Light and Dark are empty.
|
||||
func (a adaptiveColorPair) resolve(fallback color.Color) color.Color {
|
||||
if a.Light == "" && a.Dark == "" {
|
||||
return fallback
|
||||
}
|
||||
return AdaptiveColor(a.Light, a.Dark)
|
||||
}
|
||||
|
||||
func loadThemeFile(path string) (Theme, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return Theme{}, err
|
||||
}
|
||||
|
||||
var cfg themeFileConfig
|
||||
ext := strings.ToLower(filepath.Ext(path))
|
||||
switch ext {
|
||||
case ".json":
|
||||
err = json.Unmarshal(data, &cfg)
|
||||
case ".yaml", ".yml":
|
||||
err = yaml.Unmarshal(data, &cfg)
|
||||
default:
|
||||
return Theme{}, fmt.Errorf("unsupported theme file format: %s", ext)
|
||||
}
|
||||
if err != nil {
|
||||
return Theme{}, err
|
||||
}
|
||||
|
||||
return fileConfigToTheme(cfg), nil
|
||||
}
|
||||
|
||||
func fileConfigToTheme(cfg themeFileConfig) Theme {
|
||||
def := DefaultTheme()
|
||||
return Theme{
|
||||
Primary: cfg.Primary.resolve(def.Primary),
|
||||
Secondary: cfg.Secondary.resolve(def.Secondary),
|
||||
Success: cfg.Success.resolve(def.Success),
|
||||
Warning: cfg.Warning.resolve(def.Warning),
|
||||
Error: cfg.Error.resolve(def.Error),
|
||||
Info: cfg.Info.resolve(def.Info),
|
||||
Text: cfg.Text.resolve(def.Text),
|
||||
Muted: cfg.Muted.resolve(def.Muted),
|
||||
VeryMuted: cfg.VeryMuted.resolve(def.VeryMuted),
|
||||
Background: cfg.Background.resolve(def.Background),
|
||||
Border: cfg.Border.resolve(def.Border),
|
||||
MutedBorder: cfg.MutedBorder.resolve(def.MutedBorder),
|
||||
System: cfg.System.resolve(def.System),
|
||||
Tool: cfg.Tool.resolve(def.Tool),
|
||||
Accent: cfg.Accent.resolve(def.Accent),
|
||||
Highlight: cfg.Highlight.resolve(def.Highlight),
|
||||
|
||||
DiffInsertBg: cfg.DiffInsertBg.resolve(def.DiffInsertBg),
|
||||
DiffDeleteBg: cfg.DiffDeleteBg.resolve(def.DiffDeleteBg),
|
||||
DiffEqualBg: cfg.DiffEqualBg.resolve(def.DiffEqualBg),
|
||||
DiffMissingBg: cfg.DiffMissingBg.resolve(def.DiffMissingBg),
|
||||
CodeBg: cfg.CodeBg.resolve(def.CodeBg),
|
||||
GutterBg: cfg.GutterBg.resolve(def.GutterBg),
|
||||
WriteBg: cfg.WriteBg.resolve(def.WriteBg),
|
||||
|
||||
Markdown: MarkdownThemeColors{
|
||||
Text: cfg.Markdown.Text.resolve(def.Markdown.Text),
|
||||
Muted: cfg.Markdown.Muted.resolve(def.Markdown.Muted),
|
||||
Heading: cfg.Markdown.Heading.resolve(def.Markdown.Heading),
|
||||
Emph: cfg.Markdown.Emph.resolve(def.Markdown.Emph),
|
||||
Strong: cfg.Markdown.Strong.resolve(def.Markdown.Strong),
|
||||
Link: cfg.Markdown.Link.resolve(def.Markdown.Link),
|
||||
Code: cfg.Markdown.Code.resolve(def.Markdown.Code),
|
||||
Error: cfg.Markdown.Error.resolve(def.Markdown.Error),
|
||||
Keyword: cfg.Markdown.Keyword.resolve(def.Markdown.Keyword),
|
||||
String: cfg.Markdown.String.resolve(def.Markdown.String),
|
||||
Number: cfg.Markdown.Number.resolve(def.Markdown.Number),
|
||||
Comment: cfg.Markdown.Comment.resolve(def.Markdown.Comment),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -28,11 +28,12 @@ func NewToolApprovalInput(toolName, toolArgs string, width int) *ToolApprovalInp
|
||||
ta.SetHeight(4) // Default to 3 lines like huh
|
||||
ta.Focus()
|
||||
|
||||
// Style the textarea to match huh theme
|
||||
// Style the textarea using theme colors.
|
||||
theme := GetTheme()
|
||||
styles := ta.Styles()
|
||||
styles.Focused.Base = lipgloss.NewStyle()
|
||||
styles.Focused.Placeholder = lipgloss.NewStyle().Foreground(lipgloss.Color("240"))
|
||||
styles.Focused.Text = lipgloss.NewStyle().Foreground(lipgloss.Color("252"))
|
||||
styles.Focused.Placeholder = lipgloss.NewStyle().Foreground(theme.VeryMuted)
|
||||
styles.Focused.Text = lipgloss.NewStyle().Foreground(theme.Text)
|
||||
styles.Focused.Prompt = lipgloss.NewStyle()
|
||||
styles.Focused.CursorLine = lipgloss.NewStyle()
|
||||
ta.SetStyles(styles)
|
||||
@@ -87,9 +88,11 @@ func (t *ToolApprovalInput) View() tea.View {
|
||||
}
|
||||
containerStyle := lipgloss.NewStyle()
|
||||
|
||||
theme := GetTheme()
|
||||
|
||||
// PaddingLeft(3) aligns with message content: border(1) + paddingLeft(2).
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("252")).
|
||||
Foreground(theme.Text).
|
||||
MarginBottom(1).
|
||||
PaddingLeft(3)
|
||||
|
||||
@@ -100,19 +103,19 @@ func (t *ToolApprovalInput) View() tea.View {
|
||||
BorderRight(false).
|
||||
BorderTop(false).
|
||||
BorderBottom(false).
|
||||
BorderForeground(lipgloss.Color("39")).
|
||||
BorderForeground(theme.Primary).
|
||||
PaddingLeft(2). // match message block paddingLeft
|
||||
Width(t.width - 1) // full width minus left border
|
||||
|
||||
// Style for the currently selected/highlighted option
|
||||
selectedStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("42")). // Bright green
|
||||
Foreground(theme.Success).
|
||||
Bold(true).
|
||||
Underline(true)
|
||||
|
||||
// Style for the unselected/unhighlighted option
|
||||
unselectedStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("240")) // Dark gray
|
||||
Foreground(theme.VeryMuted)
|
||||
|
||||
// Build the view
|
||||
var view strings.Builder
|
||||
|
||||
@@ -49,6 +49,10 @@ func renderToolBody(toolName, toolArgs, toolResult string, width int) string {
|
||||
if body := renderBashBody(toolResult, width); body != "" {
|
||||
return body
|
||||
}
|
||||
case toolName == "spawn_subagent":
|
||||
if body := renderSubagentBody(toolResult, width); body != "" {
|
||||
return body
|
||||
}
|
||||
}
|
||||
return "" // fall back to default
|
||||
}
|
||||
@@ -716,6 +720,8 @@ func renderToolBodyCompact(toolName, toolArgs, toolResult string, width int) str
|
||||
case toolName == "bash" || toolName == "run_shell_cmd" ||
|
||||
strings.Contains(toolName, "shell") || strings.Contains(toolName, "command"):
|
||||
return renderBashCompact(toolResult, width)
|
||||
case toolName == "spawn_subagent":
|
||||
return renderSubagentCompact(toolResult)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -870,3 +876,121 @@ func renderBashCompact(toolResult string, width int) string {
|
||||
|
||||
return lipgloss.NewStyle().Foreground(theme.Muted).Render(summary)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Subagent tool renderers — show only summary, not full output
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// renderSubagentBody renders a clean summary of subagent results.
|
||||
// Extracts timing/token info and shows only a brief summary instead of raw output.
|
||||
func renderSubagentBody(toolResult string, width int) string {
|
||||
theme := getTheme()
|
||||
result := strings.TrimSpace(toolResult)
|
||||
if result == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Parse the subagent result format:
|
||||
// "Subagent completed successfully in Xs. (tokens: N in / M out)\n\nResult:\n..."
|
||||
// or "Subagent failed (exit code X) after Ys.\n\nError: ...\n\nPartial output:\n..."
|
||||
|
||||
lines := strings.Split(result, "\n")
|
||||
if len(lines) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// First line is always the status summary
|
||||
statusLine := lines[0]
|
||||
|
||||
// Build a clean summary
|
||||
var summary strings.Builder
|
||||
summary.WriteString(lipgloss.NewStyle().Foreground(theme.Muted).Render(statusLine))
|
||||
|
||||
// For successful results, extract a brief preview of the actual result
|
||||
if strings.Contains(statusLine, "successfully") {
|
||||
// Find where "Result:" starts and extract a preview
|
||||
if _, resultContent, found := strings.Cut(result, "Result:\n"); found {
|
||||
resultContent = strings.TrimSpace(resultContent)
|
||||
if resultContent != "" {
|
||||
// Show first 3 meaningful lines as preview
|
||||
preview := extractSubagentPreview(resultContent, 3, width-4)
|
||||
if preview != "" {
|
||||
summary.WriteString("\n\n")
|
||||
summary.WriteString(lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Italic(true).
|
||||
Render(preview))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return summary.String()
|
||||
}
|
||||
|
||||
// extractSubagentPreview extracts the first N non-empty lines from content,
|
||||
// truncating each line to maxWidth.
|
||||
func extractSubagentPreview(content string, maxLines, maxWidth int) string {
|
||||
lines := strings.Split(content, "\n")
|
||||
var preview []string
|
||||
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Truncate long lines
|
||||
if len(trimmed) > maxWidth {
|
||||
trimmed = trimmed[:maxWidth-3] + "..."
|
||||
}
|
||||
preview = append(preview, trimmed)
|
||||
|
||||
if len(preview) >= maxLines {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(preview) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
result := strings.Join(preview, "\n")
|
||||
|
||||
// Count remaining lines for "more" indicator
|
||||
totalLines := 0
|
||||
for _, line := range lines {
|
||||
if strings.TrimSpace(line) != "" {
|
||||
totalLines++
|
||||
}
|
||||
}
|
||||
if totalLines > maxLines {
|
||||
result += fmt.Sprintf("\n...(%d more lines)", totalLines-maxLines)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// renderSubagentCompact returns a brief one-line summary for subagent results.
|
||||
func renderSubagentCompact(toolResult string) string {
|
||||
result := strings.TrimSpace(toolResult)
|
||||
if result == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
theme := getTheme()
|
||||
|
||||
// Extract just the first line which contains the status
|
||||
lines := strings.Split(result, "\n")
|
||||
if len(lines) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
statusLine := lines[0]
|
||||
|
||||
// Make it more compact by removing redundant words
|
||||
statusLine = strings.Replace(statusLine, "Subagent completed successfully in ", "Completed in ", 1)
|
||||
statusLine = strings.Replace(statusLine, "Subagent failed", "Failed", 1)
|
||||
|
||||
return lipgloss.NewStyle().Foreground(theme.Muted).Italic(true).Render(statusLine)
|
||||
}
|
||||
|
||||
@@ -217,7 +217,14 @@ func (ts *TreeSelectorComponent) View() tea.View {
|
||||
// Header.
|
||||
b.WriteString(headerStyle.Render("Session Tree"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(helpStyle.Render("↑/↓: move ←/→: page enter: select esc: cancel ^O: cycle filter"))
|
||||
// Adapt help text to terminal width.
|
||||
if ts.width >= 70 {
|
||||
b.WriteString(helpStyle.Render("↑/↓: move ←/→: page enter: select esc: cancel ^O: cycle filter"))
|
||||
} else if ts.width >= 45 {
|
||||
b.WriteString(helpStyle.Render("↑↓ move ↵ select esc cancel ^O filter"))
|
||||
} else {
|
||||
b.WriteString(helpStyle.Render("↑↓ ↵ esc ^O"))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
|
||||
if ts.search != "" {
|
||||
@@ -269,9 +276,10 @@ func (ts *TreeSelectorComponent) IsActive() bool {
|
||||
// --- Internal helpers ---
|
||||
|
||||
func (ts *TreeSelectorComponent) visibleHeight() int {
|
||||
// Reserve lines for header(3) + search(1) + separator(1) + footer(2).
|
||||
h := max(ts.height/2-7, 5)
|
||||
return h
|
||||
// Chrome: header(1) + help(1) + separator(1) + entries + separator(1) + footer(1) = 5 fixed.
|
||||
// Optional search line adds 1 more. Use 7 as a safe estimate.
|
||||
const chromeLines = 7
|
||||
return max(ts.height-chromeLines, 3)
|
||||
}
|
||||
|
||||
func (ts *TreeSelectorComponent) rebuildFlatList() {
|
||||
@@ -389,7 +397,7 @@ func (ts *TreeSelectorComponent) passesFilter(node *session.TreeNode) bool {
|
||||
|
||||
func (ts *TreeSelectorComponent) renderNode(node FlatNode, isCursor, isLeaf bool) string {
|
||||
theme := GetTheme()
|
||||
maxWidth := ts.width - 4
|
||||
maxWidth := max(ts.width-4, 10)
|
||||
|
||||
// Cursor indicator.
|
||||
var cursor string
|
||||
@@ -401,9 +409,10 @@ func (ts *TreeSelectorComponent) renderNode(node FlatNode, isCursor, isLeaf bool
|
||||
|
||||
// Role-colored content.
|
||||
text := ts.entryDisplayText(node.Entry)
|
||||
if len(text) > maxWidth-len(node.Prefix)-10 {
|
||||
trimLen := maxWidth - len(node.Prefix) - 13
|
||||
if trimLen > 0 && trimLen < len(text) {
|
||||
available := maxWidth - len(node.Prefix) - 10
|
||||
if available > 3 && len(text) > available {
|
||||
trimLen := max(available-3, 1)
|
||||
if trimLen < len(text) {
|
||||
text = text[:trimLen] + "..."
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ func TestUsageTracker_OAuthCosts(t *testing.T) {
|
||||
stats := regularTracker.GetLastRequestStats()
|
||||
if stats == nil {
|
||||
t.Fatal("Expected stats to be non-nil")
|
||||
return
|
||||
}
|
||||
|
||||
// Check that costs are calculated for regular API key
|
||||
@@ -48,6 +49,7 @@ func TestUsageTracker_OAuthCosts(t *testing.T) {
|
||||
oauthStats := oauthTracker.GetLastRequestStats()
|
||||
if oauthStats == nil {
|
||||
t.Fatal("Expected OAuth stats to be non-nil")
|
||||
return
|
||||
}
|
||||
|
||||
// Check that all costs are $0 for OAuth
|
||||
|
||||
@@ -0,0 +1,371 @@
|
||||
# Testing Kit Extensions
|
||||
|
||||
The `github.com/mark3labs/kit/pkg/extensions/test` package provides utilities for testing Kit extensions using standard Go testing patterns.
|
||||
|
||||
## Overview
|
||||
|
||||
Extension tests run outside the Yaegi interpreter but load your extension code into an isolated interpreter instance. This allows you to:
|
||||
|
||||
- Test event handlers without running the full Kit TUI
|
||||
- Verify that your extension registers tools/commands correctly
|
||||
- Assert that context methods (Print, SetWidget, etc.) are called as expected
|
||||
- Test blocking and non-blocking event handling
|
||||
|
||||
## Installation
|
||||
|
||||
The test package is part of the Kit codebase. Import it in your extension tests:
|
||||
|
||||
```go
|
||||
import (
|
||||
"testing"
|
||||
"github.com/mark3labs/kit/pkg/extensions/test"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
)
|
||||
```
|
||||
|
||||
## Basic Usage
|
||||
|
||||
### Testing an Extension File
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"github.com/mark3labs/kit/pkg/extensions/test"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
)
|
||||
|
||||
func TestMyExtension(t *testing.T) {
|
||||
// Create a test harness
|
||||
harness := test.New(t)
|
||||
|
||||
// Load your extension
|
||||
harness.LoadFile("my-ext.go")
|
||||
|
||||
// Emit events and verify behavior
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Verify the extension printed something
|
||||
test.AssertPrinted(t, harness, "session started")
|
||||
}
|
||||
```
|
||||
|
||||
### Testing Inline Extension Code
|
||||
|
||||
For quick tests, you can load extension source directly:
|
||||
|
||||
```go
|
||||
func TestToolBlocking(t *testing.T) {
|
||||
src := `package main
|
||||
|
||||
import "kit/ext"
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnToolCall(func(tc ext.ToolCallEvent, ctx ext.Context) *ext.ToolCallResult {
|
||||
if tc.ToolName == "dangerous" {
|
||||
return &ext.ToolCallResult{Block: true, Reason: "not allowed"}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
`
|
||||
harness := test.New(t)
|
||||
harness.LoadString(src, "test-ext.go")
|
||||
|
||||
// Test the tool is blocked
|
||||
result, _ := harness.Emit(extensions.ToolCallEvent{
|
||||
ToolName: "dangerous",
|
||||
Input: "{}",
|
||||
})
|
||||
|
||||
test.AssertBlocked(t, result, "not allowed")
|
||||
}
|
||||
```
|
||||
|
||||
## Common Testing Patterns
|
||||
|
||||
### Testing Tool Registration
|
||||
|
||||
```go
|
||||
func TestToolRegistration(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("my-ext.go")
|
||||
|
||||
// Verify the tool was registered
|
||||
test.AssertToolRegistered(t, harness, "my_tool")
|
||||
|
||||
// Or inspect tools directly
|
||||
tools := harness.RegisteredTools()
|
||||
for _, tool := range tools {
|
||||
if tool.Name == "my_tool" {
|
||||
t.Logf("Tool description: %s", tool.Description)
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Testing Command Registration
|
||||
|
||||
```go
|
||||
func TestCommandRegistration(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("my-ext.go")
|
||||
|
||||
test.AssertCommandRegistered(t, harness, "mycommand")
|
||||
}
|
||||
```
|
||||
|
||||
### Testing Widgets
|
||||
|
||||
```go
|
||||
func TestWidgetBehavior(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("my-ext.go")
|
||||
|
||||
// Trigger the event that creates the widget
|
||||
_, _ = harness.Emit(extensions.SessionStartEvent{SessionID: "test"})
|
||||
|
||||
// Verify the widget was set
|
||||
test.AssertWidgetSet(t, harness, "my-widget")
|
||||
|
||||
// Verify specific widget content
|
||||
test.AssertWidgetText(t, harness, "my-widget", "Expected Text")
|
||||
|
||||
// Or verify partial content
|
||||
test.AssertWidgetTextContains(t, harness, "my-widget", "partial")
|
||||
}
|
||||
```
|
||||
|
||||
### Testing Input Handling
|
||||
|
||||
```go
|
||||
func TestInputHandling(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("my-ext.go")
|
||||
|
||||
// Test that the extension handles certain input
|
||||
result, _ := harness.Emit(extensions.InputEvent{
|
||||
Text: "secret password",
|
||||
Source: "cli",
|
||||
})
|
||||
|
||||
test.AssertInputHandled(t, result, "handled")
|
||||
}
|
||||
```
|
||||
|
||||
### Testing Print Functions
|
||||
|
||||
```go
|
||||
func TestPrintOutput(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("my-ext.go")
|
||||
|
||||
_, _ = harness.Emit(extensions.ToolCallEvent{
|
||||
ToolName: "test",
|
||||
Input: "{}",
|
||||
})
|
||||
|
||||
// Assert exact match
|
||||
test.AssertPrinted(t, harness, "exact output")
|
||||
|
||||
// Or partial match
|
||||
test.AssertPrintedContains(t, harness, "partial")
|
||||
|
||||
// Assert info/error messages
|
||||
test.AssertPrintInfo(t, harness, "info message")
|
||||
test.AssertPrintError(t, harness, "error message")
|
||||
}
|
||||
```
|
||||
|
||||
### Testing Status Bar
|
||||
|
||||
```go
|
||||
func TestStatusBar(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("my-ext.go")
|
||||
|
||||
_, _ = harness.Emit(extensions.AgentEndEvent{})
|
||||
|
||||
test.AssertStatusSet(t, harness, "myext:status")
|
||||
test.AssertStatusText(t, harness, "myext:status", "Ready")
|
||||
}
|
||||
```
|
||||
|
||||
### Testing Prompt Results
|
||||
|
||||
Configure the mock context to return specific prompt results:
|
||||
|
||||
```go
|
||||
func TestWithPrompts(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("my-ext.go")
|
||||
|
||||
// Configure prompt results before emitting events
|
||||
harness.Context().SetPromptSelectResult(extensions.PromptSelectResult{
|
||||
Value: "option1",
|
||||
Index: 0,
|
||||
Cancelled: false,
|
||||
})
|
||||
|
||||
// Now when your extension calls ctx.PromptSelect(), it will get this result
|
||||
_, _ = harness.Emit(extensions.SessionStartEvent{SessionID: "test"})
|
||||
}
|
||||
```
|
||||
|
||||
## Available Assertions
|
||||
|
||||
The test package provides these assertion helpers:
|
||||
|
||||
**Event Results:**
|
||||
- `AssertNotBlocked(t, result)` - Verify tool was not blocked
|
||||
- `AssertBlocked(t, result, reason)` - Verify tool was blocked with reason
|
||||
- `AssertInputHandled(t, result, action)` - Verify input was handled
|
||||
- `AssertInputTransformed(t, result, text)` - Verify input transformation
|
||||
|
||||
**Context Interactions:**
|
||||
- `AssertPrinted(t, harness, text)` - Verify exact print output
|
||||
- `AssertPrintedContains(t, harness, substring)` - Verify partial print output
|
||||
- `AssertPrintInfo(t, harness, text)` - Verify PrintInfo was called
|
||||
- `AssertPrintError(t, harness, text)` - Verify PrintError was called
|
||||
- `AssertWidgetSet(t, harness, id)` - Verify widget was set
|
||||
- `AssertWidgetNotSet(t, harness, id)` - Verify widget was not set
|
||||
- `AssertWidgetText(t, harness, id, text)` - Verify widget content
|
||||
- `AssertWidgetTextContains(t, harness, id, substring)` - Verify widget contains text
|
||||
- `AssertHeaderSet(t, harness)` - Verify header was set
|
||||
- `AssertFooterSet(t, harness)` - Verify footer was set
|
||||
- `AssertStatusSet(t, harness, key)` - Verify status was set
|
||||
- `AssertStatusText(t, harness, key, text)` - Verify status text
|
||||
|
||||
**Registration:**
|
||||
- `AssertToolRegistered(t, harness, name)` - Verify tool registration
|
||||
- `AssertCommandRegistered(t, harness, name)` - Verify command registration
|
||||
- `AssertHasHandlers(t, harness, eventType)` - Verify handlers exist
|
||||
- `AssertNoHandlers(t, harness, eventType)` - Verify no handlers
|
||||
|
||||
**Messaging:**
|
||||
- `AssertMessageSent(t, harness, text)` - Verify SendMessage was called
|
||||
- `AssertCancelAndSend(t, harness, text)` - Verify CancelAndSend was called
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Accessing the Mock Context
|
||||
|
||||
For custom assertions, access the mock context directly:
|
||||
|
||||
```go
|
||||
func TestCustomAssertion(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("my-ext.go")
|
||||
|
||||
_, _ = harness.Emit(extensions.SessionStartEvent{SessionID: "test"})
|
||||
|
||||
// Get all recorded prints
|
||||
prints := harness.Context().GetPrints()
|
||||
|
||||
// Check widget directly
|
||||
widget, ok := harness.Context().GetWidget("my-widget")
|
||||
if ok && widget.Style.BorderColor == "#ff0000" {
|
||||
t.Log("Widget has red border")
|
||||
}
|
||||
|
||||
// Check options
|
||||
optionValue := harness.Context().GetOption("my-option")
|
||||
}
|
||||
```
|
||||
|
||||
### Testing Multiple Extensions
|
||||
|
||||
Each harness is isolated:
|
||||
|
||||
```go
|
||||
func TestExtensionIsolation(t *testing.T) {
|
||||
// These run in completely separate interpreters
|
||||
harness1 := test.New(t)
|
||||
harness1.LoadFile("ext1.go")
|
||||
|
||||
harness2 := test.New(t)
|
||||
harness2.LoadFile("ext2.go")
|
||||
|
||||
// Events to one don't affect the other
|
||||
}
|
||||
```
|
||||
|
||||
### Direct Result Extraction
|
||||
|
||||
When you need to inspect result details:
|
||||
|
||||
```go
|
||||
result, _ := harness.Emit(extensions.ToolCallEvent{...})
|
||||
tcr := test.GetToolCallResult(result)
|
||||
if tcr != nil {
|
||||
t.Logf("Block: %v, Reason: %s", tcr.Block, tcr.Reason)
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Test one behavior per test** - Keep tests focused and readable
|
||||
2. **Use inline source for simple tests** - LoadString is great for isolated tests
|
||||
3. **Use LoadFile for integration tests** - Tests the actual extension file
|
||||
4. **Assert on context calls** - Verify your extension interacts with the context correctly
|
||||
5. **Test both positive and negative cases** - Verify tools are blocked AND allowed appropriately
|
||||
6. **Test all event handlers** - Make sure all registered handlers work correctly
|
||||
|
||||
## Limitations
|
||||
|
||||
The test harness has these limitations:
|
||||
|
||||
1. **No TUI rendering** - Widgets are recorded but not rendered visually
|
||||
2. **Prompts return configured values** - You must pre-configure prompt results in tests
|
||||
3. **Subagents don't spawn real processes** - SpawnSubagent returns nil/empty results
|
||||
4. **LLM completions are mocked** - Complete returns empty responses
|
||||
5. **Some context methods are no-ops** - Exit, SetActiveTools, etc. don't have side effects
|
||||
|
||||
These limitations are intentional - the test harness focuses on testing extension logic, not the full Kit runtime.
|
||||
|
||||
## Example: Complete Extension Test
|
||||
|
||||
Here's a complete example testing a realistic extension:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"github.com/mark3labs/kit/pkg/extensions/test"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
)
|
||||
|
||||
// Test that the extension properly blocks dangerous tools
|
||||
func TestSafetyExtension_BlocksDangerousTools(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("safety-ext.go")
|
||||
|
||||
// Verify it handles tool calls
|
||||
test.AssertHasHandlers(t, harness, extensions.ToolCall)
|
||||
|
||||
// Test allowed tool
|
||||
result, _ := harness.Emit(extensions.ToolCallEvent{ToolName: "read", Input: "{}"})
|
||||
test.AssertNotBlocked(t, result)
|
||||
|
||||
// Test blocked tool
|
||||
result, _ = harness.Emit(extensions.ToolCallEvent{ToolName: "rm", Input: "{}"})
|
||||
test.AssertBlocked(t, result, "safety block")
|
||||
test.AssertPrintError(t, harness, "Tool rm is blocked")
|
||||
}
|
||||
|
||||
// Test that the extension shows status on agent completion
|
||||
func TestSafetyExtension_ShowsStatus(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("safety-ext.go")
|
||||
|
||||
_, _ = harness.Emit(extensions.AgentEndEvent{})
|
||||
|
||||
test.AssertWidgetSet(t, harness, "safety-widget")
|
||||
test.AssertWidgetTextContains(t, harness, "safety-widget", "Safe")
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,297 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
)
|
||||
|
||||
// AssertNotBlocked fails the test if the tool call result indicates the tool was blocked.
|
||||
func AssertNotBlocked(t *testing.T, result extensions.Result) {
|
||||
t.Helper()
|
||||
if result == nil {
|
||||
return
|
||||
}
|
||||
if tcr, ok := result.(extensions.ToolCallResult); ok {
|
||||
if tcr.Block {
|
||||
t.Errorf("expected tool to not be blocked, but it was blocked with reason: %q", tcr.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AssertBlocked fails the test if the tool call result does not indicate the tool was blocked.
|
||||
func AssertBlocked(t *testing.T, result extensions.Result, expectedReason string) {
|
||||
t.Helper()
|
||||
if result == nil {
|
||||
t.Error("expected tool to be blocked, but result was nil")
|
||||
return
|
||||
}
|
||||
tcr, ok := result.(extensions.ToolCallResult)
|
||||
if !ok {
|
||||
t.Errorf("expected ToolCallResult, got %T", result)
|
||||
return
|
||||
}
|
||||
if !tcr.Block {
|
||||
t.Error("expected tool to be blocked, but it was not blocked")
|
||||
return
|
||||
}
|
||||
if expectedReason != "" && tcr.Reason != expectedReason {
|
||||
t.Errorf("expected block reason %q, got %q", expectedReason, tcr.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
// AssertInputHandled fails the test if the input result does not indicate the input was handled.
|
||||
func AssertInputHandled(t *testing.T, result extensions.Result, expectedAction string) {
|
||||
t.Helper()
|
||||
if result == nil {
|
||||
t.Error("expected input to be handled, but result was nil")
|
||||
return
|
||||
}
|
||||
ir, ok := result.(extensions.InputResult)
|
||||
if !ok {
|
||||
t.Errorf("expected InputResult, got %T", result)
|
||||
return
|
||||
}
|
||||
if ir.Action != expectedAction {
|
||||
t.Errorf("expected action %q, got %q", expectedAction, ir.Action)
|
||||
}
|
||||
}
|
||||
|
||||
// AssertInputTransformed fails the test if the input was not transformed to the expected text.
|
||||
func AssertInputTransformed(t *testing.T, result extensions.Result, expectedText string) {
|
||||
t.Helper()
|
||||
if result == nil {
|
||||
t.Errorf("expected input to be transformed to %q, but result was nil", expectedText)
|
||||
return
|
||||
}
|
||||
ir, ok := result.(extensions.InputResult)
|
||||
if !ok {
|
||||
t.Errorf("expected InputResult, got %T", result)
|
||||
return
|
||||
}
|
||||
if ir.Action != "transform" {
|
||||
t.Errorf("expected action 'transform', got %q", ir.Action)
|
||||
}
|
||||
if ir.Text != expectedText {
|
||||
t.Errorf("expected transformed text %q, got %q", expectedText, ir.Text)
|
||||
}
|
||||
}
|
||||
|
||||
// AssertPrinted fails the test if the expected text was not printed.
|
||||
func AssertPrinted(t *testing.T, harness *Harness, expected string) {
|
||||
t.Helper()
|
||||
prints := harness.Context().GetPrints()
|
||||
if slices.Contains(prints, expected) {
|
||||
return
|
||||
}
|
||||
t.Errorf("expected text %q to be printed, but it was not found in prints: %v", expected, prints)
|
||||
}
|
||||
|
||||
// AssertPrintedContains fails the test if no printed text contains the expected substring.
|
||||
func AssertPrintedContains(t *testing.T, harness *Harness, substring string) {
|
||||
t.Helper()
|
||||
prints := harness.Context().GetPrints()
|
||||
for _, p := range prints {
|
||||
if strings.Contains(p, substring) {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Errorf("expected printed text to contain %q, but it was not found in prints: %v", substring, prints)
|
||||
}
|
||||
|
||||
// AssertPrintInfo fails the test if the expected info message was not printed.
|
||||
func AssertPrintInfo(t *testing.T, harness *Harness, expected string) {
|
||||
t.Helper()
|
||||
infos := harness.Context().GetPrintInfos()
|
||||
if slices.Contains(infos, expected) {
|
||||
return
|
||||
}
|
||||
t.Errorf("expected info message %q, but it was not found in PrintInfos: %v", expected, infos)
|
||||
}
|
||||
|
||||
// AssertPrintError fails the test if the expected error message was not printed.
|
||||
func AssertPrintError(t *testing.T, harness *Harness, expected string) {
|
||||
t.Helper()
|
||||
errors := harness.Context().GetPrintErrors()
|
||||
if slices.Contains(errors, expected) {
|
||||
return
|
||||
}
|
||||
t.Errorf("expected error message %q, but it was not found in PrintErrors: %v", expected, errors)
|
||||
}
|
||||
|
||||
// AssertWidgetSet fails the test if the widget with the given ID was not set.
|
||||
func AssertWidgetSet(t *testing.T, harness *Harness, id string) {
|
||||
t.Helper()
|
||||
if !harness.Context().HasWidget(id) {
|
||||
t.Errorf("expected widget %q to be set, but it was not", id)
|
||||
}
|
||||
}
|
||||
|
||||
// AssertWidgetNotSet fails the test if the widget with the given ID was set.
|
||||
func AssertWidgetNotSet(t *testing.T, harness *Harness, id string) {
|
||||
t.Helper()
|
||||
if harness.Context().HasWidget(id) {
|
||||
t.Errorf("expected widget %q to not be set, but it was", id)
|
||||
}
|
||||
}
|
||||
|
||||
// AssertWidgetText fails the test if the widget with the given ID does not have the expected text.
|
||||
func AssertWidgetText(t *testing.T, harness *Harness, id string, expected string) {
|
||||
t.Helper()
|
||||
widget, ok := harness.Context().GetWidget(id)
|
||||
if !ok {
|
||||
t.Errorf("expected widget %q to be set, but it was not", id)
|
||||
return
|
||||
}
|
||||
if widget.Content.Text != expected {
|
||||
t.Errorf("expected widget %q to have text %q, got %q", id, expected, widget.Content.Text)
|
||||
}
|
||||
}
|
||||
|
||||
// AssertWidgetTextContains fails the test if the widget text does not contain the expected substring.
|
||||
func AssertWidgetTextContains(t *testing.T, harness *Harness, id string, substring string) {
|
||||
t.Helper()
|
||||
widget, ok := harness.Context().GetWidget(id)
|
||||
if !ok {
|
||||
t.Errorf("expected widget %q to be set, but it was not", id)
|
||||
return
|
||||
}
|
||||
if !strings.Contains(widget.Content.Text, substring) {
|
||||
t.Errorf("expected widget %q text to contain %q, but got %q", id, substring, widget.Content.Text)
|
||||
}
|
||||
}
|
||||
|
||||
// AssertHeaderSet fails the test if no header was set.
|
||||
func AssertHeaderSet(t *testing.T, harness *Harness) {
|
||||
t.Helper()
|
||||
if harness.Context().GetHeader() == nil {
|
||||
t.Error("expected header to be set, but it was not")
|
||||
}
|
||||
}
|
||||
|
||||
// AssertFooterSet fails the test if no footer was set.
|
||||
func AssertFooterSet(t *testing.T, harness *Harness) {
|
||||
t.Helper()
|
||||
if harness.Context().GetFooter() == nil {
|
||||
t.Error("expected footer to be set, but it was not")
|
||||
}
|
||||
}
|
||||
|
||||
// AssertStatusSet fails the test if the status with the given key was not set.
|
||||
func AssertStatusSet(t *testing.T, harness *Harness, key string) {
|
||||
t.Helper()
|
||||
_, ok := harness.Context().GetStatus(key)
|
||||
if !ok {
|
||||
t.Errorf("expected status %q to be set, but it was not", key)
|
||||
}
|
||||
}
|
||||
|
||||
// AssertStatusText fails the test if the status with the given key does not have the expected text.
|
||||
func AssertStatusText(t *testing.T, harness *Harness, key string, expected string) {
|
||||
t.Helper()
|
||||
status, ok := harness.Context().GetStatus(key)
|
||||
if !ok {
|
||||
t.Errorf("expected status %q to be set, but it was not", key)
|
||||
return
|
||||
}
|
||||
if status.Text != expected {
|
||||
t.Errorf("expected status %q to have text %q, got %q", key, expected, status.Text)
|
||||
}
|
||||
}
|
||||
|
||||
// AssertHasHandlers fails the test if no handlers are registered for the given event type.
|
||||
func AssertHasHandlers(t *testing.T, harness *Harness, eventType extensions.EventType) {
|
||||
t.Helper()
|
||||
if !harness.HasHandlers(eventType) {
|
||||
t.Errorf("expected handlers for event type %q, but none were registered", eventType)
|
||||
}
|
||||
}
|
||||
|
||||
// AssertNoHandlers fails the test if any handlers are registered for the given event type.
|
||||
func AssertNoHandlers(t *testing.T, harness *Harness, eventType extensions.EventType) {
|
||||
t.Helper()
|
||||
if harness.HasHandlers(eventType) {
|
||||
t.Errorf("expected no handlers for event type %q, but some were registered", eventType)
|
||||
}
|
||||
}
|
||||
|
||||
// AssertToolRegistered fails the test if the tool with the given name was not registered.
|
||||
func AssertToolRegistered(t *testing.T, harness *Harness, toolName string) {
|
||||
t.Helper()
|
||||
tools := harness.RegisteredTools()
|
||||
for _, tool := range tools {
|
||||
if tool.Name == toolName {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Errorf("expected tool %q to be registered, but it was not found in %v", toolName, tools)
|
||||
}
|
||||
|
||||
// AssertCommandRegistered fails the test if the command with the given name was not registered.
|
||||
func AssertCommandRegistered(t *testing.T, harness *Harness, cmdName string) {
|
||||
t.Helper()
|
||||
cmds := harness.RegisteredCommands()
|
||||
for _, cmd := range cmds {
|
||||
if cmd.Name == cmdName {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Errorf("expected command %q to be registered, but it was not found in %v", cmdName, cmds)
|
||||
}
|
||||
|
||||
// AssertMessageSent fails the test if the expected message was not sent.
|
||||
func AssertMessageSent(t *testing.T, harness *Harness, expected string) {
|
||||
t.Helper()
|
||||
ctx := harness.Context()
|
||||
if slices.Contains(ctx.Messages, expected) {
|
||||
return
|
||||
}
|
||||
t.Errorf("expected message %q to be sent, but it was not found in messages: %v", expected, ctx.Messages)
|
||||
}
|
||||
|
||||
// AssertCancelAndSend fails the test if the expected text was not sent via CancelAndSend.
|
||||
func AssertCancelAndSend(t *testing.T, harness *Harness, expected string) {
|
||||
t.Helper()
|
||||
ctx := harness.Context()
|
||||
if slices.Contains(ctx.CancelSends, expected) {
|
||||
return
|
||||
}
|
||||
t.Errorf("expected CancelAndSend with %q, but it was not found: %v", expected, ctx.CancelSends)
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// GetToolCallResult extracts a ToolCallResult from a Result, or nil if not applicable.
|
||||
func GetToolCallResult(result extensions.Result) *extensions.ToolCallResult {
|
||||
if result == nil {
|
||||
return nil
|
||||
}
|
||||
if tcr, ok := result.(extensions.ToolCallResult); ok {
|
||||
return &tcr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetInputResult extracts an InputResult from a Result, or nil if not applicable.
|
||||
func GetInputResult(result extensions.Result) *extensions.InputResult {
|
||||
if result == nil {
|
||||
return nil
|
||||
}
|
||||
if ir, ok := result.(extensions.InputResult); ok {
|
||||
return &ir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetToolResultResult extracts a ToolResultResult from a Result, or nil if not applicable.
|
||||
func GetToolResultResult(result extensions.Result) *extensions.ToolResultResult {
|
||||
if result == nil {
|
||||
return nil
|
||||
}
|
||||
if trr, ok := result.(extensions.ToolResultResult); ok {
|
||||
return &trr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,232 @@
|
||||
// Package test provides utilities for testing Kit extensions.
|
||||
//
|
||||
// This package allows extension authors to write standard Go tests that load
|
||||
// and exercise their extensions in a controlled environment. Extensions are
|
||||
// loaded into a Yaegi interpreter with all Kit API symbols available.
|
||||
//
|
||||
// Basic usage:
|
||||
//
|
||||
// package main
|
||||
//
|
||||
// import (
|
||||
// "testing"
|
||||
// "github.com/mark3labs/kit/pkg/extensions/test"
|
||||
// )
|
||||
//
|
||||
// func TestMyExtension(t *testing.T) {
|
||||
// // Create a test harness
|
||||
// harness := test.New(t)
|
||||
//
|
||||
// // Load your extension file
|
||||
// ext := harness.LoadFile("my-ext.go")
|
||||
//
|
||||
// // Emit events and check results
|
||||
// result := harness.Emit(test.ToolCallEvent{
|
||||
// ToolName: "my_tool",
|
||||
// Input: `{"key": "value"}`,
|
||||
// })
|
||||
//
|
||||
// // Use assertion helpers
|
||||
// test.AssertNotBlocked(t, result)
|
||||
// test.AssertPrinted(t, harness, "expected output")
|
||||
// }
|
||||
//
|
||||
// The harness provides a mock Context that records all interactions,
|
||||
// allowing you to verify that your extension called SetWidget, Print, etc.
|
||||
package test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/traefik/yaegi/interp"
|
||||
"github.com/traefik/yaegi/stdlib"
|
||||
"github.com/traefik/yaegi/stdlib/unrestricted"
|
||||
)
|
||||
|
||||
// Harness provides a testing environment for Kit extensions.
|
||||
// It loads extensions into an isolated Yaegi interpreter and provides
|
||||
// methods to emit events and verify extension behavior.
|
||||
type Harness struct {
|
||||
t *testing.T
|
||||
runner *extensions.Runner
|
||||
context *MockContext
|
||||
extPath string
|
||||
}
|
||||
|
||||
// New creates a new test harness for the given test.
|
||||
// The harness must be used within a single test function.
|
||||
func New(t *testing.T) *Harness {
|
||||
return &Harness{
|
||||
t: t,
|
||||
context: NewMockContext(),
|
||||
}
|
||||
}
|
||||
|
||||
// LoadFile loads an extension from a file path.
|
||||
// The extension is evaluated in a fresh Yaegi interpreter with all
|
||||
// Kit API symbols available. The Init function is called automatically.
|
||||
//
|
||||
// Returns the loaded extension or fails the test on error.
|
||||
func (h *Harness) LoadFile(path string) *extensions.LoadedExtension {
|
||||
h.t.Helper()
|
||||
|
||||
// Verify file exists
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
h.t.Fatalf("extension file not found: %s: %v", path, err)
|
||||
}
|
||||
|
||||
// Read extension source
|
||||
src, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
h.t.Fatalf("failed to read extension file: %v", err)
|
||||
}
|
||||
|
||||
return h.loadSource(string(src), path)
|
||||
}
|
||||
|
||||
// LoadString loads an extension from a source string.
|
||||
// Useful for inline extension tests. The path is used for error reporting.
|
||||
func (h *Harness) LoadString(src string, path string) *extensions.LoadedExtension {
|
||||
h.t.Helper()
|
||||
return h.loadSource(src, path)
|
||||
}
|
||||
|
||||
// loadSource is the internal implementation that loads extension source
|
||||
// into a Yaegi interpreter.
|
||||
func (h *Harness) loadSource(src string, path string) *extensions.LoadedExtension {
|
||||
h.t.Helper()
|
||||
|
||||
// Create a fresh interpreter
|
||||
i := interp.New(interp.Options{})
|
||||
|
||||
// Expose Go stdlib
|
||||
if err := i.Use(stdlib.Symbols); err != nil {
|
||||
h.t.Fatalf("failed to load stdlib symbols: %v", err)
|
||||
}
|
||||
if err := i.Use(unrestricted.Symbols); err != nil {
|
||||
h.t.Fatalf("failed to load unrestricted symbols: %v", err)
|
||||
}
|
||||
|
||||
// Expose Kit extension API symbols
|
||||
if err := i.Use(extensions.Symbols()); err != nil {
|
||||
h.t.Fatalf("failed to load extension symbols: %v", err)
|
||||
}
|
||||
|
||||
// Evaluate the extension source
|
||||
if _, err := i.Eval(src); err != nil {
|
||||
h.t.Fatalf("failed to evaluate extension source: %v", err)
|
||||
}
|
||||
|
||||
// Extract the Init function
|
||||
initVal, err := i.Eval("Init")
|
||||
if err != nil {
|
||||
h.t.Fatalf("extension has no Init function: %v", err)
|
||||
}
|
||||
|
||||
initFn, ok := initVal.Interface().(func(extensions.API))
|
||||
if !ok {
|
||||
h.t.Fatalf("Init has wrong signature (want func(ext.API), got %T)", initVal.Interface())
|
||||
}
|
||||
|
||||
// Create the extension struct
|
||||
ext := &extensions.LoadedExtension{
|
||||
Path: path,
|
||||
Handlers: make(map[extensions.EventType][]extensions.HandlerFunc),
|
||||
}
|
||||
|
||||
// Create the API object using the test helper
|
||||
api := extensions.NewTestAPI(ext)
|
||||
|
||||
// Call Init to register handlers
|
||||
initFn(api)
|
||||
|
||||
// Create runner with the loaded extension
|
||||
h.runner = extensions.NewRunner([]extensions.LoadedExtension{*ext})
|
||||
h.extPath = path
|
||||
|
||||
// Wire the mock context
|
||||
h.runner.SetContext(h.context.ToContext())
|
||||
|
||||
return ext
|
||||
}
|
||||
|
||||
// Emit sends an event to the loaded extension(s) and returns the result.
|
||||
// Events are dispatched in order and blocking results stop propagation.
|
||||
func (h *Harness) Emit(event extensions.Event) (extensions.Result, error) {
|
||||
h.t.Helper()
|
||||
|
||||
if h.runner == nil {
|
||||
h.t.Fatal("no extension loaded, call LoadFile() or LoadString() first")
|
||||
}
|
||||
|
||||
return h.runner.Emit(event)
|
||||
}
|
||||
|
||||
// EmitJSON is a convenience method for emitting a ToolCallEvent with JSON input.
|
||||
func (h *Harness) EmitJSON(toolName string, input string) (*extensions.ToolCallResult, error) {
|
||||
h.t.Helper()
|
||||
|
||||
result, err := h.Emit(extensions.ToolCallEvent{
|
||||
ToolName: toolName,
|
||||
Input: input,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
tcr, ok := result.(extensions.ToolCallResult)
|
||||
if !ok {
|
||||
h.t.Fatalf("expected ToolCallResult, got %T", result)
|
||||
}
|
||||
|
||||
return &tcr, nil
|
||||
}
|
||||
|
||||
// Context returns the mock context for inspection.
|
||||
// Use this to verify Print calls, widget settings, etc.
|
||||
func (h *Harness) Context() *MockContext {
|
||||
return h.context
|
||||
}
|
||||
|
||||
// Runner returns the underlying runner for advanced use cases.
|
||||
func (h *Harness) Runner() *extensions.Runner {
|
||||
return h.runner
|
||||
}
|
||||
|
||||
// HasHandlers reports whether any handlers are registered for the given event type.
|
||||
func (h *Harness) HasHandlers(eventType extensions.EventType) bool {
|
||||
if h.runner == nil {
|
||||
return false
|
||||
}
|
||||
return h.runner.HasHandlers(eventType)
|
||||
}
|
||||
|
||||
// RegisteredTools returns all tools registered by the extension.
|
||||
func (h *Harness) RegisteredTools() []extensions.ToolDef {
|
||||
if h.runner == nil {
|
||||
return nil
|
||||
}
|
||||
return h.runner.RegisteredTools()
|
||||
}
|
||||
|
||||
// RegisteredCommands returns all commands registered by the extension.
|
||||
func (h *Harness) RegisteredCommands() []extensions.CommandDef {
|
||||
if h.runner == nil {
|
||||
return nil
|
||||
}
|
||||
return h.runner.RegisteredCommands()
|
||||
}
|
||||
|
||||
// MustLoad is like LoadFile but fails the test immediately on error.
|
||||
// It returns the harness for chaining.
|
||||
func (h *Harness) MustLoad(path string) *Harness {
|
||||
h.t.Helper()
|
||||
h.LoadFile(path)
|
||||
return h
|
||||
}
|
||||
@@ -0,0 +1,568 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
)
|
||||
|
||||
// Test harness with a simple extension
|
||||
func TestHarness_LoadString(t *testing.T) {
|
||||
src := `package main
|
||||
|
||||
import "kit/ext"
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
ctx.Print("session started")
|
||||
})
|
||||
}
|
||||
`
|
||||
|
||||
harness := New(t)
|
||||
harness.LoadString(src, "test-ext.go")
|
||||
|
||||
// Emit session start event
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Verify the extension printed something
|
||||
prints := harness.Context().GetPrints()
|
||||
if len(prints) != 1 || prints[0] != "session started" {
|
||||
t.Errorf("expected ['session started'], got %v", prints)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHarness_ToolCallBlocking(t *testing.T) {
|
||||
src := `package main
|
||||
|
||||
import "kit/ext"
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnToolCall(func(tc ext.ToolCallEvent, ctx ext.Context) *ext.ToolCallResult {
|
||||
if tc.ToolName == "banned" {
|
||||
return &ext.ToolCallResult{Block: true, Reason: "tool is banned"}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
`
|
||||
|
||||
harness := New(t)
|
||||
harness.LoadString(src, "blocker.go")
|
||||
|
||||
// Test blocked tool
|
||||
result, err := harness.Emit(extensions.ToolCallEvent{ToolName: "banned", Input: "{}"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
AssertBlocked(t, result, "tool is banned")
|
||||
|
||||
// Test allowed tool
|
||||
result2, err := harness.Emit(extensions.ToolCallEvent{ToolName: "allowed", Input: "{}"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result2 != nil {
|
||||
t.Errorf("expected nil result for allowed tool, got %v", result2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHarness_ToolRegistration(t *testing.T) {
|
||||
src := `package main
|
||||
|
||||
import "kit/ext"
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.RegisterTool(ext.ToolDef{
|
||||
Name: "my_tool",
|
||||
Description: "does stuff",
|
||||
Parameters: "{}",
|
||||
Execute: func(input string) (string, error) {
|
||||
return "result: " + input, nil
|
||||
},
|
||||
})
|
||||
}
|
||||
`
|
||||
|
||||
harness := New(t)
|
||||
harness.LoadString(src, "tool-ext.go")
|
||||
|
||||
tools := harness.RegisteredTools()
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(tools))
|
||||
}
|
||||
|
||||
if tools[0].Name != "my_tool" {
|
||||
t.Errorf("expected tool name 'my_tool', got %q", tools[0].Name)
|
||||
}
|
||||
|
||||
AssertToolRegistered(t, harness, "my_tool")
|
||||
}
|
||||
|
||||
func TestHarness_CommandRegistration(t *testing.T) {
|
||||
src := `package main
|
||||
|
||||
import "kit/ext"
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "hello",
|
||||
Description: "says hello",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
ctx.Print("Hello, " + args)
|
||||
return "greeting sent", nil
|
||||
},
|
||||
})
|
||||
}
|
||||
`
|
||||
|
||||
harness := New(t)
|
||||
harness.LoadString(src, "cmd-ext.go")
|
||||
|
||||
cmds := harness.RegisteredCommands()
|
||||
if len(cmds) != 1 {
|
||||
t.Fatalf("expected 1 command, got %d", len(cmds))
|
||||
}
|
||||
|
||||
if cmds[0].Name != "hello" {
|
||||
t.Errorf("expected command name 'hello', got %q", cmds[0].Name)
|
||||
}
|
||||
|
||||
AssertCommandRegistered(t, harness, "hello")
|
||||
}
|
||||
|
||||
func TestHarness_WidgetSetting(t *testing.T) {
|
||||
src := `package main
|
||||
|
||||
import "kit/ext"
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
ctx.SetWidget(ext.WidgetConfig{
|
||||
ID: "my-widget",
|
||||
Placement: ext.WidgetAbove,
|
||||
Content: ext.WidgetContent{Text: "Hello, World!"},
|
||||
Style: ext.WidgetStyle{BorderColor: "#ff0000"},
|
||||
})
|
||||
})
|
||||
}
|
||||
`
|
||||
|
||||
harness := New(t)
|
||||
harness.LoadString(src, "widget-ext.go")
|
||||
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
AssertWidgetSet(t, harness, "my-widget")
|
||||
AssertWidgetText(t, harness, "my-widget", "Hello, World!")
|
||||
|
||||
// Also verify directly
|
||||
widget, ok := harness.Context().GetWidget("my-widget")
|
||||
if !ok {
|
||||
t.Error("expected widget 'my-widget' to exist")
|
||||
}
|
||||
if widget.Style.BorderColor != "#ff0000" {
|
||||
t.Errorf("expected border color '#ff0000', got %q", widget.Style.BorderColor)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHarness_FooterSetting(t *testing.T) {
|
||||
src := `package main
|
||||
|
||||
import "kit/ext"
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
ctx.SetFooter(ext.HeaderFooterConfig{
|
||||
Content: ext.WidgetContent{Text: "Status: OK"},
|
||||
Style: ext.WidgetStyle{BorderColor: "#00ff00"},
|
||||
})
|
||||
})
|
||||
}
|
||||
`
|
||||
|
||||
harness := New(t)
|
||||
harness.LoadString(src, "footer-ext.go")
|
||||
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
AssertFooterSet(t, harness)
|
||||
|
||||
footer := harness.Context().GetFooter()
|
||||
if footer == nil {
|
||||
t.Fatal("expected footer to be set")
|
||||
}
|
||||
if footer.Content.Text != "Status: OK" {
|
||||
t.Errorf("expected footer text 'Status: OK', got %q", footer.Content.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHarness_PrintInfoAndError(t *testing.T) {
|
||||
src := `package main
|
||||
|
||||
import "kit/ext"
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
ctx.PrintInfo("Information message")
|
||||
ctx.PrintError("Error message")
|
||||
})
|
||||
}
|
||||
`
|
||||
|
||||
harness := New(t)
|
||||
harness.LoadString(src, "print-ext.go")
|
||||
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
AssertPrintInfo(t, harness, "Information message")
|
||||
AssertPrintError(t, harness, "Error message")
|
||||
}
|
||||
|
||||
func TestHarness_EmitJSON(t *testing.T) {
|
||||
src := `package main
|
||||
|
||||
import "kit/ext"
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnToolCall(func(tc ext.ToolCallEvent, ctx ext.Context) *ext.ToolCallResult {
|
||||
if tc.ToolName == "test_tool" {
|
||||
return &ext.ToolCallResult{Block: true, Reason: "blocked"}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
`
|
||||
|
||||
harness := New(t)
|
||||
harness.LoadString(src, "json-ext.go")
|
||||
|
||||
result, err := harness.EmitJSON("test_tool", `{"key": "value"}`)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
|
||||
if !result.Block {
|
||||
t.Error("expected Block=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHarness_HasHandlers(t *testing.T) {
|
||||
src := `package main
|
||||
|
||||
import "kit/ext"
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnToolCall(func(_ ext.ToolCallEvent, _ ext.Context) *ext.ToolCallResult {
|
||||
return nil
|
||||
})
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, _ ext.Context) {
|
||||
})
|
||||
}
|
||||
`
|
||||
|
||||
harness := New(t)
|
||||
harness.LoadString(src, "handlers-ext.go")
|
||||
|
||||
AssertHasHandlers(t, harness, extensions.ToolCall)
|
||||
AssertHasHandlers(t, harness, extensions.SessionStart)
|
||||
AssertNoHandlers(t, harness, extensions.AgentEnd)
|
||||
}
|
||||
|
||||
func TestHarness_MultipleExtensions(t *testing.T) {
|
||||
ext1 := `package main
|
||||
|
||||
import "kit/ext"
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
ctx.Print("extension 1")
|
||||
})
|
||||
}
|
||||
`
|
||||
|
||||
ext2 := `package main
|
||||
|
||||
import "kit/ext"
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
ctx.Print("extension 2")
|
||||
})
|
||||
}
|
||||
`
|
||||
|
||||
// Load first extension
|
||||
harness1 := New(t)
|
||||
harness1.LoadString(ext1, "ext1.go")
|
||||
|
||||
// Load second extension
|
||||
harness2 := New(t)
|
||||
harness2.LoadString(ext2, "ext2.go")
|
||||
|
||||
// Verify they are isolated
|
||||
_, _ = harness1.Emit(extensions.SessionStartEvent{SessionID: "test1"})
|
||||
_, _ = harness2.Emit(extensions.SessionStartEvent{SessionID: "test2"})
|
||||
|
||||
prints1 := harness1.Context().GetPrints()
|
||||
prints2 := harness2.Context().GetPrints()
|
||||
|
||||
if len(prints1) != 1 || prints1[0] != "extension 1" {
|
||||
t.Errorf("ext1 prints: expected ['extension 1'], got %v", prints1)
|
||||
}
|
||||
|
||||
if len(prints2) != 1 || prints2[0] != "extension 2" {
|
||||
t.Errorf("ext2 prints: expected ['extension 2'], got %v", prints2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHarness_InputHandling(t *testing.T) {
|
||||
src := `package main
|
||||
|
||||
import (
|
||||
"kit/ext"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnInput(func(ie ext.InputEvent, ctx ext.Context) *ext.InputResult {
|
||||
if strings.Contains(ie.Text, "secret") {
|
||||
return &ext.InputResult{Action: "handled"}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
`
|
||||
|
||||
harness := New(t)
|
||||
harness.LoadString(src, "input-ext.go")
|
||||
|
||||
// Test handled input
|
||||
result, err := harness.Emit(extensions.InputEvent{Text: "my secret password", Source: "cli"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
AssertInputHandled(t, result, "handled")
|
||||
|
||||
// Test unhandled input
|
||||
result2, err := harness.Emit(extensions.InputEvent{Text: "normal input", Source: "cli"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result2 != nil {
|
||||
t.Errorf("expected nil result for unhandled input, got %v", result2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHarness_StatusSetting(t *testing.T) {
|
||||
src := `package main
|
||||
|
||||
import "kit/ext"
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
ctx.SetStatus("myext:status", "Ready", 50)
|
||||
})
|
||||
}
|
||||
`
|
||||
|
||||
harness := New(t)
|
||||
harness.LoadString(src, "status-ext.go")
|
||||
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
AssertStatusSet(t, harness, "myext:status")
|
||||
AssertStatusText(t, harness, "myext:status", "Ready")
|
||||
}
|
||||
|
||||
func TestHarness_LoadFile_NotFound(t *testing.T) {
|
||||
// Test that loading a nonexistent file fails the test
|
||||
// We create a mock testing.T to capture the failure
|
||||
mockT := &testing.T{}
|
||||
harness := New(mockT)
|
||||
|
||||
// Just verify the harness was created successfully
|
||||
_ = harness.Context().GetPrints()
|
||||
|
||||
// The actual behavior (Fatalf on missing file) is tested implicitly
|
||||
// whenever LoadFile is used in other tests
|
||||
}
|
||||
|
||||
// MockContext tests
|
||||
func TestMockContext_Prompts(t *testing.T) {
|
||||
ctx := NewMockContext()
|
||||
|
||||
// Configure results
|
||||
ctx.SetPromptSelectResult(extensions.PromptSelectResult{Value: "option1", Index: 0, Cancelled: false})
|
||||
ctx.SetPromptConfirmResult(extensions.PromptConfirmResult{Value: true, Cancelled: false})
|
||||
ctx.SetPromptInputResult(extensions.PromptInputResult{Value: "input text", Cancelled: false})
|
||||
|
||||
extCtx := ctx.ToContext()
|
||||
|
||||
// Test prompts return configured results
|
||||
selectResult := extCtx.PromptSelect(extensions.PromptSelectConfig{Message: "test", Options: []string{"a", "b"}})
|
||||
if selectResult.Value != "option1" {
|
||||
t.Errorf("expected 'option1', got %q", selectResult.Value)
|
||||
}
|
||||
|
||||
confirmResult := extCtx.PromptConfirm(extensions.PromptConfirmConfig{Message: "test"})
|
||||
if !confirmResult.Value {
|
||||
t.Error("expected true")
|
||||
}
|
||||
|
||||
inputResult := extCtx.PromptInput(extensions.PromptInputConfig{Message: "test"})
|
||||
if inputResult.Value != "input text" {
|
||||
t.Errorf("expected 'input text', got %q", inputResult.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMockContext_Options(t *testing.T) {
|
||||
ctx := NewMockContext()
|
||||
extCtx := ctx.ToContext()
|
||||
|
||||
// Initially empty
|
||||
if extCtx.GetOption("key") != "" {
|
||||
t.Error("expected empty option")
|
||||
}
|
||||
|
||||
// Set option
|
||||
extCtx.SetOption("key", "value")
|
||||
if extCtx.GetOption("key") != "value" {
|
||||
t.Errorf("expected 'value', got %q", extCtx.GetOption("key"))
|
||||
}
|
||||
}
|
||||
|
||||
// Assertion helper tests
|
||||
func TestAssertPrintedContains(t *testing.T) {
|
||||
src := `package main
|
||||
|
||||
import "kit/ext"
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
ctx.Print("This is a long message with some content")
|
||||
})
|
||||
}
|
||||
`
|
||||
|
||||
harness := New(t)
|
||||
harness.LoadString(src, "print-ext.go")
|
||||
_, _ = harness.Emit(extensions.SessionStartEvent{SessionID: "test"})
|
||||
|
||||
AssertPrintedContains(t, harness, "long message")
|
||||
AssertPrintedContains(t, harness, "some content")
|
||||
}
|
||||
|
||||
func TestAssertWidgetTextContains(t *testing.T) {
|
||||
src := `package main
|
||||
|
||||
import "kit/ext"
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
ctx.SetWidget(ext.WidgetConfig{
|
||||
ID: "status",
|
||||
Content: ext.WidgetContent{Text: "Build: passing, Tests: 42/42"},
|
||||
})
|
||||
})
|
||||
}
|
||||
`
|
||||
|
||||
harness := New(t)
|
||||
harness.LoadString(src, "widget-ext.go")
|
||||
_, _ = harness.Emit(extensions.SessionStartEvent{SessionID: "test"})
|
||||
|
||||
AssertWidgetTextContains(t, harness, "status", "Build: passing")
|
||||
AssertWidgetTextContains(t, harness, "status", "42/42")
|
||||
}
|
||||
|
||||
// Test that shows how to test a realistic extension pattern
|
||||
func TestExample_RealisticExtension(t *testing.T) {
|
||||
// This is an example of a realistic extension that:
|
||||
// 1. Blocks dangerous tools
|
||||
// 2. Shows a status widget
|
||||
// 3. Logs tool calls
|
||||
src := `package main
|
||||
|
||||
import "kit/ext"
|
||||
|
||||
var blockedTools = []string{"rm", "del", "remove"}
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnToolCall(func(tc ext.ToolCallEvent, ctx ext.Context) *ext.ToolCallResult {
|
||||
// Check if tool is blocked
|
||||
for _, blocked := range blockedTools {
|
||||
if tc.ToolName == blocked {
|
||||
ctx.PrintError("Tool " + tc.ToolName + " is blocked for safety")
|
||||
return &ext.ToolCallResult{Block: true, Reason: "safety block"}
|
||||
}
|
||||
}
|
||||
|
||||
// Log the tool call
|
||||
ctx.SetStatus("tool-logger:last", tc.ToolName, 10)
|
||||
return nil
|
||||
})
|
||||
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
ctx.SetWidget(ext.WidgetConfig{
|
||||
ID: "safety-status",
|
||||
Content: ext.WidgetContent{Text: "Safety: Active"},
|
||||
Style: ext.WidgetStyle{BorderColor: "#00ff00"},
|
||||
})
|
||||
})
|
||||
}
|
||||
`
|
||||
|
||||
harness := New(t)
|
||||
harness.LoadString(src, "safety-ext.go")
|
||||
|
||||
// Verify handlers are registered
|
||||
AssertHasHandlers(t, harness, extensions.ToolCall)
|
||||
AssertHasHandlers(t, harness, extensions.SessionStart)
|
||||
|
||||
// Test session start
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Verify widget was set
|
||||
AssertWidgetSet(t, harness, "safety-status")
|
||||
AssertWidgetText(t, harness, "safety-status", "Safety: Active")
|
||||
|
||||
// Test allowed tool
|
||||
result, _ := harness.Emit(extensions.ToolCallEvent{ToolName: "read", Input: "{}"})
|
||||
AssertNotBlocked(t, result)
|
||||
|
||||
// Verify status was updated
|
||||
AssertStatusSet(t, harness, "tool-logger:last")
|
||||
AssertStatusText(t, harness, "tool-logger:last", "read")
|
||||
|
||||
// Test blocked tool
|
||||
result2, _ := harness.Emit(extensions.ToolCallEvent{ToolName: "rm", Input: `{"file": "test.txt"}`})
|
||||
AssertBlocked(t, result2, "safety block")
|
||||
AssertPrintError(t, harness, "Tool rm is blocked for safety")
|
||||
}
|
||||
@@ -0,0 +1,460 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
)
|
||||
|
||||
// MockContext records all interactions with the extension context.
|
||||
// It provides a Context object that captures Print calls, widget settings,
|
||||
// and other context operations for verification in tests.
|
||||
type MockContext struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// Recorded calls
|
||||
Prints []string
|
||||
PrintInfos []string
|
||||
PrintErrors []string
|
||||
PrintBlocks []extensions.PrintBlockOpts
|
||||
Messages []string
|
||||
CancelSends []string
|
||||
|
||||
// Widget state
|
||||
Widgets map[string]extensions.WidgetConfig
|
||||
RemovedIDs []string
|
||||
Header *extensions.HeaderFooterConfig
|
||||
Footer *extensions.HeaderFooterConfig
|
||||
HeaderRemoved bool
|
||||
FooterRemoved bool
|
||||
|
||||
// Context properties
|
||||
SessionID string
|
||||
CWD string
|
||||
Model string
|
||||
Interactive bool
|
||||
|
||||
// UI visibility
|
||||
UIVisibility *extensions.UIVisibility
|
||||
|
||||
// Status entries
|
||||
StatusEntries map[string]extensions.StatusBarEntry
|
||||
RemovedStatus []string
|
||||
|
||||
// Editor
|
||||
EditorConfig *extensions.EditorConfig
|
||||
EditorReset bool
|
||||
EditorTexts []string
|
||||
|
||||
// Options
|
||||
Options map[string]string
|
||||
|
||||
// Prompt results (configurable for testing)
|
||||
PromptSelectResult extensions.PromptSelectResult
|
||||
PromptConfirmResult extensions.PromptConfirmResult
|
||||
PromptInputResult extensions.PromptInputResult
|
||||
PromptMultiSelectResult extensions.PromptMultiSelectResult
|
||||
|
||||
// Overlay
|
||||
Overlays []extensions.OverlayConfig
|
||||
}
|
||||
|
||||
// StatusBarEntry represents a recorded status bar entry
|
||||
type StatusBarEntry struct {
|
||||
Key string
|
||||
Text string
|
||||
Priority int
|
||||
}
|
||||
|
||||
// NewMockContext creates a new mock context with default values.
|
||||
func NewMockContext() *MockContext {
|
||||
return &MockContext{
|
||||
Prints: make([]string, 0),
|
||||
PrintInfos: make([]string, 0),
|
||||
PrintErrors: make([]string, 0),
|
||||
PrintBlocks: make([]extensions.PrintBlockOpts, 0),
|
||||
Messages: make([]string, 0),
|
||||
CancelSends: make([]string, 0),
|
||||
Widgets: make(map[string]extensions.WidgetConfig),
|
||||
RemovedIDs: make([]string, 0),
|
||||
StatusEntries: make(map[string]extensions.StatusBarEntry),
|
||||
RemovedStatus: make([]string, 0),
|
||||
EditorTexts: make([]string, 0),
|
||||
Options: make(map[string]string),
|
||||
Overlays: make([]extensions.OverlayConfig, 0),
|
||||
Interactive: true,
|
||||
SessionID: "test-session",
|
||||
CWD: "/test",
|
||||
Model: "test-model",
|
||||
}
|
||||
}
|
||||
|
||||
// ToContext returns a extensions.Context wired to record all interactions.
|
||||
func (m *MockContext) ToContext() extensions.Context {
|
||||
return extensions.Context{
|
||||
SessionID: m.SessionID,
|
||||
CWD: m.CWD,
|
||||
Model: m.Model,
|
||||
Interactive: m.Interactive,
|
||||
Print: m.recordPrint,
|
||||
PrintInfo: m.recordPrintInfo,
|
||||
PrintError: m.recordPrintError,
|
||||
PrintBlock: m.recordPrintBlock,
|
||||
SendMessage: m.recordSendMessage,
|
||||
CancelAndSend: m.recordCancelAndSend,
|
||||
SetWidget: m.recordSetWidget,
|
||||
RemoveWidget: m.recordRemoveWidget,
|
||||
SetHeader: m.recordSetHeader,
|
||||
RemoveHeader: m.recordRemoveHeader,
|
||||
SetFooter: m.recordSetFooter,
|
||||
RemoveFooter: m.recordRemoveFooter,
|
||||
PromptSelect: m.recordPromptSelect,
|
||||
PromptConfirm: m.recordPromptConfirm,
|
||||
PromptInput: m.recordPromptInput,
|
||||
PromptMultiSelect: m.recordPromptMultiSelect,
|
||||
SetEditor: m.recordSetEditor,
|
||||
ResetEditor: m.recordResetEditor,
|
||||
SetEditorText: m.recordSetEditorText,
|
||||
SetUIVisibility: m.recordUIVisibility,
|
||||
GetContextStats: m.getContextStats,
|
||||
GetMessages: m.getMessages,
|
||||
GetSessionPath: m.getSessionPath,
|
||||
AppendEntry: m.appendEntry,
|
||||
GetEntries: m.getEntries,
|
||||
SetStatus: m.recordSetStatus,
|
||||
RemoveStatus: m.recordRemoveStatus,
|
||||
GetOption: m.getOption,
|
||||
SetOption: m.setOption,
|
||||
SetModel: m.setModel,
|
||||
GetAllTools: m.getAllTools,
|
||||
SetActiveTools: m.setActiveTools,
|
||||
Exit: m.exit,
|
||||
Complete: m.complete,
|
||||
SuspendTUI: m.suspendTUI,
|
||||
RenderMessage: m.renderMessage,
|
||||
RegisterTheme: m.registerTheme,
|
||||
SetTheme: m.setTheme,
|
||||
ListThemes: m.listThemes,
|
||||
ReloadExtensions: m.reloadExtensions,
|
||||
SpawnSubagent: m.spawnSubagent,
|
||||
ShowOverlay: m.showOverlay,
|
||||
}
|
||||
}
|
||||
|
||||
// Record methods
|
||||
|
||||
func (m *MockContext) recordPrint(text string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.Prints = append(m.Prints, text)
|
||||
}
|
||||
|
||||
func (m *MockContext) recordPrintInfo(text string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.PrintInfos = append(m.PrintInfos, text)
|
||||
}
|
||||
|
||||
func (m *MockContext) recordPrintError(text string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.PrintErrors = append(m.PrintErrors, text)
|
||||
}
|
||||
|
||||
func (m *MockContext) recordPrintBlock(opts extensions.PrintBlockOpts) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.PrintBlocks = append(m.PrintBlocks, opts)
|
||||
}
|
||||
|
||||
func (m *MockContext) recordSendMessage(text string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.Messages = append(m.Messages, text)
|
||||
}
|
||||
|
||||
func (m *MockContext) recordCancelAndSend(text string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.CancelSends = append(m.CancelSends, text)
|
||||
}
|
||||
|
||||
func (m *MockContext) recordSetWidget(config extensions.WidgetConfig) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.Widgets[config.ID] = config
|
||||
}
|
||||
|
||||
func (m *MockContext) recordRemoveWidget(id string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.Widgets, id)
|
||||
m.RemovedIDs = append(m.RemovedIDs, id)
|
||||
}
|
||||
|
||||
func (m *MockContext) recordSetHeader(config extensions.HeaderFooterConfig) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.Header = &config
|
||||
}
|
||||
|
||||
func (m *MockContext) recordRemoveHeader() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.Header = nil
|
||||
m.HeaderRemoved = true
|
||||
}
|
||||
|
||||
func (m *MockContext) recordSetFooter(config extensions.HeaderFooterConfig) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.Footer = &config
|
||||
}
|
||||
|
||||
func (m *MockContext) recordRemoveFooter() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.Footer = nil
|
||||
m.FooterRemoved = true
|
||||
}
|
||||
|
||||
func (m *MockContext) recordSetStatus(key string, text string, priority int) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.StatusEntries[key] = extensions.StatusBarEntry{
|
||||
Key: key,
|
||||
Text: text,
|
||||
Priority: priority,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockContext) recordRemoveStatus(key string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.StatusEntries, key)
|
||||
m.RemovedStatus = append(m.RemovedStatus, key)
|
||||
}
|
||||
|
||||
func (m *MockContext) recordSetEditor(config extensions.EditorConfig) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.EditorConfig = &config
|
||||
}
|
||||
|
||||
func (m *MockContext) recordResetEditor() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.EditorReset = true
|
||||
m.EditorConfig = nil
|
||||
}
|
||||
|
||||
func (m *MockContext) recordSetEditorText(text string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.EditorTexts = append(m.EditorTexts, text)
|
||||
}
|
||||
|
||||
func (m *MockContext) recordUIVisibility(vis extensions.UIVisibility) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.UIVisibility = &vis
|
||||
}
|
||||
|
||||
func (m *MockContext) recordPromptSelect(config extensions.PromptSelectConfig) extensions.PromptSelectResult {
|
||||
// Return the configured result (tests can set this)
|
||||
return m.PromptSelectResult
|
||||
}
|
||||
|
||||
func (m *MockContext) recordPromptConfirm(config extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
|
||||
return m.PromptConfirmResult
|
||||
}
|
||||
|
||||
func (m *MockContext) recordPromptInput(config extensions.PromptInputConfig) extensions.PromptInputResult {
|
||||
return m.PromptInputResult
|
||||
}
|
||||
|
||||
func (m *MockContext) recordPromptMultiSelect(config extensions.PromptMultiSelectConfig) extensions.PromptMultiSelectResult {
|
||||
return m.PromptMultiSelectResult
|
||||
}
|
||||
|
||||
func (m *MockContext) showOverlay(config extensions.OverlayConfig) extensions.OverlayResult {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.Overlays = append(m.Overlays, config)
|
||||
return extensions.OverlayResult{Cancelled: true} // Default to cancelled for tests
|
||||
}
|
||||
|
||||
// Stub methods that do nothing or return defaults
|
||||
|
||||
func (m *MockContext) getContextStats() extensions.ContextStats {
|
||||
return extensions.ContextStats{
|
||||
EstimatedTokens: 1000,
|
||||
ContextLimit: 200000,
|
||||
UsagePercent: 0.5,
|
||||
MessageCount: 10,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockContext) getMessages() []extensions.SessionMessage {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockContext) getSessionPath() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *MockContext) appendEntry(entryType string, data string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *MockContext) getEntries(entryType string) []extensions.ExtensionEntry {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockContext) getOption(name string) string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.Options[name]
|
||||
}
|
||||
|
||||
func (m *MockContext) setOption(name string, value string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.Options[name] = value
|
||||
}
|
||||
|
||||
func (m *MockContext) setModel(modelString string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockContext) getAllTools() []extensions.ToolInfo {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockContext) setActiveTools(names []string) {}
|
||||
|
||||
func (m *MockContext) exit() {}
|
||||
|
||||
func (m *MockContext) complete(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
|
||||
return extensions.CompleteResponse{}, nil
|
||||
}
|
||||
|
||||
func (m *MockContext) suspendTUI(callback func()) error {
|
||||
callback()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockContext) renderMessage(rendererName string, content string) {}
|
||||
|
||||
func (m *MockContext) registerTheme(name string, config extensions.ThemeColorConfig) {}
|
||||
|
||||
func (m *MockContext) setTheme(name string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockContext) listThemes() []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockContext) reloadExtensions() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockContext) spawnSubagent(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
// Accessor methods for verification
|
||||
|
||||
// GetPrints returns all recorded Print calls.
|
||||
func (m *MockContext) GetPrints() []string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
result := make([]string, len(m.Prints))
|
||||
copy(result, m.Prints)
|
||||
return result
|
||||
}
|
||||
|
||||
// GetPrintInfos returns all recorded PrintInfo calls.
|
||||
func (m *MockContext) GetPrintInfos() []string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
result := make([]string, len(m.PrintInfos))
|
||||
copy(result, m.PrintInfos)
|
||||
return result
|
||||
}
|
||||
|
||||
// GetPrintErrors returns all recorded PrintError calls.
|
||||
func (m *MockContext) GetPrintErrors() []string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
result := make([]string, len(m.PrintErrors))
|
||||
copy(result, m.PrintErrors)
|
||||
return result
|
||||
}
|
||||
|
||||
// GetWidget returns a recorded widget by ID.
|
||||
func (m *MockContext) GetWidget(id string) (extensions.WidgetConfig, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
w, ok := m.Widgets[id]
|
||||
return w, ok
|
||||
}
|
||||
|
||||
// HasWidget reports whether a widget with the given ID was set.
|
||||
func (m *MockContext) HasWidget(id string) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
_, ok := m.Widgets[id]
|
||||
return ok
|
||||
}
|
||||
|
||||
// GetHeader returns the recorded header configuration.
|
||||
func (m *MockContext) GetHeader() *extensions.HeaderFooterConfig {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.Header
|
||||
}
|
||||
|
||||
// GetFooter returns the recorded footer configuration.
|
||||
func (m *MockContext) GetFooter() *extensions.HeaderFooterConfig {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.Footer
|
||||
}
|
||||
|
||||
// GetStatus returns a recorded status entry by key.
|
||||
func (m *MockContext) GetStatus(key string) (extensions.StatusBarEntry, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
s, ok := m.StatusEntries[key]
|
||||
return s, ok
|
||||
}
|
||||
|
||||
// SetPromptSelectResult configures the result returned by PromptSelect.
|
||||
func (m *MockContext) SetPromptSelectResult(result extensions.PromptSelectResult) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.PromptSelectResult = result
|
||||
}
|
||||
|
||||
// SetPromptConfirmResult configures the result returned by PromptConfirm.
|
||||
func (m *MockContext) SetPromptConfirmResult(result extensions.PromptConfirmResult) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.PromptConfirmResult = result
|
||||
}
|
||||
|
||||
// SetPromptInputResult configures the result returned by PromptInput.
|
||||
func (m *MockContext) SetPromptInputResult(result extensions.PromptInputResult) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.PromptInputResult = result
|
||||
}
|
||||
|
||||
// SetPromptMultiSelectResult configures the result returned by PromptMultiSelect.
|
||||
func (m *MockContext) SetPromptMultiSelectResult(result extensions.PromptMultiSelectResult) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.PromptMultiSelectResult = result
|
||||
}
|
||||
+96
-11
@@ -1,6 +1,9 @@
|
||||
package kit
|
||||
|
||||
import "sync"
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Event types
|
||||
@@ -48,6 +51,54 @@ type Event interface {
|
||||
EventType() EventType
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tool kind constants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ToolKind constants classify what a tool does, enabling UIs to render
|
||||
// appropriate visualizations (e.g. diff view for edit tools, command+output
|
||||
// for execute tools) and file trackers to identify which results contain
|
||||
// modifications.
|
||||
const (
|
||||
ToolKindExecute = "execute" // Shell execution (bash)
|
||||
ToolKindEdit = "edit" // File modification (edit, write)
|
||||
ToolKindRead = "read" // File reading (read, ls)
|
||||
ToolKindSearch = "search" // Content/file search (grep, find)
|
||||
ToolKindSubagent = "agent" // Subagent spawning (spawn_subagent)
|
||||
)
|
||||
|
||||
// coreToolKinds maps built-in tool names to their kind. MCP and extension
|
||||
// tools without an entry default to ToolKindExecute.
|
||||
var coreToolKinds = map[string]string{
|
||||
"bash": ToolKindExecute,
|
||||
"edit": ToolKindEdit,
|
||||
"write": ToolKindEdit,
|
||||
"read": ToolKindRead,
|
||||
"ls": ToolKindRead,
|
||||
"grep": ToolKindSearch,
|
||||
"find": ToolKindSearch,
|
||||
"spawn_subagent": ToolKindSubagent,
|
||||
}
|
||||
|
||||
// toolKindFor returns the ToolKind for a given tool name, defaulting to
|
||||
// ToolKindExecute for unknown tools.
|
||||
func toolKindFor(toolName string) string {
|
||||
if kind, ok := coreToolKinds[toolName]; ok {
|
||||
return kind
|
||||
}
|
||||
return ToolKindExecute
|
||||
}
|
||||
|
||||
// parseToolArgs attempts to parse a JSON-encoded tool args string into a map.
|
||||
// Returns nil on failure (non-fatal convenience parsing).
|
||||
func parseToolArgs(toolArgs string) map[string]any {
|
||||
var parsed map[string]any
|
||||
if json.Unmarshal([]byte(toolArgs), &parsed) == nil {
|
||||
return parsed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Concrete event structs
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -62,8 +113,9 @@ func (e TurnStartEvent) EventType() EventType { return EventTurnStart }
|
||||
|
||||
// TurnEndEvent fires after the agent finishes processing.
|
||||
type TurnEndEvent struct {
|
||||
Response string
|
||||
Error error
|
||||
Response string
|
||||
Error error
|
||||
StopReason string // "end_turn", "max_tokens", "tool_use", "error", etc.
|
||||
}
|
||||
|
||||
// EventType implements Event.
|
||||
@@ -101,8 +153,11 @@ func (e MessageEndEvent) EventType() EventType { return EventMessageEnd }
|
||||
|
||||
// ToolCallEvent fires when a tool call has been parsed.
|
||||
type ToolCallEvent struct {
|
||||
ToolName string
|
||||
ToolArgs string
|
||||
ToolCallID string // Stable ID for correlating tool lifecycle events
|
||||
ToolName string
|
||||
ToolKind string // Tool classification: "execute", "edit", "read", "search", "agent"
|
||||
ToolArgs string // JSON-encoded arguments
|
||||
ParsedArgs map[string]any // Pre-parsed arguments for convenience (nil on parse failure)
|
||||
}
|
||||
|
||||
// EventType implements Event.
|
||||
@@ -110,7 +165,10 @@ func (e ToolCallEvent) EventType() EventType { return EventToolCall }
|
||||
|
||||
// ToolExecutionStartEvent fires when a tool begins executing.
|
||||
type ToolExecutionStartEvent struct {
|
||||
ToolName string
|
||||
ToolCallID string
|
||||
ToolName string
|
||||
ToolKind string
|
||||
ToolArgs string
|
||||
}
|
||||
|
||||
// EventType implements Event.
|
||||
@@ -118,7 +176,9 @@ func (e ToolExecutionStartEvent) EventType() EventType { return EventToolExecuti
|
||||
|
||||
// ToolExecutionEndEvent fires when a tool finishes executing.
|
||||
type ToolExecutionEndEvent struct {
|
||||
ToolName string
|
||||
ToolCallID string
|
||||
ToolName string
|
||||
ToolKind string
|
||||
}
|
||||
|
||||
// EventType implements Event.
|
||||
@@ -126,10 +186,35 @@ func (e ToolExecutionEndEvent) EventType() EventType { return EventToolExecution
|
||||
|
||||
// ToolResultEvent fires after a tool execution completes with its result.
|
||||
type ToolResultEvent struct {
|
||||
ToolName string
|
||||
ToolArgs string
|
||||
Result string
|
||||
IsError bool
|
||||
ToolCallID string
|
||||
ToolName string
|
||||
ToolKind string
|
||||
ToolArgs string
|
||||
ParsedArgs map[string]any // Pre-parsed arguments for convenience
|
||||
Result string
|
||||
IsError bool
|
||||
Metadata *ToolResultMetadata // Optional structured metadata from tool execution
|
||||
}
|
||||
|
||||
// ToolResultMetadata carries structured data from tool executions.
|
||||
type ToolResultMetadata struct {
|
||||
FileDiffs []FileDiffInfo `json:"file_diffs,omitempty"` // Present for edit/write tools
|
||||
SubagentSessionID string `json:"subagent_session_id,omitempty"` // Present for spawn_subagent tool
|
||||
}
|
||||
|
||||
// FileDiffInfo describes a file modification from an edit or write tool.
|
||||
type FileDiffInfo struct {
|
||||
Path string `json:"path"` // Absolute file path
|
||||
Additions int `json:"additions"` // Lines added
|
||||
Deletions int `json:"deletions"` // Lines removed
|
||||
IsNew bool `json:"is_new,omitempty"` // True if file was created (write only)
|
||||
DiffBlocks []DiffBlock `json:"diff_blocks,omitempty"`
|
||||
}
|
||||
|
||||
// DiffBlock represents a single old→new text replacement within a file.
|
||||
type DiffBlock struct {
|
||||
OldText string `json:"old_text"`
|
||||
NewText string `json:"new_text"`
|
||||
}
|
||||
|
||||
// EventType implements Event.
|
||||
|
||||
@@ -89,11 +89,13 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
if runner.HasHandlers(extensions.AgentEnd) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(TurnEndEvent); ok {
|
||||
stopReason := "completed"
|
||||
stopReason := ev.StopReason
|
||||
response := ev.Response
|
||||
if ev.Error != nil {
|
||||
stopReason = "error"
|
||||
response = ""
|
||||
} else if stopReason == "" {
|
||||
stopReason = "completed"
|
||||
}
|
||||
_, _ = runner.Emit(extensions.AgentEndEvent{
|
||||
Response: response,
|
||||
|
||||
+16
-12
@@ -31,8 +31,9 @@ const (
|
||||
|
||||
// BeforeToolCallHook is the input for hooks that fire before a tool executes.
|
||||
type BeforeToolCallHook struct {
|
||||
ToolName string
|
||||
ToolArgs string
|
||||
ToolCallID string
|
||||
ToolName string
|
||||
ToolArgs string
|
||||
}
|
||||
|
||||
// BeforeToolCallResult controls whether the tool call proceeds.
|
||||
@@ -43,10 +44,11 @@ type BeforeToolCallResult struct {
|
||||
|
||||
// AfterToolResultHook is the input for hooks that fire after a tool executes.
|
||||
type AfterToolResultHook struct {
|
||||
ToolName string
|
||||
ToolArgs string
|
||||
Result string
|
||||
IsError bool
|
||||
ToolCallID string
|
||||
ToolName string
|
||||
ToolArgs string
|
||||
Result string
|
||||
IsError bool
|
||||
}
|
||||
|
||||
// AfterToolResultResult can modify the tool's output before it reaches the LLM.
|
||||
@@ -258,8 +260,9 @@ func (h *hookedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.To
|
||||
// 1. BeforeToolCall — can block execution.
|
||||
if h.beforeToolCall.hasHooks() {
|
||||
if result := h.beforeToolCall.run(BeforeToolCallHook{
|
||||
ToolName: toolName,
|
||||
ToolArgs: call.Input,
|
||||
ToolCallID: call.ID,
|
||||
ToolName: toolName,
|
||||
ToolArgs: call.Input,
|
||||
}); result != nil && result.Block {
|
||||
reason := result.Reason
|
||||
if reason == "" {
|
||||
@@ -276,10 +279,11 @@ func (h *hookedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.To
|
||||
// 3. AfterToolResult — can modify output.
|
||||
if h.afterToolResult.hasHooks() {
|
||||
if result := h.afterToolResult.run(AfterToolResultHook{
|
||||
ToolName: toolName,
|
||||
ToolArgs: call.Input,
|
||||
Result: resp.Content,
|
||||
IsError: err != nil || resp.IsError,
|
||||
ToolCallID: call.ID,
|
||||
ToolName: toolName,
|
||||
ToolArgs: call.Input,
|
||||
Result: resp.Content,
|
||||
IsError: err != nil || resp.IsError,
|
||||
}); result != nil {
|
||||
if result.Result != nil {
|
||||
resp.Content = *result.Result
|
||||
|
||||
@@ -24,6 +24,7 @@ func TestHookRegistry_RegisterAndRun(t *testing.T) {
|
||||
got := hr.run("hello")
|
||||
if got == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
return
|
||||
}
|
||||
if *got != "handled: hello" {
|
||||
t.Errorf("expected 'handled: hello', got %q", *got)
|
||||
@@ -51,6 +52,7 @@ func TestHookRegistry_FirstNonNilWins(t *testing.T) {
|
||||
got := hr.run("test")
|
||||
if got == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
return
|
||||
}
|
||||
if *got != "second: test" {
|
||||
t.Errorf("expected 'second: test', got %q", *got)
|
||||
@@ -77,6 +79,7 @@ func TestHookRegistry_PriorityOrdering(t *testing.T) {
|
||||
got := hr.run("x")
|
||||
if got == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
return
|
||||
}
|
||||
if *got != "high" {
|
||||
t.Errorf("expected 'high' (priority 0 runs first), got %q", *got)
|
||||
@@ -441,6 +444,7 @@ func TestBeforeTurnHook_PromptOverride(t *testing.T) {
|
||||
result := hr.run(BeforeTurnHook{Prompt: "original"})
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
return
|
||||
}
|
||||
if result.Prompt == nil || *result.Prompt != "modified prompt" {
|
||||
t.Errorf("expected prompt override, got %v", result.Prompt)
|
||||
@@ -462,6 +466,7 @@ func TestBeforeTurnHook_InjectSystemAndContext(t *testing.T) {
|
||||
result := hr.run(BeforeTurnHook{Prompt: "hello"})
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
return
|
||||
}
|
||||
if result.SystemPrompt == nil || *result.SystemPrompt != "be concise" {
|
||||
t.Errorf("expected system prompt injection")
|
||||
|
||||
+282
-12
@@ -2,6 +2,7 @@ package kit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
|
||||
"github.com/mark3labs/kit/internal/agent"
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
"github.com/mark3labs/kit/internal/core"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/kitsetup"
|
||||
"github.com/mark3labs/kit/internal/message"
|
||||
@@ -347,6 +349,50 @@ func (m *Kit) GetSessionMessages() []extensions.SessionMessage {
|
||||
return msgs
|
||||
}
|
||||
|
||||
// StructuredMessage represents a conversation message with typed content parts
|
||||
// (tool calls, reasoning, finish markers, etc.) instead of flattened text.
|
||||
type StructuredMessage struct {
|
||||
ID string
|
||||
ParentID string
|
||||
Role MessageRole
|
||||
Parts []ContentPart
|
||||
Model string
|
||||
Provider string
|
||||
Timestamp string // RFC3339 format
|
||||
}
|
||||
|
||||
// GetStructuredMessages returns the conversation messages on the current
|
||||
// branch with full typed content parts. Unlike GetSessionMessages() which
|
||||
// flattens all content to a single text string, this preserves tool calls,
|
||||
// tool results, reasoning blocks, and finish markers as distinct typed parts.
|
||||
func (m *Kit) GetStructuredMessages() []StructuredMessage {
|
||||
if m.treeSession == nil {
|
||||
return nil
|
||||
}
|
||||
branch := m.treeSession.GetBranch("")
|
||||
var msgs []StructuredMessage
|
||||
for _, entry := range branch {
|
||||
me, ok := entry.(*session.MessageEntry)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
msg, err := me.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs = append(msgs, StructuredMessage{
|
||||
ID: me.ID,
|
||||
ParentID: me.ParentID,
|
||||
Role: msg.Role,
|
||||
Parts: msg.Parts,
|
||||
Model: msg.Model,
|
||||
Provider: msg.Provider,
|
||||
Timestamp: me.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
|
||||
})
|
||||
}
|
||||
return msgs
|
||||
}
|
||||
|
||||
// GetSessionFilePath returns the JSONL file path of the current session.
|
||||
func (m *Kit) GetSessionFilePath() string {
|
||||
if m.treeSession == nil {
|
||||
@@ -849,11 +895,19 @@ func InitTreeSession(opts *Options) (*session.TreeManager, error) {
|
||||
// New creates a Kit instance using the same initialization as the CLI.
|
||||
// It loads configuration, initializes MCP servers, creates the LLM model, and
|
||||
// sets up the agent for interaction. Returns an error if initialization fails.
|
||||
// viperInitMu serializes viper writes during kit.New(). Viper's global state
|
||||
// is not thread-safe, so concurrent calls (e.g. parallel subagent spawns)
|
||||
// must not overlap the Set()/Get() window.
|
||||
var viperInitMu sync.Mutex
|
||||
|
||||
func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
if opts == nil {
|
||||
opts = &Options{}
|
||||
}
|
||||
|
||||
viperInitMu.Lock()
|
||||
defer viperInitMu.Unlock()
|
||||
|
||||
// Set CLI-equivalent defaults for viper. When used as an SDK (without
|
||||
// cobra), these defaults are not registered via flag bindings.
|
||||
setSDKDefaults()
|
||||
@@ -1150,6 +1204,14 @@ type TurnResult struct {
|
||||
// Response is the assistant's final text response.
|
||||
Response string
|
||||
|
||||
// StopReason indicates why the turn ended. Derived from the LLM
|
||||
// provider's finish reason: "stop", "length" (max tokens), "tool-calls",
|
||||
// "content-filter", "error", "other", "unknown".
|
||||
StopReason string
|
||||
|
||||
// SessionID is the UUID of the session this turn belongs to.
|
||||
SessionID string
|
||||
|
||||
// TotalUsage is the aggregate token usage across all steps in the turn
|
||||
// (includes tool-calling loop iterations). Nil if the provider didn't
|
||||
// report usage.
|
||||
@@ -1165,6 +1227,168 @@ type TurnResult struct {
|
||||
Messages []FantasyMessage
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// In-process subagent
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SubagentConfig configures an in-process subagent spawned via Kit.Subagent().
|
||||
type SubagentConfig struct {
|
||||
// Prompt is the task/instruction for the subagent (required).
|
||||
Prompt string
|
||||
|
||||
// Model overrides the parent's model (e.g. "anthropic/claude-haiku-3-5-20241022").
|
||||
// Empty string uses the parent's current model.
|
||||
Model string
|
||||
|
||||
// SystemPrompt provides domain-specific instructions for the subagent.
|
||||
// Empty string uses a minimal default prompt.
|
||||
SystemPrompt string
|
||||
|
||||
// Tools overrides the tool set. If nil, SubagentTools() is used (all
|
||||
// core tools except spawn_subagent, preventing infinite recursion).
|
||||
Tools []Tool
|
||||
|
||||
// NoSession, when true, uses an in-memory ephemeral session. When false
|
||||
// (default), the subagent's session is persisted and can be loaded for
|
||||
// replay/inspection.
|
||||
NoSession bool
|
||||
|
||||
// Timeout limits execution time. Zero means 5 minute default.
|
||||
Timeout time.Duration
|
||||
|
||||
// OnEvent, when set, receives all events from the subagent's event bus.
|
||||
// This enables the parent to stream subagent tool calls, text chunks,
|
||||
// etc. in real time.
|
||||
OnEvent func(Event)
|
||||
}
|
||||
|
||||
// SubagentResult contains the outcome of an in-process subagent execution.
|
||||
type SubagentResult struct {
|
||||
// Response is the subagent's final text response.
|
||||
Response string
|
||||
// Error is set if the subagent failed (nil on success).
|
||||
Error error
|
||||
// SessionID is the subagent's session identifier (for replay).
|
||||
SessionID string
|
||||
// StopReason is the LLM's finish reason for the subagent's final turn.
|
||||
StopReason string
|
||||
// Usage contains token usage from the subagent's run.
|
||||
Usage *FantasyUsage
|
||||
// Elapsed is the total execution time.
|
||||
Elapsed time.Duration
|
||||
}
|
||||
|
||||
// Subagent spawns an in-process child Kit instance to perform a task. The
|
||||
// child gets its own session, event bus, and agent loop but shares the
|
||||
// parent's config (API keys, provider settings) and defaults to the parent's
|
||||
// model when SubagentConfig.Model is empty.
|
||||
//
|
||||
// This is the recommended way to run subagents in the SDK — no subprocess,
|
||||
// no kit binary dependency, native Go types for results.
|
||||
func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult, error) {
|
||||
if cfg.Prompt == "" {
|
||||
return nil, fmt.Errorf("subagent prompt is required")
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
// Default timeout.
|
||||
timeout := cfg.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 5 * time.Minute
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
// Resolve model: fall back to parent's model, and inherit the parent's
|
||||
// provider when only a bare model name is given (e.g. "claude-haiku"
|
||||
// instead of "anthropic/claude-haiku"). This avoids provider guessing.
|
||||
model := cfg.Model
|
||||
if model == "" {
|
||||
model = m.modelString
|
||||
} else if !strings.Contains(model, "/") {
|
||||
// Bare model name — prepend parent's provider.
|
||||
if parts := strings.SplitN(m.modelString, "/", 2); len(parts) == 2 {
|
||||
model = parts[0] + "/" + model
|
||||
}
|
||||
}
|
||||
|
||||
// Default system prompt.
|
||||
systemPrompt := cfg.SystemPrompt
|
||||
if systemPrompt == "" {
|
||||
systemPrompt = "You are a helpful coding assistant. Complete the task efficiently and thoroughly."
|
||||
}
|
||||
|
||||
// Default tools: everything except spawn_subagent.
|
||||
tools := cfg.Tools
|
||||
if tools == nil {
|
||||
tools = SubagentTools()
|
||||
}
|
||||
|
||||
// Create child Kit instance. If the requested model fails (bad name,
|
||||
// unsupported provider, etc.), fall back to the parent's model so the
|
||||
// agent gets a useful error message instead of a hard failure.
|
||||
childOpts := &Options{
|
||||
Model: model,
|
||||
SystemPrompt: systemPrompt,
|
||||
Tools: tools,
|
||||
NoSession: cfg.NoSession,
|
||||
Quiet: true,
|
||||
}
|
||||
child, err := New(ctx, childOpts)
|
||||
if err != nil && model != m.modelString {
|
||||
// Model-specific failure — retry with parent's model.
|
||||
childOpts.Model = m.modelString
|
||||
child, err = New(ctx, childOpts)
|
||||
if err != nil {
|
||||
return &SubagentResult{
|
||||
Error: fmt.Errorf("failed to create subagent: %w", err),
|
||||
Elapsed: time.Since(start),
|
||||
}, err
|
||||
}
|
||||
// Prepend a note so the agent knows which model is actually running.
|
||||
cfg.Prompt = fmt.Sprintf(
|
||||
"[Note: requested model %q was not available, using %s instead.]\n\n%s",
|
||||
model, m.modelString, cfg.Prompt,
|
||||
)
|
||||
} else if err != nil {
|
||||
return &SubagentResult{
|
||||
Error: fmt.Errorf("failed to create subagent: %w", err),
|
||||
Elapsed: time.Since(start),
|
||||
}, err
|
||||
}
|
||||
defer func() { _ = child.Close() }()
|
||||
|
||||
// Forward events to parent if requested.
|
||||
if cfg.OnEvent != nil {
|
||||
child.Subscribe(cfg.OnEvent)
|
||||
}
|
||||
|
||||
// Run the prompt.
|
||||
result, err := child.PromptResult(ctx, cfg.Prompt)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
return &SubagentResult{
|
||||
Error: err,
|
||||
SessionID: child.GetSessionID(),
|
||||
Elapsed: elapsed,
|
||||
}, err
|
||||
}
|
||||
|
||||
subResult := &SubagentResult{
|
||||
Response: result.Response,
|
||||
SessionID: child.GetSessionID(),
|
||||
StopReason: result.StopReason,
|
||||
Elapsed: elapsed,
|
||||
}
|
||||
if result.TotalUsage != nil {
|
||||
subResult.Usage = result.TotalUsage
|
||||
}
|
||||
|
||||
return subResult, nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared generation helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -1173,22 +1397,64 @@ type TurnResult struct {
|
||||
// All prompt modes (Prompt, Steer, FollowUp, PromptWithOptions) share this
|
||||
// single code path so callback wiring is never duplicated.
|
||||
func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.GenerateWithLoopResult, error) {
|
||||
// Inject the in-process subagent spawner into the context so the
|
||||
// spawn_subagent core tool can create child Kit instances without
|
||||
// importing pkg/kit (which would create an import cycle).
|
||||
ctx = core.WithSubagentSpawner(ctx, func(
|
||||
spawnCtx context.Context, prompt, model, systemPrompt string, timeout time.Duration,
|
||||
) (*core.SubagentSpawnResult, error) {
|
||||
result, err := m.Subagent(spawnCtx, SubagentConfig{
|
||||
Prompt: prompt,
|
||||
Model: model,
|
||||
SystemPrompt: systemPrompt,
|
||||
Timeout: timeout,
|
||||
OnEvent: func(e Event) {
|
||||
m.events.emit(e)
|
||||
},
|
||||
})
|
||||
if result == nil {
|
||||
return &core.SubagentSpawnResult{Error: err}, err
|
||||
}
|
||||
sr := &core.SubagentSpawnResult{
|
||||
Response: result.Response,
|
||||
Error: result.Error,
|
||||
SessionID: result.SessionID,
|
||||
Elapsed: result.Elapsed,
|
||||
}
|
||||
if result.Usage != nil {
|
||||
sr.InputTokens = result.Usage.InputTokens
|
||||
sr.OutputTokens = result.Usage.OutputTokens
|
||||
}
|
||||
return sr, err
|
||||
})
|
||||
|
||||
return m.agent.GenerateWithLoopAndStreaming(ctx, messages,
|
||||
func(toolName, toolArgs string) {
|
||||
m.events.emit(ToolCallEvent{ToolName: toolName, ToolArgs: toolArgs})
|
||||
func(toolCallID, toolName, toolArgs string) {
|
||||
m.events.emit(ToolCallEvent{
|
||||
ToolCallID: toolCallID, ToolName: toolName, ToolKind: toolKindFor(toolName),
|
||||
ToolArgs: toolArgs, ParsedArgs: parseToolArgs(toolArgs),
|
||||
})
|
||||
},
|
||||
func(toolName string, isStarting bool) {
|
||||
func(toolCallID, toolName, toolArgs string, isStarting bool) {
|
||||
if isStarting {
|
||||
m.events.emit(ToolExecutionStartEvent{ToolName: toolName})
|
||||
m.events.emit(ToolExecutionStartEvent{ToolCallID: toolCallID, ToolName: toolName, ToolKind: toolKindFor(toolName), ToolArgs: toolArgs})
|
||||
} else {
|
||||
m.events.emit(ToolExecutionEndEvent{ToolName: toolName})
|
||||
m.events.emit(ToolExecutionEndEvent{ToolCallID: toolCallID, ToolName: toolName, ToolKind: toolKindFor(toolName)})
|
||||
}
|
||||
},
|
||||
func(toolName, toolArgs, resultText string, isError bool) {
|
||||
m.events.emit(ToolResultEvent{
|
||||
ToolName: toolName, ToolArgs: toolArgs,
|
||||
func(toolCallID, toolName, toolArgs, resultText, metadata string, isError bool) {
|
||||
evt := ToolResultEvent{
|
||||
ToolCallID: toolCallID, ToolName: toolName, ToolKind: toolKindFor(toolName),
|
||||
ToolArgs: toolArgs, ParsedArgs: parseToolArgs(toolArgs),
|
||||
Result: resultText, IsError: isError,
|
||||
})
|
||||
}
|
||||
if metadata != "" {
|
||||
var meta ToolResultMetadata
|
||||
if err := json.Unmarshal([]byte(metadata), &meta); err == nil {
|
||||
evt.Metadata = &meta
|
||||
}
|
||||
}
|
||||
m.events.emit(evt)
|
||||
},
|
||||
func(content string) {
|
||||
m.events.emit(ResponseEvent{Content: content})
|
||||
@@ -1317,8 +1583,10 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
|
||||
m.lastInputTokensMu.Unlock()
|
||||
}
|
||||
|
||||
stopReason := result.StopReason
|
||||
|
||||
m.events.emit(MessageEndEvent{Content: responseText})
|
||||
m.events.emit(TurnEndEvent{Response: responseText})
|
||||
m.events.emit(TurnEndEvent{Response: responseText, StopReason: stopReason})
|
||||
|
||||
// Run AfterTurn hooks.
|
||||
if m.afterTurn.hasHooks() {
|
||||
@@ -1327,8 +1595,10 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
|
||||
|
||||
// Build TurnResult with usage stats.
|
||||
turnResult := &TurnResult{
|
||||
Response: responseText,
|
||||
Messages: result.ConversationMessages,
|
||||
Response: responseText,
|
||||
StopReason: stopReason,
|
||||
SessionID: m.GetSessionID(),
|
||||
Messages: result.ConversationMessages,
|
||||
}
|
||||
totalUsage := result.TotalUsage
|
||||
turnResult.TotalUsage = &totalUsage
|
||||
|
||||
@@ -51,3 +51,8 @@ func CodingTools(opts ...ToolOption) []Tool { return core.CodingTools(opts...) }
|
||||
// ReadOnlyTools returns tools for read-only exploration:
|
||||
// read, grep, find, ls.
|
||||
func ReadOnlyTools(opts ...ToolOption) []Tool { return core.ReadOnlyTools(opts...) }
|
||||
|
||||
// SubagentTools returns all core tools except spawn_subagent. Use this when
|
||||
// creating child Kit instances (in-process subagents) to prevent infinite
|
||||
// recursion.
|
||||
func SubagentTools(opts ...ToolOption) []Tool { return core.SubagentTools(opts...) }
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
{
|
||||
"version": 1,
|
||||
"skills": {
|
||||
"btca-cli": {
|
||||
"source": "davis7dotsh/better-context",
|
||||
"sourceType": "github",
|
||||
"computedHash": "99bc5301f4f839a6f3be99d98955f32f1cd576c218731fa05fa54a003bd20e9b"
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,4 @@
|
||||
node_modules/
|
||||
out/
|
||||
dist/
|
||||
.DS_Store
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user