mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
Compare commits
50 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 49ff4c0678 | |||
| b0802a5c32 | |||
| dfe65ca227 | |||
| d4ec756ce5 | |||
| 2971e73ee8 | |||
| 5aa6c9e116 | |||
| bca08476de | |||
| 6a599d86af | |||
| fd6f200659 | |||
| b295a25946 | |||
| f0e4e2f757 | |||
| d25249506a | |||
| 971521f534 | |||
| 8c00682367 | |||
| 58caf155c1 | |||
| 3f08bf2424 | |||
| 9fbbab05f6 | |||
| b0991c7aa6 | |||
| 9c90563765 | |||
| f36166bee5 | |||
| 879e81f9b5 | |||
| 727b42acfe | |||
| 4830981570 | |||
| dcfebafcc5 | |||
| 1f5c103667 | |||
| 4caa8ba3dc | |||
| 15ef8ad78b | |||
| 551f2710d9 | |||
| 67bda5cad5 | |||
| 01d7d754ef | |||
| c6304f1e92 | |||
| bc3c733ae3 | |||
| 428ee2b8be | |||
| eb1d7fd07e | |||
| 1e3e5cafd3 | |||
| 0b93e58fb9 | |||
| 2bb01ed72c | |||
| b6ecc36ea1 | |||
| d4f27bc912 | |||
| f12e195390 | |||
| b68b3dd0bf | |||
| 48521bf76d | |||
| 16df3a738c | |||
| 9d0b8c8cef | |||
| d9326fcf21 | |||
| 22c479277e | |||
| 8ae204f12f | |||
| 8b1665a4ce | |||
| 941f1daf0b | |||
| ab7e2bda61 |
@@ -42,6 +42,33 @@ Keep this managed block so 'openspec update' can refresh the instructions.
|
||||
- **Extension system** (`internal/extensions/`): Yaegi-interpreted Go, 13 lifecycle events, custom tools/commands/widgets/overlays/editor interceptors
|
||||
- **TUI** (`internal/ui/`): Bubble Tea v2 parent-child model (`AppModel` → `InputComponent`, `StreamComponent`, etc.)
|
||||
- **Decoupling pattern**: `cmd/root.go` has converter functions (e.g. `widgetProviderForUI()`) that bridge `internal/extensions/` types to `internal/ui/` types — the UI never imports extensions directly
|
||||
- **Public SDK** (`pkg/kit/`): The public-facing Go SDK for embedding Kit as a library. See rules below.
|
||||
|
||||
## Public SDK (`pkg/kit/`) Rules
|
||||
|
||||
`pkg/kit/` is the **public API surface** consumed by external Go developers. All exported symbols, types, function names, and godoc comments in this package are part of the SDK contract.
|
||||
|
||||
### No Dependency Name Leakage
|
||||
Internal dependency names (e.g. `charm.land/fantasy`, library-specific jargon) **must not** appear in:
|
||||
- **Exported function/method names** — use generic terms (`LLM`, `Provider`, `Message`) instead of library names
|
||||
- **Exported type names** — type aliases should use domain names (e.g. `LLMMessage`, not `FantasyMessage`)
|
||||
- **Godoc comments** on exported symbols — these are visible in `go doc` output and pkg.go.dev
|
||||
- **Struct field names and tags** on exported types
|
||||
|
||||
Using dependency types directly in **function bodies** (private implementation) is fine — that's invisible to SDK consumers.
|
||||
|
||||
### Naming Conventions for SDK Symbols
|
||||
- Type aliases re-exporting dependency types: use `LLM*` prefix (e.g. `LLMMessage`, `LLMUsage`, `LLMResponse`)
|
||||
- Conversion helpers: use `ConvertToLLM*` / `ConvertFromLLM*` (not the dependency name)
|
||||
- Provider queries: use `GetLLMProviders` (not `GetFantasyProviders`)
|
||||
- When wrapping internal methods, the `pkg/kit/` name should be dependency-agnostic even if the `internal/` method still uses the old name
|
||||
|
||||
### Deprecation Pattern
|
||||
When renaming a public SDK symbol, keep the old name as a deprecated wrapper for one release cycle:
|
||||
```go
|
||||
// Deprecated: Use NewName instead.
|
||||
func OldName() { return NewName() }
|
||||
```
|
||||
|
||||
## Key Patterns
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ A powerful, extensible AI coding agent CLI with multi-provider support, built-in
|
||||
## Features
|
||||
|
||||
- **Multi-Provider LLM Support**: Anthropic, OpenAI, Google Gemini, Ollama, Azure OpenAI, AWS Bedrock, OpenRouter, and more
|
||||
- **Built-in Core Tools**: bash, read, write, edit, grep, find, ls, spawn_subagent - no MCP overhead
|
||||
- **Built-in Core Tools**: bash, read, write, edit, grep, find, ls, subagent - no MCP overhead
|
||||
- **MCP Integration**: Connect external MCP servers for expanded capabilities
|
||||
- **Extension System**: Write custom tools, commands, widgets, and UI modifications in Go
|
||||
- **Theming**: 22 built-in color themes (KITT, Catppuccin, Dracula, Nord, etc.) with runtime switching, persistence, and custom theme files
|
||||
@@ -209,7 +209,7 @@ kit auth status # Check authentication status
|
||||
|
||||
# Model database
|
||||
kit models [provider] # List available models (optionally filter by provider)
|
||||
kit models --all # Show all providers (not just Fantasy-compatible)
|
||||
kit models --all # Show all providers (not just LLM-compatible)
|
||||
kit update-models [source] # Update model database (from models.dev, URL, file, or 'embedded')
|
||||
|
||||
# Extension management
|
||||
@@ -307,6 +307,12 @@ kit -e examples/extensions/minimal.go
|
||||
- **Themes**: Register and switch color themes via `RegisterTheme`, `SetTheme`, `ListThemes`
|
||||
- **Custom Events**: Inter-extension communication via `EmitCustomEvent`
|
||||
|
||||
**Bridged SDK APIs** (NEW): Extensions can now access internal SDK capabilities:
|
||||
- **Tree Navigation**: Navigate conversation history (`GetTreeNode`, `GetCurrentBranch`, `NavigateTo`), summarize branches (`SummarizeBranch`), and implement fresh context loops (`CollapseBranch`)
|
||||
- **Skill Loading**: Dynamically load and inject skills at runtime (`LoadSkill`, `DiscoverSkills`, `InjectSkillAsContext`)
|
||||
- **Template Parsing**: Parse and render templates with `{{variables}}` (`ParseTemplate`, `RenderTemplate`), parse CLI-style arguments (`ParseArguments`, `SimpleParseArguments`), and evaluate model conditionals (`EvaluateModelConditional`, `RenderWithModelConditionals`)
|
||||
- **Model Resolution**: Resolve model fallback chains (`ResolveModelChain`), query model capabilities (`GetModelCapabilities`, `CheckModelAvailable`), and extract provider/model ID (`GetCurrentProvider`, `GetCurrentModelID`)
|
||||
|
||||
### Extension Examples
|
||||
|
||||
See the `examples/extensions/` directory:
|
||||
@@ -318,6 +324,7 @@ See the `examples/extensions/` directory:
|
||||
- `compact-notify.go` - Notification on compaction
|
||||
- `confirm-destructive.go` - Confirm destructive operations
|
||||
- `context-inject.go` - Inject context into conversations
|
||||
- `conversation-manager.go` - **NEW** Tree navigation, branch summarization, and fresh context loops
|
||||
- `custom-editor-demo.go` - Vim-like modal editor
|
||||
- `dev-reload.go` - Development live-reload
|
||||
- `header-footer-demo.go` - Custom headers and footers
|
||||
@@ -332,10 +339,10 @@ See the `examples/extensions/` directory:
|
||||
- `plan-mode.go` - Read-only planning mode
|
||||
- `project-rules.go` - Project-specific rules
|
||||
- `prompt-demo.go` - Interactive prompts (select/confirm/input)
|
||||
- `prompt-templates.go` - **NEW** Frontmatter-driven templates with model switching and skill injection
|
||||
- `protected-paths.go` - Path protection for sensitive files
|
||||
- `subagent-widget.go` - Multi-agent orchestration with status widget
|
||||
- `subagent-test.go` - Subagent testing utilities
|
||||
- `subagent-monitor.go` - Real-time monitoring widget for spawned subagents
|
||||
- `summarize.go` - Conversation summarization
|
||||
- `tool-logger.go` - Log all tool calls
|
||||
- `neon-theme.go` - Custom theme registration and switching
|
||||
@@ -495,7 +502,7 @@ func main() {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer host.Close()
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
// Send a prompt
|
||||
response, err := host.Prompt(ctx, "What is 2+2?")
|
||||
@@ -536,23 +543,26 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
### With Callbacks
|
||||
|
||||
```go
|
||||
response, err := host.PromptWithCallbacks(
|
||||
unsub := host.OnToolCall(func(e kit.ToolCallEvent) {
|
||||
println("Calling tool:", e.ToolName)
|
||||
})
|
||||
defer unsub()
|
||||
|
||||
unsub2 := host.OnToolResult(func(e kit.ToolResultEvent) {
|
||||
if e.IsError {
|
||||
println("Tool failed:", e.ToolName)
|
||||
}
|
||||
})
|
||||
defer unsub2()
|
||||
|
||||
unsub3 := host.OnStreaming(func(e kit.MessageUpdateEvent) {
|
||||
print(e.Chunk)
|
||||
})
|
||||
defer unsub3()
|
||||
|
||||
response, err := host.Prompt(
|
||||
ctx,
|
||||
"List files in current directory",
|
||||
func(name, args string) {
|
||||
// Tool call started
|
||||
println("Calling tool:", name)
|
||||
},
|
||||
func(name, args, result string, isError bool) {
|
||||
// Tool call completed
|
||||
if isError {
|
||||
println("Tool failed:", name)
|
||||
}
|
||||
},
|
||||
func(chunk string) {
|
||||
// Streaming text chunk
|
||||
print(chunk)
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
@@ -716,7 +726,7 @@ Use `custom/custom` when pointing Kit at any OpenAI-compatible endpoint with `--
|
||||
kit --provider-url "http://localhost:8080/v1" "Hello"
|
||||
```
|
||||
|
||||
This automatically defaults to `custom/custom` without needing to specify a model. The custom provider routes through fantasy's `openaicompat` provider and supports:
|
||||
This automatically defaults to `custom/custom` without needing to specify a model. The custom provider routes through the `openaicompat` provider and supports:
|
||||
|
||||
- Zero cost tracking (input/output = 0)
|
||||
- 262K context window, 65K output limit
|
||||
|
||||
@@ -76,6 +76,12 @@
|
||||
"name": "opencode",
|
||||
"url": "https://github.com/anomalyco/opencode",
|
||||
"branch": "dev"
|
||||
},
|
||||
{
|
||||
"type": "git",
|
||||
"name": "herald",
|
||||
"url": "https://github.com/indaco/herald",
|
||||
"branch": "main"
|
||||
}
|
||||
],
|
||||
"model": "claude-haiku-4-5",
|
||||
|
||||
+1
-1
@@ -55,7 +55,7 @@ func printAllProviders(showAll bool) error {
|
||||
if showAll {
|
||||
providerIDs = kit.GetSupportedProviders()
|
||||
} else {
|
||||
providerIDs = kit.GetFantasyProviders()
|
||||
providerIDs = kit.GetLLMProviders()
|
||||
}
|
||||
sort.Strings(providerIDs)
|
||||
|
||||
|
||||
+516
-105
@@ -415,7 +415,7 @@ func runKit(ctx context.Context) error {
|
||||
// normalised to start with "/" so they integrate with the slash-command
|
||||
// autocomplete and dispatch pipeline.
|
||||
func extensionCommandsForUI(k *kit.Kit) []ui.ExtensionCommand {
|
||||
defs := k.ExtensionCommands()
|
||||
defs := k.Extensions().Commands()
|
||||
if len(defs) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -429,12 +429,12 @@ func extensionCommandsForUI(k *kit.Kit) []ui.ExtensionCommand {
|
||||
Name: name,
|
||||
Description: d.Description,
|
||||
Execute: func(args string) (string, error) {
|
||||
return d.Execute(args, k.GetExtensionContext())
|
||||
return d.Execute(args, k.Extensions().GetContext())
|
||||
},
|
||||
}
|
||||
if d.Complete != nil {
|
||||
ec.Complete = func(prefix string) []string {
|
||||
return d.Complete(prefix, k.GetExtensionContext())
|
||||
return d.Complete(prefix, k.Extensions().GetContext())
|
||||
}
|
||||
}
|
||||
cmds = append(cmds, ec)
|
||||
@@ -446,11 +446,11 @@ func extensionCommandsForUI(k *kit.Kit) []ui.ExtensionCommand {
|
||||
// ui.WidgetData for the given placement. Returns nil if extensions are
|
||||
// disabled, which is safe — the UI treats a nil GetWidgets as "no widgets".
|
||||
func widgetProviderForUI(k *kit.Kit) func(string) []ui.WidgetData {
|
||||
if !k.HasExtensions() {
|
||||
if !k.Extensions().HasExtensions() {
|
||||
return nil
|
||||
}
|
||||
return func(placement string) []ui.WidgetData {
|
||||
configs := k.GetExtensionWidgets(extensions.WidgetPlacement(placement))
|
||||
configs := k.Extensions().GetWidgets(extensions.WidgetPlacement(placement))
|
||||
if len(configs) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -467,25 +467,34 @@ func widgetProviderForUI(k *kit.Kit) func(string) []ui.WidgetData {
|
||||
}
|
||||
}
|
||||
|
||||
// headerFooterProviderForUI returns a provider func that maps an
|
||||
// extensions.HeaderFooterConfig getter into the ui.WidgetData shape
|
||||
// expected by AppModel. The getter argument selects header vs footer.
|
||||
func headerFooterProviderForUI(k *kit.Kit, getter func() *extensions.HeaderFooterConfig) func() *ui.WidgetData {
|
||||
if !k.Extensions().HasExtensions() {
|
||||
return nil
|
||||
}
|
||||
return func() *ui.WidgetData {
|
||||
cfg := getter()
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return &ui.WidgetData{
|
||||
Text: cfg.Content.Text,
|
||||
Markdown: cfg.Content.Markdown,
|
||||
BorderColor: cfg.Style.BorderColor,
|
||||
NoBorder: cfg.Style.NoBorder,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// headerProviderForUI returns a function that converts the extension header
|
||||
// to a *ui.WidgetData for the TUI. Returns nil if extensions are disabled,
|
||||
// which is safe — the UI treats a nil GetHeader as "no header".
|
||||
func headerProviderForUI(k *kit.Kit) func() *ui.WidgetData {
|
||||
if !k.HasExtensions() {
|
||||
return nil
|
||||
}
|
||||
return func() *ui.WidgetData {
|
||||
config := k.GetExtensionHeader()
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
return &ui.WidgetData{
|
||||
Text: config.Content.Text,
|
||||
Markdown: config.Content.Markdown,
|
||||
BorderColor: config.Style.BorderColor,
|
||||
NoBorder: config.Style.NoBorder,
|
||||
}
|
||||
}
|
||||
return headerFooterProviderForUI(k, func() *extensions.HeaderFooterConfig {
|
||||
return k.Extensions().GetHeader()
|
||||
})
|
||||
}
|
||||
|
||||
// toolRendererProviderForUI returns a function that converts extension tool
|
||||
@@ -493,11 +502,11 @@ func headerProviderForUI(k *kit.Kit) func() *ui.WidgetData {
|
||||
// disabled, which is safe — the UI treats a nil GetToolRenderer as "no
|
||||
// custom renderers".
|
||||
func toolRendererProviderForUI(k *kit.Kit) func(string) *ui.ToolRendererData {
|
||||
if !k.HasExtensions() {
|
||||
if !k.Extensions().HasExtensions() {
|
||||
return nil
|
||||
}
|
||||
return func(toolName string) *ui.ToolRendererData {
|
||||
config := k.GetExtensionToolRenderer(toolName)
|
||||
config := k.Extensions().GetToolRenderer(toolName)
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -517,11 +526,11 @@ func toolRendererProviderForUI(k *kit.Kit) func(string) *ui.ToolRendererData {
|
||||
// Returns nil if extensions are disabled, which is safe — the UI treats a
|
||||
// nil GetEditorInterceptor as "no interceptor".
|
||||
func editorInterceptorProviderForUI(k *kit.Kit) func() *ui.EditorInterceptor {
|
||||
if !k.HasExtensions() {
|
||||
if !k.Extensions().HasExtensions() {
|
||||
return nil
|
||||
}
|
||||
return func() *ui.EditorInterceptor {
|
||||
config := k.GetExtensionEditor()
|
||||
config := k.Extensions().GetEditor()
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -555,11 +564,11 @@ func editorInterceptorProviderForUI(k *kit.Kit) func() *ui.EditorInterceptor {
|
||||
// visibility overrides to a *ui.UIVisibility for the TUI. Returns nil if
|
||||
// extensions are disabled — the UI treats nil as "show everything".
|
||||
func uiVisibilityProviderForUI(k *kit.Kit) func() *ui.UIVisibility {
|
||||
if !k.HasExtensions() {
|
||||
if !k.Extensions().HasExtensions() {
|
||||
return nil
|
||||
}
|
||||
return func() *ui.UIVisibility {
|
||||
v := k.GetExtensionUIVisibility()
|
||||
v := k.Extensions().GetUIVisibility()
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -576,21 +585,9 @@ func uiVisibilityProviderForUI(k *kit.Kit) func() *ui.UIVisibility {
|
||||
// to a *ui.WidgetData for the TUI. Returns nil if extensions are disabled,
|
||||
// which is safe — the UI treats a nil GetFooter as "no footer".
|
||||
func footerProviderForUI(k *kit.Kit) func() *ui.WidgetData {
|
||||
if !k.HasExtensions() {
|
||||
return nil
|
||||
}
|
||||
return func() *ui.WidgetData {
|
||||
config := k.GetExtensionFooter()
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
return &ui.WidgetData{
|
||||
Text: config.Content.Text,
|
||||
Markdown: config.Content.Markdown,
|
||||
BorderColor: config.Style.BorderColor,
|
||||
NoBorder: config.Style.NoBorder,
|
||||
}
|
||||
}
|
||||
return headerFooterProviderForUI(k, func() *extensions.HeaderFooterConfig {
|
||||
return k.Extensions().GetFooter()
|
||||
})
|
||||
}
|
||||
|
||||
// statusBarProviderForUI returns a function that fetches extension status bar
|
||||
@@ -598,11 +595,11 @@ func footerProviderForUI(k *kit.Kit) func() *ui.WidgetData {
|
||||
// if extensions are disabled, which is safe — the TUI treats a nil
|
||||
// GetStatusBarEntries as "no extension entries".
|
||||
func statusBarProviderForUI(k *kit.Kit) func() []ui.StatusBarEntryData {
|
||||
if !k.HasExtensions() {
|
||||
if !k.Extensions().HasExtensions() {
|
||||
return nil
|
||||
}
|
||||
return func() []ui.StatusBarEntryData {
|
||||
entries := k.GetExtensionStatusEntries()
|
||||
entries := k.Extensions().GetStatusEntries()
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -622,30 +619,36 @@ func statusBarProviderForUI(k *kit.Kit) func() []ui.StatusBarEntryData {
|
||||
// and returns (cancelled, reason). Returns nil if extensions are disabled —
|
||||
// the UI treats nil as "no hook".
|
||||
func beforeForkProviderForUI(k *kit.Kit) func(string, bool, string) (bool, string) {
|
||||
if !k.HasExtensions() {
|
||||
if !k.Extensions().HasExtensions() {
|
||||
return nil
|
||||
}
|
||||
return k.EmitBeforeFork
|
||||
return func(targetID string, isUserMsg bool, userText string) (bool, string) {
|
||||
return k.Extensions().EmitBeforeFork(targetID, isUserMsg, userText)
|
||||
}
|
||||
}
|
||||
|
||||
// beforeSessionSwitchProviderForUI returns a callback that emits a
|
||||
// BeforeSessionSwitch event and returns (cancelled, reason). Returns nil
|
||||
// if extensions are disabled — the UI treats nil as "no hook".
|
||||
func beforeSessionSwitchProviderForUI(k *kit.Kit) func(string) (bool, string) {
|
||||
if !k.HasExtensions() {
|
||||
if !k.Extensions().HasExtensions() {
|
||||
return nil
|
||||
}
|
||||
return k.EmitBeforeSessionSwitch
|
||||
return func(switchReason string) (bool, string) {
|
||||
return k.Extensions().EmitBeforeSessionSwitch(switchReason)
|
||||
}
|
||||
}
|
||||
|
||||
// globalShortcutsProviderForUI returns a callback that queries the extension
|
||||
// runner for registered keyboard shortcuts. Returns nil if extensions are
|
||||
// disabled — the UI treats nil as "no shortcuts".
|
||||
func globalShortcutsProviderForUI(k *kit.Kit) func() map[string]func() {
|
||||
if !k.HasExtensions() {
|
||||
if !k.Extensions().HasExtensions() {
|
||||
return nil
|
||||
}
|
||||
return k.GetExtensionShortcuts
|
||||
return func() map[string]func() {
|
||||
return k.Extensions().GetShortcuts()
|
||||
}
|
||||
}
|
||||
|
||||
func runNormalMode(ctx context.Context) error {
|
||||
@@ -776,7 +779,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
treeSession := kitInstance.GetTreeSession()
|
||||
var messages []fantasy.Message
|
||||
if treeSession != nil {
|
||||
messages = treeSession.GetFantasyMessages()
|
||||
messages = treeSession.GetLLMMessages()
|
||||
}
|
||||
|
||||
// Create the app.App instance.
|
||||
@@ -799,43 +802,53 @@ func runNormalMode(ctx context.Context) error {
|
||||
appInstance := app.New(appOpts, messages)
|
||||
defer appInstance.Close()
|
||||
|
||||
// Buffer for extension messages during startup (printed after startup banner).
|
||||
var startupExtensionMessages []string
|
||||
|
||||
// Set up extension context and emit SessionStart.
|
||||
if kitInstance.HasExtensions() {
|
||||
if kitInstance.Extensions().HasExtensions() {
|
||||
cwd, _ := os.Getwd()
|
||||
kitInstance.SetExtensionContext(extensions.Context{
|
||||
CWD: cwd,
|
||||
Model: modelName,
|
||||
Interactive: positionalPrompt == "",
|
||||
Print: func(text string) { appInstance.PrintFromExtension("", text) },
|
||||
PrintInfo: func(text string) { appInstance.PrintFromExtension("info", text) },
|
||||
PrintError: func(text string) { appInstance.PrintFromExtension("error", text) },
|
||||
kitInstance.Extensions().SetContext(extensions.Context{
|
||||
CWD: cwd,
|
||||
Model: modelName,
|
||||
Interactive: positionalPrompt == "",
|
||||
Print: func(text string) {
|
||||
// Capture messages during startup, print after startup banner.
|
||||
startupExtensionMessages = append(startupExtensionMessages, text)
|
||||
},
|
||||
PrintInfo: func(text string) {
|
||||
startupExtensionMessages = append(startupExtensionMessages, text)
|
||||
},
|
||||
PrintError: func(text string) {
|
||||
startupExtensionMessages = append(startupExtensionMessages, text)
|
||||
},
|
||||
PrintBlock: appInstance.PrintBlockFromExtension,
|
||||
SendMessage: func(text string) { appInstance.Run(text) },
|
||||
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
SetWidget: func(config extensions.WidgetConfig) {
|
||||
kitInstance.SetExtensionWidget(config)
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
kitInstance.Extensions().SetWidget(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveWidget: func(id string) {
|
||||
kitInstance.RemoveExtensionWidget(id)
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
kitInstance.Extensions().RemoveWidget(id)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetHeader: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.SetExtensionHeader(config)
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
kitInstance.Extensions().SetHeader(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveHeader: func() {
|
||||
kitInstance.RemoveExtensionHeader()
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
kitInstance.Extensions().RemoveHeader()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetFooter: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.SetExtensionFooter(config)
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
kitInstance.Extensions().SetFooter(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveFooter: func() {
|
||||
kitInstance.RemoveExtensionFooter()
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
kitInstance.Extensions().RemoveFooter()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
PromptSelect: func(config extensions.PromptSelectConfig) extensions.PromptSelectResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
@@ -885,8 +898,8 @@ func runNormalMode(ctx context.Context) error {
|
||||
return extensions.PromptInputResult{Value: resp.Value}
|
||||
},
|
||||
SetUIVisibility: func(v extensions.UIVisibility) {
|
||||
kitInstance.SetExtensionUIVisibility(v)
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
kitInstance.Extensions().SetUIVisibility(v)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetContextStats: func() extensions.ContextStats {
|
||||
s := kitInstance.GetContextStats()
|
||||
@@ -898,53 +911,52 @@ func runNormalMode(ctx context.Context) error {
|
||||
}
|
||||
},
|
||||
SetEditor: func(config extensions.EditorConfig) {
|
||||
kitInstance.SetExtensionEditor(config)
|
||||
// Use a goroutine for NotifyWidgetUpdate because this may be
|
||||
// called from within an editor HandleKey callback, which runs
|
||||
// synchronously inside BubbleTea's Update(). Calling prog.Send()
|
||||
// directly from Update() deadlocks the event loop.
|
||||
kitInstance.Extensions().SetEditor(config)
|
||||
// Always use a goroutine for NotifyWidgetUpdate: prog.Send()
|
||||
// deadlocks if called synchronously from inside BubbleTea's
|
||||
// Update() handler. All call sites use go-routines uniformly.
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
ResetEditor: func() {
|
||||
kitInstance.ResetExtensionEditor()
|
||||
kitInstance.Extensions().ResetEditor()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetMessages: func() []extensions.SessionMessage {
|
||||
return kitInstance.GetSessionMessages()
|
||||
return kitInstance.Extensions().GetSessionMessages()
|
||||
},
|
||||
GetSessionPath: func() string {
|
||||
return kitInstance.GetSessionFilePath()
|
||||
return kitInstance.GetSessionPath()
|
||||
},
|
||||
AppendEntry: func(entryType string, data string) (string, error) {
|
||||
return kitInstance.AppendExtensionEntry(entryType, data)
|
||||
return kitInstance.Extensions().AppendEntry(entryType, data)
|
||||
},
|
||||
GetEntries: func(entryType string) []extensions.ExtensionEntry {
|
||||
return kitInstance.GetExtensionEntries(entryType)
|
||||
return kitInstance.Extensions().GetEntries(entryType)
|
||||
},
|
||||
SetEditorText: func(text string) {
|
||||
appInstance.SetEditorTextFromExtension(text)
|
||||
},
|
||||
SetStatus: func(key string, text string, priority int) {
|
||||
kitInstance.SetExtensionStatus(extensions.StatusBarEntry{
|
||||
kitInstance.Extensions().SetStatus(extensions.StatusBarEntry{
|
||||
Key: key,
|
||||
Text: text,
|
||||
Priority: priority,
|
||||
})
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveStatus: func(key string) {
|
||||
kitInstance.RemoveExtensionStatus(key)
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
kitInstance.Extensions().RemoveStatus(key)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetOption: func(name string) string {
|
||||
return kitInstance.GetExtensionOption(name)
|
||||
return kitInstance.Extensions().GetOption(name)
|
||||
},
|
||||
SetOption: func(name string, value string) {
|
||||
kitInstance.SetExtensionOption(name, value)
|
||||
kitInstance.Extensions().SetOption(name, value)
|
||||
},
|
||||
SetModel: func(modelString string) error {
|
||||
// Capture previous model for the ModelChange event.
|
||||
previousModel := kitInstance.GetExtensionContext().Model
|
||||
previousModel := kitInstance.Extensions().GetContext().Model
|
||||
err := kitInstance.SetModel(context.Background(), modelString)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -953,9 +965,9 @@ func runNormalMode(ctx context.Context) error {
|
||||
p, m, _ := models.ParseModelString(modelString)
|
||||
appInstance.NotifyModelChanged(p, m)
|
||||
// Update the context's Model field so handlers see it.
|
||||
kitInstance.UpdateExtensionContextModel(modelString)
|
||||
kitInstance.Extensions().UpdateContextModel(modelString)
|
||||
// Fire OnModelChange event to extensions.
|
||||
kitInstance.EmitModelChange(modelString, previousModel, "extension")
|
||||
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
|
||||
// Update usage tracker with new model info for correct token counting.
|
||||
if usageTracker != nil {
|
||||
newProvider, newModel, _ := models.ParseModelString(modelString)
|
||||
@@ -980,7 +992,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
return kitInstance.GetAvailableModels()
|
||||
},
|
||||
EmitCustomEvent: func(name string, data string) {
|
||||
kitInstance.EmitExtensionCustomEvent(name, data)
|
||||
kitInstance.Extensions().EmitCustomEvent(name, data)
|
||||
},
|
||||
Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
|
||||
return kitInstance.ExecuteCompletion(context.Background(), req)
|
||||
@@ -989,7 +1001,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
return appInstance.SuspendTUI(callback)
|
||||
},
|
||||
RenderMessage: func(rendererName, content string) {
|
||||
renderer := kitInstance.GetExtensionMessageRenderer(rendererName)
|
||||
renderer := kitInstance.Extensions().GetMessageRenderer(rendererName)
|
||||
if renderer == nil || renderer.Render == nil {
|
||||
appInstance.PrintFromExtension("", content)
|
||||
return
|
||||
@@ -1002,19 +1014,19 @@ func runNormalMode(ctx context.Context) error {
|
||||
appInstance.PrintFromExtension("", rendered)
|
||||
},
|
||||
ReloadExtensions: func() error {
|
||||
err := kitInstance.ReloadExtensions()
|
||||
err := kitInstance.Extensions().Reload()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Notify TUI that widgets/status/commands may have changed.
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
return nil
|
||||
},
|
||||
GetAllTools: func() []extensions.ToolInfo {
|
||||
return kitInstance.GetExtensionToolInfos()
|
||||
return kitInstance.Extensions().GetToolInfos()
|
||||
},
|
||||
SetActiveTools: func(names []string) {
|
||||
kitInstance.SetExtensionActiveTools(names)
|
||||
kitInstance.Extensions().SetActiveTools(names)
|
||||
},
|
||||
RegisterTheme: func(name string, config extensions.ThemeColorConfig) {
|
||||
tc := func(c extensions.ThemeColor) [2]string { return [2]string{c.Light, c.Dark} }
|
||||
@@ -1085,7 +1097,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
}
|
||||
extResult := &extensions.SubagentResult{
|
||||
Response: result.Response,
|
||||
Error: result.Error,
|
||||
Error: err,
|
||||
SessionID: result.SessionID,
|
||||
Elapsed: result.Elapsed,
|
||||
}
|
||||
@@ -1097,8 +1109,398 @@ func runNormalMode(ctx context.Context) error {
|
||||
}
|
||||
return nil, extResult, err
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Tree Navigation API (Phase 1 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
GetTreeNode: func(entryID string) *extensions.TreeNode {
|
||||
node := kitInstance.GetTreeNode(entryID)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
return &extensions.TreeNode{
|
||||
ID: node.ID,
|
||||
ParentID: node.ParentID,
|
||||
Type: node.Type,
|
||||
Role: node.Role,
|
||||
Content: node.Content,
|
||||
Model: node.Model,
|
||||
Provider: node.Provider,
|
||||
Timestamp: node.Timestamp,
|
||||
Children: node.Children,
|
||||
}
|
||||
},
|
||||
GetCurrentBranch: func() []extensions.TreeNode {
|
||||
nodes := kitInstance.GetCurrentBranch()
|
||||
result := make([]extensions.TreeNode, len(nodes))
|
||||
for i, n := range nodes {
|
||||
result[i] = extensions.TreeNode{
|
||||
ID: n.ID,
|
||||
ParentID: n.ParentID,
|
||||
Type: n.Type,
|
||||
Role: n.Role,
|
||||
Content: n.Content,
|
||||
Model: n.Model,
|
||||
Provider: n.Provider,
|
||||
Timestamp: n.Timestamp,
|
||||
Children: n.Children,
|
||||
}
|
||||
}
|
||||
return result
|
||||
},
|
||||
GetChildren: kitInstance.GetChildren,
|
||||
NavigateTo: func(entryID string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.NavigateTo(entryID)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
SummarizeBranch: func(fromID, toID string) string {
|
||||
summary, _ := kitInstance.SummarizeBranch(fromID, toID)
|
||||
return summary
|
||||
},
|
||||
CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.CollapseBranch(fromID, toID, summary)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Skill Loading API (Phase 2 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
LoadSkill: func(path string) (*extensions.Skill, string) {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
return s, err
|
||||
},
|
||||
LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult {
|
||||
return kitInstance.LoadSkillsFromDirForExtension(dir)
|
||||
},
|
||||
DiscoverSkills: func() extensions.SkillLoadResult {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
return extensions.SkillLoadResult{Skills: skills}
|
||||
},
|
||||
InjectSkillAsContext: func(skillName string) string {
|
||||
// Find skill by name
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
for _, s := range skills {
|
||||
if s.Name == skillName {
|
||||
// Inject via SendMessage as a system context message
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("skill not found: %s", skillName)
|
||||
},
|
||||
InjectRawSkillAsContext: func(path string) string {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
if err != "" {
|
||||
return err
|
||||
}
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
},
|
||||
GetAvailableSkills: kitInstance.DiscoverSkillsForExtension,
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Template Parsing API (Phase 3 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
ParseTemplate: kit.ParseTemplate,
|
||||
RenderTemplate: kit.RenderTemplate,
|
||||
ParseArguments: kit.ParseArguments,
|
||||
SimpleParseArguments: kit.SimpleParseArguments,
|
||||
EvaluateModelConditional: func(condition string) bool {
|
||||
return kit.EvaluateModelConditional(kitInstance.Extensions().GetContext().Model, condition)
|
||||
},
|
||||
RenderWithModelConditionals: func(content string) string {
|
||||
return kit.RenderWithModelConditionals(content, kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Model Resolution API (Phase 4 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
ResolveModelChain: kit.ResolveModelChain,
|
||||
GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) {
|
||||
return kit.GetModelCapabilities(model)
|
||||
},
|
||||
CheckModelAvailable: kit.CheckModelAvailable,
|
||||
GetCurrentProvider: func() string {
|
||||
return kit.GetCurrentProvider(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
GetCurrentModelID: func() string {
|
||||
return kit.GetCurrentModelID(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
})
|
||||
kitInstance.Extensions().EmitSessionStart()
|
||||
|
||||
// Restore normal print functions for runtime use.
|
||||
kitInstance.Extensions().SetContext(extensions.Context{
|
||||
CWD: cwd,
|
||||
Model: modelName,
|
||||
Interactive: positionalPrompt == "",
|
||||
Print: func(text string) { appInstance.PrintFromExtension("", text) },
|
||||
PrintInfo: func(text string) { appInstance.PrintFromExtension("info", text) },
|
||||
PrintError: func(text string) { appInstance.PrintFromExtension("error", text) },
|
||||
PrintBlock: appInstance.PrintBlockFromExtension,
|
||||
SendMessage: func(text string) { appInstance.Run(text) },
|
||||
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
SetWidget: func(config extensions.WidgetConfig) {
|
||||
kitInstance.Extensions().SetWidget(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveWidget: func(id string) {
|
||||
kitInstance.Extensions().RemoveWidget(id)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetHeader: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetHeader(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveHeader: func() {
|
||||
kitInstance.Extensions().RemoveHeader()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetFooter: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetFooter(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveFooter: func() {
|
||||
kitInstance.Extensions().RemoveFooter()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
PromptSelect: func(config extensions.PromptSelectConfig) extensions.PromptSelectResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "select",
|
||||
Message: config.Message,
|
||||
Options: config.Options,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptSelectResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptSelectResult{Value: resp.Value, Index: resp.Index}
|
||||
},
|
||||
PromptConfirm: func(config extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
def := "false"
|
||||
if config.DefaultValue {
|
||||
def = "true"
|
||||
}
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "confirm",
|
||||
Message: config.Message,
|
||||
Default: def,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptConfirmResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptConfirmResult{Value: resp.Confirmed}
|
||||
},
|
||||
PromptInput: func(config extensions.PromptInputConfig) extensions.PromptInputResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "input",
|
||||
Message: config.Message,
|
||||
Placeholder: config.Placeholder,
|
||||
Default: config.Default,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptInputResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptInputResult{Value: resp.Value}
|
||||
},
|
||||
ShowOverlay: func(config extensions.OverlayConfig) extensions.OverlayResult {
|
||||
ch := make(chan app.OverlayResponse, 1)
|
||||
appInstance.SendOverlayRequest(app.OverlayRequestEvent{
|
||||
Title: config.Title,
|
||||
Content: config.Content.Text,
|
||||
Markdown: config.Content.Markdown,
|
||||
BorderColor: config.Style.BorderColor,
|
||||
Background: config.Style.Background,
|
||||
Width: config.Width,
|
||||
MaxHeight: config.MaxHeight,
|
||||
Anchor: string(config.Anchor),
|
||||
Actions: config.Actions,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.OverlayResult{Cancelled: true, Index: -1}
|
||||
}
|
||||
return extensions.OverlayResult{
|
||||
Action: resp.Action,
|
||||
Index: resp.Index,
|
||||
}
|
||||
},
|
||||
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
// In-process subagent via SDK.
|
||||
sdkCfg := kit.SubagentConfig{
|
||||
Prompt: config.Prompt,
|
||||
Model: config.Model,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Timeout: config.Timeout,
|
||||
NoSession: config.NoSession,
|
||||
}
|
||||
// Bridge SDK events to extension SubagentEvents.
|
||||
if config.OnEvent != nil {
|
||||
sdkCfg.OnEvent = func(e kit.Event) {
|
||||
se := sdkEventToSubagentEvent(e)
|
||||
if se.Type != "" {
|
||||
config.OnEvent(se)
|
||||
}
|
||||
}
|
||||
}
|
||||
result, err := kitInstance.Subagent(ctx, sdkCfg)
|
||||
if result == nil {
|
||||
return nil, &extensions.SubagentResult{Error: err}, err
|
||||
}
|
||||
extResult := &extensions.SubagentResult{
|
||||
Response: result.Response,
|
||||
Error: err,
|
||||
SessionID: result.SessionID,
|
||||
Elapsed: result.Elapsed,
|
||||
}
|
||||
if result.Usage != nil {
|
||||
extResult.Usage = &extensions.SubagentUsage{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
return nil, extResult, err
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Tree Navigation API (Phase 1 Bridge) - Second Context
|
||||
// -------------------------------------------------------------------------
|
||||
GetTreeNode: func(entryID string) *extensions.TreeNode {
|
||||
node := kitInstance.GetTreeNode(entryID)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
return &extensions.TreeNode{
|
||||
ID: node.ID,
|
||||
ParentID: node.ParentID,
|
||||
Type: node.Type,
|
||||
Role: node.Role,
|
||||
Content: node.Content,
|
||||
Model: node.Model,
|
||||
Provider: node.Provider,
|
||||
Timestamp: node.Timestamp,
|
||||
Children: node.Children,
|
||||
}
|
||||
},
|
||||
GetCurrentBranch: func() []extensions.TreeNode {
|
||||
nodes := kitInstance.GetCurrentBranch()
|
||||
result := make([]extensions.TreeNode, len(nodes))
|
||||
for i, n := range nodes {
|
||||
result[i] = extensions.TreeNode{
|
||||
ID: n.ID,
|
||||
ParentID: n.ParentID,
|
||||
Type: n.Type,
|
||||
Role: n.Role,
|
||||
Content: n.Content,
|
||||
Model: n.Model,
|
||||
Provider: n.Provider,
|
||||
Timestamp: n.Timestamp,
|
||||
Children: n.Children,
|
||||
}
|
||||
}
|
||||
return result
|
||||
},
|
||||
GetChildren: kitInstance.GetChildren,
|
||||
NavigateTo: func(entryID string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.NavigateTo(entryID)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
SummarizeBranch: func(fromID, toID string) string {
|
||||
summary, _ := kitInstance.SummarizeBranch(fromID, toID)
|
||||
return summary
|
||||
},
|
||||
CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.CollapseBranch(fromID, toID, summary)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Skill Loading API (Phase 2 Bridge) - Second Context
|
||||
// -------------------------------------------------------------------------
|
||||
LoadSkill: func(path string) (*extensions.Skill, string) {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
return s, err
|
||||
},
|
||||
LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult {
|
||||
return kitInstance.LoadSkillsFromDirForExtension(dir)
|
||||
},
|
||||
DiscoverSkills: func() extensions.SkillLoadResult {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
return extensions.SkillLoadResult{Skills: skills}
|
||||
},
|
||||
InjectSkillAsContext: func(skillName string) string {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
for _, s := range skills {
|
||||
if s.Name == skillName {
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("skill not found: %s", skillName)
|
||||
},
|
||||
InjectRawSkillAsContext: func(path string) string {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
if err != "" {
|
||||
return err
|
||||
}
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
},
|
||||
GetAvailableSkills: func() []extensions.Skill {
|
||||
return kitInstance.DiscoverSkillsForExtension()
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Template Parsing API (Phase 3 Bridge) - Second Context
|
||||
// -------------------------------------------------------------------------
|
||||
ParseTemplate: kit.ParseTemplate,
|
||||
RenderTemplate: kit.RenderTemplate,
|
||||
ParseArguments: kit.ParseArguments,
|
||||
SimpleParseArguments: kit.SimpleParseArguments,
|
||||
EvaluateModelConditional: func(condition string) bool {
|
||||
return kit.EvaluateModelConditional(kitInstance.Extensions().GetContext().Model, condition)
|
||||
},
|
||||
RenderWithModelConditionals: func(content string) string {
|
||||
return kit.RenderWithModelConditionals(content, kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Model Resolution API (Phase 4 Bridge) - Second Context
|
||||
// -------------------------------------------------------------------------
|
||||
ResolveModelChain: kit.ResolveModelChain,
|
||||
GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) {
|
||||
return kit.GetModelCapabilities(model)
|
||||
},
|
||||
CheckModelAvailable: kit.CheckModelAvailable,
|
||||
GetCurrentProvider: func() string {
|
||||
return kit.GetCurrentProvider(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
GetCurrentModelID: func() string {
|
||||
return kit.GetCurrentModelID(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
})
|
||||
kitInstance.EmitSessionStart()
|
||||
}
|
||||
|
||||
// Convert extension commands to UI-layer type for the interactive TUI.
|
||||
@@ -1166,7 +1568,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
// Update the extension context's Model field so handlers see it.
|
||||
kitInstance.UpdateExtensionContextModel(modelString)
|
||||
kitInstance.Extensions().UpdateContextModel(modelString)
|
||||
// NOTE: We do NOT call appInstance.NotifyModelChanged() here because
|
||||
// this callback runs synchronously inside BubbleTea's Update(), and
|
||||
// NotifyModelChanged calls prog.Send() which deadlocks. The UI layer
|
||||
@@ -1192,7 +1594,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
emitModelChangeForUI := func(newModel, previousModel, source string) {
|
||||
kitInstance.EmitModelChange(newModel, previousModel, source)
|
||||
kitInstance.Extensions().EmitModelChange(newModel, previousModel, source)
|
||||
}
|
||||
|
||||
// Build thinking level callback.
|
||||
@@ -1222,7 +1624,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
return fmt.Errorf("--quiet requires a prompt")
|
||||
}
|
||||
|
||||
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI)
|
||||
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, startupExtensionMessages)
|
||||
}
|
||||
|
||||
// runNonInteractiveModeApp executes a single prompt via the app layer and exits,
|
||||
@@ -1278,7 +1680,7 @@ func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui
|
||||
|
||||
// If --no-exit was requested, hand off to the interactive TUI.
|
||||
if noExit {
|
||||
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, providerName, loadingMessage, serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModel, emitModelChange, isReasoningModel, thinkingLevel, setThinkingLevel, switchSession)
|
||||
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, providerName, loadingMessage, serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModel, emitModelChange, isReasoningModel, thinkingLevel, setThinkingLevel, switchSession, nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1333,7 +1735,7 @@ func buildJSONOutput(result *kit.TurnResult, model string) ([]byte, error) {
|
||||
}
|
||||
|
||||
for _, fmsg := range result.Messages {
|
||||
converted := kit.ConvertFromFantasyMessage(fmsg)
|
||||
converted := kit.ConvertFromLLMMessage(fmsg)
|
||||
m := jsonMessage{Role: string(converted.Role)}
|
||||
for _, p := range converted.Parts {
|
||||
switch c := p.(type) {
|
||||
@@ -1376,7 +1778,7 @@ func writeJSONError(err error) {
|
||||
// 4. Calls program.Run() which blocks until the user quits (Ctrl+C or /quit).
|
||||
//
|
||||
// SetupCLI is not used for interactive mode; the TUI (AppModel) handles its own rendering.
|
||||
func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []ui.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []ui.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error) error {
|
||||
func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []ui.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []ui.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, startupExtensionMessages []string) error {
|
||||
// Determine terminal size; fall back gracefully.
|
||||
termWidth, termHeight, err := term.GetSize(int(os.Stdout.Fd()))
|
||||
if err != nil || termWidth == 0 {
|
||||
@@ -1426,6 +1828,15 @@ func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelN
|
||||
// Print startup info to stdout before Bubble Tea takes over the screen.
|
||||
appModel.PrintStartupInfo()
|
||||
|
||||
// Print any extension messages that were captured during startup.
|
||||
if len(startupExtensionMessages) > 0 {
|
||||
fmt.Println()
|
||||
for _, msg := range startupExtensionMessages {
|
||||
fmt.Println(msg)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
program := tea.NewProgram(appModel)
|
||||
|
||||
// Register the program with the app layer so agent events are sent to the TUI.
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
//go:build ignore
|
||||
|
||||
// bridge_demo.go - Demonstrates the new bridged SDK APIs for extensions.
|
||||
// This extension showcases tree navigation, skill loading, template parsing,
|
||||
// and model resolution capabilities.
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
var (
|
||||
discoveredSkills []ext.Skill
|
||||
currentBranch []ext.TreeNode
|
||||
)
|
||||
|
||||
func Init(api ext.API) {
|
||||
// Register /tree-info command to demonstrate tree navigation
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "tree-info",
|
||||
Description: "Show current conversation tree information",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
branch := ctx.GetCurrentBranch()
|
||||
info := fmt.Sprintf("Current branch has %d nodes:\n", len(branch))
|
||||
for i, node := range branch {
|
||||
info += fmt.Sprintf(" [%d] %s (%s): %s...\n", i, node.Type, node.ID[:8], truncate(node.Content, 40))
|
||||
}
|
||||
ctx.PrintInfo(info)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /discover-skills command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "discover-skills",
|
||||
Description: "Discover and list available skills",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
result := ctx.DiscoverSkills()
|
||||
if result.Error != "" {
|
||||
return "", fmt.Errorf("discovery failed: %s", result.Error)
|
||||
}
|
||||
discoveredSkills = result.Skills
|
||||
|
||||
info := fmt.Sprintf("Discovered %d skills:\n", len(result.Skills))
|
||||
for _, s := range result.Skills {
|
||||
info += fmt.Sprintf(" - %s: %s\n", s.Name, s.Description)
|
||||
}
|
||||
ctx.PrintInfo(info)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /parse-template command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "parse-template",
|
||||
Description: "Parse a template and show extracted variables",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
if args == "" {
|
||||
args = "Hello {{name}}, welcome to {{place}}!"
|
||||
}
|
||||
tpl := ctx.ParseTemplate("demo", args)
|
||||
info := fmt.Sprintf("Template: %s\nVariables: %v", tpl.Content, tpl.Variables)
|
||||
ctx.PrintInfo(info)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /render-template command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "render-template",
|
||||
Description: "Render a template with variables (usage: /render-template name=John place=Kit)",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
tpl := ctx.ParseTemplate("demo", "Hello {{name}}, welcome to {{place}}!")
|
||||
vars := ctx.ParseArguments(args, ext.ArgumentPattern{
|
||||
Flags: map[string]string{"name": "name", "place": "place"},
|
||||
})
|
||||
rendered := ctx.RenderTemplate(tpl, vars.Vars)
|
||||
ctx.PrintInfo("Rendered: " + rendered)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /check-model command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "check-model",
|
||||
Description: "Check model capabilities and availability",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
model := args
|
||||
if model == "" {
|
||||
model = ctx.Model
|
||||
}
|
||||
|
||||
available := ctx.CheckModelAvailable(model)
|
||||
caps, err := ctx.GetModelCapabilities(model)
|
||||
|
||||
info := fmt.Sprintf("Model: %s\n", model)
|
||||
info += fmt.Sprintf("Available: %v\n", available)
|
||||
if err == "" {
|
||||
info += fmt.Sprintf("Provider: %s\n", caps.Provider)
|
||||
info += fmt.Sprintf("Context Limit: %d\n", caps.ContextLimit)
|
||||
info += fmt.Sprintf("Reasoning: %v\n", caps.Reasoning)
|
||||
} else {
|
||||
info += fmt.Sprintf("Error: %s\n", err)
|
||||
}
|
||||
ctx.PrintInfo(info)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /resolve-chain command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "resolve-chain",
|
||||
Description: "Resolve a model chain (usage: /resolve-chain claude-opus,gpt-4o,claude-sonnet)",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
if args == "" {
|
||||
args = "anthropic/claude-opus-4,anthropic/claude-sonnet-4,openai/gpt-4o"
|
||||
}
|
||||
prefs := ctx.SimpleParseArguments(args, 1)
|
||||
chain := []string{}
|
||||
if len(prefs) > 1 {
|
||||
// Split the first arg by comma
|
||||
for _, p := range strings.Split(prefs[1], ",") {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
chain = append(chain, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := ctx.ResolveModelChain(chain)
|
||||
info, _ := json.MarshalIndent(result, "", " ")
|
||||
ctx.PrintInfo("Resolution Result:\n" + string(info))
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /test-conditional command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "test-conditional",
|
||||
Description: "Test model conditional rendering",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
content := `<if-model is="claude-*">This is for Claude models<else>This is for other models</if-model>`
|
||||
rendered := ctx.RenderWithModelConditionals(content)
|
||||
ctx.PrintInfo("Input: " + content)
|
||||
ctx.PrintInfo("Output: " + rendered)
|
||||
ctx.PrintInfo(fmt.Sprintf("Current model matches 'claude-*': %v", ctx.EvaluateModelConditional("claude-*")))
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// OnSessionStart: discover skills automatically
|
||||
api.OnSessionStart(func(e ext.SessionStartEvent, ctx ext.Context) {
|
||||
result := ctx.DiscoverSkills()
|
||||
if result.Error == "" && len(result.Skills) > 0 {
|
||||
discoveredSkills = result.Skills
|
||||
ctx.SetStatus("bridge-demo", fmt.Sprintf("%d skills", len(result.Skills)), 50)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func truncate(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max-3] + "..."
|
||||
}
|
||||
@@ -0,0 +1,406 @@
|
||||
//go:build ignore
|
||||
|
||||
// conversation-manager.go - Advanced conversation tree navigation and management.
|
||||
// This extension demonstrates:
|
||||
// - Tree navigation (GetTreeNode, GetCurrentBranch, NavigateTo)
|
||||
// - Branch summarization and collapsing
|
||||
// - Interactive tree exploration
|
||||
//
|
||||
// Commands:
|
||||
// /tree - Show conversation tree structure
|
||||
// /branch - Show current branch path
|
||||
// /goto <entry-id> - Navigate to a specific entry
|
||||
// /summarize <n> - Summarize last N messages
|
||||
// /fresh-context - Collapse branch and start fresh
|
||||
// /loop <n> <prompt> - Execute prompt N times with fresh context each iteration
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
var (
|
||||
loopActive bool
|
||||
loopCount int
|
||||
loopCurrent int
|
||||
loopPrompt string
|
||||
loopStartNode string
|
||||
)
|
||||
|
||||
func Init(api ext.API) {
|
||||
// /tree - Show tree structure
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "tree",
|
||||
Description: "Show conversation tree structure",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
showTree(ctx)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// /branch - Show current branch
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "branch",
|
||||
Description: "Show current conversation branch",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
showBranch(ctx)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// /goto - Navigate to entry
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "goto",
|
||||
Description: "Navigate to a specific entry ID (usage: /goto <entry-id>)",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
if args == "" {
|
||||
ctx.PrintError("Usage: /goto <entry-id>")
|
||||
return "", nil
|
||||
}
|
||||
result := ctx.NavigateTo(args)
|
||||
if !result.Success {
|
||||
ctx.PrintError(fmt.Sprintf("Navigation failed: %s", result.Error))
|
||||
return "", nil
|
||||
}
|
||||
ctx.PrintInfo(fmt.Sprintf("Navigated to entry: %s", args))
|
||||
|
||||
// Show the node we navigated to
|
||||
node := ctx.GetTreeNode(args)
|
||||
if node != nil {
|
||||
ctx.PrintInfo(fmt.Sprintf("Entry type: %s, Role: %s", node.Type, node.Role))
|
||||
}
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// /summarize - Summarize recent messages
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "summarize",
|
||||
Description: "Summarize last N messages (usage: /summarize [n=5])",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
n := 5
|
||||
if args != "" {
|
||||
if parsed, err := strconv.Atoi(args); err == nil && parsed > 0 {
|
||||
n = parsed
|
||||
}
|
||||
}
|
||||
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) < 2 {
|
||||
ctx.PrintError("Not enough messages to summarize")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Find range to summarize
|
||||
startIdx := len(branch) - n - 1
|
||||
if startIdx < 0 {
|
||||
startIdx = 0
|
||||
}
|
||||
endIdx := len(branch) - 1
|
||||
|
||||
fromID := branch[startIdx].ID
|
||||
toID := branch[endIdx].ID
|
||||
|
||||
ctx.PrintInfo(fmt.Sprintf("Summarizing messages %d to %d...", startIdx, endIdx))
|
||||
summary := ctx.SummarizeBranch(fromID, toID)
|
||||
|
||||
if summary == "" {
|
||||
ctx.PrintError("Failed to generate summary")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: summary,
|
||||
BorderColor: "#89b4fa",
|
||||
Subtitle: "conversation-manager · Summary",
|
||||
})
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// /fresh-context - Collapse and restart
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "fresh-context",
|
||||
Description: "Collapse conversation to summary and start fresh",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) < 3 {
|
||||
ctx.PrintError("Not enough context to collapse")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Keep first message (system), summarize rest
|
||||
fromID := branch[1].ID
|
||||
toID := branch[len(branch)-1].ID
|
||||
|
||||
ctx.PrintInfo("Generating summary for context collapse...")
|
||||
summary := ctx.SummarizeBranch(fromID, toID)
|
||||
|
||||
if summary == "" {
|
||||
ctx.PrintError("Failed to generate summary")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Collapse the branch
|
||||
result := ctx.CollapseBranch(fromID, toID, summary)
|
||||
if !result.Success {
|
||||
ctx.PrintError(fmt.Sprintf("Collapse failed: %s", result.Error))
|
||||
return "", nil
|
||||
}
|
||||
|
||||
ctx.PrintInfo("Context collapsed. Starting fresh with summary.")
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: summary,
|
||||
BorderColor: "#a6e3a1",
|
||||
Subtitle: "conversation-manager · Collapsed Context",
|
||||
})
|
||||
|
||||
// Set a widget showing we're in fresh mode
|
||||
ctx.SetWidget(ext.WidgetConfig{
|
||||
ID: "fresh-context",
|
||||
Placement: ext.WidgetAbove,
|
||||
Content: ext.WidgetContent{Text: "🌱 Fresh Context Mode - Previous conversation collapsed"},
|
||||
Style: ext.WidgetStyle{BorderColor: "#a6e3a1"},
|
||||
})
|
||||
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// /loop - Execute with fresh context each iteration
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "loop",
|
||||
Description: "Execute prompt N times with fresh context (usage: /loop 5 analyze this code)",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
if loopActive {
|
||||
ctx.PrintError("Loop already in progress. Wait for completion.")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Parse arguments
|
||||
parts := strings.SplitN(args, " ", 2)
|
||||
if len(parts) < 2 {
|
||||
ctx.PrintError("Usage: /loop <count> <prompt>")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
count, err := strconv.Atoi(parts[0])
|
||||
if err != nil || count <= 0 || count > 10 {
|
||||
ctx.PrintError("Invalid count (must be 1-10)")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
loopCount = count
|
||||
loopCurrent = 0
|
||||
loopPrompt = parts[1]
|
||||
loopActive = true
|
||||
|
||||
// Store current branch position
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) > 0 {
|
||||
loopStartNode = branch[len(branch)-1].ID
|
||||
}
|
||||
|
||||
ctx.PrintInfo(fmt.Sprintf("Starting loop: %d iterations", loopCount))
|
||||
ctx.SetWidget(ext.WidgetConfig{
|
||||
ID: "loop-progress",
|
||||
Placement: ext.WidgetAbove,
|
||||
Content: ext.WidgetContent{Text: fmt.Sprintf("🔄 Loop: 0/%d - %s", loopCount, loopPrompt)},
|
||||
Style: ext.WidgetStyle{BorderColor: "#fab387"},
|
||||
})
|
||||
|
||||
// Start first iteration
|
||||
executeLoopIteration(ctx)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// OnAgentEnd handles loop continuation
|
||||
api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) {
|
||||
if !loopActive {
|
||||
return
|
||||
}
|
||||
|
||||
loopCurrent++
|
||||
|
||||
if loopCurrent >= loopCount {
|
||||
// Loop complete
|
||||
loopActive = false
|
||||
ctx.RemoveWidget("loop-progress")
|
||||
ctx.PrintInfo(fmt.Sprintf("✅ Loop complete: %d/%d iterations", loopCurrent, loopCount))
|
||||
|
||||
// Show final summary
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) > 0 && loopStartNode != "" {
|
||||
summary := ctx.SummarizeBranch(loopStartNode, branch[len(branch)-1].ID)
|
||||
if summary != "" {
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: summary,
|
||||
BorderColor: "#a6e3a1",
|
||||
Subtitle: "conversation-manager · Loop Summary",
|
||||
})
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Update progress
|
||||
ctx.SetWidget(ext.WidgetConfig{
|
||||
ID: "loop-progress",
|
||||
Placement: ext.WidgetAbove,
|
||||
Content: ext.WidgetContent{Text: fmt.Sprintf("🔄 Loop: %d/%d - %s", loopCurrent, loopCount, loopPrompt)},
|
||||
Style: ext.WidgetStyle{BorderColor: "#fab387"},
|
||||
})
|
||||
|
||||
// Collapse previous iteration for fresh context
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) >= 2 {
|
||||
// Find the user messages (look for the one before the last assistant message)
|
||||
// We want to collapse from the user message that started this iteration
|
||||
// to the last assistant response
|
||||
var collapseStartIdx = -1
|
||||
for i := len(branch) - 1; i >= 0; i-- {
|
||||
if branch[i].Role == "assistant" {
|
||||
// Found the last assistant message, now find the user message before it
|
||||
for j := i - 1; j >= 0; j-- {
|
||||
if branch[j].Role == "user" {
|
||||
collapseStartIdx = j
|
||||
break
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if collapseStartIdx >= 0 {
|
||||
fromID := branch[collapseStartIdx].ID
|
||||
toID := branch[len(branch)-1].ID
|
||||
|
||||
ctx.PrintInfo(fmt.Sprintf("Collapsing iteration %d for fresh context...", loopCurrent))
|
||||
summary := ctx.SummarizeBranch(fromID, toID)
|
||||
if summary != "" {
|
||||
result := ctx.CollapseBranch(fromID, toID, summary)
|
||||
if result.Success {
|
||||
ctx.PrintInfo("Context collapsed successfully")
|
||||
} else {
|
||||
ctx.PrintError(fmt.Sprintf("Collapse failed: %s", result.Error))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Small delay to let UI update
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Trigger next iteration
|
||||
executeLoopIteration(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
// showTree displays the conversation tree structure
|
||||
func showTree(ctx ext.Context) {
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) == 0 {
|
||||
ctx.PrintInfo("Tree is empty")
|
||||
return
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString(fmt.Sprintf("Conversation Tree (%d nodes):\n\n", len(branch)))
|
||||
|
||||
for i, node := range branch {
|
||||
prefix := " "
|
||||
if i == len(branch)-1 {
|
||||
prefix = "▶ " // Current node
|
||||
} else {
|
||||
prefix = " "
|
||||
}
|
||||
|
||||
roleIcon := "💬"
|
||||
switch node.Role {
|
||||
case "user":
|
||||
roleIcon = "👤"
|
||||
case "assistant":
|
||||
roleIcon = "🤖"
|
||||
case "system":
|
||||
roleIcon = "⚙️"
|
||||
}
|
||||
|
||||
content := truncate(node.Content, 50)
|
||||
if node.Type == "branch_summary" {
|
||||
roleIcon = "📋"
|
||||
content = "[Summary] " + truncate(node.Content, 40)
|
||||
}
|
||||
|
||||
output.WriteString(fmt.Sprintf("%s%s %s: %s (%s...)\n", prefix, roleIcon, node.Role, node.ID[:8], content))
|
||||
|
||||
// Show children count if any
|
||||
children := ctx.GetChildren(node.ID)
|
||||
if len(children) > 0 {
|
||||
output.WriteString(fmt.Sprintf(" └─ %d branch(es)\n", len(children)))
|
||||
}
|
||||
}
|
||||
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: output.String(),
|
||||
BorderColor: "#89b4fa",
|
||||
Subtitle: "conversation-manager · Tree View",
|
||||
})
|
||||
}
|
||||
|
||||
// showBranch displays the current branch path
|
||||
func showBranch(ctx ext.Context) {
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) == 0 {
|
||||
ctx.PrintInfo("No active branch")
|
||||
return
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString(fmt.Sprintf("Current Branch (%d nodes from root to leaf):\n\n", len(branch)))
|
||||
|
||||
for i, node := range branch {
|
||||
marker := " "
|
||||
if i == len(branch)-1 {
|
||||
marker = "▶ " // Current leaf
|
||||
}
|
||||
|
||||
output.WriteString(fmt.Sprintf("%s[%d] %s (%s): %s\n",
|
||||
marker, i, node.Type, node.ID[:8], truncate(node.Content, 40)))
|
||||
}
|
||||
|
||||
// Show current node details
|
||||
leaf := branch[len(branch)-1]
|
||||
output.WriteString(fmt.Sprintf("\nCurrent Leaf:\n"))
|
||||
output.WriteString(fmt.Sprintf(" ID: %s\n", leaf.ID))
|
||||
output.WriteString(fmt.Sprintf(" Type: %s\n", leaf.Type))
|
||||
output.WriteString(fmt.Sprintf(" Role: %s\n", leaf.Role))
|
||||
output.WriteString(fmt.Sprintf(" Model: %s\n", leaf.Model))
|
||||
output.WriteString(fmt.Sprintf(" Children: %d\n", len(leaf.Children)))
|
||||
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: output.String(),
|
||||
BorderColor: "#cba6f7",
|
||||
Subtitle: "conversation-manager · Branch View",
|
||||
})
|
||||
}
|
||||
|
||||
// executeLoopIteration triggers the next loop iteration
|
||||
func executeLoopIteration(ctx ext.Context) {
|
||||
iterationPrompt := fmt.Sprintf("[%d/%d] %s", loopCurrent+1, loopCount, loopPrompt)
|
||||
ctx.SendMessage(iterationPrompt)
|
||||
}
|
||||
|
||||
// truncate helper
|
||||
func truncate(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max-3] + "..."
|
||||
}
|
||||
@@ -908,7 +908,7 @@ func summarizeToolAction(toolName string, inputJSON string) string {
|
||||
return "searching " + getStr("pattern", "text")
|
||||
case "ls":
|
||||
return "listing " + getStr("path", "directory")
|
||||
case "spawn_subagent":
|
||||
case "subagent":
|
||||
return "spawning subagent"
|
||||
default:
|
||||
return "using " + toolName
|
||||
|
||||
@@ -2,9 +2,7 @@
|
||||
|
||||
// lsp-diagnostics.go — LSP-powered diagnostics for Kit's edit tool.
|
||||
//
|
||||
// Starts language servers on demand and surfaces diagnostics after file edits,
|
||||
// following the same pattern used by Charm's crush editor:
|
||||
//
|
||||
// Starts language servers on demand and surfaces diagnostics after file edits:
|
||||
// 1. After an edit, notify the LSP server of the file change
|
||||
// 2. Wait for the server to publish fresh diagnostics
|
||||
// 3. Append diagnostic output to the edit tool's result
|
||||
@@ -412,7 +410,7 @@ func (c *lspClient) changeFile(absPath, content string) {
|
||||
}
|
||||
|
||||
// waitForDiagnostics polls until the server publishes new diagnostics or
|
||||
// the timeout elapses. Mirrors crush's WaitForDiagnostics pattern.
|
||||
// the timeout elapses.
|
||||
func (c *lspClient) waitForDiagnostics(timeout time.Duration) {
|
||||
c.diagMu.Lock()
|
||||
startVersion := c.diagVersion
|
||||
|
||||
@@ -0,0 +1,269 @@
|
||||
//go:build ignore
|
||||
|
||||
// prompt-templates.go - Frontmatter-driven prompt templates with model switching.
|
||||
// This extension demonstrates the new bridged SDK APIs:
|
||||
// - Tree navigation for conversation management
|
||||
// - Template parsing with {{variable}} substitution
|
||||
// - Model resolution with fallback chains
|
||||
// - Skill injection
|
||||
//
|
||||
// Usage:
|
||||
// 1. Create ~/.config/kit/prompts/debug.md with frontmatter:
|
||||
// ---
|
||||
// description: Debug Python code
|
||||
// model: claude-sonnet-4-20250514
|
||||
// skill: python
|
||||
// ---
|
||||
// Help me debug this Python code: {{input}}
|
||||
//
|
||||
// 2. In Kit: /debug my_script.py
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
// PromptTemplate represents a loaded template with frontmatter
|
||||
type PromptTemplate struct {
|
||||
Name string
|
||||
Description string
|
||||
Model string
|
||||
Skill string
|
||||
Content string
|
||||
Variables []string
|
||||
Path string
|
||||
}
|
||||
|
||||
var (
|
||||
templates = make(map[string]PromptTemplate)
|
||||
templateDir string
|
||||
)
|
||||
|
||||
func Init(api ext.API) {
|
||||
// Determine template directory
|
||||
home, _ := os.UserHomeDir()
|
||||
templateDir = filepath.Join(home, ".config", "kit", "prompts")
|
||||
|
||||
// Ensure directory exists
|
||||
os.MkdirAll(templateDir, 0755)
|
||||
|
||||
// Register commands
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "reload-templates",
|
||||
Description: "Reload prompt templates from disk",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
loadTemplates(ctx)
|
||||
ctx.PrintInfo(fmt.Sprintf("Loaded %d templates from %s", len(templates), templateDir))
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Dynamic template commands are registered after loading
|
||||
api.OnSessionStart(func(e ext.SessionStartEvent, ctx ext.Context) {
|
||||
loadTemplates(ctx)
|
||||
registerTemplateCommands(api, ctx)
|
||||
})
|
||||
}
|
||||
|
||||
// loadTemplates discovers and loads all template files
|
||||
func loadTemplates(ctx ext.Context) {
|
||||
templates = make(map[string]PromptTemplate)
|
||||
|
||||
entries, err := os.ReadDir(templateDir)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".md") {
|
||||
continue
|
||||
}
|
||||
|
||||
path := filepath.Join(templateDir, entry.Name())
|
||||
tpl, err := loadTemplateFile(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
name := strings.TrimSuffix(entry.Name(), ".md")
|
||||
templates[name] = tpl
|
||||
}
|
||||
}
|
||||
|
||||
// loadTemplateFile parses a template with YAML frontmatter
|
||||
func loadTemplateFile(path string) (PromptTemplate, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return PromptTemplate{}, err
|
||||
}
|
||||
|
||||
content := string(data)
|
||||
tpl := PromptTemplate{Path: path}
|
||||
|
||||
// Parse frontmatter
|
||||
if strings.HasPrefix(content, "---") {
|
||||
parts := strings.SplitN(content[3:], "---", 2)
|
||||
if len(parts) == 2 {
|
||||
frontmatter := strings.TrimSpace(parts[0])
|
||||
body := strings.TrimSpace(parts[1])
|
||||
|
||||
// Simple line-by-line frontmatter parsing
|
||||
for _, line := range strings.Split(frontmatter, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
key, value, found := strings.Cut(line, ":")
|
||||
if found {
|
||||
key = strings.TrimSpace(key)
|
||||
value = strings.TrimSpace(value)
|
||||
switch key {
|
||||
case "description":
|
||||
tpl.Description = value
|
||||
case "model":
|
||||
tpl.Model = value
|
||||
case "skill":
|
||||
tpl.Skill = value
|
||||
}
|
||||
}
|
||||
}
|
||||
tpl.Content = body
|
||||
} else {
|
||||
tpl.Content = content
|
||||
}
|
||||
} else {
|
||||
tpl.Content = content
|
||||
}
|
||||
|
||||
// Parse {{variables}} using simple string parsing
|
||||
// (Can't use ctx.ParseTemplate here since we're in Init, not a handler)
|
||||
var vars []string
|
||||
for {
|
||||
start := strings.Index(tpl.Content, "{{")
|
||||
if start == -1 {
|
||||
break
|
||||
}
|
||||
end := strings.Index(tpl.Content[start:], "}}")
|
||||
if end == -1 {
|
||||
break
|
||||
}
|
||||
varName := strings.TrimSpace(tpl.Content[start+2 : start+end])
|
||||
vars = append(vars, varName)
|
||||
tpl.Content = tpl.Content[:start] + "{{" + varName + "}}" + tpl.Content[start+end+2:]
|
||||
}
|
||||
tpl.Variables = vars
|
||||
|
||||
return tpl, nil
|
||||
}
|
||||
|
||||
// registerTemplateCommands dynamically registers commands for each template
|
||||
func registerTemplateCommands(api ext.API, ctx ext.Context) {
|
||||
for name, tpl := range templates {
|
||||
// Skip if already registered (we'd need to track this)
|
||||
tplCopy := tpl // Capture for closure
|
||||
nameCopy := name
|
||||
|
||||
// Build description with metadata
|
||||
desc := tplCopy.Description
|
||||
if desc == "" {
|
||||
desc = fmt.Sprintf("Run %s template", nameCopy)
|
||||
}
|
||||
if tplCopy.Model != "" {
|
||||
desc += fmt.Sprintf(" [%s", tplCopy.Model)
|
||||
if tplCopy.Skill != "" {
|
||||
desc += fmt.Sprintf(" +%s", tplCopy.Skill)
|
||||
}
|
||||
desc += "]"
|
||||
}
|
||||
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: nameCopy,
|
||||
Description: desc,
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
return executeTemplate(ctx, tplCopy, args)
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// executeTemplate runs a template with the given arguments
|
||||
func executeTemplate(ctx ext.Context, tpl PromptTemplate, args string) (string, error) {
|
||||
// Store original model for restoration
|
||||
originalModel := ctx.Model
|
||||
|
||||
// 1. Resolve and switch model if specified
|
||||
if tpl.Model != "" {
|
||||
// Parse model chain (comma-separated)
|
||||
preferences := strings.Split(tpl.Model, ",")
|
||||
for i := range preferences {
|
||||
preferences[i] = strings.TrimSpace(preferences[i])
|
||||
}
|
||||
|
||||
result := ctx.ResolveModelChain(preferences)
|
||||
if result.Error != "" {
|
||||
ctx.PrintError(fmt.Sprintf("Model resolution failed: %s", result.Error))
|
||||
// Continue with current model
|
||||
} else {
|
||||
ctx.PrintInfo(fmt.Sprintf("Switching to model: %s", result.Model))
|
||||
if err := ctx.SetModel(result.Model); err != nil {
|
||||
ctx.PrintError(fmt.Sprintf("Failed to switch model: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Inject skill if specified
|
||||
if tpl.Skill != "" {
|
||||
err := ctx.InjectSkillAsContext(tpl.Skill)
|
||||
if err != "" {
|
||||
ctx.PrintError(fmt.Sprintf("Skill injection failed: %s", err))
|
||||
} else {
|
||||
ctx.PrintInfo(fmt.Sprintf("Injected skill: %s", tpl.Skill))
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Parse and render template
|
||||
parsed := ctx.ParseTemplate(tpl.Name, tpl.Content)
|
||||
|
||||
// Build variable map
|
||||
vars := make(map[string]string)
|
||||
|
||||
// Simple argument parsing: first arg is $1 (input), rest is $@
|
||||
if len(parsed.Variables) > 0 {
|
||||
argsList := ctx.SimpleParseArguments(args, len(parsed.Variables))
|
||||
for i, varName := range parsed.Variables {
|
||||
if i < len(parsed.Variables) && i+1 < len(argsList) {
|
||||
vars[varName] = argsList[i+1]
|
||||
}
|
||||
}
|
||||
// If single variable, use full args
|
||||
if len(parsed.Variables) == 1 && vars[parsed.Variables[0]] == "" {
|
||||
vars[parsed.Variables[0]] = args
|
||||
}
|
||||
}
|
||||
|
||||
// Render with model conditionals
|
||||
content := ctx.RenderWithModelConditionals(tpl.Content)
|
||||
rendered := ctx.RenderTemplate(ext.PromptTemplate{Name: tpl.Name, Content: content, Variables: parsed.Variables}, vars)
|
||||
|
||||
// 4. Send the rendered prompt
|
||||
ctx.SendMessage(rendered)
|
||||
|
||||
// 5. Schedule model restoration after turn completes
|
||||
// We use a goroutine to wait and restore
|
||||
if tpl.Model != "" && originalModel != "" {
|
||||
go func() {
|
||||
// Note: In a real implementation, we'd use OnAgentEnd event
|
||||
// For now, the user can manually switch back
|
||||
ctx.SetStatus("template-mode", fmt.Sprintf("Template: %s (model will restore)", tpl.Name), 20)
|
||||
}()
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Executing template: %s", tpl.Name), nil
|
||||
}
|
||||
@@ -37,7 +37,7 @@ func Init(api ext.API) {
|
||||
"Subagent Test Extension loaded\n\n" +
|
||||
"/subtest <task> Spawn blocking subagent\n" +
|
||||
"/subbg <task> Spawn background subagent\n\n" +
|
||||
"The LLM can also use the spawn_subagent tool.")
|
||||
"The LLM can also use the subagent tool.")
|
||||
})
|
||||
|
||||
api.OnAgentEnd(func(_ ext.AgentEndEvent, ctx ext.Context) {
|
||||
|
||||
@@ -82,6 +82,7 @@ require (
|
||||
github.com/googleapis/gax-go/v2 v2.20.0 // indirect
|
||||
github.com/gorilla/css v1.0.1 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/indaco/herald v0.9.0 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.2.12 // indirect
|
||||
github.com/kaptinlin/jsonpointer v0.4.17 // indirect
|
||||
github.com/kaptinlin/jsonschema v0.7.6 // indirect
|
||||
|
||||
@@ -187,6 +187,8 @@ github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUq
|
||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/indaco/herald v0.9.0 h1:LrAfXEHkKz8WmctUKdndppIU/qFpylSbZ8galS0DVAc=
|
||||
github.com/indaco/herald v0.9.0/go.mod h1:T5g1+XLYvpjouhzAGHnAHDCKizhESkoV6+QPZ3DhgWA=
|
||||
github.com/kaptinlin/go-i18n v0.2.12 h1:ywDsvb4KDFddMC2dpI/rrIzGU2mWUSvHmWUm9BMsdl4=
|
||||
github.com/kaptinlin/go-i18n v0.2.12/go.mod h1:pVcu9qsW5pOIOoZFJXesRYmLos1vMQrby70JPAoWmJU=
|
||||
github.com/kaptinlin/jsonpointer v0.4.17 h1:mY9k8ciWncxbsECyaxKnR0MdmxamNdp2tLQkAKVrtSk=
|
||||
|
||||
@@ -62,8 +62,8 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
|
||||
// work in ACP mode. TUI-dependent features (widgets, prompts, editor)
|
||||
// become no-ops or return cancelled; all data/model/tool APIs work
|
||||
// identically to interactive mode.
|
||||
if kitInstance.HasExtensions() {
|
||||
kitInstance.SetExtensionContext(extensions.Context{
|
||||
if kitInstance.Extensions().HasExtensions() {
|
||||
kitInstance.Extensions().SetContext(extensions.Context{
|
||||
SessionID: sessionID,
|
||||
CWD: cwd,
|
||||
Model: kitInstance.GetModelString(),
|
||||
@@ -121,31 +121,31 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
|
||||
MessageCount: s.MessageCount,
|
||||
}
|
||||
},
|
||||
GetMessages: func() []extensions.SessionMessage { return kitInstance.GetSessionMessages() },
|
||||
GetSessionPath: func() string { return kitInstance.GetSessionFilePath() },
|
||||
GetMessages: func() []extensions.SessionMessage { return kitInstance.Extensions().GetSessionMessages() },
|
||||
GetSessionPath: func() string { return kitInstance.GetSessionPath() },
|
||||
AppendEntry: func(entryType, data string) (string, error) {
|
||||
return kitInstance.AppendExtensionEntry(entryType, data)
|
||||
return kitInstance.Extensions().AppendEntry(entryType, data)
|
||||
},
|
||||
GetEntries: func(entryType string) []extensions.ExtensionEntry {
|
||||
return kitInstance.GetExtensionEntries(entryType)
|
||||
return kitInstance.Extensions().GetEntries(entryType)
|
||||
},
|
||||
|
||||
// Options, model, and tool management.
|
||||
GetOption: func(name string) string { return kitInstance.GetExtensionOption(name) },
|
||||
SetOption: func(name, value string) { kitInstance.SetExtensionOption(name, value) },
|
||||
GetOption: func(name string) string { return kitInstance.Extensions().GetOption(name) },
|
||||
SetOption: func(name, value string) { kitInstance.Extensions().SetOption(name, value) },
|
||||
SetModel: func(modelString string) error {
|
||||
previousModel := kitInstance.GetExtensionContext().Model
|
||||
previousModel := kitInstance.Extensions().GetContext().Model
|
||||
if err := kitInstance.SetModel(context.Background(), modelString); err != nil {
|
||||
return err
|
||||
}
|
||||
kitInstance.UpdateExtensionContextModel(modelString)
|
||||
kitInstance.EmitModelChange(modelString, previousModel, "extension")
|
||||
kitInstance.Extensions().UpdateContextModel(modelString)
|
||||
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
|
||||
return nil
|
||||
},
|
||||
GetAvailableModels: func() []extensions.ModelInfoEntry { return kitInstance.GetAvailableModels() },
|
||||
EmitCustomEvent: func(name, data string) { kitInstance.EmitExtensionCustomEvent(name, data) },
|
||||
GetAllTools: func() []extensions.ToolInfo { return kitInstance.GetExtensionToolInfos() },
|
||||
SetActiveTools: func(names []string) { kitInstance.SetExtensionActiveTools(names) },
|
||||
EmitCustomEvent: func(name, data string) { kitInstance.Extensions().EmitCustomEvent(name, data) },
|
||||
GetAllTools: func() []extensions.ToolInfo { return kitInstance.Extensions().GetToolInfos() },
|
||||
SetActiveTools: func(names []string) { kitInstance.Extensions().SetActiveTools(names) },
|
||||
|
||||
// LLM completions and subagents.
|
||||
Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
|
||||
@@ -173,7 +173,7 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
|
||||
}
|
||||
extResult := &extensions.SubagentResult{
|
||||
Response: result.Response,
|
||||
Error: result.Error,
|
||||
Error: err,
|
||||
SessionID: result.SessionID,
|
||||
Elapsed: result.Elapsed,
|
||||
}
|
||||
@@ -188,15 +188,15 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
|
||||
|
||||
// Render — fall back to logging.
|
||||
RenderMessage: func(name, content string) {
|
||||
renderer := kitInstance.GetExtensionMessageRenderer(name)
|
||||
renderer := kitInstance.Extensions().GetMessageRenderer(name)
|
||||
if renderer != nil && renderer.Render != nil {
|
||||
content = renderer.Render(content, 80)
|
||||
}
|
||||
log.Info("extension: message", "renderer", name, "content", content)
|
||||
},
|
||||
ReloadExtensions: func() error { return kitInstance.ReloadExtensions() },
|
||||
ReloadExtensions: func() error { return kitInstance.Extensions().Reload() },
|
||||
})
|
||||
kitInstance.EmitSessionStart()
|
||||
kitInstance.Extensions().EmitSessionStart()
|
||||
}
|
||||
|
||||
sess := &acpSession{
|
||||
|
||||
+33
-28
@@ -31,7 +31,7 @@ type AgentConfig struct {
|
||||
CoreTools []fantasy.AgentTool
|
||||
|
||||
// ToolWrapper is an optional function that wraps the combined tool list
|
||||
// before it is passed to the Fantasy agent. Used by the extensions system
|
||||
// before it is passed to the LLM agent. Used by the extensions system
|
||||
// to intercept tool calls/results.
|
||||
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
|
||||
|
||||
@@ -75,9 +75,9 @@ type ToolOutputHandler = core.ToolOutputCallback
|
||||
// tracking during long-running tool-calling conversations.
|
||||
type StepUsageHandler func(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64)
|
||||
|
||||
// Agent represents an AI agent with core tool integration using the fantasy library.
|
||||
// Agent represents an AI agent with core tool integration using the LLM library.
|
||||
// Core tools (bash, read, write, edit, grep, find, ls) are registered as direct
|
||||
// fantasy.AgentTool implementations — no MCP layer, no serialization overhead.
|
||||
// AgentTool implementations — no MCP layer, no serialization overhead.
|
||||
// Additional tools from external MCP servers can be loaded alongside core tools.
|
||||
type Agent struct {
|
||||
toolManager *tools.MCPToolManager
|
||||
@@ -100,7 +100,7 @@ type GenerateWithLoopResult struct {
|
||||
FinalResponse *fantasy.Response
|
||||
// ConversationMessages contains all messages in the conversation including tool calls and results
|
||||
ConversationMessages []fantasy.Message
|
||||
// Messages contains the conversation as custom content blocks (crush-style)
|
||||
// Messages contains the conversation as custom content blocks
|
||||
Messages []message.Message
|
||||
// TotalUsage contains aggregate token usage across all steps
|
||||
TotalUsage fantasy.Usage
|
||||
@@ -112,13 +112,13 @@ type GenerateWithLoopResult struct {
|
||||
// Core tools (bash, read, write, edit, grep, find, ls) are always registered.
|
||||
// External MCP tools are loaded from the config if any MCP servers are configured.
|
||||
func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
// Create the LLM provider via fantasy
|
||||
// Create the LLM provider
|
||||
providerResult, err := models.CreateProvider(ctx, agentConfig.ModelConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create model provider: %v", err)
|
||||
}
|
||||
|
||||
// Register core tools (direct fantasy implementations, no MCP overhead).
|
||||
// Register core tools (direct AgentTool implementations, no MCP overhead).
|
||||
// Use caller-provided tools if set, otherwise default to all core tools.
|
||||
coreTools := agentConfig.CoreTools
|
||||
if len(coreTools) == 0 {
|
||||
@@ -158,7 +158,7 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
allTools = agentConfig.ToolWrapper(allTools)
|
||||
}
|
||||
|
||||
// Build fantasy agent options
|
||||
// Build agent options
|
||||
var agentOpts []fantasy.AgentOption
|
||||
|
||||
if agentConfig.SystemPrompt != "" {
|
||||
@@ -198,7 +198,7 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Create the fantasy agent
|
||||
// Create the agent
|
||||
fantasyAgent := fantasy.NewAgent(providerResult.Model, agentOpts...)
|
||||
|
||||
// Determine provider type from model string
|
||||
@@ -234,8 +234,8 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []fantasy.Message
|
||||
onResponse, onToolCallContent, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
// GenerateWithLoopAndStreaming processes messages using the fantasy agent with streaming and callbacks.
|
||||
// Fantasy handles the tool call loop internally. We map fantasy's rich callback system
|
||||
// GenerateWithLoopAndStreaming processes messages using the agent with streaming and callbacks.
|
||||
// The agent handles the tool call loop internally. We map the rich callback system
|
||||
// to kit's existing callback interface for UI integration.
|
||||
func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fantasy.Message,
|
||||
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler,
|
||||
@@ -251,18 +251,21 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
ctx = core.ContextWithToolOutputCallback(ctx, onToolOutput)
|
||||
}
|
||||
|
||||
// Fantasy requires the current user input as Prompt, with prior messages as history.
|
||||
// The agent requires the current user input as Prompt, with prior messages as history.
|
||||
// Extract the last user message text and files as the prompt, and pass everything
|
||||
// before it as Messages. Files (e.g. clipboard images) are passed via the Files
|
||||
// field so Fantasy includes them in the API request.
|
||||
// field so the agent includes them in the API request.
|
||||
prompt, files, history := splitPromptAndHistory(messages)
|
||||
|
||||
// Track current tool call info for callbacks
|
||||
var currentToolName string
|
||||
// Apply message-level cache control for Anthropic models.
|
||||
// This avoids type conflicts with provider-level options.
|
||||
history = applyCacheControlToMessages(history)
|
||||
|
||||
// Track current tool call args for callbacks
|
||||
var currentToolArgs string
|
||||
|
||||
// Use the streaming path when streaming is enabled OR when any callbacks are
|
||||
// provided. Fantasy only exposes tool/step callbacks on AgentStreamCall, so
|
||||
// provided. The agent only exposes tool/step callbacks on AgentStreamCall, so
|
||||
// Stream is required to observe tool execution in real time. The non-streaming
|
||||
// Generate path is reserved for the simple case with no callbacks at all.
|
||||
hasCallbacks := onToolCall != nil || onToolExecution != nil || onToolResult != nil ||
|
||||
@@ -270,12 +273,12 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
|
||||
if a.streamingEnabled || hasCallbacks {
|
||||
// Track completed step messages so we can return partial results
|
||||
// on cancellation. Fantasy's Stream() discards accumulated steps
|
||||
// on cancellation. The agent's Stream() discards accumulated steps
|
||||
// when it returns an error, but the OnStepFinish callback fires
|
||||
// for every step that completed before the error occurred.
|
||||
var completedStepMessages []fantasy.Message
|
||||
|
||||
// Use fantasy's streaming agent
|
||||
// Use the streaming agent
|
||||
streamCall := fantasy.AgentStreamCall{
|
||||
Prompt: prompt,
|
||||
Files: files,
|
||||
@@ -308,7 +311,6 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
currentToolName = tc.ToolName
|
||||
currentToolArgs = tc.Input
|
||||
|
||||
// Notify about the tool call
|
||||
@@ -405,6 +407,11 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
onConsumed(len(steered))
|
||||
}
|
||||
}
|
||||
|
||||
// Apply message-level cache control for Anthropic models.
|
||||
// This avoids type conflicts with provider-level options.
|
||||
result.Messages = applyCacheControlToMessages(result.Messages)
|
||||
|
||||
return stepCtx, result, nil
|
||||
}
|
||||
}
|
||||
@@ -452,13 +459,11 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
onResponse(result.Response.Content.Text())
|
||||
}
|
||||
|
||||
_ = currentToolName // satisfy compiler for non-streaming path
|
||||
|
||||
return convertAgentResult(result, messages), nil
|
||||
}
|
||||
|
||||
// splitPromptAndHistory extracts the last user message as the prompt string,
|
||||
// and returns everything before it as conversation history. Fantasy's agent
|
||||
// and returns everything before it as conversation history. The agent's
|
||||
// requires the current turn's input as Prompt (string), with prior messages
|
||||
// passed separately as Messages (history).
|
||||
func splitPromptAndHistory(messages []fantasy.Message) (string, []fantasy.FilePart, []fantasy.Message) {
|
||||
@@ -501,8 +506,8 @@ func splitPromptAndHistory(messages []fantasy.Message) (string, []fantasy.FilePa
|
||||
return "", nil, messages
|
||||
}
|
||||
|
||||
// convertAgentResult converts a fantasy AgentResult to our GenerateWithLoopResult.
|
||||
// It builds both the legacy fantasy.Message slice and the new custom content blocks.
|
||||
// convertAgentResult converts an AgentResult to our GenerateWithLoopResult.
|
||||
// It builds both the message slice and the new custom content blocks.
|
||||
func convertAgentResult(result *fantasy.AgentResult, originalMessages []fantasy.Message) *GenerateWithLoopResult {
|
||||
// Collect all conversation messages: original + all step messages
|
||||
var allFantasyMessages []fantasy.Message
|
||||
@@ -515,7 +520,7 @@ func convertAgentResult(result *fantasy.AgentResult, originalMessages []fantasy.
|
||||
// Convert to custom content blocks
|
||||
var allMessages []message.Message
|
||||
for _, fm := range allFantasyMessages {
|
||||
allMessages = append(allMessages, message.FromFantasyMessage(fm))
|
||||
allMessages = append(allMessages, message.FromLLMMessage(fm))
|
||||
}
|
||||
|
||||
return &GenerateWithLoopResult{
|
||||
@@ -527,7 +532,7 @@ func convertAgentResult(result *fantasy.AgentResult, originalMessages []fantasy.
|
||||
}
|
||||
}
|
||||
|
||||
// extractToolResultText extracts the text and error status from a fantasy ToolResultContent.
|
||||
// extractToolResultText extracts the text and error status from a ToolResultContent.
|
||||
// For core tools, the result is already clean text (no MCP JSON wrapping).
|
||||
// For MCP tools, it unwraps the MCP content structure.
|
||||
func extractToolResultText(tr fantasy.ToolResultContent) (string, bool) {
|
||||
@@ -540,7 +545,7 @@ func extractToolResultText(tr fantasy.ToolResultContent) (string, bool) {
|
||||
return errResult.Error.Error(), true
|
||||
}
|
||||
|
||||
// Get text directly from the Fantasy result type.
|
||||
// Get text directly from the result type.
|
||||
if textResult, ok := tr.Result.(fantasy.ToolResultOutputContentText); ok {
|
||||
// Try to unwrap MCP JSON structure (for external MCP tools).
|
||||
// Core tools return plain text, so this is a no-op for them.
|
||||
@@ -653,7 +658,7 @@ func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) err
|
||||
allTools = a.toolWrapper(allTools)
|
||||
}
|
||||
|
||||
// Rebuild fantasy agent options.
|
||||
// Rebuild agent options.
|
||||
var agentOpts []fantasy.AgentOption
|
||||
if a.systemPrompt != "" {
|
||||
agentOpts = append(agentOpts, fantasy.WithSystemPrompt(a.systemPrompt))
|
||||
@@ -714,7 +719,7 @@ func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) err
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetModel returns the underlying fantasy LanguageModel.
|
||||
// GetModel returns the underlying LanguageModel.
|
||||
func (a *Agent) GetModel() fantasy.LanguageModel {
|
||||
return a.model
|
||||
}
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"charm.land/fantasy"
|
||||
"charm.land/fantasy/providers/anthropic"
|
||||
)
|
||||
|
||||
// cacheControlOptions returns provider options for Anthropic cache control.
|
||||
// This is used at the message level to avoid type conflicts with provider-level options.
|
||||
func cacheControlOptions() fantasy.ProviderOptions {
|
||||
return anthropic.NewProviderCacheControlOptions(&anthropic.ProviderCacheControlOptions{
|
||||
CacheControl: anthropic.CacheControl{
|
||||
Type: "ephemeral",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// applyCacheControlToMessages adds cache control to specific messages.
|
||||
// Anthropic allows max 4 cache blocks per request.
|
||||
// Counts existing cache blocks and only adds new ones up to the limit.
|
||||
func applyCacheControlToMessages(messages []fantasy.Message) []fantasy.Message {
|
||||
if len(messages) == 0 {
|
||||
return messages
|
||||
}
|
||||
|
||||
// Make a copy to avoid modifying the original slice
|
||||
result := make([]fantasy.Message, len(messages))
|
||||
copy(result, messages)
|
||||
|
||||
cacheOpts := cacheControlOptions()
|
||||
maxCacheBlocks := 4
|
||||
|
||||
// Helper to check if message already has cache control
|
||||
hasCache := func(msg fantasy.Message) bool {
|
||||
if msg.ProviderOptions == nil {
|
||||
return false
|
||||
}
|
||||
if _, ok := msg.ProviderOptions["anthropic"]; ok {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Count existing cache blocks
|
||||
existingCacheCount := 0
|
||||
for _, msg := range result {
|
||||
if hasCache(msg) {
|
||||
existingCacheCount++
|
||||
}
|
||||
}
|
||||
|
||||
// If we're already at or over the limit, don't add more
|
||||
if existingCacheCount >= maxCacheBlocks {
|
||||
return result
|
||||
}
|
||||
|
||||
// How many new cache blocks can we add?
|
||||
remaining := maxCacheBlocks - existingCacheCount
|
||||
|
||||
// First: find and cache the last system message (most important)
|
||||
lastSystemIdx := -1
|
||||
for i, msg := range result {
|
||||
if msg.Role == fantasy.MessageRoleSystem {
|
||||
lastSystemIdx = i
|
||||
}
|
||||
}
|
||||
|
||||
if lastSystemIdx >= 0 && remaining > 0 && !hasCache(result[lastSystemIdx]) {
|
||||
result[lastSystemIdx].ProviderOptions = cacheOpts
|
||||
remaining--
|
||||
}
|
||||
|
||||
// Second: cache the most recent messages (up to remaining limit)
|
||||
// Work backwards from the end to prioritize recent context
|
||||
for i := len(result) - 1; i >= 0 && remaining > 0; i-- {
|
||||
if hasCache(result[i]) {
|
||||
continue
|
||||
}
|
||||
result[i].ProviderOptions = cacheOpts
|
||||
remaining--
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -39,7 +39,7 @@ type AgentCreationOptions struct {
|
||||
// CoreTools overrides the default core tool set. If empty, core.AllTools()
|
||||
// is used.
|
||||
CoreTools []fantasy.AgentTool
|
||||
// ToolWrapper wraps the combined tool list before Fantasy agent creation.
|
||||
// ToolWrapper wraps the combined tool list before agent creation.
|
||||
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
|
||||
// ExtraTools are additional tools to include (e.g. from extensions).
|
||||
ExtraTools []fantasy.AgentTool
|
||||
|
||||
+161
-43
@@ -3,8 +3,11 @@ package app
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/fantasy"
|
||||
@@ -66,6 +69,15 @@ type App struct {
|
||||
// rootCtx/rootCancel are used to signal shutdown to all goroutines.
|
||||
rootCtx context.Context
|
||||
rootCancel context.CancelFunc
|
||||
|
||||
// widgetUpdatePending is set to true when a WidgetUpdateEvent has been
|
||||
// sent to the TUI but not yet consumed by its event loop. While the flag
|
||||
// is set, subsequent NotifyWidgetUpdate calls are coalesced (dropped) to
|
||||
// prevent fast extension tickers from flooding the BubbleTea mailbox with
|
||||
// redundant re-render triggers. The flag is cleared after a short debounce
|
||||
// (~1 frame) so new updates are always let through once the TUI has had a
|
||||
// chance to process the pending event.
|
||||
widgetUpdatePending atomic.Bool
|
||||
}
|
||||
|
||||
// New creates a new App with the provided options and pre-loaded messages.
|
||||
@@ -259,6 +271,17 @@ func (a *App) ClearMessages() {
|
||||
}
|
||||
}
|
||||
|
||||
// ReloadMessagesFromTree clears the in-memory message store and reloads it
|
||||
// from the tree session's current branch. Unlike ClearMessages, this does NOT
|
||||
// reset the tree session's leaf pointer. Used after Branch() to sync the
|
||||
// store with the new branch position.
|
||||
func (a *App) ReloadMessagesFromTree() {
|
||||
a.store.Clear()
|
||||
if a.opts.TreeSession != nil {
|
||||
a.store.Replace(a.opts.TreeSession.GetLLMMessages())
|
||||
}
|
||||
}
|
||||
|
||||
// GetTreeSession returns the tree session manager, or nil if not configured.
|
||||
func (a *App) GetTreeSession() *session.TreeManager {
|
||||
return a.opts.TreeSession
|
||||
@@ -280,7 +303,7 @@ func (a *App) SwitchTreeSession(ts *session.TreeManager) {
|
||||
// Reload messages from new session.
|
||||
a.store.Clear()
|
||||
if ts != nil {
|
||||
a.store.Replace(ts.GetFantasyMessages())
|
||||
a.store.Replace(ts.GetLLMMessages())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -296,7 +319,7 @@ func (a *App) AddContextMessage(text string) {
|
||||
|
||||
// Persist to tree session if active.
|
||||
if ts := a.opts.TreeSession; ts != nil {
|
||||
_, _ = ts.AppendFantasyMessage(msg)
|
||||
_, _ = ts.AppendLLMMessage(msg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -346,7 +369,7 @@ func (a *App) CompactConversation(customInstructions string) error {
|
||||
|
||||
// Sync in-memory store with the compacted session.
|
||||
if a.opts.TreeSession != nil {
|
||||
a.store.Replace(a.opts.TreeSession.GetFantasyMessages())
|
||||
a.store.Replace(a.opts.TreeSession.GetLLMMessages())
|
||||
}
|
||||
|
||||
a.sendEvent(CompactCompleteEvent{
|
||||
@@ -567,7 +590,7 @@ func (a *App) runQueueBatch(items []queueItem) {
|
||||
// call/result pairs; only the in-progress message or tool
|
||||
// call is discarded. Sync the in-memory store to match.
|
||||
if ts := a.opts.TreeSession; ts != nil {
|
||||
a.store.Replace(ts.GetFantasyMessages())
|
||||
a.store.Replace(ts.GetLLMMessages())
|
||||
}
|
||||
a.sendEvent(StepCancelledEvent{})
|
||||
return
|
||||
@@ -598,9 +621,10 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe to SDK events for TUI rendering. The subscription is
|
||||
// temporary — it lives only for the duration of this step.
|
||||
unsub := a.subscribeSDKEvents(sendFn)
|
||||
// Subscribe to SDK events for TUI rendering and per-step usage updates.
|
||||
// The subscription is temporary — it lives only for the duration of this step.
|
||||
var sawStepUsage atomic.Bool
|
||||
unsub := a.subscribeSDKEvents(sendFn, &sawStepUsage)
|
||||
defer unsub()
|
||||
|
||||
// Show spinner while the agent works.
|
||||
@@ -620,8 +644,9 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M
|
||||
// Sync in-memory store with the SDK's authoritative conversation.
|
||||
a.store.Replace(result.Messages)
|
||||
|
||||
// Update usage tracker.
|
||||
a.updateUsageFromTurnResult(result, prompt)
|
||||
// Update usage tracker. If per-step usage was already recorded from
|
||||
// StepUsageEvent callbacks, avoid double-counting totals.
|
||||
a.updateUsageFromTurnResult(result, prompt, sawStepUsage.Load())
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -645,9 +670,10 @@ func (a *App) executeBatch(ctx context.Context, items []queueItem, eventFn func(
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe to SDK events for TUI rendering. The subscription is
|
||||
// temporary — it lives only for the duration of this step.
|
||||
unsub := a.subscribeSDKEvents(sendFn)
|
||||
// Subscribe to SDK events for TUI rendering and per-step usage updates.
|
||||
// The subscription is temporary — it lives only for the duration of this step.
|
||||
var sawStepUsage atomic.Bool
|
||||
unsub := a.subscribeSDKEvents(sendFn, &sawStepUsage)
|
||||
defer unsub()
|
||||
|
||||
// Show spinner while the agent works.
|
||||
@@ -680,8 +706,8 @@ func (a *App) executeBatch(ctx context.Context, items []queueItem, eventFn func(
|
||||
messages = append(messages, item.Prompt)
|
||||
}
|
||||
|
||||
// TODO: Handle file attachments in batch mode
|
||||
// For now, files are ignored in batch mode (rare edge case)
|
||||
// File attachments are not supported in batch mode; fall back to
|
||||
// processing only the first item that carries files.
|
||||
if hasFiles {
|
||||
// If files exist, fall back to processing just the first item with files
|
||||
for _, item := range items {
|
||||
@@ -702,8 +728,10 @@ func (a *App) executeBatch(ctx context.Context, items []queueItem, eventFn func(
|
||||
// Sync in-memory store with the SDK's authoritative conversation.
|
||||
a.store.Replace(result.Messages)
|
||||
|
||||
// Update usage tracker (using last item's prompt for tracking).
|
||||
a.updateUsageFromTurnResult(result, items[len(items)-1].Prompt)
|
||||
// Update usage tracker (using last item's prompt for fallback estimation).
|
||||
// If per-step usage was already recorded from StepUsageEvent callbacks,
|
||||
// avoid double-counting totals.
|
||||
a.updateUsageFromTurnResult(result, items[len(items)-1].Prompt, sawStepUsage.Load())
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -720,9 +748,10 @@ func (a *App) sendEvent(msg tea.Msg) {
|
||||
}
|
||||
|
||||
// subscribeSDKEvents registers temporary SDK event subscribers that convert
|
||||
// SDK events to tea.Msg events and dispatch them via sendFn. Returns an
|
||||
// unsubscribe function that removes all listeners.
|
||||
func (a *App) subscribeSDKEvents(sendFn func(tea.Msg)) func() {
|
||||
// SDK events to tea.Msg events and dispatch them via sendFn. When stepUsageSeen
|
||||
// is provided, it is set to true after any non-zero StepUsageEvent is observed.
|
||||
// Returns an unsubscribe function that removes all listeners.
|
||||
func (a *App) subscribeSDKEvents(sendFn func(tea.Msg), stepUsageSeen *atomic.Bool) func() {
|
||||
k := a.opts.Kit
|
||||
var unsubs []func()
|
||||
|
||||
@@ -754,17 +783,10 @@ func (a *App) subscribeSDKEvents(sendFn func(tea.Msg)) func() {
|
||||
Chunk: ev.Chunk,
|
||||
IsStderr: ev.IsStderr,
|
||||
})
|
||||
case kit.StepUsageEvent:
|
||||
if a.opts.UsageTracker != nil {
|
||||
a.opts.UsageTracker.UpdateUsage(
|
||||
int(ev.InputTokens),
|
||||
int(ev.OutputTokens),
|
||||
int(ev.CacheReadTokens),
|
||||
int(ev.CacheWriteTokens),
|
||||
)
|
||||
}
|
||||
case kit.SteerConsumedEvent:
|
||||
sendFn(SteerConsumedEvent{})
|
||||
case kit.StepUsageEvent:
|
||||
a.recordStepUsage(ev, stepUsageSeen)
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -833,12 +855,32 @@ func (a *App) NotifyModelChanged(provider, model string) {
|
||||
// NotifyWidgetUpdate sends a WidgetUpdateEvent to the TUI so it re-renders
|
||||
// extension widgets. Called from the extension context's SetWidget/RemoveWidget
|
||||
// closures. In non-interactive mode this is a no-op (widgets are TUI-only).
|
||||
//
|
||||
// Coalescing: if a WidgetUpdateEvent is already queued and not yet consumed
|
||||
// by the TUI event loop, additional calls within the same ~16 ms window are
|
||||
// dropped. This prevents fast extension tickers from flooding BubbleTea's
|
||||
// mailbox with redundant re-render triggers.
|
||||
func (a *App) NotifyWidgetUpdate() {
|
||||
// Coalesce: only one pending update at a time.
|
||||
if !a.widgetUpdatePending.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
a.mu.Lock()
|
||||
prog := a.program
|
||||
a.mu.Unlock()
|
||||
if prog != nil {
|
||||
prog.Send(WidgetUpdateEvent{})
|
||||
// Reset the pending flag after a short debounce so subsequent calls
|
||||
// within the same render cycle are also coalesced, but new updates
|
||||
// after the cycle are allowed through.
|
||||
go func() {
|
||||
time.Sleep(16 * time.Millisecond) // ~1 frame at 60 fps
|
||||
a.widgetUpdatePending.Store(false)
|
||||
}()
|
||||
} else {
|
||||
// No program registered (non-interactive mode); clear the flag so
|
||||
// future calls are never permanently blocked.
|
||||
a.widgetUpdatePending.Store(false)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -934,30 +976,106 @@ func (a *App) PrintBlockFromExtension(opts extensions.PrintBlockOpts) {
|
||||
}
|
||||
}
|
||||
|
||||
// recordStepUsage applies token/cost usage reported for a completed step.
|
||||
// Step usage events arrive even when a turn is later cancelled, so this keeps
|
||||
// the usage widget accurate on all stop paths.
|
||||
func (a *App) recordStepUsage(ev kit.StepUsageEvent, stepUsageSeen *atomic.Bool) {
|
||||
hasUsage := ev.InputTokens > 0 || ev.OutputTokens > 0 || ev.CacheReadTokens > 0 || ev.CacheWriteTokens > 0
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] recordStepUsage: hasUsage=%v input=%d output=%d cacheRead=%d cacheWrite=%d",
|
||||
hasUsage, ev.InputTokens, ev.OutputTokens, ev.CacheReadTokens, ev.CacheWriteTokens)
|
||||
}
|
||||
if !hasUsage {
|
||||
return
|
||||
}
|
||||
if stepUsageSeen != nil {
|
||||
stepUsageSeen.Store(true)
|
||||
}
|
||||
if a.opts.UsageTracker == nil {
|
||||
return
|
||||
}
|
||||
a.opts.UsageTracker.UpdateUsage(
|
||||
int(ev.InputTokens),
|
||||
int(ev.OutputTokens),
|
||||
int(ev.CacheReadTokens),
|
||||
int(ev.CacheWriteTokens),
|
||||
)
|
||||
// NOTE: We do NOT call SetContextTokens here. Context fill is set once
|
||||
// at turn completion via updateUsageFromTurnResult using FinalUsage.InputTokens,
|
||||
// which reflects the full accumulated context. Per-step context tokens would
|
||||
// cause the display to jump around during multi-step tool calls.
|
||||
}
|
||||
|
||||
// updateUsageFromTurnResult records token usage from an SDK TurnResult into the
|
||||
// configured UsageTracker. This is the SDK-path equivalent of updateUsage.
|
||||
func (a *App) updateUsageFromTurnResult(result *kit.TurnResult, userPrompt string) {
|
||||
// configured UsageTracker. Called once per turn after the turn completes.
|
||||
//
|
||||
// When sawStepUsage is true, totals were already accumulated incrementally via
|
||||
// StepUsageEvent callbacks; in that case this method only updates context fill.
|
||||
// Otherwise it falls back to TotalUsage from the API response.
|
||||
//
|
||||
// NOTE: We only use ACTUAL token counts from API responses for cost tracking.
|
||||
// Estimation is never used for costs - only API-reported tokens are accurate.
|
||||
func (a *App) updateUsageFromTurnResult(result *kit.TurnResult, userPrompt string, sawStepUsage bool) {
|
||||
if a.opts.UsageTracker == nil || result == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if result.TotalUsage != nil {
|
||||
inputTokens := int(result.TotalUsage.InputTokens)
|
||||
outputTokens := int(result.TotalUsage.OutputTokens)
|
||||
// Use API-reported tokens if input tokens are available (output may be 0 in some cases)
|
||||
if inputTokens > 0 {
|
||||
cacheReadTokens := int(result.TotalUsage.CacheReadTokens)
|
||||
cacheWriteTokens := int(result.TotalUsage.CacheCreationTokens)
|
||||
a.opts.UsageTracker.UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens)
|
||||
// Debug logging for token tracking
|
||||
if a.opts.Debug {
|
||||
if result.TotalUsage != nil {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult TotalUsage: input=%d output=%d cacheRead=%d cacheCreate=%d",
|
||||
result.TotalUsage.InputTokens, result.TotalUsage.OutputTokens,
|
||||
result.TotalUsage.CacheReadTokens, result.TotalUsage.CacheCreationTokens)
|
||||
} else {
|
||||
a.opts.UsageTracker.EstimateAndUpdateUsage(userPrompt, result.Response)
|
||||
return
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: TotalUsage=nil")
|
||||
}
|
||||
if result.FinalUsage != nil {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult FinalUsage: input=%d output=%d cacheRead=%d cacheCreate=%d",
|
||||
result.FinalUsage.InputTokens, result.FinalUsage.OutputTokens,
|
||||
result.FinalUsage.CacheReadTokens, result.FinalUsage.CacheCreationTokens)
|
||||
} else {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: FinalUsage=nil")
|
||||
}
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: sawStepUsage=%v", sawStepUsage)
|
||||
}
|
||||
|
||||
if result.FinalUsage != nil {
|
||||
if ct := int(result.FinalUsage.InputTokens) + int(result.FinalUsage.OutputTokens); ct > 0 {
|
||||
a.opts.UsageTracker.SetContextTokens(ct)
|
||||
// --- Accumulate cost/token totals for the session ---
|
||||
// Only use actual API-reported tokens for cost tracking.
|
||||
// If sawStepUsage is true, totals were already updated via StepUsageEvent.
|
||||
// Check any token field > 0 (not just InputTokens) because cached prompts
|
||||
// can result in InputTokens=0 while OutputTokens>0 (OpenAI-compatible behavior).
|
||||
hasTotalUsage := result.TotalUsage != nil &&
|
||||
(result.TotalUsage.InputTokens > 0 ||
|
||||
result.TotalUsage.OutputTokens > 0 ||
|
||||
result.TotalUsage.CacheReadTokens > 0 ||
|
||||
result.TotalUsage.CacheCreationTokens > 0)
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: hasTotalUsage=%v", hasTotalUsage)
|
||||
}
|
||||
if !sawStepUsage && hasTotalUsage {
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: calling UpdateUsage input=%d output=%d cacheRead=%d cacheCreate=%d",
|
||||
result.TotalUsage.InputTokens, result.TotalUsage.OutputTokens,
|
||||
result.TotalUsage.CacheReadTokens, result.TotalUsage.CacheCreationTokens)
|
||||
}
|
||||
a.opts.UsageTracker.UpdateUsage(
|
||||
int(result.TotalUsage.InputTokens),
|
||||
int(result.TotalUsage.OutputTokens),
|
||||
int(result.TotalUsage.CacheReadTokens),
|
||||
int(result.TotalUsage.CacheCreationTokens),
|
||||
)
|
||||
}
|
||||
|
||||
// --- Context window fill (drives the % bar) ---
|
||||
// Use FinalUsage.InputTokens as the context window fill. The API's InputTokens
|
||||
// already includes the full conversation history (system prompt + all previous
|
||||
// messages + current user message). Adding OutputTokens would double-count since
|
||||
// the output becomes part of the input for the next turn.
|
||||
if result.FinalUsage != nil && result.FinalUsage.InputTokens > 0 {
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: calling SetContextTokens=%d (FinalUsage.InputTokens)",
|
||||
result.FinalUsage.InputTokens)
|
||||
}
|
||||
a.opts.UsageTracker.SetContextTokens(int(result.FinalUsage.InputTokens))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
@@ -14,6 +16,47 @@ import (
|
||||
// Helpers
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
type usageUpdaterStub struct {
|
||||
mu sync.Mutex
|
||||
|
||||
updateCalls int
|
||||
estimateCalls int
|
||||
contextCalls int
|
||||
|
||||
lastUpdateInput int
|
||||
lastUpdateOutput int
|
||||
lastUpdateCacheRead int
|
||||
lastUpdateCacheWrite int
|
||||
lastContextTokens int
|
||||
lastEstimateInput string
|
||||
lastEstimateOutput string
|
||||
}
|
||||
|
||||
func (s *usageUpdaterStub) UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.updateCalls++
|
||||
s.lastUpdateInput = inputTokens
|
||||
s.lastUpdateOutput = outputTokens
|
||||
s.lastUpdateCacheRead = cacheReadTokens
|
||||
s.lastUpdateCacheWrite = cacheWriteTokens
|
||||
}
|
||||
|
||||
func (s *usageUpdaterStub) EstimateAndUpdateUsage(inputText, outputText string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.estimateCalls++
|
||||
s.lastEstimateInput = inputText
|
||||
s.lastEstimateOutput = outputText
|
||||
}
|
||||
|
||||
func (s *usageUpdaterStub) SetContextTokens(tokens int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.contextCalls++
|
||||
s.lastContextTokens = tokens
|
||||
}
|
||||
|
||||
// turnResult builds a minimal TurnResult with response text t.
|
||||
func turnResult(t string) *kit.TurnResult {
|
||||
return &kit.TurnResult{Response: t}
|
||||
@@ -489,3 +532,133 @@ func TestQueueLength_reflects(t *testing.T) {
|
||||
t.Fatalf("expected 3, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecordStepUsage_updatesTracker verifies that per-step usage updates are
|
||||
// recorded immediately for cost tracking. Context tokens are NOT updated here
|
||||
// (only via updateUsageFromTurnResult) to avoid display jumps during multi-step
|
||||
// tool calls.
|
||||
func TestRecordStepUsage_updatesTracker(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
defer app.Close()
|
||||
|
||||
app.recordStepUsage(kit.StepUsageEvent{
|
||||
InputTokens: 120,
|
||||
OutputTokens: 45,
|
||||
CacheReadTokens: 5,
|
||||
CacheWriteTokens: 2,
|
||||
}, nil)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
|
||||
if usage.updateCalls != 1 {
|
||||
t.Fatalf("expected 1 update call, got %d", usage.updateCalls)
|
||||
}
|
||||
if usage.lastUpdateInput != 120 || usage.lastUpdateOutput != 45 || usage.lastUpdateCacheRead != 5 || usage.lastUpdateCacheWrite != 2 {
|
||||
t.Fatalf("unexpected usage update payload: in=%d out=%d cache_read=%d cache_write=%d",
|
||||
usage.lastUpdateInput, usage.lastUpdateOutput, usage.lastUpdateCacheRead, usage.lastUpdateCacheWrite)
|
||||
}
|
||||
// Context tokens should NOT be updated by recordStepUsage (only by updateUsageFromTurnResult)
|
||||
if usage.contextCalls != 0 {
|
||||
t.Fatalf("expected 0 context token updates from recordStepUsage, got %d", usage.contextCalls)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateUsageFromTurnResult_skipsTotalsWhenStepUsageSeen ensures we avoid
|
||||
// double-counting totals once StepUsageEvent-based updates were already applied.
|
||||
func TestUpdateUsageFromTurnResult_skipsTotalsWhenStepUsageSeen(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
defer app.Close()
|
||||
|
||||
app.updateUsageFromTurnResult(&kit.TurnResult{
|
||||
Response: "ok",
|
||||
TotalUsage: &fantasy.Usage{
|
||||
InputTokens: 999,
|
||||
OutputTokens: 111,
|
||||
CacheReadTokens: 7,
|
||||
CacheCreationTokens: 3,
|
||||
},
|
||||
FinalUsage: &fantasy.Usage{InputTokens: 456},
|
||||
}, "prompt", true)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
|
||||
if usage.updateCalls != 0 {
|
||||
t.Fatalf("expected no total usage update when sawStepUsage=true, got %d", usage.updateCalls)
|
||||
}
|
||||
if usage.estimateCalls != 0 {
|
||||
t.Fatalf("expected no estimate update when sawStepUsage=true, got %d", usage.estimateCalls)
|
||||
}
|
||||
// Context tokens should be InputTokens only (456)
|
||||
if usage.contextCalls != 1 || usage.lastContextTokens != 456 {
|
||||
t.Fatalf("expected final context tokens=456 (InputTokens only), got calls=%d tokens=%d", usage.contextCalls, usage.lastContextTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateUsageFromTurnResult_recordsWhenInputTokensZero verifies that usage
|
||||
// is recorded when InputTokens=0 but OutputTokens>0 (OpenAI-compatible cache behavior).
|
||||
func TestUpdateUsageFromTurnResult_recordsWhenInputTokensZero(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
defer app.Close()
|
||||
|
||||
// Simulate OpenAI-compatible behavior: all prompt tokens cached, InputTokens=0
|
||||
app.updateUsageFromTurnResult(&kit.TurnResult{
|
||||
Response: "ok",
|
||||
TotalUsage: &fantasy.Usage{
|
||||
InputTokens: 0, // All cached - subtracted from prompt
|
||||
OutputTokens: 150, // Actual generated tokens
|
||||
CacheReadTokens: 500, // Cache hit
|
||||
CacheCreationTokens: 0,
|
||||
},
|
||||
FinalUsage: &fantasy.Usage{InputTokens: 0, OutputTokens: 150},
|
||||
}, "prompt", false)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
|
||||
if usage.updateCalls != 1 {
|
||||
t.Fatalf("expected 1 update call when InputTokens=0 but OutputTokens>0, got %d", usage.updateCalls)
|
||||
}
|
||||
if usage.lastUpdateInput != 0 || usage.lastUpdateOutput != 150 {
|
||||
t.Fatalf("expected input=0 output=150, got input=%d output=%d",
|
||||
usage.lastUpdateInput, usage.lastUpdateOutput)
|
||||
}
|
||||
if usage.lastUpdateCacheRead != 500 {
|
||||
t.Fatalf("expected cache_read=500, got %d", usage.lastUpdateCacheRead)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateUsageFromTurnResult_contextTokensUsesInputOnly verifies that context
|
||||
// window fill uses InputTokens only (not input+output). The API's InputTokens
|
||||
// already includes the full conversation history; adding output would double-count.
|
||||
func TestUpdateUsageFromTurnResult_contextTokensUsesInputOnly(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
defer app.Close()
|
||||
|
||||
app.updateUsageFromTurnResult(&kit.TurnResult{
|
||||
Response: "ok",
|
||||
TotalUsage: &fantasy.Usage{
|
||||
InputTokens: 1000,
|
||||
OutputTokens: 200,
|
||||
},
|
||||
FinalUsage: &fantasy.Usage{
|
||||
InputTokens: 1000, // Full context including history
|
||||
OutputTokens: 200,
|
||||
},
|
||||
}, "prompt", false)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
|
||||
// Context tokens should be InputTokens only (1000), not input+output (1200)
|
||||
// because InputTokens already includes the full conversation history
|
||||
if usage.contextCalls != 1 || usage.lastContextTokens != 1000 {
|
||||
t.Fatalf("expected context tokens=1000 (InputTokens only), got calls=%d tokens=%d",
|
||||
usage.contextCalls, usage.lastContextTokens)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,13 +43,30 @@ type OpenAICredentials struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// oauthTokenExpired reports whether an OAuth token with the given type and
|
||||
// expiry unix timestamp is past its expiry. Returns false for API key
|
||||
// credentials or when no expiry is set.
|
||||
func oauthTokenExpired(credType string, expiresAt int64) bool {
|
||||
if credType != "oauth" || expiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= expiresAt
|
||||
}
|
||||
|
||||
// oauthTokenNeedsRefresh reports whether an OAuth token will expire within the
|
||||
// next 5 minutes, allowing proactive refresh before it becomes invalid.
|
||||
// Returns false for API key credentials or when no expiry is set.
|
||||
func oauthTokenNeedsRefresh(credType string, expiresAt int64) bool {
|
||||
if credType != "oauth" || expiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= (expiresAt - 300) // 5 minutes buffer
|
||||
}
|
||||
|
||||
// IsExpired checks if the OAuth token is expired based on the ExpiresAt timestamp.
|
||||
// Returns false for API key authentication or if no expiration is set.
|
||||
func (c *AnthropicCredentials) IsExpired() bool {
|
||||
if c.Type != "oauth" || c.ExpiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= c.ExpiresAt
|
||||
return oauthTokenExpired(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// NeedsRefresh checks if the OAuth token needs refresh, returning true if the token
|
||||
@@ -57,19 +74,13 @@ func (c *AnthropicCredentials) IsExpired() bool {
|
||||
// to avoid authentication failures during operations. Returns false for API key
|
||||
// authentication or if no expiration is set.
|
||||
func (c *AnthropicCredentials) NeedsRefresh() bool {
|
||||
if c.Type != "oauth" || c.ExpiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= (c.ExpiresAt - 300) // 5 minutes buffer
|
||||
return oauthTokenNeedsRefresh(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// IsExpired checks if the OAuth token is expired based on the ExpiresAt timestamp.
|
||||
// Returns false for API key authentication or if no expiration is set.
|
||||
func (c *OpenAICredentials) IsExpired() bool {
|
||||
if c.Type != "oauth" || c.ExpiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= c.ExpiresAt
|
||||
return oauthTokenExpired(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// NeedsRefresh checks if the OAuth token needs refresh, returning true if the token
|
||||
@@ -77,10 +88,7 @@ func (c *OpenAICredentials) IsExpired() bool {
|
||||
// to avoid authentication failures during operations. Returns false for API key
|
||||
// authentication or if no expiration is set.
|
||||
func (c *OpenAICredentials) NeedsRefresh() bool {
|
||||
if c.Type != "oauth" || c.ExpiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= (c.ExpiresAt - 300) // 5 minutes buffer
|
||||
return oauthTokenNeedsRefresh(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// CredentialManager handles secure storage and retrieval of authentication credentials.
|
||||
|
||||
@@ -403,10 +403,9 @@ func FilepathOr[T any](key string, value *T) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
filepath.Join(home, absPath[2:])
|
||||
absPath = filepath.Join(home, absPath[2:])
|
||||
}
|
||||
if !filepath.IsAbs(absPath) {
|
||||
// base := GetConfigPath()
|
||||
base := configPath
|
||||
if base == "" {
|
||||
fmt.Fprintf(os.Stderr, "unable to build relative path to config.")
|
||||
|
||||
+5
-18
@@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -39,20 +40,8 @@ func toolOutputCallbackFromContext(ctx context.Context) ToolOutputCallback {
|
||||
const defaultBashTimeout = 120 * time.Second
|
||||
const maxBashTimeout = 600 * time.Second
|
||||
|
||||
var bannedCommands = []string{
|
||||
"alias ", "bg ", "bind ", "builtin ",
|
||||
"caller ", "command ", "compgen ",
|
||||
"complete ", "compopt ", "coproc ",
|
||||
"dirs ", "disown ", "enable ",
|
||||
"fc ", "fg ", "hash ", "help ",
|
||||
"history ", "jobs ", "kill ",
|
||||
"logout ", "mapfile ", "popd ",
|
||||
"pushd ", "readonly ", "select ",
|
||||
"set ", "shopt ", "source ",
|
||||
"suspend ", "times ", "trap ",
|
||||
"type ", "typeset ", "ulimit ",
|
||||
"umask ", "unalias ", "wait ",
|
||||
}
|
||||
// bannedCmdRe matches bash builtin commands that are not allowed for security reasons.
|
||||
var bannedCmdRe = regexp.MustCompile(`^(alias|bg|bind|builtin|caller|command|compgen|complete|compopt|coproc|dirs|disown|enable|fc|fg|hash|help|history|jobs|kill|logout|mapfile|popd|pushd|readonly|select|set|shopt|source|suspend|times|trap|type|typeset|ulimit|umask|unalias|wait)\s`)
|
||||
|
||||
type bashArgs struct {
|
||||
Command string `json:"command"`
|
||||
@@ -94,10 +83,8 @@ func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
}
|
||||
|
||||
// Check for banned commands
|
||||
for _, banned := range bannedCommands {
|
||||
if strings.HasPrefix(args.Command, banned) {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", args.Command)), nil
|
||||
}
|
||||
if bannedCmdRe.MatchString(args.Command) {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", args.Command)), nil
|
||||
}
|
||||
|
||||
// Determine timeout
|
||||
|
||||
+234
-44
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
@@ -13,19 +14,45 @@ import (
|
||||
udiff "github.com/aymanbagabas/go-udiff"
|
||||
)
|
||||
|
||||
type editArgs struct {
|
||||
Path string `json:"path"`
|
||||
// Edit represents a single replacement in a multi-edit operation.
|
||||
type Edit struct {
|
||||
OldText string `json:"old_text"`
|
||||
NewText string `json:"new_text"`
|
||||
}
|
||||
|
||||
// editArgs holds the arguments for the edit tool.
|
||||
// Supports both single-edit mode (old_text/new_text) and multi-edit mode (edits array).
|
||||
type editArgs struct {
|
||||
Path string `json:"path"`
|
||||
OldText string `json:"old_text"` // Single-edit mode
|
||||
NewText string `json:"new_text"` // Single-edit mode
|
||||
Edits []Edit `json:"edits"` // Multi-edit mode
|
||||
}
|
||||
|
||||
// replacement represents a normalized edit ready for processing.
|
||||
type replacement struct {
|
||||
oldText string // normalized old text for matching
|
||||
newText string // normalized new text
|
||||
originalOld string // original old text for metadata
|
||||
originalNew string // original new text for metadata
|
||||
index int // index in the original edits array (for error messages)
|
||||
}
|
||||
|
||||
// matchedReplacement represents a replacement with its match location.
|
||||
type matchedReplacement struct {
|
||||
replacement
|
||||
start int // start index in normalized content
|
||||
end int // end index in normalized content
|
||||
usedFuzzyMatch bool // true if fuzzy matching was used
|
||||
}
|
||||
|
||||
// NewEditTool creates the edit core tool.
|
||||
func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
cfg := ApplyOptions(opts)
|
||||
return &coreTool{
|
||||
info: fantasy.ToolInfo{
|
||||
Name: "edit",
|
||||
Description: "Edit a file by replacing exact text. The old_text must match exactly (including whitespace). Use this for precise, surgical edits. Fails if old_text is not found or matches multiple locations.",
|
||||
Description: "Edit a file by replacing exact text. Supports single edit via old_text/new_text, or multiple edits via the edits array. All edits in the array are matched against the original file content (non-incremental) and must be non-overlapping.",
|
||||
Parameters: map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
@@ -33,14 +60,32 @@ func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
},
|
||||
"old_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Exact text to find and replace (must match exactly)",
|
||||
"description": "Exact text to find and replace (single-edit mode). Must not be used with 'edits' array.",
|
||||
},
|
||||
"new_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "New text to replace the old text with",
|
||||
"description": "New text to replace the old text with (single-edit mode). Must not be used with 'edits' array.",
|
||||
},
|
||||
"edits": map[string]any{
|
||||
"type": "array",
|
||||
"description": "Array of edits for multi-region replacement. Each edit must have unique, non-overlapping old_text. All matches are against the original file content.",
|
||||
"items": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"old_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Exact text to find and replace for this edit",
|
||||
},
|
||||
"new_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "New text for this edit",
|
||||
},
|
||||
},
|
||||
"required": []string{"old_text", "new_text"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Required: []string{"path", "old_text", "new_text"},
|
||||
Required: []string{"path"},
|
||||
},
|
||||
handler: func(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return executeEdit(ctx, call, cfg.WorkDir)
|
||||
@@ -51,7 +96,7 @@ func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
func executeEdit(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
|
||||
var args editArgs
|
||||
if err := parseArgs(call.Input, &args); err != nil {
|
||||
return fantasy.NewTextErrorResponse("path, old_text, and new_text parameters are required"), nil
|
||||
return fantasy.NewTextErrorResponse("failed to parse arguments: " + err.Error()), nil
|
||||
}
|
||||
if args.Path == "" {
|
||||
return fantasy.NewTextErrorResponse("path parameter is required"), nil
|
||||
@@ -69,56 +114,201 @@ func executeEdit(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
|
||||
content := string(contentBytes)
|
||||
|
||||
// Normalize line endings for matching
|
||||
normalized := strings.ReplaceAll(content, "\r\n", "\n")
|
||||
normalizedOld := strings.ReplaceAll(args.OldText, "\r\n", "\n")
|
||||
|
||||
// Try exact match first
|
||||
count := strings.Count(normalized, normalizedOld)
|
||||
|
||||
// If no exact match, try fuzzy matching
|
||||
if count == 0 {
|
||||
if idx, matchLen := fuzzyMatch(normalized, normalizedOld); idx >= 0 {
|
||||
// Apply fuzzy match — the matched text is the original content slice
|
||||
matchedText := normalized[idx : idx+matchLen]
|
||||
newContent := normalized[:idx] + args.NewText + normalized[idx+matchLen:]
|
||||
if err := os.WriteFile(absPath, []byte(newContent), 0644); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to write file: %v", err)), nil
|
||||
}
|
||||
diff := generateDiff(absPath, normalized, newContent)
|
||||
resp := fantasy.NewTextResponse(fmt.Sprintf("Applied edit (fuzzy match) to %s\n%s", args.Path, diff))
|
||||
return fantasy.WithResponseMetadata(resp, editDiffMeta(absPath, matchedText, args.NewText)), nil
|
||||
}
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("old_text not found in %s", args.Path)), nil
|
||||
// Normalize and validate input
|
||||
replacements, err := normalizeEditInput(args)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
if count > 1 {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("found %d matches for old_text in %s. Provide more context to identify the correct match.", count, args.Path)), nil
|
||||
// Apply all edits
|
||||
newContent, applied, err := applyEdits(content, replacements)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
// Apply the edit
|
||||
newContent := strings.Replace(normalized, normalizedOld, args.NewText, 1)
|
||||
|
||||
// Write the file
|
||||
if err := os.WriteFile(absPath, []byte(newContent), 0644); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to write file: %v", err)), nil
|
||||
}
|
||||
|
||||
diff := generateDiff(absPath, normalized, newContent)
|
||||
resp := fantasy.NewTextResponse(fmt.Sprintf("Applied edit to %s\n%s", args.Path, diff))
|
||||
return fantasy.WithResponseMetadata(resp, editDiffMeta(absPath, normalizedOld, args.NewText)), nil
|
||||
// Generate diff
|
||||
normalizedContent := strings.ReplaceAll(content, "\r\n", "\n")
|
||||
diff := generateDiff(absPath, normalizedContent, newContent)
|
||||
|
||||
// Build response with fuzzy match indication
|
||||
fuzzyCount := 0
|
||||
for _, m := range applied {
|
||||
if m.usedFuzzyMatch {
|
||||
fuzzyCount++
|
||||
}
|
||||
}
|
||||
|
||||
var msg string
|
||||
if len(applied) == 1 {
|
||||
if fuzzyCount > 0 {
|
||||
msg = fmt.Sprintf("Applied edit (fuzzy match) to %s\n%s", args.Path, diff)
|
||||
} else {
|
||||
msg = fmt.Sprintf("Applied edit to %s\n%s", args.Path, diff)
|
||||
}
|
||||
} else {
|
||||
if fuzzyCount > 0 {
|
||||
msg = fmt.Sprintf("Applied %d edits (%d fuzzy) to %s\n%s", len(applied), fuzzyCount, args.Path, diff)
|
||||
} else {
|
||||
msg = fmt.Sprintf("Applied %d edits to %s\n%s", len(applied), args.Path, diff)
|
||||
}
|
||||
}
|
||||
|
||||
resp := fantasy.NewTextResponse(msg)
|
||||
return fantasy.WithResponseMetadata(resp, editDiffMeta(absPath, applied)), nil
|
||||
}
|
||||
|
||||
// normalizeEditInput validates and normalizes the edit input.
|
||||
// Returns error if both single-edit and multi-edit modes are used.
|
||||
func normalizeEditInput(args editArgs) ([]replacement, error) {
|
||||
singleMode := args.OldText != "" || args.NewText != ""
|
||||
multiMode := len(args.Edits) > 0
|
||||
|
||||
if singleMode && multiMode {
|
||||
return nil, fmt.Errorf("cannot use old_text/new_text together with edits array")
|
||||
}
|
||||
|
||||
if !singleMode && !multiMode {
|
||||
return nil, fmt.Errorf("must provide either old_text/new_text or edits array")
|
||||
}
|
||||
|
||||
if singleMode {
|
||||
if args.OldText == "" {
|
||||
return nil, fmt.Errorf("old_text is required when using single-edit mode")
|
||||
}
|
||||
if args.NewText == "" {
|
||||
return nil, fmt.Errorf("new_text is required when using single-edit mode")
|
||||
}
|
||||
return []replacement{{
|
||||
oldText: strings.ReplaceAll(args.OldText, "\r\n", "\n"),
|
||||
newText: strings.ReplaceAll(args.NewText, "\r\n", "\n"),
|
||||
originalOld: args.OldText,
|
||||
originalNew: args.NewText,
|
||||
index: 0,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// Multi-edit mode
|
||||
var reps []replacement
|
||||
for i, edit := range args.Edits {
|
||||
if edit.OldText == "" {
|
||||
return nil, fmt.Errorf("edits[%d].old_text is required", i)
|
||||
}
|
||||
reps = append(reps, replacement{
|
||||
oldText: strings.ReplaceAll(edit.OldText, "\r\n", "\n"),
|
||||
newText: strings.ReplaceAll(edit.NewText, "\r\n", "\n"),
|
||||
originalOld: edit.OldText,
|
||||
originalNew: edit.NewText,
|
||||
index: i,
|
||||
})
|
||||
}
|
||||
return reps, nil
|
||||
}
|
||||
|
||||
// applyEdits applies multiple replacements to the content.
|
||||
// All matches are against the original content (non-incremental).
|
||||
// Returns the new content, the applied matches, and any error.
|
||||
func applyEdits(content string, edits []replacement) (string, []matchedReplacement, error) {
|
||||
normalizedContent := strings.ReplaceAll(content, "\r\n", "\n")
|
||||
|
||||
// Find all matches
|
||||
var matched []matchedReplacement
|
||||
for _, edit := range edits {
|
||||
m, err := findMatch(normalizedContent, edit)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
matched = append(matched, *m)
|
||||
}
|
||||
|
||||
// Sort by position
|
||||
sort.Slice(matched, func(i, j int) bool {
|
||||
return matched[i].start < matched[j].start
|
||||
})
|
||||
|
||||
// Check for overlaps
|
||||
for i := 1; i < len(matched); i++ {
|
||||
if matched[i-1].end > matched[i].start {
|
||||
return "", nil, fmt.Errorf("edits[%d] and edits[%d] overlap; merge them into a single edit",
|
||||
matched[i-1].index, matched[i].index)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply edits in reverse order (end to start) to maintain stable offsets
|
||||
result := normalizedContent
|
||||
for i := len(matched) - 1; i >= 0; i-- {
|
||||
m := matched[i]
|
||||
result = result[:m.start] + m.newText + result[m.end:]
|
||||
}
|
||||
|
||||
return result, matched, nil
|
||||
}
|
||||
|
||||
// findMatch finds a unique match for the edit in the content.
|
||||
// Returns error if not found or ambiguous.
|
||||
func findMatch(content string, edit replacement) (*matchedReplacement, error) {
|
||||
// Try exact match first
|
||||
count := strings.Count(content, edit.oldText)
|
||||
|
||||
if count == 0 {
|
||||
// Try fuzzy match
|
||||
idx, matchLen := fuzzyMatch(content, edit.oldText)
|
||||
if idx < 0 {
|
||||
return nil, fmt.Errorf("edits[%d]: could not find old_text in file. The text must match exactly (including whitespace)", edit.index)
|
||||
}
|
||||
// Use the matched text from content for the replacement
|
||||
matchedText := content[idx : idx+matchLen]
|
||||
return &matchedReplacement{
|
||||
replacement: replacement{
|
||||
oldText: matchedText,
|
||||
newText: edit.newText,
|
||||
originalOld: edit.originalOld,
|
||||
originalNew: edit.originalNew,
|
||||
index: edit.index,
|
||||
},
|
||||
start: idx,
|
||||
end: idx + matchLen,
|
||||
usedFuzzyMatch: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if count > 1 {
|
||||
return nil, fmt.Errorf("found %d matches for edits[%d].old_text; each old_text must be unique, provide more context to identify the correct match", count, edit.index)
|
||||
}
|
||||
|
||||
// Single exact match
|
||||
idx := strings.Index(content, edit.oldText)
|
||||
return &matchedReplacement{
|
||||
replacement: edit,
|
||||
start: idx,
|
||||
end: idx + len(edit.oldText),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// editDiffMeta builds the structured metadata attached to edit tool responses.
|
||||
func editDiffMeta(path, oldText, newText string) map[string]any {
|
||||
func editDiffMeta(path string, applied []matchedReplacement) map[string]any {
|
||||
var diffBlocks []map[string]any
|
||||
totalAdditions, totalDeletions := 0, 0
|
||||
|
||||
for _, m := range applied {
|
||||
diffBlocks = append(diffBlocks, map[string]any{
|
||||
"old_text": m.originalOld,
|
||||
"new_text": m.originalNew,
|
||||
})
|
||||
totalAdditions += strings.Count(m.originalNew, "\n") + 1
|
||||
totalDeletions += strings.Count(m.originalOld, "\n") + 1
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"file_diffs": []map[string]any{{
|
||||
"path": path,
|
||||
"additions": strings.Count(newText, "\n") + 1,
|
||||
"deletions": strings.Count(oldText, "\n") + 1,
|
||||
"diff_blocks": []map[string]any{{
|
||||
"old_text": oldText,
|
||||
"new_text": newText,
|
||||
}},
|
||||
"path": path,
|
||||
"additions": totalAdditions,
|
||||
"deletions": totalDeletions,
|
||||
"diff_blocks": diffBlocks,
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -715,3 +715,315 @@ func TestExecuteEdit_MetadataContainsFileDiffs(t *testing.T) {
|
||||
t.Fatal("file_diffs should be a non-empty array")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Multi-edit tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExecuteEdit_MultiEdit_Basic(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "multi.txt")
|
||||
writeFileOrFail(t, path, "line1\nline2\nline3\nline4\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "line1", NewText: "LINE1"},
|
||||
{OldText: "line3", NewText: "LINE3"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("tool returned error: %s", resp.Content)
|
||||
}
|
||||
|
||||
got, _ := os.ReadFile(path)
|
||||
gotStr := string(got)
|
||||
|
||||
if !strings.Contains(gotStr, "LINE1") {
|
||||
t.Error("first edit not applied: missing LINE1")
|
||||
}
|
||||
if !strings.Contains(gotStr, "LINE3") {
|
||||
t.Error("second edit not applied: missing LINE3")
|
||||
}
|
||||
if !strings.Contains(gotStr, "line2") {
|
||||
t.Error("line2 was modified but should be untouched")
|
||||
}
|
||||
if !strings.Contains(gotStr, "line4") {
|
||||
t.Error("line4 was modified but should be untouched")
|
||||
}
|
||||
|
||||
// Check response mentions multiple edits
|
||||
if !strings.Contains(resp.Content, "2 edits") {
|
||||
t.Errorf("response should mention '2 edits', got: %s", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_NonIncrementalMatching(t *testing.T) {
|
||||
// All edits are matched against the original content, not incrementally
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "noninc.txt")
|
||||
writeFileOrFail(t, path, "aaa\nbbb\nccc\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "aaa", NewText: "AAA"},
|
||||
{OldText: "bbb", NewText: "BBB"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("tool returned error: %s", resp.Content)
|
||||
}
|
||||
|
||||
got, _ := os.ReadFile(path)
|
||||
gotStr := string(got)
|
||||
|
||||
want := "AAA\nBBB\nccc\n"
|
||||
if gotStr != want {
|
||||
t.Errorf("got %q, want %q", gotStr, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_OverlapDetection(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "overlap.txt")
|
||||
writeFileOrFail(t, path, "hello world\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "hello", NewText: "HELLO"},
|
||||
{OldText: "hello world", NewText: "GOODBYE"}, // Overlaps with first edit
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error for overlapping edits")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "overlap") {
|
||||
t.Errorf("expected 'overlap' in error, got: %s", resp.Content)
|
||||
}
|
||||
|
||||
// File should be untouched
|
||||
got, _ := os.ReadFile(path)
|
||||
if string(got) != "hello world\n" {
|
||||
t.Error("file was modified despite error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_DuplicateDetection(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "dup.txt")
|
||||
writeFileOrFail(t, path, "hello\nworld\nhello\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "hello", NewText: "HELLO"},
|
||||
{OldText: "world", NewText: "WORLD"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error for ambiguous old_text (duplicate matches)")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "unique") {
|
||||
t.Errorf("expected 'unique' in error, got: %s", resp.Content)
|
||||
}
|
||||
|
||||
// File should be untouched
|
||||
got, _ := os.ReadFile(path)
|
||||
if string(got) != "hello\nworld\nhello\n" {
|
||||
t.Error("file was modified despite error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_NotFound(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "notfound.txt")
|
||||
writeFileOrFail(t, path, "hello world\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "nonexistent", NewText: "REPLACEMENT"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error for not found")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "edits[0]") {
|
||||
t.Errorf("expected 'edits[0]' in error, got: %s", resp.Content)
|
||||
}
|
||||
|
||||
// File should be untouched
|
||||
got, _ := os.ReadFile(path)
|
||||
if string(got) != "hello world\n" {
|
||||
t.Error("file was modified despite error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_EmptyArray(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "empty.txt")
|
||||
writeFileOrFail(t, path, "hello\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error for empty edits array")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_MixedWithSingleMode(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "mixed.txt")
|
||||
writeFileOrFail(t, path, "hello\n")
|
||||
|
||||
input, _ := json.Marshal(map[string]any{
|
||||
"path": path,
|
||||
"old_text": "hello",
|
||||
"new_text": "HELLO",
|
||||
"edits": []Edit{
|
||||
{OldText: "hello", NewText: "HI"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error when mixing single and multi-edit modes")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "cannot use") {
|
||||
t.Errorf("expected 'cannot use' in error, got: %s", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_FuzzyMatch(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "fuzzy_multi.txt")
|
||||
// File has trailing whitespace
|
||||
original := "func foo() { \n\treturn 1 \n}\nfunc bar() { \n\treturn 2 \n}\n"
|
||||
writeFileOrFail(t, path, original)
|
||||
|
||||
// Search without trailing whitespace (common LLM behavior)
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "func foo() {\n\treturn 1\n}", NewText: "func foo() {\n\treturn 10\n}"},
|
||||
{OldText: "func bar() {\n\treturn 2\n}", NewText: "func bar() {\n\treturn 20\n}"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("tool returned error: %s", resp.Content)
|
||||
}
|
||||
|
||||
got, _ := os.ReadFile(path)
|
||||
gotStr := string(got)
|
||||
|
||||
if !strings.Contains(gotStr, "return 10") {
|
||||
t.Error("first edit not applied")
|
||||
}
|
||||
if !strings.Contains(gotStr, "return 20") {
|
||||
t.Error("second edit not applied")
|
||||
}
|
||||
|
||||
// Response should mention fuzzy match
|
||||
if !strings.Contains(resp.Content, "fuzzy") {
|
||||
t.Errorf("response should mention 'fuzzy', got: %s", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_Metadata(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "meta_multi.txt")
|
||||
writeFileOrFail(t, path, "aaa\nbbb\nccc\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "aaa", NewText: "AAA"},
|
||||
{OldText: "bbb", NewText: "BBB"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("tool returned error: %s", resp.Content)
|
||||
}
|
||||
|
||||
var meta map[string]any
|
||||
if err := json.Unmarshal([]byte(resp.Metadata), &meta); err != nil {
|
||||
t.Fatalf("metadata is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
diffs, ok := meta["file_diffs"].([]any)
|
||||
if !ok || len(diffs) == 0 {
|
||||
t.Fatal("metadata missing file_diffs")
|
||||
}
|
||||
|
||||
firstDiff, ok := diffs[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("first diff is not an object")
|
||||
}
|
||||
|
||||
// Check that diff_blocks contains both edits
|
||||
diffBlocks, ok := firstDiff["diff_blocks"].([]any)
|
||||
if !ok || len(diffBlocks) != 2 {
|
||||
t.Fatalf("expected 2 diff_blocks, got %d", len(diffBlocks))
|
||||
}
|
||||
|
||||
// Verify each block has old_text and new_text
|
||||
for i, block := range diffBlocks {
|
||||
b, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("diff_block[%d] is not an object", i)
|
||||
}
|
||||
if _, ok := b["old_text"]; !ok {
|
||||
t.Fatalf("diff_block[%d] missing old_text", i)
|
||||
}
|
||||
if _, ok := b["new_text"]; !ok {
|
||||
t.Fatalf("diff_block[%d] missing new_text", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,14 +28,14 @@ type SubagentSpawnResult struct {
|
||||
// SubagentSpawnFunc is a callback that spawns an in-process subagent. The
|
||||
// parent Kit instance injects this into the context so the core tool can
|
||||
// call back without importing pkg/kit (which would create a cycle).
|
||||
// The toolCallID parameter is the LLM-assigned ID of the spawn_subagent
|
||||
// The toolCallID parameter is the LLM-assigned ID of the subagent
|
||||
// tool call, enabling the parent to correlate subagent events.
|
||||
type SubagentSpawnFunc func(ctx context.Context, toolCallID, prompt, model, systemPrompt string, timeout time.Duration) (*SubagentSpawnResult, error)
|
||||
|
||||
type subagentCtxKey struct{}
|
||||
|
||||
// WithSubagentSpawner stores a spawn function in the context so that the
|
||||
// spawn_subagent core tool can create in-process subagents.
|
||||
// subagent core tool can create in-process subagents.
|
||||
func WithSubagentSpawner(ctx context.Context, fn SubagentSpawnFunc) context.Context {
|
||||
return context.WithValue(ctx, subagentCtxKey{}, fn)
|
||||
}
|
||||
@@ -49,7 +49,7 @@ func getSubagentSpawner(ctx context.Context) SubagentSpawnFunc {
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// spawn_subagent tool
|
||||
// subagent tool
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type subagentArgs struct {
|
||||
@@ -59,11 +59,11 @@ type subagentArgs struct {
|
||||
TimeoutSeconds int `json:"timeout_seconds,omitempty"`
|
||||
}
|
||||
|
||||
// NewSubagentTool creates the spawn_subagent core tool.
|
||||
// NewSubagentTool creates the subagent core tool.
|
||||
func NewSubagentTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
return &coreTool{
|
||||
info: fantasy.ToolInfo{
|
||||
Name: "spawn_subagent",
|
||||
Name: "subagent",
|
||||
Description: `Spawn a subagent to perform a task autonomously.
|
||||
|
||||
The subagent runs as a separate in-process Kit instance with full tool access
|
||||
|
||||
@@ -86,7 +86,7 @@ func ReadOnlyTools(opts ...ToolOption) []fantasy.AgentTool {
|
||||
}
|
||||
}
|
||||
|
||||
// SubagentTools returns all core tools except spawn_subagent. This prevents
|
||||
// SubagentTools returns all core tools except subagent. This prevents
|
||||
// infinite recursion when a subagent is itself a Kit instance.
|
||||
func SubagentTools(opts ...ToolOption) []fantasy.AgentTool {
|
||||
return []fantasy.AgentTool{
|
||||
|
||||
+243
-5
@@ -572,6 +572,102 @@ type Context struct {
|
||||
// })
|
||||
// // handle.Kill() to cancel, handle.Wait() to block
|
||||
SpawnSubagent func(SubagentConfig) (*SubagentHandle, *SubagentResult, error)
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Tree Navigation API (Phase 1 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// GetTreeNode returns a node by ID with full metadata and children.
|
||||
// Returns nil if entry not found.
|
||||
GetTreeNode func(entryID string) *TreeNode
|
||||
|
||||
// GetCurrentBranch returns the path from root to current leaf.
|
||||
// Each node contains full metadata (unlike GetMessages which flattens).
|
||||
GetCurrentBranch func() []TreeNode
|
||||
|
||||
// GetChildren returns direct child IDs of an entry.
|
||||
GetChildren func(entryID string) []string
|
||||
|
||||
// NavigateTo branches/forks the session to the specified entry ID.
|
||||
// Equivalent to SDK's Branch() but for extensions.
|
||||
NavigateTo func(entryID string) TreeNavigationResult
|
||||
|
||||
// SummarizeBranch uses LLM to summarize a branch range.
|
||||
// Returns summary text or error string (empty if success).
|
||||
SummarizeBranch func(fromID, toID string) string
|
||||
|
||||
// CollapseBranch replaces a branch range with a summary entry.
|
||||
// This is the "fresh context" primitive for context window management.
|
||||
CollapseBranch func(fromID, toID, summary string) TreeNavigationResult
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Skill Loading API (Phase 2 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// LoadSkill loads a single skill file from path.
|
||||
// Parses YAML frontmatter, returns skill with content ready for injection.
|
||||
LoadSkill func(path string) (*Skill, string)
|
||||
|
||||
// LoadSkillsFromDir discovers and loads all skills from a directory.
|
||||
LoadSkillsFromDir func(dir string) SkillLoadResult
|
||||
|
||||
// DiscoverSkills finds skills in standard locations.
|
||||
// Checks ~/.config/kit/skills/, .kit/skills/, .agents/skills/
|
||||
DiscoverSkills func() SkillLoadResult
|
||||
|
||||
// InjectSkillAsContext sends a skill's content as a system message.
|
||||
// Looks up skill by name from discovered skills.
|
||||
InjectSkillAsContext func(skillName string) string
|
||||
|
||||
// InjectRawSkillAsContext loads and immediately injects a skill file.
|
||||
InjectRawSkillAsContext func(path string) string
|
||||
|
||||
// GetAvailableSkills returns all currently loaded/discovered skills.
|
||||
GetAvailableSkills func() []Skill
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Template Parsing API (Phase 3 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// ParseTemplate extracts {{variables}} from template content.
|
||||
ParseTemplate func(name, content string) PromptTemplate
|
||||
|
||||
// RenderTemplate substitutes variables into template content.
|
||||
RenderTemplate func(tpl PromptTemplate, vars map[string]string) string
|
||||
|
||||
// ParseArguments parses command-line style arguments.
|
||||
ParseArguments func(input string, pattern ArgumentPattern) ParseResult
|
||||
|
||||
// SimpleParseArguments parses $1, $2, $@ style arguments.
|
||||
// Returns slice where [0]=full input, [1]=$1, [2]=$2, ... [n]=$@
|
||||
SimpleParseArguments func(input string, count int) []string
|
||||
|
||||
// EvaluateModelConditional checks if condition matches current model.
|
||||
// Condition supports wildcards: * matches any, ? matches single char.
|
||||
EvaluateModelConditional func(condition string) bool
|
||||
|
||||
// RenderWithModelConditionals processes <if-model> blocks in content.
|
||||
RenderWithModelConditionals func(content string) string
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Model Resolution API (Phase 4 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// ResolveModelChain attempts each model in order until one is available.
|
||||
ResolveModelChain func(preferences []string) ModelResolutionResult
|
||||
|
||||
// GetModelCapabilities returns capabilities for a specific model.
|
||||
// If model is empty, uses current model.
|
||||
GetModelCapabilities func(model string) (ModelCapabilities, string)
|
||||
|
||||
// CheckModelAvailable verifies if a model string is valid.
|
||||
CheckModelAvailable func(model string) bool
|
||||
|
||||
// GetCurrentProvider returns just the provider part of current model.
|
||||
GetCurrentProvider func() string
|
||||
|
||||
// GetCurrentModelID returns just the model ID part of current model.
|
||||
GetCurrentModelID func() string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -598,6 +694,148 @@ type SessionMessage struct {
|
||||
Timestamp string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tree navigation types (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TreeNode represents a node in the session tree for navigation.
|
||||
// Extensions use this to traverse conversation history and implement
|
||||
// features like "fresh context" loops and branch summarization.
|
||||
type TreeNode struct {
|
||||
// ID is the unique entry identifier.
|
||||
ID string
|
||||
// ParentID links this entry to its parent (empty if root).
|
||||
ParentID string
|
||||
// Type is the entry type: "message", "branch_summary", "model_change", "extension_data", "tool_execution".
|
||||
Type string
|
||||
// Role is the message role for message entries: "user", "assistant", "system", "tool".
|
||||
Role string
|
||||
// Content is the text content or summary.
|
||||
Content string
|
||||
// Model is the model that generated this (for assistant messages).
|
||||
Model string
|
||||
// Provider is the provider used.
|
||||
Provider string
|
||||
// Timestamp is the RFC3339-formatted creation time.
|
||||
Timestamp string
|
||||
// Children is the list of child entry IDs for tree traversal.
|
||||
Children []string
|
||||
}
|
||||
|
||||
// TreeNavigationResult reports success or failure of tree operations.
|
||||
type TreeNavigationResult struct {
|
||||
// Success is true if the operation completed.
|
||||
Success bool
|
||||
// Error describes what went wrong (empty if success).
|
||||
Error string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Skill types (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Skill represents a loaded skill file with parsed YAML frontmatter.
|
||||
type Skill struct {
|
||||
// Name is the human-readable identifier.
|
||||
Name string
|
||||
// Description summarizes what this skill provides.
|
||||
Description string
|
||||
// Content is the markdown body (frontmatter stripped).
|
||||
Content string
|
||||
// Path is the absolute filesystem path.
|
||||
Path string
|
||||
// Tags are optional labels for categorization.
|
||||
Tags []string
|
||||
// When controls automatic inclusion: "always", "on-demand", or file-glob.
|
||||
When string
|
||||
}
|
||||
|
||||
// SkillLoadResult reports skills loaded from a directory.
|
||||
type SkillLoadResult struct {
|
||||
// Skills is the list of loaded skills.
|
||||
Skills []Skill
|
||||
// Error describes loading failures (empty if success).
|
||||
Error string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Template parsing types (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// PromptTemplate represents a parsed template with variable placeholders.
|
||||
type PromptTemplate struct {
|
||||
// Name is the template identifier.
|
||||
Name string
|
||||
// Content is the original template content.
|
||||
Content string
|
||||
// Variables are the extracted {{variable}} names.
|
||||
Variables []string
|
||||
}
|
||||
|
||||
// ArgumentPattern defines how to parse command arguments.
|
||||
type ArgumentPattern struct {
|
||||
// Positional names for $1, $2, etc.
|
||||
Positional []string
|
||||
// Rest is the variable name for $@ (all remaining).
|
||||
Rest string
|
||||
// Flags maps flag names to variable names (e.g., "--loop" -> "loop").
|
||||
Flags map[string]string
|
||||
}
|
||||
|
||||
// ParseResult reports argument parsing outcome.
|
||||
type ParseResult struct {
|
||||
// Vars maps variable names to values for positional args.
|
||||
Vars map[string]string
|
||||
// Flags maps flag names to values.
|
||||
Flags map[string]string
|
||||
// Rest is remaining unparsed text.
|
||||
Rest string
|
||||
// Error describes parsing failures (empty if success).
|
||||
Error string
|
||||
}
|
||||
|
||||
// ModelConditional represents an <if-model> block for evaluation.
|
||||
type ModelConditional struct {
|
||||
// Condition is the model pattern (e.g., "claude-*", "anthropic/*").
|
||||
Condition string
|
||||
// Content is rendered if condition matches.
|
||||
Content string
|
||||
// Else is rendered if condition doesn't match.
|
||||
Else string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Model resolution types (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ModelCapabilities describes what a model supports.
|
||||
type ModelCapabilities struct {
|
||||
// Provider is the provider ID (e.g., "anthropic").
|
||||
Provider string
|
||||
// ModelID is the model identifier (e.g., "claude-sonnet-4-20250929").
|
||||
ModelID string
|
||||
// ContextLimit is the maximum context window in tokens.
|
||||
ContextLimit int
|
||||
// OutputLimit is the maximum output tokens.
|
||||
OutputLimit int
|
||||
// Reasoning indicates if the model supports reasoning/thinking.
|
||||
Reasoning bool
|
||||
// Streaming indicates if the model supports streaming.
|
||||
Streaming bool
|
||||
}
|
||||
|
||||
// ModelResolutionResult reports model chain resolution outcome.
|
||||
type ModelResolutionResult struct {
|
||||
// Model is the selected model in "provider/model" format.
|
||||
Model string
|
||||
// Capabilities describes the selected model.
|
||||
Capabilities ModelCapabilities
|
||||
// Attempted lists models tried before success.
|
||||
Attempted []string
|
||||
// Error describes resolution failures (empty if success).
|
||||
Error string
|
||||
}
|
||||
|
||||
// ExtensionEntry represents persisted extension data stored in the session.
|
||||
// Extensions use AppendEntry to save custom state and GetEntries to retrieve
|
||||
// it on session resume.
|
||||
@@ -784,7 +1022,7 @@ func (a *API) OnToolResult(handler func(ToolResultEvent, Context) *ToolResultRes
|
||||
a.onToolResult(handler)
|
||||
}
|
||||
|
||||
// OnSubagentStart registers a handler that fires when a spawn_subagent tool
|
||||
// OnSubagentStart registers a handler that fires when a subagent tool
|
||||
// call begins executing. Use the ToolCallID to correlate with subsequent
|
||||
// OnSubagentChunk and OnSubagentEnd events for the same subagent.
|
||||
func (a *API) OnSubagentStart(handler func(SubagentStartEvent, Context)) {
|
||||
@@ -799,7 +1037,7 @@ func (a *API) OnSubagentChunk(handler func(SubagentChunkEvent, Context)) {
|
||||
a.onSubagentChunk(handler)
|
||||
}
|
||||
|
||||
// OnSubagentEnd registers a handler that fires when a spawn_subagent call
|
||||
// OnSubagentEnd registers a handler that fires when a subagent call
|
||||
// completes. ErrorMsg is non-empty when the subagent failed.
|
||||
func (a *API) OnSubagentEnd(handler func(SubagentEndEvent, Context)) {
|
||||
a.onSubagentEnd(handler)
|
||||
@@ -1808,9 +2046,9 @@ func (BeforeCompactResult) isResult() {}
|
||||
// Subagent lifecycle events (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SubagentStartEvent fires when a spawn_subagent tool call begins executing.
|
||||
// SubagentStartEvent fires when a subagent tool call begins executing.
|
||||
type SubagentStartEvent struct {
|
||||
// ToolCallID is the LLM-assigned ID of the spawn_subagent tool call.
|
||||
// ToolCallID is the LLM-assigned ID of the subagent tool call.
|
||||
// Use this to correlate SubagentChunkEvent and SubagentEndEvent.
|
||||
ToolCallID string
|
||||
// Task is the task description passed to the subagent.
|
||||
@@ -1850,7 +2088,7 @@ type SubagentChunkEvent struct {
|
||||
|
||||
func (e SubagentChunkEvent) Type() EventType { return SubagentChunk }
|
||||
|
||||
// SubagentEndEvent fires when a spawn_subagent tool call completes.
|
||||
// SubagentEndEvent fires when a subagent tool call completes.
|
||||
type SubagentEndEvent struct {
|
||||
// ToolCallID matches the SubagentStartEvent.ToolCallID for this subagent.
|
||||
ToolCallID string
|
||||
|
||||
@@ -72,7 +72,7 @@ const (
|
||||
// cancel compaction by returning Cancel=true.
|
||||
BeforeCompact EventType = "before_compact"
|
||||
|
||||
// SubagentStart fires when a spawn_subagent tool call begins executing.
|
||||
// SubagentStart fires when a subagent tool call begins executing.
|
||||
// Carries the tool call ID and the task description.
|
||||
SubagentStart EventType = "subagent_start"
|
||||
|
||||
@@ -80,7 +80,7 @@ const (
|
||||
// subagent: text chunks, tool calls, tool results, etc.
|
||||
SubagentChunk EventType = "subagent_chunk"
|
||||
|
||||
// SubagentEnd fires when a spawn_subagent tool call completes (success
|
||||
// SubagentEnd fires when a subagent tool call completes (success
|
||||
// or error). Carries the final response and any error message.
|
||||
SubagentEnd EventType = "subagent_end"
|
||||
)
|
||||
|
||||
@@ -47,46 +47,56 @@ func LoadExtensions(extraPaths []string) ([]LoadedExtension, error) {
|
||||
return loaded, nil
|
||||
}
|
||||
|
||||
// pathSet is a thread-safe helper for deduplicating and ordering file paths.
|
||||
type pathSet struct {
|
||||
m map[string]bool
|
||||
list []string
|
||||
}
|
||||
|
||||
func newPathSet() *pathSet {
|
||||
return &pathSet{m: make(map[string]bool)}
|
||||
}
|
||||
|
||||
func (ps *pathSet) add(p string) bool {
|
||||
abs, err := filepath.Abs(p)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if ps.m[abs] {
|
||||
return false
|
||||
}
|
||||
ps.m[abs] = true
|
||||
ps.list = append(ps.list, abs)
|
||||
return true
|
||||
}
|
||||
|
||||
// discoverExtensionPaths returns deduplicated paths to extension files in
|
||||
// load-order (global first, then project-local, then explicit).
|
||||
func discoverExtensionPaths(extraPaths []string) []string {
|
||||
seen := make(map[string]bool)
|
||||
var paths []string
|
||||
|
||||
add := func(p string) {
|
||||
abs, err := filepath.Abs(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if seen[abs] {
|
||||
return
|
||||
}
|
||||
seen[abs] = true
|
||||
paths = append(paths, abs)
|
||||
}
|
||||
ps := newPathSet()
|
||||
|
||||
// Global extensions: $XDG_CONFIG_HOME/kit/extensions/ (default ~/.config/kit/extensions/)
|
||||
globalDir := globalExtensionsDir()
|
||||
for _, p := range findExtensionsInDir(globalDir) {
|
||||
add(p)
|
||||
ps.add(p)
|
||||
}
|
||||
|
||||
// Global installed git packages: $XDG_DATA_HOME/kit/git/
|
||||
globalGitDir := globalGitInstallRoot()
|
||||
for _, p := range findExtensionsInGitPackages(globalGitDir) {
|
||||
add(p)
|
||||
ps.add(p)
|
||||
}
|
||||
|
||||
// Project-local extensions: .kit/extensions/
|
||||
localDir := filepath.Join(".kit", "extensions")
|
||||
for _, p := range findExtensionsInDir(localDir) {
|
||||
add(p)
|
||||
ps.add(p)
|
||||
}
|
||||
|
||||
// Project-local installed git packages: .kit/git/
|
||||
projectGitDir := filepath.Join(".kit", "git")
|
||||
for _, p := range findExtensionsInGitPackages(projectGitDir) {
|
||||
add(p)
|
||||
ps.add(p)
|
||||
}
|
||||
|
||||
// Explicit paths (highest precedence)
|
||||
@@ -97,14 +107,14 @@ func discoverExtensionPaths(extraPaths []string) []string {
|
||||
}
|
||||
if info.IsDir() {
|
||||
for _, found := range findExtensionsInDir(p) {
|
||||
add(found)
|
||||
ps.add(found)
|
||||
}
|
||||
} else if strings.HasSuffix(p, ".go") {
|
||||
add(p)
|
||||
ps.add(p)
|
||||
}
|
||||
}
|
||||
|
||||
return paths
|
||||
return ps.list
|
||||
}
|
||||
|
||||
// findExtensionsInDir returns .go files in dir and main.go in immediate subdirs.
|
||||
|
||||
@@ -56,11 +56,261 @@ func NewRunner(exts []LoadedExtension) *Runner {
|
||||
}
|
||||
|
||||
// SetContext updates the runtime context (session ID, model, etc.) that is
|
||||
// passed to every handler invocation. Thread-safe.
|
||||
// passed to every handler invocation. Nil function fields are replaced with
|
||||
// safe no-ops so extension handlers never panic on a missing callback.
|
||||
// Thread-safe.
|
||||
func (r *Runner) SetContext(ctx Context) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.ctx = ctx
|
||||
r.ctx = normalizeContext(ctx)
|
||||
}
|
||||
|
||||
// normalizeContext replaces nil function fields in ctx with no-op stubs so
|
||||
// that extension handlers can call any ctx method without a nil-function panic.
|
||||
func normalizeContext(ctx Context) Context {
|
||||
if ctx.Print == nil {
|
||||
ctx.Print = func(string) {}
|
||||
}
|
||||
if ctx.PrintInfo == nil {
|
||||
ctx.PrintInfo = func(string) {}
|
||||
}
|
||||
if ctx.PrintError == nil {
|
||||
ctx.PrintError = func(string) {}
|
||||
}
|
||||
if ctx.PrintBlock == nil {
|
||||
ctx.PrintBlock = func(PrintBlockOpts) {}
|
||||
}
|
||||
if ctx.SendMessage == nil {
|
||||
ctx.SendMessage = func(string) {}
|
||||
}
|
||||
if ctx.CancelAndSend == nil {
|
||||
ctx.CancelAndSend = func(string) {}
|
||||
}
|
||||
if ctx.SetWidget == nil {
|
||||
ctx.SetWidget = func(WidgetConfig) {}
|
||||
}
|
||||
if ctx.RemoveWidget == nil {
|
||||
ctx.RemoveWidget = func(string) {}
|
||||
}
|
||||
if ctx.SetHeader == nil {
|
||||
ctx.SetHeader = func(HeaderFooterConfig) {}
|
||||
}
|
||||
if ctx.RemoveHeader == nil {
|
||||
ctx.RemoveHeader = func() {}
|
||||
}
|
||||
if ctx.SetFooter == nil {
|
||||
ctx.SetFooter = func(HeaderFooterConfig) {}
|
||||
}
|
||||
if ctx.RemoveFooter == nil {
|
||||
ctx.RemoveFooter = func() {}
|
||||
}
|
||||
if ctx.PromptSelect == nil {
|
||||
ctx.PromptSelect = func(PromptSelectConfig) PromptSelectResult {
|
||||
return PromptSelectResult{Cancelled: true}
|
||||
}
|
||||
}
|
||||
if ctx.PromptConfirm == nil {
|
||||
ctx.PromptConfirm = func(PromptConfirmConfig) PromptConfirmResult {
|
||||
return PromptConfirmResult{Cancelled: true}
|
||||
}
|
||||
}
|
||||
if ctx.PromptInput == nil {
|
||||
ctx.PromptInput = func(PromptInputConfig) PromptInputResult {
|
||||
return PromptInputResult{Cancelled: true}
|
||||
}
|
||||
}
|
||||
if ctx.PromptMultiSelect == nil {
|
||||
ctx.PromptMultiSelect = func(PromptMultiSelectConfig) PromptMultiSelectResult {
|
||||
return PromptMultiSelectResult{Cancelled: true}
|
||||
}
|
||||
}
|
||||
if ctx.ShowOverlay == nil {
|
||||
ctx.ShowOverlay = func(OverlayConfig) OverlayResult {
|
||||
return OverlayResult{Cancelled: true, Index: -1}
|
||||
}
|
||||
}
|
||||
if ctx.SetEditor == nil {
|
||||
ctx.SetEditor = func(EditorConfig) {}
|
||||
}
|
||||
if ctx.ResetEditor == nil {
|
||||
ctx.ResetEditor = func() {}
|
||||
}
|
||||
if ctx.SetEditorText == nil {
|
||||
ctx.SetEditorText = func(string) {}
|
||||
}
|
||||
if ctx.SetUIVisibility == nil {
|
||||
ctx.SetUIVisibility = func(UIVisibility) {}
|
||||
}
|
||||
if ctx.SetStatus == nil {
|
||||
ctx.SetStatus = func(string, string, int) {}
|
||||
}
|
||||
if ctx.RemoveStatus == nil {
|
||||
ctx.RemoveStatus = func(string) {}
|
||||
}
|
||||
if ctx.GetContextStats == nil {
|
||||
ctx.GetContextStats = func() ContextStats { return ContextStats{} }
|
||||
}
|
||||
if ctx.GetMessages == nil {
|
||||
ctx.GetMessages = func() []SessionMessage { return nil }
|
||||
}
|
||||
if ctx.GetSessionPath == nil {
|
||||
ctx.GetSessionPath = func() string { return "" }
|
||||
}
|
||||
if ctx.AppendEntry == nil {
|
||||
ctx.AppendEntry = func(string, string) (string, error) { return "", nil }
|
||||
}
|
||||
if ctx.GetEntries == nil {
|
||||
ctx.GetEntries = func(string) []ExtensionEntry { return nil }
|
||||
}
|
||||
if ctx.GetOption == nil {
|
||||
ctx.GetOption = func(string) string { return "" }
|
||||
}
|
||||
if ctx.SetOption == nil {
|
||||
ctx.SetOption = func(string, string) {}
|
||||
}
|
||||
if ctx.SetModel == nil {
|
||||
ctx.SetModel = func(string) error { return nil }
|
||||
}
|
||||
if ctx.GetAvailableModels == nil {
|
||||
ctx.GetAvailableModels = func() []ModelInfoEntry { return nil }
|
||||
}
|
||||
if ctx.EmitCustomEvent == nil {
|
||||
ctx.EmitCustomEvent = func(string, string) {}
|
||||
}
|
||||
if ctx.GetAllTools == nil {
|
||||
ctx.GetAllTools = func() []ToolInfo { return nil }
|
||||
}
|
||||
if ctx.SetActiveTools == nil {
|
||||
ctx.SetActiveTools = func([]string) {}
|
||||
}
|
||||
if ctx.Exit == nil {
|
||||
ctx.Exit = func() {}
|
||||
}
|
||||
if ctx.Complete == nil {
|
||||
ctx.Complete = func(CompleteRequest) (CompleteResponse, error) {
|
||||
return CompleteResponse{}, nil
|
||||
}
|
||||
}
|
||||
if ctx.SuspendTUI == nil {
|
||||
ctx.SuspendTUI = func(callback func()) error { callback(); return nil }
|
||||
}
|
||||
if ctx.RenderMessage == nil {
|
||||
ctx.RenderMessage = func(string, string) {}
|
||||
}
|
||||
if ctx.RegisterTheme == nil {
|
||||
ctx.RegisterTheme = func(string, ThemeColorConfig) {}
|
||||
}
|
||||
if ctx.SetTheme == nil {
|
||||
ctx.SetTheme = func(string) error { return nil }
|
||||
}
|
||||
if ctx.ListThemes == nil {
|
||||
ctx.ListThemes = func() []string { return nil }
|
||||
}
|
||||
if ctx.ReloadExtensions == nil {
|
||||
ctx.ReloadExtensions = func() error { return nil }
|
||||
}
|
||||
if ctx.SpawnSubagent == nil {
|
||||
ctx.SpawnSubagent = func(SubagentConfig) (*SubagentHandle, *SubagentResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Tree Navigation API no-ops
|
||||
// -------------------------------------------------------------------------
|
||||
if ctx.GetTreeNode == nil {
|
||||
ctx.GetTreeNode = func(string) *TreeNode { return nil }
|
||||
}
|
||||
if ctx.GetCurrentBranch == nil {
|
||||
ctx.GetCurrentBranch = func() []TreeNode { return nil }
|
||||
}
|
||||
if ctx.GetChildren == nil {
|
||||
ctx.GetChildren = func(string) []string { return nil }
|
||||
}
|
||||
if ctx.NavigateTo == nil {
|
||||
ctx.NavigateTo = func(string) TreeNavigationResult {
|
||||
return TreeNavigationResult{Success: false, Error: "not implemented"}
|
||||
}
|
||||
}
|
||||
if ctx.SummarizeBranch == nil {
|
||||
ctx.SummarizeBranch = func(string, string) string {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
if ctx.CollapseBranch == nil {
|
||||
ctx.CollapseBranch = func(string, string, string) TreeNavigationResult {
|
||||
return TreeNavigationResult{Success: false, Error: "not implemented"}
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Skill Loading API no-ops
|
||||
// -------------------------------------------------------------------------
|
||||
if ctx.LoadSkill == nil {
|
||||
ctx.LoadSkill = func(string) (*Skill, string) { return nil, "" }
|
||||
}
|
||||
if ctx.LoadSkillsFromDir == nil {
|
||||
ctx.LoadSkillsFromDir = func(string) SkillLoadResult { return SkillLoadResult{} }
|
||||
}
|
||||
if ctx.DiscoverSkills == nil {
|
||||
ctx.DiscoverSkills = func() SkillLoadResult { return SkillLoadResult{} }
|
||||
}
|
||||
if ctx.InjectSkillAsContext == nil {
|
||||
ctx.InjectSkillAsContext = func(string) string { return "" }
|
||||
}
|
||||
if ctx.InjectRawSkillAsContext == nil {
|
||||
ctx.InjectRawSkillAsContext = func(string) string { return "" }
|
||||
}
|
||||
if ctx.GetAvailableSkills == nil {
|
||||
ctx.GetAvailableSkills = func() []Skill { return nil }
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Template Parsing API no-ops
|
||||
// -------------------------------------------------------------------------
|
||||
if ctx.ParseTemplate == nil {
|
||||
ctx.ParseTemplate = func(string, string) PromptTemplate { return PromptTemplate{} }
|
||||
}
|
||||
if ctx.RenderTemplate == nil {
|
||||
ctx.RenderTemplate = func(PromptTemplate, map[string]string) string { return "" }
|
||||
}
|
||||
if ctx.ParseArguments == nil {
|
||||
ctx.ParseArguments = func(string, ArgumentPattern) ParseResult { return ParseResult{} }
|
||||
}
|
||||
if ctx.SimpleParseArguments == nil {
|
||||
ctx.SimpleParseArguments = func(string, int) []string { return nil }
|
||||
}
|
||||
if ctx.EvaluateModelConditional == nil {
|
||||
ctx.EvaluateModelConditional = func(string) bool { return false }
|
||||
}
|
||||
if ctx.RenderWithModelConditionals == nil {
|
||||
ctx.RenderWithModelConditionals = func(string) string { return "" }
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Model Resolution API no-ops
|
||||
// -------------------------------------------------------------------------
|
||||
if ctx.ResolveModelChain == nil {
|
||||
ctx.ResolveModelChain = func([]string) ModelResolutionResult {
|
||||
return ModelResolutionResult{Error: "not implemented"}
|
||||
}
|
||||
}
|
||||
if ctx.GetModelCapabilities == nil {
|
||||
ctx.GetModelCapabilities = func(string) (ModelCapabilities, string) {
|
||||
return ModelCapabilities{}, "not implemented"
|
||||
}
|
||||
}
|
||||
if ctx.CheckModelAvailable == nil {
|
||||
ctx.CheckModelAvailable = func(string) bool { return false }
|
||||
}
|
||||
if ctx.GetCurrentProvider == nil {
|
||||
ctx.GetCurrentProvider = func() string { return "" }
|
||||
}
|
||||
if ctx.GetCurrentModelID == nil {
|
||||
ctx.GetCurrentModelID = func() string { return "" }
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
// GetContext returns a snapshot of the current runtime context. Thread-safe.
|
||||
|
||||
@@ -173,10 +173,10 @@ type subagentJSONOutput struct {
|
||||
} `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
var subagentCounter uint64
|
||||
var subagentCounter atomic.Uint64
|
||||
|
||||
func generateSubagentID() string {
|
||||
n := atomic.AddUint64(&subagentCounter, 1)
|
||||
n := subagentCounter.Add(1)
|
||||
return fmt.Sprintf("sub-%d-%d", time.Now().UnixNano(), n)
|
||||
}
|
||||
|
||||
|
||||
@@ -128,6 +128,24 @@ func Symbols() interp.Exports {
|
||||
"ThemeColor": reflect.ValueOf((*ThemeColor)(nil)),
|
||||
"ThemeColorConfig": reflect.ValueOf((*ThemeColorConfig)(nil)),
|
||||
|
||||
// Tree navigation types
|
||||
"TreeNode": reflect.ValueOf((*TreeNode)(nil)),
|
||||
"TreeNavigationResult": reflect.ValueOf((*TreeNavigationResult)(nil)),
|
||||
|
||||
// Skill types
|
||||
"Skill": reflect.ValueOf((*Skill)(nil)),
|
||||
"SkillLoadResult": reflect.ValueOf((*SkillLoadResult)(nil)),
|
||||
|
||||
// Template parsing types
|
||||
"PromptTemplate": reflect.ValueOf((*PromptTemplate)(nil)),
|
||||
"ArgumentPattern": reflect.ValueOf((*ArgumentPattern)(nil)),
|
||||
"ParseResult": reflect.ValueOf((*ParseResult)(nil)),
|
||||
"ModelConditional": reflect.ValueOf((*ModelConditional)(nil)),
|
||||
|
||||
// Model resolution types
|
||||
"ModelCapabilities": reflect.ValueOf((*ModelCapabilities)(nil)),
|
||||
"ModelResolutionResult": reflect.ValueOf((*ModelResolutionResult)(nil)),
|
||||
|
||||
// Event structs
|
||||
"ToolCallEvent": reflect.ValueOf((*ToolCallEvent)(nil)),
|
||||
"ToolCallResult": reflect.ValueOf((*ToolCallResult)(nil)),
|
||||
|
||||
@@ -42,14 +42,14 @@ func ExtensionToolsAsFantasy(defs []ToolDef, runner *Runner) []fantasy.AgentTool
|
||||
|
||||
// coreToolKinds maps built-in tool names to their kind classification.
|
||||
var coreToolKinds = map[string]string{
|
||||
"bash": "execute",
|
||||
"edit": "edit",
|
||||
"write": "edit",
|
||||
"read": "read",
|
||||
"ls": "read",
|
||||
"grep": "search",
|
||||
"find": "search",
|
||||
"spawn_subagent": "agent",
|
||||
"bash": "execute",
|
||||
"edit": "edit",
|
||||
"write": "edit",
|
||||
"read": "read",
|
||||
"ls": "read",
|
||||
"grep": "search",
|
||||
"find": "search",
|
||||
"subagent": "agent",
|
||||
}
|
||||
|
||||
// toolKindFor returns the ToolKind for a given tool name, defaulting to
|
||||
|
||||
@@ -115,9 +115,9 @@ const (
|
||||
)
|
||||
|
||||
// Message is a single conversation message containing a heterogeneous slice
|
||||
// of ContentPart blocks. This design (borrowed from crush) enables a single
|
||||
// assistant message to carry text, reasoning, and multiple tool calls as
|
||||
// discrete, typed blocks rather than flattening everything into strings.
|
||||
// of ContentPart blocks. This design enables a single assistant message to
|
||||
// carry text, reasoning, and multiple tool calls as discrete, typed blocks
|
||||
// rather than flattening everything into strings.
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
Role MessageRole `json:"role"`
|
||||
@@ -312,12 +312,18 @@ func UnmarshalParts(data []byte) ([]ContentPart, error) {
|
||||
return parts, nil
|
||||
}
|
||||
|
||||
// --- Fantasy bridge ---
|
||||
// --- LLM bridge ---
|
||||
|
||||
// ToFantasyMessages converts a Message to one or more fantasy.Message values.
|
||||
// An assistant message with tool calls produces a single fantasy message with
|
||||
// ToLLMMessages converts a Message to one or more LLM message values.
|
||||
// An assistant message with tool calls produces a single message with
|
||||
// mixed TextPart and ToolCallPart content. Tool-role messages produce
|
||||
// ToolResultPart entries.
|
||||
func (m *Message) ToLLMMessages() []fantasy.Message {
|
||||
return m.ToFantasyMessages()
|
||||
}
|
||||
|
||||
// Deprecated: Use ToLLMMessages instead.
|
||||
// ToFantasyMessages converts a Message to one or more LLM message values.
|
||||
func (m *Message) ToFantasyMessages() []fantasy.Message {
|
||||
switch m.Role {
|
||||
case RoleAssistant:
|
||||
@@ -416,7 +422,14 @@ func (m *Message) ToFantasyMessages() []fantasy.Message {
|
||||
}
|
||||
}
|
||||
|
||||
// FromFantasyMessage converts a fantasy.Message into our Message type,
|
||||
// FromLLMMessage converts an LLM message into our Message type,
|
||||
// extracting all content parts into the appropriate block types.
|
||||
func FromLLMMessage(msg fantasy.Message) Message {
|
||||
return FromFantasyMessage(msg)
|
||||
}
|
||||
|
||||
// Deprecated: Use FromLLMMessage instead.
|
||||
// FromFantasyMessage converts an LLM message into our Message type,
|
||||
// extracting all content parts into the appropriate block types.
|
||||
func FromFantasyMessage(msg fantasy.Message) Message {
|
||||
m := Message{
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"maps"
|
||||
"os"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"charm.land/fantasy/providers/openai"
|
||||
)
|
||||
|
||||
// buildCacheProviderOptions returns caching options for supported models.
|
||||
// Caching is enabled by default for all supported models to reduce costs.
|
||||
// Set KIT_DISABLE_CACHE=1 or ProviderConfig.DisableCaching=true to opt out.
|
||||
func buildCacheProviderOptions(modelInfo *ModelInfo, config *ProviderConfig) fantasy.ProviderOptions {
|
||||
// Check explicit opt-out via config
|
||||
if config.DisableCaching {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check global opt-out via environment
|
||||
if os.Getenv("KIT_DISABLE_CACHE") != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if model supports caching
|
||||
if modelInfo == nil || !modelInfo.SupportsCaching() {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch modelInfo.CacheType() {
|
||||
case "anthropic-ephemeral":
|
||||
// Provider-level Anthropic caching disabled - use message-level caching instead.
|
||||
return nil
|
||||
case "openai-prompt-cache":
|
||||
return buildOpenAICacheOptions(config, modelInfo.ID)
|
||||
case "google-cached-content":
|
||||
// Google caching not yet implemented.
|
||||
return nil
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// buildOpenAICacheOptions enables prompt caching for OpenAI models.
|
||||
// Uses a deterministic cache key based on system prompt and model ID.
|
||||
func buildOpenAICacheOptions(config *ProviderConfig, modelID string) fantasy.ProviderOptions {
|
||||
cacheKey := generateCacheKey(config.SystemPrompt, modelID)
|
||||
|
||||
return fantasy.ProviderOptions{
|
||||
openai.Name: &openai.ProviderOptions{
|
||||
PromptCacheKey: &cacheKey,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// generateCacheKey creates a deterministic cache key from system prompt and model.
|
||||
// This ensures the same system prompt + model combination gets cache hits.
|
||||
func generateCacheKey(systemPrompt, modelID string) string {
|
||||
if systemPrompt == "" {
|
||||
systemPrompt = "default"
|
||||
}
|
||||
|
||||
h := sha256.New()
|
||||
h.Write([]byte(systemPrompt))
|
||||
h.Write([]byte(modelID))
|
||||
|
||||
// Prefix with "kit-" to identify KIT-generated cache keys
|
||||
return "kit-" + hex.EncodeToString(h.Sum(nil))[:24]
|
||||
}
|
||||
|
||||
// mergeProviderOptions merges multiple ProviderOptions maps.
|
||||
// Later maps take precedence over earlier ones.
|
||||
func mergeProviderOptions(opts ...fantasy.ProviderOptions) fantasy.ProviderOptions {
|
||||
result := make(fantasy.ProviderOptions)
|
||||
|
||||
for _, opt := range opts {
|
||||
maps.Copy(result, opt)
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,248 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
func TestModelInfo_SupportsCaching(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
family string
|
||||
expected bool
|
||||
}{
|
||||
{"Claude model", "claude-3-5-sonnet", true},
|
||||
{"Claude 4 model", "claude-4-opus", true},
|
||||
{"GPT model", "gpt-4", true},
|
||||
{"GPT-5 model", "gpt-5", true},
|
||||
{"O1 model", "o1", true},
|
||||
{"O3 model", "o3", true},
|
||||
{"O4 model", "o4-mini", true},
|
||||
{"Codex model", "codex", true},
|
||||
{"Gemini model", "gemini-2.5-pro", true},
|
||||
{"Gemini 1.5 model", "gemini-1.5-flash", true},
|
||||
{"Llama model", "llama-3", false},
|
||||
{"Unknown model", "unknown", false},
|
||||
{"Empty family", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := &ModelInfo{Family: tt.family}
|
||||
if got := m.SupportsCaching(); got != tt.expected {
|
||||
t.Errorf("ModelInfo.SupportsCaching() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelInfo_CacheType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
family string
|
||||
expected string
|
||||
}{
|
||||
{"Claude model", "claude-3-5-sonnet", "anthropic-ephemeral"},
|
||||
{"GPT model", "gpt-4", "openai-prompt-cache"},
|
||||
{"O1 model", "o1", "openai-prompt-cache"},
|
||||
{"Gemini model", "gemini-2.5-pro", "google-cached-content"},
|
||||
{"Unknown model", "llama-3", ""},
|
||||
{"Empty family", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := &ModelInfo{Family: tt.family}
|
||||
if got := m.CacheType(); got != tt.expected {
|
||||
t.Errorf("ModelInfo.CacheType() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCacheKey(t *testing.T) {
|
||||
key1 := generateCacheKey("system prompt", "model-id")
|
||||
key2 := generateCacheKey("system prompt", "model-id")
|
||||
if key1 != key2 {
|
||||
t.Errorf("generateCacheKey should be deterministic: got %q and %q", key1, key2)
|
||||
}
|
||||
|
||||
key3 := generateCacheKey("different prompt", "model-id")
|
||||
if key1 == key3 {
|
||||
t.Errorf("generateCacheKey should produce different keys for different inputs")
|
||||
}
|
||||
|
||||
key4 := generateCacheKey("", "model-id")
|
||||
key5 := generateCacheKey("default", "model-id")
|
||||
if key4 != key5 {
|
||||
t.Errorf("generateCacheKey should treat empty prompt as 'default'")
|
||||
}
|
||||
|
||||
if len(key1) < 4 || key1[:4] != "kit-" {
|
||||
t.Errorf("generateCacheKey should produce keys with 'kit-' prefix, got %q", key1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCacheProviderOptions_Disabled(t *testing.T) {
|
||||
config := &ProviderConfig{DisableCaching: true}
|
||||
modelInfo := &ModelInfo{Family: "claude-3", ID: "claude-3-opus"}
|
||||
|
||||
if opts := buildCacheProviderOptions(modelInfo, config); opts != nil {
|
||||
t.Errorf("buildCacheProviderOptions should return nil when DisableCaching=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCacheProviderOptions_EnvironmentVariable(t *testing.T) {
|
||||
_ = os.Setenv("KIT_DISABLE_CACHE", "1")
|
||||
defer func() { _ = os.Unsetenv("KIT_DISABLE_CACHE") }()
|
||||
|
||||
config := &ProviderConfig{DisableCaching: false}
|
||||
modelInfo := &ModelInfo{Family: "claude-3", ID: "claude-3-opus"}
|
||||
|
||||
if opts := buildCacheProviderOptions(modelInfo, config); opts != nil {
|
||||
t.Errorf("buildCacheProviderOptions should return nil when KIT_DISABLE_CACHE is set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCacheProviderOptions_UnsupportedModel(t *testing.T) {
|
||||
config := &ProviderConfig{DisableCaching: false}
|
||||
modelInfo := &ModelInfo{Family: "llama-3", ID: "llama-3-70b"}
|
||||
|
||||
if opts := buildCacheProviderOptions(modelInfo, config); opts != nil {
|
||||
t.Errorf("buildCacheProviderOptions should return nil for unsupported model families")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCacheProviderOptions_NilModelInfo(t *testing.T) {
|
||||
config := &ProviderConfig{DisableCaching: false}
|
||||
|
||||
if opts := buildCacheProviderOptions(nil, config); opts != nil {
|
||||
t.Errorf("buildCacheProviderOptions should return nil when modelInfo is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCacheProviderOptions_Anthropic(t *testing.T) {
|
||||
_ = os.Unsetenv("KIT_DISABLE_CACHE")
|
||||
|
||||
config := &ProviderConfig{DisableCaching: false}
|
||||
modelInfo := &ModelInfo{Family: "claude-3", ID: "claude-3-opus"}
|
||||
|
||||
opts := buildCacheProviderOptions(modelInfo, config)
|
||||
// Provider-level Anthropic caching is disabled; message-level caching is used instead
|
||||
if opts != nil {
|
||||
t.Logf("Provider-level Anthropic caching disabled; using message-level caching")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCacheProviderOptions_OpenAI(t *testing.T) {
|
||||
_ = os.Unsetenv("KIT_DISABLE_CACHE")
|
||||
|
||||
config := &ProviderConfig{
|
||||
DisableCaching: false,
|
||||
SystemPrompt: "test system prompt",
|
||||
}
|
||||
modelInfo := &ModelInfo{Family: "gpt-4", ID: "gpt-4o"}
|
||||
|
||||
opts := buildCacheProviderOptions(modelInfo, config)
|
||||
if opts == nil {
|
||||
t.Fatalf("buildCacheProviderOptions should return options for OpenAI models")
|
||||
}
|
||||
|
||||
if _, ok := opts["openai"]; !ok {
|
||||
t.Errorf("buildCacheProviderOptions should include 'openai' key for GPT models")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachingPriorityOverThinking(t *testing.T) {
|
||||
_ = os.Unsetenv("KIT_DISABLE_CACHE")
|
||||
|
||||
// Anthropic uses message-level caching; provider-level returns nil
|
||||
config1 := &ProviderConfig{
|
||||
DisableCaching: false,
|
||||
ThinkingLevel: ThinkingOff,
|
||||
}
|
||||
modelInfo1 := &ModelInfo{Family: "claude-3", ID: "claude-3-opus"}
|
||||
opts1 := buildCacheProviderOptions(modelInfo1, config1)
|
||||
if opts1 != nil {
|
||||
t.Logf("Provider-level Anthropic caching disabled; using message-level caching")
|
||||
}
|
||||
|
||||
// OpenAI provider-level caching works with thinking enabled
|
||||
config2 := &ProviderConfig{
|
||||
DisableCaching: false,
|
||||
SystemPrompt: "test prompt",
|
||||
ThinkingLevel: ThinkingMedium,
|
||||
}
|
||||
modelInfo2 := &ModelInfo{Family: "gpt-4", ID: "gpt-4o"}
|
||||
opts2 := buildCacheProviderOptions(modelInfo2, config2)
|
||||
if opts2 == nil {
|
||||
t.Errorf("OpenAI caching should work with thinking enabled")
|
||||
}
|
||||
|
||||
// OpenAI caching also works with thinking disabled
|
||||
config3 := &ProviderConfig{
|
||||
DisableCaching: false,
|
||||
SystemPrompt: "test prompt",
|
||||
ThinkingLevel: ThinkingOff,
|
||||
}
|
||||
opts3 := buildCacheProviderOptions(modelInfo2, config3)
|
||||
if opts3 == nil {
|
||||
t.Errorf("OpenAI caching should work when thinking is OFF")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeProviderOptions(t *testing.T) {
|
||||
opts1 := fantasy.ProviderOptions{
|
||||
"provider1": &testProviderData{value: "value1"},
|
||||
}
|
||||
opts2 := fantasy.ProviderOptions{
|
||||
"provider2": &testProviderData{value: "value2"},
|
||||
}
|
||||
|
||||
merged := mergeProviderOptions(opts1, opts2)
|
||||
|
||||
if len(merged) != 2 {
|
||||
t.Errorf("mergeProviderOptions should combine options from multiple maps, got %d items", len(merged))
|
||||
}
|
||||
|
||||
if _, ok := merged["provider1"]; !ok {
|
||||
t.Errorf("merged options should contain 'provider1' key")
|
||||
}
|
||||
|
||||
if _, ok := merged["provider2"]; !ok {
|
||||
t.Errorf("merged options should contain 'provider2' key")
|
||||
}
|
||||
|
||||
// Later options should override earlier ones
|
||||
opts3 := fantasy.ProviderOptions{
|
||||
"provider1": &testProviderData{value: "overridden"},
|
||||
}
|
||||
merged2 := mergeProviderOptions(opts1, opts3)
|
||||
|
||||
if data, ok := merged2["provider1"].(*testProviderData); ok {
|
||||
if data.value != "overridden" {
|
||||
t.Errorf("later options should override earlier ones, got %q", data.value)
|
||||
}
|
||||
}
|
||||
|
||||
if mergeProviderOptions() != nil {
|
||||
t.Errorf("mergeProviderOptions with no args should return nil")
|
||||
}
|
||||
}
|
||||
|
||||
// testProviderData is a simple implementation of ProviderOptionsData for testing
|
||||
type testProviderData struct {
|
||||
value string
|
||||
}
|
||||
|
||||
func (t *testProviderData) Options() {}
|
||||
|
||||
func (t *testProviderData) MarshalJSON() ([]byte, error) {
|
||||
return []byte(`"` + t.value + `"`), nil
|
||||
}
|
||||
|
||||
func (t *testProviderData) UnmarshalJSON(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
@@ -48,10 +48,10 @@ type modelsDBLimit struct {
|
||||
Output int `json:"output"`
|
||||
}
|
||||
|
||||
// npmToFantasyProvider maps npm package names from models.dev to fantasy
|
||||
// npmToLLMProvider maps npm package names from models.dev to LLM
|
||||
// provider identifiers. Providers not in this map but with an api URL
|
||||
// can be auto-routed through openaicompat.
|
||||
var npmToFantasyProvider = map[string]string{
|
||||
var npmToLLMProvider = map[string]string{
|
||||
"@ai-sdk/anthropic": "anthropic",
|
||||
"@ai-sdk/openai": "openai",
|
||||
"@ai-sdk/google": "google",
|
||||
|
||||
@@ -155,6 +155,7 @@ type ProviderConfig struct {
|
||||
MainGPU *int32
|
||||
TLSSkipVerify bool
|
||||
ThinkingLevel ThinkingLevel
|
||||
DisableCaching bool // Opt-out: set to true to disable automatic prompt caching
|
||||
}
|
||||
|
||||
// ProviderResult contains the result of provider creation.
|
||||
@@ -237,30 +238,59 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
|
||||
validateModelConfig(config, modelInfo)
|
||||
}
|
||||
|
||||
// Create the base provider
|
||||
var result *ProviderResult
|
||||
var createErr error
|
||||
|
||||
switch provider {
|
||||
case "anthropic":
|
||||
return createAnthropicProvider(ctx, config, modelName)
|
||||
result, createErr = createAnthropicProvider(ctx, config, modelName)
|
||||
case "openai":
|
||||
return createOpenAIProvider(ctx, config, modelName)
|
||||
result, createErr = createOpenAIProvider(ctx, config, modelName)
|
||||
case "google", "gemini":
|
||||
return createGoogleProvider(ctx, config, modelName)
|
||||
result, createErr = createGoogleProvider(ctx, config, modelName)
|
||||
case "ollama":
|
||||
return createOllamaProvider(ctx, config, modelName)
|
||||
result, createErr = createOllamaProvider(ctx, config, modelName)
|
||||
case "azure":
|
||||
return createAzureProvider(ctx, config, modelName)
|
||||
result, createErr = createAzureProvider(ctx, config, modelName)
|
||||
case "google-vertex-anthropic":
|
||||
return createVertexAnthropicProvider(ctx, config, modelName)
|
||||
result, createErr = createVertexAnthropicProvider(ctx, config, modelName)
|
||||
case "openrouter":
|
||||
return createOpenRouterProvider(ctx, config, modelName)
|
||||
result, createErr = createOpenRouterProvider(ctx, config, modelName)
|
||||
case "bedrock":
|
||||
return createBedrockProvider(ctx, config, modelName)
|
||||
result, createErr = createBedrockProvider(ctx, config, modelName)
|
||||
case "vercel":
|
||||
return createVercelProvider(ctx, config, modelName)
|
||||
result, createErr = createVercelProvider(ctx, config, modelName)
|
||||
case "custom":
|
||||
return createCustomProvider(ctx, config, modelName)
|
||||
result, createErr = createCustomProvider(ctx, config, modelName)
|
||||
default:
|
||||
return autoRouteProvider(ctx, config, provider, modelName, registry)
|
||||
result, createErr = autoRouteProvider(ctx, config, provider, modelName, registry)
|
||||
}
|
||||
|
||||
if createErr != nil {
|
||||
return nil, createErr
|
||||
}
|
||||
|
||||
// AUTOMATICALLY ENABLE CACHING for supported models (unless disabled).
|
||||
// This works for BOTH native and auto-routed providers by detecting
|
||||
// the model family from the model metadata.
|
||||
if cacheOpts := buildCacheProviderOptions(modelInfo, config); cacheOpts != nil {
|
||||
if result.ProviderOptions == nil {
|
||||
result.ProviderOptions = cacheOpts
|
||||
} else {
|
||||
// Merge cache options with existing provider options.
|
||||
// Only add cache options for providers that don't already have
|
||||
// options set, to avoid type conflicts (e.g., Anthropic has
|
||||
// different types for regular options vs cache control options).
|
||||
for k, v := range cacheOpts {
|
||||
if _, exists := result.ProviderOptions[k]; !exists {
|
||||
result.ProviderOptions[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// autoRouteProvider attempts to create a provider by looking up its npm package
|
||||
@@ -280,14 +310,14 @@ func autoRouteProvider(ctx context.Context, config *ProviderConfig, provider, mo
|
||||
npmPackage = modelInfo.ProviderNPM
|
||||
}
|
||||
|
||||
// Determine the fantasy provider for this npm package
|
||||
fantasyProvider := npmToFantasyProvider[npmPackage]
|
||||
if fantasyProvider == "" && providerInfo.API != "" {
|
||||
// Determine the LLM provider for this npm package
|
||||
llmProvider := npmToLLMProvider[npmPackage]
|
||||
if llmProvider == "" && providerInfo.API != "" {
|
||||
// Unknown npm but has API URL → route through openaicompat
|
||||
fantasyProvider = "openaicompat"
|
||||
llmProvider = "openaicompat"
|
||||
}
|
||||
|
||||
switch fantasyProvider {
|
||||
switch llmProvider {
|
||||
case "openaicompat":
|
||||
return createAutoRoutedOpenAICompatProvider(ctx, config, modelName, providerInfo)
|
||||
case "anthropic":
|
||||
@@ -301,7 +331,7 @@ func autoRouteProvider(ctx context.Context, config *ProviderConfig, provider, mo
|
||||
}
|
||||
return createAutoRoutedOpenAIProvider(ctx, config, modelName, providerInfo)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported provider: %s (npm: %s has no fantasy mapping)", provider, npmPackage)
|
||||
return nil, fmt.Errorf("unsupported provider: %s (npm: %s has no LLM provider mapping)", provider, npmPackage)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -510,10 +540,15 @@ func thinkingLevelToReasoningEffort(level ThinkingLevel) *openai.ReasoningEffort
|
||||
// SendReasoning to true and configures the thinking budget. For thinking-off
|
||||
// or non-reasoning models the returned map is nil.
|
||||
//
|
||||
// NOTE: With message-level caching, thinking and caching can work together.
|
||||
// Message-level cache control (ProviderCacheControlOptions) doesn't conflict
|
||||
// with provider-level thinking options (ProviderOptions).
|
||||
//
|
||||
// Anthropic requires max_tokens > thinking.budget_tokens. If the configured
|
||||
// MaxTokens is too low, it is bumped to budget + 4096 to leave room for the
|
||||
// actual response.
|
||||
func buildAnthropicProviderOptions(config *ProviderConfig, modelName string) fantasy.ProviderOptions {
|
||||
// Thinking is OFF by default. If user hasn't explicitly enabled it, return nil.
|
||||
if config.ThinkingLevel == "" || config.ThinkingLevel == ThinkingOff {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ var embeddedModelsJSON []byte
|
||||
type ModelInfo struct {
|
||||
ID string
|
||||
Name string
|
||||
Family string // Model family (e.g., "claude", "gpt", "gemini")
|
||||
Attachment bool
|
||||
Reasoning bool
|
||||
Temperature bool
|
||||
@@ -25,6 +26,44 @@ type ModelInfo struct {
|
||||
ProviderNPM string // Model-specific provider npm override (e.g. "@ai-sdk/anthropic")
|
||||
}
|
||||
|
||||
// SupportsCaching returns true if this model family supports prompt caching.
|
||||
// This enables automatic cost savings for supported models regardless of provider.
|
||||
func (m *ModelInfo) SupportsCaching() bool {
|
||||
switch {
|
||||
case strings.HasPrefix(m.Family, "claude"):
|
||||
return true
|
||||
case strings.HasPrefix(m.Family, "gpt"),
|
||||
strings.HasPrefix(m.Family, "o1"),
|
||||
strings.HasPrefix(m.Family, "o3"),
|
||||
strings.HasPrefix(m.Family, "o4"),
|
||||
strings.HasPrefix(m.Family, "codex"):
|
||||
return true
|
||||
case strings.HasPrefix(m.Family, "gemini"):
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// CacheType returns the appropriate cache mechanism for this model family.
|
||||
// Returns empty string if caching is not supported.
|
||||
func (m *ModelInfo) CacheType() string {
|
||||
switch {
|
||||
case strings.HasPrefix(m.Family, "claude"):
|
||||
return "anthropic-ephemeral"
|
||||
case strings.HasPrefix(m.Family, "gpt"),
|
||||
strings.HasPrefix(m.Family, "o1"),
|
||||
strings.HasPrefix(m.Family, "o3"),
|
||||
strings.HasPrefix(m.Family, "o4"),
|
||||
strings.HasPrefix(m.Family, "codex"):
|
||||
return "openai-prompt-cache"
|
||||
case strings.HasPrefix(m.Family, "gemini"):
|
||||
return "google-cached-content"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// Cost represents the pricing information for a model.
|
||||
type Cost struct {
|
||||
Input float64
|
||||
@@ -86,6 +125,7 @@ func buildFromModelsDB() map[string]ProviderInfo {
|
||||
modelsMap[modelID] = ModelInfo{
|
||||
ID: dm.ID,
|
||||
Name: dm.Name,
|
||||
Family: dm.Family,
|
||||
Attachment: dm.Attachment,
|
||||
Reasoning: dm.Reasoning,
|
||||
Temperature: dm.Temperature,
|
||||
@@ -308,27 +348,32 @@ func (r *ModelsRegistry) GetSupportedProviders() []string {
|
||||
return providers
|
||||
}
|
||||
|
||||
// GetFantasyProviders returns provider IDs that can be used with fantasy,
|
||||
// GetLLMProviders returns provider IDs that have LLM support,
|
||||
// either through a native provider or via openaicompat auto-routing.
|
||||
func (r *ModelsRegistry) GetFantasyProviders() []string {
|
||||
func (r *ModelsRegistry) GetLLMProviders() []string {
|
||||
var providers []string
|
||||
for providerID, info := range r.providers {
|
||||
if isProviderFantasySupported(providerID, &info) {
|
||||
if isProviderLLMSupported(providerID, &info) {
|
||||
providers = append(providers, providerID)
|
||||
}
|
||||
}
|
||||
return providers
|
||||
}
|
||||
|
||||
// isProviderFantasySupported checks if a provider can be used with fantasy.
|
||||
func isProviderFantasySupported(providerID string, info *ProviderInfo) bool {
|
||||
// Deprecated: Use GetLLMProviders instead.
|
||||
func (r *ModelsRegistry) GetFantasyProviders() []string {
|
||||
return r.GetLLMProviders()
|
||||
}
|
||||
|
||||
// isProviderLLMSupported checks if a provider can be used with the LLM layer.
|
||||
func isProviderLLMSupported(providerID string, info *ProviderInfo) bool {
|
||||
// Ollama is always supported (via openaicompat pointed at localhost)
|
||||
if providerID == "ollama" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if npm maps to a fantasy provider
|
||||
if _, ok := npmToFantasyProvider[info.NPM]; ok {
|
||||
// Check if npm maps to an LLM provider
|
||||
if _, ok := npmToLLMProvider[info.NPM]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -181,7 +181,7 @@ func OpenTreeSession(path string) (*TreeManager, error) {
|
||||
|
||||
// Set leaf to the last entry.
|
||||
if len(tm.entries) > 0 {
|
||||
tm.leafID = tm.entryID(tm.entries[len(tm.entries)-1])
|
||||
tm.leafID = tm.EntryID(tm.entries[len(tm.entries)-1])
|
||||
}
|
||||
|
||||
// Open file for appending.
|
||||
@@ -242,9 +242,14 @@ func (tm *TreeManager) AppendMessage(msg message.Message) (string, error) {
|
||||
return entry.ID, nil
|
||||
}
|
||||
|
||||
// AppendFantasyMessage converts a fantasy.Message and appends it.
|
||||
// AppendLLMMessage converts an LLM message and appends it.
|
||||
func (tm *TreeManager) AppendLLMMessage(msg fantasy.Message) (string, error) {
|
||||
return tm.AppendMessage(message.FromLLMMessage(msg))
|
||||
}
|
||||
|
||||
// Deprecated: Use AppendLLMMessage instead.
|
||||
func (tm *TreeManager) AppendFantasyMessage(msg fantasy.Message) (string, error) {
|
||||
return tm.AppendMessage(message.FromFantasyMessage(msg))
|
||||
return tm.AppendLLMMessage(msg)
|
||||
}
|
||||
|
||||
// AppendModelChange records a model/provider change.
|
||||
@@ -521,7 +526,7 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
for _, entry := range branch {
|
||||
// Once we reach the first kept entry, stop skipping.
|
||||
if skipping {
|
||||
entryID := tm.entryID(entry)
|
||||
entryID := tm.EntryID(entry)
|
||||
if entryID == lastCompaction.FirstKeptEntryID {
|
||||
skipping = false
|
||||
} else {
|
||||
@@ -535,7 +540,7 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
if err != nil {
|
||||
continue // skip malformed entries
|
||||
}
|
||||
msgs := msg.ToFantasyMessages()
|
||||
msgs := msg.ToLLMMessages()
|
||||
messages = append(messages, msgs...)
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
@@ -684,7 +689,7 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
|
||||
skipping := lastCompaction != nil
|
||||
for _, entry := range branch {
|
||||
if skipping {
|
||||
entryID := tm.entryID(entry)
|
||||
entryID := tm.EntryID(entry)
|
||||
if entryID == lastCompaction.FirstKeptEntryID {
|
||||
skipping = false
|
||||
} else {
|
||||
@@ -698,7 +703,7 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToFantasyMessages()
|
||||
msgs := msg.ToLLMMessages()
|
||||
for range msgs {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
@@ -737,31 +742,41 @@ func (tm *TreeManager) GetLastCompaction() *CompactionEntry {
|
||||
|
||||
// --- Legacy bridge ---
|
||||
|
||||
// AddFantasyMessages appends multiple fantasy messages as entries. This is
|
||||
// AddLLMMessages appends multiple LLM messages as entries. This is
|
||||
// used when syncing from the agent's ConversationMessages after a step.
|
||||
func (tm *TreeManager) AddFantasyMessages(msgs []fantasy.Message) error {
|
||||
func (tm *TreeManager) AddLLMMessages(msgs []fantasy.Message) error {
|
||||
for _, msg := range msgs {
|
||||
if _, err := tm.AppendFantasyMessage(msg); err != nil {
|
||||
if _, err := tm.AppendLLMMessage(msg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFantasyMessages builds the context and returns just the messages.
|
||||
// Deprecated: Use AddLLMMessages instead.
|
||||
func (tm *TreeManager) AddFantasyMessages(msgs []fantasy.Message) error {
|
||||
return tm.AddLLMMessages(msgs)
|
||||
}
|
||||
|
||||
// GetLLMMessages builds the context and returns just the messages.
|
||||
// This satisfies the same conceptual role as the old Manager.GetMessages().
|
||||
func (tm *TreeManager) GetFantasyMessages() []fantasy.Message {
|
||||
func (tm *TreeManager) GetLLMMessages() []fantasy.Message {
|
||||
msgs, _, _ := tm.BuildContext()
|
||||
return msgs
|
||||
}
|
||||
|
||||
// Deprecated: Use GetLLMMessages instead.
|
||||
func (tm *TreeManager) GetFantasyMessages() []fantasy.Message {
|
||||
return tm.GetLLMMessages()
|
||||
}
|
||||
|
||||
// --- Internal helpers ---
|
||||
|
||||
// addEntryToIndex adds an entry to the in-memory indices.
|
||||
func (tm *TreeManager) addEntryToIndex(entry any) {
|
||||
tm.entries = append(tm.entries, entry)
|
||||
|
||||
id := tm.entryID(entry)
|
||||
id := tm.EntryID(entry)
|
||||
parentID := tm.entryParentID(entry)
|
||||
|
||||
if id != "" {
|
||||
@@ -798,8 +813,8 @@ func (tm *TreeManager) writeEntry(entry any) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// entryID extracts the ID from any entry type.
|
||||
func (tm *TreeManager) entryID(entry any) string {
|
||||
// EntryID extracts the ID from any entry type.
|
||||
func (tm *TreeManager) EntryID(entry any) string {
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
return e.ID
|
||||
|
||||
@@ -127,9 +127,7 @@ func (p *MCPConnectionPool) GetConnection(ctx context.Context, serverName string
|
||||
return conn, nil
|
||||
} else {
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
p.debugLogger.LogDebug(fmt.Sprintf("[POOL] Connection %s unhealthy, removing", serverName))
|
||||
}
|
||||
p.debugLogger.LogDebug(fmt.Sprintf("[POOL] Connection %s unhealthy, removing", serverName))
|
||||
}
|
||||
_ = conn.client.Close()
|
||||
delete(p.connections, serverName)
|
||||
|
||||
@@ -3,6 +3,7 @@ package tools
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -70,7 +71,7 @@ func TestMCPToolManager_LoadTools_GracefulFailure(t *testing.T) {
|
||||
}
|
||||
|
||||
// The error should mention that all servers failed
|
||||
if err != nil && !contains(err.Error(), "all MCP servers failed") {
|
||||
if err != nil && !strings.Contains(err.Error(), "all MCP servers failed") {
|
||||
t.Errorf("Expected error message to mention all servers failed, got: %v", err)
|
||||
}
|
||||
|
||||
@@ -459,13 +460,3 @@ func sliceEqual(a, b []any) bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Helper function to check if a string contains a substring
|
||||
func contains(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
+100
-22
@@ -149,11 +149,13 @@ func TestInputComponent_QuitReturnsTeaQuit(t *testing.T) {
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// TestInputComponent_ClearCallsClearMessages verifies that /clear (and its
|
||||
// aliases) calls appCtrl.ClearMessages() and returns no submitMsg.
|
||||
// TestInputComponent_ClearForwardsAsSubmitMsg verifies that /clear (and its
|
||||
// aliases) are forwarded as submitMsg to the parent model so that the parent
|
||||
// can call ClearMessages(), update scrollback, and print the confirmation
|
||||
// message in one place. InputComponent must NOT call ClearMessages() directly.
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func TestInputComponent_ClearCallsClearMessages(t *testing.T) {
|
||||
func TestInputComponent_ClearForwardsAsSubmitMsg(t *testing.T) {
|
||||
aliases := []string{"/clear", "/c", "/cls"}
|
||||
for _, alias := range aliases {
|
||||
t.Run(alias, func(t *testing.T) {
|
||||
@@ -164,22 +166,29 @@ func TestInputComponent_ClearCallsClearMessages(t *testing.T) {
|
||||
|
||||
_, cmd := sendInputMsg(c, tea.KeyPressMsg{Code: tea.KeyEnter})
|
||||
|
||||
if ctrl.clearMsgCalled != 1 {
|
||||
t.Fatalf("%s: expected ClearMessages() called once, got %d", alias, ctrl.clearMsgCalled)
|
||||
// InputComponent must NOT call ClearMessages() directly.
|
||||
if ctrl.clearMsgCalled != 0 {
|
||||
t.Fatalf("%s: InputComponent must not call ClearMessages(), got %d", alias, ctrl.clearMsgCalled)
|
||||
}
|
||||
// No cmd should be returned (no submitMsg forwarded to parent).
|
||||
if cmd != nil {
|
||||
msg := runCmd(cmd)
|
||||
if _, ok := msg.(submitMsg); ok {
|
||||
t.Fatalf("%s: /clear should not emit submitMsg, got submitMsg", alias)
|
||||
}
|
||||
// A submitMsg must be emitted so the parent model handles /clear.
|
||||
if cmd == nil {
|
||||
t.Fatalf("%s: expected submitMsg cmd, got nil", alias)
|
||||
}
|
||||
msg := runCmd(cmd)
|
||||
sm, ok := msg.(submitMsg)
|
||||
if !ok {
|
||||
t.Fatalf("%s: expected submitMsg, got %T", alias, msg)
|
||||
}
|
||||
if sm.Text != alias {
|
||||
t.Fatalf("%s: expected submitMsg text %q, got %q", alias, alias, sm.Text)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestInputComponent_ClearNilCtrl_NoPanic verifies that /clear with a nil
|
||||
// appCtrl does not panic.
|
||||
// appCtrl does not panic. Since /clear is now forwarded to the parent via
|
||||
// submitMsg, no appCtrl interaction happens in InputComponent at all.
|
||||
func TestInputComponent_ClearNilCtrl_NoPanic(t *testing.T) {
|
||||
c := newTestInput(nil)
|
||||
c.textarea.SetValue("/clear")
|
||||
@@ -349,7 +358,7 @@ func TestStreamComponent_SpinnerKeepsRunningDuringStreaming(t *testing.T) {
|
||||
c = sendStreamMsg(c, app.StreamChunkEvent{Content: "hello"})
|
||||
|
||||
// Flush pending chunks (simulates the 16ms tick firing).
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{})
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{generation: c.flushGeneration})
|
||||
|
||||
if !c.spinning {
|
||||
t.Fatal("expected spinning=true after first chunk")
|
||||
@@ -376,7 +385,7 @@ func TestStreamComponent_ChunkAccumulation(t *testing.T) {
|
||||
}
|
||||
|
||||
// Flush pending chunks (simulates the 16ms tick firing).
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{})
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{generation: c.flushGeneration})
|
||||
|
||||
got := c.streamContent.String()
|
||||
want := "Hello, world!"
|
||||
@@ -396,6 +405,7 @@ func TestStreamComponent_ToolExecution_IsStarting_ShowsSpinner(t *testing.T) {
|
||||
c := newTestStream()
|
||||
|
||||
_, cmd := c.Update(app.ToolExecutionEvent{
|
||||
ToolCallID: "call-exec-1",
|
||||
ToolName: "exec_tool",
|
||||
IsStarting: true,
|
||||
})
|
||||
@@ -403,8 +413,9 @@ func TestStreamComponent_ToolExecution_IsStarting_ShowsSpinner(t *testing.T) {
|
||||
if !c.spinning {
|
||||
t.Fatal("expected spinning=true during tool execution")
|
||||
}
|
||||
if len(c.activeTools) != 1 || !strings.Contains(c.activeTools[0], "exec_tool") {
|
||||
t.Fatalf("expected activeTools to contain tool name, got %v", c.activeTools)
|
||||
tools := c.activeToolDisplays()
|
||||
if len(tools) != 1 || !strings.Contains(tools[0], "exec_tool") {
|
||||
t.Fatalf("expected activeTools to contain tool name, got %v", tools)
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Fatal("expected tick cmd from ToolExecutionEvent{IsStarting:true}")
|
||||
@@ -418,11 +429,13 @@ func TestStreamComponent_ToolExecution_NotStarting_KeepsSpinning(t *testing.T) {
|
||||
c = sendStreamMsg(c, app.SpinnerEvent{Show: true})
|
||||
// Simulate a tool starting
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{
|
||||
ToolCallID: "call-some-1",
|
||||
ToolName: "some_tool",
|
||||
IsStarting: true,
|
||||
})
|
||||
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{
|
||||
ToolCallID: "call-some-1",
|
||||
ToolName: "some_tool",
|
||||
IsStarting: false,
|
||||
})
|
||||
@@ -440,9 +453,9 @@ func TestStreamComponent_ParallelToolExecution(t *testing.T) {
|
||||
c := newTestStream()
|
||||
|
||||
// Start three tools in parallel
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "read", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "grep", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "find", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-read", ToolName: "read", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-grep", ToolName: "grep", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-find", ToolName: "find", IsStarting: true})
|
||||
|
||||
if len(c.activeTools) != 3 {
|
||||
t.Fatalf("expected 3 active tools, got %d: %v", len(c.activeTools), c.activeTools)
|
||||
@@ -455,19 +468,44 @@ func TestStreamComponent_ParallelToolExecution(t *testing.T) {
|
||||
}
|
||||
|
||||
// Finish one tool
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "grep", IsStarting: false})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-grep", ToolName: "grep", IsStarting: false})
|
||||
if len(c.activeTools) != 2 {
|
||||
t.Fatalf("expected 2 active tools after one finished, got %d: %v", len(c.activeTools), c.activeTools)
|
||||
}
|
||||
|
||||
// Finish remaining tools
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "read", IsStarting: false})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "find", IsStarting: false})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-read", ToolName: "read", IsStarting: false})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-find", ToolName: "find", IsStarting: false})
|
||||
if len(c.activeTools) != 0 {
|
||||
t.Fatalf("expected 0 active tools after all finished, got %d: %v", len(c.activeTools), c.activeTools)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamComponent_ParallelSameToolName_UsesToolCallID verifies finishing one
|
||||
// tool call does not remove another concurrent call with the same tool name.
|
||||
func TestStreamComponent_ParallelSameToolName_UsesToolCallID(t *testing.T) {
|
||||
c := newTestStream()
|
||||
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-read-1", ToolName: "read", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-read-2", ToolName: "read", IsStarting: true})
|
||||
|
||||
tools := c.activeToolDisplays()
|
||||
if len(tools) != 2 {
|
||||
t.Fatalf("expected 2 active read calls, got %d (%v)", len(tools), tools)
|
||||
}
|
||||
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-read-1", ToolName: "read", IsStarting: false})
|
||||
tools = c.activeToolDisplays()
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("expected 1 active read call after finishing one ID, got %d (%v)", len(tools), tools)
|
||||
}
|
||||
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-read-2", ToolName: "read", IsStarting: false})
|
||||
if len(c.activeToolDisplays()) != 0 {
|
||||
t.Fatalf("expected no active tools after finishing both IDs, got %v", c.activeToolDisplays())
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// TestStreamComponent_GetRenderedContent verifies the method returns rendered
|
||||
// text when content is accumulated, and empty string when not.
|
||||
@@ -621,3 +659,43 @@ func TestStreamComponent_StaleTick_Discarded(t *testing.T) {
|
||||
t.Fatal("current-gen tick should reschedule")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamComponent_StaleFlushTick_Discarded verifies that flush ticks from a
|
||||
// previous generation (e.g. pre-Reset) are ignored.
|
||||
func TestStreamComponent_StaleFlushTick_Discarded(t *testing.T) {
|
||||
c := newTestStream()
|
||||
|
||||
// Start a pending flush and capture its generation.
|
||||
c = sendStreamMsg(c, app.StreamChunkEvent{Content: "old"})
|
||||
staleGen := c.flushGeneration
|
||||
if !c.flushPending {
|
||||
t.Fatal("precondition: expected flushPending=true after first chunk")
|
||||
}
|
||||
|
||||
// Reset should invalidate in-flight flush ticks.
|
||||
c.Reset()
|
||||
if c.flushGeneration == staleGen {
|
||||
t.Fatal("expected flushGeneration to change after Reset")
|
||||
}
|
||||
|
||||
// New content in a new generation.
|
||||
c = sendStreamMsg(c, app.StreamChunkEvent{Content: "new"})
|
||||
if got := c.pendingStream.String(); got != "new" {
|
||||
t.Fatalf("expected pendingStream='new', got %q", got)
|
||||
}
|
||||
|
||||
// Stale flush tick should be ignored.
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{generation: staleGen})
|
||||
if got := c.pendingStream.String(); got != "new" {
|
||||
t.Fatalf("stale flush tick should not commit pending stream, got %q", got)
|
||||
}
|
||||
|
||||
// Current generation flush should commit.
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{generation: c.flushGeneration})
|
||||
if got := c.pendingStream.String(); got != "" {
|
||||
t.Fatalf("expected pendingStream empty after current flush, got %q", got)
|
||||
}
|
||||
if got := c.streamContent.String(); got != "new" {
|
||||
t.Fatalf("expected streamContent='new' after current flush, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
+7
-9
@@ -179,9 +179,8 @@ func (c *CLI) DisplayDebugConfig(config map[string]any) {
|
||||
}
|
||||
|
||||
// UpdateUsageFromResponse records token usage using metadata from the fantasy
|
||||
// response when available. Falls back to text-based estimation if the metadata is
|
||||
// missing or appears unreliable. This provides more accurate usage tracking when
|
||||
// providers supply token count information.
|
||||
// response. Only actual API-reported tokens are used for cost tracking.
|
||||
// If the provider doesn't report token counts, no usage is recorded.
|
||||
func (c *CLI) UpdateUsageFromResponse(response *fantasy.Response, inputText string) {
|
||||
if c.usageTracker == nil {
|
||||
return
|
||||
@@ -191,8 +190,9 @@ func (c *CLI) UpdateUsageFromResponse(response *fantasy.Response, inputText stri
|
||||
inputTokens := int(usage.InputTokens)
|
||||
outputTokens := int(usage.OutputTokens)
|
||||
|
||||
// Validate that the metadata seems reasonable
|
||||
// Use API-reported tokens if input tokens are available (output may be 0 in some cases)
|
||||
// Only use actual API-reported tokens for cost tracking.
|
||||
// We intentionally do NOT estimate tokens - estimation is inaccurate
|
||||
// and should never be used for cost calculations.
|
||||
if inputTokens > 0 {
|
||||
cacheReadTokens := int(usage.CacheReadTokens)
|
||||
cacheWriteTokens := int(usage.CacheCreationTokens)
|
||||
@@ -200,11 +200,9 @@ func (c *CLI) UpdateUsageFromResponse(response *fantasy.Response, inputText stri
|
||||
// Per-response usage is a single API call, so it represents the
|
||||
// actual context window fill level.
|
||||
c.usageTracker.SetContextTokens(inputTokens + outputTokens)
|
||||
} else {
|
||||
// Fallback to estimation if no metadata is available.
|
||||
// EstimateAndUpdateUsage sets context tokens internally.
|
||||
c.usageTracker.EstimateAndUpdateUsage(inputText, response.Content.Text())
|
||||
}
|
||||
// If inputTokens is 0, the provider didn't report usage - we skip recording
|
||||
// rather than estimating, to ensure cost accuracy.
|
||||
}
|
||||
|
||||
// DisplayUsageAfterResponse renders and displays token usage information immediately
|
||||
|
||||
@@ -43,7 +43,7 @@ func (r *CompactRenderer) SetWidth(width int) {
|
||||
// while minimizing vertical space usage. Returns a UIMessage with formatted content
|
||||
// and metadata.
|
||||
func (r *CompactRenderer) RenderUserMessage(content string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
theme := GetTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Info).Render(">")
|
||||
label := lipgloss.NewStyle().Foreground(theme.Info).Bold(true).Render("User")
|
||||
|
||||
@@ -96,7 +96,7 @@ func (r *CompactRenderer) RenderAssistantMessage(content string, timestamp time.
|
||||
}
|
||||
}
|
||||
|
||||
theme := getTheme()
|
||||
theme := GetTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Primary).Render("<")
|
||||
|
||||
// Use the full model name, fallback to "Assistant" if empty
|
||||
@@ -127,35 +127,11 @@ func (r *CompactRenderer) RenderAssistantMessage(content string, timestamp time.
|
||||
}
|
||||
}
|
||||
|
||||
// RenderToolCallMessage renders a tool call notification in compact format, showing
|
||||
// the tool being executed with its arguments in a single line. The tool name is
|
||||
// highlighted and arguments are displayed in a muted color for visual distinction.
|
||||
func (r *CompactRenderer) RenderToolCallMessage(toolName, toolArgs string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Tool).Render("[")
|
||||
label := lipgloss.NewStyle().Foreground(theme.Tool).Bold(true).Render(toolName)
|
||||
|
||||
// Format args for compact display
|
||||
argsDisplay := r.formatToolArgs(toolArgs)
|
||||
if argsDisplay != "" {
|
||||
argsDisplay = lipgloss.NewStyle().Foreground(theme.Muted).Render(argsDisplay)
|
||||
}
|
||||
|
||||
line := fmt.Sprintf("%s %s %s", symbol, label, argsDisplay)
|
||||
|
||||
return UIMessage{
|
||||
Type: ToolCallMessage,
|
||||
Content: line,
|
||||
Height: 1,
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderToolMessage renders a unified tool block in compact format, combining
|
||||
// the tool invocation header (icon + display name + params) with the execution
|
||||
// result body. Status is indicated by icon: checkmark for success, cross for error.
|
||||
func (r *CompactRenderer) RenderToolMessage(toolName, toolArgs, toolResult string, isError bool) UIMessage {
|
||||
theme := getTheme()
|
||||
theme := GetTheme()
|
||||
|
||||
// Resolve extension renderer once for all overrides.
|
||||
var extRd *ToolRendererData
|
||||
@@ -244,7 +220,7 @@ func (r *CompactRenderer) RenderToolMessage(toolName, toolArgs, toolResult strin
|
||||
// compact format with a distinctive symbol (*) and "System" label. Content is
|
||||
// formatted to fit on a single line for minimal space usage.
|
||||
func (r *CompactRenderer) RenderSystemMessage(content string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
theme := GetTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Muted).Render("◇")
|
||||
label := lipgloss.NewStyle().Foreground(theme.Muted).Bold(true).Render("System")
|
||||
|
||||
@@ -264,7 +240,7 @@ func (r *CompactRenderer) RenderSystemMessage(content string, timestamp time.Tim
|
||||
// distinctive error symbol (!) and styling to ensure visibility. The error
|
||||
// content is displayed in a single line with appropriate color highlighting.
|
||||
func (r *CompactRenderer) RenderErrorMessage(errorMsg string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
theme := GetTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Error).Render("!")
|
||||
label := lipgloss.NewStyle().Foreground(theme.Error).Bold(true).Render("Error")
|
||||
|
||||
@@ -284,7 +260,7 @@ func (r *CompactRenderer) RenderErrorMessage(errorMsg string, timestamp time.Tim
|
||||
// mode is enabled. Messages are truncated if they exceed the available width to
|
||||
// maintain single-line display.
|
||||
func (r *CompactRenderer) RenderDebugMessage(message string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
theme := GetTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Tool).Render("*")
|
||||
label := lipgloss.NewStyle().Foreground(theme.Tool).Bold(true).Render("Debug")
|
||||
|
||||
@@ -308,7 +284,7 @@ func (r *CompactRenderer) RenderDebugMessage(message string, timestamp time.Time
|
||||
// debugging purposes. Config entries are displayed as key=value pairs separated
|
||||
// by commas, truncated if necessary to fit on a single line.
|
||||
func (r *CompactRenderer) RenderDebugConfigMessage(config map[string]any, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
theme := GetTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Tool).Render("*")
|
||||
label := lipgloss.NewStyle().Foreground(theme.Tool).Bold(true).Render("Debug")
|
||||
|
||||
@@ -426,32 +402,6 @@ func (r *CompactRenderer) wrapText(text string, width int) string {
|
||||
return strings.Join(wrappedLines, "\n")
|
||||
}
|
||||
|
||||
// formatToolArgs formats tool arguments for compact display
|
||||
func (r *CompactRenderer) formatToolArgs(args string) string {
|
||||
if args == "" || args == "{}" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Remove JSON braces and format compactly
|
||||
args = strings.TrimSpace(args)
|
||||
if strings.HasPrefix(args, "{") && strings.HasSuffix(args, "}") {
|
||||
args = strings.TrimPrefix(args, "{")
|
||||
args = strings.TrimSuffix(args, "}")
|
||||
args = strings.TrimSpace(args)
|
||||
}
|
||||
|
||||
// Remove quotes around simple values
|
||||
args = strings.ReplaceAll(args, `"`, "")
|
||||
|
||||
// Remove parameter names (e.g., "command: ls" -> "ls", "path: /home" -> "/home")
|
||||
// Look for pattern "key: value" and extract just the value
|
||||
if colonIndex := strings.Index(args, ":"); colonIndex != -1 {
|
||||
args = strings.TrimSpace(args[colonIndex+1:])
|
||||
}
|
||||
|
||||
return r.formatCompactContent(args)
|
||||
}
|
||||
|
||||
// formatToolResult formats tool results preserving formatting but limiting to 5 lines
|
||||
func (r *CompactRenderer) formatToolResult(result string) string {
|
||||
if result == "" {
|
||||
@@ -490,5 +440,5 @@ func (r *CompactRenderer) formatToolResult(result string) string {
|
||||
// and styling appropriately. Delegates tag parsing to the shared parseBashOutput
|
||||
// helper.
|
||||
func (r *CompactRenderer) formatBashOutput(result string) string {
|
||||
return parseBashOutput(result, getTheme())
|
||||
return parseBashOutput(result, GetTheme())
|
||||
}
|
||||
|
||||
@@ -35,8 +35,11 @@ func GetTheme() Theme {
|
||||
|
||||
// SetTheme updates the global UI theme, affecting all subsequent rendering
|
||||
// operations. This allows runtime theme switching for different visual preferences.
|
||||
// It also invalidates the markdownRendererCache so the next call to
|
||||
// GetMarkdownRenderer picks up the new theme's colors.
|
||||
func SetTheme(theme Theme) {
|
||||
currentTheme = theme
|
||||
markdownRendererCache = nil // invalidate cached renderer; colors may have changed
|
||||
}
|
||||
|
||||
// MarkdownThemeColors defines colors for markdown rendering and syntax highlighting.
|
||||
@@ -291,45 +294,3 @@ func ApplyGradient(text string, colorA, colorB color.Color) string {
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// CreateGradientText creates styled text with a gradient effect between two colors.
|
||||
func CreateGradientText(text string, startColor, endColor color.Color) string {
|
||||
return ApplyGradient(text, startColor, endColor)
|
||||
}
|
||||
|
||||
// Compact styling utilities
|
||||
|
||||
// StyleCompactSymbol creates a lipgloss style for message type indicators in
|
||||
// compact mode, using bold colored text to distinguish different message categories.
|
||||
func StyleCompactSymbol(symbol string, c color.Color) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(c).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleCompactLabel creates a lipgloss style for message labels in compact mode
|
||||
// with fixed width for alignment and bold colored text for readability.
|
||||
func StyleCompactLabel(c color.Color) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(c).
|
||||
Bold(true).
|
||||
Width(8)
|
||||
}
|
||||
|
||||
// StyleCompactContent creates a simple lipgloss style for message content in
|
||||
// compact mode, applying only color without additional formatting.
|
||||
func StyleCompactContent(c color.Color) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(c)
|
||||
}
|
||||
|
||||
// FormatCompactLine assembles a complete compact mode message line with consistent
|
||||
// spacing and styling. Combines a symbol, fixed-width label, and content with their
|
||||
// respective colors to create a uniform appearance across all message types.
|
||||
func FormatCompactLine(symbol, label, content string, symbolColor, labelColor, contentColor color.Color) string {
|
||||
styledSymbol := StyleCompactSymbol(symbol, symbolColor).Render(symbol)
|
||||
styledLabel := StyleCompactLabel(labelColor).Render(label)
|
||||
styledContent := StyleCompactContent(contentColor).Render(content)
|
||||
|
||||
return fmt.Sprintf("%s %-8s %s", styledSymbol, styledLabel, styledContent)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// FileSuggestion represents a single file or directory suggestion for the @
|
||||
@@ -345,44 +344,14 @@ func scoreFilePath(query, path string) int {
|
||||
}
|
||||
|
||||
// Fuzzy character match on basename.
|
||||
if score := fuzzyCharMatch(query, baseNameLower); score > 0 {
|
||||
if score := fuzzyCharacterMatch(query, baseNameLower); score > 0 {
|
||||
return score
|
||||
}
|
||||
|
||||
// Fuzzy character match on full path.
|
||||
if score := fuzzyCharMatch(query, pathLower); score > 0 {
|
||||
if score := fuzzyCharacterMatch(query, pathLower); score > 0 {
|
||||
return score - 50
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// fuzzyCharMatch performs character-by-character fuzzy matching. Returns a
|
||||
// positive score if all query characters appear in order in the target.
|
||||
func fuzzyCharMatch(query, target string) int {
|
||||
if utf8.RuneCountInString(query) > utf8.RuneCountInString(target) {
|
||||
return 0
|
||||
}
|
||||
|
||||
qRunes := []rune(query)
|
||||
tRunes := []rune(target)
|
||||
qi := 0
|
||||
score := 100
|
||||
consecutive := 0
|
||||
|
||||
for ti := 0; ti < len(tRunes) && qi < len(qRunes); ti++ {
|
||||
if tRunes[ti] == qRunes[qi] {
|
||||
qi++
|
||||
consecutive++
|
||||
score += consecutive * 5
|
||||
} else {
|
||||
consecutive = 0
|
||||
score -= 2
|
||||
}
|
||||
}
|
||||
|
||||
if qi < len(qRunes) {
|
||||
return 0
|
||||
}
|
||||
return score
|
||||
}
|
||||
|
||||
+11
-7
@@ -113,19 +113,23 @@ func fuzzyScore(query string, cmd *SlashCommand) int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// fuzzyCharacterMatch performs character-by-character fuzzy matching
|
||||
// fuzzyCharacterMatch performs character-by-character fuzzy matching using
|
||||
// rune-safe iteration so multi-byte Unicode characters are handled correctly.
|
||||
// Returns a positive score if all query runes appear in order within target.
|
||||
func fuzzyCharacterMatch(query, target string) int {
|
||||
if len(query) > len(target) {
|
||||
qRunes := []rune(query)
|
||||
tRunes := []rune(target)
|
||||
if len(qRunes) > len(tRunes) {
|
||||
return 0
|
||||
}
|
||||
|
||||
queryIdx := 0
|
||||
qi := 0
|
||||
score := 100
|
||||
consecutiveMatches := 0
|
||||
|
||||
for i := 0; i < len(target) && queryIdx < len(query); i++ {
|
||||
if target[i] == query[queryIdx] {
|
||||
queryIdx++
|
||||
for ti := 0; ti < len(tRunes) && qi < len(qRunes); ti++ {
|
||||
if tRunes[ti] == qRunes[qi] {
|
||||
qi++
|
||||
consecutiveMatches++
|
||||
score += consecutiveMatches * 10
|
||||
} else {
|
||||
@@ -135,7 +139,7 @@ func fuzzyCharacterMatch(query, target string) int {
|
||||
}
|
||||
|
||||
// Must match all characters in query
|
||||
if queryIdx < len(query) {
|
||||
if qi < len(qRunes) {
|
||||
return 0
|
||||
}
|
||||
|
||||
|
||||
+4
-11
@@ -409,21 +409,14 @@ func (s *InputComponent) handleSubmit(value string) tea.Cmd {
|
||||
}
|
||||
|
||||
// Resolve via canonical command lookup so aliases are handled uniformly.
|
||||
// Only /quit and /clear are handled locally — /clear-queue must go
|
||||
// through the parent model so it can update queueCount directly
|
||||
// (calling ClearQueue here would skip the UI state update since we
|
||||
// can't send events from within Update without deadlocking).
|
||||
// Only /quit is handled locally — all other slash commands (including
|
||||
// /clear and /clear-queue) are forwarded to the parent model via
|
||||
// submitMsg so the parent can update its own state (scrollback, queue
|
||||
// counts, etc.) in one place.
|
||||
if sc := GetCommandByName(trimmed); sc != nil {
|
||||
switch sc.Name {
|
||||
case "/quit":
|
||||
return tea.Quit
|
||||
|
||||
case "/clear":
|
||||
if s.appCtrl != nil {
|
||||
s.appCtrl.ClearMessages()
|
||||
}
|
||||
// Don't forward to app.Run(); just clear silently.
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+117
-358
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"charm.land/lipgloss/v2"
|
||||
"github.com/indaco/herald"
|
||||
)
|
||||
|
||||
// ansiEscapeRe matches ANSI escape sequences used for terminal styling.
|
||||
@@ -22,9 +23,9 @@ const (
|
||||
UserMessage MessageType = iota
|
||||
AssistantMessage
|
||||
ToolMessage
|
||||
ToolCallMessage // New type for showing tool calls in progress
|
||||
SystemMessage // New type for KIT system messages (help, tools, etc.)
|
||||
ErrorMessage // New type for error messages
|
||||
ToolCallMessage
|
||||
SystemMessage
|
||||
ErrorMessage
|
||||
)
|
||||
|
||||
// UIMessage encapsulates a fully rendered message ready for display in the UI,
|
||||
@@ -40,29 +41,9 @@ type UIMessage struct {
|
||||
Streaming bool
|
||||
}
|
||||
|
||||
// Helper functions to get theme colors
|
||||
func getTheme() Theme {
|
||||
return GetTheme()
|
||||
}
|
||||
|
||||
// toolDisplayNames maps raw tool names to human-friendly display names.
|
||||
var toolDisplayNames = map[string]string{
|
||||
"bash": "Bash",
|
||||
"read": "Read",
|
||||
"write": "Write",
|
||||
"edit": "Edit",
|
||||
"grep": "Grep",
|
||||
"find": "Find",
|
||||
"ls": "Ls",
|
||||
"run_shell_cmd": "Bash",
|
||||
}
|
||||
|
||||
// toolDisplayName returns a human-friendly display name for a tool.
|
||||
// Falls back to capitalizing the first letter of the raw name.
|
||||
// toolDisplayName returns a human-friendly display name for a tool,
|
||||
// title-casing the first letter of the raw name.
|
||||
func toolDisplayName(rawName string) string {
|
||||
if display, ok := toolDisplayNames[rawName]; ok {
|
||||
return display
|
||||
}
|
||||
if rawName != "" {
|
||||
return strings.ToUpper(rawName[:1]) + rawName[1:]
|
||||
}
|
||||
@@ -70,8 +51,6 @@ func toolDisplayName(rawName string) string {
|
||||
}
|
||||
|
||||
// formatToolParams formats tool input parameters for inline header display.
|
||||
// Extracts the primary parameter (command/filePath) first, then shows
|
||||
// remaining params as (key=val, ...). Truncates to maxWidth.
|
||||
func formatToolParams(toolArgs string, maxWidth int) string {
|
||||
args := strings.TrimSpace(toolArgs)
|
||||
if args == "" || args == "{}" {
|
||||
@@ -80,7 +59,6 @@ func formatToolParams(toolArgs string, maxWidth int) string {
|
||||
|
||||
var params map[string]any
|
||||
if err := json.Unmarshal([]byte(args), ¶ms); err != nil {
|
||||
// Fallback: strip braces and return raw content
|
||||
args = strings.TrimPrefix(args, "{")
|
||||
args = strings.TrimSuffix(args, "}")
|
||||
args = strings.TrimSpace(args)
|
||||
@@ -94,7 +72,6 @@ func formatToolParams(toolArgs string, maxWidth int) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Identify primary parameter by checking known keys in priority order
|
||||
primaryKeys := []string{"command", "filePath", "path", "pattern", "query", "url"}
|
||||
var primaryKey string
|
||||
var primaryVal string
|
||||
@@ -111,14 +88,13 @@ func formatToolParams(toolArgs string, maxWidth int) string {
|
||||
result.WriteString(primaryVal)
|
||||
}
|
||||
|
||||
// Collect remaining parameters, skipping body-content keys (already
|
||||
// rendered in the tool body) and any values that are too large.
|
||||
bodyKeys := map[string]bool{
|
||||
"content": true,
|
||||
"old_text": true,
|
||||
"new_text": true,
|
||||
"oldText": true,
|
||||
"newText": true,
|
||||
"edits": true,
|
||||
"todos": true,
|
||||
}
|
||||
var remaining []string
|
||||
@@ -154,65 +130,35 @@ func formatToolParams(toolArgs string, maxWidth int) string {
|
||||
}
|
||||
|
||||
// MessageRenderer handles the formatting and rendering of different message types
|
||||
// with consistent styling, markdown support, and appropriate visual hierarchies
|
||||
// for the standard (non-compact) display mode.
|
||||
type MessageRenderer struct {
|
||||
width int
|
||||
debug bool
|
||||
|
||||
// getToolRenderer returns extension-provided rendering overrides for a
|
||||
// specific tool. May be nil if no extensions are loaded. Used in
|
||||
// RenderToolMessage to check for custom header/body formatting before
|
||||
// falling back to builtin renderers.
|
||||
width int
|
||||
debug bool
|
||||
ty *herald.Typography
|
||||
getToolRenderer func(toolName string) *ToolRendererData
|
||||
}
|
||||
|
||||
// newMessageRenderer creates and initializes a new MessageRenderer with the specified
|
||||
// terminal width and debug mode setting. The width parameter determines line wrapping
|
||||
// and layout calculations.
|
||||
// newMessageRenderer creates and initializes a new MessageRenderer
|
||||
func newMessageRenderer(width int, debug bool) *MessageRenderer {
|
||||
return &MessageRenderer{
|
||||
width: width,
|
||||
debug: debug,
|
||||
ty: createTypography(GetTheme()),
|
||||
}
|
||||
}
|
||||
|
||||
// SetWidth updates the terminal width for the renderer, affecting how content
|
||||
// is wrapped and formatted in subsequent render operations.
|
||||
// SetWidth updates the terminal width for the renderer
|
||||
func (r *MessageRenderer) SetWidth(width int) {
|
||||
r.width = width
|
||||
}
|
||||
|
||||
// RenderUserMessage renders a user's input message with distinctive right-aligned
|
||||
// formatting, including the system username, timestamp, and markdown-rendered content.
|
||||
// The message is displayed with a colored right border for visual distinction.
|
||||
// RenderUserMessage renders a user's input message using herald Tip alert
|
||||
func (r *MessageRenderer) RenderUserMessage(content string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
|
||||
// Only run markdown rendering when the message contains code spans or
|
||||
// fenced code blocks. Plain text is rendered directly so that newlines
|
||||
// are preserved without the extra paragraph spacing glamour adds.
|
||||
var messageContent string
|
||||
if strings.Contains(content, "`") {
|
||||
// Glamour treats single \n as a soft break, so convert to paragraph
|
||||
// breaks and collapse the resulting blank lines after rendering.
|
||||
mdContent := strings.ReplaceAll(content, "\n", "\n\n")
|
||||
messageContent = r.renderMarkdown(mdContent, r.width-8)
|
||||
messageContent = removeBlankLines(messageContent)
|
||||
} else {
|
||||
messageContent = content
|
||||
if strings.TrimSpace(content) == "" {
|
||||
content = "(empty message)"
|
||||
}
|
||||
|
||||
fullContent := strings.TrimSuffix(messageContent, "\n")
|
||||
|
||||
// Left border with Blue color for user messages.
|
||||
rendered := renderContentBlock(
|
||||
fullContent,
|
||||
r.width,
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(theme.Info),
|
||||
WithMarginBottom(1),
|
||||
)
|
||||
rendered := r.ty.Tip(content)
|
||||
rendered = styleMarginBottom1.Render(rendered)
|
||||
|
||||
return UIMessage{
|
||||
Type: UserMessage,
|
||||
@@ -222,12 +168,8 @@ func (r *MessageRenderer) RenderUserMessage(content string, timestamp time.Time)
|
||||
}
|
||||
}
|
||||
|
||||
// RenderAssistantMessage renders an AI assistant's response with left-aligned formatting,
|
||||
// including the model name, timestamp, and markdown-rendered content. Empty responses
|
||||
// are ignored and return an empty message. The message features a colored left border
|
||||
// for visual distinction.
|
||||
// RenderAssistantMessage renders an AI assistant's response
|
||||
func (r *MessageRenderer) RenderAssistantMessage(content string, timestamp time.Time, modelName string) UIMessage {
|
||||
// Ignore empty responses - don't render anything
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return UIMessage{
|
||||
Type: AssistantMessage,
|
||||
@@ -237,17 +179,9 @@ func (r *MessageRenderer) RenderAssistantMessage(content string, timestamp time.
|
||||
}
|
||||
}
|
||||
|
||||
theme := getTheme()
|
||||
messageContent := r.renderMarkdown(content, r.width-8)
|
||||
fullContent := strings.TrimSuffix(messageContent, "\n")
|
||||
|
||||
// Left border with Primary (Mauve) color for assistant messages.
|
||||
rendered := renderContentBlock(
|
||||
fullContent,
|
||||
r.width,
|
||||
WithBorderColor(theme.Primary),
|
||||
WithMarginBottom(1),
|
||||
)
|
||||
// Use markdown rendering with Chroma syntax highlighting
|
||||
rendered := toMarkdown(content, r.width-4)
|
||||
rendered = styleMarginBottom1.Render(rendered)
|
||||
|
||||
return UIMessage{
|
||||
Type: AssistantMessage,
|
||||
@@ -257,30 +191,14 @@ func (r *MessageRenderer) RenderAssistantMessage(content string, timestamp time.
|
||||
}
|
||||
}
|
||||
|
||||
// RenderSystemMessage renders KIT system messages such as help text, command outputs,
|
||||
// and informational notifications. These messages are displayed with a distinctive system
|
||||
// color border and "KIT System" label to differentiate them from user and AI content.
|
||||
// RenderSystemMessage renders KIT system messages using herald Note alert
|
||||
func (r *MessageRenderer) RenderSystemMessage(content string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
|
||||
var messageContent string
|
||||
if strings.TrimSpace(content) == "" {
|
||||
messageContent = "No content available"
|
||||
} else if strings.Contains(content, "`") {
|
||||
messageContent = r.renderMarkdown(content, r.width-8)
|
||||
} else {
|
||||
messageContent = content
|
||||
content = "No content available"
|
||||
}
|
||||
|
||||
fullContent := "◇ " + strings.TrimSuffix(messageContent, "\n")
|
||||
|
||||
rendered := renderContentBlock(
|
||||
fullContent,
|
||||
r.width,
|
||||
WithNoBorder(),
|
||||
WithForeground(theme.Muted),
|
||||
WithMarginBottom(1),
|
||||
)
|
||||
rendered := r.ty.Note(content)
|
||||
rendered = styleMarginBottom1.Render(rendered)
|
||||
|
||||
return UIMessage{
|
||||
Type: SystemMessage,
|
||||
@@ -290,27 +208,9 @@ func (r *MessageRenderer) RenderSystemMessage(content string, timestamp time.Tim
|
||||
}
|
||||
}
|
||||
|
||||
// RenderDebugMessage renders diagnostic and debugging information with special formatting
|
||||
// including a debug icon, colored border, and structured layout. Debug messages are only
|
||||
// displayed when debug mode is enabled and help developers troubleshoot issues.
|
||||
// RenderDebugMessage renders diagnostic and debugging information
|
||||
func (r *MessageRenderer) RenderDebugMessage(message string, timestamp time.Time) UIMessage {
|
||||
baseStyle := lipgloss.NewStyle()
|
||||
|
||||
theme := getTheme()
|
||||
style := baseStyle.
|
||||
Width(r.width - 3).
|
||||
BorderLeft(true).
|
||||
Foreground(theme.Muted).
|
||||
BorderForeground(theme.Tool).
|
||||
BorderStyle(lipgloss.ThickBorder()).
|
||||
PaddingLeft(1).
|
||||
MarginLeft(2).
|
||||
MarginBottom(1)
|
||||
|
||||
header := baseStyle.
|
||||
Foreground(theme.Tool).
|
||||
Bold(true).
|
||||
Render("🔍 Debug Output")
|
||||
header := r.ty.H6("🔍 Debug Output")
|
||||
|
||||
lines := strings.Split(message, "\n")
|
||||
var formattedLines []string
|
||||
@@ -320,87 +220,52 @@ func (r *MessageRenderer) RenderDebugMessage(message string, timestamp time.Time
|
||||
}
|
||||
}
|
||||
|
||||
content := baseStyle.
|
||||
Foreground(theme.Muted).
|
||||
Render(strings.Join(formattedLines, "\n"))
|
||||
|
||||
fullContent := lipgloss.JoinVertical(lipgloss.Left,
|
||||
content := r.ty.Compose(
|
||||
header,
|
||||
content,
|
||||
r.ty.P(strings.Join(formattedLines, "\n")),
|
||||
)
|
||||
content = styleMarginBottom1.Render(content)
|
||||
|
||||
return UIMessage{
|
||||
Content: style.Render(fullContent),
|
||||
Height: lipgloss.Height(style.Render(fullContent)),
|
||||
Content: content,
|
||||
Height: lipgloss.Height(content),
|
||||
}
|
||||
}
|
||||
|
||||
// RenderDebugConfigMessage renders configuration settings in a formatted debug display
|
||||
// with key-value pairs shown in a structured layout. Used to display runtime configuration
|
||||
// for debugging purposes with a distinctive icon and border styling.
|
||||
// RenderDebugConfigMessage renders configuration settings
|
||||
func (r *MessageRenderer) RenderDebugConfigMessage(config map[string]any, timestamp time.Time) UIMessage {
|
||||
baseStyle := lipgloss.NewStyle()
|
||||
|
||||
theme := getTheme()
|
||||
style := baseStyle.
|
||||
Width(r.width - 1).
|
||||
BorderLeft(true).
|
||||
Foreground(theme.Muted).
|
||||
BorderForeground(theme.Tool).
|
||||
BorderStyle(lipgloss.ThickBorder()).
|
||||
PaddingLeft(1)
|
||||
|
||||
header := baseStyle.
|
||||
Foreground(theme.Tool).
|
||||
Bold(true).
|
||||
Render("🔧 Debug Configuration")
|
||||
header := r.ty.H6("🔧 Debug Configuration")
|
||||
|
||||
var configLines []string
|
||||
for key, value := range config {
|
||||
if value != nil {
|
||||
configLines = append(configLines, fmt.Sprintf(" %s: %v", key, value))
|
||||
configLines = append(configLines, fmt.Sprintf("%s: %v", key, value))
|
||||
}
|
||||
}
|
||||
|
||||
configContent := baseStyle.
|
||||
Foreground(theme.Muted).
|
||||
Render(strings.Join(configLines, "\n"))
|
||||
|
||||
parts := []string{header}
|
||||
var content string
|
||||
if len(configLines) > 0 {
|
||||
parts = append(parts, configContent)
|
||||
content = r.ty.Compose(
|
||||
header,
|
||||
r.ty.P(strings.Join(configLines, "\n")),
|
||||
)
|
||||
} else {
|
||||
content = header
|
||||
}
|
||||
|
||||
rendered := style.Render(
|
||||
lipgloss.JoinVertical(lipgloss.Left, parts...),
|
||||
)
|
||||
content = styleMarginBottom1.Render(content)
|
||||
|
||||
return UIMessage{
|
||||
Type: SystemMessage,
|
||||
Content: rendered,
|
||||
Height: lipgloss.Height(rendered),
|
||||
Content: content,
|
||||
Height: lipgloss.Height(content),
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderErrorMessage renders error notifications with distinctive red coloring and
|
||||
// bold text to ensure visibility. Error messages include timestamp information and
|
||||
// are displayed with an error-colored border for immediate recognition.
|
||||
// RenderErrorMessage renders error notifications
|
||||
func (r *MessageRenderer) RenderErrorMessage(errorMsg string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
|
||||
errorContent := lipgloss.NewStyle().
|
||||
Foreground(theme.Error).
|
||||
Bold(true).
|
||||
Render(errorMsg)
|
||||
|
||||
rendered := renderContentBlock(
|
||||
errorContent,
|
||||
r.width,
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(theme.Error),
|
||||
WithMarginBottom(1),
|
||||
)
|
||||
rendered := r.ty.Caution(errorMsg)
|
||||
rendered = styleMarginBottom1.Render(rendered)
|
||||
|
||||
return UIMessage{
|
||||
Type: ErrorMessage,
|
||||
@@ -410,93 +275,18 @@ func (r *MessageRenderer) RenderErrorMessage(errorMsg string, timestamp time.Tim
|
||||
}
|
||||
}
|
||||
|
||||
// RenderToolCallMessage renders a notification that a tool is being executed, showing
|
||||
// the tool name, formatted arguments (if any), and execution timestamp. The message
|
||||
// uses tool-specific coloring to distinguish it from regular conversation messages.
|
||||
func (r *MessageRenderer) RenderToolCallMessage(toolName, toolArgs string, timestamp time.Time) UIMessage {
|
||||
// Format timestamp
|
||||
timeStr := timestamp.Local().Format("15:04")
|
||||
|
||||
// Format arguments with better presentation
|
||||
theme := getTheme()
|
||||
var argsContent string
|
||||
if toolArgs != "" && toolArgs != "{}" {
|
||||
argsContent = lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Italic(true).
|
||||
Render(fmt.Sprintf("Arguments: %s", r.formatToolArgs(toolArgs)))
|
||||
}
|
||||
|
||||
// Create info line
|
||||
info := fmt.Sprintf(" Executing %s (%s)", toolName, timeStr)
|
||||
|
||||
// Combine parts
|
||||
var fullContent string
|
||||
if argsContent != "" {
|
||||
fullContent = argsContent + "\n" +
|
||||
lipgloss.NewStyle().Foreground(theme.VeryMuted).Render(info)
|
||||
} else {
|
||||
fullContent = lipgloss.NewStyle().Foreground(theme.VeryMuted).Render(info)
|
||||
}
|
||||
|
||||
// Use the new block renderer
|
||||
rendered := renderContentBlock(
|
||||
fullContent,
|
||||
r.width,
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(theme.Tool),
|
||||
WithMarginBottom(1),
|
||||
)
|
||||
|
||||
return UIMessage{
|
||||
Type: ToolCallMessage,
|
||||
Content: rendered,
|
||||
Height: lipgloss.Height(rendered),
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderToolMessage renders a unified tool block combining the tool invocation
|
||||
// header (icon + display name + params) with the execution result body. The
|
||||
// border color indicates status: green for success, red for error. This replaces
|
||||
// the previous two-block approach (separate call + result blocks).
|
||||
// RenderToolMessage renders a unified tool block
|
||||
func (r *MessageRenderer) RenderToolMessage(toolName, toolArgs, toolResult string, isError bool) UIMessage {
|
||||
theme := getTheme()
|
||||
|
||||
// Resolve extension renderer once for all overrides.
|
||||
var extRd *ToolRendererData
|
||||
if r.getToolRenderer != nil {
|
||||
extRd = r.getToolRenderer(toolName)
|
||||
}
|
||||
|
||||
// --- Header: [icon] [name] [params] ---
|
||||
var icon string
|
||||
borderColor := theme.Success
|
||||
iconColor := theme.Success
|
||||
if isError {
|
||||
icon = "×"
|
||||
borderColor = theme.Error
|
||||
iconColor = theme.Error
|
||||
} else {
|
||||
icon = "✓"
|
||||
}
|
||||
|
||||
// Extension can override border color (applies to both success and error).
|
||||
if extRd != nil && extRd.BorderColor != "" {
|
||||
borderColor = lipgloss.Color(extRd.BorderColor)
|
||||
}
|
||||
|
||||
iconStr := lipgloss.NewStyle().Foreground(iconColor).Bold(true).Render(icon)
|
||||
|
||||
// Extension can override display name.
|
||||
displayName := toolDisplayName(toolName)
|
||||
if extRd != nil && extRd.DisplayName != "" {
|
||||
displayName = extRd.DisplayName
|
||||
}
|
||||
nameStr := lipgloss.NewStyle().Foreground(theme.Info).Bold(true).Render(displayName)
|
||||
|
||||
// Format params with width budget for the header line.
|
||||
// Check extension renderer for custom header params first.
|
||||
paramBudget := max(r.width-10-len(displayName), 20)
|
||||
var params string
|
||||
if extRd != nil && extRd.RenderHeader != nil {
|
||||
@@ -506,97 +296,70 @@ func (r *MessageRenderer) RenderToolMessage(toolName, toolArgs, toolResult strin
|
||||
params = formatToolParams(toolArgs, paramBudget)
|
||||
}
|
||||
|
||||
header := iconStr + " " + nameStr
|
||||
if params != "" {
|
||||
header += " " + lipgloss.NewStyle().Foreground(theme.Muted).Render(params)
|
||||
var icon string
|
||||
iconColor := GetTheme().Success
|
||||
if isError {
|
||||
icon = "×"
|
||||
iconColor = GetTheme().Error
|
||||
} else {
|
||||
icon = "✓"
|
||||
}
|
||||
|
||||
// --- Body: check extension renderer first, then builtin, then default ---
|
||||
// Style the tool name with color
|
||||
theme := GetTheme()
|
||||
nameColor := theme.Info
|
||||
if isError {
|
||||
nameColor = theme.Error
|
||||
}
|
||||
styledName := lipgloss.NewStyle().Foreground(nameColor).Bold(true).Render(displayName)
|
||||
styledIcon := lipgloss.NewStyle().Foreground(iconColor).Render(icon)
|
||||
|
||||
// Build the content: icon + name + params on first line, then body
|
||||
headerLine := styledIcon + " " + styledName
|
||||
if params != "" {
|
||||
headerLine += " " + lipgloss.NewStyle().Foreground(theme.Muted).Render(params)
|
||||
}
|
||||
|
||||
// Get body content
|
||||
var body string
|
||||
if extRd != nil && extRd.RenderBody != nil {
|
||||
body = extRd.RenderBody(toolResult, isError, r.width-8)
|
||||
// Apply markdown rendering if requested and body is non-empty.
|
||||
if body != "" && extRd.BodyMarkdown {
|
||||
body = strings.TrimSuffix(toMarkdown(body, r.width-8), "\n")
|
||||
}
|
||||
}
|
||||
if body == "" {
|
||||
if isError {
|
||||
body = lipgloss.NewStyle().
|
||||
Foreground(theme.Error).
|
||||
Render(toolResult)
|
||||
body = r.formatToolResult(toolName, toolResult)
|
||||
} else {
|
||||
body = renderToolBody(toolName, toolArgs, toolResult, r.width-8)
|
||||
if body == "" {
|
||||
body = r.formatToolResult(toolName, toolResult, r.width-8)
|
||||
body = r.formatToolResult(toolName, toolResult)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(body) == "" {
|
||||
body = lipgloss.NewStyle().
|
||||
Italic(true).
|
||||
Foreground(theme.Muted).
|
||||
Render("(no output)")
|
||||
body = r.ty.Italic("(no output)")
|
||||
}
|
||||
|
||||
// Combine header + body into a single block.
|
||||
fullContent := header + "\n\n" + strings.TrimSuffix(body, "\n")
|
||||
|
||||
// Build rendering options; extension can override background.
|
||||
blockOpts := []renderingOption{
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(borderColor),
|
||||
WithMarginBottom(1),
|
||||
}
|
||||
if extRd != nil && extRd.Background != "" {
|
||||
blockOpts = append(blockOpts, WithBackground(lipgloss.Color(extRd.Background)))
|
||||
}
|
||||
|
||||
rendered := renderContentBlock(
|
||||
fullContent,
|
||||
r.width,
|
||||
blockOpts...,
|
||||
// Compose: icon + name + params, then body
|
||||
fullContent := r.ty.Compose(
|
||||
headerLine,
|
||||
"",
|
||||
body,
|
||||
)
|
||||
fullContent = styleMarginBottom1.Render(fullContent)
|
||||
|
||||
return UIMessage{
|
||||
Type: ToolMessage,
|
||||
Content: rendered,
|
||||
Height: lipgloss.Height(rendered),
|
||||
Content: fullContent,
|
||||
Height: lipgloss.Height(fullContent),
|
||||
}
|
||||
}
|
||||
|
||||
// formatToolArgs formats tool arguments for display
|
||||
func (r *MessageRenderer) formatToolArgs(args string) string {
|
||||
// Remove outer braces and clean up JSON formatting
|
||||
args = strings.TrimSpace(args)
|
||||
if strings.HasPrefix(args, "{") && strings.HasSuffix(args, "}") {
|
||||
args = strings.TrimPrefix(args, "{")
|
||||
args = strings.TrimSuffix(args, "}")
|
||||
args = strings.TrimSpace(args)
|
||||
}
|
||||
|
||||
// If it's empty after cleanup, return a placeholder
|
||||
if args == "" {
|
||||
return "(no arguments)"
|
||||
}
|
||||
|
||||
// Truncate if too long, but skip truncation in debug mode
|
||||
if !r.debug {
|
||||
maxLen := 100
|
||||
if len(args) > maxLen {
|
||||
return args[:maxLen] + "..."
|
||||
}
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
// formatToolResult formats tool results based on tool type
|
||||
func (r *MessageRenderer) formatToolResult(toolName, result string, width int) string {
|
||||
baseStyle := lipgloss.NewStyle()
|
||||
|
||||
// Truncate very long results only if not in debug mode
|
||||
func (r *MessageRenderer) formatToolResult(toolName, result string) string {
|
||||
if !r.debug {
|
||||
maxLines := 10
|
||||
lines := strings.Split(result, "\n")
|
||||
@@ -605,51 +368,47 @@ func (r *MessageRenderer) formatToolResult(toolName, result string, width int) s
|
||||
}
|
||||
}
|
||||
|
||||
// Format bash/command output with better formatting
|
||||
if strings.Contains(toolName, "bash") || strings.Contains(toolName, "command") || strings.Contains(toolName, "shell") || toolName == "run_shell_cmd" {
|
||||
theme := getTheme()
|
||||
|
||||
// Split result into sections if it contains both stdout and stderr
|
||||
if strings.Contains(toolName, "bash") || strings.Contains(toolName, "command") ||
|
||||
strings.Contains(toolName, "shell") {
|
||||
if strings.Contains(result, "<stdout>") || strings.Contains(result, "<stderr>") {
|
||||
return r.formatBashOutput(result, width, theme)
|
||||
return parseBashOutput(result, GetTheme())
|
||||
}
|
||||
|
||||
// For simple output, just render as monospace text with proper line breaks
|
||||
return baseStyle.
|
||||
Width(width).
|
||||
Foreground(theme.Muted).
|
||||
Render(result)
|
||||
}
|
||||
|
||||
// For other tools, render as muted text
|
||||
theme := getTheme()
|
||||
return baseStyle.
|
||||
Width(width).
|
||||
Foreground(theme.Muted).
|
||||
Render(result)
|
||||
return result
|
||||
}
|
||||
|
||||
// formatBashOutput formats bash command output with proper section handling.
|
||||
// Delegates tag parsing to the shared parseBashOutput helper.
|
||||
func (r *MessageRenderer) formatBashOutput(result string, width int, theme Theme) string {
|
||||
parsed := parseBashOutput(result, theme)
|
||||
return lipgloss.NewStyle().
|
||||
Width(width).
|
||||
Foreground(theme.Muted).
|
||||
Render(parsed)
|
||||
}
|
||||
|
||||
// renderMarkdown renders markdown content using glamour
|
||||
func (r *MessageRenderer) renderMarkdown(content string, width int) string {
|
||||
rendered := toMarkdown(content, width)
|
||||
return strings.TrimSuffix(rendered, "\n")
|
||||
// createTypography creates a typography instance from theme
|
||||
func createTypography(theme Theme) *herald.Typography {
|
||||
return herald.New(
|
||||
herald.WithPalette(herald.ColorPalette{
|
||||
Primary: theme.Primary,
|
||||
Secondary: theme.Secondary,
|
||||
Tertiary: theme.Info,
|
||||
Accent: theme.Accent,
|
||||
Highlight: theme.Highlight,
|
||||
Muted: theme.Muted,
|
||||
Text: theme.Text,
|
||||
Surface: theme.Background,
|
||||
Base: theme.CodeBg,
|
||||
}),
|
||||
herald.WithAlertPalette(herald.AlertPalette{
|
||||
Note: theme.Info,
|
||||
Tip: theme.Success,
|
||||
Important: theme.Accent,
|
||||
Warning: theme.Warning,
|
||||
Caution: theme.Error,
|
||||
}),
|
||||
herald.WithCodeLineNumbers(true),
|
||||
// Customize alert labels
|
||||
herald.WithAlertLabel(herald.AlertNote, "Info"),
|
||||
herald.WithAlertLabel(herald.AlertTip, "You"),
|
||||
herald.WithAlertLabel(herald.AlertWarning, "Working"),
|
||||
herald.WithAlertLabel(herald.AlertCaution, "Error"),
|
||||
)
|
||||
}
|
||||
|
||||
// removeBlankLines removes lines that are visually blank from rendered output.
|
||||
// Glamour wraps every character (including padding spaces) with ANSI color
|
||||
// codes, so we must strip escape sequences before checking whether a line is
|
||||
// empty. This collapses paragraph spacing so user messages render without
|
||||
// extra vertical gaps.
|
||||
func removeBlankLines(s string) string {
|
||||
lines := strings.Split(s, "\n")
|
||||
filtered := lines[:0]
|
||||
|
||||
+134
-131
@@ -8,7 +8,6 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
@@ -75,6 +74,11 @@ type AppController interface {
|
||||
ClearQueue()
|
||||
// ClearMessages clears the conversation history.
|
||||
ClearMessages()
|
||||
// ReloadMessagesFromTree clears the in-memory message store and reloads
|
||||
// it from the tree session's current branch. Unlike ClearMessages, this
|
||||
// does NOT reset the tree session's leaf pointer. Used after Branch() to
|
||||
// sync the store with the new branch position.
|
||||
ReloadMessagesFromTree()
|
||||
// CompactConversation summarises older messages to free context space.
|
||||
// Runs asynchronously; results are delivered via CompactCompleteEvent or
|
||||
// CompactErrorEvent sent through the registered tea.Program. Returns an
|
||||
@@ -149,6 +153,22 @@ type ToolRendererData struct {
|
||||
RenderBody func(toolResult string, isError bool, width int) string
|
||||
}
|
||||
|
||||
// noopCmd is a sentinel tea.Cmd returned by handlers that have consumed an
|
||||
// event but produce no side-effects. It returns a nil Msg which BubbleTea
|
||||
// discards, but its non-nil value lets callers distinguish "handled" from
|
||||
// "not handled" (nil tea.Cmd).
|
||||
var noopCmd tea.Cmd = func() tea.Msg { return nil }
|
||||
|
||||
// Package-level lipgloss styles that are invariant across frames (only depend
|
||||
// on theme colors, which are updated via SetTheme). Defined at package level
|
||||
// to avoid allocating new lipgloss.Style structs on every render call.
|
||||
//
|
||||
// Note: theme-sensitive styles (those using theme.Warning, theme.Muted, etc.)
|
||||
// are rebuilt on theme change via ApplyTheme. The cancel warning style
|
||||
// intentionally reads the theme at render time because themes can change at
|
||||
// runtime; only truly static styles belong here.
|
||||
var styleMarginBottom1 = lipgloss.NewStyle().MarginBottom(1)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Editor interceptor types (UI-layer, decoupled from extensions package)
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -584,24 +604,30 @@ type AppModel struct {
|
||||
streamingBashStderr []string
|
||||
// streamingBashMaxLines caps how many lines to accumulate to prevent memory issues.
|
||||
streamingBashMaxLines int
|
||||
// streamingMu protects the streaming bash output fields from concurrent access.
|
||||
streamingMu sync.RWMutex
|
||||
// streaming bash fields are only mutated/read from the Bubble Tea event loop
|
||||
// (Update/View), so no mutex is required here.
|
||||
// streamingBashCommand holds the command being executed for display as a header.
|
||||
streamingBashCommand string
|
||||
|
||||
// ---------- Cached layout heights (invalidated by layoutDirty) ----------
|
||||
|
||||
// layoutDirty marks that distributeHeight must recompute the stream height
|
||||
// on the next View() call. Set by any state change that affects sizing
|
||||
// (resize, queue changes, widget updates, visibility changes, etc.).
|
||||
// View() calls distributeHeight() when this is true and then clears it.
|
||||
layoutDirty bool
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Child component interfaces (stubs until TAS-15/16/17 implement them)
|
||||
// Child component interfaces
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// inputComponentIface is the interface the parent requires from InputComponent.
|
||||
// It will be satisfied by the real InputComponent created in TAS-15.
|
||||
type inputComponentIface interface {
|
||||
tea.Model
|
||||
}
|
||||
|
||||
// streamComponentIface is the interface the parent requires from StreamComponent.
|
||||
// It will be satisfied by the real StreamComponent created in TAS-16.
|
||||
type streamComponentIface interface {
|
||||
tea.Model
|
||||
// Reset clears accumulated state between agent steps.
|
||||
@@ -754,16 +780,9 @@ func NewAppModel(appCtrl AppController, opts AppModelOptions) *AppModel {
|
||||
// Init implements tea.Model. Initialises child components. Startup info is
|
||||
// printed to stdout before the program starts via PrintStartupInfo().
|
||||
func (m *AppModel) Init() tea.Cmd {
|
||||
var cmds []tea.Cmd
|
||||
|
||||
if m.input != nil {
|
||||
cmds = append(cmds, m.input.Init())
|
||||
}
|
||||
if m.stream != nil {
|
||||
cmds = append(cmds, m.stream.Init())
|
||||
}
|
||||
|
||||
return tea.Batch(cmds...)
|
||||
// m.input is always set by NewAppModel; its Init starts the textarea cursor blink.
|
||||
// m.stream.Init() always returns nil, so there is nothing to batch.
|
||||
return m.input.Init()
|
||||
}
|
||||
|
||||
// uiVis returns the current UIVisibility, defaulting to zero value (show all)
|
||||
@@ -787,28 +806,29 @@ func (m *AppModel) PrintStartupInfo() {
|
||||
return
|
||||
}
|
||||
|
||||
render := func(text string) string {
|
||||
return m.renderer.RenderSystemMessage(text, time.Now()).Content
|
||||
}
|
||||
// Create typography instance for startup rendering
|
||||
ty := createTypography(GetTheme())
|
||||
|
||||
fmt.Println()
|
||||
|
||||
// Build the combined startup content.
|
||||
var lines []string
|
||||
// Build key-value pairs for startup info
|
||||
var pairs [][2]string
|
||||
|
||||
if m.providerName != "" && m.modelName != "" {
|
||||
lines = append(lines, fmt.Sprintf("Model loaded: %s (%s)", m.providerName, m.modelName))
|
||||
pairs = append(pairs, [2]string{"Model", fmt.Sprintf("%s (%s)", m.providerName, m.modelName)})
|
||||
}
|
||||
|
||||
if m.loadingMessage != "" {
|
||||
lines = append(lines, m.loadingMessage)
|
||||
pairs = append(pairs, [2]string{"Status", m.loadingMessage})
|
||||
}
|
||||
|
||||
// Context — loaded AGENTS.md files.
|
||||
if len(m.contextPaths) > 0 {
|
||||
for _, p := range m.contextPaths {
|
||||
lines = append(lines, fmt.Sprintf("Context: %s", tildeHome(p)))
|
||||
contextStr := tildeHome(m.contextPaths[0])
|
||||
if len(m.contextPaths) > 1 {
|
||||
contextStr += fmt.Sprintf(" +%d more", len(m.contextPaths)-1)
|
||||
}
|
||||
pairs = append(pairs, [2]string{"Context", contextStr})
|
||||
}
|
||||
|
||||
// Skills — listed by name.
|
||||
@@ -817,21 +837,23 @@ func (m *AppModel) PrintStartupInfo() {
|
||||
for i, si := range m.skillItems {
|
||||
names[i] = si.Name
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("Skills: %s", strings.Join(names, ", ")))
|
||||
pairs = append(pairs, [2]string{"Skills", strings.Join(names, ", ")})
|
||||
}
|
||||
|
||||
// Extension tool count (only shown when > 0).
|
||||
if m.extensionToolCount > 0 {
|
||||
lines = append(lines, fmt.Sprintf("Loaded %d extension tools", m.extensionToolCount))
|
||||
pairs = append(pairs, [2]string{"Extensions", fmt.Sprintf("%d tools", m.extensionToolCount)})
|
||||
}
|
||||
|
||||
// MCP tool count (only shown when > 0).
|
||||
if m.mcpToolCount > 0 {
|
||||
lines = append(lines, fmt.Sprintf("Loaded %d tools from MCP servers", m.mcpToolCount))
|
||||
pairs = append(pairs, [2]string{"MCP", fmt.Sprintf("%d tools", m.mcpToolCount)})
|
||||
}
|
||||
|
||||
if len(lines) > 0 {
|
||||
fmt.Println(render(strings.Join(lines, "\n\n")))
|
||||
if len(pairs) > 0 {
|
||||
rendered := ty.KVGroup(pairs)
|
||||
rendered = styleMarginBottom1.Render(rendered)
|
||||
fmt.Println(rendered)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -901,7 +923,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}()
|
||||
m.treeSelector = nil
|
||||
m.state = stateInput
|
||||
return m, func() tea.Msg { return nil }
|
||||
return m, noopCmd
|
||||
}
|
||||
|
||||
cmds = append(cmds, m.performFork(targetID, msg.IsUser, msg.UserText))
|
||||
@@ -983,14 +1005,16 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
case tea.WindowSizeMsg:
|
||||
m.width = msg.Width
|
||||
m.height = msg.Height
|
||||
m.distributeHeight()
|
||||
m.layoutDirty = true
|
||||
// Propagate to children.
|
||||
if m.input != nil {
|
||||
_, cmd := m.input.Update(msg)
|
||||
updated, cmd := m.input.Update(msg)
|
||||
m.input, _ = updated.(inputComponentIface)
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
if m.stream != nil {
|
||||
_, cmd := m.stream.Update(msg)
|
||||
updated, cmd := m.stream.Update(msg)
|
||||
m.stream, _ = updated.(streamComponentIface)
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
|
||||
@@ -1114,7 +1138,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
sLen := m.appCtrl.Steer(processedText)
|
||||
if sLen > 0 {
|
||||
m.steeringMessages = append(m.steeringMessages, text)
|
||||
m.distributeHeight()
|
||||
m.layoutDirty = true
|
||||
} else {
|
||||
// Started immediately (agent was idle).
|
||||
m.pendingUserPrints = append(m.pendingUserPrints, text)
|
||||
@@ -1185,63 +1209,17 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// ── Input submitted ──────────────────────────────────────────────────────
|
||||
case submitMsg:
|
||||
// Handle slash commands locally — they should never reach app.Run().
|
||||
if sc := GetCommandByName(msg.Text); sc != nil {
|
||||
if cmd := m.handleSlashCommand(sc); cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
cmds = append(cmds, m.drainScrollback())
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
// /compact and /model support optional args (e.g. "/compact Focus on API",
|
||||
// "/model anthropic/claude-haiku-3-5-20241022").
|
||||
// GetCommandByName won't match the full text, so check the prefix.
|
||||
if name, args, ok := strings.Cut(msg.Text, " "); ok {
|
||||
// Parse once: split on the first space so argument-bearing commands
|
||||
// (e.g. "/model anthropic/foo", "/compact Focus on X") are matched by
|
||||
// their name and their args are passed through to the handler.
|
||||
if strings.HasPrefix(msg.Text, "/") {
|
||||
name, args, _ := strings.Cut(msg.Text, " ")
|
||||
if sc := GetCommandByName(name); sc != nil {
|
||||
switch sc.Name {
|
||||
case "/compact":
|
||||
if cmd := m.handleCompactCommand(strings.TrimSpace(args)); cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
cmds = append(cmds, m.drainScrollback())
|
||||
return m, tea.Batch(cmds...)
|
||||
case "/model":
|
||||
if cmd := m.handleModelCommand(strings.TrimSpace(args)); cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
cmds = append(cmds, m.drainScrollback())
|
||||
return m, tea.Batch(cmds...)
|
||||
case "/thinking":
|
||||
if cmd := m.handleThinkingCommand(strings.TrimSpace(args)); cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
cmds = append(cmds, m.drainScrollback())
|
||||
return m, tea.Batch(cmds...)
|
||||
case "/theme":
|
||||
if cmd := m.handleThemeCommand(strings.TrimSpace(args)); cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
cmds = append(cmds, m.drainScrollback())
|
||||
return m, tea.Batch(cmds...)
|
||||
case "/name":
|
||||
if cmd := m.handleNameCommand(strings.TrimSpace(args)); cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
cmds = append(cmds, m.drainScrollback())
|
||||
return m, tea.Batch(cmds...)
|
||||
case "/export":
|
||||
if cmd := m.handleExportCommand(strings.TrimSpace(args)); cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
cmds = append(cmds, m.drainScrollback())
|
||||
return m, tea.Batch(cmds...)
|
||||
case "/import":
|
||||
if cmd := m.handleImportCommand(strings.TrimSpace(args)); cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
cmds = append(cmds, m.drainScrollback())
|
||||
return m, tea.Batch(cmds...)
|
||||
if cmd := m.handleSlashCommand(sc, strings.TrimSpace(args)); cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
cmds = append(cmds, m.drainScrollback())
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1298,7 +1276,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// "queued" badge. It will be printed to scrollback when
|
||||
// the agent picks it up (via SpinnerEvent).
|
||||
m.queuedMessages = append(m.queuedMessages, displayText)
|
||||
m.distributeHeight()
|
||||
m.layoutDirty = true
|
||||
} else {
|
||||
// Started immediately. Flush any leftover stream content
|
||||
// from the previous step first, then print the user
|
||||
@@ -1319,7 +1297,8 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// Show spinner while the shell command runs.
|
||||
m.state = stateWorking
|
||||
if m.stream != nil {
|
||||
_, cmd := m.stream.Update(app.SpinnerEvent{Show: true})
|
||||
updated, cmd := m.stream.Update(app.SpinnerEvent{Show: true})
|
||||
m.stream, _ = updated.(streamComponentIface)
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
// Execute the shell command asynchronously so the TUI stays responsive.
|
||||
@@ -1328,7 +1307,8 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
case shellCommandResultMsg:
|
||||
// Stop spinner now that the command has finished.
|
||||
if m.stream != nil {
|
||||
_, cmd := m.stream.Update(app.SpinnerEvent{Show: false})
|
||||
updated, cmd := m.stream.Update(app.SpinnerEvent{Show: false})
|
||||
m.stream, _ = updated.(streamComponentIface)
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
m.state = stateInput
|
||||
@@ -1346,22 +1326,25 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if msg.Show {
|
||||
m.flushStreamAndPendingUserMessages()
|
||||
m.state = stateWorking
|
||||
m.distributeHeight()
|
||||
m.layoutDirty = true
|
||||
}
|
||||
if m.stream != nil {
|
||||
_, cmd := m.stream.Update(msg)
|
||||
updated, cmd := m.stream.Update(msg)
|
||||
m.stream, _ = updated.(streamComponentIface)
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
|
||||
case app.ReasoningChunkEvent:
|
||||
if m.stream != nil {
|
||||
_, cmd := m.stream.Update(msg)
|
||||
updated, cmd := m.stream.Update(msg)
|
||||
m.stream, _ = updated.(streamComponentIface)
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
|
||||
case app.StreamChunkEvent:
|
||||
if m.stream != nil {
|
||||
_, cmd := m.stream.Update(msg)
|
||||
updated, cmd := m.stream.Update(msg)
|
||||
m.stream, _ = updated.(streamComponentIface)
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
|
||||
@@ -1378,16 +1361,15 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
Command string `json:"command"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(msg.ToolArgs), &args); err == nil && args.Command != "" {
|
||||
m.streamingMu.Lock()
|
||||
m.streamingBashCommand = args.Command
|
||||
m.streamingMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
case app.ToolExecutionEvent:
|
||||
// Pass to stream component for execution spinner display.
|
||||
if m.stream != nil {
|
||||
_, cmd := m.stream.Update(msg)
|
||||
updated, cmd := m.stream.Update(msg)
|
||||
m.stream, _ = updated.(streamComponentIface)
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
|
||||
@@ -1395,20 +1377,18 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// Buffer tool result for scrollback.
|
||||
m.printToolResult(msg)
|
||||
// Clear streaming bash output since tool completed.
|
||||
m.streamingMu.Lock()
|
||||
m.streamingBashOutput = nil
|
||||
m.streamingBashStderr = nil
|
||||
m.streamingBashCommand = ""
|
||||
m.streamingMu.Unlock()
|
||||
// Start spinner again while waiting for the next LLM response.
|
||||
if m.stream != nil {
|
||||
_, cmd := m.stream.Update(app.SpinnerEvent{Show: true})
|
||||
updated, cmd := m.stream.Update(app.SpinnerEvent{Show: true})
|
||||
m.stream, _ = updated.(streamComponentIface)
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
|
||||
case app.ToolOutputEvent:
|
||||
// Accumulate streaming bash output for display.
|
||||
m.streamingMu.Lock()
|
||||
if msg.IsStderr {
|
||||
m.streamingBashStderr = append(m.streamingBashStderr, msg.Chunk)
|
||||
// Cap stderr lines to prevent memory issues.
|
||||
@@ -1422,7 +1402,6 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.streamingBashOutput = m.streamingBashOutput[len(m.streamingBashOutput)-m.streamingBashMaxLines:]
|
||||
}
|
||||
}
|
||||
m.streamingMu.Unlock()
|
||||
|
||||
case app.ToolCallContentEvent:
|
||||
// In streaming mode this text was already delivered via StreamChunkEvents
|
||||
@@ -1456,7 +1435,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.queuedMessages = m.queuedMessages[1:]
|
||||
m.pendingUserPrints = append(m.pendingUserPrints, text)
|
||||
}
|
||||
m.distributeHeight()
|
||||
m.layoutDirty = true
|
||||
|
||||
case app.SteerConsumedEvent:
|
||||
// Steering messages were consumed — either injected mid-turn via
|
||||
@@ -1481,13 +1460,13 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.printUserMessage(text)
|
||||
}
|
||||
m.steeringMessages = m.steeringMessages[:0]
|
||||
m.distributeHeight()
|
||||
m.layoutDirty = true
|
||||
cmds = append(cmds, m.drainScrollback())
|
||||
} else {
|
||||
// Case 2: post-turn — defer so SpinnerEvent orders correctly.
|
||||
m.pendingUserPrints = append(m.pendingUserPrints, m.steeringMessages...)
|
||||
m.steeringMessages = m.steeringMessages[:0]
|
||||
m.distributeHeight()
|
||||
m.layoutDirty = true
|
||||
}
|
||||
|
||||
case app.StepCompleteEvent:
|
||||
@@ -1498,7 +1477,8 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// scrollback when the next step starts (SpinnerEvent{Show: true}).
|
||||
// Just stop the spinner and return to input state.
|
||||
if m.stream != nil {
|
||||
_, cmd := m.stream.Update(app.SpinnerEvent{Show: false})
|
||||
updated, cmd := m.stream.Update(app.SpinnerEvent{Show: false})
|
||||
m.stream, _ = updated.(streamComponentIface)
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
m.state = stateInput
|
||||
@@ -1508,7 +1488,8 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// User cancelled the step (double-ESC). Keep partial stream content
|
||||
// visible (same reasoning as StepCompleteEvent). Just stop the spinner.
|
||||
if m.stream != nil {
|
||||
_, cmd := m.stream.Update(app.SpinnerEvent{Show: false})
|
||||
updated, cmd := m.stream.Update(app.SpinnerEvent{Show: false})
|
||||
m.stream, _ = updated.(streamComponentIface)
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
m.state = stateInput
|
||||
@@ -1519,7 +1500,8 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// StepCompleteEvent). Print the error to scrollback — it appears
|
||||
// above the view, and the partial response stays visible below.
|
||||
if m.stream != nil {
|
||||
_, cmd := m.stream.Update(app.SpinnerEvent{Show: false})
|
||||
updated, cmd := m.stream.Update(app.SpinnerEvent{Show: false})
|
||||
m.stream, _ = updated.(streamComponentIface)
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
if msg.Err != nil {
|
||||
@@ -1552,7 +1534,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// Extension widget changed — recalculate height distribution so the
|
||||
// stream region accounts for widget space. View() will read the
|
||||
// latest widget state on the next render.
|
||||
m.distributeHeight()
|
||||
m.layoutDirty = true
|
||||
|
||||
// Refresh extension commands (e.g. after hot-reload). The callback
|
||||
// returns the current set from the runner which may have changed.
|
||||
@@ -1697,11 +1679,13 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
default:
|
||||
// Pass unrecognised messages to all children.
|
||||
if m.input != nil {
|
||||
_, cmd := m.input.Update(msg)
|
||||
updated, cmd := m.input.Update(msg)
|
||||
m.input, _ = updated.(inputComponentIface)
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
if m.stream != nil {
|
||||
_, cmd := m.stream.Update(msg)
|
||||
updated, cmd := m.stream.Update(msg)
|
||||
m.stream, _ = updated.(streamComponentIface)
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
}
|
||||
@@ -1735,6 +1719,15 @@ func (m *AppModel) View() tea.View {
|
||||
return tea.NewView(m.overlay.Render())
|
||||
}
|
||||
|
||||
// Recompute layout heights if any Update() changed state that affects
|
||||
// sizing. Deferring this to View() guarantees exactly one call per frame
|
||||
// regardless of how many events triggered a layout change in a single
|
||||
// Update() invocation.
|
||||
if m.layoutDirty {
|
||||
m.distributeHeight()
|
||||
m.layoutDirty = false
|
||||
}
|
||||
|
||||
vis := m.uiVis()
|
||||
|
||||
streamView := m.renderStream()
|
||||
@@ -1847,13 +1840,11 @@ func (m *AppModel) renderStream() string {
|
||||
// Lines are truncated to the terminal width and capped to maxBashLines to prevent
|
||||
// long-running commands from blowing up the TUI layout.
|
||||
func (m *AppModel) renderStreamingBashOutput(theme Theme) string {
|
||||
m.streamingMu.RLock()
|
||||
stdoutLines := make([]string, len(m.streamingBashOutput))
|
||||
copy(stdoutLines, m.streamingBashOutput)
|
||||
stderrLines := make([]string, len(m.streamingBashStderr))
|
||||
copy(stderrLines, m.streamingBashStderr)
|
||||
command := m.streamingBashCommand
|
||||
m.streamingMu.RUnlock()
|
||||
|
||||
if len(stdoutLines) == 0 && len(stderrLines) == 0 {
|
||||
return ""
|
||||
@@ -2244,10 +2235,10 @@ func (m *AppModel) printErrorResponse(evt app.StepErrorEvent) {
|
||||
// Slash command handlers
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// handleSlashCommand executes a recognized slash command and returns a tea.Cmd
|
||||
// that emits the appropriate output to scrollback. Returns tea.Quit for /quit,
|
||||
// nil for commands with no visible output, or a tea.Println cmd for display.
|
||||
func (m *AppModel) handleSlashCommand(sc *SlashCommand) tea.Cmd {
|
||||
// handleSlashCommand executes a recognized slash command and returns a tea.Cmd.
|
||||
// args contains any text after the command name (may be empty). Returns tea.Quit
|
||||
// for /quit, nil for commands with no output, or a tea.Println cmd for display.
|
||||
func (m *AppModel) handleSlashCommand(sc *SlashCommand, args string) tea.Cmd {
|
||||
switch sc.Name {
|
||||
case "/quit":
|
||||
return tea.Quit
|
||||
@@ -2262,13 +2253,13 @@ func (m *AppModel) handleSlashCommand(sc *SlashCommand) tea.Cmd {
|
||||
case "/reset-usage":
|
||||
m.printResetUsage()
|
||||
case "/model":
|
||||
return m.handleModelCommand("")
|
||||
return m.handleModelCommand(args)
|
||||
case "/theme":
|
||||
return m.handleThemeCommand("")
|
||||
return m.handleThemeCommand(args)
|
||||
case "/thinking":
|
||||
return m.handleThinkingCommand("")
|
||||
return m.handleThinkingCommand(args)
|
||||
case "/compact":
|
||||
return m.handleCompactCommand("")
|
||||
return m.handleCompactCommand(args)
|
||||
case "/clear":
|
||||
if m.appCtrl != nil {
|
||||
m.appCtrl.ClearMessages()
|
||||
@@ -2280,7 +2271,7 @@ func (m *AppModel) handleSlashCommand(sc *SlashCommand) tea.Cmd {
|
||||
}
|
||||
m.queuedMessages = m.queuedMessages[:0]
|
||||
m.steeringMessages = m.steeringMessages[:0]
|
||||
m.distributeHeight()
|
||||
m.layoutDirty = true
|
||||
|
||||
case "/tree":
|
||||
return m.handleTreeCommand()
|
||||
@@ -2289,15 +2280,15 @@ func (m *AppModel) handleSlashCommand(sc *SlashCommand) tea.Cmd {
|
||||
case "/new":
|
||||
return m.handleNewCommand()
|
||||
case "/name":
|
||||
return m.handleNameCommand("")
|
||||
return m.handleNameCommand(args)
|
||||
case "/resume":
|
||||
return m.handleResumeCommand()
|
||||
case "/export":
|
||||
return m.handleExportCommand("")
|
||||
return m.handleExportCommand(args)
|
||||
case "/share":
|
||||
return m.handleShareCommand()
|
||||
case "/import":
|
||||
return m.handleImportCommand("")
|
||||
return m.handleImportCommand(args)
|
||||
case "/session":
|
||||
return m.handleSessionInfoCommand()
|
||||
|
||||
@@ -2384,7 +2375,7 @@ func (m *AppModel) handleExtensionCommand(text string) tea.Cmd {
|
||||
// Return a non-nil Cmd so the caller knows the command was handled
|
||||
// and doesn't fall through to the regular prompt path. The Cmd itself
|
||||
// is a no-op.
|
||||
return func() tea.Msg { return nil }
|
||||
return noopCmd
|
||||
}
|
||||
|
||||
// expandPromptTemplate checks if the submitted text matches a prompt template
|
||||
@@ -3010,7 +3001,7 @@ func (m *AppModel) handleNewCommand() tea.Cmd {
|
||||
reason: reason,
|
||||
})
|
||||
}()
|
||||
return func() tea.Msg { return nil }
|
||||
return noopCmd
|
||||
}
|
||||
|
||||
return m.performNewSession()
|
||||
@@ -3027,6 +3018,10 @@ func (m *AppModel) performNewSession() tea.Cmd {
|
||||
if m.appCtrl != nil {
|
||||
m.appCtrl.ClearMessages()
|
||||
}
|
||||
// Reset usage statistics for fresh session
|
||||
if m.usageTracker != nil {
|
||||
m.usageTracker.Reset()
|
||||
}
|
||||
m.printSystemMessage("Conversation cleared. Starting fresh.")
|
||||
return nil
|
||||
}
|
||||
@@ -3040,6 +3035,10 @@ func (m *AppModel) performNewSession() tea.Cmd {
|
||||
|
||||
// Switch to the new session, closing the old one
|
||||
m.appCtrl.SwitchTreeSession(newTs)
|
||||
// Reset usage statistics for the new session
|
||||
if m.usageTracker != nil {
|
||||
m.usageTracker.Reset()
|
||||
}
|
||||
m.printSystemMessage("New session started. Previous conversation saved.")
|
||||
return nil
|
||||
}
|
||||
@@ -3053,8 +3052,12 @@ func (m *AppModel) performFork(targetID string, isUser bool, userText string) te
|
||||
return nil
|
||||
}
|
||||
|
||||
// Branch the tree session to the target entry. We must NOT call
|
||||
// ClearMessages() here because it resets the leaf pointer back to "",
|
||||
// undoing the branch we just set. Instead, branch first and then
|
||||
// reload the in-memory store from the tree session's current branch.
|
||||
_ = ts.Branch(targetID)
|
||||
m.appCtrl.ClearMessages()
|
||||
m.appCtrl.ReloadMessagesFromTree()
|
||||
|
||||
// If it was a user message, populate the input with the text.
|
||||
if isUser && userText != "" {
|
||||
|
||||
@@ -50,7 +50,7 @@ func NewModelSelector(currentModel string, width, height int) *ModelSelectorComp
|
||||
registry := models.GetGlobalRegistry()
|
||||
var allModels []ModelEntry
|
||||
|
||||
for _, providerID := range registry.GetFantasyProviders() {
|
||||
for _, providerID := range registry.GetLLMProviders() {
|
||||
// Only include providers with valid API keys configured.
|
||||
if err := registry.ValidateEnvironment(providerID, ""); err != nil {
|
||||
continue
|
||||
|
||||
@@ -46,6 +46,10 @@ func (s *stubAppController) ClearMessages() {
|
||||
s.clearMsgCalled++
|
||||
}
|
||||
|
||||
func (s *stubAppController) ReloadMessagesFromTree() {
|
||||
// no-op in tests
|
||||
}
|
||||
|
||||
func (s *stubAppController) CompactConversation(_ string) error {
|
||||
return nil
|
||||
}
|
||||
@@ -142,7 +146,11 @@ func newTestAppModel(ctrl AppController) (*AppModel, *stubStreamComponent, *stub
|
||||
// sendMsg calls m.Update once with the given message and returns the updated model.
|
||||
func sendMsg(m *AppModel, msg tea.Msg) *AppModel {
|
||||
updated, _ := m.Update(msg)
|
||||
return updated.(*AppModel)
|
||||
result := updated.(*AppModel)
|
||||
// Simulate BubbleTea's frame cycle: View() is called after every Update().
|
||||
// This flushes any pending layoutDirty work (e.g. distributeHeight).
|
||||
_ = result.View()
|
||||
return result
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
@@ -1,352 +0,0 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"charm.land/bubbles/v2/key"
|
||||
"charm.land/bubbles/v2/textarea"
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/lipgloss/v2"
|
||||
)
|
||||
|
||||
// SlashCommandInput provides an interactive text input field with intelligent
|
||||
// slash command autocomplete functionality. It displays a popup menu of matching
|
||||
// commands as the user types, supporting fuzzy matching and keyboard navigation.
|
||||
type SlashCommandInput struct {
|
||||
textarea textarea.Model
|
||||
commands []SlashCommand
|
||||
showPopup bool
|
||||
filtered []FuzzyMatch
|
||||
selected int
|
||||
width int
|
||||
lastValue string
|
||||
popupHeight int
|
||||
title string
|
||||
quitting bool
|
||||
value string
|
||||
submitNext bool // Flag to submit on next update
|
||||
renderedLines int // Track how many lines were rendered
|
||||
hideHint bool // Suppress the "enter submit · ctrl+j..." hint
|
||||
}
|
||||
|
||||
// NewSlashCommandInput creates and initializes a new slash command input field with
|
||||
// the specified width and title. The input supports multi-line text entry, command
|
||||
// autocomplete, and is styled to match the application's theme.
|
||||
func NewSlashCommandInput(width int, title string) *SlashCommandInput {
|
||||
ta := textarea.New()
|
||||
ta.Placeholder = "Type your message..."
|
||||
ta.ShowLineNumbers = false
|
||||
ta.Prompt = ""
|
||||
ta.CharLimit = 5000
|
||||
ta.SetWidth(width - 8) // Account for container padding, border and internal padding
|
||||
ta.SetHeight(3) // Default to 3 lines like huh
|
||||
ta.Focus()
|
||||
|
||||
// Override InsertNewline so only ctrl+j and shift+enter insert newlines.
|
||||
// Enter always submits the input.
|
||||
ta.KeyMap.InsertNewline = key.NewBinding(
|
||||
key.WithKeys("ctrl+j", "shift+enter"),
|
||||
key.WithHelp("ctrl+j", "insert newline"),
|
||||
)
|
||||
|
||||
// Style the textarea using theme colors.
|
||||
theme := GetTheme()
|
||||
styles := ta.Styles()
|
||||
styles.Focused.Base = lipgloss.NewStyle()
|
||||
styles.Focused.Placeholder = lipgloss.NewStyle().Foreground(theme.VeryMuted)
|
||||
styles.Focused.Text = lipgloss.NewStyle().Foreground(theme.Text)
|
||||
styles.Focused.Prompt = lipgloss.NewStyle()
|
||||
styles.Focused.CursorLine = lipgloss.NewStyle()
|
||||
ta.SetStyles(styles)
|
||||
|
||||
return &SlashCommandInput{
|
||||
textarea: ta,
|
||||
commands: SlashCommands,
|
||||
width: width,
|
||||
popupHeight: 7,
|
||||
title: title,
|
||||
}
|
||||
}
|
||||
|
||||
// Init implements the tea.Model interface, returning the initial command to start
|
||||
// the cursor blinking animation for the text input field.
|
||||
func (s *SlashCommandInput) Init() tea.Cmd {
|
||||
return textarea.Blink
|
||||
}
|
||||
|
||||
// Update implements the tea.Model interface, handling keyboard input for text entry,
|
||||
// command selection, and navigation. Manages the autocomplete popup display and
|
||||
// processes submission or cancellation actions.
|
||||
func (s *SlashCommandInput) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
|
||||
// Check if we need to submit after updating the view
|
||||
if s.submitNext {
|
||||
s.value = s.textarea.Value()
|
||||
s.quitting = true
|
||||
return s, tea.Quit
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyPressMsg: // Check for quit keys first (when popup is not shown)
|
||||
if !s.showPopup {
|
||||
switch msg.String() {
|
||||
case "ctrl+c", "esc":
|
||||
s.quitting = true
|
||||
return s, tea.Quit
|
||||
case "ctrl+d", "enter": // Enter always submits
|
||||
s.value = s.textarea.Value()
|
||||
s.quitting = true
|
||||
return s, tea.Quit
|
||||
}
|
||||
}
|
||||
|
||||
// Handle popup navigation
|
||||
if s.showPopup {
|
||||
switch {
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("up"), key.WithHelp("↑", "up"))):
|
||||
if s.selected > 0 {
|
||||
s.selected--
|
||||
}
|
||||
return s, nil
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("down"), key.WithHelp("↓", "down"))):
|
||||
if s.selected < len(s.filtered)-1 {
|
||||
s.selected++
|
||||
}
|
||||
return s, nil
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("tab"))):
|
||||
if s.selected < len(s.filtered) {
|
||||
// Complete with selected command
|
||||
s.textarea.SetValue(s.filtered[s.selected].Command.Name)
|
||||
s.showPopup = false
|
||||
s.selected = 0
|
||||
// Move cursor to end
|
||||
s.textarea.CursorEnd()
|
||||
}
|
||||
return s, nil
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("enter"))):
|
||||
if s.selected < len(s.filtered) {
|
||||
// Populate the field with the selected command
|
||||
s.textarea.SetValue(s.filtered[s.selected].Command.Name)
|
||||
s.textarea.CursorEnd()
|
||||
// Hide the popup
|
||||
s.showPopup = false
|
||||
s.selected = 0
|
||||
// Set flag to submit on next update (after view refresh)
|
||||
s.submitNext = true
|
||||
// Force a refresh
|
||||
return s, nil
|
||||
}
|
||||
return s, nil
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("esc"))):
|
||||
s.showPopup = false
|
||||
s.selected = 0
|
||||
return s, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Update textarea
|
||||
s.textarea, cmd = s.textarea.Update(msg)
|
||||
|
||||
// Check if we should show/update popup
|
||||
value := s.textarea.Value()
|
||||
if value != s.lastValue {
|
||||
s.lastValue = value
|
||||
// Only show popup if we're on the first line and it starts with /
|
||||
lines := strings.Split(value, "\n")
|
||||
if len(lines) > 0 && strings.HasPrefix(lines[0], "/") && !strings.Contains(lines[0], " ") && len(lines) == 1 {
|
||||
// Show and update popup
|
||||
s.showPopup = true
|
||||
s.filtered = FuzzyMatchCommands(lines[0], s.commands)
|
||||
s.selected = 0
|
||||
} else {
|
||||
// Hide popup
|
||||
s.showPopup = false
|
||||
}
|
||||
}
|
||||
return s, cmd
|
||||
|
||||
default:
|
||||
// Pass through other messages
|
||||
s.textarea, cmd = s.textarea.Update(msg)
|
||||
return s, cmd
|
||||
}
|
||||
}
|
||||
|
||||
// View implements the tea.Model interface, rendering the complete input field
|
||||
// including the title, text area, autocomplete popup (when active), and help text.
|
||||
// The view adapts based on whether single or multi-line input is detected.
|
||||
func (s *SlashCommandInput) View() tea.View {
|
||||
containerStyle := lipgloss.NewStyle()
|
||||
|
||||
theme := GetTheme()
|
||||
|
||||
// PaddingLeft(3) aligns with message content: border(1) + paddingLeft(2).
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Text).
|
||||
MarginBottom(1).
|
||||
PaddingLeft(3)
|
||||
|
||||
// Input box with huh-like styling
|
||||
inputBoxStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.ThickBorder()).
|
||||
BorderLeft(true).
|
||||
BorderRight(false).
|
||||
BorderTop(false).
|
||||
BorderBottom(false).
|
||||
BorderForeground(theme.Primary).
|
||||
PaddingLeft(2). // match message block paddingLeft
|
||||
Width(s.width - 1) // full width minus left border
|
||||
|
||||
// Build the view
|
||||
var view strings.Builder
|
||||
view.WriteString(titleStyle.Render(s.title))
|
||||
view.WriteString("\n")
|
||||
view.WriteString(inputBoxStyle.Render(s.textarea.View()))
|
||||
// Count rendered lines
|
||||
s.renderedLines = 2 + s.textarea.Height() // title + newline + textarea height
|
||||
|
||||
// Add popup if visible
|
||||
if s.showPopup && len(s.filtered) > 0 {
|
||||
view.WriteString("\n")
|
||||
view.WriteString(s.renderPopup())
|
||||
// Add popup lines
|
||||
visibleItems := min(len(s.filtered), s.popupHeight)
|
||||
scrollIndicators := 0
|
||||
if s.selected >= s.popupHeight {
|
||||
scrollIndicators++ // top indicator
|
||||
}
|
||||
if len(s.filtered) > s.popupHeight {
|
||||
scrollIndicators++ // bottom indicator
|
||||
}
|
||||
popupLines := visibleItems + scrollIndicators + 5 // items + scroll + border + padding + footer
|
||||
s.renderedLines += 1 + popupLines // newline + popup
|
||||
}
|
||||
|
||||
// Add help text at bottom (unless hidden by extension).
|
||||
if !s.hideHint {
|
||||
helpStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.VeryMuted).
|
||||
MarginTop(1).
|
||||
PaddingLeft(3)
|
||||
|
||||
helpText := "enter submit • ctrl+j / shift+enter new line"
|
||||
|
||||
view.WriteString("\n")
|
||||
view.WriteString(helpStyle.Render(helpText))
|
||||
s.renderedLines += 2 // newline + help text
|
||||
}
|
||||
|
||||
// Apply container padding to entire view
|
||||
return tea.NewView(containerStyle.Render(view.String()))
|
||||
}
|
||||
|
||||
// renderPopup renders the autocomplete popup
|
||||
func (s *SlashCommandInput) renderPopup() string {
|
||||
theme := GetTheme()
|
||||
|
||||
// Popup styling
|
||||
popupStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(theme.MutedBorder).
|
||||
Padding(1, 2).
|
||||
Width(s.width - 4). // Account for container padding
|
||||
MarginLeft(0) // No extra margin needed due to container padding
|
||||
|
||||
var items []string
|
||||
|
||||
// Calculate visible window
|
||||
visibleItems := min(len(s.filtered), s.popupHeight)
|
||||
startIdx := 0
|
||||
|
||||
// Adjust window to keep selected item visible
|
||||
if s.selected >= s.popupHeight {
|
||||
startIdx = s.selected - s.popupHeight + 1
|
||||
}
|
||||
|
||||
endIdx := min(startIdx+visibleItems, len(s.filtered))
|
||||
|
||||
for i := startIdx; i < endIdx; i++ {
|
||||
match := s.filtered[i]
|
||||
cmd := match.Command
|
||||
// Create the selection indicator
|
||||
var indicator string
|
||||
if i == s.selected {
|
||||
indicator = lipgloss.NewStyle().
|
||||
Foreground(theme.Primary).
|
||||
Render("> ")
|
||||
} else {
|
||||
indicator = " "
|
||||
}
|
||||
|
||||
// Format item
|
||||
nameStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Secondary).
|
||||
Bold(true)
|
||||
|
||||
descStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted)
|
||||
|
||||
// Highlight selected item
|
||||
if i == s.selected {
|
||||
nameStyle = nameStyle.Foreground(theme.Primary)
|
||||
descStyle = descStyle.Foreground(theme.Text)
|
||||
}
|
||||
|
||||
// Format with proper spacing
|
||||
nameWidth := 15
|
||||
name := nameStyle.Width(nameWidth - 2).Render(cmd.Name)
|
||||
|
||||
// Truncate description if needed
|
||||
desc := cmd.Description
|
||||
maxDescLen := s.width - nameWidth - 14 // Account for padding and indicator
|
||||
if len(desc) > maxDescLen && maxDescLen > 3 {
|
||||
desc = desc[:maxDescLen-3] + "..."
|
||||
}
|
||||
|
||||
line := indicator + name + descStyle.Render(desc)
|
||||
items = append(items, line)
|
||||
}
|
||||
|
||||
// Add scroll indicators if needed
|
||||
if startIdx > 0 {
|
||||
scrollUpStyle := lipgloss.NewStyle().Foreground(theme.VeryMuted)
|
||||
items = append([]string{scrollUpStyle.Render(" ↑ more above")}, items...)
|
||||
}
|
||||
if endIdx < len(s.filtered) {
|
||||
scrollDownStyle := lipgloss.NewStyle().Foreground(theme.VeryMuted)
|
||||
items = append(items, scrollDownStyle.Render(" ↓ more below"))
|
||||
}
|
||||
// Join items
|
||||
content := strings.Join(items, "\n")
|
||||
|
||||
// Add footer hint
|
||||
footerStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.VeryMuted).
|
||||
Italic(true)
|
||||
footer := footerStyle.Render("↑↓ navigate • tab complete • ↵ select • esc dismiss")
|
||||
|
||||
// Combine content and footer
|
||||
popupContent := content + "\n\n" + footer
|
||||
|
||||
return popupStyle.Render(popupContent)
|
||||
}
|
||||
|
||||
// Value returns the final text value entered by the user after submission.
|
||||
// This will be empty if the input was cancelled.
|
||||
func (s *SlashCommandInput) Value() string {
|
||||
return s.value
|
||||
}
|
||||
|
||||
// Cancelled returns true if the user cancelled the input operation (e.g., by
|
||||
// pressing ESC or Ctrl+C) without submitting any text.
|
||||
func (s *SlashCommandInput) Cancelled() bool {
|
||||
return s.quitting && s.value == ""
|
||||
}
|
||||
|
||||
// RenderedLines returns the total number of terminal lines used by the last
|
||||
// rendered view, including the title, input area, popup, and help text. This
|
||||
// is used for proper screen clearing when the input is dismissed.
|
||||
func (s *SlashCommandInput) RenderedLines() int {
|
||||
return s.renderedLines
|
||||
}
|
||||
+136
-94
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/lipgloss/v2"
|
||||
"github.com/indaco/herald"
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
)
|
||||
|
||||
@@ -79,7 +80,12 @@ func streamSpinnerTickCmd(generation uint64) tea.Cmd {
|
||||
// streamFlushTickMsg fires when it's time to commit pending chunks to the
|
||||
// main content builders and trigger a re-render. This coalesces rapid
|
||||
// streaming chunks into fewer expensive markdown re-renders.
|
||||
type streamFlushTickMsg struct{}
|
||||
//
|
||||
// generation ties the tick to the pending flush session that created it so
|
||||
// stale ticks from a prior Reset() are discarded.
|
||||
type streamFlushTickMsg struct {
|
||||
generation uint64
|
||||
}
|
||||
|
||||
// streamFlushInterval is the coalescing window for stream chunks. Chunks
|
||||
// arriving within this window are batched into a single render pass.
|
||||
@@ -89,9 +95,9 @@ const streamFlushInterval = 16 * time.Millisecond
|
||||
|
||||
// streamFlushTickCmd returns a tea.Cmd that fires streamFlushTickMsg after
|
||||
// the coalescing interval.
|
||||
func streamFlushTickCmd() tea.Cmd {
|
||||
func streamFlushTickCmd(generation uint64) tea.Cmd {
|
||||
return tea.Tick(streamFlushInterval, func(_ time.Time) tea.Msg {
|
||||
return streamFlushTickMsg{}
|
||||
return streamFlushTickMsg{generation: generation}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -149,9 +155,11 @@ type StreamComponent struct {
|
||||
// spinnerFrame is the current frame index.
|
||||
spinnerFrame int
|
||||
|
||||
// activeTools tracks the names of tools currently executing in parallel.
|
||||
// When multiple tools run concurrently, all are displayed in the spinner.
|
||||
activeTools []string
|
||||
// activeTools maps ToolCallID -> display label for currently running tools.
|
||||
activeTools map[string]string
|
||||
|
||||
// activeToolOrder preserves deterministic display order for active tools.
|
||||
activeToolOrder []string
|
||||
|
||||
// streamContent holds committed streaming text (flushed from pending).
|
||||
streamContent strings.Builder
|
||||
@@ -172,6 +180,10 @@ type StreamComponent struct {
|
||||
// the same coalescing window.
|
||||
flushPending bool
|
||||
|
||||
// flushGeneration is incremented when stream state resets so stale flush
|
||||
// ticks from a previous step can be discarded.
|
||||
flushGeneration uint64
|
||||
|
||||
// renderCache holds the last rendered output string. Reused by View()
|
||||
// between flush ticks to avoid redundant markdown re-parsing.
|
||||
renderCache string
|
||||
@@ -190,14 +202,8 @@ type StreamComponent struct {
|
||||
// reasoningDuration holds the total reasoning time, frozen when streaming text begins.
|
||||
reasoningDuration time.Duration
|
||||
|
||||
// messageRenderer renders assistant messages in standard mode.
|
||||
messageRenderer *MessageRenderer
|
||||
|
||||
// compactRenderer renders assistant messages in compact mode.
|
||||
compactRenderer *CompactRenderer
|
||||
|
||||
// compactMode selects which renderer to use.
|
||||
compactMode bool
|
||||
// renderer renders streaming assistant text in either compact or standard mode.
|
||||
renderer Renderer
|
||||
|
||||
// modelName is displayed in the streaming text header.
|
||||
modelName string
|
||||
@@ -211,6 +217,9 @@ type StreamComponent struct {
|
||||
// height constrains the render output to at most this many lines.
|
||||
// 0 means unconstrained.
|
||||
height int
|
||||
|
||||
// ty provides typography functions for rendering text.
|
||||
ty *herald.Typography
|
||||
}
|
||||
|
||||
// NewStreamComponent creates a new StreamComponent ready to be embedded in AppModel.
|
||||
@@ -218,13 +227,20 @@ func NewStreamComponent(compactMode bool, width int, modelName string) *StreamCo
|
||||
if width == 0 {
|
||||
width = 80
|
||||
}
|
||||
|
||||
var renderer Renderer
|
||||
if compactMode {
|
||||
renderer = NewCompactRenderer(width, false)
|
||||
} else {
|
||||
renderer = newMessageRenderer(width, false)
|
||||
}
|
||||
|
||||
return &StreamComponent{
|
||||
spinnerFrames: knightRiderFrames(),
|
||||
compactMode: compactMode,
|
||||
modelName: modelName,
|
||||
messageRenderer: newMessageRenderer(width, false),
|
||||
compactRenderer: NewCompactRenderer(width, false),
|
||||
width: width,
|
||||
spinnerFrames: knightRiderFrames(),
|
||||
modelName: modelName,
|
||||
renderer: renderer,
|
||||
width: width,
|
||||
ty: createTypography(GetTheme()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -251,11 +267,13 @@ func (s *StreamComponent) Reset() {
|
||||
s.spinnerGeneration++ // invalidate any in-flight tick commands
|
||||
s.spinnerFrame = 0
|
||||
s.activeTools = nil
|
||||
s.activeToolOrder = nil
|
||||
s.streamContent.Reset()
|
||||
s.reasoningContent.Reset()
|
||||
s.pendingStream.Reset()
|
||||
s.pendingReasoning.Reset()
|
||||
s.flushPending = false
|
||||
s.flushGeneration++
|
||||
s.renderCache = ""
|
||||
s.renderDirty = false
|
||||
s.timestamp = time.Time{}
|
||||
@@ -282,7 +300,8 @@ func (s *StreamComponent) GetRenderedContent() string {
|
||||
|
||||
text := s.streamContent.String()
|
||||
if text != "" {
|
||||
sections = append(sections, s.renderStreamingText(text))
|
||||
rendered := s.renderStreamingText(text)
|
||||
sections = append(sections, rendered)
|
||||
}
|
||||
|
||||
if len(sections) == 0 {
|
||||
@@ -322,8 +341,9 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
case tea.WindowSizeMsg:
|
||||
s.width = msg.Width
|
||||
s.messageRenderer.SetWidth(s.width)
|
||||
s.compactRenderer.SetWidth(s.width)
|
||||
if s.renderer != nil {
|
||||
s.renderer.SetWidth(s.width)
|
||||
}
|
||||
// Invalidate render cache — width change affects wrapping/styling.
|
||||
s.renderCache = ""
|
||||
s.renderDirty = true
|
||||
@@ -359,6 +379,9 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
|
||||
case streamFlushTickMsg:
|
||||
if msg.generation != s.flushGeneration {
|
||||
break
|
||||
}
|
||||
s.flushPending = false
|
||||
s.commitPending()
|
||||
|
||||
@@ -373,7 +396,7 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
s.pendingReasoning.WriteString(msg.Delta)
|
||||
if !s.flushPending {
|
||||
s.flushPending = true
|
||||
return s, streamFlushTickCmd()
|
||||
return s, streamFlushTickCmd(s.flushGeneration)
|
||||
}
|
||||
|
||||
case app.StreamChunkEvent:
|
||||
@@ -388,14 +411,25 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
s.pendingStream.WriteString(msg.Content)
|
||||
if !s.flushPending {
|
||||
s.flushPending = true
|
||||
return s, streamFlushTickCmd()
|
||||
return s, streamFlushTickCmd(s.flushGeneration)
|
||||
}
|
||||
|
||||
case app.ToolExecutionEvent:
|
||||
toolID := msg.ToolCallID
|
||||
if toolID == "" {
|
||||
// Defensive fallback for older/third-party emitters that may omit
|
||||
// ToolCallID. Best-effort only: same-name+args concurrent calls can
|
||||
// still collide without a stable ID.
|
||||
toolID = fmt.Sprintf("%s|%s", msg.ToolName, msg.ToolArgs)
|
||||
}
|
||||
if msg.IsStarting {
|
||||
// Add tool to active list for parallel execution display.
|
||||
toolDisplay := formatToolExecutionMessage(msg.ToolName, msg.ToolArgs)
|
||||
s.activeTools = append(s.activeTools, toolDisplay)
|
||||
if s.activeTools == nil {
|
||||
s.activeTools = make(map[string]string)
|
||||
}
|
||||
if _, exists := s.activeTools[toolID]; !exists {
|
||||
s.activeToolOrder = append(s.activeToolOrder, toolID)
|
||||
}
|
||||
s.activeTools[toolID] = formatToolExecutionMessage(msg.ToolName)
|
||||
s.spinnerFrame = 0
|
||||
if !s.spinning {
|
||||
s.phase = streamPhaseActive
|
||||
@@ -404,9 +438,10 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return s, streamSpinnerTickCmd(s.spinnerGeneration)
|
||||
}
|
||||
} else {
|
||||
// Tool finished — remove from active list but keep spinning if others remain.
|
||||
toolDisplay := formatToolExecutionMessage(msg.ToolName, msg.ToolArgs)
|
||||
s.activeTools = removeFromSlice(s.activeTools, toolDisplay)
|
||||
if s.activeTools != nil {
|
||||
delete(s.activeTools, toolID)
|
||||
}
|
||||
s.activeToolOrder = removeToolID(s.activeToolOrder, toolID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -415,7 +450,9 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
// View implements tea.Model. Renders the current stream region content.
|
||||
func (s *StreamComponent) View() tea.View {
|
||||
return tea.NewView(s.render())
|
||||
fullContent := s.render()
|
||||
visibleContent := s.viewContent(fullContent)
|
||||
return tea.NewView(visibleContent)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -458,54 +495,51 @@ func (s *StreamComponent) render() string {
|
||||
|
||||
content := strings.Join(sections, "\n")
|
||||
|
||||
// Clamp to height if constrained: keep the last h lines so the most
|
||||
// recent output is always visible.
|
||||
if s.height > 0 && content != "" {
|
||||
lines := strings.Split(content, "\n")
|
||||
if len(lines) > s.height {
|
||||
lines = lines[len(lines)-s.height:]
|
||||
content = strings.Join(lines, "\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Cache FULL content without height clamping.
|
||||
// Height clamping is applied in View() for display only.
|
||||
s.renderCache = content
|
||||
s.renderDirty = false
|
||||
return content
|
||||
}
|
||||
|
||||
// renderReasoningBlock renders the reasoning/thinking content in a surface-tinted
|
||||
// box. When collapsed, shows the last 10 lines with a truncation hint. When
|
||||
// viewContent returns the visible portion of content based on height constraint.
|
||||
// This is called by View() to get the slice that fits in the terminal.
|
||||
func (s *StreamComponent) viewContent(fullContent string) string {
|
||||
if s.height > 0 && fullContent != "" {
|
||||
lines := strings.Split(fullContent, "\n")
|
||||
if len(lines) > s.height {
|
||||
// Keep only the last h lines so the most recent output is visible.
|
||||
lines = lines[len(lines)-s.height:]
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
}
|
||||
return fullContent
|
||||
}
|
||||
|
||||
// renderReasoningBlock renders the reasoning/thinking content using blockquote.
|
||||
// When collapsed, shows the last 10 lines with a truncation hint. When
|
||||
// expanded, shows all lines. Includes a "Thought for Xs" duration footer.
|
||||
func (s *StreamComponent) renderReasoningBlock(reasoning string) string {
|
||||
theme := GetTheme()
|
||||
maxWidth := max(s.width-4, 20)
|
||||
|
||||
lines := strings.Split(strings.TrimRight(reasoning, "\n"), "\n")
|
||||
|
||||
contentStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.MutedBorder).
|
||||
Italic(true)
|
||||
|
||||
var parts []string
|
||||
|
||||
// When collapsed and content exceeds 10 lines, show only the last 10
|
||||
// with a truncation hint (matching iteratr's thinking block pattern).
|
||||
// with a truncation hint.
|
||||
const maxCollapsedLines = 10
|
||||
if !s.thinkingVisible && len(lines) > maxCollapsedLines {
|
||||
hidden := len(lines) - maxCollapsedLines
|
||||
hintStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.VeryMuted).
|
||||
Background(theme.MutedBorder).
|
||||
Italic(true)
|
||||
parts = append(parts, hintStyle.Render(fmt.Sprintf("... (%d lines hidden)", hidden)))
|
||||
parts = append(parts, s.ty.Italic(fmt.Sprintf("... (%d lines hidden)", hidden)))
|
||||
lines = lines[len(lines)-maxCollapsedLines:]
|
||||
}
|
||||
|
||||
// Render reasoning text.
|
||||
parts = append(parts, contentStyle.Width(maxWidth).Render(strings.Join(lines, "\n")))
|
||||
// Main content using Italic with Muted color for visual distinction.
|
||||
content := strings.TrimLeft(strings.Join(lines, "\n"), " \t\n")
|
||||
theme := GetTheme()
|
||||
mutedStyle := lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
parts = append(parts, mutedStyle.Render(s.ty.Italic(content)))
|
||||
|
||||
// Duration footer.
|
||||
// Duration footer with VeryMuted label and Accent duration.
|
||||
var duration time.Duration
|
||||
if s.reasoningDuration > 0 {
|
||||
duration = s.reasoningDuration
|
||||
@@ -519,21 +553,21 @@ func (s *StreamComponent) renderReasoningBlock(reasoning string) string {
|
||||
} else {
|
||||
durationStr = fmt.Sprintf("%.1fs", duration.Seconds())
|
||||
}
|
||||
footer := lipgloss.NewStyle().Foreground(theme.VeryMuted).Background(theme.MutedBorder).Render("Thought for ") +
|
||||
lipgloss.NewStyle().Foreground(theme.Info).Background(theme.MutedBorder).Render(durationStr)
|
||||
parts = append(parts, footer)
|
||||
label := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render("Thought for ")
|
||||
durationStyled := lipgloss.NewStyle().Foreground(theme.Accent).Render(durationStr)
|
||||
parts = append(parts, label+durationStyled)
|
||||
}
|
||||
|
||||
innerContent := strings.Join(parts, "\n")
|
||||
|
||||
// Wrap in box with surface background for visual distinction.
|
||||
boxStyle := lipgloss.NewStyle().
|
||||
Background(theme.MutedBorder). // Surface0 (#313244)
|
||||
PaddingLeft(1).
|
||||
Width(maxWidth + 2).
|
||||
MarginBottom(1)
|
||||
|
||||
return boxStyle.Render(innerContent)
|
||||
// Concatenate parts with newline between blockquote and footer
|
||||
var result string
|
||||
if len(parts) == 1 {
|
||||
result = parts[0]
|
||||
} else if len(parts) == 2 {
|
||||
result = parts[0] + "\n" + parts[1]
|
||||
} else {
|
||||
result = strings.Join(parts, "\n")
|
||||
}
|
||||
return styleMarginBottom1.Render(result)
|
||||
}
|
||||
|
||||
// SetThinkingVisible sets whether reasoning blocks are shown or collapsed.
|
||||
@@ -559,7 +593,8 @@ func (s *StreamComponent) SpinnerView() string {
|
||||
return ""
|
||||
}
|
||||
frame := s.spinnerFrames[s.spinnerFrame%len(s.spinnerFrames)]
|
||||
if len(s.activeTools) == 0 {
|
||||
tools := s.activeToolDisplays()
|
||||
if len(tools) == 0 {
|
||||
return " " + frame
|
||||
}
|
||||
theme := GetTheme()
|
||||
@@ -569,10 +604,10 @@ func (s *StreamComponent) SpinnerView() string {
|
||||
|
||||
// Format active tools list
|
||||
var toolsMsg string
|
||||
if len(s.activeTools) == 1 {
|
||||
toolsMsg = s.activeTools[0]
|
||||
if len(tools) == 1 {
|
||||
toolsMsg = tools[0]
|
||||
} else {
|
||||
toolsMsg = "Running: " + strings.Join(s.activeTools, ", ")
|
||||
toolsMsg = "Running: " + strings.Join(tools, ", ")
|
||||
}
|
||||
return " " + frame + " " + msgStyle.Render(toolsMsg)
|
||||
}
|
||||
@@ -584,30 +619,37 @@ func (s *StreamComponent) renderStreamingText(text string) string {
|
||||
if ts.IsZero() {
|
||||
ts = time.Now()
|
||||
}
|
||||
|
||||
if s.compactMode {
|
||||
msg := s.compactRenderer.RenderAssistantMessage(text, ts, s.modelName)
|
||||
return msg.Content
|
||||
if s.renderer == nil {
|
||||
return text
|
||||
}
|
||||
msg := s.messageRenderer.RenderAssistantMessage(text, ts, s.modelName)
|
||||
msg := s.renderer.RenderAssistantMessage(text, ts, s.modelName)
|
||||
return msg.Content
|
||||
}
|
||||
|
||||
// removeFromSlice removes the first occurrence of a string from a slice.
|
||||
func removeFromSlice(slice []string, s string) []string {
|
||||
for i, v := range slice {
|
||||
if v == s {
|
||||
return append(slice[:i], slice[i+1:]...)
|
||||
func (s *StreamComponent) activeToolDisplays() []string {
|
||||
if len(s.activeTools) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(s.activeToolOrder))
|
||||
for _, id := range s.activeToolOrder {
|
||||
if display, ok := s.activeTools[id]; ok {
|
||||
out = append(out, display)
|
||||
}
|
||||
}
|
||||
return slice
|
||||
return out
|
||||
}
|
||||
|
||||
// removeToolID removes the first occurrence of a tool ID from a slice.
|
||||
func removeToolID(ids []string, id string) []string {
|
||||
for i, v := range ids {
|
||||
if v == id {
|
||||
return append(ids[:i], ids[i+1:]...)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// formatToolExecutionMessage creates a descriptive spinner message for tool execution.
|
||||
// For spawn_subagent, it shows simply as "Subagent" with optional task preview.
|
||||
func formatToolExecutionMessage(toolName, toolArgs string) string {
|
||||
if toolName == "spawn_subagent" {
|
||||
return "Subagent"
|
||||
}
|
||||
func formatToolExecutionMessage(toolName string) string {
|
||||
return toolName
|
||||
}
|
||||
|
||||
+20
-3
@@ -33,14 +33,31 @@ func colorHexPtr(c color.Color) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
// GetMarkdownRenderer creates and returns a configured glamour.TermRenderer for
|
||||
// rendering markdown content with syntax highlighting and proper formatting. The
|
||||
// renderer is customized with our theme colors and adapted to the specified width.
|
||||
// markdownRendererCache holds the last-created TermRenderer so we avoid
|
||||
// re-initializing a full goldmark parser on every streaming flush tick.
|
||||
// The cache is keyed by width; it is invalidated (set to nil) by SetTheme
|
||||
// whenever the active theme changes.
|
||||
// This is only accessed from BubbleTea's single-threaded Update/View cycle,
|
||||
// so no mutex is required.
|
||||
var (
|
||||
markdownRendererCache *glamour.TermRenderer
|
||||
markdownRendererWidth int
|
||||
)
|
||||
|
||||
// GetMarkdownRenderer returns a glamour.TermRenderer configured for our theme
|
||||
// and the given content width. The renderer is cached by width — it is only
|
||||
// rebuilt when the width changes, avoiding expensive goldmark re-initialization
|
||||
// on every streaming flush tick.
|
||||
func GetMarkdownRenderer(width int) *glamour.TermRenderer {
|
||||
if markdownRendererCache != nil && markdownRendererWidth == width {
|
||||
return markdownRendererCache
|
||||
}
|
||||
r, _ := glamour.NewTermRenderer(
|
||||
glamour.WithStyles(generateMarkdownStyleConfig()),
|
||||
glamour.WithWordWrap(width),
|
||||
)
|
||||
markdownRendererCache = r
|
||||
markdownRendererWidth = width
|
||||
return r
|
||||
}
|
||||
|
||||
|
||||
+10
-3
@@ -129,13 +129,20 @@ type presetColors struct {
|
||||
}
|
||||
|
||||
func makeTheme(p presetColors) Theme {
|
||||
ac := func(pair [2]string) color.Color { return AdaptiveColor(pair[0], pair[1]) }
|
||||
def := DefaultTheme()
|
||||
acOr := func(pair [2]string, fb color.Color) color.Color {
|
||||
ac := func(pair [2]string) color.Color {
|
||||
c := AdaptiveColor(pair[0], pair[1])
|
||||
if pair[0] == "" && pair[1] == "" {
|
||||
return nil
|
||||
}
|
||||
return c
|
||||
}
|
||||
acOr := func(pair [2]string, fb color.Color) color.Color {
|
||||
c := ac(pair)
|
||||
if c == nil {
|
||||
return fb
|
||||
}
|
||||
return ac(pair)
|
||||
return c
|
||||
}
|
||||
t := Theme{
|
||||
Primary: ac(p.primary),
|
||||
|
||||
+100
-48
@@ -26,6 +26,14 @@ const (
|
||||
maxLsLines = 20 // lines for Ls directory listings
|
||||
)
|
||||
|
||||
// isShellTool reports if the tool name matches a shell-like tool (bash, grep, find, or
|
||||
// tools with "shell"/"command" in the name). Used by both renderToolBody and
|
||||
// renderToolBodyCompact to avoid code duplication.
|
||||
func isShellTool(toolName string) bool {
|
||||
return toolName == "bash" || toolName == "grep" || toolName == "find" ||
|
||||
strings.Contains(toolName, "shell") || strings.Contains(toolName, "command")
|
||||
}
|
||||
|
||||
// renderToolBody dispatches to tool-specific body renderers based on tool name.
|
||||
// Returns the styled body string, or empty string to fall back to default rendering.
|
||||
func renderToolBody(toolName, toolArgs, toolResult string, width int) string {
|
||||
@@ -46,12 +54,11 @@ func renderToolBody(toolName, toolArgs, toolResult string, width int) string {
|
||||
if body := renderWriteBody(toolArgs, toolResult, width); body != "" {
|
||||
return body
|
||||
}
|
||||
case toolName == "bash" || toolName == "run_shell_cmd" ||
|
||||
strings.Contains(toolName, "shell") || strings.Contains(toolName, "command"):
|
||||
case isShellTool(toolName):
|
||||
if body := renderBashBody(toolResult, width); body != "" {
|
||||
return body
|
||||
}
|
||||
case toolName == "spawn_subagent":
|
||||
case toolName == "subagent":
|
||||
if body := renderSubagentBody(toolResult, width); body != "" {
|
||||
return body
|
||||
}
|
||||
@@ -64,21 +71,44 @@ func renderToolBody(toolName, toolArgs, toolResult string, width int) string {
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// renderEditBody renders a side-by-side diff from old_text/new_text in toolArgs.
|
||||
// Supports both single-edit mode and multi-edit mode (edits array).
|
||||
func renderEditBody(toolArgs, toolResult string, width int) string {
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(toolArgs), &args); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Try to extract the starting line number from the unified diff in the result
|
||||
startLine := extractDiffStartLine(toolResult)
|
||||
|
||||
// Check for multi-edit mode (edits array)
|
||||
if editsArr, ok := args["edits"].([]any); ok && len(editsArr) > 0 {
|
||||
var results []string
|
||||
for _, edit := range editsArr {
|
||||
if e, ok := edit.(map[string]any); ok {
|
||||
oldText, _ := e["old_text"].(string)
|
||||
newText, _ := e["new_text"].(string)
|
||||
if oldText != "" || newText != "" {
|
||||
diff := renderDiffBlock(oldText, newText, startLine, width)
|
||||
if diff != "" {
|
||||
results = append(results, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(results) > 0 {
|
||||
return strings.Join(results, "\n")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Single-edit mode (legacy)
|
||||
oldText, _ := args["old_text"].(string)
|
||||
newText, _ := args["new_text"].(string)
|
||||
if oldText == "" && newText == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Try to extract the starting line number from the unified diff in the result
|
||||
startLine := extractDiffStartLine(toolResult)
|
||||
|
||||
return renderDiffBlock(oldText, newText, startLine, width)
|
||||
}
|
||||
|
||||
@@ -221,7 +251,7 @@ func renderDiffBlock(before, after string, startLine int, width int) string {
|
||||
gutterWidth := max(len(fmt.Sprintf("%d", maxLineNum)), 3)
|
||||
contentWidth := max(panelWidth-gutterWidth-4, 10) // gutter + " - " or " + "
|
||||
|
||||
theme := getTheme()
|
||||
theme := GetTheme()
|
||||
|
||||
// Styles for each cell type
|
||||
gutterInsert := lipgloss.NewStyle().Foreground(theme.Muted).Background(theme.DiffInsertBg)
|
||||
@@ -326,7 +356,7 @@ func renderLsBody(toolResult string, width int) string {
|
||||
const indent = " "
|
||||
codeWidth := max(width-len(indent), 20)
|
||||
|
||||
theme := getTheme()
|
||||
theme := GetTheme()
|
||||
codeStyle := lipgloss.NewStyle().Background(theme.CodeBg).PaddingLeft(1)
|
||||
|
||||
var result []string
|
||||
@@ -440,7 +470,7 @@ func renderCodeBlock(content, fileName string, width int) string {
|
||||
gutterWidth := max(maxNumWidth+2, 5)
|
||||
codeWidth := max(width-gutterWidth-len(codeIndent), 20)
|
||||
|
||||
theme := getTheme()
|
||||
theme := GetTheme()
|
||||
gutterStyle := lipgloss.NewStyle().Foreground(theme.Muted).Background(theme.GutterBg).PaddingRight(1)
|
||||
codeStyle := lipgloss.NewStyle().Background(theme.CodeBg).PaddingLeft(1)
|
||||
|
||||
@@ -535,7 +565,7 @@ func renderWriteBlock(content, fileName string, width int) string {
|
||||
gutterWidth := numDigits + 2
|
||||
codeWidth := max(width-gutterWidth-len(codeIndent), 20)
|
||||
|
||||
theme := getTheme()
|
||||
theme := GetTheme()
|
||||
gutterStyle := lipgloss.NewStyle().Foreground(theme.Muted).Background(theme.GutterBg).PaddingRight(1)
|
||||
writeStyle := lipgloss.NewStyle().Background(theme.WriteBg).PaddingLeft(1)
|
||||
|
||||
@@ -587,7 +617,7 @@ func renderBashBody(toolResult string, width int) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
theme := getTheme()
|
||||
theme := GetTheme()
|
||||
outputStyle := lipgloss.NewStyle().Background(theme.CodeBg).PaddingLeft(1)
|
||||
stderrStyle := lipgloss.NewStyle().Foreground(theme.Error).Background(theme.CodeBg).PaddingLeft(1)
|
||||
|
||||
@@ -604,7 +634,6 @@ func renderBashBody(toolResult string, width int) string {
|
||||
|
||||
const lineIndent = " "
|
||||
// Truncate individual lines to the available width so they never wrap.
|
||||
// This mirrors Crush's approach: truncate, don't wrap.
|
||||
lineWidth := max(width-len(lineIndent), 20)
|
||||
// Account for PaddingLeft(1) on the output/stderr styles
|
||||
maxLineChars := lineWidth - 1
|
||||
@@ -754,10 +783,9 @@ func renderToolBodyCompact(toolName, toolArgs, toolResult string, width int) str
|
||||
return renderReadCompact(toolResult)
|
||||
case toolName == "write":
|
||||
return renderWriteCompact(toolArgs)
|
||||
case toolName == "bash" || toolName == "run_shell_cmd" ||
|
||||
strings.Contains(toolName, "shell") || strings.Contains(toolName, "command"):
|
||||
case isShellTool(toolName):
|
||||
return renderBashCompact(toolResult, width)
|
||||
case toolName == "spawn_subagent":
|
||||
case toolName == "subagent":
|
||||
return renderSubagentCompact(toolResult)
|
||||
}
|
||||
return ""
|
||||
@@ -786,7 +814,7 @@ func renderReadCompact(toolResult string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
theme := getTheme()
|
||||
theme := GetTheme()
|
||||
summary := fmt.Sprintf("%d lines", codeLines)
|
||||
return lipgloss.NewStyle().Foreground(theme.Muted).Italic(true).Render(summary)
|
||||
}
|
||||
@@ -807,7 +835,7 @@ func renderEditCompact(toolArgs, toolResult string) string {
|
||||
oldCount := len(strings.Split(oldText, "\n"))
|
||||
newCount := len(strings.Split(newText, "\n"))
|
||||
|
||||
theme := getTheme()
|
||||
theme := GetTheme()
|
||||
var summary string
|
||||
if oldCount == newCount {
|
||||
summary = fmt.Sprintf("%d lines modified", oldCount)
|
||||
@@ -830,7 +858,7 @@ func renderWriteCompact(toolArgs string) string {
|
||||
}
|
||||
|
||||
count := len(strings.Split(content, "\n"))
|
||||
theme := getTheme()
|
||||
theme := GetTheme()
|
||||
summary := fmt.Sprintf("%d lines written", count)
|
||||
return lipgloss.NewStyle().Foreground(theme.Muted).Italic(true).Render(summary)
|
||||
}
|
||||
@@ -843,7 +871,7 @@ func renderLsCompact(toolResult string) string {
|
||||
}
|
||||
|
||||
entries := strings.Split(content, "\n")
|
||||
theme := getTheme()
|
||||
theme := GetTheme()
|
||||
summary := fmt.Sprintf("%d entries", len(entries))
|
||||
return lipgloss.NewStyle().Foreground(theme.Muted).Italic(true).Render(summary)
|
||||
}
|
||||
@@ -881,14 +909,14 @@ func renderBashCompact(toolResult string, width int) string {
|
||||
|
||||
if len(outputLines) == 0 {
|
||||
if exitCode != "" {
|
||||
theme := getTheme()
|
||||
theme := GetTheme()
|
||||
return lipgloss.NewStyle().Foreground(theme.Error).Render(exitCode)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
const maxLines = 3
|
||||
theme := getTheme()
|
||||
theme := GetTheme()
|
||||
|
||||
display := outputLines
|
||||
if len(display) > maxLines {
|
||||
@@ -916,10 +944,10 @@ func renderBashCompact(toolResult string, width int) string {
|
||||
// Subagent tool renderers — show only summary, not full output
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// renderSubagentBody renders a clean summary of subagent results.
|
||||
// Extracts timing/token info and shows only a brief summary instead of raw output.
|
||||
// renderSubagentBody renders a clean summary of subagent results with bash-style
|
||||
// background styling for consistency with other tools.
|
||||
func renderSubagentBody(toolResult string, width int) string {
|
||||
theme := getTheme()
|
||||
theme := GetTheme()
|
||||
result := strings.TrimSpace(toolResult)
|
||||
if result == "" {
|
||||
return ""
|
||||
@@ -937,9 +965,19 @@ func renderSubagentBody(toolResult string, width int) string {
|
||||
// First line is always the status summary
|
||||
statusLine := lines[0]
|
||||
|
||||
// Build a clean summary
|
||||
var summary strings.Builder
|
||||
summary.WriteString(lipgloss.NewStyle().Foreground(theme.Muted).Render(statusLine))
|
||||
// Build content lines for display with bash-style background
|
||||
outputStyle := lipgloss.NewStyle().Background(theme.CodeBg).PaddingLeft(1)
|
||||
errorStyle := lipgloss.NewStyle().Foreground(theme.Error).Background(theme.CodeBg).PaddingLeft(1)
|
||||
|
||||
const lineIndent = " "
|
||||
lineWidth := max(width-len(lineIndent), 20)
|
||||
maxLineChars := lineWidth - 1 // account for PaddingLeft(1)
|
||||
|
||||
var contentLines []string
|
||||
|
||||
// Add status line
|
||||
styledStatus := outputStyle.Width(lineWidth).Render(truncateLine(statusLine, maxLineChars))
|
||||
contentLines = append(contentLines, lineIndent+styledStatus)
|
||||
|
||||
// For successful results, extract a brief preview of the actual result
|
||||
if strings.Contains(statusLine, "successfully") {
|
||||
@@ -947,25 +985,45 @@ func renderSubagentBody(toolResult string, width int) string {
|
||||
if _, resultContent, found := strings.Cut(result, "Result:\n"); found {
|
||||
resultContent = strings.TrimSpace(resultContent)
|
||||
if resultContent != "" {
|
||||
// Show first 3 meaningful lines as preview
|
||||
preview := extractSubagentPreview(resultContent, 3, width-4)
|
||||
if preview != "" {
|
||||
summary.WriteString("\n\n")
|
||||
summary.WriteString(lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Italic(true).
|
||||
Render(preview))
|
||||
// Show first few meaningful lines as preview
|
||||
previewLines := extractSubagentPreviewLines(resultContent, 5, maxLineChars)
|
||||
if len(previewLines) > 0 {
|
||||
// Add blank separator line
|
||||
blankLine := outputStyle.Width(lineWidth).Render("")
|
||||
contentLines = append(contentLines, lineIndent+blankLine)
|
||||
|
||||
for _, line := range previewLines {
|
||||
styled := outputStyle.Width(lineWidth).Render(line)
|
||||
contentLines = append(contentLines, lineIndent+styled)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// For failed results, show error info
|
||||
if _, errorContent, found := strings.Cut(result, "Error:\n"); found {
|
||||
errorContent = strings.TrimSpace(errorContent)
|
||||
if errorContent != "" {
|
||||
previewLines := extractSubagentPreviewLines(errorContent, 3, maxLineChars)
|
||||
if len(previewLines) > 0 {
|
||||
blankLine := outputStyle.Width(lineWidth).Render("")
|
||||
contentLines = append(contentLines, lineIndent+blankLine)
|
||||
|
||||
for _, line := range previewLines {
|
||||
styled := errorStyle.Width(lineWidth).Render(line)
|
||||
contentLines = append(contentLines, lineIndent+styled)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return summary.String()
|
||||
return strings.Join(contentLines, "\n")
|
||||
}
|
||||
|
||||
// extractSubagentPreview extracts the first N non-empty lines from content,
|
||||
// truncating each line to maxWidth.
|
||||
func extractSubagentPreview(content string, maxLines, maxWidth int) string {
|
||||
// extractSubagentPreviewLines extracts the first N non-empty lines from content,
|
||||
// truncating each line to maxWidth. Returns as a slice of strings.
|
||||
func extractSubagentPreviewLines(content string, maxLines, maxWidth int) []string {
|
||||
lines := strings.Split(content, "\n")
|
||||
var preview []string
|
||||
|
||||
@@ -984,12 +1042,6 @@ func extractSubagentPreview(content string, maxLines, maxWidth int) string {
|
||||
}
|
||||
}
|
||||
|
||||
if len(preview) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
result := strings.Join(preview, "\n")
|
||||
|
||||
// Count remaining lines for "more" indicator
|
||||
totalLines := 0
|
||||
for _, line := range lines {
|
||||
@@ -998,10 +1050,10 @@ func extractSubagentPreview(content string, maxLines, maxWidth int) string {
|
||||
}
|
||||
}
|
||||
if totalLines > maxLines {
|
||||
result += fmt.Sprintf("\n...(%d more lines)", totalLines-maxLines)
|
||||
preview = append(preview, fmt.Sprintf("...(%d more lines)", totalLines-maxLines))
|
||||
}
|
||||
|
||||
return result
|
||||
return preview
|
||||
}
|
||||
|
||||
// renderSubagentCompact returns a brief one-line summary for subagent results.
|
||||
@@ -1011,7 +1063,7 @@ func renderSubagentCompact(toolResult string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
theme := getTheme()
|
||||
theme := GetTheme()
|
||||
|
||||
// Extract just the first line which contains the status
|
||||
lines := strings.Split(result, "\n")
|
||||
|
||||
@@ -134,13 +134,23 @@ func (ut *UsageTracker) EstimateAndUpdateUsage(inputText, outputText string) {
|
||||
}
|
||||
|
||||
// SetContextTokens records the approximate current context window utilization.
|
||||
// This should be set from the final API call's input + output tokens (i.e.
|
||||
// FinalResponse.Usage) rather than the aggregate TotalUsage, because TotalUsage
|
||||
// This should be set from FinalUsage.InputTokens, which already includes the
|
||||
// full conversation history (system prompt + all previous messages). Do NOT
|
||||
// add OutputTokens as that would double-count (output becomes input next turn).
|
||||
// Use FinalResponse.Usage rather than aggregate TotalUsage, because TotalUsage
|
||||
// sums across all tool-calling steps and overstates the actual window fill level.
|
||||
func (ut *UsageTracker) SetContextTokens(tokens int) {
|
||||
ut.mu.Lock()
|
||||
defer ut.mu.Unlock()
|
||||
ut.contextTokens = tokens
|
||||
// Track the maximum context seen so far. In multi-step tool calls,
|
||||
// FinalUsage.InputTokens may reflect only the last step's input, which
|
||||
// can be smaller than previous steps. We want to show the largest context
|
||||
// the model has processed in this session.
|
||||
if tokens > ut.contextTokens {
|
||||
ut.contextTokens = tokens
|
||||
}
|
||||
// If tokens < current, we keep the larger value (no-op)
|
||||
// This prevents the display from dropping during multi-step tool calls.
|
||||
}
|
||||
|
||||
// RenderUsageInfo generates a formatted string displaying current usage statistics
|
||||
@@ -151,10 +161,6 @@ func (ut *UsageTracker) RenderUsageInfo() string {
|
||||
ut.mu.RLock()
|
||||
defer ut.mu.RUnlock()
|
||||
|
||||
if ut.sessionStats.RequestCount == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
baseStyle := lipgloss.NewStyle()
|
||||
|
||||
// Display the current context window token count (from the last API call),
|
||||
|
||||
@@ -67,3 +67,62 @@ func TestUsageTracker_RenderUsageInfo_OAuth(t *testing.T) {
|
||||
t.Errorf("Expected regular rendered output to show actual cost, got: %s", regularRendered)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageTracker_RenderUsageInfo_StartupState(t *testing.T) {
|
||||
// Create a mock model info with costs and context limit
|
||||
modelInfo := &models.ModelInfo{
|
||||
ID: "claude-3-5-sonnet-20241022",
|
||||
Name: "Claude 3.5 Sonnet v2",
|
||||
Cost: models.Cost{
|
||||
Input: 3.0,
|
||||
Output: 15.0,
|
||||
},
|
||||
Limit: models.Limit{
|
||||
Context: 200000,
|
||||
Output: 8192,
|
||||
},
|
||||
}
|
||||
|
||||
// Test startup state (no requests made yet) - Regular API key
|
||||
regularTracker := NewUsageTracker(modelInfo, "anthropic", 80, false)
|
||||
rendered := stripAnsi(regularTracker.RenderUsageInfo())
|
||||
|
||||
// Should NOT return empty string on startup
|
||||
if rendered == "" {
|
||||
t.Errorf("Expected non-empty output on startup, got empty string")
|
||||
}
|
||||
|
||||
// Should show 0 tokens
|
||||
if !strings.Contains(rendered, "Tokens: 0") {
|
||||
t.Errorf("Expected 'Tokens: 0' on startup, got: %s", rendered)
|
||||
}
|
||||
|
||||
// Should NOT show percentage when tokens are 0
|
||||
if strings.Contains(rendered, "(%") {
|
||||
t.Errorf("Expected no percentage on startup with 0 tokens, got: %s", rendered)
|
||||
}
|
||||
|
||||
// Should show $0.0000 cost for regular API key
|
||||
if !strings.Contains(rendered, "Cost: $0.0000") {
|
||||
t.Errorf("Expected 'Cost: $0.0000' on startup, got: %s", rendered)
|
||||
}
|
||||
|
||||
// Test startup state (no requests made yet) - OAuth
|
||||
oauthTracker := NewUsageTracker(modelInfo, "anthropic", 80, true)
|
||||
oauthRendered := stripAnsi(oauthTracker.RenderUsageInfo())
|
||||
|
||||
// Should NOT return empty string on startup
|
||||
if oauthRendered == "" {
|
||||
t.Errorf("Expected non-empty output on startup for OAuth, got empty string")
|
||||
}
|
||||
|
||||
// Should show 0 tokens for OAuth
|
||||
if !strings.Contains(oauthRendered, "Tokens: 0") {
|
||||
t.Errorf("Expected 'Tokens: 0' on startup for OAuth, got: %s", oauthRendered)
|
||||
}
|
||||
|
||||
// Should show $0.00 cost for OAuth
|
||||
if !strings.Contains(oauthRendered, "Cost: $0.00") {
|
||||
t.Errorf("Expected 'Cost: $0.00' on startup for OAuth, got: %s", oauthRendered)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,7 +52,6 @@ type Harness struct {
|
||||
t *testing.T
|
||||
runner *extensions.Runner
|
||||
context *MockContext
|
||||
extPath string
|
||||
}
|
||||
|
||||
// New creates a new test harness for the given test.
|
||||
@@ -72,15 +71,9 @@ func New(t *testing.T) *Harness {
|
||||
func (h *Harness) LoadFile(path string) *extensions.LoadedExtension {
|
||||
h.t.Helper()
|
||||
|
||||
// Verify file exists
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
h.t.Fatalf("extension file not found: %s: %v", path, err)
|
||||
}
|
||||
|
||||
// Read extension source
|
||||
src, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
h.t.Fatalf("failed to read extension file: %v", err)
|
||||
h.t.Fatalf("failed to read extension file %s: %v", path, err)
|
||||
}
|
||||
|
||||
return h.loadSource(string(src), path)
|
||||
@@ -144,7 +137,6 @@ func (h *Harness) loadSource(src string, path string) *extensions.LoadedExtensio
|
||||
|
||||
// Create runner with the loaded extension
|
||||
h.runner = extensions.NewRunner([]extensions.LoadedExtension{*ext})
|
||||
h.extPath = path
|
||||
|
||||
// Wire the mock context
|
||||
h.runner.SetContext(h.context.ToContext())
|
||||
@@ -222,11 +214,3 @@ func (h *Harness) RegisteredCommands() []extensions.CommandDef {
|
||||
}
|
||||
return h.runner.RegisteredCommands()
|
||||
}
|
||||
|
||||
// MustLoad is like LoadFile but fails the test immediately on error.
|
||||
// It returns the harness for chaining.
|
||||
func (h *Harness) MustLoad(path string) *Harness {
|
||||
h.t.Helper()
|
||||
h.LoadFile(path)
|
||||
return h
|
||||
}
|
||||
|
||||
@@ -59,29 +59,12 @@ type MockContext struct {
|
||||
Overlays []extensions.OverlayConfig
|
||||
}
|
||||
|
||||
// StatusBarEntry represents a recorded status bar entry
|
||||
type StatusBarEntry struct {
|
||||
Key string
|
||||
Text string
|
||||
Priority int
|
||||
}
|
||||
|
||||
// NewMockContext creates a new mock context with default values.
|
||||
func NewMockContext() *MockContext {
|
||||
return &MockContext{
|
||||
Prints: make([]string, 0),
|
||||
PrintInfos: make([]string, 0),
|
||||
PrintErrors: make([]string, 0),
|
||||
PrintBlocks: make([]extensions.PrintBlockOpts, 0),
|
||||
Messages: make([]string, 0),
|
||||
CancelSends: make([]string, 0),
|
||||
Widgets: make(map[string]extensions.WidgetConfig),
|
||||
RemovedIDs: make([]string, 0),
|
||||
StatusEntries: make(map[string]extensions.StatusBarEntry),
|
||||
RemovedStatus: make([]string, 0),
|
||||
EditorTexts: make([]string, 0),
|
||||
Options: make(map[string]string),
|
||||
Overlays: make([]extensions.OverlayConfig, 0),
|
||||
Interactive: true,
|
||||
SessionID: "test-session",
|
||||
CWD: "/test",
|
||||
|
||||
+76
-41
@@ -1,6 +1,6 @@
|
||||
# KIT SDK
|
||||
|
||||
The KIT SDK allows you to use KIT programmatically from Go applications without spawning OS processes.
|
||||
The KIT SDK (`pkg/kit`) lets you embed Kit's full agent capabilities — LLM interactions, tool execution, session management, streaming, hooks — into any Go application.
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -17,26 +17,26 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
|
||||
|
||||
// Create Kit instance with default configuration
|
||||
host, err := kit.New(ctx, nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer host.Close()
|
||||
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
// Send a prompt
|
||||
response, err := host.Prompt(ctx, "What is 2+2?")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
fmt.Println(response)
|
||||
}
|
||||
```
|
||||
@@ -56,11 +56,23 @@ You can override specific settings:
|
||||
```go
|
||||
host, err := kit.New(ctx, &kit.Options{
|
||||
Model: "ollama/llama3", // Override model
|
||||
SystemPrompt: "You are a helpful bot", // Override system prompt
|
||||
ConfigFile: "/path/to/config.yml", // Use specific config file
|
||||
MaxSteps: 10, // Override max steps
|
||||
Streaming: true, // Enable streaming
|
||||
Quiet: true, // Suppress debug output
|
||||
SystemPrompt: "You are a helpful bot", // Override system prompt
|
||||
ConfigFile: "/path/to/config.yml", // Use specific config file
|
||||
MaxSteps: 10, // Override max steps
|
||||
Streaming: true, // Enable streaming
|
||||
Quiet: true, // Suppress debug output
|
||||
|
||||
// Session options
|
||||
SessionPath: "./session.jsonl", // Open specific session
|
||||
Continue: true, // Resume most recent session
|
||||
NoSession: true, // Ephemeral mode
|
||||
|
||||
// Tool options
|
||||
Tools: []kit.Tool{kit.NewBashTool()}, // Replace default tool set
|
||||
ExtraTools: []kit.Tool{myTool}, // Add alongside defaults
|
||||
|
||||
// Compaction
|
||||
AutoCompact: true, // Auto-compact near context limit
|
||||
})
|
||||
```
|
||||
|
||||
@@ -71,22 +83,28 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
Monitor tool execution in real-time:
|
||||
|
||||
```go
|
||||
response, err := host.PromptWithCallbacks(
|
||||
unsub := host.OnToolCall(func(e kit.ToolCallEvent) {
|
||||
fmt.Printf("Calling tool: %s\n", e.ToolName)
|
||||
})
|
||||
defer unsub()
|
||||
|
||||
unsub2 := host.OnToolResult(func(e kit.ToolResultEvent) {
|
||||
if e.IsError {
|
||||
fmt.Printf("Tool %s failed: %s\n", e.ToolName, e.Result)
|
||||
} else {
|
||||
fmt.Printf("Tool %s succeeded\n", e.ToolName)
|
||||
}
|
||||
})
|
||||
defer unsub2()
|
||||
|
||||
unsub3 := host.OnStreaming(func(e kit.MessageUpdateEvent) {
|
||||
fmt.Print(e.Chunk)
|
||||
})
|
||||
defer unsub3()
|
||||
|
||||
response, err := host.Prompt(
|
||||
ctx,
|
||||
"List files in the current directory",
|
||||
func(name, args string) {
|
||||
fmt.Printf("Calling tool: %s\n", name)
|
||||
},
|
||||
func(name, args, result string, isError bool) {
|
||||
if isError {
|
||||
fmt.Printf("Tool %s failed: %s\n", name, result)
|
||||
} else {
|
||||
fmt.Printf("Tool %s succeeded\n", name)
|
||||
}
|
||||
},
|
||||
func(chunk string) {
|
||||
fmt.Print(chunk) // Stream output
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
@@ -102,35 +120,52 @@ host.Prompt(ctx, "My name is Alice")
|
||||
response, _ := host.Prompt(ctx, "What's my name?")
|
||||
// Response: "Your name is Alice"
|
||||
|
||||
// Save session
|
||||
host.SaveSession("./session.json")
|
||||
|
||||
// Load session later
|
||||
host.LoadSession("./session.json")
|
||||
|
||||
// Clear session
|
||||
// Clear conversation history
|
||||
host.ClearSession()
|
||||
```
|
||||
|
||||
## Re-exported Types
|
||||
|
||||
The SDK re-exports types so you don't need direct internal imports:
|
||||
|
||||
```go
|
||||
// Message types
|
||||
kit.Message, kit.MessageRole, kit.ContentPart
|
||||
kit.TextContent, kit.ReasoningContent, kit.ToolCall, kit.ToolResult, kit.Finish
|
||||
kit.RoleUser, kit.RoleAssistant, kit.RoleTool, kit.RoleSystem
|
||||
|
||||
// LLM types (re-exported from the underlying LLM library)
|
||||
kit.LLMMessage, kit.LLMUsage, kit.LLMResponse, kit.LLMFilePart
|
||||
|
||||
// Conversion helpers
|
||||
msgs := kit.ConvertToLLMMessages(&msg) // SDK message → LLM messages
|
||||
msg := kit.ConvertFromLLMMessage(fMsg) // LLM message → SDK message
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### Types
|
||||
|
||||
- `Kit` - Main SDK type
|
||||
- `Options` - Configuration options
|
||||
- `Message` - Conversation message
|
||||
- `ToolCall` - Tool invocation details
|
||||
- `Message` - Conversation message with typed content parts
|
||||
- `Tool` - Agent tool interface
|
||||
- `TurnResult` - Full result from a prompt including usage stats
|
||||
|
||||
### Methods
|
||||
### Key Methods
|
||||
|
||||
- `New(ctx, opts)` - Create new Kit instance
|
||||
- `Prompt(ctx, message)` - Send message and get response
|
||||
- `PromptWithCallbacks(ctx, message, ...)` - Send message with progress callbacks
|
||||
- `LoadSession(path)` - Load session from file
|
||||
- `SaveSession(path)` - Save session to file
|
||||
- `ClearSession()` - Clear conversation history
|
||||
- `GetSessionManager()` - Get session manager for advanced usage
|
||||
- `Prompt(ctx, message)` - Send message and get response string
|
||||
- `PromptResult(ctx, message)` - Send message and get full TurnResult
|
||||
- `PromptWithOptions(ctx, message, opts)` - Prompt with per-call options
|
||||
- `Steer(ctx, instruction)` - System-level steering
|
||||
- `FollowUp(ctx, text)` - Continue without new user input
|
||||
- `SetModel(ctx, model)` - Switch model at runtime
|
||||
- `GetModelString()` - Get current model string
|
||||
- `GetModelInfo()` - Get model capabilities and limits
|
||||
- `ClearSession()` - Clear conversation history
|
||||
- `GetSessionPath()` - Get session file path
|
||||
- `GetSessionID()` - Get session UUID
|
||||
- `Close()` - Clean up resources
|
||||
|
||||
## Environment Variables
|
||||
|
||||
+11
-9
@@ -1,6 +1,10 @@
|
||||
package kit
|
||||
|
||||
import "github.com/mark3labs/kit/internal/auth"
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
)
|
||||
|
||||
// CredentialManager manages API keys and OAuth credentials.
|
||||
type CredentialManager = auth.CredentialManager
|
||||
@@ -66,14 +70,12 @@ func HasOpenAICredentials() bool {
|
||||
// Returns an empty string if no key is found.
|
||||
func GetOpenAIAPIKey() string {
|
||||
cm, err := auth.NewCredentialManager()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
// Try to get valid access token (handles OAuth refresh)
|
||||
token, err := cm.GetValidOpenAIAccessToken()
|
||||
if err == nil && token != "" {
|
||||
return token
|
||||
if err == nil {
|
||||
// Try to get valid access token (handles OAuth refresh)
|
||||
if token, err := cm.GetValidOpenAIAccessToken(); err == nil && token != "" {
|
||||
return token
|
||||
}
|
||||
}
|
||||
// Fall back to environment variable
|
||||
return ""
|
||||
return os.Getenv("OPENAI_API_KEY")
|
||||
}
|
||||
|
||||
+46
-42
@@ -2,6 +2,7 @@ package kit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"charm.land/fantasy"
|
||||
@@ -17,10 +18,14 @@ type ContextStats struct {
|
||||
MessageCount int // Number of messages in the conversation
|
||||
}
|
||||
|
||||
// defaultReserveTokens is the number of tokens to keep free in the context
|
||||
// window as a safety margin during compaction checks.
|
||||
const defaultReserveTokens = 16384
|
||||
|
||||
// EstimateContextTokens returns the estimated token count of the current
|
||||
// conversation based on tree session messages.
|
||||
func (m *Kit) EstimateContextTokens() int {
|
||||
messages := m.treeSession.GetFantasyMessages()
|
||||
messages := m.treeSession.GetLLMMessages()
|
||||
return compaction.EstimateMessageTokens(messages)
|
||||
}
|
||||
|
||||
@@ -34,12 +39,12 @@ func (m *Kit) ShouldCompact() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
reserveTokens := 16384
|
||||
reserveTokens := defaultReserveTokens
|
||||
if m.compactionOpts != nil && m.compactionOpts.ReserveTokens > 0 {
|
||||
reserveTokens = m.compactionOpts.ReserveTokens
|
||||
}
|
||||
|
||||
messages := m.treeSession.GetFantasyMessages()
|
||||
messages := m.treeSession.GetLLMMessages()
|
||||
return compaction.ShouldCompact(messages, info.Limit.Context, reserveTokens)
|
||||
}
|
||||
|
||||
@@ -52,7 +57,7 @@ func (m *Kit) ShouldCompact() bool {
|
||||
// because it includes system prompts, tool definitions, and other overhead
|
||||
// that the heuristic cannot account for.
|
||||
func (m *Kit) GetContextStats() ContextStats {
|
||||
messages := m.treeSession.GetFantasyMessages()
|
||||
messages := m.treeSession.GetLLMMessages()
|
||||
|
||||
// Prefer the real API-reported input token count when available.
|
||||
m.lastInputTokensMu.RLock()
|
||||
@@ -111,7 +116,7 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
|
||||
}
|
||||
}
|
||||
|
||||
messages := m.treeSession.GetFantasyMessages()
|
||||
messages := m.treeSession.GetLLMMessages()
|
||||
if len(messages) < 2 {
|
||||
return nil, fmt.Errorf("cannot compact: need at least 2 messages")
|
||||
}
|
||||
@@ -131,7 +136,7 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
|
||||
if reason == "" {
|
||||
reason = "compaction cancelled by extension"
|
||||
}
|
||||
return nil, fmt.Errorf("%s", reason)
|
||||
return nil, errors.New(reason)
|
||||
}
|
||||
// Extension provided a custom summary — use it directly.
|
||||
if hookResult.Summary != "" {
|
||||
@@ -166,27 +171,10 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
|
||||
firstKeptEntryID = entryIDs[result.CutPoint]
|
||||
}
|
||||
|
||||
if _, err := m.treeSession.AppendCompaction(
|
||||
result.Summary,
|
||||
firstKeptEntryID,
|
||||
result.OriginalTokens,
|
||||
result.CompactedTokens,
|
||||
result.MessagesRemoved,
|
||||
result.ReadFiles,
|
||||
result.ModifiedFiles,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed to persist compaction entry: %w", err)
|
||||
if err := m.persistAndEmitCompaction(result.Summary, firstKeptEntryID, result.OriginalTokens, result.CompactedTokens, result.MessagesRemoved, result.ReadFiles, result.ModifiedFiles); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.events.emit(CompactionEvent{
|
||||
Summary: result.Summary,
|
||||
OriginalTokens: result.OriginalTokens,
|
||||
CompactedTokens: result.CompactedTokens,
|
||||
MessagesRemoved: result.MessagesRemoved,
|
||||
ReadFiles: result.ReadFiles,
|
||||
ModifiedFiles: result.ModifiedFiles,
|
||||
})
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -218,17 +206,6 @@ func (m *Kit) applyCustomCompaction(summary string, messages []fantasy.Message,
|
||||
recentTokens := compaction.EstimateMessageTokens(messages[cutPoint:])
|
||||
compactedTokens := summaryTokens + recentTokens
|
||||
|
||||
if _, err := m.treeSession.AppendCompaction(
|
||||
summary,
|
||||
firstKeptEntryID,
|
||||
originalTokens,
|
||||
compactedTokens,
|
||||
cutPoint,
|
||||
nil, nil, // no file tracking for custom summaries
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed to persist compaction entry: %w", err)
|
||||
}
|
||||
|
||||
result := &CompactionResult{
|
||||
Summary: summary,
|
||||
OriginalTokens: originalTokens,
|
||||
@@ -236,12 +213,39 @@ func (m *Kit) applyCustomCompaction(summary string, messages []fantasy.Message,
|
||||
MessagesRemoved: cutPoint,
|
||||
}
|
||||
|
||||
m.events.emit(CompactionEvent{
|
||||
Summary: result.Summary,
|
||||
OriginalTokens: result.OriginalTokens,
|
||||
CompactedTokens: result.CompactedTokens,
|
||||
MessagesRemoved: result.MessagesRemoved,
|
||||
})
|
||||
if err := m.persistAndEmitCompaction(summary, firstKeptEntryID, originalTokens, compactedTokens, cutPoint, nil, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// persistAndEmitCompaction writes a CompactionEntry to the session tree and
|
||||
// emits a CompactionEvent. It is the single implementation shared by
|
||||
// compactInternal and applyCustomCompaction.
|
||||
func (m *Kit) persistAndEmitCompaction(
|
||||
summary, firstKeptEntryID string,
|
||||
originalTokens, compactedTokens, messagesRemoved int,
|
||||
readFiles, modifiedFiles []string,
|
||||
) error {
|
||||
if _, err := m.treeSession.AppendCompaction(
|
||||
summary,
|
||||
firstKeptEntryID,
|
||||
originalTokens,
|
||||
compactedTokens,
|
||||
messagesRemoved,
|
||||
readFiles,
|
||||
modifiedFiles,
|
||||
); err != nil {
|
||||
return fmt.Errorf("failed to persist compaction entry: %w", err)
|
||||
}
|
||||
m.events.emit(CompactionEvent{
|
||||
Summary: summary,
|
||||
OriginalTokens: originalTokens,
|
||||
CompactedTokens: compactedTokens,
|
||||
MessagesRemoved: messagesRemoved,
|
||||
ReadFiles: readFiles,
|
||||
ModifiedFiles: modifiedFiles,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
+11
-11
@@ -12,6 +12,10 @@ import (
|
||||
// defaultSystemPrompt is the built-in system prompt used when no custom
|
||||
// prompt is configured. It describes the available core tools and provides
|
||||
// usage guidelines.
|
||||
//
|
||||
// NOTE: Keep this in sync with the CLI default in cmd/root.go (search for
|
||||
// defaultSystemPrompt or system-prompt flag default). Changes here should
|
||||
// generally be reflected there, and vice versa.
|
||||
const defaultSystemPrompt = `You are an expert coding assistant operating inside kit, a coding agent harness. You help users by reading files, executing commands, editing code, and writing new files.
|
||||
|
||||
Available tools:
|
||||
@@ -78,20 +82,16 @@ func InitConfig(configFile string, debug bool) error {
|
||||
viper.AddConfigPath(home)
|
||||
|
||||
configLoaded := false
|
||||
configNames := []string{".kit"}
|
||||
|
||||
for _, name := range configNames {
|
||||
viper.SetConfigName(name)
|
||||
if err := viper.ReadInConfig(); err == nil {
|
||||
configPath := viper.ConfigFileUsed()
|
||||
if err := LoadConfigWithEnvSubstitution(configPath); err != nil {
|
||||
if strings.Contains(err.Error(), "environment variable substitution failed") {
|
||||
return fmt.Errorf("error reading config file '%s': %w", configPath, err)
|
||||
}
|
||||
continue
|
||||
viper.SetConfigName(".kit")
|
||||
if err := viper.ReadInConfig(); err == nil {
|
||||
configPath := viper.ConfigFileUsed()
|
||||
if err := LoadConfigWithEnvSubstitution(configPath); err != nil {
|
||||
if strings.Contains(err.Error(), "environment variable substitution failed") {
|
||||
return fmt.Errorf("error reading config file '%s': %w", configPath, err)
|
||||
}
|
||||
} else {
|
||||
configLoaded = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+16
-52
@@ -70,20 +70,20 @@ const (
|
||||
ToolKindEdit = "edit" // File modification (edit, write)
|
||||
ToolKindRead = "read" // File reading (read, ls)
|
||||
ToolKindSearch = "search" // Content/file search (grep, find)
|
||||
ToolKindSubagent = "agent" // Subagent spawning (spawn_subagent)
|
||||
ToolKindSubagent = "agent" // Subagent spawning (subagent)
|
||||
)
|
||||
|
||||
// coreToolKinds maps built-in tool names to their kind. MCP and extension
|
||||
// tools without an entry default to ToolKindExecute.
|
||||
var coreToolKinds = map[string]string{
|
||||
"bash": ToolKindExecute,
|
||||
"edit": ToolKindEdit,
|
||||
"write": ToolKindEdit,
|
||||
"read": ToolKindRead,
|
||||
"ls": ToolKindRead,
|
||||
"grep": ToolKindSearch,
|
||||
"find": ToolKindSearch,
|
||||
"spawn_subagent": ToolKindSubagent,
|
||||
"bash": ToolKindExecute,
|
||||
"edit": ToolKindEdit,
|
||||
"write": ToolKindEdit,
|
||||
"read": ToolKindRead,
|
||||
"ls": ToolKindRead,
|
||||
"grep": ToolKindSearch,
|
||||
"find": ToolKindSearch,
|
||||
"subagent": ToolKindSubagent,
|
||||
}
|
||||
|
||||
// toolKindFor returns the ToolKind for a given tool name, defaulting to
|
||||
@@ -216,7 +216,7 @@ type ToolResultEvent struct {
|
||||
// ToolResultMetadata carries structured data from tool executions.
|
||||
type ToolResultMetadata struct {
|
||||
FileDiffs []FileDiffInfo `json:"file_diffs,omitempty"` // Present for edit/write tools
|
||||
SubagentSessionID string `json:"subagent_session_id,omitempty"` // Present for spawn_subagent tool
|
||||
SubagentSessionID string `json:"subagent_session_id,omitempty"` // Present for subagent tool
|
||||
}
|
||||
|
||||
// FileDiffInfo describes a file modification from an edit or write tool.
|
||||
@@ -416,68 +416,32 @@ func (m *Kit) OnTurnEnd(handler func(TurnEndEvent)) func() {
|
||||
// Subagent event subscriptions
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// subagentListenerSet holds per-tool-call listeners for subagent events.
|
||||
type subagentListenerSet struct {
|
||||
mu sync.RWMutex
|
||||
listeners map[int]EventListener
|
||||
nextID int
|
||||
}
|
||||
|
||||
func newSubagentListenerSet() *subagentListenerSet {
|
||||
return &subagentListenerSet{listeners: make(map[int]EventListener)}
|
||||
}
|
||||
|
||||
func (s *subagentListenerSet) add(listener EventListener) func() {
|
||||
s.mu.Lock()
|
||||
id := s.nextID
|
||||
s.nextID++
|
||||
s.listeners[id] = listener
|
||||
s.mu.Unlock()
|
||||
return func() {
|
||||
s.mu.Lock()
|
||||
delete(s.listeners, id)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *subagentListenerSet) emit(event Event) {
|
||||
s.mu.RLock()
|
||||
snapshot := make([]EventListener, 0, len(s.listeners))
|
||||
for _, l := range s.listeners {
|
||||
snapshot = append(snapshot, l)
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
for _, l := range snapshot {
|
||||
l(event)
|
||||
}
|
||||
}
|
||||
|
||||
// SubscribeSubagent registers a listener for real-time events from a subagent
|
||||
// identified by its tool call ID. Returns an unsubscribe function.
|
||||
//
|
||||
// The listener receives the same event types as Subscribe() (ToolCallEvent,
|
||||
// MessageUpdateEvent, etc.) but scoped to the child agent's activity. If the
|
||||
// tool call ID doesn't correspond to an active or future spawn_subagent call,
|
||||
// tool call ID doesn't correspond to an active or future subagent call,
|
||||
// the listener simply never fires.
|
||||
//
|
||||
// Typical usage — register inside an OnToolCall handler:
|
||||
//
|
||||
// kit.OnToolCall(func(e kit.ToolCallEvent) {
|
||||
// if e.ToolName == "spawn_subagent" {
|
||||
// if e.ToolName == "subagent" {
|
||||
// kit.SubscribeSubagent(e.ToolCallID, func(child kit.Event) {
|
||||
// // real-time subagent events
|
||||
// })
|
||||
// }
|
||||
// })
|
||||
func (m *Kit) SubscribeSubagent(toolCallID string, listener EventListener) func() {
|
||||
actual, _ := m.subagentListeners.LoadOrStore(toolCallID, newSubagentListenerSet())
|
||||
return actual.(*subagentListenerSet).add(listener)
|
||||
actual, _ := m.subagentListeners.LoadOrStore(toolCallID, newEventBus())
|
||||
return actual.(*eventBus).subscribe(listener)
|
||||
}
|
||||
|
||||
// getSubagentListenerSet returns the listener set for a tool call, or nil.
|
||||
func (m *Kit) getSubagentListenerSet(toolCallID string) *subagentListenerSet {
|
||||
func (m *Kit) getSubagentListenerSet(toolCallID string) *eventBus {
|
||||
if v, ok := m.subagentListeners.Load(toolCallID); ok {
|
||||
return v.(*subagentListenerSet)
|
||||
return v.(*eventBus)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
+23
-2
@@ -140,8 +140,14 @@ func TestEventBusConcurrentSubscribeEmit(t *testing.T) {
|
||||
wg.Wait()
|
||||
|
||||
// We can't assert an exact count because subscribe/emit ordering is
|
||||
// non-deterministic, but it must not panic or deadlock.
|
||||
t.Logf("total events received across subscribers: %d", total.Load())
|
||||
// non-deterministic, but we can assert the count is non-negative and
|
||||
// that no events were lost (each subscriber that registered before an
|
||||
// emit must have received it at least partially).
|
||||
got := total.Load()
|
||||
if got < 0 {
|
||||
t.Errorf("expected non-negative total event count, got %d", got)
|
||||
}
|
||||
t.Logf("total events received across subscribers: %d", got)
|
||||
}
|
||||
|
||||
// TestEventBusEmitNoListeners verifies emit is a no-op with no subscribers.
|
||||
@@ -169,6 +175,11 @@ func TestEventTypes(t *testing.T) {
|
||||
{ToolResultEvent{}, EventToolResult},
|
||||
{ToolCallContentEvent{}, EventToolCallContent},
|
||||
{ResponseEvent{}, EventResponse},
|
||||
{CompactionEvent{}, EventCompaction},
|
||||
{ReasoningDeltaEvent{}, EventReasoningDelta},
|
||||
{ToolOutputEvent{}, EventToolOutput},
|
||||
{StepUsageEvent{}, EventStepUsage},
|
||||
{SteerConsumedEvent{}, EventSteerConsumed},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -212,26 +223,36 @@ func TestEventOrdering(t *testing.T) {
|
||||
EventTurnStart,
|
||||
EventMessageStart,
|
||||
EventMessageUpdate,
|
||||
EventReasoningDelta,
|
||||
EventToolOutput,
|
||||
EventToolCall,
|
||||
EventToolExecutionStart,
|
||||
EventToolExecutionEnd,
|
||||
EventToolResult,
|
||||
EventToolCallContent,
|
||||
EventMessageEnd,
|
||||
EventStepUsage,
|
||||
EventResponse,
|
||||
EventCompaction,
|
||||
EventSteerConsumed,
|
||||
EventTurnEnd,
|
||||
}
|
||||
|
||||
bus.emit(TurnStartEvent{})
|
||||
bus.emit(MessageStartEvent{})
|
||||
bus.emit(MessageUpdateEvent{Chunk: "hello"})
|
||||
bus.emit(ReasoningDeltaEvent{Delta: "thinking..."})
|
||||
bus.emit(ToolOutputEvent{ToolName: "bash", Chunk: "output"})
|
||||
bus.emit(ToolCallEvent{ToolName: "bash"})
|
||||
bus.emit(ToolExecutionStartEvent{ToolName: "bash"})
|
||||
bus.emit(ToolExecutionEndEvent{ToolName: "bash"})
|
||||
bus.emit(ToolResultEvent{ToolName: "bash", Result: "ok"})
|
||||
bus.emit(ToolCallContentEvent{Content: "I'll run bash"})
|
||||
bus.emit(MessageEndEvent{Content: "done"})
|
||||
bus.emit(StepUsageEvent{InputTokens: 100})
|
||||
bus.emit(ResponseEvent{Content: "done"})
|
||||
bus.emit(CompactionEvent{Summary: "compacted"})
|
||||
bus.emit(SteerConsumedEvent{Count: 1})
|
||||
bus.emit(TurnEndEvent{Response: "done"})
|
||||
|
||||
if len(types) != len(expected) {
|
||||
|
||||
@@ -0,0 +1,435 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/message"
|
||||
"github.com/mark3labs/kit/internal/session"
|
||||
)
|
||||
|
||||
// ExtensionAPI provides grouped access to all extension-related functionality.
|
||||
// This cleans up the main Kit API surface while keeping all extension capabilities available.
|
||||
type ExtensionAPI interface {
|
||||
// Context management
|
||||
SetContext(ctx extensions.Context)
|
||||
GetContext() extensions.Context
|
||||
UpdateContextModel(model string)
|
||||
|
||||
// Widgets
|
||||
SetWidget(config extensions.WidgetConfig)
|
||||
RemoveWidget(id string)
|
||||
GetWidgets(placement extensions.WidgetPlacement) []extensions.WidgetConfig
|
||||
|
||||
// Header/Footer
|
||||
SetHeader(config extensions.HeaderFooterConfig)
|
||||
RemoveHeader()
|
||||
GetHeader() *extensions.HeaderFooterConfig
|
||||
SetFooter(config extensions.HeaderFooterConfig)
|
||||
RemoveFooter()
|
||||
GetFooter() *extensions.HeaderFooterConfig
|
||||
|
||||
// Editor
|
||||
SetEditor(config extensions.EditorConfig)
|
||||
ResetEditor()
|
||||
GetEditor() *extensions.EditorConfig
|
||||
|
||||
// UI Visibility
|
||||
SetUIVisibility(v extensions.UIVisibility)
|
||||
GetUIVisibility() *extensions.UIVisibility
|
||||
|
||||
// Tool rendering
|
||||
GetToolRenderer(toolName string) *extensions.ToolRenderConfig
|
||||
GetMessageRenderer(name string) *extensions.MessageRendererConfig
|
||||
|
||||
// Session data
|
||||
GetSessionMessages() []extensions.SessionMessage
|
||||
AppendEntry(extType, data string) (string, error)
|
||||
GetEntries(extType string) []extensions.ExtensionEntry
|
||||
|
||||
// Status bar
|
||||
SetStatus(entry extensions.StatusBarEntry)
|
||||
RemoveStatus(key string)
|
||||
GetStatusEntries() []extensions.StatusBarEntry
|
||||
|
||||
// Shortcuts
|
||||
GetShortcuts() map[string]func()
|
||||
|
||||
// Tools
|
||||
GetToolInfos() []extensions.ToolInfo
|
||||
SetActiveTools(names []string)
|
||||
|
||||
// Options
|
||||
GetOption(name string) string
|
||||
SetOption(name, value string)
|
||||
|
||||
// Events
|
||||
EmitSessionStart()
|
||||
EmitModelChange(newModel, previousModel, source string)
|
||||
EmitCustomEvent(name, data string)
|
||||
EmitBeforeFork(targetID string, isUserMsg bool, userText string) (cancelled bool, reason string)
|
||||
EmitBeforeSessionSwitch(switchReason string) (cancelled bool, reason string)
|
||||
|
||||
// Commands
|
||||
Commands() []extensions.CommandDef
|
||||
|
||||
// Lifecycle
|
||||
Reload() error
|
||||
HasExtensions() bool
|
||||
}
|
||||
|
||||
// extensionAPI implements ExtensionAPI by wrapping a Kit instance.
|
||||
type extensionAPI struct {
|
||||
kit *Kit
|
||||
}
|
||||
|
||||
// Extensions returns the ExtensionAPI for accessing all extension-related functionality.
|
||||
func (m *Kit) Extensions() ExtensionAPI {
|
||||
return &extensionAPI{kit: m}
|
||||
}
|
||||
|
||||
// Context management
|
||||
|
||||
func (e *extensionAPI) SetContext(ctx extensions.Context) {
|
||||
if e.kit.extRunner != nil {
|
||||
e.kit.extRunner.SetContext(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) GetContext() extensions.Context {
|
||||
if e.kit.extRunner != nil {
|
||||
return e.kit.extRunner.GetContext()
|
||||
}
|
||||
return extensions.Context{}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) UpdateContextModel(model string) {
|
||||
if e.kit.extRunner != nil {
|
||||
ctx := e.kit.extRunner.GetContext()
|
||||
ctx.Model = model
|
||||
e.kit.extRunner.SetContext(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// Widgets
|
||||
|
||||
func (e *extensionAPI) SetWidget(config extensions.WidgetConfig) {
|
||||
if e.kit.extRunner != nil {
|
||||
e.kit.extRunner.SetWidget(config)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) RemoveWidget(id string) {
|
||||
if e.kit.extRunner != nil {
|
||||
e.kit.extRunner.RemoveWidget(id)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) GetWidgets(placement extensions.WidgetPlacement) []extensions.WidgetConfig {
|
||||
if e.kit.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
return e.kit.extRunner.GetWidgets(placement)
|
||||
}
|
||||
|
||||
// Header/Footer
|
||||
|
||||
func (e *extensionAPI) SetHeader(config extensions.HeaderFooterConfig) {
|
||||
if e.kit.extRunner != nil {
|
||||
e.kit.extRunner.SetHeader(config)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) RemoveHeader() {
|
||||
if e.kit.extRunner != nil {
|
||||
e.kit.extRunner.RemoveHeader()
|
||||
}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) GetHeader() *extensions.HeaderFooterConfig {
|
||||
if e.kit.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
return e.kit.extRunner.GetHeader()
|
||||
}
|
||||
|
||||
func (e *extensionAPI) SetFooter(config extensions.HeaderFooterConfig) {
|
||||
if e.kit.extRunner != nil {
|
||||
e.kit.extRunner.SetFooter(config)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) RemoveFooter() {
|
||||
if e.kit.extRunner != nil {
|
||||
e.kit.extRunner.RemoveFooter()
|
||||
}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) GetFooter() *extensions.HeaderFooterConfig {
|
||||
if e.kit.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
return e.kit.extRunner.GetFooter()
|
||||
}
|
||||
|
||||
// Editor
|
||||
|
||||
func (e *extensionAPI) SetEditor(config extensions.EditorConfig) {
|
||||
if e.kit.extRunner != nil {
|
||||
e.kit.extRunner.SetEditor(config)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) ResetEditor() {
|
||||
if e.kit.extRunner != nil {
|
||||
e.kit.extRunner.ResetEditor()
|
||||
}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) GetEditor() *extensions.EditorConfig {
|
||||
if e.kit.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
return e.kit.extRunner.GetEditor()
|
||||
}
|
||||
|
||||
// UI Visibility
|
||||
|
||||
func (e *extensionAPI) SetUIVisibility(v extensions.UIVisibility) {
|
||||
if e.kit.extRunner != nil {
|
||||
e.kit.extRunner.SetUIVisibility(v)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) GetUIVisibility() *extensions.UIVisibility {
|
||||
if e.kit.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
return e.kit.extRunner.GetUIVisibility()
|
||||
}
|
||||
|
||||
// Tool rendering
|
||||
|
||||
func (e *extensionAPI) GetToolRenderer(toolName string) *extensions.ToolRenderConfig {
|
||||
if e.kit.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
return e.kit.extRunner.GetToolRenderer(toolName)
|
||||
}
|
||||
|
||||
func (e *extensionAPI) GetMessageRenderer(name string) *extensions.MessageRendererConfig {
|
||||
if e.kit.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
return e.kit.extRunner.GetMessageRenderer(name)
|
||||
}
|
||||
|
||||
// Session data
|
||||
|
||||
func (e *extensionAPI) GetSessionMessages() []extensions.SessionMessage {
|
||||
return iterBranchMessages(e.kit.treeSession, func(me *session.MessageEntry, msg message.Message) extensions.SessionMessage {
|
||||
return extensions.SessionMessage{
|
||||
ID: me.ID,
|
||||
Role: string(msg.Role),
|
||||
Content: msg.Content(),
|
||||
Timestamp: me.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (e *extensionAPI) AppendEntry(extType, data string) (string, error) {
|
||||
if e.kit.treeSession == nil {
|
||||
return "", fmt.Errorf("no session available")
|
||||
}
|
||||
return e.kit.treeSession.AppendExtensionData(extType, data)
|
||||
}
|
||||
|
||||
func (e *extensionAPI) GetEntries(extType string) []extensions.ExtensionEntry {
|
||||
if e.kit.treeSession == nil {
|
||||
return nil
|
||||
}
|
||||
entries := e.kit.treeSession.GetExtensionData(extType)
|
||||
result := make([]extensions.ExtensionEntry, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
result = append(result, extensions.ExtensionEntry{
|
||||
ID: e.ID,
|
||||
EntryType: e.ExtType,
|
||||
Data: e.Data,
|
||||
Timestamp: e.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Status bar
|
||||
|
||||
func (e *extensionAPI) SetStatus(entry extensions.StatusBarEntry) {
|
||||
if e.kit.extRunner != nil {
|
||||
e.kit.extRunner.SetStatusEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) RemoveStatus(key string) {
|
||||
if e.kit.extRunner != nil {
|
||||
e.kit.extRunner.RemoveStatusEntry(key)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) GetStatusEntries() []extensions.StatusBarEntry {
|
||||
if e.kit.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
return e.kit.extRunner.GetStatusEntries()
|
||||
}
|
||||
|
||||
// Shortcuts
|
||||
|
||||
func (e *extensionAPI) GetShortcuts() map[string]func() {
|
||||
if e.kit.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
entries := e.kit.extRunner.GetShortcuts()
|
||||
if entries == nil {
|
||||
return nil
|
||||
}
|
||||
result := make(map[string]func(), len(entries))
|
||||
for key, entry := range entries {
|
||||
h := entry.Handler
|
||||
r := e.kit.extRunner
|
||||
result[key] = func() {
|
||||
ctx := r.GetContext()
|
||||
h(ctx)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Tools
|
||||
|
||||
func (e *extensionAPI) GetToolInfos() []extensions.ToolInfo {
|
||||
agentTools := e.kit.agent.GetTools()
|
||||
coreCount := e.kit.agent.GetCoreToolCount()
|
||||
mcpCount := e.kit.agent.GetMCPToolCount()
|
||||
|
||||
result := make([]extensions.ToolInfo, 0, len(agentTools))
|
||||
for i, t := range agentTools {
|
||||
info := t.Info()
|
||||
source := "core"
|
||||
if i >= coreCount && i < coreCount+mcpCount {
|
||||
source = "mcp"
|
||||
} else if i >= coreCount+mcpCount {
|
||||
source = "extension"
|
||||
}
|
||||
enabled := true
|
||||
if e.kit.extRunner != nil && e.kit.extRunner.IsToolDisabled(info.Name) {
|
||||
enabled = false
|
||||
}
|
||||
result = append(result, extensions.ToolInfo{
|
||||
Name: info.Name,
|
||||
Description: info.Description,
|
||||
Source: source,
|
||||
Enabled: enabled,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (e *extensionAPI) SetActiveTools(names []string) {
|
||||
if e.kit.extRunner != nil {
|
||||
e.kit.extRunner.SetActiveTools(names)
|
||||
}
|
||||
}
|
||||
|
||||
// Options
|
||||
|
||||
func (e *extensionAPI) GetOption(name string) string {
|
||||
if e.kit.extRunner == nil {
|
||||
return ""
|
||||
}
|
||||
return e.kit.extRunner.GetOption(name)
|
||||
}
|
||||
|
||||
func (e *extensionAPI) SetOption(name, value string) {
|
||||
if e.kit.extRunner != nil {
|
||||
e.kit.extRunner.SetOption(name, value)
|
||||
}
|
||||
}
|
||||
|
||||
// Events
|
||||
|
||||
func (e *extensionAPI) EmitSessionStart() {
|
||||
if e.kit.extRunner != nil && e.kit.extRunner.HasHandlers(extensions.SessionStart) {
|
||||
_, _ = e.kit.extRunner.Emit(extensions.SessionStartEvent{})
|
||||
}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) EmitModelChange(newModel, previousModel, source string) {
|
||||
if e.kit.extRunner != nil && e.kit.extRunner.HasHandlers(extensions.ModelChange) {
|
||||
_, _ = e.kit.extRunner.Emit(extensions.ModelChangeEvent{
|
||||
NewModel: newModel,
|
||||
PreviousModel: previousModel,
|
||||
Source: source,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) EmitCustomEvent(name, data string) {
|
||||
if e.kit.extRunner != nil {
|
||||
e.kit.extRunner.EmitCustomEvent(name, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) EmitBeforeFork(targetID string, isUserMsg bool, userText string) (cancelled bool, reason string) {
|
||||
if e.kit.extRunner == nil || !e.kit.extRunner.HasHandlers(extensions.BeforeFork) {
|
||||
return false, ""
|
||||
}
|
||||
result, _ := e.kit.extRunner.Emit(extensions.BeforeForkEvent{
|
||||
TargetID: targetID,
|
||||
IsUserMessage: isUserMsg,
|
||||
UserText: userText,
|
||||
})
|
||||
if r, ok := result.(extensions.BeforeForkResult); ok && r.Cancel {
|
||||
reason := r.Reason
|
||||
if reason == "" {
|
||||
reason = "Fork cancelled by extension."
|
||||
}
|
||||
return true, reason
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
func (e *extensionAPI) EmitBeforeSessionSwitch(switchReason string) (cancelled bool, reason string) {
|
||||
if e.kit.extRunner == nil || !e.kit.extRunner.HasHandlers(extensions.BeforeSessionSwitch) {
|
||||
return false, ""
|
||||
}
|
||||
result, _ := e.kit.extRunner.Emit(extensions.BeforeSessionSwitchEvent{
|
||||
Reason: switchReason,
|
||||
})
|
||||
if r, ok := result.(extensions.BeforeSessionSwitchResult); ok && r.Cancel {
|
||||
reason := r.Reason
|
||||
if reason == "" {
|
||||
reason = "Session switch cancelled by extension."
|
||||
}
|
||||
return true, reason
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// Commands
|
||||
|
||||
func (e *extensionAPI) Commands() []extensions.CommandDef {
|
||||
if e.kit.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
return e.kit.extRunner.RegisteredCommands()
|
||||
}
|
||||
|
||||
// Lifecycle
|
||||
|
||||
func (e *extensionAPI) Reload() error {
|
||||
return e.kit.ReloadExtensions()
|
||||
}
|
||||
|
||||
func (e *extensionAPI) HasExtensions() bool {
|
||||
return e.kit.extRunner != nil
|
||||
}
|
||||
@@ -104,11 +104,9 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
if runner.HasHandlers(extensions.AgentEnd) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(TurnEndEvent); ok {
|
||||
stopReason := ev.StopReason
|
||||
response := ev.Response
|
||||
stopReason, response := ev.StopReason, ev.Response
|
||||
if ev.Error != nil {
|
||||
stopReason = "error"
|
||||
response = ""
|
||||
stopReason, response = "error", ""
|
||||
} else if stopReason == "" {
|
||||
stopReason = "completed"
|
||||
}
|
||||
@@ -126,9 +124,9 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
// extension runner.
|
||||
//
|
||||
// Flow:
|
||||
// ToolExecutionStartEvent(spawn_subagent) → emit SubagentStartEvent
|
||||
// ToolExecutionStartEvent(subagent) → emit SubagentStartEvent
|
||||
// → SubscribeSubagent → emit SubagentChunkEvents
|
||||
// ToolResultEvent(spawn_subagent) → emit SubagentEndEvent
|
||||
// ToolResultEvent(subagent) → emit SubagentEndEvent
|
||||
//
|
||||
// We use ToolExecutionStart (not ToolCall) for SubagentStart because that
|
||||
// is when the subagent actually begins running. We use ToolResult for
|
||||
@@ -141,12 +139,12 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
// taskByCallID tracks the task description extracted from ToolCall input,
|
||||
// keyed by toolCallID. Populated on ToolCall, consumed on ToolResult.
|
||||
taskByCallID := make(map[string]string)
|
||||
var taskMu = &taskMutex{}
|
||||
var taskMu sync.Mutex
|
||||
|
||||
// Intercept ToolCall to capture the task and subscribe to child events.
|
||||
m.Subscribe(func(e Event) {
|
||||
ev, ok := e.(ToolCallEvent)
|
||||
if !ok || ev.ToolName != "spawn_subagent" {
|
||||
if !ok || ev.ToolName != "subagent" {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -157,7 +155,9 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
task = t
|
||||
}
|
||||
}
|
||||
taskMu.set(taskByCallID, ev.ToolCallID, task)
|
||||
taskMu.Lock()
|
||||
taskByCallID[ev.ToolCallID] = task
|
||||
taskMu.Unlock()
|
||||
|
||||
// Subscribe to child events so we can forward them as SubagentChunkEvents.
|
||||
if runner.HasHandlers(extensions.SubagentChunk) {
|
||||
@@ -201,10 +201,12 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
if runner.HasHandlers(extensions.SubagentStart) {
|
||||
m.Subscribe(func(e Event) {
|
||||
ev, ok := e.(ToolExecutionStartEvent)
|
||||
if !ok || ev.ToolName != "spawn_subagent" {
|
||||
if !ok || ev.ToolName != "subagent" {
|
||||
return
|
||||
}
|
||||
task := taskMu.get(taskByCallID, ev.ToolCallID)
|
||||
taskMu.Lock()
|
||||
task := taskByCallID[ev.ToolCallID]
|
||||
taskMu.Unlock()
|
||||
_, _ = runner.Emit(extensions.SubagentStartEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
Task: task,
|
||||
@@ -216,11 +218,13 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
if runner.HasHandlers(extensions.SubagentEnd) {
|
||||
m.Subscribe(func(e Event) {
|
||||
ev, ok := e.(ToolResultEvent)
|
||||
if !ok || ev.ToolName != "spawn_subagent" {
|
||||
if !ok || ev.ToolName != "subagent" {
|
||||
return
|
||||
}
|
||||
task := taskMu.get(taskByCallID, ev.ToolCallID)
|
||||
taskMu.del(taskByCallID, ev.ToolCallID)
|
||||
taskMu.Lock()
|
||||
task := taskByCallID[ev.ToolCallID]
|
||||
delete(taskByCallID, ev.ToolCallID)
|
||||
taskMu.Unlock()
|
||||
errMsg := ""
|
||||
if ev.IsError {
|
||||
errMsg = ev.Result
|
||||
@@ -243,7 +247,7 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
// Extension ContextPrepare → SDK ContextPrepare hook.
|
||||
if runner.HasHandlers(extensions.ContextPrepare) {
|
||||
m.OnContextPrepare(HookPriorityNormal, func(h ContextPrepareHook) *ContextPrepareResult {
|
||||
// Convert fantasy.Message slice to extension ContextMessage slice.
|
||||
// Convert LLM message slice to extension ContextMessage slice.
|
||||
extMsgs := make([]extensions.ContextMessage, len(h.Messages))
|
||||
for i, msg := range h.Messages {
|
||||
// Extract text from content parts.
|
||||
@@ -266,7 +270,7 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rebuild fantasy.Message slice from extension result.
|
||||
// Rebuild LLM message slice from extension result.
|
||||
rebuilt := make([]fantasy.Message, 0, len(r.Messages))
|
||||
for _, cm := range r.Messages {
|
||||
if cm.Index >= 0 && cm.Index < len(h.Messages) {
|
||||
@@ -324,27 +328,3 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// taskMutex is a simple mutex-protected map helper used by bridgeExtensions.
|
||||
// It lives in this file to avoid polluting the kit package with unexported types.
|
||||
type taskMutex struct {
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (t *taskMutex) set(m map[string]string, key, val string) {
|
||||
t.mu.Lock()
|
||||
m[key] = val
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
func (t *taskMutex) get(m map[string]string, key string) string {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return m[key]
|
||||
}
|
||||
|
||||
func (t *taskMutex) del(m map[string]string, key string) {
|
||||
t.mu.Lock()
|
||||
delete(m, key)
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
+7
-2
@@ -82,7 +82,7 @@ type AfterTurnResult struct{}
|
||||
// is assembled from the session tree (including compaction) and before the
|
||||
// messages are sent to the LLM. Hooks can filter, reorder, or inject messages.
|
||||
type ContextPrepareHook struct {
|
||||
// Messages is the current context as fantasy.Message objects.
|
||||
// Messages is the current context as LLM message objects.
|
||||
Messages []fantasy.Message
|
||||
}
|
||||
|
||||
@@ -167,8 +167,13 @@ func (hr *hookRegistry[In, Out]) register(p HookPriority, h func(In) *Out) func(
|
||||
}
|
||||
|
||||
// run executes all hooks in priority order. The first non-nil result wins.
|
||||
// Returns nil immediately if no hooks are registered.
|
||||
func (hr *hookRegistry[In, Out]) run(input In) *Out {
|
||||
hr.mu.RLock()
|
||||
if len(hr.hooks) == 0 {
|
||||
hr.mu.RUnlock()
|
||||
return nil
|
||||
}
|
||||
snapshot := make([]hookEntry[In, Out], len(hr.hooks))
|
||||
copy(snapshot, hr.hooks)
|
||||
hr.mu.RUnlock()
|
||||
@@ -247,7 +252,7 @@ func (m *Kit) OnBeforeCompact(p HookPriority, h func(BeforeCompactHook) *BeforeC
|
||||
// Tool wrapping via hooks
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// hookedTool wraps a fantasy.AgentTool to run BeforeToolCall and
|
||||
// hookedTool wraps an AgentTool to run BeforeToolCall and
|
||||
// AfterToolResult hooks around each execution. The registries are referenced
|
||||
// by pointer so hooks added after agent creation are still invoked.
|
||||
type hookedTool struct {
|
||||
|
||||
+13
-26
@@ -107,6 +107,11 @@ func TestHookRegistry_SamePriorityPreservesOrder(t *testing.T) {
|
||||
func TestHookRegistry_Unregister(t *testing.T) {
|
||||
hr := newHookRegistry[string, string]()
|
||||
|
||||
// Verify initial state (merged from TestHookRegistry_HasHooks).
|
||||
if hr.hasHooks() {
|
||||
t.Error("expected hasHooks to be false initially")
|
||||
}
|
||||
|
||||
unregister := hr.register(HookPriorityNormal, func(input string) *string {
|
||||
result := "should be gone"
|
||||
return &result
|
||||
@@ -137,24 +142,6 @@ func TestHookRegistry_NoHooksReturnsNil(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookRegistry_HasHooks(t *testing.T) {
|
||||
hr := newHookRegistry[string, string]()
|
||||
|
||||
if hr.hasHooks() {
|
||||
t.Error("expected hasHooks to be false initially")
|
||||
}
|
||||
|
||||
unsub := hr.register(HookPriorityNormal, func(_ string) *string { return nil })
|
||||
if !hr.hasHooks() {
|
||||
t.Error("expected hasHooks to be true after registration")
|
||||
}
|
||||
|
||||
unsub()
|
||||
if hr.hasHooks() {
|
||||
t.Error("expected hasHooks to be false after unregister")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookRegistry_ConcurrentAccess(t *testing.T) {
|
||||
hr := newHookRegistry[int, int]()
|
||||
|
||||
@@ -187,7 +174,7 @@ func TestHookRegistry_ConcurrentAccess(t *testing.T) {
|
||||
// hookedTool tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// mockAgentTool implements fantasy.AgentTool for testing.
|
||||
// mockAgentTool implements the AgentTool interface for testing.
|
||||
type mockAgentTool struct {
|
||||
name string
|
||||
runFn func(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error)
|
||||
@@ -206,10 +193,14 @@ func (m *mockAgentTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy
|
||||
return fantasy.NewTextResponse("default output"), nil
|
||||
}
|
||||
|
||||
func TestHookedTool_Passthrough(t *testing.T) {
|
||||
// newEmptyHookedTool creates a hookedTool with empty hook registries and the given mock tool.
|
||||
func newEmptyHookedTool(mock *mockAgentTool) *hookedTool {
|
||||
before := newHookRegistry[BeforeToolCallHook, BeforeToolCallResult]()
|
||||
after := newHookRegistry[AfterToolResultHook, AfterToolResultResult]()
|
||||
return &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after}
|
||||
}
|
||||
|
||||
func TestHookedTool_Passthrough(t *testing.T) {
|
||||
mock := &mockAgentTool{
|
||||
name: "test_tool",
|
||||
runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
@@ -217,7 +208,7 @@ func TestHookedTool_Passthrough(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after}
|
||||
ht := newEmptyHookedTool(mock)
|
||||
|
||||
resp, err := ht.Run(context.Background(), fantasy.ToolCall{Input: "{}"})
|
||||
if err != nil {
|
||||
@@ -372,11 +363,7 @@ func TestHookedTool_HookReceivesToolInfo(t *testing.T) {
|
||||
|
||||
func TestHookedTool_InfoDelegates(t *testing.T) {
|
||||
mock := &mockAgentTool{name: "delegate_test"}
|
||||
ht := &hookedTool{
|
||||
inner: mock,
|
||||
beforeToolCall: newHookRegistry[BeforeToolCallHook, BeforeToolCallResult](),
|
||||
afterToolResult: newHookRegistry[AfterToolResultHook, AfterToolResultResult](),
|
||||
}
|
||||
ht := newEmptyHookedTool(mock)
|
||||
|
||||
if ht.Info().Name != "delegate_test" {
|
||||
t.Errorf("expected Info() to delegate to inner tool")
|
||||
|
||||
+113
-572
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
charmlog "github.com/charmbracelet/log"
|
||||
|
||||
"github.com/mark3labs/kit/internal/agent"
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
@@ -67,8 +68,16 @@ type Kit struct {
|
||||
// SubscribeSubagent(). Keyed by toolCallID → *subagentListenerSet.
|
||||
subagentListeners sync.Map
|
||||
|
||||
// skillCache holds skills discovered for this Kit instance.
|
||||
// Using a per-instance cache avoids cross-contamination when multiple
|
||||
// Kit instances exist in the same process.
|
||||
skillCache struct {
|
||||
skills []*skills.Skill
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// steerCh is a buffered channel used to inject steering messages into
|
||||
// the running agent turn via Fantasy's PrepareStep. Created fresh for
|
||||
// the running agent turn via the LLM library's PrepareStep. Created fresh for
|
||||
// each generate() call and set to nil when idle. Protected by steerMu.
|
||||
steerMu sync.Mutex
|
||||
steerCh chan string
|
||||
@@ -76,31 +85,14 @@ type Kit struct {
|
||||
}
|
||||
|
||||
// Subscribe registers an EventListener that will be called for every lifecycle
|
||||
// event emitted during Prompt() and PromptWithCallbacks(). Returns an
|
||||
// unsubscribe function that removes the listener.
|
||||
// event emitted during Prompt(). Returns an unsubscribe function that removes
|
||||
// the listener.
|
||||
func (m *Kit) Subscribe(listener EventListener) func() {
|
||||
return m.events.subscribe(listener)
|
||||
}
|
||||
|
||||
// GetExtRunner returns the extension runner (nil if extensions are disabled).
|
||||
//
|
||||
// Deprecated: Use SetExtensionContext and EmitSessionStart instead. GetExtRunner
|
||||
// leaks the internal extensions.Runner type across the SDK boundary.
|
||||
func (m *Kit) GetExtRunner() *extensions.Runner { return m.extRunner }
|
||||
|
||||
// GetBufferedLogger returns the buffered debug logger (nil if not configured).
|
||||
//
|
||||
// Deprecated: Use GetBufferedDebugMessages instead.
|
||||
func (m *Kit) GetBufferedLogger() *tools.BufferedDebugLogger { return m.bufferedLogger }
|
||||
|
||||
// GetAgent returns the underlying agent.
|
||||
//
|
||||
// Deprecated: Use GetToolNames, GetLoadingMessage, GetLoadedServerNames,
|
||||
// GetMCPToolCount, GetExtensionToolCount instead.
|
||||
func (m *Kit) GetAgent() *agent.Agent { return m.agent }
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Narrow accessors — prefer these over GetAgent/GetExtRunner/GetBufferedLogger
|
||||
// Narrow accessors
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// GetToolNames returns the names of all tools available to the agent.
|
||||
@@ -144,222 +136,6 @@ func (m *Kit) GetBufferedDebugMessages() []string {
|
||||
return m.bufferedLogger.GetMessages()
|
||||
}
|
||||
|
||||
// SetExtensionContext configures the extension runner with the given context
|
||||
// functions. No-op if extensions are disabled.
|
||||
func (m *Kit) SetExtensionContext(ctx extensions.Context) {
|
||||
if m.extRunner != nil {
|
||||
m.extRunner.SetContext(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// GetExtensionContext returns the current extension runtime context.
|
||||
// Returns a zero Context if extensions are disabled.
|
||||
func (m *Kit) GetExtensionContext() extensions.Context {
|
||||
if m.extRunner != nil {
|
||||
return m.extRunner.GetContext()
|
||||
}
|
||||
return extensions.Context{}
|
||||
}
|
||||
|
||||
// UpdateExtensionContextModel updates the Model field on the extension
|
||||
// context so subsequent event handlers see the new model. This is a
|
||||
// targeted update that avoids replacing the entire Context struct.
|
||||
func (m *Kit) UpdateExtensionContextModel(model string) {
|
||||
if m.extRunner != nil {
|
||||
ctx := m.extRunner.GetContext()
|
||||
ctx.Model = model
|
||||
m.extRunner.SetContext(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// EmitSessionStart fires the SessionStart event for extensions.
|
||||
// No-op if extensions are disabled or no handlers are registered.
|
||||
func (m *Kit) EmitSessionStart() {
|
||||
if m.extRunner != nil && m.extRunner.HasHandlers(extensions.SessionStart) {
|
||||
_, _ = m.extRunner.Emit(extensions.SessionStartEvent{})
|
||||
}
|
||||
}
|
||||
|
||||
// ExtensionCommands returns the slash commands registered by extensions.
|
||||
// Returns nil if extensions are disabled or no commands are registered.
|
||||
func (m *Kit) ExtensionCommands() []extensions.CommandDef {
|
||||
if m.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
return m.extRunner.RegisteredCommands()
|
||||
}
|
||||
|
||||
// SetExtensionWidget places or updates a persistent extension widget.
|
||||
// Delegates to the extension runner. No-op if extensions are disabled.
|
||||
func (m *Kit) SetExtensionWidget(config extensions.WidgetConfig) {
|
||||
if m.extRunner != nil {
|
||||
m.extRunner.SetWidget(config)
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveExtensionWidget removes a previously placed extension widget by ID.
|
||||
// Delegates to the extension runner. No-op if extensions are disabled.
|
||||
func (m *Kit) RemoveExtensionWidget(id string) {
|
||||
if m.extRunner != nil {
|
||||
m.extRunner.RemoveWidget(id)
|
||||
}
|
||||
}
|
||||
|
||||
// GetExtensionWidgets returns extension widgets matching the given placement.
|
||||
// Returns nil if extensions are disabled or no widgets match.
|
||||
func (m *Kit) GetExtensionWidgets(placement extensions.WidgetPlacement) []extensions.WidgetConfig {
|
||||
if m.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
return m.extRunner.GetWidgets(placement)
|
||||
}
|
||||
|
||||
// SetExtensionHeader places or replaces the custom header from extensions.
|
||||
// Delegates to the extension runner. No-op if extensions are disabled.
|
||||
func (m *Kit) SetExtensionHeader(config extensions.HeaderFooterConfig) {
|
||||
if m.extRunner != nil {
|
||||
m.extRunner.SetHeader(config)
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveExtensionHeader removes the custom extension header.
|
||||
// Delegates to the extension runner. No-op if extensions are disabled.
|
||||
func (m *Kit) RemoveExtensionHeader() {
|
||||
if m.extRunner != nil {
|
||||
m.extRunner.RemoveHeader()
|
||||
}
|
||||
}
|
||||
|
||||
// GetExtensionHeader returns the current custom header, or nil if none is set.
|
||||
// Returns nil if extensions are disabled.
|
||||
func (m *Kit) GetExtensionHeader() *extensions.HeaderFooterConfig {
|
||||
if m.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
return m.extRunner.GetHeader()
|
||||
}
|
||||
|
||||
// SetExtensionFooter places or replaces the custom footer from extensions.
|
||||
// Delegates to the extension runner. No-op if extensions are disabled.
|
||||
func (m *Kit) SetExtensionFooter(config extensions.HeaderFooterConfig) {
|
||||
if m.extRunner != nil {
|
||||
m.extRunner.SetFooter(config)
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveExtensionFooter removes the custom extension footer.
|
||||
// Delegates to the extension runner. No-op if extensions are disabled.
|
||||
func (m *Kit) RemoveExtensionFooter() {
|
||||
if m.extRunner != nil {
|
||||
m.extRunner.RemoveFooter()
|
||||
}
|
||||
}
|
||||
|
||||
// GetExtensionFooter returns the current custom footer, or nil if none is set.
|
||||
// Returns nil if extensions are disabled.
|
||||
func (m *Kit) GetExtensionFooter() *extensions.HeaderFooterConfig {
|
||||
if m.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
return m.extRunner.GetFooter()
|
||||
}
|
||||
|
||||
// GetExtensionToolRenderer returns the custom renderer for the named tool, or
|
||||
// nil if no extension registered a renderer for it. Returns nil if extensions
|
||||
// are disabled.
|
||||
func (m *Kit) GetExtensionToolRenderer(toolName string) *extensions.ToolRenderConfig {
|
||||
if m.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
return m.extRunner.GetToolRenderer(toolName)
|
||||
}
|
||||
|
||||
// SetExtensionEditor installs an editor interceptor from extensions.
|
||||
// Delegates to the extension runner. No-op if extensions are disabled.
|
||||
func (m *Kit) SetExtensionEditor(config extensions.EditorConfig) {
|
||||
if m.extRunner != nil {
|
||||
m.extRunner.SetEditor(config)
|
||||
}
|
||||
}
|
||||
|
||||
// ResetExtensionEditor removes the active editor interceptor from extensions.
|
||||
// Delegates to the extension runner. No-op if extensions are disabled.
|
||||
func (m *Kit) ResetExtensionEditor() {
|
||||
if m.extRunner != nil {
|
||||
m.extRunner.ResetEditor()
|
||||
}
|
||||
}
|
||||
|
||||
// GetExtensionEditor returns the current editor interceptor, or nil if none
|
||||
// is set. Returns nil if extensions are disabled.
|
||||
func (m *Kit) GetExtensionEditor() *extensions.EditorConfig {
|
||||
if m.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
return m.extRunner.GetEditor()
|
||||
}
|
||||
|
||||
// SetExtensionUIVisibility stores extension-provided UI visibility overrides.
|
||||
// No-op if extensions are disabled.
|
||||
func (m *Kit) SetExtensionUIVisibility(v extensions.UIVisibility) {
|
||||
if m.extRunner != nil {
|
||||
m.extRunner.SetUIVisibility(v)
|
||||
}
|
||||
}
|
||||
|
||||
// GetExtensionUIVisibility returns extension-provided UI visibility overrides,
|
||||
// or nil if none have been set. Returns nil if extensions are disabled.
|
||||
func (m *Kit) GetExtensionUIVisibility() *extensions.UIVisibility {
|
||||
if m.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
return m.extRunner.GetUIVisibility()
|
||||
}
|
||||
|
||||
// GetSessionMessages returns the conversation messages on the current branch
|
||||
// as extension-facing SessionMessage structs, ordered root to leaf.
|
||||
func (m *Kit) GetSessionMessages() []extensions.SessionMessage {
|
||||
if m.treeSession == nil {
|
||||
return nil
|
||||
}
|
||||
branch := m.treeSession.GetBranch("")
|
||||
var msgs []extensions.SessionMessage
|
||||
for _, entry := range branch {
|
||||
me, ok := entry.(*session.MessageEntry)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
msg, err := me.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
// Flatten content parts into a single text string.
|
||||
var content strings.Builder
|
||||
for _, p := range msg.Parts {
|
||||
switch pt := p.(type) {
|
||||
case message.TextContent:
|
||||
content.WriteString(pt.Text)
|
||||
case message.ReasoningContent:
|
||||
content.WriteString(pt.Thinking)
|
||||
case message.ToolCall:
|
||||
fmt.Fprintf(&content, "[tool_call: %s(%s)]", pt.Name, pt.Input)
|
||||
case message.ToolResult:
|
||||
fmt.Fprintf(&content, "[tool_result: %s]", pt.Content)
|
||||
}
|
||||
}
|
||||
msgs = append(msgs, extensions.SessionMessage{
|
||||
ID: me.ID,
|
||||
ParentID: me.ParentID,
|
||||
Role: string(msg.Role),
|
||||
Content: content.String(),
|
||||
Model: msg.Model,
|
||||
Provider: msg.Provider,
|
||||
Timestamp: me.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
|
||||
})
|
||||
}
|
||||
return msgs
|
||||
}
|
||||
|
||||
// StructuredMessage represents a conversation message with typed content parts
|
||||
// (tool calls, reasoning, finish markers, etc.) instead of flattened text.
|
||||
type StructuredMessage struct {
|
||||
@@ -377,11 +153,29 @@ type StructuredMessage struct {
|
||||
// flattens all content to a single text string, this preserves tool calls,
|
||||
// tool results, reasoning blocks, and finish markers as distinct typed parts.
|
||||
func (m *Kit) GetStructuredMessages() []StructuredMessage {
|
||||
if m.treeSession == nil {
|
||||
return iterBranchMessages(m.treeSession, func(me *session.MessageEntry, msg message.Message) StructuredMessage {
|
||||
return StructuredMessage{
|
||||
ID: me.ID,
|
||||
ParentID: me.ParentID,
|
||||
Role: msg.Role,
|
||||
Parts: msg.Parts,
|
||||
Model: msg.Model,
|
||||
Provider: msg.Provider,
|
||||
Timestamp: me.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// iterBranchMessages iterates over the current branch's MessageEntry items,
|
||||
// converting each to a message.Message and calling fn to build the result.
|
||||
// Returns nil if there is no tree session. Skips entries that are not
|
||||
// MessageEntry or that fail conversion.
|
||||
func iterBranchMessages[T any](tm *session.TreeManager, fn func(*session.MessageEntry, message.Message) T) []T {
|
||||
if tm == nil {
|
||||
return nil
|
||||
}
|
||||
branch := m.treeSession.GetBranch("")
|
||||
var msgs []StructuredMessage
|
||||
branch := tm.GetBranch("")
|
||||
var results []T
|
||||
for _, entry := range branch {
|
||||
me, ok := entry.(*session.MessageEntry)
|
||||
if !ok {
|
||||
@@ -391,137 +185,9 @@ func (m *Kit) GetStructuredMessages() []StructuredMessage {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs = append(msgs, StructuredMessage{
|
||||
ID: me.ID,
|
||||
ParentID: me.ParentID,
|
||||
Role: msg.Role,
|
||||
Parts: msg.Parts,
|
||||
Model: msg.Model,
|
||||
Provider: msg.Provider,
|
||||
Timestamp: me.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
|
||||
})
|
||||
}
|
||||
return msgs
|
||||
}
|
||||
|
||||
// GetSessionFilePath returns the JSONL file path of the current session.
|
||||
func (m *Kit) GetSessionFilePath() string {
|
||||
if m.treeSession == nil {
|
||||
return ""
|
||||
}
|
||||
return m.treeSession.GetFilePath()
|
||||
}
|
||||
|
||||
// AppendExtensionEntry persists custom extension data in the session tree.
|
||||
func (m *Kit) AppendExtensionEntry(extType, data string) (string, error) {
|
||||
if m.treeSession == nil {
|
||||
return "", fmt.Errorf("no session available")
|
||||
}
|
||||
return m.treeSession.AppendExtensionData(extType, data)
|
||||
}
|
||||
|
||||
// GetExtensionEntries retrieves persisted extension data entries for a type.
|
||||
func (m *Kit) GetExtensionEntries(extType string) []extensions.ExtensionEntry {
|
||||
if m.treeSession == nil {
|
||||
return nil
|
||||
}
|
||||
entries := m.treeSession.GetExtensionData(extType)
|
||||
result := make([]extensions.ExtensionEntry, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
result = append(result, extensions.ExtensionEntry{
|
||||
ID: e.ID,
|
||||
EntryType: e.ExtType,
|
||||
Data: e.Data,
|
||||
Timestamp: e.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// SetExtensionStatus places or updates a keyed status bar entry.
|
||||
func (m *Kit) SetExtensionStatus(entry extensions.StatusBarEntry) {
|
||||
if m.extRunner != nil {
|
||||
m.extRunner.SetStatusEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveExtensionStatus removes a keyed status bar entry.
|
||||
func (m *Kit) RemoveExtensionStatus(key string) {
|
||||
if m.extRunner != nil {
|
||||
m.extRunner.RemoveStatusEntry(key)
|
||||
}
|
||||
}
|
||||
|
||||
// GetExtensionStatusEntries returns all extension status bar entries sorted by priority.
|
||||
func (m *Kit) GetExtensionStatusEntries() []extensions.StatusBarEntry {
|
||||
if m.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
return m.extRunner.GetStatusEntries()
|
||||
}
|
||||
|
||||
// GetExtensionShortcuts returns a map of key bindings to handler functions
|
||||
// from all loaded extensions. Returns nil if no shortcuts are registered or
|
||||
// extensions are disabled. Handlers are closures that capture the runner's
|
||||
// current context, so they can call Print/SetStatus/etc.
|
||||
func (m *Kit) GetExtensionShortcuts() map[string]func() {
|
||||
if m.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
entries := m.extRunner.GetShortcuts()
|
||||
if entries == nil {
|
||||
return nil
|
||||
}
|
||||
result := make(map[string]func(), len(entries))
|
||||
for key, entry := range entries {
|
||||
h := entry.Handler
|
||||
r := m.extRunner
|
||||
result[key] = func() {
|
||||
ctx := r.GetContext()
|
||||
h(ctx)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetExtensionToolInfos returns information about all tools available to the
|
||||
// agent, including enabled/disabled status from SetActiveTools. Each tool is
|
||||
// categorized by source: "core", "mcp", or "extension".
|
||||
func (m *Kit) GetExtensionToolInfos() []extensions.ToolInfo {
|
||||
agentTools := m.agent.GetTools()
|
||||
coreCount := m.agent.GetCoreToolCount()
|
||||
mcpCount := m.agent.GetMCPToolCount()
|
||||
|
||||
result := make([]extensions.ToolInfo, 0, len(agentTools))
|
||||
for i, t := range agentTools {
|
||||
info := t.Info()
|
||||
source := "core"
|
||||
if i >= coreCount && i < coreCount+mcpCount {
|
||||
source = "mcp"
|
||||
} else if i >= coreCount+mcpCount {
|
||||
source = "extension"
|
||||
}
|
||||
enabled := true
|
||||
if m.extRunner != nil && m.extRunner.IsToolDisabled(info.Name) {
|
||||
enabled = false
|
||||
}
|
||||
result = append(result, extensions.ToolInfo{
|
||||
Name: info.Name,
|
||||
Description: info.Description,
|
||||
Source: source,
|
||||
Enabled: enabled,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// SetExtensionActiveTools restricts the tool set to the named tools. All
|
||||
// other tools are blocked from execution. Pass nil to re-enable all tools.
|
||||
// No-op if extensions are disabled.
|
||||
func (m *Kit) SetExtensionActiveTools(names []string) {
|
||||
if m.extRunner != nil {
|
||||
m.extRunner.SetActiveTools(names)
|
||||
results = append(results, fn(me, msg))
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
// SetModel changes the active model at runtime. The existing tools, system
|
||||
@@ -538,6 +204,10 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
|
||||
// Build a provider config from current settings, overriding the model.
|
||||
// Load system prompt properly (handles both file paths and inline content).
|
||||
systemPrompt, _ := config.LoadSystemPrompt(viper.GetString("system-prompt"))
|
||||
thinkingLevel := models.ParseThinkingLevel(viper.GetString("thinking-level"))
|
||||
|
||||
// With message-level caching, thinking and caching can work together.
|
||||
// No need to disable caching when thinking is enabled.
|
||||
config := &models.ProviderConfig{
|
||||
ModelString: modelString,
|
||||
SystemPrompt: systemPrompt,
|
||||
@@ -545,7 +215,8 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
|
||||
ProviderURL: viper.GetString("provider-url"),
|
||||
MaxTokens: viper.GetInt("max-tokens"),
|
||||
TLSSkipVerify: viper.GetBool("tls-skip-verify"),
|
||||
ThinkingLevel: models.ParseThinkingLevel(viper.GetString("thinking-level")),
|
||||
ThinkingLevel: thinkingLevel,
|
||||
DisableCaching: false, // Caching enabled by default, works with thinking
|
||||
}
|
||||
temperature := float32(viper.GetFloat64("temperature"))
|
||||
config.Temperature = &temperature
|
||||
@@ -577,7 +248,7 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
|
||||
func (m *Kit) GetAvailableModels() []extensions.ModelInfoEntry {
|
||||
registry := models.GetGlobalRegistry()
|
||||
var result []extensions.ModelInfoEntry
|
||||
for _, providerID := range registry.GetFantasyProviders() {
|
||||
for _, providerID := range registry.GetLLMProviders() {
|
||||
modelsMap, err := registry.GetModelsForProvider(providerID)
|
||||
if err != nil {
|
||||
continue
|
||||
@@ -596,50 +267,6 @@ func (m *Kit) GetAvailableModels() []extensions.ModelInfoEntry {
|
||||
return result
|
||||
}
|
||||
|
||||
// GetExtensionOption resolves a named extension option value.
|
||||
func (m *Kit) GetExtensionOption(name string) string {
|
||||
if m.extRunner == nil {
|
||||
return ""
|
||||
}
|
||||
return m.extRunner.GetOption(name)
|
||||
}
|
||||
|
||||
// SetExtensionOption stores a runtime override for a named extension option.
|
||||
func (m *Kit) SetExtensionOption(name, value string) {
|
||||
if m.extRunner != nil {
|
||||
m.extRunner.SetOption(name, value)
|
||||
}
|
||||
}
|
||||
|
||||
// EmitModelChange fires the ModelChange event for extensions.
|
||||
// No-op if extensions are disabled or no handlers are registered.
|
||||
func (m *Kit) EmitModelChange(newModel, previousModel, source string) {
|
||||
if m.extRunner != nil && m.extRunner.HasHandlers(extensions.ModelChange) {
|
||||
_, _ = m.extRunner.Emit(extensions.ModelChangeEvent{
|
||||
NewModel: newModel,
|
||||
PreviousModel: previousModel,
|
||||
Source: source,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// EmitExtensionCustomEvent dispatches a named event to all extension handlers.
|
||||
// No-op if extensions are disabled.
|
||||
func (m *Kit) EmitExtensionCustomEvent(name, data string) {
|
||||
if m.extRunner != nil {
|
||||
m.extRunner.EmitCustomEvent(name, data)
|
||||
}
|
||||
}
|
||||
|
||||
// GetExtensionMessageRenderer returns the named message renderer, or nil
|
||||
// if no extension registered a renderer with that name.
|
||||
func (m *Kit) GetExtensionMessageRenderer(name string) *extensions.MessageRendererConfig {
|
||||
if m.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
return m.extRunner.GetMessageRenderer(name)
|
||||
}
|
||||
|
||||
// ReloadExtensions hot-reloads all extensions from disk. Event handlers,
|
||||
// commands, renderers, and shortcuts update immediately. Extension-defined
|
||||
// tools are NOT updated (they are baked into the agent at creation time).
|
||||
@@ -714,7 +341,7 @@ func (m *Kit) ExecuteCompletion(ctx context.Context, req extensions.CompleteRequ
|
||||
}
|
||||
defer closer()
|
||||
|
||||
// Build fantasy agent options (no tools — just a simple completion).
|
||||
// Build agent options (no tools — just a simple completion).
|
||||
var agentOpts []fantasy.AgentOption
|
||||
if req.System != "" {
|
||||
agentOpts = append(agentOpts, fantasy.WithSystemPrompt(req.System))
|
||||
@@ -728,7 +355,7 @@ func (m *Kit) ExecuteCompletion(ctx context.Context, req extensions.CompleteRequ
|
||||
|
||||
completionAgent := fantasy.NewAgent(llmModel, agentOpts...)
|
||||
|
||||
// Convert extension SessionMessage history to fantasy.Message slice.
|
||||
// Convert extension SessionMessage history to LLM message slice.
|
||||
var messages []fantasy.Message
|
||||
for _, sm := range req.Messages {
|
||||
messages = append(messages, fantasy.Message{
|
||||
@@ -776,53 +403,6 @@ func (m *Kit) ExecuteCompletion(ctx context.Context, req extensions.CompleteRequ
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EmitBeforeFork emits a BeforeFork event to extensions and returns
|
||||
// whether the fork was cancelled and the reason. No-op if extensions are
|
||||
// disabled (returns false, "").
|
||||
func (m *Kit) EmitBeforeFork(targetID string, isUserMsg bool, userText string) (cancelled bool, reason string) {
|
||||
if m.extRunner == nil || !m.extRunner.HasHandlers(extensions.BeforeFork) {
|
||||
return false, ""
|
||||
}
|
||||
result, _ := m.extRunner.Emit(extensions.BeforeForkEvent{
|
||||
TargetID: targetID,
|
||||
IsUserMessage: isUserMsg,
|
||||
UserText: userText,
|
||||
})
|
||||
if r, ok := result.(extensions.BeforeForkResult); ok && r.Cancel {
|
||||
reason := r.Reason
|
||||
if reason == "" {
|
||||
reason = "Fork cancelled by extension."
|
||||
}
|
||||
return true, reason
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// EmitBeforeSessionSwitch emits a BeforeSessionSwitch event to extensions
|
||||
// and returns whether the switch was cancelled and the reason. No-op if
|
||||
// extensions are disabled (returns false, "").
|
||||
func (m *Kit) EmitBeforeSessionSwitch(switchReason string) (cancelled bool, reason string) {
|
||||
if m.extRunner == nil || !m.extRunner.HasHandlers(extensions.BeforeSessionSwitch) {
|
||||
return false, ""
|
||||
}
|
||||
result, _ := m.extRunner.Emit(extensions.BeforeSessionSwitchEvent{
|
||||
Reason: switchReason,
|
||||
})
|
||||
if r, ok := result.(extensions.BeforeSessionSwitchResult); ok && r.Cancel {
|
||||
reason := r.Reason
|
||||
if reason == "" {
|
||||
reason = "Session switch cancelled by extension."
|
||||
}
|
||||
return true, reason
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// HasExtensions returns true if the extension runner is configured and active.
|
||||
func (m *Kit) HasExtensions() bool {
|
||||
return m.extRunner != nil
|
||||
}
|
||||
|
||||
// Options configures Kit creation with optional overrides for model,
|
||||
// prompts, configuration, and behavior settings. All fields are optional
|
||||
// and will use CLI defaults if not specified.
|
||||
@@ -1063,6 +643,15 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
// Bridge extension events to SDK hooks.
|
||||
if agentResult.ExtRunner != nil {
|
||||
k.bridgeExtensions(agentResult.ExtRunner)
|
||||
|
||||
// Initialize extension context with minimal defaults. SDK users can call
|
||||
// Extensions().SetContext to override with richer implementations (TUI callbacks,
|
||||
// prompts, etc.). This ensures extensions never crash on nil function fields.
|
||||
k.Extensions().SetContext(extensions.Context{
|
||||
CWD: cwd,
|
||||
Model: k.modelString,
|
||||
Interactive: false, // SDK mode defaults to non-interactive
|
||||
})
|
||||
}
|
||||
|
||||
return k, nil
|
||||
@@ -1233,16 +822,16 @@ type TurnResult struct {
|
||||
// TotalUsage is the aggregate token usage across all steps in the turn
|
||||
// (includes tool-calling loop iterations). Nil if the provider didn't
|
||||
// report usage.
|
||||
TotalUsage *FantasyUsage
|
||||
TotalUsage *LLMUsage
|
||||
|
||||
// FinalUsage is the token usage from the last API call only. Use this
|
||||
// for context window fill estimation (InputTokens + OutputTokens ≈
|
||||
// current context size). Nil if unavailable.
|
||||
FinalUsage *FantasyUsage
|
||||
FinalUsage *LLMUsage
|
||||
|
||||
// Messages is the full updated conversation after the turn, including
|
||||
// any tool call/result messages added during the agent loop.
|
||||
Messages []FantasyMessage
|
||||
Messages []LLMMessage
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -1263,7 +852,7 @@ type SubagentConfig struct {
|
||||
SystemPrompt string
|
||||
|
||||
// Tools overrides the tool set. If nil, SubagentTools() is used (all
|
||||
// core tools except spawn_subagent, preventing infinite recursion).
|
||||
// core tools except subagent, preventing infinite recursion).
|
||||
Tools []Tool
|
||||
|
||||
// NoSession, when true, uses an in-memory ephemeral session. When false
|
||||
@@ -1281,17 +870,16 @@ type SubagentConfig struct {
|
||||
}
|
||||
|
||||
// SubagentResult contains the outcome of an in-process subagent execution.
|
||||
// Errors are returned as the error return value of Subagent(), not in this struct.
|
||||
type SubagentResult struct {
|
||||
// Response is the subagent's final text response.
|
||||
Response string
|
||||
// Error is set if the subagent failed (nil on success).
|
||||
Error error
|
||||
// SessionID is the subagent's session identifier (for replay).
|
||||
SessionID string
|
||||
// StopReason is the LLM's finish reason for the subagent's final turn.
|
||||
StopReason string
|
||||
// Usage contains token usage from the subagent's run.
|
||||
Usage *FantasyUsage
|
||||
Usage *LLMUsage
|
||||
// Elapsed is the total execution time.
|
||||
Elapsed time.Duration
|
||||
}
|
||||
@@ -1337,7 +925,7 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult
|
||||
systemPrompt = "You are a helpful coding assistant. Complete the task efficiently and thoroughly."
|
||||
}
|
||||
|
||||
// Default tools: everything except spawn_subagent.
|
||||
// Default tools: everything except subagent.
|
||||
tools := cfg.Tools
|
||||
if tools == nil {
|
||||
tools = SubagentTools()
|
||||
@@ -1359,10 +947,7 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult
|
||||
childOpts.Model = m.modelString
|
||||
child, err = New(ctx, childOpts)
|
||||
if err != nil {
|
||||
return &SubagentResult{
|
||||
Error: fmt.Errorf("failed to create subagent: %w", err),
|
||||
Elapsed: time.Since(start),
|
||||
}, err
|
||||
return nil, fmt.Errorf("failed to create subagent: %w", err)
|
||||
}
|
||||
// Prepend a note so the agent knows which model is actually running.
|
||||
cfg.Prompt = fmt.Sprintf(
|
||||
@@ -1370,10 +955,7 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult
|
||||
model, m.modelString, cfg.Prompt,
|
||||
)
|
||||
} else if err != nil {
|
||||
return &SubagentResult{
|
||||
Error: fmt.Errorf("failed to create subagent: %w", err),
|
||||
Elapsed: time.Since(start),
|
||||
}, err
|
||||
return nil, fmt.Errorf("failed to create subagent: %w", err)
|
||||
}
|
||||
defer func() { _ = child.Close() }()
|
||||
|
||||
@@ -1387,11 +969,7 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
return &SubagentResult{
|
||||
Error: err,
|
||||
SessionID: child.GetSessionID(),
|
||||
Elapsed: elapsed,
|
||||
}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
subResult := &SubagentResult{
|
||||
@@ -1430,14 +1008,13 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
|
||||
case msg := <-steerCh:
|
||||
leftover = append(leftover, msg)
|
||||
default:
|
||||
goto drained
|
||||
m.steerMu.Lock()
|
||||
m.steerCh = nil
|
||||
m.leftoverSteer = leftover
|
||||
m.steerMu.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
drained:
|
||||
m.steerMu.Lock()
|
||||
m.steerCh = nil
|
||||
m.leftoverSteer = leftover
|
||||
m.steerMu.Unlock()
|
||||
}()
|
||||
ctx = agent.ContextWithSteerCh(ctx, steerCh)
|
||||
ctx = agent.ContextWithSteerConsumed(ctx, func(count int) {
|
||||
@@ -1445,7 +1022,7 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
|
||||
})
|
||||
|
||||
// Inject the in-process subagent spawner into the context so the
|
||||
// spawn_subagent core tool can create child Kit instances without
|
||||
// subagent core tool can create child Kit instances without
|
||||
// importing pkg/kit (which would create an import cycle).
|
||||
ctx = core.WithSubagentSpawner(ctx, func(
|
||||
spawnCtx context.Context, toolCallID, prompt, model, systemPrompt string, timeout time.Duration,
|
||||
@@ -1470,7 +1047,7 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
|
||||
}
|
||||
sr := &core.SubagentSpawnResult{
|
||||
Response: result.Response,
|
||||
Error: result.Error,
|
||||
Error: err,
|
||||
SessionID: result.SessionID,
|
||||
Elapsed: result.Elapsed,
|
||||
}
|
||||
@@ -1532,6 +1109,14 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
|
||||
},
|
||||
func(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64) {
|
||||
// Emit step usage event for real-time cost tracking
|
||||
if viper.GetBool("debug") {
|
||||
charmlog.Debug("Kit.generate emitting StepUsageEvent",
|
||||
"input", inputTokens,
|
||||
"output", outputTokens,
|
||||
"cacheRead", cacheReadTokens,
|
||||
"cacheCreate", cacheCreationTokens,
|
||||
)
|
||||
}
|
||||
m.events.emit(StepUsageEvent{
|
||||
InputTokens: uint64(inputTokens),
|
||||
OutputTokens: uint64(outputTokens),
|
||||
@@ -1571,36 +1156,34 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
|
||||
}
|
||||
|
||||
// Run BeforeTurn hooks — can modify the prompt, inject system/context messages.
|
||||
if m.beforeTurn.hasHooks() {
|
||||
if hookResult := m.beforeTurn.run(BeforeTurnHook{Prompt: prompt}); hookResult != nil {
|
||||
// Override prompt text in the last user message, preserving
|
||||
// any file parts (e.g. clipboard images).
|
||||
if hookResult.Prompt != nil {
|
||||
for i := len(preMessages) - 1; i >= 0; i-- {
|
||||
if preMessages[i].Role == fantasy.MessageRoleUser {
|
||||
files := extractFileParts(preMessages[i])
|
||||
preMessages[i] = fantasy.NewUserMessage(*hookResult.Prompt, files...)
|
||||
break
|
||||
}
|
||||
if hookResult := m.beforeTurn.run(BeforeTurnHook{Prompt: prompt}); hookResult != nil {
|
||||
// Override prompt text in the last user message, preserving
|
||||
// any file parts (e.g. clipboard images).
|
||||
if hookResult.Prompt != nil {
|
||||
for i := len(preMessages) - 1; i >= 0; i-- {
|
||||
if preMessages[i].Role == fantasy.MessageRoleUser {
|
||||
files := extractFileParts(preMessages[i])
|
||||
preMessages[i] = fantasy.NewUserMessage(*hookResult.Prompt, files...)
|
||||
break
|
||||
}
|
||||
}
|
||||
// Inject messages before the original preMessages.
|
||||
var injected []fantasy.Message
|
||||
if hookResult.SystemPrompt != nil {
|
||||
injected = append(injected, fantasy.NewSystemMessage(*hookResult.SystemPrompt))
|
||||
}
|
||||
if hookResult.InjectText != nil {
|
||||
injected = append(injected, fantasy.NewUserMessage(*hookResult.InjectText))
|
||||
}
|
||||
if len(injected) > 0 {
|
||||
preMessages = append(injected, preMessages...)
|
||||
}
|
||||
}
|
||||
// Inject messages before the original preMessages.
|
||||
var injected []fantasy.Message
|
||||
if hookResult.SystemPrompt != nil {
|
||||
injected = append(injected, fantasy.NewSystemMessage(*hookResult.SystemPrompt))
|
||||
}
|
||||
if hookResult.InjectText != nil {
|
||||
injected = append(injected, fantasy.NewUserMessage(*hookResult.InjectText))
|
||||
}
|
||||
if len(injected) > 0 {
|
||||
preMessages = append(injected, preMessages...)
|
||||
}
|
||||
}
|
||||
|
||||
// Persist pre-generation messages to tree session.
|
||||
for _, msg := range preMessages {
|
||||
_, _ = m.treeSession.AppendFantasyMessage(msg)
|
||||
_, _ = m.treeSession.AppendLLMMessage(msg)
|
||||
}
|
||||
|
||||
// Auto-compact if enabled and conversation is near the context limit.
|
||||
@@ -1609,13 +1192,11 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
|
||||
}
|
||||
|
||||
// Build context from the tree so only the current branch is sent.
|
||||
messages := m.treeSession.GetFantasyMessages()
|
||||
messages := m.treeSession.GetLLMMessages()
|
||||
|
||||
// Run ContextPrepare hooks — extensions can filter, reorder, or inject messages.
|
||||
if m.contextPrepare.hasHooks() {
|
||||
if hookResult := m.contextPrepare.run(ContextPrepareHook{Messages: messages}); hookResult != nil && hookResult.Messages != nil {
|
||||
messages = hookResult.Messages
|
||||
}
|
||||
if hookResult := m.contextPrepare.run(ContextPrepareHook{Messages: messages}); hookResult != nil && hookResult.Messages != nil {
|
||||
messages = hookResult.Messages
|
||||
}
|
||||
|
||||
sentCount := len(messages)
|
||||
@@ -1634,14 +1215,12 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
|
||||
// (pending) message or tool call is discarded.
|
||||
if result != nil && len(result.ConversationMessages) > sentCount {
|
||||
for _, msg := range result.ConversationMessages[sentCount:] {
|
||||
_, _ = m.treeSession.AppendFantasyMessage(msg)
|
||||
_, _ = m.treeSession.AppendLLMMessage(msg)
|
||||
}
|
||||
}
|
||||
m.events.emit(TurnEndEvent{Error: err})
|
||||
// Run AfterTurn hooks even on error.
|
||||
if m.afterTurn.hasHooks() {
|
||||
m.afterTurn.run(AfterTurnHook{Error: err})
|
||||
}
|
||||
m.afterTurn.run(AfterTurnHook{Error: err})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1652,7 +1231,7 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
|
||||
// GetContextStats() see up-to-date token counts.
|
||||
if len(result.ConversationMessages) > sentCount {
|
||||
for _, msg := range result.ConversationMessages[sentCount:] {
|
||||
_, _ = m.treeSession.AppendFantasyMessage(msg)
|
||||
_, _ = m.treeSession.AppendLLMMessage(msg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1672,9 +1251,7 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
|
||||
m.events.emit(TurnEndEvent{Response: responseText, StopReason: stopReason})
|
||||
|
||||
// Run AfterTurn hooks.
|
||||
if m.afterTurn.hasHooks() {
|
||||
m.afterTurn.run(AfterTurnHook{Response: responseText})
|
||||
}
|
||||
m.afterTurn.run(AfterTurnHook{Response: responseText})
|
||||
|
||||
// Build TurnResult with usage stats.
|
||||
turnResult := &TurnResult{
|
||||
@@ -1736,7 +1313,7 @@ func (m *Kit) Steer(ctx context.Context, instruction string) (string, error) {
|
||||
// Returns an error if there are no previous messages in the session.
|
||||
func (m *Kit) FollowUp(ctx context.Context, text string) (string, error) {
|
||||
// Verify there is conversation history to follow up on.
|
||||
if len(m.treeSession.GetFantasyMessages()) == 0 {
|
||||
if len(m.treeSession.GetLLMMessages()) == 0 {
|
||||
return "", fmt.Errorf("cannot follow up: no previous messages")
|
||||
}
|
||||
|
||||
@@ -1843,45 +1420,6 @@ func (m *Kit) PromptWithOptions(ctx context.Context, msg string, opts PromptOpti
|
||||
return result.Response, nil
|
||||
}
|
||||
|
||||
// PromptWithCallbacks sends a message with callbacks for monitoring tool
|
||||
// execution and streaming responses. Lifecycle events are also emitted to all
|
||||
// registered subscribers (via Subscribe).
|
||||
//
|
||||
// Deprecated: Use Subscribe/OnToolCall/OnToolResult/OnStreaming instead of
|
||||
// inline callbacks. PromptWithCallbacks is retained for backward compatibility.
|
||||
func (m *Kit) PromptWithCallbacks(
|
||||
ctx context.Context,
|
||||
message string,
|
||||
onToolCall func(name, args string),
|
||||
onToolResult func(name, args, result string, isError bool),
|
||||
onStreaming func(chunk string),
|
||||
) (string, error) {
|
||||
// Register temporary subscribers for the inline callbacks.
|
||||
var unsubs []func()
|
||||
if onToolCall != nil {
|
||||
unsubs = append(unsubs, m.OnToolCall(func(e ToolCallEvent) {
|
||||
onToolCall(e.ToolName, e.ToolArgs)
|
||||
}))
|
||||
}
|
||||
if onToolResult != nil {
|
||||
unsubs = append(unsubs, m.OnToolResult(func(e ToolResultEvent) {
|
||||
onToolResult(e.ToolName, e.ToolArgs, e.Result, e.IsError)
|
||||
}))
|
||||
}
|
||||
if onStreaming != nil {
|
||||
unsubs = append(unsubs, m.OnStreaming(func(e MessageUpdateEvent) {
|
||||
onStreaming(e.Chunk)
|
||||
}))
|
||||
}
|
||||
defer func() {
|
||||
for _, unsub := range unsubs {
|
||||
unsub()
|
||||
}
|
||||
}()
|
||||
|
||||
return m.Prompt(ctx, message)
|
||||
}
|
||||
|
||||
// PromptResult sends a message and returns the full turn result including
|
||||
// usage statistics and conversation messages. Use this instead of Prompt()
|
||||
// when you need more than just the response text.
|
||||
@@ -1894,7 +1432,7 @@ func (m *Kit) PromptResult(ctx context.Context, message string) (*TurnResult, er
|
||||
// PromptResultWithFiles sends a multimodal message (text + images) and returns
|
||||
// the full turn result. The files parameter carries binary file data (e.g.
|
||||
// clipboard images) that are included alongside the text in the user message.
|
||||
func (m *Kit) PromptResultWithFiles(ctx context.Context, message string, files []fantasy.FilePart) (*TurnResult, error) {
|
||||
func (m *Kit) PromptResultWithFiles(ctx context.Context, message string, files []LLMFilePart) (*TurnResult, error) {
|
||||
return m.runTurn(ctx, message, message, []fantasy.Message{
|
||||
fantasy.NewUserMessage(message, files...),
|
||||
})
|
||||
@@ -1915,7 +1453,7 @@ func (m *Kit) PromptResultWithMessages(ctx context.Context, messages []string) (
|
||||
promptLabel = promptLabel[:100] + "..."
|
||||
}
|
||||
|
||||
// Build fantasy messages from all strings
|
||||
// Build LLM messages from all strings
|
||||
var preMessages []fantasy.Message
|
||||
for _, msg := range messages {
|
||||
preMessages = append(preMessages, fantasy.NewUserMessage(msg))
|
||||
@@ -1960,6 +1498,9 @@ func (m *Kit) GetThinkingLevel() string {
|
||||
|
||||
// SetThinkingLevel changes the thinking level and recreates the agent with
|
||||
// the new thinking budget. Returns an error if provider recreation fails.
|
||||
//
|
||||
// With message-level caching, both thinking and caching work together.
|
||||
// Caching reduces costs by 60-90% for repeated context.
|
||||
func (m *Kit) SetThinkingLevel(ctx context.Context, level string) error {
|
||||
viper.Set("thinking-level", level)
|
||||
// Recreate agent with new thinking config by re-running SetModel
|
||||
|
||||
+7
-2
@@ -16,10 +16,15 @@ func GetSupportedProviders() []string {
|
||||
return models.GetGlobalRegistry().GetSupportedProviders()
|
||||
}
|
||||
|
||||
// GetFantasyProviders returns provider IDs that can be used with fantasy,
|
||||
// GetLLMProviders returns provider IDs that have LLM support,
|
||||
// either through a native provider or via openaicompat auto-routing.
|
||||
func GetLLMProviders() []string {
|
||||
return models.GetGlobalRegistry().GetLLMProviders()
|
||||
}
|
||||
|
||||
// Deprecated: Use GetLLMProviders instead.
|
||||
func GetFantasyProviders() []string {
|
||||
return models.GetGlobalRegistry().GetFantasyProviders()
|
||||
return GetLLMProviders()
|
||||
}
|
||||
|
||||
// GetModelsForProvider returns all known models for a provider.
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/message"
|
||||
"github.com/mark3labs/kit/internal/session"
|
||||
)
|
||||
|
||||
@@ -86,3 +91,192 @@ func (m *Kit) SetSessionName(name string) error {
|
||||
_, err := m.treeSession.AppendSessionInfo(name)
|
||||
return err
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tree Navigation Bridge for Extensions (Phase 1)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// GetTreeNode returns a node by ID with full metadata and children.
|
||||
// Returns nil if entry not found or no tree session.
|
||||
func (m *Kit) GetTreeNode(entryID string) *TreeNode {
|
||||
if m.treeSession == nil {
|
||||
return nil
|
||||
}
|
||||
entry := m.treeSession.GetEntry(entryID)
|
||||
if entry == nil {
|
||||
return nil
|
||||
}
|
||||
return m.entryToTreeNode(entry)
|
||||
}
|
||||
|
||||
// GetCurrentBranch returns the path from root to current leaf as TreeNodes.
|
||||
func (m *Kit) GetCurrentBranch() []TreeNode {
|
||||
if m.treeSession == nil {
|
||||
return nil
|
||||
}
|
||||
branch := m.treeSession.GetBranch("")
|
||||
var nodes []TreeNode
|
||||
for _, entry := range branch {
|
||||
node := m.entryToTreeNode(entry)
|
||||
if node != nil {
|
||||
nodes = append(nodes, *node)
|
||||
}
|
||||
}
|
||||
return nodes
|
||||
}
|
||||
|
||||
// GetChildren returns direct child IDs of an entry.
|
||||
func (m *Kit) GetChildren(parentID string) []string {
|
||||
if m.treeSession == nil {
|
||||
return nil
|
||||
}
|
||||
return m.treeSession.GetChildren(parentID)
|
||||
}
|
||||
|
||||
// NavigateTo branches/forks the session to the specified entry ID.
|
||||
// Returns an error if the session is unavailable or the entry ID is not found.
|
||||
func (m *Kit) NavigateTo(entryID string) error {
|
||||
if m.treeSession == nil {
|
||||
return fmt.Errorf("no tree session available")
|
||||
}
|
||||
return m.treeSession.Branch(entryID)
|
||||
}
|
||||
|
||||
// SummarizeBranch uses the LLM to summarize the conversation between two
|
||||
// entry IDs. Returns the summary text, or an error if the range is invalid,
|
||||
// the session is unavailable, or the LLM call fails.
|
||||
func (m *Kit) SummarizeBranch(fromID, toID string) (string, error) {
|
||||
if m.treeSession == nil {
|
||||
return "", fmt.Errorf("no tree session available")
|
||||
}
|
||||
|
||||
// Get the branch and find the range
|
||||
branch := m.treeSession.GetBranch("")
|
||||
var startIdx, endIdx = -1, -1
|
||||
for i, entry := range branch {
|
||||
id := m.treeSession.EntryID(entry)
|
||||
if id == fromID {
|
||||
startIdx = i
|
||||
}
|
||||
if id == toID {
|
||||
endIdx = i
|
||||
}
|
||||
}
|
||||
|
||||
if startIdx < 0 || endIdx < 0 || startIdx > endIdx {
|
||||
return "", fmt.Errorf("entry IDs not found or out of order in current branch")
|
||||
}
|
||||
|
||||
// Build text to summarize
|
||||
var content strings.Builder
|
||||
for i := startIdx; i <= endIdx; i++ {
|
||||
node := m.entryToTreeNode(branch[i])
|
||||
if node != nil && node.Content != "" {
|
||||
fmt.Fprintf(&content, "[%s] %s\n\n", node.Role, node.Content)
|
||||
}
|
||||
}
|
||||
|
||||
if content.Len() == 0 {
|
||||
return "", fmt.Errorf("no content found in the specified range")
|
||||
}
|
||||
|
||||
// Use LLM to summarize
|
||||
resp, err := m.ExecuteCompletion(context.Background(), extensions.CompleteRequest{
|
||||
Model: "", // Use current model
|
||||
System: "You are a concise summarization assistant. Summarize the conversation in 2-3 sentences.",
|
||||
Prompt: content.String(),
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("summarization failed: %w", err)
|
||||
}
|
||||
return resp.Text, nil
|
||||
}
|
||||
|
||||
// CollapseBranch replaces a branch range with a summary entry.
|
||||
// Returns an error if the session is unavailable or the operation fails.
|
||||
func (m *Kit) CollapseBranch(fromID, toID, summary string) error {
|
||||
if m.treeSession == nil {
|
||||
return fmt.Errorf("no tree session available")
|
||||
}
|
||||
_, err := m.treeSession.AppendBranchSummary(fromID, summary)
|
||||
return err
|
||||
}
|
||||
|
||||
// entryToTreeNode converts a session entry to a TreeNode.
|
||||
func (m *Kit) entryToTreeNode(entry any) *TreeNode {
|
||||
switch e := entry.(type) {
|
||||
case *session.MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var content strings.Builder
|
||||
for _, p := range msg.Parts {
|
||||
switch pt := p.(type) {
|
||||
case message.TextContent:
|
||||
content.WriteString(pt.Text)
|
||||
case message.ReasoningContent:
|
||||
content.WriteString(pt.Thinking)
|
||||
case message.ToolCall:
|
||||
fmt.Fprintf(&content, "[tool_call: %s]", pt.Name)
|
||||
case message.ToolResult:
|
||||
fmt.Fprintf(&content, "[tool_result: %s]", pt.Content)
|
||||
}
|
||||
}
|
||||
return &TreeNode{
|
||||
ID: e.ID,
|
||||
ParentID: e.ParentID,
|
||||
Type: "message",
|
||||
Role: string(msg.Role),
|
||||
Content: content.String(),
|
||||
Model: msg.Model,
|
||||
Provider: msg.Provider,
|
||||
Timestamp: e.Timestamp.Format(time.RFC3339),
|
||||
Children: m.treeSession.GetChildren(e.ID),
|
||||
}
|
||||
case *session.BranchSummaryEntry:
|
||||
return &TreeNode{
|
||||
ID: e.ID,
|
||||
ParentID: e.ParentID,
|
||||
Type: "branch_summary",
|
||||
Content: e.Summary,
|
||||
Timestamp: e.Timestamp.Format(time.RFC3339),
|
||||
Children: m.treeSession.GetChildren(e.ID),
|
||||
}
|
||||
case *session.ModelChangeEntry:
|
||||
return &TreeNode{
|
||||
ID: e.ID,
|
||||
ParentID: e.ParentID,
|
||||
Type: "model_change",
|
||||
Content: fmt.Sprintf("Model changed to %s/%s", e.Provider, e.ModelID),
|
||||
Model: e.Provider + "/" + e.ModelID,
|
||||
Provider: e.Provider,
|
||||
Timestamp: e.Timestamp.Format(time.RFC3339),
|
||||
Children: m.treeSession.GetChildren(e.ID),
|
||||
}
|
||||
case *session.ExtensionDataEntry:
|
||||
return &TreeNode{
|
||||
ID: e.ID,
|
||||
ParentID: e.ParentID,
|
||||
Type: "extension_data",
|
||||
Content: fmt.Sprintf("Extension data: %s", e.ExtType),
|
||||
Timestamp: e.Timestamp.Format(time.RFC3339),
|
||||
Children: m.treeSession.GetChildren(e.ID),
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// TreeNode represents a node in the session tree for SDK consumers.
|
||||
type TreeNode struct {
|
||||
ID string
|
||||
ParentID string
|
||||
Type string // "message", "branch_summary", "model_change", "extension_data"
|
||||
Role string // for messages: "user", "assistant", "system", "tool"
|
||||
Content string
|
||||
Model string
|
||||
Provider string
|
||||
Timestamp string
|
||||
Children []string
|
||||
}
|
||||
|
||||
+70
-1
@@ -1,6 +1,11 @@
|
||||
package kit
|
||||
|
||||
import "github.com/mark3labs/kit/internal/skills"
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/skills"
|
||||
)
|
||||
|
||||
// ==== Skills Types ====
|
||||
|
||||
@@ -67,3 +72,67 @@ func LoadPromptTemplate(path string) (*PromptTemplate, error) {
|
||||
func NewPromptBuilder(basePrompt string) *PromptBuilder {
|
||||
return skills.NewPromptBuilder(basePrompt)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Skill Bridge for Extensions (Phase 2)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// DiscoverSkillsForExtension finds skills in standard locations for extensions.
|
||||
// Returns skills in the extension-facing format. Results are cached per-Kit
|
||||
// instance to avoid reloading on every call.
|
||||
func (m *Kit) DiscoverSkillsForExtension() []extensions.Skill {
|
||||
cwd, _ := os.Getwd()
|
||||
|
||||
m.skillCache.mu.Lock()
|
||||
defer m.skillCache.mu.Unlock()
|
||||
if len(m.skillCache.skills) == 0 {
|
||||
m.skillCache.skills, _ = skills.LoadSkills(cwd)
|
||||
}
|
||||
return m.convertSkills(m.skillCache.skills)
|
||||
}
|
||||
|
||||
// LoadSkillForExtension loads a single skill file for extensions.
|
||||
func (m *Kit) LoadSkillForExtension(path string) (*extensions.Skill, string) {
|
||||
s, err := skills.LoadSkill(path)
|
||||
if err != nil {
|
||||
return nil, err.Error()
|
||||
}
|
||||
return m.convertSkill(s), ""
|
||||
}
|
||||
|
||||
// LoadSkillsFromDirForExtension loads all skills from a directory for extensions.
|
||||
func (m *Kit) LoadSkillsFromDirForExtension(dir string) extensions.SkillLoadResult {
|
||||
skillList, err := skills.LoadSkillsFromDir(dir)
|
||||
if err != nil {
|
||||
return extensions.SkillLoadResult{Error: err.Error()}
|
||||
}
|
||||
return extensions.SkillLoadResult{Skills: m.convertSkills(skillList)}
|
||||
}
|
||||
|
||||
// convertSkill converts internal skill to extension-facing format.
|
||||
func (m *Kit) convertSkill(s *skills.Skill) *extensions.Skill {
|
||||
return &extensions.Skill{
|
||||
Name: s.Name,
|
||||
Description: s.Description,
|
||||
Content: s.Content,
|
||||
Path: s.Path,
|
||||
Tags: s.Tags,
|
||||
When: s.When,
|
||||
}
|
||||
}
|
||||
|
||||
// convertSkills converts a slice of skills.
|
||||
func (m *Kit) convertSkills(skillList []*skills.Skill) []extensions.Skill {
|
||||
result := make([]extensions.Skill, 0, len(skillList))
|
||||
for _, s := range skillList {
|
||||
result = append(result, *m.convertSkill(s))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ClearSkillCache clears the skill cache for this Kit instance.
|
||||
func (m *Kit) ClearSkillCache() {
|
||||
m.skillCache.mu.Lock()
|
||||
defer m.skillCache.mu.Unlock()
|
||||
m.skillCache.skills = nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,457 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Template Parsing Bridge for Extensions (Phase 3)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// varRegex matches {{variable}} placeholders in templates.
|
||||
var varRegex = regexp.MustCompile(`\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}\}`)
|
||||
|
||||
// ParseTemplate extracts {{variables}} from template content.
|
||||
func ParseTemplate(name, content string) extensions.PromptTemplate {
|
||||
matches := varRegex.FindAllStringSubmatch(content, -1)
|
||||
vars := make([]string, 0, len(matches))
|
||||
seen := make(map[string]bool)
|
||||
for _, m := range matches {
|
||||
if len(m) > 1 && !seen[m[1]] {
|
||||
seen[m[1]] = true
|
||||
vars = append(vars, m[1])
|
||||
}
|
||||
}
|
||||
return extensions.PromptTemplate{
|
||||
Name: name,
|
||||
Content: content,
|
||||
Variables: vars,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderTemplate substitutes variables into template content.
|
||||
// Handles {{name}} and {{ name }} (any whitespace) placeholders.
|
||||
func RenderTemplate(tpl extensions.PromptTemplate, vars map[string]string) string {
|
||||
return varRegex.ReplaceAllStringFunc(tpl.Content, func(m string) string {
|
||||
sub := varRegex.FindStringSubmatch(m)
|
||||
if len(sub) > 1 {
|
||||
if v, ok := vars[sub[1]]; ok {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return m
|
||||
})
|
||||
}
|
||||
|
||||
// ParseArguments parses command-line style arguments.
|
||||
func ParseArguments(input string, pattern extensions.ArgumentPattern) extensions.ParseResult {
|
||||
result := extensions.ParseResult{
|
||||
Vars: make(map[string]string),
|
||||
Flags: make(map[string]string),
|
||||
}
|
||||
|
||||
fields := parseFields(input)
|
||||
if len(fields) == 0 {
|
||||
return result
|
||||
}
|
||||
|
||||
// First field is the command itself (if present); skip it.
|
||||
startIdx := 0
|
||||
if len(fields) > 0 && !strings.HasPrefix(fields[0], "-") {
|
||||
startIdx = 1
|
||||
}
|
||||
|
||||
// Parse flags
|
||||
i := startIdx
|
||||
for i < len(fields) {
|
||||
field := fields[i]
|
||||
|
||||
// Check for flags
|
||||
if strings.HasPrefix(field, "--") {
|
||||
flagName := field[2:]
|
||||
if varName, ok := pattern.Flags["--"+flagName]; ok {
|
||||
// Flag with value
|
||||
if i+1 < len(fields) && !strings.HasPrefix(fields[i+1], "-") {
|
||||
result.Flags["--"+flagName] = fields[i+1]
|
||||
result.Vars[varName] = fields[i+1]
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
// Boolean flag
|
||||
result.Flags["--"+flagName] = "true"
|
||||
result.Vars[varName] = "true"
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(field, "-") && len(field) > 1 {
|
||||
flagName := field[1:]
|
||||
if varName, ok := pattern.Flags["-"+flagName]; ok {
|
||||
// Flag with value
|
||||
if i+1 < len(fields) && !strings.HasPrefix(fields[i+1], "-") {
|
||||
result.Flags["-"+flagName] = fields[i+1]
|
||||
result.Vars[varName] = fields[i+1]
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
// Boolean flag
|
||||
result.Flags["-"+flagName] = "true"
|
||||
result.Vars[varName] = "true"
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
i++
|
||||
}
|
||||
|
||||
// Collect remaining as positional args and "rest"
|
||||
positional := make([]string, 0)
|
||||
i = startIdx
|
||||
for i < len(fields) {
|
||||
field := fields[i]
|
||||
if !strings.HasPrefix(field, "-") {
|
||||
// Check if this was consumed as a flag value
|
||||
consumed := false
|
||||
for _, v := range result.Vars {
|
||||
if v == field {
|
||||
// Might be consumed, check previous field
|
||||
if i > 0 {
|
||||
prev := fields[i-1]
|
||||
if strings.HasPrefix(prev, "-") {
|
||||
consumed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !consumed {
|
||||
positional = append(positional, field)
|
||||
}
|
||||
}
|
||||
i++
|
||||
}
|
||||
|
||||
// Map positional args
|
||||
for i, name := range pattern.Positional {
|
||||
if i < len(positional) {
|
||||
result.Vars[name] = positional[i]
|
||||
}
|
||||
}
|
||||
|
||||
// Set rest
|
||||
if pattern.Rest != "" && len(positional) > len(pattern.Positional) {
|
||||
restStart := len(pattern.Positional)
|
||||
if restStart < len(positional) {
|
||||
result.Vars[pattern.Rest] = strings.Join(positional[restStart:], " ")
|
||||
}
|
||||
}
|
||||
|
||||
result.Rest = strings.Join(fields, " ")
|
||||
return result
|
||||
}
|
||||
|
||||
// SimpleParseArguments parses $1, $2, $@ style arguments.
|
||||
// Returns slice where [0]=full input, [1]=$1, [2]=$2, ... [n]=$@
|
||||
func SimpleParseArguments(input string, count int) []string {
|
||||
fields := parseFields(input)
|
||||
result := make([]string, 0, count+2)
|
||||
result = append(result, input) // [0] = full input
|
||||
|
||||
// [1]..[count] = positional args
|
||||
for i := range count {
|
||||
if i < len(fields) {
|
||||
result = append(result, fields[i])
|
||||
} else {
|
||||
result = append(result, "")
|
||||
}
|
||||
}
|
||||
|
||||
// [n] = $@ (all remaining)
|
||||
if len(fields) > count {
|
||||
result = append(result, strings.Join(fields[count:], " "))
|
||||
} else {
|
||||
result = append(result, "")
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// parseFields splits input respecting quoted strings.
|
||||
func parseFields(input string) []string {
|
||||
var fields []string
|
||||
var current strings.Builder
|
||||
inQuote := false
|
||||
quoteChar := rune(0)
|
||||
|
||||
for _, r := range input {
|
||||
switch r {
|
||||
case '"', '\'':
|
||||
if !inQuote {
|
||||
inQuote = true
|
||||
quoteChar = r
|
||||
} else if r == quoteChar {
|
||||
inQuote = false
|
||||
quoteChar = 0
|
||||
} else {
|
||||
current.WriteRune(r)
|
||||
}
|
||||
case ' ', '\t':
|
||||
if inQuote {
|
||||
current.WriteRune(r)
|
||||
} else {
|
||||
if current.Len() > 0 {
|
||||
fields = append(fields, current.String())
|
||||
current.Reset()
|
||||
}
|
||||
}
|
||||
default:
|
||||
current.WriteRune(r)
|
||||
}
|
||||
}
|
||||
|
||||
if current.Len() > 0 {
|
||||
fields = append(fields, current.String())
|
||||
}
|
||||
|
||||
return fields
|
||||
}
|
||||
|
||||
// EvaluateModelConditional checks if condition matches current model.
|
||||
// Condition supports wildcards: * matches any, ? matches single char.
|
||||
func EvaluateModelConditional(currentModel, condition string) bool {
|
||||
// Handle comma-separated conditions (OR logic)
|
||||
for c := range strings.SplitSeq(condition, ",") {
|
||||
c = strings.TrimSpace(c)
|
||||
if matchModelPattern(currentModel, c) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// modelPatternCache caches compiled regexps for model glob patterns.
|
||||
var modelPatternCache sync.Map
|
||||
|
||||
// matchModelPattern matches a model against a pattern with wildcards.
|
||||
// Compiled regexps are cached to avoid recompilation on hot paths.
|
||||
func matchModelPattern(model, pattern string) bool {
|
||||
rePattern := "^" + strings.ReplaceAll(strings.ReplaceAll(pattern, "*", ".*"), "?", ".") + "$"
|
||||
var re *regexp.Regexp
|
||||
if v, ok := modelPatternCache.Load(rePattern); ok {
|
||||
re = v.(*regexp.Regexp)
|
||||
} else {
|
||||
compiled, err := regexp.Compile(rePattern)
|
||||
if err != nil {
|
||||
// Fallback: exact match
|
||||
return model == pattern
|
||||
}
|
||||
modelPatternCache.Store(rePattern, compiled)
|
||||
re = compiled
|
||||
}
|
||||
return re.MatchString(model)
|
||||
}
|
||||
|
||||
// RenderWithModelConditionals processes <if-model> blocks in content.
|
||||
func RenderWithModelConditionals(content, currentModel string) string {
|
||||
// Simple regex-based processor for <if-model> blocks
|
||||
// Supports: <if-model is="pattern">content</if-model>
|
||||
// And: <if-model is="pattern">content<else>other</if-model>
|
||||
|
||||
result := content
|
||||
|
||||
// Pattern for if-model blocks
|
||||
ifModelRegex := regexp.MustCompile(`(?s)<if-model\s+is="([^"]+)">(.*?)(?:<else>(.*?))?</if-model>`)
|
||||
|
||||
for {
|
||||
match := ifModelRegex.FindStringSubmatchIndex(result)
|
||||
if match == nil {
|
||||
break
|
||||
}
|
||||
|
||||
condition := result[match[2]:match[3]]
|
||||
ifContent := result[match[4]:match[5]]
|
||||
elseContent := ""
|
||||
if match[6] >= 0 && match[7] >= 0 {
|
||||
elseContent = result[match[6]:match[7]]
|
||||
}
|
||||
|
||||
var replacement string
|
||||
if EvaluateModelConditional(currentModel, condition) {
|
||||
replacement = ifContent
|
||||
} else {
|
||||
replacement = elseContent
|
||||
}
|
||||
|
||||
result = result[:match[0]] + replacement + result[match[1]:]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Model Resolution Bridge for Extensions (Phase 4)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ResolveModelChain attempts each model in order until one is available.
|
||||
func ResolveModelChain(preferences []string) extensions.ModelResolutionResult {
|
||||
result := extensions.ModelResolutionResult{
|
||||
Attempted: make([]string, 0, len(preferences)),
|
||||
}
|
||||
|
||||
registry := models.GetGlobalRegistry()
|
||||
|
||||
for _, pref := range preferences {
|
||||
pref = strings.TrimSpace(pref)
|
||||
result.Attempted = append(result.Attempted, pref)
|
||||
|
||||
// Parse model string
|
||||
provider, modelID, err := models.ParseModelString(pref)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if provider exists
|
||||
if registry.GetProviderInfo(provider) == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if model exists in registry
|
||||
modelInfo := registry.LookupModel(provider, modelID)
|
||||
if modelInfo == nil {
|
||||
// Try with just the model as bare name
|
||||
continue
|
||||
}
|
||||
|
||||
// Found available model
|
||||
result.Model = provider + "/" + modelID
|
||||
result.Capabilities = extensions.ModelCapabilities{
|
||||
Provider: provider,
|
||||
ModelID: modelID,
|
||||
ContextLimit: modelInfo.Limit.Context,
|
||||
OutputLimit: modelInfo.Limit.Output,
|
||||
Reasoning: modelInfo.Reasoning,
|
||||
Streaming: true, // Assume streaming support
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
result.Error = "no models in chain are available"
|
||||
return result
|
||||
}
|
||||
|
||||
// GetModelCapabilities returns capabilities for a specific model.
|
||||
// If model is empty, returns zero capabilities.
|
||||
func GetModelCapabilities(model string) (extensions.ModelCapabilities, string) {
|
||||
if model == "" {
|
||||
return extensions.ModelCapabilities{}, "no model specified"
|
||||
}
|
||||
|
||||
provider, modelID, err := models.ParseModelString(model)
|
||||
if err != nil {
|
||||
return extensions.ModelCapabilities{}, err.Error()
|
||||
}
|
||||
|
||||
registry := models.GetGlobalRegistry()
|
||||
modelInfo := registry.LookupModel(provider, modelID)
|
||||
if modelInfo == nil {
|
||||
return extensions.ModelCapabilities{}, "model not found in registry"
|
||||
}
|
||||
|
||||
return extensions.ModelCapabilities{
|
||||
Provider: provider,
|
||||
ModelID: modelID,
|
||||
ContextLimit: modelInfo.Limit.Context,
|
||||
OutputLimit: modelInfo.Limit.Output,
|
||||
Reasoning: modelInfo.Reasoning,
|
||||
Streaming: true,
|
||||
}, ""
|
||||
}
|
||||
|
||||
// CheckModelAvailable verifies if a model string is valid and provider exists.
|
||||
func CheckModelAvailable(model string) bool {
|
||||
provider, _, err := models.ParseModelString(model)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
registry := models.GetGlobalRegistry()
|
||||
if registry.GetProviderInfo(provider) == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Model doesn't need to be in registry - could be dynamic/Ollama
|
||||
return true
|
||||
}
|
||||
|
||||
// GetCurrentProvider extracts provider from model string.
|
||||
func GetCurrentProvider(model string) string {
|
||||
provider, _, _ := models.ParseModelString(model)
|
||||
return provider
|
||||
}
|
||||
|
||||
// GetCurrentModelID extracts model ID from model string.
|
||||
func GetCurrentModelID(model string) string {
|
||||
_, modelID, _ := models.ParseModelString(model)
|
||||
return modelID
|
||||
}
|
||||
|
||||
// JoinModel combines provider and model ID into a model string.
|
||||
func JoinModel(provider, modelID string) string {
|
||||
if provider == "" {
|
||||
return modelID
|
||||
}
|
||||
return provider + "/" + modelID
|
||||
}
|
||||
|
||||
// MatchModelGlob matches a model against a glob pattern.
|
||||
// Pattern can contain * (match any) and ? (match single).
|
||||
func MatchModelGlob(model, pattern string) bool {
|
||||
return matchModelPattern(model, pattern)
|
||||
}
|
||||
|
||||
// ExtractProviderFromPath extracts provider from a path-like model string.
|
||||
func ExtractProviderFromPath(model string) string {
|
||||
parts := strings.Split(model, "/")
|
||||
if len(parts) >= 2 {
|
||||
return parts[0]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ExtractModelFromPath extracts model ID from a path-like model string.
|
||||
func ExtractModelFromPath(model string) string {
|
||||
parts := strings.Split(model, "/")
|
||||
if len(parts) >= 2 {
|
||||
return parts[1]
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
// IsBareModelID checks if a string is a bare model ID (no provider).
|
||||
func IsBareModelID(model string) bool {
|
||||
return !strings.Contains(model, "/")
|
||||
}
|
||||
|
||||
// AddProviderToModel adds a provider prefix to a bare model ID.
|
||||
func AddProviderToModel(provider, model string) string {
|
||||
if strings.Contains(model, "/") {
|
||||
return model // Already has provider
|
||||
}
|
||||
return provider + "/" + model
|
||||
}
|
||||
|
||||
// RemoveProviderFromModel removes the provider prefix from a model string.
|
||||
func RemoveProviderFromModel(model string) string {
|
||||
parts := strings.SplitN(model, "/", 2)
|
||||
if len(parts) == 2 {
|
||||
return parts[1]
|
||||
}
|
||||
return model
|
||||
}
|
||||
+1
-1
@@ -52,7 +52,7 @@ func CodingTools(opts ...ToolOption) []Tool { return core.CodingTools(opts...) }
|
||||
// read, grep, find, ls.
|
||||
func ReadOnlyTools(opts ...ToolOption) []Tool { return core.ReadOnlyTools(opts...) }
|
||||
|
||||
// SubagentTools returns all core tools except spawn_subagent. Use this when
|
||||
// SubagentTools returns all core tools except subagent. Use this when
|
||||
// creating child Kit instances (in-process subagents) to prevent infinite
|
||||
// recursion.
|
||||
func SubagentTools(opts ...ToolOption) []Tool { return core.SubagentTools(opts...) }
|
||||
|
||||
+52
-29
@@ -1,6 +1,8 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/mark3labs/kit/internal/agent"
|
||||
@@ -76,10 +78,6 @@ type MCPServerConfig = config.MCPServerConfig
|
||||
// AgentConfig holds configuration options for creating a new Agent.
|
||||
type AgentConfig = agent.AgentConfig
|
||||
|
||||
// GenerateResult contains the result and conversation history from an agent
|
||||
// interaction.
|
||||
type GenerateResult = agent.GenerateWithLoopResult
|
||||
|
||||
type (
|
||||
// ToolCallHandler is a function type for handling tool calls as they happen.
|
||||
ToolCallHandler = agent.ToolCallHandler
|
||||
@@ -128,18 +126,22 @@ type ModelsRegistry = models.ModelsRegistry
|
||||
// Ollama model loading. Signature: func(fn func() error) error.
|
||||
type SpinnerFunc = agent.SpinnerFunc
|
||||
|
||||
// ==== Fantasy Types (re-exported) ====
|
||||
// ==== LLM Types ====
|
||||
|
||||
// FantasyMessage is the underlying message type used by the fantasy agent
|
||||
// library. Re-exported so SDK users can work with fantasy types without a
|
||||
// direct import of charm.land/fantasy.
|
||||
type FantasyMessage = fantasy.Message
|
||||
// LLMMessage is the underlying message type used by the LLM agent
|
||||
// library. Re-exported so SDK users can work with LLM types without a
|
||||
// direct import of the underlying LLM library.
|
||||
type LLMMessage = fantasy.Message
|
||||
|
||||
// FantasyUsage contains token usage information from an LLM response.
|
||||
type FantasyUsage = fantasy.Usage
|
||||
// LLMUsage contains token usage information from an LLM response.
|
||||
type LLMUsage = fantasy.Usage
|
||||
|
||||
// FantasyResponse is the response type returned by the fantasy agent library.
|
||||
type FantasyResponse = fantasy.Response
|
||||
// LLMResponse is the response type returned by the LLM agent library.
|
||||
type LLMResponse = fantasy.Response
|
||||
|
||||
// LLMFilePart represents a file attachment (image, document, etc.) that can
|
||||
// be included in a prompt via PromptResultWithFiles.
|
||||
type LLMFilePart = fantasy.FilePart
|
||||
|
||||
// ==== Compaction Types (internal/compaction/) ====
|
||||
|
||||
@@ -151,27 +153,48 @@ type CompactionOptions = compaction.CompactionOptions
|
||||
|
||||
// ==== Constructor & Helper Functions ====
|
||||
|
||||
var (
|
||||
// ParseModelString parses a model string in "provider/model" format.
|
||||
ParseModelString = models.ParseModelString
|
||||
// CreateProvider creates a fantasy LanguageModel based on provider config.
|
||||
CreateProvider = models.CreateProvider
|
||||
// GetGlobalRegistry returns the global models registry instance.
|
||||
GetGlobalRegistry = models.GetGlobalRegistry
|
||||
// LoadSystemPrompt loads system prompt from file or returns string directly.
|
||||
LoadSystemPrompt = config.LoadSystemPrompt
|
||||
)
|
||||
// ParseModelString parses a model string in "provider/model" format.
|
||||
// Returns provider, modelID, and an error if the format is invalid.
|
||||
func ParseModelString(model string) (provider, modelID string, err error) {
|
||||
return models.ParseModelString(model)
|
||||
}
|
||||
|
||||
// CreateProvider creates a LanguageModel based on provider config.
|
||||
func CreateProvider(ctx context.Context, cfg *ProviderConfig) (*ProviderResult, error) {
|
||||
return models.CreateProvider(ctx, cfg)
|
||||
}
|
||||
|
||||
// GetGlobalRegistry returns the global models registry instance.
|
||||
func GetGlobalRegistry() *ModelsRegistry {
|
||||
return models.GetGlobalRegistry()
|
||||
}
|
||||
|
||||
// LoadSystemPrompt loads a system prompt from a file path, or returns the
|
||||
// string directly if it is not a valid file path.
|
||||
func LoadSystemPrompt(pathOrContent string) (string, error) {
|
||||
return config.LoadSystemPrompt(pathOrContent)
|
||||
}
|
||||
|
||||
// ==== Conversion Helpers ====
|
||||
|
||||
// ConvertToFantasyMessages converts an SDK message to the underlying fantasy
|
||||
// ConvertToLLMMessages converts an SDK message to the underlying LLM
|
||||
// messages used by the agent for LLM interactions.
|
||||
func ConvertToFantasyMessages(msg *Message) []fantasy.Message {
|
||||
return msg.ToFantasyMessages()
|
||||
func ConvertToLLMMessages(msg *Message) []fantasy.Message {
|
||||
return msg.ToLLMMessages()
|
||||
}
|
||||
|
||||
// ConvertFromFantasyMessage converts a fantasy message from the agent to an SDK
|
||||
// ConvertFromLLMMessage converts an LLM message from the agent to an SDK
|
||||
// message format for use in the SDK API.
|
||||
func ConvertFromFantasyMessage(msg fantasy.Message) Message {
|
||||
return message.FromFantasyMessage(msg)
|
||||
func ConvertFromLLMMessage(msg fantasy.Message) Message {
|
||||
return message.FromLLMMessage(msg)
|
||||
}
|
||||
|
||||
// Deprecated: Use ConvertToLLMMessages instead.
|
||||
func ConvertToFantasyMessages(msg *Message) []fantasy.Message {
|
||||
return ConvertToLLMMessages(msg)
|
||||
}
|
||||
|
||||
// Deprecated: Use ConvertFromLLMMessage instead.
|
||||
func ConvertFromFantasyMessage(msg fantasy.Message) Message {
|
||||
return ConvertFromLLMMessage(msg)
|
||||
}
|
||||
|
||||
@@ -49,12 +49,12 @@ func TestTypeExports(t *testing.T) {
|
||||
Role: kit.RoleUser,
|
||||
Parts: []kit.ContentPart{kit.TextContent{Text: "test"}},
|
||||
}
|
||||
fantasyMsgs := kit.ConvertToFantasyMessages(&userMsg)
|
||||
if len(fantasyMsgs) == 0 {
|
||||
t.Error("ConvertToFantasyMessages returned empty slice")
|
||||
llmMsgs := kit.ConvertToLLMMessages(&userMsg)
|
||||
if len(llmMsgs) == 0 {
|
||||
t.Error("ConvertToLLMMessages returned empty slice")
|
||||
}
|
||||
|
||||
roundTrip := kit.ConvertFromFantasyMessage(fantasyMsgs[0])
|
||||
roundTrip := kit.ConvertFromLLMMessage(llmMsgs[0])
|
||||
if roundTrip.Content() != "test" {
|
||||
t.Errorf("round-trip Content() = %q, want %q", roundTrip.Content(), "test")
|
||||
}
|
||||
|
||||
@@ -1210,6 +1210,129 @@ func applyMode(ctx ext.Context, active bool, tools []string) {
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Bridged SDK APIs (New)
|
||||
|
||||
Extensions can now access powerful internal SDK capabilities that enable advanced features like conversation tree navigation, dynamic skill loading, template parsing, and model resolution.
|
||||
|
||||
### Tree Navigation
|
||||
|
||||
Navigate the conversation tree, summarize branches, and implement "fresh context" loops:
|
||||
|
||||
```go
|
||||
// Get a specific node by ID with full metadata and children
|
||||
node := ctx.GetTreeNode("entry-id")
|
||||
// node.ID, node.ParentID, node.Type ("message"/"branch_summary"/etc)
|
||||
// node.Role, node.Content, node.Model, node.Children ([]string)
|
||||
|
||||
// Get the current branch from root to leaf
|
||||
branch := ctx.GetCurrentBranch() // []ext.TreeNode
|
||||
|
||||
// Get child entry IDs of a node
|
||||
children := ctx.GetChildren("entry-id") // []string
|
||||
|
||||
// Navigate/fork to a different entry in the tree
|
||||
result := ctx.NavigateTo("entry-id") // ext.TreeNavigationResult{Success, Error}
|
||||
|
||||
// Summarize a range of the branch using LLM
|
||||
summary := ctx.SummarizeBranch("from-id", "to-id") // string
|
||||
|
||||
// Collapse a branch range into a summary entry (fresh context primitive)
|
||||
result := ctx.CollapseBranch("from-id", "to-id", "summary text")
|
||||
```
|
||||
|
||||
### Skill Loading
|
||||
|
||||
Load and inject skills dynamically at runtime:
|
||||
|
||||
```go
|
||||
// Discover skills from standard locations
|
||||
result := ctx.DiscoverSkills() // ext.SkillLoadResult{Skills, Error}
|
||||
// Standard locations: ~/.config/kit/skills/, .kit/skills/, .agents/skills/
|
||||
|
||||
// Load a specific skill file
|
||||
skill, err := ctx.LoadSkill("/path/to/skill.md") // (*ext.Skill, error string)
|
||||
// skill.Name, skill.Description, skill.Content, skill.Tags, skill.When
|
||||
|
||||
// Load all skills from a directory
|
||||
result := ctx.LoadSkillsFromDir("/path/to/skills") // ext.SkillLoadResult
|
||||
|
||||
// Inject a skill as context (pre-loads for next turn)
|
||||
err := ctx.InjectSkillAsContext("skill-name") // error string
|
||||
|
||||
// Inject a skill file directly
|
||||
err := ctx.InjectRawSkillAsContext("/path/to/skill.md") // error string
|
||||
|
||||
// Get all discovered skills
|
||||
skills := ctx.GetAvailableSkills() // []ext.Skill
|
||||
```
|
||||
|
||||
### Template Parsing
|
||||
|
||||
Parse and render templates with variable substitution:
|
||||
|
||||
```go
|
||||
// Parse a template to extract {{variables}}
|
||||
tpl := ctx.ParseTemplate("name", "Hello {{name}}, welcome to {{place}}!")
|
||||
// tpl.Name, tpl.Content, tpl.Variables ([]string)
|
||||
|
||||
// Render a template with variable values
|
||||
vars := map[string]string{"name": "Alice", "place": "Kit"}
|
||||
rendered := ctx.RenderTemplate(tpl, vars) // "Hello Alice, welcome to Kit!"
|
||||
|
||||
// Parse command-line style arguments
|
||||
pattern := ext.ArgumentPattern{
|
||||
Positional: []string{"command", "target"}, // $1, $2
|
||||
Rest: "args", // $@
|
||||
Flags: map[string]string{"--loop": "loop", "-f": "force"},
|
||||
}
|
||||
result := ctx.ParseArguments("deploy staging --loop 5", pattern)
|
||||
// result.Vars["command"] = "deploy"
|
||||
// result.Vars["target"] = "staging"
|
||||
// result.Flags["--loop"] = "5"
|
||||
|
||||
// Simple positional argument parsing ($1, $2, $@)
|
||||
args := ctx.SimpleParseArguments("deploy staging --force", 2)
|
||||
// args[0] = "deploy staging --force" (full input)
|
||||
// args[1] = "deploy" ($1)
|
||||
// args[2] = "staging" ($2)
|
||||
// args[3] = "--force" ($@)
|
||||
|
||||
// Evaluate model conditionals with wildcards
|
||||
matches := ctx.EvaluateModelConditional("claude-*") // bool
|
||||
// Patterns: * matches any, ? matches single char, comma = OR
|
||||
|
||||
// Render content with <if-model> conditionals
|
||||
content := `<if-model is="claude-*">Hi Claude<else>Hi there</if-model>`
|
||||
rendered := ctx.RenderWithModelConditionals(content) // based on current model
|
||||
```
|
||||
|
||||
### Model Resolution
|
||||
|
||||
Resolve model fallback chains and query capabilities:
|
||||
|
||||
```go
|
||||
// Resolve a chain of model preferences (tries each until available)
|
||||
result := ctx.ResolveModelChain([]string{
|
||||
"anthropic/claude-opus-4",
|
||||
"anthropic/claude-sonnet-4",
|
||||
"openai/gpt-4o",
|
||||
})
|
||||
// result.Model (selected), result.Capabilities, result.Attempted, result.Error
|
||||
|
||||
// Get capabilities for a specific model
|
||||
caps, err := ctx.GetModelCapabilities("anthropic/claude-sonnet-4")
|
||||
// caps.Provider, caps.ModelID, caps.ContextLimit, caps.Reasoning, caps.Streaming
|
||||
|
||||
// Check if a model is available (provider exists)
|
||||
available := ctx.CheckModelAvailable("anthropic/claude-sonnet-4") // bool
|
||||
|
||||
// Get current provider/model ID
|
||||
provider := ctx.GetCurrentProvider() // "anthropic"
|
||||
modelID := ctx.GetCurrentModelID() // "claude-sonnet-4"
|
||||
```
|
||||
|
||||
## Key Files for Reference
|
||||
|
||||
- [`internal/extensions/api.go`](https://github.com/mark3labs/kit/blob/main/internal/extensions/api.go) — Complete API type definitions
|
||||
|
||||
+64
-26
@@ -119,17 +119,15 @@ result, err := host.PromptResult(ctx, "Analyze this file")
|
||||
// result.Response — assistant's text
|
||||
// result.StopReason — "stop", "length", "tool-calls", "error", etc.
|
||||
// result.SessionID — session UUID
|
||||
// result.TotalUsage — aggregate tokens across all steps (*kit.FantasyUsage)
|
||||
// result.TotalUsage — aggregate tokens across all steps (*kit.LLMUsage)
|
||||
// result.FinalUsage — tokens from last API call only
|
||||
// result.Messages — full updated conversation ([]kit.FantasyMessage)
|
||||
// result.Messages — full updated conversation ([]kit.LLMMessage)
|
||||
```
|
||||
|
||||
### Multimodal with file attachments
|
||||
|
||||
```go
|
||||
import "charm.land/fantasy"
|
||||
|
||||
files := []fantasy.FilePart{{
|
||||
files := []kit.LLMFilePart{{
|
||||
Name: "screenshot.png",
|
||||
MediaType: "image/png",
|
||||
Data: imageBytes,
|
||||
@@ -167,16 +165,6 @@ result, err := host.PromptResultWithMessages(ctx, []string{
|
||||
})
|
||||
```
|
||||
|
||||
### Legacy inline callbacks (deprecated — use event subscribers instead)
|
||||
|
||||
```go
|
||||
response, err := host.PromptWithCallbacks(ctx, "List files",
|
||||
func(name, args string) { fmt.Printf("Tool: %s\n", name) },
|
||||
func(name, args, result string, isError bool) { /* tool result */ },
|
||||
func(chunk string) { fmt.Print(chunk) }, // streaming
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Event System
|
||||
@@ -252,6 +240,8 @@ unsub := host.Subscribe(func(e kit.Event) {
|
||||
| `response` | `ResponseEvent` | `Content` |
|
||||
| `compaction` | `CompactionEvent` | `Summary`, `OriginalTokens`, `CompactedTokens`, `MessagesRemoved`, `ReadFiles`, `ModifiedFiles` |
|
||||
| `reasoning_delta` | `ReasoningDeltaEvent` | `Delta` |
|
||||
| `step_usage` | `StepUsageEvent` | `InputTokens`, `OutputTokens`, `CacheReadTokens`, `CacheWriteTokens` |
|
||||
| `steer_consumed` | `SteerConsumedEvent` | `Count` |
|
||||
|
||||
### Tool kind constants
|
||||
|
||||
@@ -261,7 +251,7 @@ Tools are classified by kind for UI rendering:
|
||||
- `ToolKindEdit` = `"edit"` — edit, write
|
||||
- `ToolKindRead` = `"read"` — read, ls
|
||||
- `ToolKindSearch` = `"search"` — grep, find
|
||||
- `ToolKindSubagent` = `"agent"` — spawn_subagent
|
||||
- `ToolKindSubagent` = `"agent"` — subagent
|
||||
|
||||
---
|
||||
|
||||
@@ -318,7 +308,7 @@ host.OnAfterTurn(kit.HookPriorityNormal, func(h kit.AfterTurnHook) {
|
||||
|
||||
```go
|
||||
host.OnContextPrepare(kit.HookPriorityNormal, func(h kit.ContextPrepareHook) *kit.ContextPrepareResult {
|
||||
// h.Messages — []fantasy.Message (the full context being sent to the LLM)
|
||||
// h.Messages — []kit.LLMMessage (the full context being sent to the LLM)
|
||||
// Return nil to pass through, or replace entire context:
|
||||
return &kit.ContextPrepareResult{Messages: filteredMessages}
|
||||
})
|
||||
@@ -368,7 +358,7 @@ kit.NewLsTool(opts...) // directory listing
|
||||
kit.AllTools(opts...) // all 7 core tools
|
||||
kit.CodingTools(opts...) // bash, read, write, edit
|
||||
kit.ReadOnlyTools(opts...) // read, grep, find, ls
|
||||
kit.SubagentTools(opts...) // all except spawn_subagent (prevents recursion)
|
||||
kit.SubagentTools(opts...) // all except subagent (prevents recursion)
|
||||
```
|
||||
|
||||
### Tool options
|
||||
@@ -467,7 +457,7 @@ err = host.SetThinkingLevel(ctx, "medium") // recreates agent with new thinking
|
||||
```go
|
||||
models := host.GetAvailableModels() // []extensions.ModelInfoEntry
|
||||
providers := kit.GetSupportedProviders() // []string
|
||||
providers := kit.GetFantasyProviders() // providers usable with fantasy
|
||||
providers := kit.GetLLMProviders() // providers with LLM support
|
||||
models, _ := kit.GetModelsForProvider("anthropic") // map[string]kit.ModelInfo
|
||||
info := kit.LookupModel("anthropic", "claude-sonnet-4-5-20250929") // *kit.ModelInfo
|
||||
info := kit.GetProviderInfo("openai") // *kit.ProviderInfo (env vars, API URL)
|
||||
@@ -524,7 +514,7 @@ result, err := host.Subagent(ctx, kit.SubagentConfig{
|
||||
Prompt: "Analyze the test files and summarize coverage",
|
||||
Model: "anthropic/claude-haiku-3-5-20241022", // empty = parent's model
|
||||
SystemPrompt: "You are a test analysis expert.",
|
||||
Tools: nil, // nil = SubagentTools() (all except spawn_subagent)
|
||||
Tools: nil, // nil = SubagentTools() (all except subagent)
|
||||
NoSession: true, // ephemeral
|
||||
Timeout: 2 * time.Minute, // 0 = 5 minute default
|
||||
OnEvent: func(e kit.Event) {
|
||||
@@ -535,14 +525,14 @@ result, err := host.Subagent(ctx, kit.SubagentConfig{
|
||||
},
|
||||
})
|
||||
// result.Response, result.Error, result.SessionID, result.StopReason
|
||||
// result.Usage (*kit.FantasyUsage), result.Elapsed (time.Duration)
|
||||
// result.Usage (*kit.LLMUsage), result.Elapsed (time.Duration)
|
||||
```
|
||||
|
||||
### Subscribing to subagent events from parent
|
||||
|
||||
```go
|
||||
host.OnToolCall(func(e kit.ToolCallEvent) {
|
||||
if e.ToolName == "spawn_subagent" {
|
||||
if e.ToolName == "subagent" {
|
||||
host.SubscribeSubagent(e.ToolCallID, func(child kit.Event) {
|
||||
// Real-time events scoped to this subagent
|
||||
})
|
||||
@@ -552,6 +542,53 @@ host.OnToolCall(func(e kit.ToolCallEvent) {
|
||||
|
||||
---
|
||||
|
||||
## Extension API
|
||||
|
||||
The `Extensions()` method returns an `ExtensionAPI` interface that groups all extension-related functionality. This is the primary way to interact with extension state from the SDK.
|
||||
|
||||
```go
|
||||
extAPI := host.Extensions()
|
||||
|
||||
// Check if extensions are loaded
|
||||
if extAPI.HasExtensions() {
|
||||
// Context management
|
||||
extAPI.SetContext(extensions.Context{...})
|
||||
ctx := extAPI.GetContext()
|
||||
extAPI.UpdateContextModel("anthropic/claude-sonnet-4-5-20250929")
|
||||
|
||||
// Widgets, headers, footers
|
||||
extAPI.SetWidget(extensions.WidgetConfig{...})
|
||||
extAPI.RemoveWidget("widget-id")
|
||||
extAPI.SetHeader(extensions.HeaderFooterConfig{...})
|
||||
extAPI.SetFooter(extensions.HeaderFooterConfig{...})
|
||||
|
||||
// Status bar
|
||||
extAPI.SetStatus(extensions.StatusBarEntry{...})
|
||||
extAPI.RemoveStatus("key")
|
||||
|
||||
// Options
|
||||
extAPI.SetOption("name", "value")
|
||||
val := extAPI.GetOption("name")
|
||||
|
||||
// Tools
|
||||
tools := extAPI.GetToolInfos()
|
||||
extAPI.SetActiveTools([]string{"bash", "read"})
|
||||
|
||||
// Events
|
||||
extAPI.EmitSessionStart()
|
||||
extAPI.EmitModelChange("new/model", "old/model", "extension")
|
||||
extAPI.EmitCustomEvent("my-event", "data")
|
||||
|
||||
// Commands and lifecycle
|
||||
cmds := extAPI.Commands()
|
||||
err := extAPI.Reload()
|
||||
}
|
||||
```
|
||||
|
||||
All methods are no-ops when extensions are disabled (nil runner), so callers don't need nil checks.
|
||||
|
||||
---
|
||||
|
||||
## Authentication
|
||||
|
||||
```go
|
||||
@@ -603,15 +640,15 @@ kit.Config, kit.MCPServerConfig
|
||||
// Provider types
|
||||
kit.ProviderConfig, kit.ProviderResult, kit.ModelInfo, kit.ModelCost, kit.ModelLimit
|
||||
|
||||
// Fantasy types (from charm.land/fantasy)
|
||||
kit.FantasyMessage, kit.FantasyUsage, kit.FantasyResponse
|
||||
// LLM types (re-exported from the underlying LLM library)
|
||||
kit.LLMMessage, kit.LLMUsage, kit.LLMResponse, kit.LLMFilePart
|
||||
|
||||
// Compaction types
|
||||
kit.CompactionResult, kit.CompactionOptions
|
||||
|
||||
// Conversion helpers
|
||||
msgs := kit.ConvertToFantasyMessages(&msg) // SDK message → fantasy messages
|
||||
msg := kit.ConvertFromFantasyMessage(fMsg) // fantasy message → SDK message
|
||||
msgs := kit.ConvertToLLMMessages(&msg) // SDK message → LLM messages
|
||||
msg := kit.ConvertFromLLMMessage(fMsg) // LLM message → SDK message
|
||||
```
|
||||
|
||||
---
|
||||
@@ -759,6 +796,7 @@ kit.LoadConfigWithEnvSubstitution("/path/to/config.yml")
|
||||
## Key Files for Reference
|
||||
|
||||
- [`pkg/kit/kit.go`](https://github.com/mark3labs/kit/blob/main/pkg/kit/kit.go) — Kit struct, New(), Prompt methods, Subagent, Close
|
||||
- [`pkg/kit/extension_api.go`](https://github.com/mark3labs/kit/blob/main/pkg/kit/extension_api.go) — ExtensionAPI interface, kit.Extensions() accessor
|
||||
- [`pkg/kit/types.go`](https://github.com/mark3labs/kit/blob/main/pkg/kit/types.go) — Re-exported types from internal packages
|
||||
- [`pkg/kit/tools.go`](https://github.com/mark3labs/kit/blob/main/pkg/kit/tools.go) — Tool constructors and bundles
|
||||
- [`pkg/kit/events.go`](https://github.com/mark3labs/kit/blob/main/pkg/kit/events.go) — Event types, EventBus, typed subscribers
|
||||
|
||||
@@ -32,12 +32,12 @@ Key flags for subprocess usage:
|
||||
|
||||
Positional arguments are the prompt. `@file` arguments attach file content as context.
|
||||
|
||||
## Built-in spawn_subagent tool
|
||||
## Built-in subagent tool
|
||||
|
||||
Kit includes a built-in `spawn_subagent` tool that the LLM can use to delegate tasks to independent child agents:
|
||||
Kit includes a built-in `subagent` tool that the LLM can use to delegate tasks to independent child agents:
|
||||
|
||||
```
|
||||
spawn_subagent(
|
||||
subagent(
|
||||
task: "Analyze the test files and summarize coverage",
|
||||
model: "anthropic/claude-haiku-latest", // optional
|
||||
system_prompt: "You are a test analysis expert.", // optional
|
||||
@@ -61,7 +61,7 @@ result := ctx.SpawnSubagent(ext.SubagentConfig{
|
||||
|
||||
### Monitoring subagents from extensions
|
||||
|
||||
When the LLM (not the extension itself) spawns a subagent using the `spawn_subagent` tool, extensions can monitor its activity in real-time using three lifecycle event handlers:
|
||||
When the LLM (not the extension itself) spawns a subagent using the `subagent` tool, extensions can monitor its activity in real-time using three lifecycle event handlers:
|
||||
|
||||
```go
|
||||
// Track active subagents and display their output
|
||||
@@ -130,7 +130,7 @@ type SubagentEndEvent struct {
|
||||
}
|
||||
```
|
||||
|
||||
This enables building monitoring widgets that display real-time activity from all subagents spawned by the main agent. See the `subagent-monitor.go` example for a complete implementation with horizontal widget layouts and scrolling output.
|
||||
This enables building monitoring widgets that display real-time activity from all subagents spawned by the main agent.
|
||||
|
||||
## Go SDK subagents
|
||||
|
||||
@@ -147,11 +147,11 @@ result, err := host.Subagent(ctx, kit.SubagentConfig{
|
||||
|
||||
### Real-time subagent events
|
||||
|
||||
Use `SubscribeSubagent` to receive real-time events from LLM-initiated subagents (i.e., when the model uses the `spawn_subagent` tool). Register inside an `OnToolCall` handler using the tool call ID:
|
||||
Use `SubscribeSubagent` to receive real-time events from LLM-initiated subagents (i.e., when the model uses the `subagent` tool). Register inside an `OnToolCall` handler using the tool call ID:
|
||||
|
||||
```go
|
||||
host.OnToolCall(func(e kit.ToolCallEvent) {
|
||||
if e.ToolName == "spawn_subagent" {
|
||||
if e.ToolName == "subagent" {
|
||||
host.SubscribeSubagent(e.ToolCallID, func(event kit.Event) {
|
||||
switch ev := event.(type) {
|
||||
case kit.MessageUpdateEvent:
|
||||
|
||||
@@ -21,7 +21,7 @@ Manage the local model database that maps provider names to API configurations.
|
||||
|
||||
```bash
|
||||
kit models [provider] # List available models (optionally filter by provider)
|
||||
kit models --all # Show all providers (not just Fantasy-compatible)
|
||||
kit models --all # Show all providers (not just LLM-compatible)
|
||||
kit update-models [source] # Update model database
|
||||
```
|
||||
|
||||
@@ -74,7 +74,7 @@ These commands are available inside the Kit TUI during an interactive session:
|
||||
| `/reset-usage` | Reset usage statistics |
|
||||
| `/tree` | Navigate session tree |
|
||||
| `/fork` | Branch from an earlier message |
|
||||
| `/new` | Start a new session |
|
||||
| `/new` | Start a new session (creates new session file) |
|
||||
| `/name [name]` | Set or show session display name |
|
||||
| `/resume` | Open session picker to switch sessions (alias: `/r`) |
|
||||
| `/session` | Show session info |
|
||||
@@ -95,9 +95,17 @@ Press **ESC twice** to cancel the current operation:
|
||||
|
||||
This ensures that `tool_use` and `tool_result` messages are always sent to the API as matched pairs, avoiding errors from orphaned tool calls.
|
||||
|
||||
## Prompt templates
|
||||
### Mid-turn steering
|
||||
|
||||
Create reusable prompt templates with shell-style argument substitution. Templates are loaded from `~/.kit/prompts/*.md` and `.kit/prompts/*.md`.
|
||||
Press **Ctrl+S** during streaming to inject a system-level instruction mid-turn. This allows you to steer the conversation direction without waiting for the model to finish:
|
||||
|
||||
- Works during streaming output
|
||||
- Sends a steering instruction as a system message
|
||||
- Model continues from the interruption point with the new guidance
|
||||
|
||||
Example: While the model is writing code, press Ctrl+S and type "Use async/await instead" to change the implementation approach.
|
||||
|
||||
## Prompt templates
|
||||
|
||||
### Creating templates
|
||||
|
||||
|
||||
@@ -96,9 +96,45 @@ mcpServers:
|
||||
|
||||
A legacy format with `transport`, `args`, `env`, and `headers` fields is also supported.
|
||||
|
||||
## Theme configuration
|
||||
## Custom models
|
||||
|
||||
Set theme colors inline or reference an external file:
|
||||
Define custom models in your `.kit.yml` for use with the `custom` provider. This is useful for self-hosted models or API endpoints not in the built-in database:
|
||||
|
||||
```yaml
|
||||
customModels:
|
||||
my-model:
|
||||
name: "My Custom Model"
|
||||
reasoning: true
|
||||
temperature: true
|
||||
cost:
|
||||
input: 0.002
|
||||
output: 0.004
|
||||
limit:
|
||||
context: 128000
|
||||
output: 32000
|
||||
```
|
||||
|
||||
### Custom model fields
|
||||
|
||||
| Field | Type | Required | Description |
|
||||
|-------|------|----------|-------------|
|
||||
| `name` | string | Yes | Display name for the model |
|
||||
| `reasoning` | bool | No | Whether the model supports reasoning/thinking |
|
||||
| `temperature` | bool | No | Whether the model supports temperature adjustment |
|
||||
| `cost.input` | float | No | Cost per 1K input tokens |
|
||||
| `cost.output` | float | No | Cost per 1K output tokens |
|
||||
| `limit.context` | int | Yes | Maximum context window in tokens |
|
||||
| `limit.output` | int | No | Maximum output tokens |
|
||||
|
||||
Use with a custom provider URL:
|
||||
|
||||
```bash
|
||||
kit --provider-url "http://localhost:8080/v1" --model custom/my-model "Hello"
|
||||
```
|
||||
|
||||
When `--provider-url` is specified without `--model`, Kit defaults to `custom/custom` which has zero cost tracking and a 262K context window.
|
||||
|
||||
## Theme configuration
|
||||
|
||||
```yaml
|
||||
# Inline partial overrides (unspecified fields inherit from default)
|
||||
|
||||
@@ -239,7 +239,7 @@ result := ctx.SpawnSubagent(ext.SubagentConfig{
|
||||
|
||||
### Monitoring subagents spawned by the main agent
|
||||
|
||||
When the LLM uses the built-in `spawn_subagent` tool, extensions can monitor the subagent's activity in real-time using three lifecycle events:
|
||||
When the LLM uses the built-in `subagent` tool, extensions can monitor the subagent's activity in real-time using three lifecycle events:
|
||||
|
||||
```go
|
||||
// Subagent started
|
||||
@@ -283,7 +283,7 @@ api.OnSubagentEnd(func(e ext.SubagentEndEvent, ctx ext.Context) {
|
||||
})
|
||||
```
|
||||
|
||||
This enables building widgets that display real-time subagent activity. See the `subagent-monitor.go` example for a complete implementation showing horizontal widget layouts with scrolling output from multiple parallel subagents.
|
||||
This enables building widgets that display real-time subagent activity.
|
||||
|
||||
## LLM completion
|
||||
|
||||
@@ -334,3 +334,124 @@ api.OnCustomEvent("my-extension:data-ready", func(data any, ctx ext.Context) {
|
||||
// handle event
|
||||
})
|
||||
```
|
||||
|
||||
## Bridged SDK APIs
|
||||
|
||||
Extensions can access powerful internal SDK capabilities that enable advanced features like conversation tree navigation, dynamic skill loading, template parsing, and model resolution.
|
||||
|
||||
### Tree Navigation
|
||||
|
||||
Navigate the conversation tree, summarize branches, and implement "fresh context" loops:
|
||||
|
||||
```go
|
||||
// Get a specific node by ID with full metadata and children
|
||||
node := ctx.GetTreeNode("entry-id")
|
||||
// node.ID, node.ParentID, node.Type ("message"/"branch_summary"/etc)
|
||||
// node.Role, node.Content, node.Model, node.Children ([]string)
|
||||
|
||||
// Get the current branch from root to leaf
|
||||
branch := ctx.GetCurrentBranch() // []ext.TreeNode
|
||||
|
||||
// Get child entry IDs of a node
|
||||
children := ctx.GetChildren("entry-id") // []string
|
||||
|
||||
// Navigate/fork to a different entry in the tree
|
||||
result := ctx.NavigateTo("entry-id") // ext.TreeNavigationResult{Success, Error}
|
||||
|
||||
// Summarize a range of the branch using LLM
|
||||
summary := ctx.SummarizeBranch("from-id", "to-id") // string
|
||||
|
||||
// Collapse a branch range into a summary entry (fresh context primitive)
|
||||
result := ctx.CollapseBranch("from-id", "to-id", "summary text")
|
||||
```
|
||||
|
||||
### Skill Loading
|
||||
|
||||
Load and inject skills dynamically at runtime:
|
||||
|
||||
```go
|
||||
// Discover skills from standard locations
|
||||
result := ctx.DiscoverSkills() // ext.SkillLoadResult{Skills, Error}
|
||||
// Standard locations: ~/.config/kit/skills/, .kit/skills/, .agents/skills/
|
||||
|
||||
// Load a specific skill file
|
||||
skill, err := ctx.LoadSkill("/path/to/skill.md") // (*ext.Skill, error string)
|
||||
// skill.Name, skill.Description, skill.Content, skill.Tags, skill.When
|
||||
|
||||
// Load all skills from a directory
|
||||
result := ctx.LoadSkillsFromDir("/path/to/skills") // ext.SkillLoadResult
|
||||
|
||||
// Inject a skill as context (pre-loads for next turn)
|
||||
err := ctx.InjectSkillAsContext("skill-name") // error string
|
||||
|
||||
// Inject a skill file directly
|
||||
err := ctx.InjectRawSkillAsContext("/path/to/skill.md") // error string
|
||||
|
||||
// Get all discovered skills
|
||||
skills := ctx.GetAvailableSkills() // []ext.Skill
|
||||
```
|
||||
|
||||
### Template Parsing
|
||||
|
||||
Parse and render templates with variable substitution:
|
||||
|
||||
```go
|
||||
// Parse a template to extract {{variables}}
|
||||
tpl := ctx.ParseTemplate("name", "Hello {{name}}, welcome to {{place}}!")
|
||||
// tpl.Name, tpl.Content, tpl.Variables ([]string)
|
||||
|
||||
// Render a template with variable values
|
||||
vars := map[string]string{"name": "Alice", "place": "Kit"}
|
||||
rendered := ctx.RenderTemplate(tpl, vars) // "Hello Alice, welcome to Kit!"
|
||||
|
||||
// Parse command-line style arguments
|
||||
pattern := ext.ArgumentPattern{
|
||||
Positional: []string{"command", "target"}, // $1, $2
|
||||
Rest: "args", // $@
|
||||
Flags: map[string]string{"--loop": "loop", "-f": "force"},
|
||||
}
|
||||
result := ctx.ParseArguments("deploy staging --loop 5", pattern)
|
||||
// result.Vars["command"] = "deploy"
|
||||
// result.Vars["target"] = "staging"
|
||||
// result.Flags["--loop"] = "5"
|
||||
|
||||
// Simple positional argument parsing ($1, $2, $@)
|
||||
args := ctx.SimpleParseArguments("deploy staging --force", 2)
|
||||
// args[0] = "deploy staging --force" (full input)
|
||||
// args[1] = "deploy" ($1)
|
||||
// args[2] = "staging" ($2)
|
||||
// args[3] = "--force" ($@)
|
||||
|
||||
// Evaluate model conditionals with wildcards
|
||||
matches := ctx.EvaluateModelConditional("claude-*") // bool
|
||||
// Patterns: * matches any, ? matches single char, comma = OR
|
||||
|
||||
// Render content with <if-model> conditionals
|
||||
content := `<if-model is="claude-*">Hi Claude<else>Hi there</if-model>`
|
||||
rendered := ctx.RenderWithModelConditionals(content) // based on current model
|
||||
```
|
||||
|
||||
### Model Resolution
|
||||
|
||||
Resolve model fallback chains and query capabilities:
|
||||
|
||||
```go
|
||||
// Resolve a chain of model preferences (tries each until available)
|
||||
result := ctx.ResolveModelChain([]string{
|
||||
"anthropic/claude-opus-4",
|
||||
"anthropic/claude-sonnet-4",
|
||||
"openai/gpt-4o",
|
||||
})
|
||||
// result.Model (selected), result.Capabilities, result.Attempted, result.Error
|
||||
|
||||
// Get capabilities for a specific model
|
||||
caps, err := ctx.GetModelCapabilities("anthropic/claude-sonnet-4")
|
||||
// caps.Provider, caps.ModelID, caps.ContextLimit, caps.Reasoning, caps.Streaming
|
||||
|
||||
// Check if a model is available (provider exists)
|
||||
available := ctx.CheckModelAvailable("anthropic/claude-sonnet-4") // bool
|
||||
|
||||
// Get current provider/model ID
|
||||
provider := ctx.GetCurrentProvider() // "anthropic"
|
||||
modelID := ctx.GetCurrentModelID() // "claude-sonnet-4"
|
||||
```
|
||||
|
||||
@@ -51,6 +51,15 @@ Kit ships with a rich set of example extensions in the `examples/extensions/` di
|
||||
| [`summarize.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/summarize.go) | Conversation summarization |
|
||||
| [`lsp-diagnostics.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/lsp-diagnostics.go) | LSP diagnostic integration |
|
||||
|
||||
## Bridged SDK APIs
|
||||
|
||||
These examples demonstrate the new bridged SDK APIs that give extensions access to internal Kit capabilities:
|
||||
|
||||
| Extension | Description |
|
||||
|-----------|-------------|
|
||||
| [`conversation-manager.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/conversation-manager.go) | **NEW** Tree navigation (`GetTreeNode`, `GetCurrentBranch`, `NavigateTo`), branch summarization (`SummarizeBranch`), and fresh context loops (`CollapseBranch`) |
|
||||
| [`prompt-templates.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/prompt-templates.go) | **NEW** Frontmatter-driven templates with model fallback chains (`ResolveModelChain`), skill injection (`InjectSkillAsContext`), and template parsing (`ParseTemplate`, `RenderTemplate`) |
|
||||
|
||||
## Themes
|
||||
|
||||
| Extension | Description |
|
||||
@@ -64,7 +73,6 @@ Kit ships with a rich set of example extensions in the `examples/extensions/` di
|
||||
| [`kit-kit.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/kit-kit.go) | Kit-in-Kit sub-agent spawning |
|
||||
| [`subagent-widget.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/subagent-widget.go) | Multi-agent orchestration with status widget |
|
||||
| [`subagent-test.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/subagent-test.go) | Subagent testing utilities |
|
||||
| [`subagent-monitor.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/subagent-monitor.go) | Real-time monitoring widget for spawned subagents |
|
||||
|
||||
## Development
|
||||
|
||||
@@ -72,7 +80,6 @@ Kit ships with a rich set of example extensions in the `examples/extensions/` di
|
||||
|-----------|-------------|
|
||||
| [`dev-reload.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/dev-reload.go) | Development live-reload |
|
||||
| [`tool-logger_test.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/tool-logger_test.go) | Example extension tests (see [Testing](/extensions/testing)) |
|
||||
| [`subagent-monitor_test.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/subagent-monitor_test.go) | Subagent lifecycle event tests |
|
||||
| [`extension_test_template.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/extension_test_template.go) | Copy-and-paste test template for your extensions |
|
||||
|
||||
## Subdirectory extensions
|
||||
|
||||
+1
-1
@@ -13,7 +13,7 @@ A powerful, extensible AI coding agent CLI with multi-provider support, built-in
|
||||
## Features
|
||||
|
||||
- **Multi-Provider LLM Support** — Anthropic, OpenAI, Google Gemini, Ollama, Azure OpenAI, AWS Bedrock, OpenRouter, and more
|
||||
- **Built-in Core Tools** — bash, read, write, edit, grep, find, ls, spawn_subagent with no MCP overhead
|
||||
- **Built-in Core Tools** — bash, read, write, edit, grep, find, ls, subagent with no MCP overhead
|
||||
- **MCP Integration** — Connect external MCP servers for expanded capabilities
|
||||
- **Extension System** — Write custom tools, commands, widgets, and UI modifications in Go
|
||||
- **Interactive TUI** — Rich terminal interface powered by Bubble Tea with streaming, syntax highlighting, and custom rendering
|
||||
|
||||
@@ -139,7 +139,7 @@ When `--provider-url` is provided without `--model`, Kit automatically defaults
|
||||
kit --provider-url "http://localhost:8080/v1" "Hello"
|
||||
```
|
||||
|
||||
The `custom/custom` model has zero cost, 262K context window, and supports reasoning. It routes through fantasy's `openaicompat` provider and accepts any OpenAI-compatible API endpoint.
|
||||
The `custom/custom` model has zero cost, 262K context window, and supports reasoning. It routes through the `openaicompat` provider and accepts any OpenAI-compatible API endpoint.
|
||||
|
||||
Optionally set `CUSTOM_API_KEY` environment variable or use `--provider-api-key` for endpoints requiring authentication.
|
||||
|
||||
|
||||
@@ -5,48 +5,6 @@ description: Monitor tool calls and streaming output with the Kit Go SDK.
|
||||
|
||||
# Callbacks
|
||||
|
||||
## PromptWithCallbacks
|
||||
|
||||
The `PromptWithCallbacks` method provides real-time visibility into tool calls and streaming output:
|
||||
|
||||
```go
|
||||
response, err := host.PromptWithCallbacks(
|
||||
ctx,
|
||||
"List files in current directory",
|
||||
func(name, args string) {
|
||||
// Called when the model invokes a tool
|
||||
fmt.Println("Calling tool:", name)
|
||||
},
|
||||
func(name, args, result string, isError bool) {
|
||||
// Called when a tool returns its result
|
||||
if isError {
|
||||
fmt.Println("Tool failed:", name)
|
||||
}
|
||||
},
|
||||
func(chunk string) {
|
||||
// Called for each streaming text chunk
|
||||
fmt.Print(chunk)
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
### Callback signatures
|
||||
|
||||
| Callback | Signature | When |
|
||||
|----------|-----------|------|
|
||||
| `onToolCall` | `func(name, args string)` | Model requests a tool call |
|
||||
| `onToolResult` | `func(name, args, result string, isError bool)` | Tool execution completes |
|
||||
| `onStreaming` | `func(chunk string)` | Streaming text chunk received |
|
||||
|
||||
Any callback can be `nil` if you don't need it:
|
||||
|
||||
```go
|
||||
// Only care about streaming output
|
||||
response, err := host.PromptWithCallbacks(ctx, "Hello", nil, nil, func(chunk string) {
|
||||
fmt.Print(chunk)
|
||||
})
|
||||
```
|
||||
|
||||
## Event-based monitoring
|
||||
|
||||
For more granular control, use the event subscription API:
|
||||
@@ -116,11 +74,11 @@ The first argument is a priority (lower = runs first).
|
||||
|
||||
## Subagent event monitoring
|
||||
|
||||
Monitor real-time events from LLM-initiated subagents (when the model uses the `spawn_subagent` tool):
|
||||
Monitor real-time events from LLM-initiated subagents (when the model uses the `subagent` tool):
|
||||
|
||||
```go
|
||||
host.OnToolCall(func(e kit.ToolCallEvent) {
|
||||
if e.ToolName == "spawn_subagent" {
|
||||
if e.ToolName == "subagent" {
|
||||
host.SubscribeSubagent(e.ToolCallID, func(event kit.Event) {
|
||||
// Receives the same event types as Subscribe(), scoped to the child agent
|
||||
switch ev := event.(type) {
|
||||
|
||||
@@ -62,7 +62,6 @@ The SDK provides several prompt variants:
|
||||
| Method | Description |
|
||||
|--------|-------------|
|
||||
| `Prompt(ctx, message)` | Simple prompt, returns response string |
|
||||
| `PromptWithCallbacks(ctx, message, ...)` | With tool call and streaming callbacks |
|
||||
| `PromptWithOptions(ctx, message, opts)` | With per-call options |
|
||||
| `PromptResult(ctx, message)` | Returns full `TurnResult` with usage stats |
|
||||
| `PromptResultWithFiles(ctx, message, files)` | Multimodal with file attachments |
|
||||
|
||||
+11
-3
@@ -30,12 +30,20 @@ When conversations grow long, Kit can compact them to free up context window spa
|
||||
|
||||
Use `/compact [focus]` to manually compact, or enable `--auto-compact` to compact automatically near the context limit.
|
||||
|
||||
## Auto-cleanup
|
||||
|
||||
Kit automatically cleans up empty sessions on shutdown and when using `/resume`. A session is considered empty if it has no messages beyond the initial system prompt. This prevents cluttering your sessions directory with unused files.
|
||||
|
||||
To start fresh without creating a session file at all, use ephemeral mode:
|
||||
|
||||
```bash
|
||||
kit --no-session
|
||||
```
|
||||
|
||||
## Resuming sessions
|
||||
|
||||
### Continue most recent
|
||||
|
||||
Resume the most recent session for the current directory:
|
||||
|
||||
```bash
|
||||
kit --continue
|
||||
kit -c
|
||||
@@ -73,7 +81,7 @@ These slash commands are available during an interactive session:
|
||||
| `/share` | Upload session to GitHub Gist and get a shareable viewer URL |
|
||||
| `/tree` | Navigate the session tree |
|
||||
| `/fork` | Branch from an earlier message |
|
||||
| `/new` | Start a fresh session |
|
||||
| `/new` | Start a new session (creates new session file) |
|
||||
|
||||
## Ephemeral mode
|
||||
|
||||
|
||||
@@ -1566,7 +1566,7 @@ a:hover { text-decoration: underline; }
|
||||
'grep': '🔍',
|
||||
'find': '📁',
|
||||
'ls': '📂',
|
||||
'spawn_subagent': '🤖',
|
||||
'subagent': '🤖',
|
||||
'fetch': '🌐',
|
||||
'todo': '✅'
|
||||
};
|
||||
@@ -1612,7 +1612,7 @@ a:hover { text-decoration: underline; }
|
||||
headerLabel = formatLsHeader(input);
|
||||
bodyHtml = renderGenericBody(input, result);
|
||||
break;
|
||||
case 'spawn_subagent':
|
||||
case 'subagent':
|
||||
headerLabel = formatSubagentHeader(input);
|
||||
bodyHtml = renderSubagentBody(input, result);
|
||||
break;
|
||||
|
||||
Reference in New Issue
Block a user