mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-20 14:20:34 +00:00
Compare commits
51 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f36166bee5 | |||
| 879e81f9b5 | |||
| 727b42acfe | |||
| 4830981570 | |||
| dcfebafcc5 | |||
| 1f5c103667 | |||
| 4caa8ba3dc | |||
| 15ef8ad78b | |||
| 551f2710d9 | |||
| 67bda5cad5 | |||
| 01d7d754ef | |||
| c6304f1e92 | |||
| bc3c733ae3 | |||
| 428ee2b8be | |||
| eb1d7fd07e | |||
| 1e3e5cafd3 | |||
| 0b93e58fb9 | |||
| 2bb01ed72c | |||
| b6ecc36ea1 | |||
| d4f27bc912 | |||
| f12e195390 | |||
| b68b3dd0bf | |||
| 48521bf76d | |||
| 16df3a738c | |||
| 9d0b8c8cef | |||
| d9326fcf21 | |||
| 22c479277e | |||
| 8ae204f12f | |||
| 8b1665a4ce | |||
| 941f1daf0b | |||
| ab7e2bda61 | |||
| 741520927c | |||
| 4c1bda9541 | |||
| 3b69b13556 | |||
| 83a959a379 | |||
| 3491e05e9e | |||
| 0a54a8aa05 | |||
| 3cb3e5dba1 | |||
| 31966c469f | |||
| f03625d6e5 | |||
| d06641dc0a | |||
| bbf1106e27 | |||
| babed03a3d | |||
| 1cd074836f | |||
| ab3ce260c8 | |||
| 8e8cc3946d | |||
| e18e36625e | |||
| be55bc03f1 | |||
| 09919b6307 | |||
| 7a2de4cc3c | |||
| acd7fd7f45 |
@@ -0,0 +1,304 @@
|
||||
//go:build ignore
|
||||
|
||||
// subagent-monitor — live horizontal widget strip for spawned subagents
|
||||
//
|
||||
// Subscribes to subagents spawned by the main Kit agent and displays a
|
||||
// single widget just above the input box. Each subagent occupies one column
|
||||
// in a side-by-side horizontal layout. Columns show scrolling real-time
|
||||
// output as the subagent works. When a subagent finishes its column is
|
||||
// removed automatically.
|
||||
//
|
||||
// Yaegi-safe design notes:
|
||||
// - No sync.Mutex (Yaegi has reflection issues with sync primitives)
|
||||
// - No channels in maps (Yaegi panics on range over map[string]chan)
|
||||
// - All ctx.* calls guarded with nil checks
|
||||
// - Simple data structures only
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Per-subagent state
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type submonEntry struct {
|
||||
id int
|
||||
callID string
|
||||
task string
|
||||
lines []string
|
||||
started time.Time
|
||||
elapsed time.Duration
|
||||
}
|
||||
|
||||
const (
|
||||
submonColWidth = 34 // visible character width per column
|
||||
submonMaxLines = 5 // scrolling output lines per column
|
||||
submonColGap = 2 // spaces between columns
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Package-level state - all simple types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
var (
|
||||
submonCtx ext.Context
|
||||
submonHasCtx bool
|
||||
submonEntries []*submonEntry
|
||||
submonNextID int
|
||||
)
|
||||
|
||||
func submonInit() {
|
||||
submonEntries = nil
|
||||
submonNextID = 1
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// String helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func submonPad(s string, w int) string {
|
||||
r := []rune(s)
|
||||
if len(r) >= w {
|
||||
return string(r[:w])
|
||||
}
|
||||
return s + strings.Repeat(" ", w-len(r))
|
||||
}
|
||||
|
||||
func submonTrunc(s string, w int) string {
|
||||
r := []rune(s)
|
||||
if len(r) <= w {
|
||||
return s
|
||||
}
|
||||
if w <= 1 {
|
||||
return "…"
|
||||
}
|
||||
return string(r[:w-1]) + "…"
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Widget rendering
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func submonRenderColumn(e *submonEntry) []string {
|
||||
var rows []string
|
||||
|
||||
// Calculate elapsed time on-demand to avoid race conditions with ticker
|
||||
elapsed := e.elapsed
|
||||
if elapsed == 0 && !e.started.IsZero() {
|
||||
elapsed = time.Since(e.started)
|
||||
}
|
||||
secs := int(elapsed.Seconds())
|
||||
timeStr := fmt.Sprintf("%ds", secs)
|
||||
taskMax := submonColWidth - len(timeStr) - 3
|
||||
taskPart := submonTrunc(e.task, taskMax)
|
||||
header := fmt.Sprintf("#%d %s %s", e.id, taskPart, timeStr)
|
||||
rows = append(rows, submonPad(header, submonColWidth))
|
||||
|
||||
display := e.lines
|
||||
if len(display) > submonMaxLines {
|
||||
display = display[len(display)-submonMaxLines:]
|
||||
}
|
||||
for _, l := range display {
|
||||
rows = append(rows, submonPad(" "+submonTrunc(l, submonColWidth-2), submonColWidth))
|
||||
}
|
||||
for len(rows) < submonMaxLines+1 {
|
||||
if len(rows) == 1 && len(e.lines) == 0 {
|
||||
rows = append(rows, submonPad(" waiting…", submonColWidth))
|
||||
} else {
|
||||
rows = append(rows, strings.Repeat(" ", submonColWidth))
|
||||
}
|
||||
}
|
||||
return rows
|
||||
}
|
||||
|
||||
func submonBuildWidget() string {
|
||||
if len(submonEntries) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
numCols := len(submonEntries)
|
||||
numRows := submonMaxLines + 1
|
||||
cols := make([][]string, numCols)
|
||||
for i, e := range submonEntries {
|
||||
rows := submonRenderColumn(e)
|
||||
col := make([]string, numRows)
|
||||
for j := 0; j < numRows; j++ {
|
||||
if j < len(rows) {
|
||||
col[j] = rows[j]
|
||||
} else {
|
||||
col[j] = strings.Repeat(" ", submonColWidth)
|
||||
}
|
||||
}
|
||||
cols[i] = col
|
||||
}
|
||||
|
||||
gap := strings.Repeat(" ", submonColGap)
|
||||
var sb strings.Builder
|
||||
for row := 0; row < numRows; row++ {
|
||||
for ci := range cols {
|
||||
if ci > 0 {
|
||||
sb.WriteString(gap)
|
||||
}
|
||||
sb.WriteString(cols[ci][row])
|
||||
}
|
||||
if row < numRows-1 {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func submonPushWidget() {
|
||||
if !submonHasCtx {
|
||||
return
|
||||
}
|
||||
if submonCtx.SetWidget == nil {
|
||||
return
|
||||
}
|
||||
|
||||
text := submonBuildWidget()
|
||||
if len(submonEntries) == 0 {
|
||||
if submonCtx.RemoveWidget != nil {
|
||||
submonCtx.RemoveWidget("submon")
|
||||
}
|
||||
return
|
||||
}
|
||||
submonCtx.SetWidget(ext.WidgetConfig{
|
||||
ID: "submon",
|
||||
Placement: ext.WidgetAbove,
|
||||
Content: ext.WidgetContent{Text: text},
|
||||
Style: ext.WidgetStyle{BorderColor: "#89b4fa"},
|
||||
Priority: 0,
|
||||
})
|
||||
}
|
||||
|
||||
func submonAppendLine(e *submonEntry, line string) {
|
||||
line = strings.TrimRight(line, "\r\n")
|
||||
if strings.TrimSpace(line) == "" {
|
||||
return
|
||||
}
|
||||
e.lines = append(e.lines, line)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Init
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func Init(api ext.API) {
|
||||
submonInit()
|
||||
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
submonCtx = ctx
|
||||
submonHasCtx = true
|
||||
submonInit()
|
||||
if ctx.RemoveWidget != nil {
|
||||
ctx.RemoveWidget("submon")
|
||||
}
|
||||
})
|
||||
|
||||
api.OnAgentEnd(func(_ ext.AgentEndEvent, ctx ext.Context) {
|
||||
submonCtx = ctx
|
||||
submonHasCtx = true
|
||||
})
|
||||
|
||||
// ── SubagentStart ────────────────────────────────────────────────────────
|
||||
api.OnSubagentStart(func(e ext.SubagentStartEvent, ctx ext.Context) {
|
||||
submonCtx = ctx
|
||||
submonHasCtx = true
|
||||
|
||||
id := submonNextID
|
||||
submonNextID++
|
||||
entry := &submonEntry{
|
||||
id: id,
|
||||
callID: e.ToolCallID,
|
||||
task: e.Task,
|
||||
started: time.Now(),
|
||||
}
|
||||
submonEntries = append(submonEntries, entry)
|
||||
|
||||
submonPushWidget()
|
||||
})
|
||||
|
||||
// ── SubagentChunk ────────────────────────────────────────────────────────
|
||||
api.OnSubagentChunk(func(e ext.SubagentChunkEvent, ctx ext.Context) {
|
||||
submonCtx = ctx
|
||||
submonHasCtx = true
|
||||
|
||||
var entry *submonEntry
|
||||
for _, en := range submonEntries {
|
||||
if en.callID == e.ToolCallID {
|
||||
entry = en
|
||||
break
|
||||
}
|
||||
}
|
||||
if entry == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch e.ChunkType {
|
||||
case "text":
|
||||
for _, line := range strings.Split(e.Content, "\n") {
|
||||
submonAppendLine(entry, line)
|
||||
}
|
||||
case "tool_call":
|
||||
submonAppendLine(entry, "→ "+e.ToolName)
|
||||
case "tool_execution_start":
|
||||
submonAppendLine(entry, "⚙ "+e.ToolName)
|
||||
case "tool_result":
|
||||
if e.IsError {
|
||||
submonAppendLine(entry, "✗ "+e.ToolName)
|
||||
} else {
|
||||
submonAppendLine(entry, "✓ "+e.ToolName)
|
||||
}
|
||||
}
|
||||
|
||||
submonPushWidget()
|
||||
})
|
||||
|
||||
// ── SubagentEnd ──────────────────────────────────────────────────────────
|
||||
api.OnSubagentEnd(func(e ext.SubagentEndEvent, ctx ext.Context) {
|
||||
submonCtx = ctx
|
||||
submonHasCtx = true
|
||||
|
||||
var entry *submonEntry
|
||||
for _, en := range submonEntries {
|
||||
if en.callID == e.ToolCallID {
|
||||
entry = en
|
||||
break
|
||||
}
|
||||
}
|
||||
if entry != nil {
|
||||
entry.elapsed = time.Since(entry.started)
|
||||
if e.ErrorMsg != "" {
|
||||
submonAppendLine(entry, "✗ "+submonTrunc(e.ErrorMsg, submonColWidth-2))
|
||||
}
|
||||
}
|
||||
|
||||
submonPushWidget()
|
||||
|
||||
// Remove the entry immediately (no goroutine to avoid races)
|
||||
newEntries := submonEntries[:0]
|
||||
for _, en := range submonEntries {
|
||||
if en.callID != e.ToolCallID {
|
||||
newEntries = append(newEntries, en)
|
||||
}
|
||||
}
|
||||
submonEntries = newEntries
|
||||
submonPushWidget()
|
||||
})
|
||||
|
||||
// ── SessionShutdown ──────────────────────────────────────────────────────
|
||||
api.OnSessionShutdown(func(_ ext.SessionShutdownEvent, ctx ext.Context) {
|
||||
submonInit()
|
||||
// Guard ctx access - may be nil during shutdown
|
||||
if ctx.RemoveWidget != nil {
|
||||
ctx.RemoveWidget("submon")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -18,7 +18,7 @@ A powerful, extensible AI coding agent CLI with multi-provider support, built-in
|
||||
## Features
|
||||
|
||||
- **Multi-Provider LLM Support**: Anthropic, OpenAI, Google Gemini, Ollama, Azure OpenAI, AWS Bedrock, OpenRouter, and more
|
||||
- **Built-in Core Tools**: bash, read, write, edit, grep, find, ls, spawn_subagent - no MCP overhead
|
||||
- **Built-in Core Tools**: bash, read, write, edit, grep, find, ls, subagent - no MCP overhead
|
||||
- **MCP Integration**: Connect external MCP servers for expanded capabilities
|
||||
- **Extension System**: Write custom tools, commands, widgets, and UI modifications in Go
|
||||
- **Theming**: 22 built-in color themes (KITT, Catppuccin, Dracula, Nord, etc.) with runtime switching, persistence, and custom theme files
|
||||
@@ -287,7 +287,7 @@ kit -e examples/extensions/minimal.go
|
||||
|
||||
### Extension Capabilities
|
||||
|
||||
**Lifecycle Events**: OnSessionStart, OnSessionShutdown, OnBeforeAgentStart, OnAgentStart, OnAgentEnd, OnToolCall, OnToolExecutionStart, OnToolOutput, OnToolExecutionEnd, OnToolResult, OnInput, OnMessageStart, OnMessageUpdate, OnMessageEnd, OnModelChange, OnContextPrepare, OnBeforeFork, OnBeforeSessionSwitch, OnBeforeCompact, OnCustomEvent
|
||||
**Lifecycle Events**: OnSessionStart, OnSessionShutdown, OnBeforeAgentStart, OnAgentStart, OnAgentEnd, OnToolCall, OnToolExecutionStart, OnToolOutput, OnToolExecutionEnd, OnToolResult, OnInput, OnMessageStart, OnMessageUpdate, OnMessageEnd, OnModelChange, OnContextPrepare, OnBeforeFork, OnBeforeSessionSwitch, OnBeforeCompact, OnCustomEvent, OnSubagentStart, OnSubagentChunk, OnSubagentEnd
|
||||
|
||||
**Custom Components**:
|
||||
- **Tools**: Add new tools the LLM can invoke
|
||||
@@ -307,6 +307,12 @@ kit -e examples/extensions/minimal.go
|
||||
- **Themes**: Register and switch color themes via `RegisterTheme`, `SetTheme`, `ListThemes`
|
||||
- **Custom Events**: Inter-extension communication via `EmitCustomEvent`
|
||||
|
||||
**Bridged SDK APIs** (NEW): Extensions can now access internal SDK capabilities:
|
||||
- **Tree Navigation**: Navigate conversation history (`GetTreeNode`, `GetCurrentBranch`, `NavigateTo`), summarize branches (`SummarizeBranch`), and implement fresh context loops (`CollapseBranch`)
|
||||
- **Skill Loading**: Dynamically load and inject skills at runtime (`LoadSkill`, `DiscoverSkills`, `InjectSkillAsContext`)
|
||||
- **Template Parsing**: Parse and render templates with `{{variables}}` (`ParseTemplate`, `RenderTemplate`), parse CLI-style arguments (`ParseArguments`, `SimpleParseArguments`), and evaluate model conditionals (`EvaluateModelConditional`, `RenderWithModelConditionals`)
|
||||
- **Model Resolution**: Resolve model fallback chains (`ResolveModelChain`), query model capabilities (`GetModelCapabilities`, `CheckModelAvailable`), and extract provider/model ID (`GetCurrentProvider`, `GetCurrentModelID`)
|
||||
|
||||
### Extension Examples
|
||||
|
||||
See the `examples/extensions/` directory:
|
||||
@@ -318,6 +324,7 @@ See the `examples/extensions/` directory:
|
||||
- `compact-notify.go` - Notification on compaction
|
||||
- `confirm-destructive.go` - Confirm destructive operations
|
||||
- `context-inject.go` - Inject context into conversations
|
||||
- `conversation-manager.go` - **NEW** Tree navigation, branch summarization, and fresh context loops
|
||||
- `custom-editor-demo.go` - Vim-like modal editor
|
||||
- `dev-reload.go` - Development live-reload
|
||||
- `header-footer-demo.go` - Custom headers and footers
|
||||
@@ -332,6 +339,7 @@ See the `examples/extensions/` directory:
|
||||
- `plan-mode.go` - Read-only planning mode
|
||||
- `project-rules.go` - Project-specific rules
|
||||
- `prompt-demo.go` - Interactive prompts (select/confirm/input)
|
||||
- `prompt-templates.go` - **NEW** Frontmatter-driven templates with model switching and skill injection
|
||||
- `protected-paths.go` - Path protection for sensitive files
|
||||
- `subagent-widget.go` - Multi-agent orchestration with status widget
|
||||
- `subagent-test.go` - Subagent testing utilities
|
||||
@@ -535,23 +543,26 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
### With Callbacks
|
||||
|
||||
```go
|
||||
response, err := host.PromptWithCallbacks(
|
||||
unsub := host.OnToolCall(func(e kit.ToolCallEvent) {
|
||||
println("Calling tool:", e.ToolName)
|
||||
})
|
||||
defer unsub()
|
||||
|
||||
unsub2 := host.OnToolResult(func(e kit.ToolResultEvent) {
|
||||
if e.IsError {
|
||||
println("Tool failed:", e.ToolName)
|
||||
}
|
||||
})
|
||||
defer unsub2()
|
||||
|
||||
unsub3 := host.OnStreaming(func(e kit.MessageUpdateEvent) {
|
||||
print(e.Chunk)
|
||||
})
|
||||
defer unsub3()
|
||||
|
||||
response, err := host.Prompt(
|
||||
ctx,
|
||||
"List files in current directory",
|
||||
func(name, args string) {
|
||||
// Tool call started
|
||||
println("Calling tool:", name)
|
||||
},
|
||||
func(name, args, result string, isError bool) {
|
||||
// Tool call completed
|
||||
if isError {
|
||||
println("Tool failed:", name)
|
||||
}
|
||||
},
|
||||
func(chunk string) {
|
||||
// Streaming text chunk
|
||||
print(chunk)
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
@@ -76,6 +76,12 @@
|
||||
"name": "opencode",
|
||||
"url": "https://github.com/anomalyco/opencode",
|
||||
"branch": "dev"
|
||||
},
|
||||
{
|
||||
"type": "git",
|
||||
"name": "herald",
|
||||
"url": "https://github.com/indaco/herald",
|
||||
"branch": "main"
|
||||
}
|
||||
],
|
||||
"model": "claude-haiku-4-5",
|
||||
|
||||
+300
-7
@@ -1,9 +1,13 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/huh/v2"
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
@@ -14,7 +18,7 @@ import (
|
||||
// authCmd represents the auth command for managing AI provider authentication.
|
||||
// This command provides subcommands for login, logout, and status checking
|
||||
// of authentication credentials for various AI providers, with OAuth support
|
||||
// for providers like Anthropic.
|
||||
// for providers like Anthropic and OpenAI.
|
||||
var authCmd = &cobra.Command{
|
||||
Use: "auth",
|
||||
Short: "Manage authentication credentials for AI providers",
|
||||
@@ -25,9 +29,11 @@ using OAuth flows. Stored credentials take precedence over environment variables
|
||||
|
||||
Available providers:
|
||||
- anthropic: Anthropic Claude API (OAuth)
|
||||
- openai: OpenAI API (OAuth and API key)
|
||||
|
||||
Examples:
|
||||
kit auth login anthropic
|
||||
kit auth login openai
|
||||
kit auth logout anthropic
|
||||
kit auth status`,
|
||||
}
|
||||
@@ -46,9 +52,11 @@ environment variables when making API calls.
|
||||
|
||||
Available providers:
|
||||
- anthropic: Anthropic Claude API (OAuth)
|
||||
- openai: OpenAI ChatGPT Plus/Pro (Codex OAuth)
|
||||
|
||||
Example:
|
||||
kit auth login anthropic`,
|
||||
kit auth login anthropic
|
||||
kit auth login openai`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runAuthLogin,
|
||||
}
|
||||
@@ -61,14 +69,16 @@ var authLogoutCmd = &cobra.Command{
|
||||
Short: "Remove stored authentication credentials for a provider",
|
||||
Long: `Remove stored authentication credentials for an AI provider.
|
||||
|
||||
This will delete the stored API key for the specified provider. You will need
|
||||
to use environment variables or command-line flags for authentication after logout.
|
||||
This will delete the stored API key or OAuth credentials for the specified provider.
|
||||
You will need to use environment variables or command-line flags for authentication after logout.
|
||||
|
||||
Available providers:
|
||||
- anthropic: Anthropic Claude API
|
||||
- openai: OpenAI API
|
||||
|
||||
Example:
|
||||
kit auth logout anthropic`,
|
||||
kit auth logout anthropic
|
||||
kit auth logout openai`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runAuthLogout,
|
||||
}
|
||||
@@ -101,8 +111,10 @@ func runAuthLogin(cmd *cobra.Command, args []string) error {
|
||||
switch provider {
|
||||
case "anthropic":
|
||||
return loginAnthropic()
|
||||
case "openai":
|
||||
return loginOpenAI()
|
||||
default:
|
||||
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic", provider)
|
||||
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic, openai", provider)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,8 +124,10 @@ func runAuthLogout(cmd *cobra.Command, args []string) error {
|
||||
switch provider {
|
||||
case "anthropic":
|
||||
return logoutAnthropic()
|
||||
case "openai":
|
||||
return logoutOpenAI()
|
||||
default:
|
||||
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic", provider)
|
||||
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic, openai", provider)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,8 +171,44 @@ func runAuthStatus(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Check OpenAI credentials
|
||||
fmt.Print("\nOpenAI: ")
|
||||
if hasOpenAICreds, err := cm.HasOpenAICredentials(); err != nil {
|
||||
fmt.Printf("Error checking credentials: %v\n", err)
|
||||
} else if hasOpenAICreds {
|
||||
if creds, err := cm.GetOpenAICredentials(); err != nil {
|
||||
fmt.Printf("Error reading credentials: %v\n", err)
|
||||
} else {
|
||||
authType := "API Key"
|
||||
status := "✓ Authenticated"
|
||||
|
||||
if creds.Type == "oauth" {
|
||||
authType = "OAuth (ChatGPT/Codex)"
|
||||
if creds.IsExpired() {
|
||||
status = "⚠️ Token expired (will refresh automatically)"
|
||||
} else if creds.NeedsRefresh() {
|
||||
status = "⚠️ Token expires soon (will refresh automatically)"
|
||||
}
|
||||
}
|
||||
|
||||
accountInfo := ""
|
||||
if creds.Type == "oauth" && creds.AccountID != "" {
|
||||
accountInfo = fmt.Sprintf(" [%s]", creds.AccountID)
|
||||
}
|
||||
|
||||
fmt.Printf("%s (%s%s, stored %s)\n", status, authType, accountInfo, creds.CreatedAt.Format("2006-01-02 15:04:05"))
|
||||
}
|
||||
} else {
|
||||
fmt.Println("✗ Not authenticated")
|
||||
// Check if environment variable is set
|
||||
if os.Getenv("OPENAI_API_KEY") != "" {
|
||||
fmt.Println(" (OPENAI_API_KEY environment variable is set)")
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("\nTo authenticate with a provider:")
|
||||
fmt.Println(" kit auth login anthropic")
|
||||
fmt.Println(" kit auth login openai")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -282,3 +332,246 @@ func logoutAnthropic() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func loginOpenAI() error {
|
||||
cm, err := kit.NewCredentialManager()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize credential manager: %w", err)
|
||||
}
|
||||
|
||||
// Check if already authenticated
|
||||
if hasAuth, err := cm.HasOpenAICredentials(); err == nil && hasAuth {
|
||||
var reauth bool
|
||||
err := huh.NewConfirm().
|
||||
Title("You are already authenticated with OpenAI (ChatGPT/Codex)").
|
||||
Description("Do you want to re-authenticate?").
|
||||
Affirmative("Yes").
|
||||
Negative("No").
|
||||
Value(&reauth).
|
||||
Run()
|
||||
if err != nil || !reauth {
|
||||
fmt.Println("Authentication cancelled.")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Create OAuth client
|
||||
client := auth.NewOpenAIOAuthClient()
|
||||
|
||||
// Generate authorization URL
|
||||
fmt.Println("🔐 Starting OAuth authentication with OpenAI (ChatGPT/Codex)...")
|
||||
fmt.Println("This will open your browser to authenticate with your ChatGPT account.")
|
||||
fmt.Println()
|
||||
|
||||
authData, err := client.GetAuthorizationURL()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate authorization URL: %w", err)
|
||||
}
|
||||
|
||||
// Start local callback server
|
||||
callbackServer, err := startOpenAICallbackServer(authData.State)
|
||||
if err != nil {
|
||||
fmt.Printf("⚠️ Could not start local callback server: %v\n", err)
|
||||
fmt.Println("Falling back to manual code entry.")
|
||||
}
|
||||
if callbackServer != nil {
|
||||
defer callbackServer.Close()
|
||||
}
|
||||
|
||||
// Display URL and try to open browser
|
||||
fmt.Println("📱 Opening your browser for authentication...")
|
||||
fmt.Println("If the browser doesn't open automatically, please visit this URL:")
|
||||
fmt.Printf("\n%s\n\n", authData.URL)
|
||||
|
||||
// Try to open browser
|
||||
auth.TryOpenBrowser(authData.URL)
|
||||
|
||||
// Wait for callback or manual input
|
||||
var code string
|
||||
if callbackServer != nil {
|
||||
fmt.Println("Waiting for browser authentication...")
|
||||
select {
|
||||
case callbackCode := <-callbackServer.CodeChan:
|
||||
if callbackCode != "" {
|
||||
code = callbackCode
|
||||
fmt.Println("✓ Received authorization code from browser callback.")
|
||||
}
|
||||
case <-time.After(2 * time.Minute):
|
||||
fmt.Println("\n⏱️ Timeout waiting for browser callback.")
|
||||
callbackServer.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// If no code from callback, prompt for manual entry
|
||||
if code == "" {
|
||||
fmt.Println("\nAfter authorizing, paste the callback URL or authorization code below.")
|
||||
fmt.Println("(The callback URL will look like: http://localhost:1455/auth/callback?code=...&state=...)")
|
||||
fmt.Println()
|
||||
|
||||
var input string
|
||||
err = huh.NewInput().
|
||||
Title("Callback URL or Code").
|
||||
Description("Paste the full callback URL or just the authorization code").
|
||||
Value(&input).
|
||||
Run()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read input: %w", err)
|
||||
}
|
||||
input = strings.TrimSpace(input)
|
||||
|
||||
if input == "" {
|
||||
return fmt.Errorf("authorization code cannot be empty")
|
||||
}
|
||||
|
||||
// Parse the input (could be full URL or just code)
|
||||
parsedCode, parsedState := auth.ParseOpenAIAuthorizationInput(input)
|
||||
if parsedCode == "" {
|
||||
return fmt.Errorf("could not extract authorization code from input")
|
||||
}
|
||||
|
||||
// Validate state if provided
|
||||
if parsedState != "" && parsedState != authData.State {
|
||||
return fmt.Errorf("state mismatch - possible security issue")
|
||||
}
|
||||
code = parsedCode
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
fmt.Println("\n🔄 Exchanging authorization code for access token...")
|
||||
creds, err := client.ExchangeCode(code, authData.Verifier)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to exchange authorization code: %w", err)
|
||||
}
|
||||
|
||||
// Store the credentials
|
||||
if err := cm.SetOpenAIOAuthCredentials(creds); err != nil {
|
||||
return fmt.Errorf("failed to store credentials: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("✅ Successfully authenticated with OpenAI (ChatGPT/Codex)!")
|
||||
fmt.Printf("📁 Credentials stored in: %s\n", cm.GetCredentialsPath())
|
||||
fmt.Printf("👤 Account ID: %s\n", creds.AccountID)
|
||||
fmt.Println("\n🎉 Your OAuth credentials will now be used for OpenAI API calls.")
|
||||
fmt.Println("💡 You can check your authentication status with: kit auth status")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// callbackServer holds the HTTP server and channel for receiving the OAuth callback
|
||||
type callbackServer struct {
|
||||
Server *http.Server
|
||||
CodeChan chan string
|
||||
State string
|
||||
}
|
||||
|
||||
// Close shuts down the callback server
|
||||
func (cs *callbackServer) Close() {
|
||||
if cs.Server != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = cs.Server.Shutdown(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// startOpenAICallbackServer starts a local HTTP server to receive the OAuth callback
|
||||
func startOpenAICallbackServer(expectedState string) (*callbackServer, error) {
|
||||
codeChan := make(chan string, 1)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
server := &http.Server{
|
||||
Addr: "127.0.0.1:1455",
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
mux.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check state
|
||||
state := r.URL.Query().Get("state")
|
||||
if state != expectedState {
|
||||
http.Error(w, "State mismatch", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
code := r.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
http.Error(w, "Missing authorization code", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Send code to channel
|
||||
select {
|
||||
case codeChan <- code:
|
||||
default:
|
||||
}
|
||||
|
||||
// Return success page
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = fmt.Fprintf(w, `<!DOCTYPE html>
|
||||
<html>
|
||||
<head><title>Authentication Successful</title></head>
|
||||
<body style="font-family: sans-serif; text-align: center; padding: 50px;">
|
||||
<h1>✓ Authentication Successful</h1>
|
||||
<p>You can close this window and return to the terminal.</p>
|
||||
</body>
|
||||
</html>`)
|
||||
})
|
||||
|
||||
// Try to start server
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:1455")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("port 1455 not available: %w", err)
|
||||
}
|
||||
_ = listener.Close()
|
||||
|
||||
go func() {
|
||||
_ = server.ListenAndServe()
|
||||
}()
|
||||
|
||||
return &callbackServer{
|
||||
Server: server,
|
||||
CodeChan: codeChan,
|
||||
State: expectedState,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func logoutOpenAI() error {
|
||||
cm, err := kit.NewCredentialManager()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize credential manager: %w", err)
|
||||
}
|
||||
|
||||
// Check if authenticated
|
||||
hasAuth, err := cm.HasOpenAICredentials()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check authentication status: %w", err)
|
||||
}
|
||||
|
||||
if !hasAuth {
|
||||
fmt.Println("You are not currently authenticated with OpenAI.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Confirm logout
|
||||
var confirm bool
|
||||
err = huh.NewConfirm().
|
||||
Title("Remove OpenAI credentials").
|
||||
Description("Are you sure you want to remove your stored credentials?").
|
||||
Affirmative("Yes").
|
||||
Negative("No").
|
||||
Value(&confirm).
|
||||
Run()
|
||||
if err != nil || !confirm {
|
||||
fmt.Println("Logout cancelled.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove credentials
|
||||
if err := cm.RemoveOpenAICredentials(); err != nil {
|
||||
return fmt.Errorf("failed to remove credentials: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("✓ Successfully logged out from OpenAI!")
|
||||
fmt.Println("You will need to use environment variables or command-line flags for authentication.")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
+450
-10
@@ -13,6 +13,7 @@ import (
|
||||
"charm.land/fantasy"
|
||||
"charm.land/lipgloss/v2"
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
@@ -798,19 +799,29 @@ func runNormalMode(ctx context.Context) error {
|
||||
appInstance := app.New(appOpts, messages)
|
||||
defer appInstance.Close()
|
||||
|
||||
// Buffer for extension messages during startup (printed after startup banner).
|
||||
var startupExtensionMessages []string
|
||||
|
||||
// Set up extension context and emit SessionStart.
|
||||
if kitInstance.HasExtensions() {
|
||||
cwd, _ := os.Getwd()
|
||||
kitInstance.SetExtensionContext(extensions.Context{
|
||||
CWD: cwd,
|
||||
Model: modelName,
|
||||
Interactive: positionalPrompt == "",
|
||||
Print: func(text string) { appInstance.PrintFromExtension("", text) },
|
||||
PrintInfo: func(text string) { appInstance.PrintFromExtension("info", text) },
|
||||
PrintError: func(text string) { appInstance.PrintFromExtension("error", text) },
|
||||
CWD: cwd,
|
||||
Model: modelName,
|
||||
Interactive: positionalPrompt == "",
|
||||
Print: func(text string) {
|
||||
// Capture messages during startup, print after startup banner.
|
||||
startupExtensionMessages = append(startupExtensionMessages, text)
|
||||
},
|
||||
PrintInfo: func(text string) {
|
||||
startupExtensionMessages = append(startupExtensionMessages, text)
|
||||
},
|
||||
PrintError: func(text string) {
|
||||
startupExtensionMessages = append(startupExtensionMessages, text)
|
||||
},
|
||||
PrintBlock: appInstance.PrintBlockFromExtension,
|
||||
SendMessage: func(text string) { appInstance.Run(text) },
|
||||
CancelAndSend: func(text string) { appInstance.Steer(text) },
|
||||
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
SetWidget: func(config extensions.WidgetConfig) {
|
||||
kitInstance.SetExtensionWidget(config)
|
||||
@@ -955,6 +966,24 @@ func runNormalMode(ctx context.Context) error {
|
||||
kitInstance.UpdateExtensionContextModel(modelString)
|
||||
// Fire OnModelChange event to extensions.
|
||||
kitInstance.EmitModelChange(modelString, previousModel, "extension")
|
||||
// Update usage tracker with new model info for correct token counting.
|
||||
if usageTracker != nil {
|
||||
newProvider, newModel, _ := models.ParseModelString(modelString)
|
||||
if newProvider != "unknown" && newModel != "unknown" && newProvider != "ollama" {
|
||||
registry := models.GetGlobalRegistry()
|
||||
if modelInfo := registry.LookupModel(newProvider, newModel); modelInfo != nil {
|
||||
// Check OAuth status for Anthropic models
|
||||
isOAuth := false
|
||||
if newProvider == "anthropic" {
|
||||
_, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key"))
|
||||
if err == nil && strings.HasPrefix(source, "stored OAuth") {
|
||||
isOAuth = true
|
||||
}
|
||||
}
|
||||
usageTracker.UpdateModelInfo(modelInfo, newProvider, isOAuth)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetAvailableModels: func() []extensions.ModelInfoEntry {
|
||||
@@ -1078,8 +1107,392 @@ func runNormalMode(ctx context.Context) error {
|
||||
}
|
||||
return nil, extResult, err
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Tree Navigation API (Phase 1 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
GetTreeNode: func(entryID string) *extensions.TreeNode {
|
||||
node := kitInstance.GetTreeNode(entryID)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
return &extensions.TreeNode{
|
||||
ID: node.ID,
|
||||
ParentID: node.ParentID,
|
||||
Type: node.Type,
|
||||
Role: node.Role,
|
||||
Content: node.Content,
|
||||
Model: node.Model,
|
||||
Provider: node.Provider,
|
||||
Timestamp: node.Timestamp,
|
||||
Children: node.Children,
|
||||
}
|
||||
},
|
||||
GetCurrentBranch: func() []extensions.TreeNode {
|
||||
nodes := kitInstance.GetCurrentBranch()
|
||||
result := make([]extensions.TreeNode, len(nodes))
|
||||
for i, n := range nodes {
|
||||
result[i] = extensions.TreeNode{
|
||||
ID: n.ID,
|
||||
ParentID: n.ParentID,
|
||||
Type: n.Type,
|
||||
Role: n.Role,
|
||||
Content: n.Content,
|
||||
Model: n.Model,
|
||||
Provider: n.Provider,
|
||||
Timestamp: n.Timestamp,
|
||||
Children: n.Children,
|
||||
}
|
||||
}
|
||||
return result
|
||||
},
|
||||
GetChildren: kitInstance.GetChildren,
|
||||
NavigateTo: func(entryID string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.NavigateTo(entryID)
|
||||
if err != "" {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
SummarizeBranch: kitInstance.SummarizeBranch,
|
||||
CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.CollapseBranch(fromID, toID, summary)
|
||||
if err != "" {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Skill Loading API (Phase 2 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
LoadSkill: func(path string) (*extensions.Skill, string) {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
return s, err
|
||||
},
|
||||
LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult {
|
||||
return kitInstance.LoadSkillsFromDirForExtension(dir)
|
||||
},
|
||||
DiscoverSkills: func() extensions.SkillLoadResult {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
return extensions.SkillLoadResult{Skills: skills}
|
||||
},
|
||||
InjectSkillAsContext: func(skillName string) string {
|
||||
// Find skill by name
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
for _, s := range skills {
|
||||
if s.Name == skillName {
|
||||
// Inject via SendMessage as a system context message
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("skill not found: %s", skillName)
|
||||
},
|
||||
InjectRawSkillAsContext: func(path string) string {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
if err != "" {
|
||||
return err
|
||||
}
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
},
|
||||
GetAvailableSkills: kitInstance.DiscoverSkillsForExtension,
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Template Parsing API (Phase 3 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
ParseTemplate: kit.ParseTemplate,
|
||||
RenderTemplate: kit.RenderTemplate,
|
||||
ParseArguments: kit.ParseArguments,
|
||||
SimpleParseArguments: kit.SimpleParseArguments,
|
||||
EvaluateModelConditional: func(condition string) bool {
|
||||
return kit.EvaluateModelConditional(kitInstance.GetExtensionContext().Model, condition)
|
||||
},
|
||||
RenderWithModelConditionals: func(content string) string {
|
||||
return kit.RenderWithModelConditionals(content, kitInstance.GetExtensionContext().Model)
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Model Resolution API (Phase 4 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
ResolveModelChain: kit.ResolveModelChain,
|
||||
GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) {
|
||||
return kit.GetModelCapabilities(model)
|
||||
},
|
||||
CheckModelAvailable: kit.CheckModelAvailable,
|
||||
GetCurrentProvider: func() string {
|
||||
return kit.GetCurrentProvider(kitInstance.GetExtensionContext().Model)
|
||||
},
|
||||
GetCurrentModelID: func() string {
|
||||
return kit.GetCurrentModelID(kitInstance.GetExtensionContext().Model)
|
||||
},
|
||||
})
|
||||
kitInstance.EmitSessionStart()
|
||||
|
||||
// Restore normal print functions for runtime use.
|
||||
kitInstance.SetExtensionContext(extensions.Context{
|
||||
CWD: cwd,
|
||||
Model: modelName,
|
||||
Interactive: positionalPrompt == "",
|
||||
Print: func(text string) { appInstance.PrintFromExtension("", text) },
|
||||
PrintInfo: func(text string) { appInstance.PrintFromExtension("info", text) },
|
||||
PrintError: func(text string) { appInstance.PrintFromExtension("error", text) },
|
||||
PrintBlock: appInstance.PrintBlockFromExtension,
|
||||
SendMessage: func(text string) { appInstance.Run(text) },
|
||||
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
SetWidget: func(config extensions.WidgetConfig) {
|
||||
kitInstance.SetExtensionWidget(config)
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveWidget: func(id string) {
|
||||
kitInstance.RemoveExtensionWidget(id)
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetHeader: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.SetExtensionHeader(config)
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveHeader: func() {
|
||||
kitInstance.RemoveExtensionHeader()
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetFooter: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.SetExtensionFooter(config)
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveFooter: func() {
|
||||
kitInstance.RemoveExtensionFooter()
|
||||
appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
PromptSelect: func(config extensions.PromptSelectConfig) extensions.PromptSelectResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "select",
|
||||
Message: config.Message,
|
||||
Options: config.Options,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptSelectResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptSelectResult{Value: resp.Value, Index: resp.Index}
|
||||
},
|
||||
PromptConfirm: func(config extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
def := "false"
|
||||
if config.DefaultValue {
|
||||
def = "true"
|
||||
}
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "confirm",
|
||||
Message: config.Message,
|
||||
Default: def,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptConfirmResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptConfirmResult{Value: resp.Confirmed}
|
||||
},
|
||||
PromptInput: func(config extensions.PromptInputConfig) extensions.PromptInputResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "input",
|
||||
Message: config.Message,
|
||||
Placeholder: config.Placeholder,
|
||||
Default: config.Default,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptInputResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptInputResult{Value: resp.Value}
|
||||
},
|
||||
ShowOverlay: func(config extensions.OverlayConfig) extensions.OverlayResult {
|
||||
ch := make(chan app.OverlayResponse, 1)
|
||||
appInstance.SendOverlayRequest(app.OverlayRequestEvent{
|
||||
Title: config.Title,
|
||||
Content: config.Content.Text,
|
||||
Markdown: config.Content.Markdown,
|
||||
BorderColor: config.Style.BorderColor,
|
||||
Background: config.Style.Background,
|
||||
Width: config.Width,
|
||||
MaxHeight: config.MaxHeight,
|
||||
Anchor: string(config.Anchor),
|
||||
Actions: config.Actions,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.OverlayResult{Cancelled: true, Index: -1}
|
||||
}
|
||||
return extensions.OverlayResult{
|
||||
Action: resp.Action,
|
||||
Index: resp.Index,
|
||||
}
|
||||
},
|
||||
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
// In-process subagent via SDK.
|
||||
sdkCfg := kit.SubagentConfig{
|
||||
Prompt: config.Prompt,
|
||||
Model: config.Model,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Timeout: config.Timeout,
|
||||
NoSession: config.NoSession,
|
||||
}
|
||||
// Bridge SDK events to extension SubagentEvents.
|
||||
if config.OnEvent != nil {
|
||||
sdkCfg.OnEvent = func(e kit.Event) {
|
||||
se := sdkEventToSubagentEvent(e)
|
||||
if se.Type != "" {
|
||||
config.OnEvent(se)
|
||||
}
|
||||
}
|
||||
}
|
||||
result, err := kitInstance.Subagent(ctx, sdkCfg)
|
||||
if result == nil {
|
||||
return nil, &extensions.SubagentResult{Error: err}, err
|
||||
}
|
||||
extResult := &extensions.SubagentResult{
|
||||
Response: result.Response,
|
||||
Error: result.Error,
|
||||
SessionID: result.SessionID,
|
||||
Elapsed: result.Elapsed,
|
||||
}
|
||||
if result.Usage != nil {
|
||||
extResult.Usage = &extensions.SubagentUsage{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
return nil, extResult, err
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Tree Navigation API (Phase 1 Bridge) - Second Context
|
||||
// -------------------------------------------------------------------------
|
||||
GetTreeNode: func(entryID string) *extensions.TreeNode {
|
||||
node := kitInstance.GetTreeNode(entryID)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
return &extensions.TreeNode{
|
||||
ID: node.ID,
|
||||
ParentID: node.ParentID,
|
||||
Type: node.Type,
|
||||
Role: node.Role,
|
||||
Content: node.Content,
|
||||
Model: node.Model,
|
||||
Provider: node.Provider,
|
||||
Timestamp: node.Timestamp,
|
||||
Children: node.Children,
|
||||
}
|
||||
},
|
||||
GetCurrentBranch: func() []extensions.TreeNode {
|
||||
nodes := kitInstance.GetCurrentBranch()
|
||||
result := make([]extensions.TreeNode, len(nodes))
|
||||
for i, n := range nodes {
|
||||
result[i] = extensions.TreeNode{
|
||||
ID: n.ID,
|
||||
ParentID: n.ParentID,
|
||||
Type: n.Type,
|
||||
Role: n.Role,
|
||||
Content: n.Content,
|
||||
Model: n.Model,
|
||||
Provider: n.Provider,
|
||||
Timestamp: n.Timestamp,
|
||||
Children: n.Children,
|
||||
}
|
||||
}
|
||||
return result
|
||||
},
|
||||
GetChildren: kitInstance.GetChildren,
|
||||
NavigateTo: func(entryID string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.NavigateTo(entryID)
|
||||
if err != "" {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
SummarizeBranch: kitInstance.SummarizeBranch,
|
||||
CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.CollapseBranch(fromID, toID, summary)
|
||||
if err != "" {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Skill Loading API (Phase 2 Bridge) - Second Context
|
||||
// -------------------------------------------------------------------------
|
||||
LoadSkill: func(path string) (*extensions.Skill, string) {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
return s, err
|
||||
},
|
||||
LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult {
|
||||
return kitInstance.LoadSkillsFromDirForExtension(dir)
|
||||
},
|
||||
DiscoverSkills: func() extensions.SkillLoadResult {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
return extensions.SkillLoadResult{Skills: skills}
|
||||
},
|
||||
InjectSkillAsContext: func(skillName string) string {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
for _, s := range skills {
|
||||
if s.Name == skillName {
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("skill not found: %s", skillName)
|
||||
},
|
||||
InjectRawSkillAsContext: func(path string) string {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
if err != "" {
|
||||
return err
|
||||
}
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
},
|
||||
GetAvailableSkills: func() []extensions.Skill {
|
||||
return kitInstance.DiscoverSkillsForExtension()
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Template Parsing API (Phase 3 Bridge) - Second Context
|
||||
// -------------------------------------------------------------------------
|
||||
ParseTemplate: kit.ParseTemplate,
|
||||
RenderTemplate: kit.RenderTemplate,
|
||||
ParseArguments: kit.ParseArguments,
|
||||
SimpleParseArguments: kit.SimpleParseArguments,
|
||||
EvaluateModelConditional: func(condition string) bool {
|
||||
return kit.EvaluateModelConditional(kitInstance.GetExtensionContext().Model, condition)
|
||||
},
|
||||
RenderWithModelConditionals: func(content string) string {
|
||||
return kit.RenderWithModelConditionals(content, kitInstance.GetExtensionContext().Model)
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Model Resolution API (Phase 4 Bridge) - Second Context
|
||||
// -------------------------------------------------------------------------
|
||||
ResolveModelChain: kit.ResolveModelChain,
|
||||
GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) {
|
||||
return kit.GetModelCapabilities(model)
|
||||
},
|
||||
CheckModelAvailable: kit.CheckModelAvailable,
|
||||
GetCurrentProvider: func() string {
|
||||
return kit.GetCurrentProvider(kitInstance.GetExtensionContext().Model)
|
||||
},
|
||||
GetCurrentModelID: func() string {
|
||||
return kit.GetCurrentModelID(kitInstance.GetExtensionContext().Model)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Convert extension commands to UI-layer type for the interactive TUI.
|
||||
@@ -1152,6 +1565,24 @@ func runNormalMode(ctx context.Context) error {
|
||||
// this callback runs synchronously inside BubbleTea's Update(), and
|
||||
// NotifyModelChanged calls prog.Send() which deadlocks. The UI layer
|
||||
// updates m.providerName and m.modelName directly after setModel returns.
|
||||
// Update usage tracker with new model info for correct token counting.
|
||||
if usageTracker != nil {
|
||||
newProvider, newModel, _ := models.ParseModelString(modelString)
|
||||
if newProvider != "unknown" && newModel != "unknown" && newProvider != "ollama" {
|
||||
registry := models.GetGlobalRegistry()
|
||||
if modelInfo := registry.LookupModel(newProvider, newModel); modelInfo != nil {
|
||||
// Check OAuth status for Anthropic models
|
||||
isOAuth := false
|
||||
if newProvider == "anthropic" {
|
||||
_, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key"))
|
||||
if err == nil && strings.HasPrefix(source, "stored OAuth") {
|
||||
isOAuth = true
|
||||
}
|
||||
}
|
||||
usageTracker.UpdateModelInfo(modelInfo, newProvider, isOAuth)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
emitModelChangeForUI := func(newModel, previousModel, source string) {
|
||||
@@ -1185,7 +1616,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
return fmt.Errorf("--quiet requires a prompt")
|
||||
}
|
||||
|
||||
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI)
|
||||
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, startupExtensionMessages)
|
||||
}
|
||||
|
||||
// runNonInteractiveModeApp executes a single prompt via the app layer and exits,
|
||||
@@ -1241,7 +1672,7 @@ func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui
|
||||
|
||||
// If --no-exit was requested, hand off to the interactive TUI.
|
||||
if noExit {
|
||||
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, providerName, loadingMessage, serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModel, emitModelChange, isReasoningModel, thinkingLevel, setThinkingLevel, switchSession)
|
||||
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, providerName, loadingMessage, serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModel, emitModelChange, isReasoningModel, thinkingLevel, setThinkingLevel, switchSession, nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1339,7 +1770,7 @@ func writeJSONError(err error) {
|
||||
// 4. Calls program.Run() which blocks until the user quits (Ctrl+C or /quit).
|
||||
//
|
||||
// SetupCLI is not used for interactive mode; the TUI (AppModel) handles its own rendering.
|
||||
func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []ui.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []ui.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error) error {
|
||||
func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []ui.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []ui.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, startupExtensionMessages []string) error {
|
||||
// Determine terminal size; fall back gracefully.
|
||||
termWidth, termHeight, err := term.GetSize(int(os.Stdout.Fd()))
|
||||
if err != nil || termWidth == 0 {
|
||||
@@ -1389,6 +1820,15 @@ func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelN
|
||||
// Print startup info to stdout before Bubble Tea takes over the screen.
|
||||
appModel.PrintStartupInfo()
|
||||
|
||||
// Print any extension messages that were captured during startup.
|
||||
if len(startupExtensionMessages) > 0 {
|
||||
fmt.Println()
|
||||
for _, msg := range startupExtensionMessages {
|
||||
fmt.Println(msg)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
program := tea.NewProgram(appModel)
|
||||
|
||||
// Register the program with the app layer so agent events are sent to the TUI.
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
//go:build ignore
|
||||
|
||||
// bridge_demo.go - Demonstrates the new bridged SDK APIs for extensions.
|
||||
// This extension showcases tree navigation, skill loading, template parsing,
|
||||
// and model resolution capabilities.
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
var (
|
||||
discoveredSkills []ext.Skill
|
||||
currentBranch []ext.TreeNode
|
||||
)
|
||||
|
||||
func Init(api ext.API) {
|
||||
// Register /tree-info command to demonstrate tree navigation
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "tree-info",
|
||||
Description: "Show current conversation tree information",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
branch := ctx.GetCurrentBranch()
|
||||
info := fmt.Sprintf("Current branch has %d nodes:\n", len(branch))
|
||||
for i, node := range branch {
|
||||
info += fmt.Sprintf(" [%d] %s (%s): %s...\n", i, node.Type, node.ID[:8], truncate(node.Content, 40))
|
||||
}
|
||||
ctx.PrintInfo(info)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /discover-skills command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "discover-skills",
|
||||
Description: "Discover and list available skills",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
result := ctx.DiscoverSkills()
|
||||
if result.Error != "" {
|
||||
return "", fmt.Errorf("discovery failed: %s", result.Error)
|
||||
}
|
||||
discoveredSkills = result.Skills
|
||||
|
||||
info := fmt.Sprintf("Discovered %d skills:\n", len(result.Skills))
|
||||
for _, s := range result.Skills {
|
||||
info += fmt.Sprintf(" - %s: %s\n", s.Name, s.Description)
|
||||
}
|
||||
ctx.PrintInfo(info)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /parse-template command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "parse-template",
|
||||
Description: "Parse a template and show extracted variables",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
if args == "" {
|
||||
args = "Hello {{name}}, welcome to {{place}}!"
|
||||
}
|
||||
tpl := ctx.ParseTemplate("demo", args)
|
||||
info := fmt.Sprintf("Template: %s\nVariables: %v", tpl.Content, tpl.Variables)
|
||||
ctx.PrintInfo(info)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /render-template command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "render-template",
|
||||
Description: "Render a template with variables (usage: /render-template name=John place=Kit)",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
tpl := ctx.ParseTemplate("demo", "Hello {{name}}, welcome to {{place}}!")
|
||||
vars := ctx.ParseArguments(args, ext.ArgumentPattern{
|
||||
Flags: map[string]string{"name": "name", "place": "place"},
|
||||
})
|
||||
rendered := ctx.RenderTemplate(tpl, vars.Vars)
|
||||
ctx.PrintInfo("Rendered: " + rendered)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /check-model command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "check-model",
|
||||
Description: "Check model capabilities and availability",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
model := args
|
||||
if model == "" {
|
||||
model = ctx.Model
|
||||
}
|
||||
|
||||
available := ctx.CheckModelAvailable(model)
|
||||
caps, err := ctx.GetModelCapabilities(model)
|
||||
|
||||
info := fmt.Sprintf("Model: %s\n", model)
|
||||
info += fmt.Sprintf("Available: %v\n", available)
|
||||
if err == "" {
|
||||
info += fmt.Sprintf("Provider: %s\n", caps.Provider)
|
||||
info += fmt.Sprintf("Context Limit: %d\n", caps.ContextLimit)
|
||||
info += fmt.Sprintf("Reasoning: %v\n", caps.Reasoning)
|
||||
} else {
|
||||
info += fmt.Sprintf("Error: %s\n", err)
|
||||
}
|
||||
ctx.PrintInfo(info)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /resolve-chain command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "resolve-chain",
|
||||
Description: "Resolve a model chain (usage: /resolve-chain claude-opus,gpt-4o,claude-sonnet)",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
if args == "" {
|
||||
args = "anthropic/claude-opus-4,anthropic/claude-sonnet-4,openai/gpt-4o"
|
||||
}
|
||||
prefs := ctx.SimpleParseArguments(args, 1)
|
||||
chain := []string{}
|
||||
if len(prefs) > 1 {
|
||||
// Split the first arg by comma
|
||||
for _, p := range strings.Split(prefs[1], ",") {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
chain = append(chain, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := ctx.ResolveModelChain(chain)
|
||||
info, _ := json.MarshalIndent(result, "", " ")
|
||||
ctx.PrintInfo("Resolution Result:\n" + string(info))
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /test-conditional command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "test-conditional",
|
||||
Description: "Test model conditional rendering",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
content := `<if-model is="claude-*">This is for Claude models<else>This is for other models</if-model>`
|
||||
rendered := ctx.RenderWithModelConditionals(content)
|
||||
ctx.PrintInfo("Input: " + content)
|
||||
ctx.PrintInfo("Output: " + rendered)
|
||||
ctx.PrintInfo(fmt.Sprintf("Current model matches 'claude-*': %v", ctx.EvaluateModelConditional("claude-*")))
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// OnSessionStart: discover skills automatically
|
||||
api.OnSessionStart(func(e ext.SessionStartEvent, ctx ext.Context) {
|
||||
result := ctx.DiscoverSkills()
|
||||
if result.Error == "" && len(result.Skills) > 0 {
|
||||
discoveredSkills = result.Skills
|
||||
ctx.SetStatus("bridge-demo", fmt.Sprintf("%d skills", len(result.Skills)), 50)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func truncate(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max-3] + "..."
|
||||
}
|
||||
@@ -0,0 +1,406 @@
|
||||
//go:build ignore
|
||||
|
||||
// conversation-manager.go - Advanced conversation tree navigation and management.
|
||||
// This extension demonstrates:
|
||||
// - Tree navigation (GetTreeNode, GetCurrentBranch, NavigateTo)
|
||||
// - Branch summarization and collapsing
|
||||
// - Interactive tree exploration
|
||||
//
|
||||
// Commands:
|
||||
// /tree - Show conversation tree structure
|
||||
// /branch - Show current branch path
|
||||
// /goto <entry-id> - Navigate to a specific entry
|
||||
// /summarize <n> - Summarize last N messages
|
||||
// /fresh-context - Collapse branch and start fresh
|
||||
// /loop <n> <prompt> - Execute prompt N times with fresh context each iteration
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
var (
|
||||
loopActive bool
|
||||
loopCount int
|
||||
loopCurrent int
|
||||
loopPrompt string
|
||||
loopStartNode string
|
||||
)
|
||||
|
||||
func Init(api ext.API) {
|
||||
// /tree - Show tree structure
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "tree",
|
||||
Description: "Show conversation tree structure",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
showTree(ctx)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// /branch - Show current branch
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "branch",
|
||||
Description: "Show current conversation branch",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
showBranch(ctx)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// /goto - Navigate to entry
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "goto",
|
||||
Description: "Navigate to a specific entry ID (usage: /goto <entry-id>)",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
if args == "" {
|
||||
ctx.PrintError("Usage: /goto <entry-id>")
|
||||
return "", nil
|
||||
}
|
||||
result := ctx.NavigateTo(args)
|
||||
if !result.Success {
|
||||
ctx.PrintError(fmt.Sprintf("Navigation failed: %s", result.Error))
|
||||
return "", nil
|
||||
}
|
||||
ctx.PrintInfo(fmt.Sprintf("Navigated to entry: %s", args))
|
||||
|
||||
// Show the node we navigated to
|
||||
node := ctx.GetTreeNode(args)
|
||||
if node != nil {
|
||||
ctx.PrintInfo(fmt.Sprintf("Entry type: %s, Role: %s", node.Type, node.Role))
|
||||
}
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// /summarize - Summarize recent messages
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "summarize",
|
||||
Description: "Summarize last N messages (usage: /summarize [n=5])",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
n := 5
|
||||
if args != "" {
|
||||
if parsed, err := strconv.Atoi(args); err == nil && parsed > 0 {
|
||||
n = parsed
|
||||
}
|
||||
}
|
||||
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) < 2 {
|
||||
ctx.PrintError("Not enough messages to summarize")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Find range to summarize
|
||||
startIdx := len(branch) - n - 1
|
||||
if startIdx < 0 {
|
||||
startIdx = 0
|
||||
}
|
||||
endIdx := len(branch) - 1
|
||||
|
||||
fromID := branch[startIdx].ID
|
||||
toID := branch[endIdx].ID
|
||||
|
||||
ctx.PrintInfo(fmt.Sprintf("Summarizing messages %d to %d...", startIdx, endIdx))
|
||||
summary := ctx.SummarizeBranch(fromID, toID)
|
||||
|
||||
if summary == "" {
|
||||
ctx.PrintError("Failed to generate summary")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: summary,
|
||||
BorderColor: "#89b4fa",
|
||||
Subtitle: "conversation-manager · Summary",
|
||||
})
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// /fresh-context - Collapse and restart
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "fresh-context",
|
||||
Description: "Collapse conversation to summary and start fresh",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) < 3 {
|
||||
ctx.PrintError("Not enough context to collapse")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Keep first message (system), summarize rest
|
||||
fromID := branch[1].ID
|
||||
toID := branch[len(branch)-1].ID
|
||||
|
||||
ctx.PrintInfo("Generating summary for context collapse...")
|
||||
summary := ctx.SummarizeBranch(fromID, toID)
|
||||
|
||||
if summary == "" {
|
||||
ctx.PrintError("Failed to generate summary")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Collapse the branch
|
||||
result := ctx.CollapseBranch(fromID, toID, summary)
|
||||
if !result.Success {
|
||||
ctx.PrintError(fmt.Sprintf("Collapse failed: %s", result.Error))
|
||||
return "", nil
|
||||
}
|
||||
|
||||
ctx.PrintInfo("Context collapsed. Starting fresh with summary.")
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: summary,
|
||||
BorderColor: "#a6e3a1",
|
||||
Subtitle: "conversation-manager · Collapsed Context",
|
||||
})
|
||||
|
||||
// Set a widget showing we're in fresh mode
|
||||
ctx.SetWidget(ext.WidgetConfig{
|
||||
ID: "fresh-context",
|
||||
Placement: ext.WidgetAbove,
|
||||
Content: ext.WidgetContent{Text: "🌱 Fresh Context Mode - Previous conversation collapsed"},
|
||||
Style: ext.WidgetStyle{BorderColor: "#a6e3a1"},
|
||||
})
|
||||
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// /loop - Execute with fresh context each iteration
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "loop",
|
||||
Description: "Execute prompt N times with fresh context (usage: /loop 5 analyze this code)",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
if loopActive {
|
||||
ctx.PrintError("Loop already in progress. Wait for completion.")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Parse arguments
|
||||
parts := strings.SplitN(args, " ", 2)
|
||||
if len(parts) < 2 {
|
||||
ctx.PrintError("Usage: /loop <count> <prompt>")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
count, err := strconv.Atoi(parts[0])
|
||||
if err != nil || count <= 0 || count > 10 {
|
||||
ctx.PrintError("Invalid count (must be 1-10)")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
loopCount = count
|
||||
loopCurrent = 0
|
||||
loopPrompt = parts[1]
|
||||
loopActive = true
|
||||
|
||||
// Store current branch position
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) > 0 {
|
||||
loopStartNode = branch[len(branch)-1].ID
|
||||
}
|
||||
|
||||
ctx.PrintInfo(fmt.Sprintf("Starting loop: %d iterations", loopCount))
|
||||
ctx.SetWidget(ext.WidgetConfig{
|
||||
ID: "loop-progress",
|
||||
Placement: ext.WidgetAbove,
|
||||
Content: ext.WidgetContent{Text: fmt.Sprintf("🔄 Loop: 0/%d - %s", loopCount, loopPrompt)},
|
||||
Style: ext.WidgetStyle{BorderColor: "#fab387"},
|
||||
})
|
||||
|
||||
// Start first iteration
|
||||
executeLoopIteration(ctx)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// OnAgentEnd handles loop continuation
|
||||
api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) {
|
||||
if !loopActive {
|
||||
return
|
||||
}
|
||||
|
||||
loopCurrent++
|
||||
|
||||
if loopCurrent >= loopCount {
|
||||
// Loop complete
|
||||
loopActive = false
|
||||
ctx.RemoveWidget("loop-progress")
|
||||
ctx.PrintInfo(fmt.Sprintf("✅ Loop complete: %d/%d iterations", loopCurrent, loopCount))
|
||||
|
||||
// Show final summary
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) > 0 && loopStartNode != "" {
|
||||
summary := ctx.SummarizeBranch(loopStartNode, branch[len(branch)-1].ID)
|
||||
if summary != "" {
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: summary,
|
||||
BorderColor: "#a6e3a1",
|
||||
Subtitle: "conversation-manager · Loop Summary",
|
||||
})
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Update progress
|
||||
ctx.SetWidget(ext.WidgetConfig{
|
||||
ID: "loop-progress",
|
||||
Placement: ext.WidgetAbove,
|
||||
Content: ext.WidgetContent{Text: fmt.Sprintf("🔄 Loop: %d/%d - %s", loopCurrent, loopCount, loopPrompt)},
|
||||
Style: ext.WidgetStyle{BorderColor: "#fab387"},
|
||||
})
|
||||
|
||||
// Collapse previous iteration for fresh context
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) >= 2 {
|
||||
// Find the user messages (look for the one before the last assistant message)
|
||||
// We want to collapse from the user message that started this iteration
|
||||
// to the last assistant response
|
||||
var collapseStartIdx = -1
|
||||
for i := len(branch) - 1; i >= 0; i-- {
|
||||
if branch[i].Role == "assistant" {
|
||||
// Found the last assistant message, now find the user message before it
|
||||
for j := i - 1; j >= 0; j-- {
|
||||
if branch[j].Role == "user" {
|
||||
collapseStartIdx = j
|
||||
break
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if collapseStartIdx >= 0 {
|
||||
fromID := branch[collapseStartIdx].ID
|
||||
toID := branch[len(branch)-1].ID
|
||||
|
||||
ctx.PrintInfo(fmt.Sprintf("Collapsing iteration %d for fresh context...", loopCurrent))
|
||||
summary := ctx.SummarizeBranch(fromID, toID)
|
||||
if summary != "" {
|
||||
result := ctx.CollapseBranch(fromID, toID, summary)
|
||||
if result.Success {
|
||||
ctx.PrintInfo("Context collapsed successfully")
|
||||
} else {
|
||||
ctx.PrintError(fmt.Sprintf("Collapse failed: %s", result.Error))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Small delay to let UI update
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Trigger next iteration
|
||||
executeLoopIteration(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
// showTree displays the conversation tree structure
|
||||
func showTree(ctx ext.Context) {
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) == 0 {
|
||||
ctx.PrintInfo("Tree is empty")
|
||||
return
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString(fmt.Sprintf("Conversation Tree (%d nodes):\n\n", len(branch)))
|
||||
|
||||
for i, node := range branch {
|
||||
prefix := " "
|
||||
if i == len(branch)-1 {
|
||||
prefix = "▶ " // Current node
|
||||
} else {
|
||||
prefix = " "
|
||||
}
|
||||
|
||||
roleIcon := "💬"
|
||||
switch node.Role {
|
||||
case "user":
|
||||
roleIcon = "👤"
|
||||
case "assistant":
|
||||
roleIcon = "🤖"
|
||||
case "system":
|
||||
roleIcon = "⚙️"
|
||||
}
|
||||
|
||||
content := truncate(node.Content, 50)
|
||||
if node.Type == "branch_summary" {
|
||||
roleIcon = "📋"
|
||||
content = "[Summary] " + truncate(node.Content, 40)
|
||||
}
|
||||
|
||||
output.WriteString(fmt.Sprintf("%s%s %s: %s (%s...)\n", prefix, roleIcon, node.Role, node.ID[:8], content))
|
||||
|
||||
// Show children count if any
|
||||
children := ctx.GetChildren(node.ID)
|
||||
if len(children) > 0 {
|
||||
output.WriteString(fmt.Sprintf(" └─ %d branch(es)\n", len(children)))
|
||||
}
|
||||
}
|
||||
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: output.String(),
|
||||
BorderColor: "#89b4fa",
|
||||
Subtitle: "conversation-manager · Tree View",
|
||||
})
|
||||
}
|
||||
|
||||
// showBranch displays the current branch path
|
||||
func showBranch(ctx ext.Context) {
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) == 0 {
|
||||
ctx.PrintInfo("No active branch")
|
||||
return
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString(fmt.Sprintf("Current Branch (%d nodes from root to leaf):\n\n", len(branch)))
|
||||
|
||||
for i, node := range branch {
|
||||
marker := " "
|
||||
if i == len(branch)-1 {
|
||||
marker = "▶ " // Current leaf
|
||||
}
|
||||
|
||||
output.WriteString(fmt.Sprintf("%s[%d] %s (%s): %s\n",
|
||||
marker, i, node.Type, node.ID[:8], truncate(node.Content, 40)))
|
||||
}
|
||||
|
||||
// Show current node details
|
||||
leaf := branch[len(branch)-1]
|
||||
output.WriteString(fmt.Sprintf("\nCurrent Leaf:\n"))
|
||||
output.WriteString(fmt.Sprintf(" ID: %s\n", leaf.ID))
|
||||
output.WriteString(fmt.Sprintf(" Type: %s\n", leaf.Type))
|
||||
output.WriteString(fmt.Sprintf(" Role: %s\n", leaf.Role))
|
||||
output.WriteString(fmt.Sprintf(" Model: %s\n", leaf.Model))
|
||||
output.WriteString(fmt.Sprintf(" Children: %d\n", len(leaf.Children)))
|
||||
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: output.String(),
|
||||
BorderColor: "#cba6f7",
|
||||
Subtitle: "conversation-manager · Branch View",
|
||||
})
|
||||
}
|
||||
|
||||
// executeLoopIteration triggers the next loop iteration
|
||||
func executeLoopIteration(ctx ext.Context) {
|
||||
iterationPrompt := fmt.Sprintf("[%d/%d] %s", loopCurrent+1, loopCount, loopPrompt)
|
||||
ctx.SendMessage(iterationPrompt)
|
||||
}
|
||||
|
||||
// truncate helper
|
||||
func truncate(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max-3] + "..."
|
||||
}
|
||||
@@ -908,7 +908,7 @@ func summarizeToolAction(toolName string, inputJSON string) string {
|
||||
return "searching " + getStr("pattern", "text")
|
||||
case "ls":
|
||||
return "listing " + getStr("path", "directory")
|
||||
case "spawn_subagent":
|
||||
case "subagent":
|
||||
return "spawning subagent"
|
||||
default:
|
||||
return "using " + toolName
|
||||
|
||||
@@ -0,0 +1,269 @@
|
||||
//go:build ignore
|
||||
|
||||
// prompt-templates.go - Frontmatter-driven prompt templates with model switching.
|
||||
// This extension demonstrates the new bridged SDK APIs:
|
||||
// - Tree navigation for conversation management
|
||||
// - Template parsing with {{variable}} substitution
|
||||
// - Model resolution with fallback chains
|
||||
// - Skill injection
|
||||
//
|
||||
// Usage:
|
||||
// 1. Create ~/.config/kit/prompts/debug.md with frontmatter:
|
||||
// ---
|
||||
// description: Debug Python code
|
||||
// model: claude-sonnet-4-20250514
|
||||
// skill: python
|
||||
// ---
|
||||
// Help me debug this Python code: {{input}}
|
||||
//
|
||||
// 2. In Kit: /debug my_script.py
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
// PromptTemplate represents a loaded template with frontmatter
|
||||
type PromptTemplate struct {
|
||||
Name string
|
||||
Description string
|
||||
Model string
|
||||
Skill string
|
||||
Content string
|
||||
Variables []string
|
||||
Path string
|
||||
}
|
||||
|
||||
var (
|
||||
templates = make(map[string]PromptTemplate)
|
||||
templateDir string
|
||||
)
|
||||
|
||||
func Init(api ext.API) {
|
||||
// Determine template directory
|
||||
home, _ := os.UserHomeDir()
|
||||
templateDir = filepath.Join(home, ".config", "kit", "prompts")
|
||||
|
||||
// Ensure directory exists
|
||||
os.MkdirAll(templateDir, 0755)
|
||||
|
||||
// Register commands
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "reload-templates",
|
||||
Description: "Reload prompt templates from disk",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
loadTemplates(ctx)
|
||||
ctx.PrintInfo(fmt.Sprintf("Loaded %d templates from %s", len(templates), templateDir))
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Dynamic template commands are registered after loading
|
||||
api.OnSessionStart(func(e ext.SessionStartEvent, ctx ext.Context) {
|
||||
loadTemplates(ctx)
|
||||
registerTemplateCommands(api, ctx)
|
||||
})
|
||||
}
|
||||
|
||||
// loadTemplates discovers and loads all template files
|
||||
func loadTemplates(ctx ext.Context) {
|
||||
templates = make(map[string]PromptTemplate)
|
||||
|
||||
entries, err := os.ReadDir(templateDir)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".md") {
|
||||
continue
|
||||
}
|
||||
|
||||
path := filepath.Join(templateDir, entry.Name())
|
||||
tpl, err := loadTemplateFile(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
name := strings.TrimSuffix(entry.Name(), ".md")
|
||||
templates[name] = tpl
|
||||
}
|
||||
}
|
||||
|
||||
// loadTemplateFile parses a template with YAML frontmatter
|
||||
func loadTemplateFile(path string) (PromptTemplate, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return PromptTemplate{}, err
|
||||
}
|
||||
|
||||
content := string(data)
|
||||
tpl := PromptTemplate{Path: path}
|
||||
|
||||
// Parse frontmatter
|
||||
if strings.HasPrefix(content, "---") {
|
||||
parts := strings.SplitN(content[3:], "---", 2)
|
||||
if len(parts) == 2 {
|
||||
frontmatter := strings.TrimSpace(parts[0])
|
||||
body := strings.TrimSpace(parts[1])
|
||||
|
||||
// Simple line-by-line frontmatter parsing
|
||||
for _, line := range strings.Split(frontmatter, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
key, value, found := strings.Cut(line, ":")
|
||||
if found {
|
||||
key = strings.TrimSpace(key)
|
||||
value = strings.TrimSpace(value)
|
||||
switch key {
|
||||
case "description":
|
||||
tpl.Description = value
|
||||
case "model":
|
||||
tpl.Model = value
|
||||
case "skill":
|
||||
tpl.Skill = value
|
||||
}
|
||||
}
|
||||
}
|
||||
tpl.Content = body
|
||||
} else {
|
||||
tpl.Content = content
|
||||
}
|
||||
} else {
|
||||
tpl.Content = content
|
||||
}
|
||||
|
||||
// Parse {{variables}} using simple string parsing
|
||||
// (Can't use ctx.ParseTemplate here since we're in Init, not a handler)
|
||||
var vars []string
|
||||
for {
|
||||
start := strings.Index(tpl.Content, "{{")
|
||||
if start == -1 {
|
||||
break
|
||||
}
|
||||
end := strings.Index(tpl.Content[start:], "}}")
|
||||
if end == -1 {
|
||||
break
|
||||
}
|
||||
varName := strings.TrimSpace(tpl.Content[start+2 : start+end])
|
||||
vars = append(vars, varName)
|
||||
tpl.Content = tpl.Content[:start] + "{{" + varName + "}}" + tpl.Content[start+end+2:]
|
||||
}
|
||||
tpl.Variables = vars
|
||||
|
||||
return tpl, nil
|
||||
}
|
||||
|
||||
// registerTemplateCommands dynamically registers commands for each template
|
||||
func registerTemplateCommands(api ext.API, ctx ext.Context) {
|
||||
for name, tpl := range templates {
|
||||
// Skip if already registered (we'd need to track this)
|
||||
tplCopy := tpl // Capture for closure
|
||||
nameCopy := name
|
||||
|
||||
// Build description with metadata
|
||||
desc := tplCopy.Description
|
||||
if desc == "" {
|
||||
desc = fmt.Sprintf("Run %s template", nameCopy)
|
||||
}
|
||||
if tplCopy.Model != "" {
|
||||
desc += fmt.Sprintf(" [%s", tplCopy.Model)
|
||||
if tplCopy.Skill != "" {
|
||||
desc += fmt.Sprintf(" +%s", tplCopy.Skill)
|
||||
}
|
||||
desc += "]"
|
||||
}
|
||||
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: nameCopy,
|
||||
Description: desc,
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
return executeTemplate(ctx, tplCopy, args)
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// executeTemplate runs a template with the given arguments
|
||||
func executeTemplate(ctx ext.Context, tpl PromptTemplate, args string) (string, error) {
|
||||
// Store original model for restoration
|
||||
originalModel := ctx.Model
|
||||
|
||||
// 1. Resolve and switch model if specified
|
||||
if tpl.Model != "" {
|
||||
// Parse model chain (comma-separated)
|
||||
preferences := strings.Split(tpl.Model, ",")
|
||||
for i := range preferences {
|
||||
preferences[i] = strings.TrimSpace(preferences[i])
|
||||
}
|
||||
|
||||
result := ctx.ResolveModelChain(preferences)
|
||||
if result.Error != "" {
|
||||
ctx.PrintError(fmt.Sprintf("Model resolution failed: %s", result.Error))
|
||||
// Continue with current model
|
||||
} else {
|
||||
ctx.PrintInfo(fmt.Sprintf("Switching to model: %s", result.Model))
|
||||
if err := ctx.SetModel(result.Model); err != nil {
|
||||
ctx.PrintError(fmt.Sprintf("Failed to switch model: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Inject skill if specified
|
||||
if tpl.Skill != "" {
|
||||
err := ctx.InjectSkillAsContext(tpl.Skill)
|
||||
if err != "" {
|
||||
ctx.PrintError(fmt.Sprintf("Skill injection failed: %s", err))
|
||||
} else {
|
||||
ctx.PrintInfo(fmt.Sprintf("Injected skill: %s", tpl.Skill))
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Parse and render template
|
||||
parsed := ctx.ParseTemplate(tpl.Name, tpl.Content)
|
||||
|
||||
// Build variable map
|
||||
vars := make(map[string]string)
|
||||
|
||||
// Simple argument parsing: first arg is $1 (input), rest is $@
|
||||
if len(parsed.Variables) > 0 {
|
||||
argsList := ctx.SimpleParseArguments(args, len(parsed.Variables))
|
||||
for i, varName := range parsed.Variables {
|
||||
if i < len(parsed.Variables) && i+1 < len(argsList) {
|
||||
vars[varName] = argsList[i+1]
|
||||
}
|
||||
}
|
||||
// If single variable, use full args
|
||||
if len(parsed.Variables) == 1 && vars[parsed.Variables[0]] == "" {
|
||||
vars[parsed.Variables[0]] = args
|
||||
}
|
||||
}
|
||||
|
||||
// Render with model conditionals
|
||||
content := ctx.RenderWithModelConditionals(tpl.Content)
|
||||
rendered := ctx.RenderTemplate(ext.PromptTemplate{Name: tpl.Name, Content: content, Variables: parsed.Variables}, vars)
|
||||
|
||||
// 4. Send the rendered prompt
|
||||
ctx.SendMessage(rendered)
|
||||
|
||||
// 5. Schedule model restoration after turn completes
|
||||
// We use a goroutine to wait and restore
|
||||
if tpl.Model != "" && originalModel != "" {
|
||||
go func() {
|
||||
// Note: In a real implementation, we'd use OnAgentEnd event
|
||||
// For now, the user can manually switch back
|
||||
ctx.SetStatus("template-mode", fmt.Sprintf("Template: %s (model will restore)", tpl.Name), 20)
|
||||
}()
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Executing template: %s", tpl.Name), nil
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/pkg/extensions/test"
|
||||
)
|
||||
|
||||
// TestSubagentMonitor_SessionStart verifies OnSessionStart initializes state
|
||||
// without panicking and properly guards nil ctx calls.
|
||||
func TestSubagentMonitor_SessionStart(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
|
||||
// Emit SessionStart - should not panic even with nil ctx functions
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionStart should not error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubagentMonitor_SubagentLifecycle verifies the full subagent lifecycle
|
||||
// creates entries and emits widget updates.
|
||||
func TestSubagentMonitor_SubagentLifecycle(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
|
||||
// Start session
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionStart should not error: %v", err)
|
||||
}
|
||||
|
||||
// Emit SubagentStart
|
||||
_, err = harness.Emit(extensions.SubagentStartEvent{
|
||||
ToolCallID: "call-1",
|
||||
Task: "test task",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentStart should not error: %v", err)
|
||||
}
|
||||
|
||||
// Emit a few chunks
|
||||
for i := range 3 {
|
||||
_, err = harness.Emit(extensions.SubagentChunkEvent{
|
||||
ToolCallID: "call-1",
|
||||
Task: "test task",
|
||||
ChunkType: "text",
|
||||
Content: fmt.Sprintf("line %d", i),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentChunk %d should not error: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Emit tool call chunk
|
||||
_, err = harness.Emit(extensions.SubagentChunkEvent{
|
||||
ToolCallID: "call-1",
|
||||
Task: "test task",
|
||||
ChunkType: "tool_call",
|
||||
ToolName: "bash",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentChunk tool_call should not error: %v", err)
|
||||
}
|
||||
|
||||
// Emit SubagentEnd
|
||||
_, err = harness.Emit(extensions.SubagentEndEvent{
|
||||
ToolCallID: "call-1",
|
||||
Task: "test task",
|
||||
Response: "done",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentEnd should not error: %v", err)
|
||||
}
|
||||
|
||||
// Give time for cleanup goroutine
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestSubagentMonitor_MultipleSubagents verifies multiple parallel subagents.
|
||||
func TestSubagentMonitor_MultipleSubagents(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionStart should not error: %v", err)
|
||||
}
|
||||
|
||||
// Start 3 subagents
|
||||
for i := 1; i <= 3; i++ {
|
||||
_, err := harness.Emit(extensions.SubagentStartEvent{
|
||||
ToolCallID: fmt.Sprintf("call-%d", i),
|
||||
Task: fmt.Sprintf("task %d", i),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentStart %d should not error: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Emit chunks for each
|
||||
for i := 1; i <= 3; i++ {
|
||||
_, err := harness.Emit(extensions.SubagentChunkEvent{
|
||||
ToolCallID: fmt.Sprintf("call-%d", i),
|
||||
Task: fmt.Sprintf("task %d", i),
|
||||
ChunkType: "text",
|
||||
Content: fmt.Sprintf("output from agent %d", i),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentChunk %d should not error: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// End all subagents
|
||||
for i := 1; i <= 3; i++ {
|
||||
_, err := harness.Emit(extensions.SubagentEndEvent{
|
||||
ToolCallID: fmt.Sprintf("call-%d", i),
|
||||
Task: fmt.Sprintf("task %d", i),
|
||||
Response: "completed",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentEnd %d should not error: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestSubagentMonitor_SessionShutdown verifies shutdown doesn't panic
|
||||
// even with nil ctx functions.
|
||||
func TestSubagentMonitor_SessionShutdown(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
|
||||
// Start then shutdown
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionStart should not error: %v", err)
|
||||
}
|
||||
|
||||
// Start a subagent
|
||||
_, err = harness.Emit(extensions.SubagentStartEvent{
|
||||
ToolCallID: "call-1",
|
||||
Task: "test task",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentStart should not error: %v", err)
|
||||
}
|
||||
|
||||
// Shutdown - should not panic even with active subagent
|
||||
_, err = harness.Emit(extensions.SessionShutdownEvent{})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionShutdown should not error: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -37,7 +37,7 @@ func Init(api ext.API) {
|
||||
"Subagent Test Extension loaded\n\n" +
|
||||
"/subtest <task> Spawn blocking subagent\n" +
|
||||
"/subbg <task> Spawn background subagent\n\n" +
|
||||
"The LLM can also use the spawn_subagent tool.")
|
||||
"The LLM can also use the subagent tool.")
|
||||
})
|
||||
|
||||
api.OnAgentEnd(func(_ ext.AgentEndEvent, ctx ext.Context) {
|
||||
|
||||
@@ -5,7 +5,7 @@ go 1.26.1
|
||||
require (
|
||||
charm.land/bubbles/v2 v2.0.0
|
||||
charm.land/bubbletea/v2 v2.0.2
|
||||
charm.land/fantasy v0.16.0
|
||||
charm.land/fantasy v0.17.1
|
||||
charm.land/huh/v2 v2.0.3
|
||||
charm.land/lipgloss/v2 v2.0.2
|
||||
github.com/alecthomas/chroma/v2 v2.23.1
|
||||
@@ -13,7 +13,7 @@ require (
|
||||
github.com/charmbracelet/fang v1.0.0
|
||||
github.com/charmbracelet/log v1.0.0
|
||||
github.com/coder/acp-go-sdk v0.6.3
|
||||
github.com/mark3labs/mcp-go v0.45.0
|
||||
github.com/mark3labs/mcp-go v0.46.0
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/traefik/yaegi v0.16.1
|
||||
@@ -23,14 +23,14 @@ require (
|
||||
|
||||
require (
|
||||
cloud.google.com/go v0.123.0 // indirect
|
||||
cloud.google.com/go/auth v0.18.2 // indirect
|
||||
cloud.google.com/go/auth v0.19.0 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.9.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect
|
||||
github.com/atotto/clipboard v0.1.4 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.4 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.12 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.12 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20 // indirect
|
||||
@@ -45,8 +45,6 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 // indirect
|
||||
github.com/aws/smithy-go v1.24.2 // indirect
|
||||
github.com/aymerick/douceur v0.2.0 // indirect
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/buger/jsonparser v1.1.2 // indirect
|
||||
github.com/catppuccin/go v0.3.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab // indirect
|
||||
@@ -56,9 +54,9 @@ require (
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266 // indirect
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260316091819-b93f6a3b8502 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260322003602-9b007323c5cd // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260323091123-df7b1bcffcca // indirect
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260322003602-9b007323c5cd // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260323091123-df7b1bcffcca // indirect
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/json v0.2.0 // indirect
|
||||
github.com/charmbracelet/x/termios v0.1.1 // indirect
|
||||
@@ -77,18 +75,18 @@ require (
|
||||
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/jsonschema-go v0.4.2 // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.19.0 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.20.0 // indirect
|
||||
github.com/gorilla/css v1.0.1 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/invopop/jsonschema v0.13.0 // indirect
|
||||
github.com/indaco/herald v0.9.0 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.2.12 // indirect
|
||||
github.com/kaptinlin/jsonpointer v0.4.17 // indirect
|
||||
github.com/kaptinlin/jsonschema v0.7.6 // indirect
|
||||
github.com/kaptinlin/messageformat-go v0.4.18 // indirect
|
||||
github.com/mailru/easyjson v0.9.2 // indirect
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect
|
||||
github.com/muesli/mango v0.2.0 // indirect
|
||||
@@ -96,7 +94,7 @@ require (
|
||||
github.com/muesli/mango-pflag v0.2.0 // indirect
|
||||
github.com/muesli/reflow v0.3.0 // indirect
|
||||
github.com/muesli/roff v0.1.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.3.0 // indirect
|
||||
github.com/sagikazarmark/locafero v0.12.0 // indirect
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
@@ -105,10 +103,9 @@ require (
|
||||
github.com/tidwall/match v1.2.0 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
github.com/yuin/goldmark v1.7.17 // indirect
|
||||
github.com/yuin/goldmark v1.8.2 // indirect
|
||||
github.com/yuin/goldmark-emoji v1.0.6 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 // indirect
|
||||
@@ -122,7 +119,7 @@ require (
|
||||
golang.org/x/net v0.52.0 // indirect
|
||||
golang.org/x/oauth2 v0.36.0 // indirect
|
||||
golang.org/x/time v0.15.0 // indirect
|
||||
google.golang.org/api v0.272.0 // indirect
|
||||
google.golang.org/api v0.273.0 // indirect
|
||||
google.golang.org/genai v1.51.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7 // indirect
|
||||
google.golang.org/grpc v1.79.3 // indirect
|
||||
|
||||
@@ -2,16 +2,16 @@ charm.land/bubbles/v2 v2.0.0 h1:tE3eK/pHjmtrDiRdoC9uGNLgpopOd8fjhEe31B/ai5s=
|
||||
charm.land/bubbles/v2 v2.0.0/go.mod h1:rCHoleP2XhU8um45NTuOWBPNVHxnkXKTiZqcclL/qOI=
|
||||
charm.land/bubbletea/v2 v2.0.2 h1:4CRtRnuZOdFDTWSff9r8QFt/9+z6Emubz3aDMnf/dx0=
|
||||
charm.land/bubbletea/v2 v2.0.2/go.mod h1:3LRff2U4WIYXy7MTxfbAQ+AdfM3D8Xuvz2wbsOD9OHQ=
|
||||
charm.land/fantasy v0.16.0 h1:vE/6sR9nPcSD8qXJXX6wR8NXjtWlBVAzwQmTh5pHVrs=
|
||||
charm.land/fantasy v0.16.0/go.mod h1:VZjpXVh7IgeiIzGQybEnKzd68ofDsRj94+kzH1ZCAfQ=
|
||||
charm.land/fantasy v0.17.1 h1:SQzfnyJPDuQWt6e//KKmQmEEXdqHMC0IZz10XwkLcEM=
|
||||
charm.land/fantasy v0.17.1/go.mod h1:FF5ALCCHETacHJPBqU42CtwMInYQ0ul52fdzIHQMbQk=
|
||||
charm.land/huh/v2 v2.0.3 h1:2cJsMqEPwSywGHvdlKsJyQKPtSJLVnFKyFbsYZTlLkU=
|
||||
charm.land/huh/v2 v2.0.3/go.mod h1:93eEveeeqn47MwiC3tf+2atZ2l7Is88rAtmZNZ8x9Wc=
|
||||
charm.land/lipgloss/v2 v2.0.2 h1:xFolbF8JdpNkM2cEPTfXEcW1p6NRzOWTSamRfYEw8cs=
|
||||
charm.land/lipgloss/v2 v2.0.2/go.mod h1:KjPle2Qd3YmvP1KL5OMHiHysGcNwq6u83MUjYkFvEkM=
|
||||
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
|
||||
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
|
||||
cloud.google.com/go/auth v0.18.2 h1:+Nbt5Ev0xEqxlNjd6c+yYUeosQ5TtEUaNcN/3FozlaM=
|
||||
cloud.google.com/go/auth v0.18.2/go.mod h1:xD+oY7gcahcu7G2SG2DsBerfFxgPAJz17zz2joOFF3M=
|
||||
cloud.google.com/go/auth v0.19.0 h1:DGYwtbcsGsT1ywuxsIoWi1u/vlks0moIblQHgSDgQkQ=
|
||||
cloud.google.com/go/auth v0.19.0/go.mod h1:2Aph7BT2KnaSFOM0JDPyiYgNh6PL9vGMiP8CUIXZ+IY=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
|
||||
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
|
||||
@@ -36,8 +36,8 @@ github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.4 h1:10f50G7WyU02T56ox1wWXq+zTX9I1zxG46HYuG1hH/k=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.4/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 h1:3kGOqnh1pPeddVa/E37XNTaWJ8W6vrbYV9lJEkCnhuY=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 h1:eBMB84YGghSocM7PsjmmPffTa+1FBUeNvGvFou6V/4o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.12 h1:O3csC7HUGn2895eNrLytOJQdoL2xyJy0iYXhoZ1OmP0=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.12/go.mod h1:96zTvoOFR4FURjI+/5wY1vc1ABceROO4lWgWJuxgy0g=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.12 h1:oqtA6v+y5fZg//tcTWahyN9PEn5eDU/Wpvc2+kJ4aY8=
|
||||
@@ -70,10 +70,6 @@ github.com/aymanbagabas/go-udiff v0.4.1 h1:OEIrQ8maEeDBXQDoGCbbTTXYJMYRCRO1fnodZ
|
||||
github.com/aymanbagabas/go-udiff v0.4.1/go.mod h1:0L9PGwj20lrtmEMeyw4WKJ/TMyDtvAoK9bf2u/mNo3w=
|
||||
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
|
||||
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
|
||||
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
|
||||
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
|
||||
github.com/buger/jsonparser v1.1.2 h1:frqHqw7otoVbk5M8LlE/L7HTnIq2v9RX6EJ48i9AxJk=
|
||||
github.com/buger/jsonparser v1.1.2/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
|
||||
github.com/catppuccin/go v0.3.0 h1:d+0/YicIq+hSTo5oPuRi5kOpqkVA5tAsU6dNhvRu+aY=
|
||||
github.com/catppuccin/go v0.3.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
@@ -104,14 +100,14 @@ github.com/charmbracelet/x/conpty v0.1.1 h1:s1bUxjoi7EpqiXysVtC+a8RrvPPNcNvAjfi4
|
||||
github.com/charmbracelet/x/conpty v0.1.1/go.mod h1:OmtR77VODEFbiTzGE9G1XiRJAga6011PIm4u5fTNZpk=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 h1:JSt3B+U9iqk37QUU2Rvb6DSBYRLtWqFqfxf8l5hOZUA=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86/go.mod h1:2P0UgXMEa6TsToMSuFqKFQR+fZTO9CNGUNokkPatT/0=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260322003602-9b007323c5cd h1:eStB6uX52pgrm6TxQcEKctPrEC+a/9ubJC+P671idOc=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260322003602-9b007323c5cd/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260323091123-df7b1bcffcca h1:62yAoS1Ynbuzwcn1LkNBxi3IMF5p0E0cHCoaLOOmN9w=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260323091123-df7b1bcffcca/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f h1:pk6gmGpCE7F3FcjaOEKYriCvpmIN4+6OS/RD0vm4uIA=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f/go.mod h1:IfZAMTHB6XkZSeXUqriemErjAWCCzT0LwjKFYCZyw0I=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 h1:55/qLwjIh0gL0Vni+QAWk7T/qRVP6sBf+2agPBgnOFE=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0/go.mod h1:5UHwmG+is5THxMyCJHNPCn2/ecI07aKNrW+LcResjJ8=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260322003602-9b007323c5cd h1:U8xj0UXwqHzO+UYHZJopKF+gWaQEW8oj60fmiq9TFY4=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260322003602-9b007323c5cd/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260323091123-df7b1bcffcca h1:QQoyQLgUzojMNWHVHToN6d9qTvT0KWtxUKIRPx/Ox5o=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260323091123-df7b1bcffcca/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 h1:i69S2XI7uG1u4NLGeJPSYU++Nmjvpo9nwd6aoEm7gkA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0/go.mod h1:/ehtMPNh9K4odGFkqYJKpIYyePhdp1hLBRvyY4bWkH8=
|
||||
github.com/charmbracelet/x/json v0.2.0 h1:DqB+ZGx2h+Z+1s98HOuOyli+i97wsFQIxP2ZQANTPrQ=
|
||||
@@ -173,14 +169,16 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
|
||||
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
|
||||
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
|
||||
github.com/googleapis/gax-go/v2 v2.19.0 h1:fYQaUOiGwll0cGj7jmHT/0nPlcrZDFPrZRhTsoCr8hE=
|
||||
github.com/googleapis/gax-go/v2 v2.19.0/go.mod h1:w2ROXVdfGEVFXzmlciUU4EdjHgWvB5h2n6x/8XSTTJA=
|
||||
github.com/googleapis/gax-go/v2 v2.20.0 h1:NIKVuLhDlIV74muWlsMM4CcQZqN6JJ20Qcxd9YMuYcs=
|
||||
github.com/googleapis/gax-go/v2 v2.20.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
|
||||
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
|
||||
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
@@ -189,8 +187,8 @@ github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUq
|
||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=
|
||||
github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0=
|
||||
github.com/indaco/herald v0.9.0 h1:LrAfXEHkKz8WmctUKdndppIU/qFpylSbZ8galS0DVAc=
|
||||
github.com/indaco/herald v0.9.0/go.mod h1:T5g1+XLYvpjouhzAGHnAHDCKizhESkoV6+QPZ3DhgWA=
|
||||
github.com/kaptinlin/go-i18n v0.2.12 h1:ywDsvb4KDFddMC2dpI/rrIzGU2mWUSvHmWUm9BMsdl4=
|
||||
github.com/kaptinlin/go-i18n v0.2.12/go.mod h1:pVcu9qsW5pOIOoZFJXesRYmLos1vMQrby70JPAoWmJU=
|
||||
github.com/kaptinlin/jsonpointer v0.4.17 h1:mY9k8ciWncxbsECyaxKnR0MdmxamNdp2tLQkAKVrtSk=
|
||||
@@ -207,10 +205,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mailru/easyjson v0.9.2 h1:dX8U45hQsZpxd80nLvDGihsQ/OxlvTkVUXH2r/8cb2M=
|
||||
github.com/mailru/easyjson v0.9.2/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
|
||||
github.com/mark3labs/mcp-go v0.45.0 h1:s0S8qR/9fWaQ3pHxz7pm1uQ0DrswoSnRIxKIjbiQtkc=
|
||||
github.com/mark3labs/mcp-go v0.45.0/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw=
|
||||
github.com/mark3labs/mcp-go v0.46.0 h1:8KRibF4wcKejbLsHxCA/QBVUr5fQ9nwz/n8lGqmaALo=
|
||||
github.com/mark3labs/mcp-go v0.46.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
|
||||
@@ -234,8 +230,8 @@ github.com/muesli/roff v0.1.0 h1:YD0lalCotmYuF5HhZliKWlIx7IEhiXeSfq7hNjFqGF8=
|
||||
github.com/muesli/roff v0.1.0/go.mod h1:pjAHQM9hdUUwm/krAfrLGgJkXJ+YuhtsfZ42kieB2Ig=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pelletier/go-toml/v2 v2.3.0 h1:k59bC/lIZREW0/iVaQR8nDHxVq8OVlIzYCOJf421CaM=
|
||||
github.com/pelletier/go-toml/v2 v2.3.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo=
|
||||
@@ -279,14 +275,12 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/traefik/yaegi v0.16.1 h1:f1De3DVJqIDKmnasUF6MwmWv1dSEEat0wcpXhD2On3E=
|
||||
github.com/traefik/yaegi v0.16.1/go.mod h1:4eVhbPb3LnD2VigQjhYbEJ69vDRFdT2HQNrXx8eEwUY=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/yuin/goldmark v1.7.17 h1:p36OVWwRb246iHxA/U4p8OPEpOTESm4n+g+8t0EE5uA=
|
||||
github.com/yuin/goldmark v1.7.17/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
|
||||
github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE=
|
||||
github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
|
||||
github.com/yuin/goldmark-emoji v1.0.6 h1:QWfF2FYaXwL74tfGOW5izeiZepUDroDJfWubQI9HTHs=
|
||||
github.com/yuin/goldmark-emoji v1.0.6/go.mod h1:ukxJDKFpdFb5x0a5HqbdlcKtebh086iJpI31LTKmWuA=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
@@ -328,10 +322,14 @@ golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/api v0.272.0 h1:eLUQZGnAS3OHn31URRf9sAmRk3w2JjMx37d2k8AjJmA=
|
||||
google.golang.org/api v0.272.0/go.mod h1:wKjowi5LNJc5qarNvDCvNQBn3rVK8nSy6jg2SwRwzIA=
|
||||
google.golang.org/api v0.273.0 h1:r/Bcv36Xa/te1ugaN1kdJ5LoA5Wj/cL+a4gj6FiPBjQ=
|
||||
google.golang.org/api v0.273.0/go.mod h1:JbAt7mF+XVmWu6xNP8/+CTiGH30ofmCmk9nM8d8fHew=
|
||||
google.golang.org/genai v1.51.0 h1:IZGuUqgfx40INv3hLFGCbOSGp0qFqm7LVmDghzNIYqg=
|
||||
google.golang.org/genai v1.51.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 h1:XzmzkmB14QhVhgnawEVsOn6OFsnpyxNPRY9QV01dNB0=
|
||||
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:L43LFes82YgSonw6iTXTxXUX1OlULt4AQtkik4ULL/I=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7 h1:41r6JMbpzBMen0R/4TZeeAmGXSJC7DftGINUodzTkPI=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:EIQZ5bFCfRQDV4MhRle7+OgjNtZ6P1PiZBgAKuxXu/Y=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7 h1:ndE4FoJqsIceKP2oYSnUZqhTdYufCYYkqwtFzfrhI7w=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE=
|
||||
|
||||
+63
-10
@@ -70,6 +70,11 @@ type ReasoningDeltaHandler func(delta string)
|
||||
// Note: This is an alias for core.ToolOutputCallback to avoid import cycles.
|
||||
type ToolOutputHandler = core.ToolOutputCallback
|
||||
|
||||
// StepUsageHandler is a function type for handling token usage after each
|
||||
// complete step in a multi-step agent turn. This enables real-time cost
|
||||
// tracking during long-running tool-calling conversations.
|
||||
type StepUsageHandler func(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64)
|
||||
|
||||
// Agent represents an AI agent with core tool integration using the fantasy library.
|
||||
// Core tools (bash, read, write, edit, grep, find, ls) are registered as direct
|
||||
// fantasy.AgentTool implementations — no MCP layer, no serialization overhead.
|
||||
@@ -178,7 +183,8 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
|
||||
// Pass generation parameters when available.
|
||||
if agentConfig.ModelConfig != nil {
|
||||
if agentConfig.ModelConfig.MaxTokens > 0 {
|
||||
// Skip max_output_tokens for providers that don't support it (e.g., Codex OAuth)
|
||||
if agentConfig.ModelConfig.MaxTokens > 0 && !providerResult.SkipMaxOutputTokens {
|
||||
agentOpts = append(agentOpts, fantasy.WithMaxOutputTokens(int64(agentConfig.ModelConfig.MaxTokens)))
|
||||
}
|
||||
if agentConfig.ModelConfig.Temperature != nil {
|
||||
@@ -225,7 +231,7 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []fantasy.Message
|
||||
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult,
|
||||
onResponse, onToolCallContent, nil, nil, nil)
|
||||
onResponse, onToolCallContent, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
// GenerateWithLoopAndStreaming processes messages using the fantasy agent with streaming and callbacks.
|
||||
@@ -237,6 +243,7 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
onStreamingResponse StreamingResponseHandler,
|
||||
onReasoningDelta ReasoningDeltaHandler,
|
||||
onToolOutput ToolOutputHandler,
|
||||
onStepUsage StepUsageHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
|
||||
// Inject tool output handler into context for use by core tools (e.g., bash).
|
||||
@@ -250,8 +257,7 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
// field so Fantasy includes them in the API request.
|
||||
prompt, files, history := splitPromptAndHistory(messages)
|
||||
|
||||
// Track current tool call info for callbacks
|
||||
var currentToolName string
|
||||
// Track current tool call args for callbacks
|
||||
var currentToolArgs string
|
||||
|
||||
// Use the streaming path when streaming is enabled OR when any callbacks are
|
||||
@@ -269,7 +275,7 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
var completedStepMessages []fantasy.Message
|
||||
|
||||
// Use fantasy's streaming agent
|
||||
result, err := a.fantasyAgent.Stream(ctx, fantasy.AgentStreamCall{
|
||||
streamCall := fantasy.AgentStreamCall{
|
||||
Prompt: prompt,
|
||||
Files: files,
|
||||
Messages: history,
|
||||
@@ -301,7 +307,6 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
currentToolName = tc.ToolName
|
||||
currentToolArgs = tc.Input
|
||||
|
||||
// Notify about the tool call
|
||||
@@ -351,9 +356,58 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
if text != "" && len(toolCalls) > 0 && onToolCallContent != nil {
|
||||
onToolCallContent(text)
|
||||
}
|
||||
// Emit step usage for real-time cost tracking
|
||||
if onStepUsage != nil {
|
||||
onStepUsage(step.Usage.InputTokens, step.Usage.OutputTokens,
|
||||
step.Usage.CacheReadTokens, step.Usage.CacheCreationTokens)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// If a steer channel is attached to the context, wire up a
|
||||
// PrepareStep function that drains the channel between steps
|
||||
// and injects pending steer messages as user messages before
|
||||
// the next LLM call. This enables graceful mid-turn steering
|
||||
// without cancelling in-progress tool execution.
|
||||
if steerCh := steerChFromContext(ctx); steerCh != nil {
|
||||
onConsumed := steerConsumedFromContext(ctx)
|
||||
streamCall.PrepareStep = func(
|
||||
stepCtx context.Context,
|
||||
opts fantasy.PrepareStepFunctionOptions,
|
||||
) (context.Context, fantasy.PrepareStepResult, error) {
|
||||
// Drain all pending steer messages (non-blocking).
|
||||
var steered []string
|
||||
for {
|
||||
select {
|
||||
case msg := <-steerCh:
|
||||
steered = append(steered, msg)
|
||||
default:
|
||||
goto done
|
||||
}
|
||||
}
|
||||
done:
|
||||
result := fantasy.PrepareStepResult{
|
||||
Model: opts.Model,
|
||||
Messages: opts.Messages,
|
||||
}
|
||||
if len(steered) > 0 {
|
||||
// Inject each steer message as a user message so the
|
||||
// LLM sees the redirection on the next step.
|
||||
for _, text := range steered {
|
||||
result.Messages = append(result.Messages,
|
||||
fantasy.NewUserMessage(text))
|
||||
}
|
||||
// Notify that steer messages were consumed.
|
||||
if onConsumed != nil {
|
||||
onConsumed(len(steered))
|
||||
}
|
||||
}
|
||||
return stepCtx, result, nil
|
||||
}
|
||||
}
|
||||
|
||||
result, err := a.fantasyAgent.Stream(ctx, streamCall)
|
||||
if err != nil {
|
||||
// On cancellation (or any error), return a partial result
|
||||
// containing messages from completed steps so the caller can
|
||||
@@ -396,8 +450,6 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
onResponse(result.Response.Content.Text())
|
||||
}
|
||||
|
||||
_ = currentToolName // satisfy compiler for non-streaming path
|
||||
|
||||
return convertAgentResult(result, messages), nil
|
||||
}
|
||||
|
||||
@@ -617,7 +669,8 @@ func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) err
|
||||
}
|
||||
|
||||
// Pass generation parameters when available.
|
||||
if config.MaxTokens > 0 {
|
||||
// Skip max_output_tokens for providers that don't support it (e.g., Codex OAuth)
|
||||
if config.MaxTokens > 0 && !providerResult.SkipMaxOutputTokens {
|
||||
agentOpts = append(agentOpts, fantasy.WithMaxOutputTokens(int64(config.MaxTokens)))
|
||||
}
|
||||
if config.Temperature != nil {
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
package agent
|
||||
|
||||
import "context"
|
||||
|
||||
// steerChKey is the context key for the steer channel.
|
||||
type steerChKey struct{}
|
||||
|
||||
// steerConsumedKey is the context key for the steer-consumed callback.
|
||||
type steerConsumedKey struct{}
|
||||
|
||||
// ContextWithSteerCh returns a new context with the steer channel attached.
|
||||
// The agent's PrepareStep function checks this channel between steps and
|
||||
// injects any pending steer messages as user messages before the next LLM call.
|
||||
func ContextWithSteerCh(ctx context.Context, ch <-chan string) context.Context {
|
||||
return context.WithValue(ctx, steerChKey{}, ch)
|
||||
}
|
||||
|
||||
// ContextWithSteerConsumed returns a new context with a callback that fires
|
||||
// when steer messages are consumed by PrepareStep. The count argument is the
|
||||
// number of messages injected in this batch.
|
||||
func ContextWithSteerConsumed(ctx context.Context, fn func(count int)) context.Context {
|
||||
return context.WithValue(ctx, steerConsumedKey{}, fn)
|
||||
}
|
||||
|
||||
// steerChFromContext extracts the steer channel from the context, or nil.
|
||||
func steerChFromContext(ctx context.Context) <-chan string {
|
||||
ch, _ := ctx.Value(steerChKey{}).(<-chan string)
|
||||
return ch
|
||||
}
|
||||
|
||||
// steerConsumedFromContext extracts the steer-consumed callback, or nil.
|
||||
func steerConsumedFromContext(ctx context.Context) func(int) {
|
||||
fn, _ := ctx.Value(steerConsumedKey{}).(func(int))
|
||||
return fn
|
||||
}
|
||||
+199
-34
@@ -3,7 +3,10 @@ package app
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/fantasy"
|
||||
@@ -159,11 +162,57 @@ func (a *App) QueueLength() int {
|
||||
return len(a.queue)
|
||||
}
|
||||
|
||||
// Steer cancels the current agent step (if running), clears the queue, and
|
||||
// sends a new message that will execute as soon as the current step finishes
|
||||
// cancelling. If the agent is idle, the message executes immediately.
|
||||
// This is the "steer" delivery mode for SendMessage.
|
||||
func (a *App) Steer(prompt string) {
|
||||
// Steer injects a steering message into the currently running agent turn.
|
||||
// If the agent is in a multi-step tool loop, the message is delivered after
|
||||
// the current tool execution finishes but before the next LLM call (graceful
|
||||
// mid-turn injection via Fantasy's PrepareStep). If the agent is streaming
|
||||
// a text-only response (no pending tool calls), the message waits until the
|
||||
// response completes and then executes as the next turn.
|
||||
//
|
||||
// If the agent is idle, the message starts executing immediately (same as Run).
|
||||
//
|
||||
// Returns the number of pending steer/queue items (0 = started immediately,
|
||||
// >0 = injected/queued). The caller must update UI state based on the return
|
||||
// value — Steer does NOT send events to the program to avoid deadlocking
|
||||
// when called from within Update().
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) Steer(prompt string) int {
|
||||
a.mu.Lock()
|
||||
|
||||
if a.closed {
|
||||
a.mu.Unlock()
|
||||
return 0
|
||||
}
|
||||
|
||||
if !a.busy {
|
||||
// Not busy — start immediately, same as Run().
|
||||
item := queueItem{Prompt: prompt}
|
||||
a.busy = true
|
||||
a.wg.Add(1)
|
||||
a.mu.Unlock()
|
||||
go a.drainQueue(item)
|
||||
return 0
|
||||
}
|
||||
|
||||
a.mu.Unlock()
|
||||
|
||||
// Agent is busy — inject via the SDK's steer channel. The message
|
||||
// will be picked up by PrepareStep between agent steps (after tool
|
||||
// execution, before next LLM call). If PrepareStep doesn't fire
|
||||
// (text-only response), drainQueue will pick it up after the turn.
|
||||
if a.opts.Kit != nil {
|
||||
a.opts.Kit.InjectSteer(prompt)
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// InterruptAndSend cancels the current agent step (if running), clears the
|
||||
// queue, and sends a new message that will execute as soon as the current
|
||||
// step finishes cancelling. If the agent is idle, the message executes
|
||||
// immediately. This is the hard-cancel delivery mode used by extensions'
|
||||
// CancelAndSend.
|
||||
func (a *App) InterruptAndSend(prompt string) {
|
||||
a.mu.Lock()
|
||||
|
||||
if a.closed {
|
||||
@@ -226,6 +275,10 @@ func (a *App) SwitchTreeSession(ts *session.TreeManager) {
|
||||
_ = old.Close()
|
||||
}
|
||||
a.opts.TreeSession = ts
|
||||
// Also update the kit SDK's tree session so messages are persisted correctly.
|
||||
if a.opts.Kit != nil {
|
||||
a.opts.Kit.SetTreeSession(ts)
|
||||
}
|
||||
// Reload messages from new session.
|
||||
a.store.Clear()
|
||||
if ts != nil {
|
||||
@@ -401,6 +454,13 @@ func (a *App) Close() {
|
||||
|
||||
// Wait for background goroutines.
|
||||
a.wg.Wait()
|
||||
|
||||
// Clean up empty session file on shutdown.
|
||||
if ts := a.opts.TreeSession; ts != nil && ts.IsEmpty() {
|
||||
if path := ts.GetFilePath(); path != "" {
|
||||
_ = os.Remove(path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -434,6 +494,24 @@ func (a *App) drainQueue(first queueItem) {
|
||||
// Process all collected items as a single batch
|
||||
a.runQueueBatch(items)
|
||||
|
||||
// Drain any unconsumed steer messages from the SDK channel.
|
||||
// These arrive when the user steered during a text-only response
|
||||
// (no tool calls, so PrepareStep didn't fire for a second step).
|
||||
// They go to the front of the queue so they run next.
|
||||
if a.opts.Kit != nil {
|
||||
if leftover := a.opts.Kit.DrainSteer(); len(leftover) > 0 {
|
||||
a.mu.Lock()
|
||||
steerItems := make([]queueItem, len(leftover))
|
||||
for i, text := range leftover {
|
||||
steerItems[i] = queueItem{Prompt: text}
|
||||
}
|
||||
a.queue = append(steerItems, a.queue...)
|
||||
a.mu.Unlock()
|
||||
// Notify UI about the consumed steer messages.
|
||||
a.sendEvent(SteerConsumedEvent{})
|
||||
}
|
||||
}
|
||||
|
||||
// Check if more items were queued while we were processing
|
||||
a.mu.Lock()
|
||||
hasMore := len(a.queue) > 0
|
||||
@@ -522,9 +600,10 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe to SDK events for TUI rendering. The subscription is
|
||||
// temporary — it lives only for the duration of this step.
|
||||
unsub := a.subscribeSDKEvents(sendFn)
|
||||
// Subscribe to SDK events for TUI rendering and per-step usage updates.
|
||||
// The subscription is temporary — it lives only for the duration of this step.
|
||||
var sawStepUsage atomic.Bool
|
||||
unsub := a.subscribeSDKEvents(sendFn, &sawStepUsage)
|
||||
defer unsub()
|
||||
|
||||
// Show spinner while the agent works.
|
||||
@@ -544,8 +623,9 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M
|
||||
// Sync in-memory store with the SDK's authoritative conversation.
|
||||
a.store.Replace(result.Messages)
|
||||
|
||||
// Update usage tracker.
|
||||
a.updateUsageFromTurnResult(result, prompt)
|
||||
// Update usage tracker. If per-step usage was already recorded from
|
||||
// StepUsageEvent callbacks, avoid double-counting totals.
|
||||
a.updateUsageFromTurnResult(result, prompt, sawStepUsage.Load())
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -569,9 +649,10 @@ func (a *App) executeBatch(ctx context.Context, items []queueItem, eventFn func(
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe to SDK events for TUI rendering. The subscription is
|
||||
// temporary — it lives only for the duration of this step.
|
||||
unsub := a.subscribeSDKEvents(sendFn)
|
||||
// Subscribe to SDK events for TUI rendering and per-step usage updates.
|
||||
// The subscription is temporary — it lives only for the duration of this step.
|
||||
var sawStepUsage atomic.Bool
|
||||
unsub := a.subscribeSDKEvents(sendFn, &sawStepUsage)
|
||||
defer unsub()
|
||||
|
||||
// Show spinner while the agent works.
|
||||
@@ -604,8 +685,8 @@ func (a *App) executeBatch(ctx context.Context, items []queueItem, eventFn func(
|
||||
messages = append(messages, item.Prompt)
|
||||
}
|
||||
|
||||
// TODO: Handle file attachments in batch mode
|
||||
// For now, files are ignored in batch mode (rare edge case)
|
||||
// File attachments are not supported in batch mode; fall back to
|
||||
// processing only the first item that carries files.
|
||||
if hasFiles {
|
||||
// If files exist, fall back to processing just the first item with files
|
||||
for _, item := range items {
|
||||
@@ -626,8 +707,10 @@ func (a *App) executeBatch(ctx context.Context, items []queueItem, eventFn func(
|
||||
// Sync in-memory store with the SDK's authoritative conversation.
|
||||
a.store.Replace(result.Messages)
|
||||
|
||||
// Update usage tracker (using last item's prompt for tracking).
|
||||
a.updateUsageFromTurnResult(result, items[len(items)-1].Prompt)
|
||||
// Update usage tracker (using last item's prompt for fallback estimation).
|
||||
// If per-step usage was already recorded from StepUsageEvent callbacks,
|
||||
// avoid double-counting totals.
|
||||
a.updateUsageFromTurnResult(result, items[len(items)-1].Prompt, sawStepUsage.Load())
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -644,9 +727,10 @@ func (a *App) sendEvent(msg tea.Msg) {
|
||||
}
|
||||
|
||||
// subscribeSDKEvents registers temporary SDK event subscribers that convert
|
||||
// SDK events to tea.Msg events and dispatch them via sendFn. Returns an
|
||||
// unsubscribe function that removes all listeners.
|
||||
func (a *App) subscribeSDKEvents(sendFn func(tea.Msg)) func() {
|
||||
// SDK events to tea.Msg events and dispatch them via sendFn. When stepUsageSeen
|
||||
// is provided, it is set to true after any non-zero StepUsageEvent is observed.
|
||||
// Returns an unsubscribe function that removes all listeners.
|
||||
func (a *App) subscribeSDKEvents(sendFn func(tea.Msg), stepUsageSeen *atomic.Bool) func() {
|
||||
k := a.opts.Kit
|
||||
var unsubs []func()
|
||||
|
||||
@@ -678,6 +762,10 @@ func (a *App) subscribeSDKEvents(sendFn func(tea.Msg)) func() {
|
||||
Chunk: ev.Chunk,
|
||||
IsStderr: ev.IsStderr,
|
||||
})
|
||||
case kit.SteerConsumedEvent:
|
||||
sendFn(SteerConsumedEvent{})
|
||||
case kit.StepUsageEvent:
|
||||
a.recordStepUsage(ev, stepUsageSeen)
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -847,29 +935,106 @@ func (a *App) PrintBlockFromExtension(opts extensions.PrintBlockOpts) {
|
||||
}
|
||||
}
|
||||
|
||||
// recordStepUsage applies token/cost usage reported for a completed step.
|
||||
// Step usage events arrive even when a turn is later cancelled, so this keeps
|
||||
// the usage widget accurate on all stop paths.
|
||||
func (a *App) recordStepUsage(ev kit.StepUsageEvent, stepUsageSeen *atomic.Bool) {
|
||||
hasUsage := ev.InputTokens > 0 || ev.OutputTokens > 0 || ev.CacheReadTokens > 0 || ev.CacheWriteTokens > 0
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] recordStepUsage: hasUsage=%v input=%d output=%d cacheRead=%d cacheWrite=%d",
|
||||
hasUsage, ev.InputTokens, ev.OutputTokens, ev.CacheReadTokens, ev.CacheWriteTokens)
|
||||
}
|
||||
if !hasUsage {
|
||||
return
|
||||
}
|
||||
if stepUsageSeen != nil {
|
||||
stepUsageSeen.Store(true)
|
||||
}
|
||||
if a.opts.UsageTracker == nil {
|
||||
return
|
||||
}
|
||||
a.opts.UsageTracker.UpdateUsage(
|
||||
int(ev.InputTokens),
|
||||
int(ev.OutputTokens),
|
||||
int(ev.CacheReadTokens),
|
||||
int(ev.CacheWriteTokens),
|
||||
)
|
||||
// NOTE: We do NOT call SetContextTokens here. Context fill is set once
|
||||
// at turn completion via updateUsageFromTurnResult using FinalUsage.InputTokens,
|
||||
// which reflects the full accumulated context. Per-step context tokens would
|
||||
// cause the display to jump around during multi-step tool calls.
|
||||
}
|
||||
|
||||
// updateUsageFromTurnResult records token usage from an SDK TurnResult into the
|
||||
// configured UsageTracker. This is the SDK-path equivalent of updateUsage.
|
||||
func (a *App) updateUsageFromTurnResult(result *kit.TurnResult, userPrompt string) {
|
||||
// configured UsageTracker. Called once per turn after the turn completes.
|
||||
//
|
||||
// When sawStepUsage is true, totals were already accumulated incrementally via
|
||||
// StepUsageEvent callbacks; in that case this method only updates context fill.
|
||||
// Otherwise it falls back to TotalUsage from the API response.
|
||||
//
|
||||
// NOTE: We only use ACTUAL token counts from API responses for cost tracking.
|
||||
// Estimation is never used for costs - only API-reported tokens are accurate.
|
||||
func (a *App) updateUsageFromTurnResult(result *kit.TurnResult, userPrompt string, sawStepUsage bool) {
|
||||
if a.opts.UsageTracker == nil || result == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if result.TotalUsage != nil {
|
||||
inputTokens := int(result.TotalUsage.InputTokens)
|
||||
outputTokens := int(result.TotalUsage.OutputTokens)
|
||||
if inputTokens > 0 && outputTokens > 0 {
|
||||
cacheReadTokens := int(result.TotalUsage.CacheReadTokens)
|
||||
cacheWriteTokens := int(result.TotalUsage.CacheCreationTokens)
|
||||
a.opts.UsageTracker.UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens)
|
||||
// Debug logging for token tracking
|
||||
if a.opts.Debug {
|
||||
if result.TotalUsage != nil {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult TotalUsage: input=%d output=%d cacheRead=%d cacheCreate=%d",
|
||||
result.TotalUsage.InputTokens, result.TotalUsage.OutputTokens,
|
||||
result.TotalUsage.CacheReadTokens, result.TotalUsage.CacheCreationTokens)
|
||||
} else {
|
||||
a.opts.UsageTracker.EstimateAndUpdateUsage(userPrompt, result.Response)
|
||||
return
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: TotalUsage=nil")
|
||||
}
|
||||
if result.FinalUsage != nil {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult FinalUsage: input=%d output=%d cacheRead=%d cacheCreate=%d",
|
||||
result.FinalUsage.InputTokens, result.FinalUsage.OutputTokens,
|
||||
result.FinalUsage.CacheReadTokens, result.FinalUsage.CacheCreationTokens)
|
||||
} else {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: FinalUsage=nil")
|
||||
}
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: sawStepUsage=%v", sawStepUsage)
|
||||
}
|
||||
|
||||
if result.FinalUsage != nil {
|
||||
if ct := int(result.FinalUsage.InputTokens) + int(result.FinalUsage.OutputTokens); ct > 0 {
|
||||
a.opts.UsageTracker.SetContextTokens(ct)
|
||||
// --- Accumulate cost/token totals for the session ---
|
||||
// Only use actual API-reported tokens for cost tracking.
|
||||
// If sawStepUsage is true, totals were already updated via StepUsageEvent.
|
||||
// Check any token field > 0 (not just InputTokens) because cached prompts
|
||||
// can result in InputTokens=0 while OutputTokens>0 (OpenAI-compatible behavior).
|
||||
hasTotalUsage := result.TotalUsage != nil &&
|
||||
(result.TotalUsage.InputTokens > 0 ||
|
||||
result.TotalUsage.OutputTokens > 0 ||
|
||||
result.TotalUsage.CacheReadTokens > 0 ||
|
||||
result.TotalUsage.CacheCreationTokens > 0)
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: hasTotalUsage=%v", hasTotalUsage)
|
||||
}
|
||||
if !sawStepUsage && hasTotalUsage {
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: calling UpdateUsage input=%d output=%d cacheRead=%d cacheCreate=%d",
|
||||
result.TotalUsage.InputTokens, result.TotalUsage.OutputTokens,
|
||||
result.TotalUsage.CacheReadTokens, result.TotalUsage.CacheCreationTokens)
|
||||
}
|
||||
a.opts.UsageTracker.UpdateUsage(
|
||||
int(result.TotalUsage.InputTokens),
|
||||
int(result.TotalUsage.OutputTokens),
|
||||
int(result.TotalUsage.CacheReadTokens),
|
||||
int(result.TotalUsage.CacheCreationTokens),
|
||||
)
|
||||
}
|
||||
|
||||
// --- Context window fill (drives the % bar) ---
|
||||
// Use FinalUsage.InputTokens as the context window fill. The API's InputTokens
|
||||
// already includes the full conversation history (system prompt + all previous
|
||||
// messages + current user message). Adding OutputTokens would double-count since
|
||||
// the output becomes part of the input for the next turn.
|
||||
if result.FinalUsage != nil && result.FinalUsage.InputTokens > 0 {
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: calling SetContextTokens=%d (FinalUsage.InputTokens)",
|
||||
result.FinalUsage.InputTokens)
|
||||
}
|
||||
a.opts.UsageTracker.SetContextTokens(int(result.FinalUsage.InputTokens))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
@@ -14,6 +16,47 @@ import (
|
||||
// Helpers
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
type usageUpdaterStub struct {
|
||||
mu sync.Mutex
|
||||
|
||||
updateCalls int
|
||||
estimateCalls int
|
||||
contextCalls int
|
||||
|
||||
lastUpdateInput int
|
||||
lastUpdateOutput int
|
||||
lastUpdateCacheRead int
|
||||
lastUpdateCacheWrite int
|
||||
lastContextTokens int
|
||||
lastEstimateInput string
|
||||
lastEstimateOutput string
|
||||
}
|
||||
|
||||
func (s *usageUpdaterStub) UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.updateCalls++
|
||||
s.lastUpdateInput = inputTokens
|
||||
s.lastUpdateOutput = outputTokens
|
||||
s.lastUpdateCacheRead = cacheReadTokens
|
||||
s.lastUpdateCacheWrite = cacheWriteTokens
|
||||
}
|
||||
|
||||
func (s *usageUpdaterStub) EstimateAndUpdateUsage(inputText, outputText string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.estimateCalls++
|
||||
s.lastEstimateInput = inputText
|
||||
s.lastEstimateOutput = outputText
|
||||
}
|
||||
|
||||
func (s *usageUpdaterStub) SetContextTokens(tokens int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.contextCalls++
|
||||
s.lastContextTokens = tokens
|
||||
}
|
||||
|
||||
// turnResult builds a minimal TurnResult with response text t.
|
||||
func turnResult(t string) *kit.TurnResult {
|
||||
return &kit.TurnResult{Response: t}
|
||||
@@ -489,3 +532,133 @@ func TestQueueLength_reflects(t *testing.T) {
|
||||
t.Fatalf("expected 3, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecordStepUsage_updatesTracker verifies that per-step usage updates are
|
||||
// recorded immediately for cost tracking. Context tokens are NOT updated here
|
||||
// (only via updateUsageFromTurnResult) to avoid display jumps during multi-step
|
||||
// tool calls.
|
||||
func TestRecordStepUsage_updatesTracker(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
defer app.Close()
|
||||
|
||||
app.recordStepUsage(kit.StepUsageEvent{
|
||||
InputTokens: 120,
|
||||
OutputTokens: 45,
|
||||
CacheReadTokens: 5,
|
||||
CacheWriteTokens: 2,
|
||||
}, nil)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
|
||||
if usage.updateCalls != 1 {
|
||||
t.Fatalf("expected 1 update call, got %d", usage.updateCalls)
|
||||
}
|
||||
if usage.lastUpdateInput != 120 || usage.lastUpdateOutput != 45 || usage.lastUpdateCacheRead != 5 || usage.lastUpdateCacheWrite != 2 {
|
||||
t.Fatalf("unexpected usage update payload: in=%d out=%d cache_read=%d cache_write=%d",
|
||||
usage.lastUpdateInput, usage.lastUpdateOutput, usage.lastUpdateCacheRead, usage.lastUpdateCacheWrite)
|
||||
}
|
||||
// Context tokens should NOT be updated by recordStepUsage (only by updateUsageFromTurnResult)
|
||||
if usage.contextCalls != 0 {
|
||||
t.Fatalf("expected 0 context token updates from recordStepUsage, got %d", usage.contextCalls)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateUsageFromTurnResult_skipsTotalsWhenStepUsageSeen ensures we avoid
|
||||
// double-counting totals once StepUsageEvent-based updates were already applied.
|
||||
func TestUpdateUsageFromTurnResult_skipsTotalsWhenStepUsageSeen(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
defer app.Close()
|
||||
|
||||
app.updateUsageFromTurnResult(&kit.TurnResult{
|
||||
Response: "ok",
|
||||
TotalUsage: &fantasy.Usage{
|
||||
InputTokens: 999,
|
||||
OutputTokens: 111,
|
||||
CacheReadTokens: 7,
|
||||
CacheCreationTokens: 3,
|
||||
},
|
||||
FinalUsage: &fantasy.Usage{InputTokens: 456},
|
||||
}, "prompt", true)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
|
||||
if usage.updateCalls != 0 {
|
||||
t.Fatalf("expected no total usage update when sawStepUsage=true, got %d", usage.updateCalls)
|
||||
}
|
||||
if usage.estimateCalls != 0 {
|
||||
t.Fatalf("expected no estimate update when sawStepUsage=true, got %d", usage.estimateCalls)
|
||||
}
|
||||
// Context tokens should be InputTokens only (456)
|
||||
if usage.contextCalls != 1 || usage.lastContextTokens != 456 {
|
||||
t.Fatalf("expected final context tokens=456 (InputTokens only), got calls=%d tokens=%d", usage.contextCalls, usage.lastContextTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateUsageFromTurnResult_recordsWhenInputTokensZero verifies that usage
|
||||
// is recorded when InputTokens=0 but OutputTokens>0 (OpenAI-compatible cache behavior).
|
||||
func TestUpdateUsageFromTurnResult_recordsWhenInputTokensZero(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
defer app.Close()
|
||||
|
||||
// Simulate OpenAI-compatible behavior: all prompt tokens cached, InputTokens=0
|
||||
app.updateUsageFromTurnResult(&kit.TurnResult{
|
||||
Response: "ok",
|
||||
TotalUsage: &fantasy.Usage{
|
||||
InputTokens: 0, // All cached - subtracted from prompt
|
||||
OutputTokens: 150, // Actual generated tokens
|
||||
CacheReadTokens: 500, // Cache hit
|
||||
CacheCreationTokens: 0,
|
||||
},
|
||||
FinalUsage: &fantasy.Usage{InputTokens: 0, OutputTokens: 150},
|
||||
}, "prompt", false)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
|
||||
if usage.updateCalls != 1 {
|
||||
t.Fatalf("expected 1 update call when InputTokens=0 but OutputTokens>0, got %d", usage.updateCalls)
|
||||
}
|
||||
if usage.lastUpdateInput != 0 || usage.lastUpdateOutput != 150 {
|
||||
t.Fatalf("expected input=0 output=150, got input=%d output=%d",
|
||||
usage.lastUpdateInput, usage.lastUpdateOutput)
|
||||
}
|
||||
if usage.lastUpdateCacheRead != 500 {
|
||||
t.Fatalf("expected cache_read=500, got %d", usage.lastUpdateCacheRead)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateUsageFromTurnResult_contextTokensUsesInputOnly verifies that context
|
||||
// window fill uses InputTokens only (not input+output). The API's InputTokens
|
||||
// already includes the full conversation history; adding output would double-count.
|
||||
func TestUpdateUsageFromTurnResult_contextTokensUsesInputOnly(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
defer app.Close()
|
||||
|
||||
app.updateUsageFromTurnResult(&kit.TurnResult{
|
||||
Response: "ok",
|
||||
TotalUsage: &fantasy.Usage{
|
||||
InputTokens: 1000,
|
||||
OutputTokens: 200,
|
||||
},
|
||||
FinalUsage: &fantasy.Usage{
|
||||
InputTokens: 1000, // Full context including history
|
||||
OutputTokens: 200,
|
||||
},
|
||||
}, "prompt", false)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
|
||||
// Context tokens should be InputTokens only (1000), not input+output (1200)
|
||||
// because InputTokens already includes the full conversation history
|
||||
if usage.contextCalls != 1 || usage.lastContextTokens != 1000 {
|
||||
t.Fatalf("expected context tokens=1000 (InputTokens only), got calls=%d tokens=%d",
|
||||
usage.contextCalls, usage.lastContextTokens)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,6 +141,12 @@ type CompactErrorEvent struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// SteerConsumedEvent is sent when one or more steering messages have been
|
||||
// consumed — either injected mid-turn via PrepareStep, or drained into the
|
||||
// queue after a turn completes. The TUI uses this to clear the steering
|
||||
// badge from the display.
|
||||
type SteerConsumedEvent struct{}
|
||||
|
||||
// ModelChangedEvent is sent when an extension changes the active model via
|
||||
// ctx.SetModel. The TUI updates the model name shown in the status bar and
|
||||
// message attribution.
|
||||
|
||||
@@ -10,9 +10,10 @@ import (
|
||||
)
|
||||
|
||||
// CredentialStore holds all stored credentials for various providers.
|
||||
// Currently supports Anthropic credentials with both OAuth and API key authentication methods.
|
||||
// Currently supports Anthropic and OpenAI credentials with both OAuth and API key authentication methods.
|
||||
type CredentialStore struct {
|
||||
Anthropic *AnthropicCredentials `json:"anthropic,omitempty"`
|
||||
OpenAI *OpenAICredentials `json:"openai,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicCredentials holds Anthropic API credentials supporting both OAuth
|
||||
@@ -28,13 +29,44 @@ type AnthropicCredentials struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// OpenAICredentials holds OpenAI API credentials supporting both OAuth
|
||||
// and API key authentication methods. The Type field indicates which authentication
|
||||
// method is being used. For OAuth, tokens are stored with expiration timestamps
|
||||
// for automatic refresh. For API keys, only the key itself is stored.
|
||||
type OpenAICredentials struct {
|
||||
Type string `json:"type"` // "oauth" or "api_key"
|
||||
APIKey string `json:"api_key,omitempty"` // For API key auth
|
||||
AccessToken string `json:"access_token,omitempty"` // For OAuth
|
||||
RefreshToken string `json:"refresh_token,omitempty"` // For OAuth
|
||||
ExpiresAt int64 `json:"expires_at,omitempty"` // For OAuth
|
||||
AccountID string `json:"account_id,omitempty"` // For OAuth (ChatGPT account ID)
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// oauthTokenExpired reports whether an OAuth token with the given type and
|
||||
// expiry unix timestamp is past its expiry. Returns false for API key
|
||||
// credentials or when no expiry is set.
|
||||
func oauthTokenExpired(credType string, expiresAt int64) bool {
|
||||
if credType != "oauth" || expiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= expiresAt
|
||||
}
|
||||
|
||||
// oauthTokenNeedsRefresh reports whether an OAuth token will expire within the
|
||||
// next 5 minutes, allowing proactive refresh before it becomes invalid.
|
||||
// Returns false for API key credentials or when no expiry is set.
|
||||
func oauthTokenNeedsRefresh(credType string, expiresAt int64) bool {
|
||||
if credType != "oauth" || expiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= (expiresAt - 300) // 5 minutes buffer
|
||||
}
|
||||
|
||||
// IsExpired checks if the OAuth token is expired based on the ExpiresAt timestamp.
|
||||
// Returns false for API key authentication or if no expiration is set.
|
||||
func (c *AnthropicCredentials) IsExpired() bool {
|
||||
if c.Type != "oauth" || c.ExpiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= c.ExpiresAt
|
||||
return oauthTokenExpired(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// NeedsRefresh checks if the OAuth token needs refresh, returning true if the token
|
||||
@@ -42,10 +74,21 @@ func (c *AnthropicCredentials) IsExpired() bool {
|
||||
// to avoid authentication failures during operations. Returns false for API key
|
||||
// authentication or if no expiration is set.
|
||||
func (c *AnthropicCredentials) NeedsRefresh() bool {
|
||||
if c.Type != "oauth" || c.ExpiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= (c.ExpiresAt - 300) // 5 minutes buffer
|
||||
return oauthTokenNeedsRefresh(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// IsExpired checks if the OAuth token is expired based on the ExpiresAt timestamp.
|
||||
// Returns false for API key authentication or if no expiration is set.
|
||||
func (c *OpenAICredentials) IsExpired() bool {
|
||||
return oauthTokenExpired(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// NeedsRefresh checks if the OAuth token needs refresh, returning true if the token
|
||||
// will expire within the next 5 minutes. This allows for proactive token refresh
|
||||
// to avoid authentication failures during operations. Returns false for API key
|
||||
// authentication or if no expiration is set.
|
||||
func (c *OpenAICredentials) NeedsRefresh() bool {
|
||||
return oauthTokenNeedsRefresh(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// CredentialManager handles secure storage and retrieval of authentication credentials.
|
||||
@@ -212,6 +255,142 @@ func (cm *CredentialManager) HasAnthropicCredentials() (bool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// SetOpenAICredentials stores OpenAI API key credentials. It validates the
|
||||
// API key format before storing. The API key must start with "sk-" and be
|
||||
// at least 20 characters long. Returns an error if the API key is invalid or
|
||||
// if storage fails.
|
||||
func (cm *CredentialManager) SetOpenAICredentials(apiKey string) error {
|
||||
if err := validateOpenAIAPIKey(apiKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store.OpenAI = &OpenAICredentials{
|
||||
Type: "api_key",
|
||||
APIKey: apiKey,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
return cm.SaveCredentials(store)
|
||||
}
|
||||
|
||||
// GetOpenAICredentials retrieves stored OpenAI credentials. Returns nil if
|
||||
// no credentials are stored. The returned credentials may be either OAuth or API
|
||||
// key type, check the Type field to determine which.
|
||||
func (cm *CredentialManager) GetOpenAICredentials() (*OpenAICredentials, error) {
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return store.OpenAI, nil
|
||||
}
|
||||
|
||||
// RemoveOpenAICredentials removes stored OpenAI credentials from storage.
|
||||
// If this was the only credential stored, the entire credentials file is removed.
|
||||
// Returns an error if the removal fails.
|
||||
func (cm *CredentialManager) RemoveOpenAICredentials() error {
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store.OpenAI = nil
|
||||
|
||||
// If store is empty, remove the file entirely
|
||||
if store.Anthropic == nil && store.OpenAI == nil {
|
||||
if err := os.Remove(cm.credentialsPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to remove credentials file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return cm.SaveCredentials(store)
|
||||
}
|
||||
|
||||
// HasOpenAICredentials checks if valid OpenAI credentials are stored.
|
||||
// Returns true if either a non-empty OAuth access token or API key is present,
|
||||
// false otherwise. Returns an error if credentials cannot be loaded.
|
||||
func (cm *CredentialManager) HasOpenAICredentials() (bool, error) {
|
||||
creds, err := cm.GetOpenAICredentials()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if creds == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Check based on credential type
|
||||
switch creds.Type {
|
||||
case "oauth":
|
||||
return creds.AccessToken != "", nil
|
||||
case "api_key":
|
||||
return creds.APIKey != "", nil
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// SetOpenAIOAuthCredentials stores OpenAI OAuth credentials in the credential manager's secure storage.
|
||||
// The credentials should include access token, refresh token, and expiration information.
|
||||
// Returns an error if the credentials cannot be saved.
|
||||
func (cm *CredentialManager) SetOpenAIOAuthCredentials(creds *OpenAICredentials) error {
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store.OpenAI = creds
|
||||
return cm.SaveCredentials(store)
|
||||
}
|
||||
|
||||
// GetValidOpenAIAccessToken returns a valid access token for API requests. For OAuth credentials,
|
||||
// it automatically refreshes the token if it's expired or about to expire. For API key
|
||||
// credentials, it simply returns the API key. Returns an error if no credentials are found,
|
||||
// if token refresh fails, or if the credential type is unknown.
|
||||
func (cm *CredentialManager) GetValidOpenAIAccessToken() (string, error) {
|
||||
creds, err := cm.GetOpenAICredentials()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if creds == nil {
|
||||
return "", fmt.Errorf("no credentials found")
|
||||
}
|
||||
|
||||
// For API key auth, return the API key
|
||||
if creds.Type == "api_key" {
|
||||
return creds.APIKey, nil
|
||||
}
|
||||
|
||||
// For OAuth, check if token needs refresh
|
||||
if creds.Type == "oauth" {
|
||||
if creds.NeedsRefresh() {
|
||||
// Refresh the token
|
||||
client := NewOpenAIOAuthClient()
|
||||
newCreds, err := client.RefreshToken(creds.RefreshToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to refresh token: %w", err)
|
||||
}
|
||||
|
||||
// Update stored credentials
|
||||
if err := cm.SetOpenAIOAuthCredentials(newCreds); err != nil {
|
||||
return "", fmt.Errorf("failed to save refreshed token: %w", err)
|
||||
}
|
||||
|
||||
return newCreds.AccessToken, nil
|
||||
}
|
||||
|
||||
return creds.AccessToken, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("unknown credential type: %s", creds.Type)
|
||||
}
|
||||
|
||||
// GetCredentialsPath returns the absolute path to the credentials JSON file.
|
||||
// This is useful for debugging or displaying the storage location to users.
|
||||
func (cm *CredentialManager) GetCredentialsPath() string {
|
||||
@@ -238,6 +417,26 @@ func validateAnthropicAPIKey(apiKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateOpenAIAPIKey validates the format of an OpenAI API key
|
||||
func validateOpenAIAPIKey(apiKey string) error {
|
||||
apiKey = strings.TrimSpace(apiKey)
|
||||
|
||||
if apiKey == "" {
|
||||
return fmt.Errorf("API key cannot be empty")
|
||||
}
|
||||
|
||||
// OpenAI API keys typically start with "sk-" and are quite long
|
||||
if !strings.HasPrefix(apiKey, "sk-") {
|
||||
return fmt.Errorf("invalid OpenAI API key format (should start with 'sk-')")
|
||||
}
|
||||
|
||||
if len(apiKey) < 20 {
|
||||
return fmt.Errorf("API key appears to be too short")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAnthropicAPIKey retrieves an Anthropic API key from multiple sources in priority order:
|
||||
// 1. Command-line flag value (highest priority)
|
||||
// 2. Stored credentials (OAuth or API key)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -30,6 +31,7 @@ type OAuthClient struct {
|
||||
type AuthData struct {
|
||||
URL string
|
||||
Verifier string
|
||||
State string // Optional state parameter for CSRF protection
|
||||
}
|
||||
|
||||
// NewOAuthClient creates a new OAuth client configured for Anthropic's OAuth service.
|
||||
@@ -199,6 +201,270 @@ func (c *OAuthClient) parseCodeAndState(code string) (parsedCode, parsedState st
|
||||
return
|
||||
}
|
||||
|
||||
// OpenAIOAuthClient handles OAuth 2.0 authentication flow with OpenAI Codex (ChatGPT Plus/Pro).
|
||||
// This uses OpenAI's auth0-based OAuth service for ChatGPT account authentication.
|
||||
type OpenAIOAuthClient struct {
|
||||
ClientID string
|
||||
AuthorizeURL string
|
||||
TokenURL string
|
||||
RedirectURI string
|
||||
Scopes string
|
||||
}
|
||||
|
||||
// NewOpenAIOAuthClient creates a new OAuth client configured for OpenAI Codex OAuth.
|
||||
// This uses the public client ID for CLI applications with PKCE for security.
|
||||
func NewOpenAIOAuthClient() *OpenAIOAuthClient {
|
||||
return &OpenAIOAuthClient{
|
||||
// Public client ID for OpenAI Codex CLI OAuth
|
||||
ClientID: "app_EMoamEEZ73f0CkXaXp7hrann",
|
||||
AuthorizeURL: "https://auth.openai.com/oauth/authorize",
|
||||
TokenURL: "https://auth.openai.com/oauth/token",
|
||||
RedirectURI: "http://localhost:1455/auth/callback",
|
||||
Scopes: "openid profile email offline_access",
|
||||
}
|
||||
}
|
||||
|
||||
// GetAuthorizationURL generates a complete authorization URL for the OAuth flow with
|
||||
// PKCE parameters. Returns an AuthData structure containing the URL for user
|
||||
// authentication and the PKCE verifier for the subsequent code exchange.
|
||||
func (c *OpenAIOAuthClient) GetAuthorizationURL() (*AuthData, error) {
|
||||
verifier, challenge, err := generatePKCE()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate PKCE: %w", err)
|
||||
}
|
||||
|
||||
// Generate random state
|
||||
stateBytes := make([]byte, 16)
|
||||
if _, err := rand.Read(stateBytes); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
}
|
||||
state := fmt.Sprintf("%x", stateBytes)
|
||||
|
||||
params := url.Values{
|
||||
"response_type": {"code"},
|
||||
"client_id": {c.ClientID},
|
||||
"redirect_uri": {c.RedirectURI},
|
||||
"scope": {c.Scopes},
|
||||
"code_challenge": {challenge},
|
||||
"code_challenge_method": {"S256"},
|
||||
"state": {state},
|
||||
"id_token_add_organizations": {"true"},
|
||||
"codex_cli_simplified_flow": {"true"},
|
||||
"originator": {"kit"},
|
||||
}
|
||||
|
||||
authURL := fmt.Sprintf("%s?%s", c.AuthorizeURL, params.Encode())
|
||||
|
||||
return &AuthData{
|
||||
URL: authURL,
|
||||
Verifier: verifier,
|
||||
State: state,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges an authorization code for access and refresh tokens.
|
||||
// The code parameter should be the authorization code received from the OAuth callback.
|
||||
// The verifier parameter must be the same PKCE verifier generated during GetAuthorizationURL.
|
||||
// Returns OpenAICredentials containing the tokens, expiration, and account ID.
|
||||
func (c *OpenAIOAuthClient) ExchangeCode(code, verifier string) (*OpenAICredentials, error) {
|
||||
return c.exchangeAuthorizationCode(code, verifier, c.RedirectURI)
|
||||
}
|
||||
|
||||
// exchangeAuthorizationCode performs the token exchange with the OAuth server
|
||||
func (c *OpenAIOAuthClient) exchangeAuthorizationCode(code, verifier, redirectUri string) (*OpenAICredentials, error) {
|
||||
data := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"client_id": {c.ClientID},
|
||||
"code": {code},
|
||||
"code_verifier": {verifier},
|
||||
"redirect_uri": {redirectUri},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), "POST", c.TokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make token request: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("token exchange failed: %s", string(body))
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode token response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.AccessToken == "" || tokenResp.RefreshToken == "" {
|
||||
return nil, fmt.Errorf("token response missing required fields")
|
||||
}
|
||||
|
||||
// Extract account ID from JWT token
|
||||
accountID := extractOpenAIAccountID(tokenResp.AccessToken)
|
||||
if accountID == "" {
|
||||
return nil, fmt.Errorf("failed to extract account ID from token")
|
||||
}
|
||||
|
||||
return &OpenAICredentials{
|
||||
Type: "oauth",
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
||||
CreatedAt: time.Now(),
|
||||
AccountID: accountID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an expired or expiring access token using a refresh token.
|
||||
// Returns new OpenAICredentials with updated access token, refresh token (may be
|
||||
// rotated), and new expiration timestamp. Returns an error if the refresh fails or
|
||||
// the refresh token is invalid.
|
||||
func (c *OpenAIOAuthClient) RefreshToken(refreshToken string) (*OpenAICredentials, error) {
|
||||
data := url.Values{
|
||||
"grant_type": {"refresh_token"},
|
||||
"refresh_token": {refreshToken},
|
||||
"client_id": {c.ClientID},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), "POST", c.TokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make refresh request: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("token refresh failed: %s", string(body))
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode refresh response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.AccessToken == "" || tokenResp.RefreshToken == "" {
|
||||
return nil, fmt.Errorf("refresh response missing required fields")
|
||||
}
|
||||
|
||||
// Extract account ID from JWT token
|
||||
accountID := extractOpenAIAccountID(tokenResp.AccessToken)
|
||||
if accountID == "" {
|
||||
return nil, fmt.Errorf("failed to extract account ID from refreshed token")
|
||||
}
|
||||
|
||||
return &OpenAICredentials{
|
||||
Type: "oauth",
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
||||
CreatedAt: time.Now(),
|
||||
AccountID: accountID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extractOpenAIAccountID extracts the ChatGPT account ID from a JWT access token.
|
||||
// The account ID is stored in the claim path https://api.openai.com/auth.chatgpt_account_id
|
||||
func extractOpenAIAccountID(token string) string {
|
||||
// JWT tokens are base64-encoded JSON payloads
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Decode payload (second part)
|
||||
payload := parts[1]
|
||||
// Add padding if needed
|
||||
if len(payload)%4 != 0 {
|
||||
payload += strings.Repeat("=", 4-len(payload)%4)
|
||||
}
|
||||
|
||||
decoded, err := base64.URLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var claims map[string]any
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Navigate to the claim path: https://api.openai.com/auth.chatgpt_account_id
|
||||
authPath, ok := claims["https://api.openai.com/auth"].(map[string]any)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
accountID, ok := authPath["chatgpt_account_id"].(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
return accountID
|
||||
}
|
||||
|
||||
// ParseOpenAIAuthorizationInput parses various forms of authorization input:
|
||||
// - Full callback URL: http://localhost:1455/auth/callback?code=xxx&state=yyy
|
||||
// - Code#State format: abc123#state456
|
||||
// - Query string: code=abc123&state=state456
|
||||
// - Just the code: abc123
|
||||
func ParseOpenAIAuthorizationInput(input string) (code, state string) {
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// Try parsing as URL
|
||||
if strings.HasPrefix(input, "http") {
|
||||
if u, err := url.Parse(input); err == nil {
|
||||
return u.Query().Get("code"), u.Query().Get("state")
|
||||
}
|
||||
}
|
||||
|
||||
// Try code#state format
|
||||
if strings.Contains(input, "#") {
|
||||
parts := strings.SplitN(input, "#", 2)
|
||||
return parts[0], parts[1]
|
||||
}
|
||||
|
||||
// Try query string format
|
||||
if strings.Contains(input, "code=") {
|
||||
if values, err := url.ParseQuery(input); err == nil {
|
||||
return values.Get("code"), values.Get("state")
|
||||
}
|
||||
}
|
||||
|
||||
// Assume it's just the code
|
||||
return input, ""
|
||||
}
|
||||
|
||||
// SetOAuthCredentials stores OAuth credentials in the credential manager's secure storage.
|
||||
// The credentials should include access token, refresh token, and expiration information.
|
||||
// Returns an error if the credentials cannot be saved.
|
||||
|
||||
@@ -403,10 +403,9 @@ func FilepathOr[T any](key string, value *T) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
filepath.Join(home, absPath[2:])
|
||||
absPath = filepath.Join(home, absPath[2:])
|
||||
}
|
||||
if !filepath.IsAbs(absPath) {
|
||||
// base := GetConfigPath()
|
||||
base := configPath
|
||||
if base == "" {
|
||||
fmt.Fprintf(os.Stderr, "unable to build relative path to config.")
|
||||
|
||||
+234
-44
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
@@ -13,19 +14,45 @@ import (
|
||||
udiff "github.com/aymanbagabas/go-udiff"
|
||||
)
|
||||
|
||||
type editArgs struct {
|
||||
Path string `json:"path"`
|
||||
// Edit represents a single replacement in a multi-edit operation.
|
||||
type Edit struct {
|
||||
OldText string `json:"old_text"`
|
||||
NewText string `json:"new_text"`
|
||||
}
|
||||
|
||||
// editArgs holds the arguments for the edit tool.
|
||||
// Supports both single-edit mode (old_text/new_text) and multi-edit mode (edits array).
|
||||
type editArgs struct {
|
||||
Path string `json:"path"`
|
||||
OldText string `json:"old_text"` // Single-edit mode
|
||||
NewText string `json:"new_text"` // Single-edit mode
|
||||
Edits []Edit `json:"edits"` // Multi-edit mode
|
||||
}
|
||||
|
||||
// replacement represents a normalized edit ready for processing.
|
||||
type replacement struct {
|
||||
oldText string // normalized old text for matching
|
||||
newText string // normalized new text
|
||||
originalOld string // original old text for metadata
|
||||
originalNew string // original new text for metadata
|
||||
index int // index in the original edits array (for error messages)
|
||||
}
|
||||
|
||||
// matchedReplacement represents a replacement with its match location.
|
||||
type matchedReplacement struct {
|
||||
replacement
|
||||
start int // start index in normalized content
|
||||
end int // end index in normalized content
|
||||
usedFuzzyMatch bool // true if fuzzy matching was used
|
||||
}
|
||||
|
||||
// NewEditTool creates the edit core tool.
|
||||
func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
cfg := ApplyOptions(opts)
|
||||
return &coreTool{
|
||||
info: fantasy.ToolInfo{
|
||||
Name: "edit",
|
||||
Description: "Edit a file by replacing exact text. The old_text must match exactly (including whitespace). Use this for precise, surgical edits. Fails if old_text is not found or matches multiple locations.",
|
||||
Description: "Edit a file by replacing exact text. Supports single edit via old_text/new_text, or multiple edits via the edits array. All edits in the array are matched against the original file content (non-incremental) and must be non-overlapping.",
|
||||
Parameters: map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
@@ -33,14 +60,32 @@ func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
},
|
||||
"old_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Exact text to find and replace (must match exactly)",
|
||||
"description": "Exact text to find and replace (single-edit mode). Must not be used with 'edits' array.",
|
||||
},
|
||||
"new_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "New text to replace the old text with",
|
||||
"description": "New text to replace the old text with (single-edit mode). Must not be used with 'edits' array.",
|
||||
},
|
||||
"edits": map[string]any{
|
||||
"type": "array",
|
||||
"description": "Array of edits for multi-region replacement. Each edit must have unique, non-overlapping old_text. All matches are against the original file content.",
|
||||
"items": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"old_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Exact text to find and replace for this edit",
|
||||
},
|
||||
"new_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "New text for this edit",
|
||||
},
|
||||
},
|
||||
"required": []string{"old_text", "new_text"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Required: []string{"path", "old_text", "new_text"},
|
||||
Required: []string{"path"},
|
||||
},
|
||||
handler: func(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return executeEdit(ctx, call, cfg.WorkDir)
|
||||
@@ -51,7 +96,7 @@ func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
func executeEdit(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
|
||||
var args editArgs
|
||||
if err := parseArgs(call.Input, &args); err != nil {
|
||||
return fantasy.NewTextErrorResponse("path, old_text, and new_text parameters are required"), nil
|
||||
return fantasy.NewTextErrorResponse("failed to parse arguments: " + err.Error()), nil
|
||||
}
|
||||
if args.Path == "" {
|
||||
return fantasy.NewTextErrorResponse("path parameter is required"), nil
|
||||
@@ -69,56 +114,201 @@ func executeEdit(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
|
||||
content := string(contentBytes)
|
||||
|
||||
// Normalize line endings for matching
|
||||
normalized := strings.ReplaceAll(content, "\r\n", "\n")
|
||||
normalizedOld := strings.ReplaceAll(args.OldText, "\r\n", "\n")
|
||||
|
||||
// Try exact match first
|
||||
count := strings.Count(normalized, normalizedOld)
|
||||
|
||||
// If no exact match, try fuzzy matching
|
||||
if count == 0 {
|
||||
if idx, matchLen := fuzzyMatch(normalized, normalizedOld); idx >= 0 {
|
||||
// Apply fuzzy match — the matched text is the original content slice
|
||||
matchedText := normalized[idx : idx+matchLen]
|
||||
newContent := normalized[:idx] + args.NewText + normalized[idx+matchLen:]
|
||||
if err := os.WriteFile(absPath, []byte(newContent), 0644); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to write file: %v", err)), nil
|
||||
}
|
||||
diff := generateDiff(absPath, normalized, newContent)
|
||||
resp := fantasy.NewTextResponse(fmt.Sprintf("Applied edit (fuzzy match) to %s\n%s", args.Path, diff))
|
||||
return fantasy.WithResponseMetadata(resp, editDiffMeta(absPath, matchedText, args.NewText)), nil
|
||||
}
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("old_text not found in %s", args.Path)), nil
|
||||
// Normalize and validate input
|
||||
replacements, err := normalizeEditInput(args)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
if count > 1 {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("found %d matches for old_text in %s. Provide more context to identify the correct match.", count, args.Path)), nil
|
||||
// Apply all edits
|
||||
newContent, applied, err := applyEdits(content, replacements)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
// Apply the edit
|
||||
newContent := strings.Replace(normalized, normalizedOld, args.NewText, 1)
|
||||
|
||||
// Write the file
|
||||
if err := os.WriteFile(absPath, []byte(newContent), 0644); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to write file: %v", err)), nil
|
||||
}
|
||||
|
||||
diff := generateDiff(absPath, normalized, newContent)
|
||||
resp := fantasy.NewTextResponse(fmt.Sprintf("Applied edit to %s\n%s", args.Path, diff))
|
||||
return fantasy.WithResponseMetadata(resp, editDiffMeta(absPath, normalizedOld, args.NewText)), nil
|
||||
// Generate diff
|
||||
normalizedContent := strings.ReplaceAll(content, "\r\n", "\n")
|
||||
diff := generateDiff(absPath, normalizedContent, newContent)
|
||||
|
||||
// Build response with fuzzy match indication
|
||||
fuzzyCount := 0
|
||||
for _, m := range applied {
|
||||
if m.usedFuzzyMatch {
|
||||
fuzzyCount++
|
||||
}
|
||||
}
|
||||
|
||||
var msg string
|
||||
if len(applied) == 1 {
|
||||
if fuzzyCount > 0 {
|
||||
msg = fmt.Sprintf("Applied edit (fuzzy match) to %s\n%s", args.Path, diff)
|
||||
} else {
|
||||
msg = fmt.Sprintf("Applied edit to %s\n%s", args.Path, diff)
|
||||
}
|
||||
} else {
|
||||
if fuzzyCount > 0 {
|
||||
msg = fmt.Sprintf("Applied %d edits (%d fuzzy) to %s\n%s", len(applied), fuzzyCount, args.Path, diff)
|
||||
} else {
|
||||
msg = fmt.Sprintf("Applied %d edits to %s\n%s", len(applied), args.Path, diff)
|
||||
}
|
||||
}
|
||||
|
||||
resp := fantasy.NewTextResponse(msg)
|
||||
return fantasy.WithResponseMetadata(resp, editDiffMeta(absPath, applied)), nil
|
||||
}
|
||||
|
||||
// normalizeEditInput validates and normalizes the edit input.
|
||||
// Returns error if both single-edit and multi-edit modes are used.
|
||||
func normalizeEditInput(args editArgs) ([]replacement, error) {
|
||||
singleMode := args.OldText != "" || args.NewText != ""
|
||||
multiMode := len(args.Edits) > 0
|
||||
|
||||
if singleMode && multiMode {
|
||||
return nil, fmt.Errorf("cannot use old_text/new_text together with edits array")
|
||||
}
|
||||
|
||||
if !singleMode && !multiMode {
|
||||
return nil, fmt.Errorf("must provide either old_text/new_text or edits array")
|
||||
}
|
||||
|
||||
if singleMode {
|
||||
if args.OldText == "" {
|
||||
return nil, fmt.Errorf("old_text is required when using single-edit mode")
|
||||
}
|
||||
if args.NewText == "" {
|
||||
return nil, fmt.Errorf("new_text is required when using single-edit mode")
|
||||
}
|
||||
return []replacement{{
|
||||
oldText: strings.ReplaceAll(args.OldText, "\r\n", "\n"),
|
||||
newText: strings.ReplaceAll(args.NewText, "\r\n", "\n"),
|
||||
originalOld: args.OldText,
|
||||
originalNew: args.NewText,
|
||||
index: 0,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// Multi-edit mode
|
||||
var reps []replacement
|
||||
for i, edit := range args.Edits {
|
||||
if edit.OldText == "" {
|
||||
return nil, fmt.Errorf("edits[%d].old_text is required", i)
|
||||
}
|
||||
reps = append(reps, replacement{
|
||||
oldText: strings.ReplaceAll(edit.OldText, "\r\n", "\n"),
|
||||
newText: strings.ReplaceAll(edit.NewText, "\r\n", "\n"),
|
||||
originalOld: edit.OldText,
|
||||
originalNew: edit.NewText,
|
||||
index: i,
|
||||
})
|
||||
}
|
||||
return reps, nil
|
||||
}
|
||||
|
||||
// applyEdits applies multiple replacements to the content.
|
||||
// All matches are against the original content (non-incremental).
|
||||
// Returns the new content, the applied matches, and any error.
|
||||
func applyEdits(content string, edits []replacement) (string, []matchedReplacement, error) {
|
||||
normalizedContent := strings.ReplaceAll(content, "\r\n", "\n")
|
||||
|
||||
// Find all matches
|
||||
var matched []matchedReplacement
|
||||
for _, edit := range edits {
|
||||
m, err := findMatch(normalizedContent, edit)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
matched = append(matched, *m)
|
||||
}
|
||||
|
||||
// Sort by position
|
||||
sort.Slice(matched, func(i, j int) bool {
|
||||
return matched[i].start < matched[j].start
|
||||
})
|
||||
|
||||
// Check for overlaps
|
||||
for i := 1; i < len(matched); i++ {
|
||||
if matched[i-1].end > matched[i].start {
|
||||
return "", nil, fmt.Errorf("edits[%d] and edits[%d] overlap; merge them into a single edit",
|
||||
matched[i-1].index, matched[i].index)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply edits in reverse order (end to start) to maintain stable offsets
|
||||
result := normalizedContent
|
||||
for i := len(matched) - 1; i >= 0; i-- {
|
||||
m := matched[i]
|
||||
result = result[:m.start] + m.newText + result[m.end:]
|
||||
}
|
||||
|
||||
return result, matched, nil
|
||||
}
|
||||
|
||||
// findMatch finds a unique match for the edit in the content.
|
||||
// Returns error if not found or ambiguous.
|
||||
func findMatch(content string, edit replacement) (*matchedReplacement, error) {
|
||||
// Try exact match first
|
||||
count := strings.Count(content, edit.oldText)
|
||||
|
||||
if count == 0 {
|
||||
// Try fuzzy match
|
||||
idx, matchLen := fuzzyMatch(content, edit.oldText)
|
||||
if idx < 0 {
|
||||
return nil, fmt.Errorf("edits[%d]: could not find old_text in file. The text must match exactly (including whitespace)", edit.index)
|
||||
}
|
||||
// Use the matched text from content for the replacement
|
||||
matchedText := content[idx : idx+matchLen]
|
||||
return &matchedReplacement{
|
||||
replacement: replacement{
|
||||
oldText: matchedText,
|
||||
newText: edit.newText,
|
||||
originalOld: edit.originalOld,
|
||||
originalNew: edit.originalNew,
|
||||
index: edit.index,
|
||||
},
|
||||
start: idx,
|
||||
end: idx + matchLen,
|
||||
usedFuzzyMatch: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if count > 1 {
|
||||
return nil, fmt.Errorf("found %d matches for edits[%d].old_text; each old_text must be unique, provide more context to identify the correct match", count, edit.index)
|
||||
}
|
||||
|
||||
// Single exact match
|
||||
idx := strings.Index(content, edit.oldText)
|
||||
return &matchedReplacement{
|
||||
replacement: edit,
|
||||
start: idx,
|
||||
end: idx + len(edit.oldText),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// editDiffMeta builds the structured metadata attached to edit tool responses.
|
||||
func editDiffMeta(path, oldText, newText string) map[string]any {
|
||||
func editDiffMeta(path string, applied []matchedReplacement) map[string]any {
|
||||
var diffBlocks []map[string]any
|
||||
totalAdditions, totalDeletions := 0, 0
|
||||
|
||||
for _, m := range applied {
|
||||
diffBlocks = append(diffBlocks, map[string]any{
|
||||
"old_text": m.originalOld,
|
||||
"new_text": m.originalNew,
|
||||
})
|
||||
totalAdditions += strings.Count(m.originalNew, "\n") + 1
|
||||
totalDeletions += strings.Count(m.originalOld, "\n") + 1
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"file_diffs": []map[string]any{{
|
||||
"path": path,
|
||||
"additions": strings.Count(newText, "\n") + 1,
|
||||
"deletions": strings.Count(oldText, "\n") + 1,
|
||||
"diff_blocks": []map[string]any{{
|
||||
"old_text": oldText,
|
||||
"new_text": newText,
|
||||
}},
|
||||
"path": path,
|
||||
"additions": totalAdditions,
|
||||
"deletions": totalDeletions,
|
||||
"diff_blocks": diffBlocks,
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -715,3 +715,315 @@ func TestExecuteEdit_MetadataContainsFileDiffs(t *testing.T) {
|
||||
t.Fatal("file_diffs should be a non-empty array")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Multi-edit tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExecuteEdit_MultiEdit_Basic(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "multi.txt")
|
||||
writeFileOrFail(t, path, "line1\nline2\nline3\nline4\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "line1", NewText: "LINE1"},
|
||||
{OldText: "line3", NewText: "LINE3"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("tool returned error: %s", resp.Content)
|
||||
}
|
||||
|
||||
got, _ := os.ReadFile(path)
|
||||
gotStr := string(got)
|
||||
|
||||
if !strings.Contains(gotStr, "LINE1") {
|
||||
t.Error("first edit not applied: missing LINE1")
|
||||
}
|
||||
if !strings.Contains(gotStr, "LINE3") {
|
||||
t.Error("second edit not applied: missing LINE3")
|
||||
}
|
||||
if !strings.Contains(gotStr, "line2") {
|
||||
t.Error("line2 was modified but should be untouched")
|
||||
}
|
||||
if !strings.Contains(gotStr, "line4") {
|
||||
t.Error("line4 was modified but should be untouched")
|
||||
}
|
||||
|
||||
// Check response mentions multiple edits
|
||||
if !strings.Contains(resp.Content, "2 edits") {
|
||||
t.Errorf("response should mention '2 edits', got: %s", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_NonIncrementalMatching(t *testing.T) {
|
||||
// All edits are matched against the original content, not incrementally
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "noninc.txt")
|
||||
writeFileOrFail(t, path, "aaa\nbbb\nccc\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "aaa", NewText: "AAA"},
|
||||
{OldText: "bbb", NewText: "BBB"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("tool returned error: %s", resp.Content)
|
||||
}
|
||||
|
||||
got, _ := os.ReadFile(path)
|
||||
gotStr := string(got)
|
||||
|
||||
want := "AAA\nBBB\nccc\n"
|
||||
if gotStr != want {
|
||||
t.Errorf("got %q, want %q", gotStr, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_OverlapDetection(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "overlap.txt")
|
||||
writeFileOrFail(t, path, "hello world\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "hello", NewText: "HELLO"},
|
||||
{OldText: "hello world", NewText: "GOODBYE"}, // Overlaps with first edit
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error for overlapping edits")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "overlap") {
|
||||
t.Errorf("expected 'overlap' in error, got: %s", resp.Content)
|
||||
}
|
||||
|
||||
// File should be untouched
|
||||
got, _ := os.ReadFile(path)
|
||||
if string(got) != "hello world\n" {
|
||||
t.Error("file was modified despite error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_DuplicateDetection(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "dup.txt")
|
||||
writeFileOrFail(t, path, "hello\nworld\nhello\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "hello", NewText: "HELLO"},
|
||||
{OldText: "world", NewText: "WORLD"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error for ambiguous old_text (duplicate matches)")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "unique") {
|
||||
t.Errorf("expected 'unique' in error, got: %s", resp.Content)
|
||||
}
|
||||
|
||||
// File should be untouched
|
||||
got, _ := os.ReadFile(path)
|
||||
if string(got) != "hello\nworld\nhello\n" {
|
||||
t.Error("file was modified despite error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_NotFound(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "notfound.txt")
|
||||
writeFileOrFail(t, path, "hello world\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "nonexistent", NewText: "REPLACEMENT"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error for not found")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "edits[0]") {
|
||||
t.Errorf("expected 'edits[0]' in error, got: %s", resp.Content)
|
||||
}
|
||||
|
||||
// File should be untouched
|
||||
got, _ := os.ReadFile(path)
|
||||
if string(got) != "hello world\n" {
|
||||
t.Error("file was modified despite error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_EmptyArray(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "empty.txt")
|
||||
writeFileOrFail(t, path, "hello\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error for empty edits array")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_MixedWithSingleMode(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "mixed.txt")
|
||||
writeFileOrFail(t, path, "hello\n")
|
||||
|
||||
input, _ := json.Marshal(map[string]any{
|
||||
"path": path,
|
||||
"old_text": "hello",
|
||||
"new_text": "HELLO",
|
||||
"edits": []Edit{
|
||||
{OldText: "hello", NewText: "HI"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error when mixing single and multi-edit modes")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "cannot use") {
|
||||
t.Errorf("expected 'cannot use' in error, got: %s", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_FuzzyMatch(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "fuzzy_multi.txt")
|
||||
// File has trailing whitespace
|
||||
original := "func foo() { \n\treturn 1 \n}\nfunc bar() { \n\treturn 2 \n}\n"
|
||||
writeFileOrFail(t, path, original)
|
||||
|
||||
// Search without trailing whitespace (common LLM behavior)
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "func foo() {\n\treturn 1\n}", NewText: "func foo() {\n\treturn 10\n}"},
|
||||
{OldText: "func bar() {\n\treturn 2\n}", NewText: "func bar() {\n\treturn 20\n}"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("tool returned error: %s", resp.Content)
|
||||
}
|
||||
|
||||
got, _ := os.ReadFile(path)
|
||||
gotStr := string(got)
|
||||
|
||||
if !strings.Contains(gotStr, "return 10") {
|
||||
t.Error("first edit not applied")
|
||||
}
|
||||
if !strings.Contains(gotStr, "return 20") {
|
||||
t.Error("second edit not applied")
|
||||
}
|
||||
|
||||
// Response should mention fuzzy match
|
||||
if !strings.Contains(resp.Content, "fuzzy") {
|
||||
t.Errorf("response should mention 'fuzzy', got: %s", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_Metadata(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "meta_multi.txt")
|
||||
writeFileOrFail(t, path, "aaa\nbbb\nccc\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "aaa", NewText: "AAA"},
|
||||
{OldText: "bbb", NewText: "BBB"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("tool returned error: %s", resp.Content)
|
||||
}
|
||||
|
||||
var meta map[string]any
|
||||
if err := json.Unmarshal([]byte(resp.Metadata), &meta); err != nil {
|
||||
t.Fatalf("metadata is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
diffs, ok := meta["file_diffs"].([]any)
|
||||
if !ok || len(diffs) == 0 {
|
||||
t.Fatal("metadata missing file_diffs")
|
||||
}
|
||||
|
||||
firstDiff, ok := diffs[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("first diff is not an object")
|
||||
}
|
||||
|
||||
// Check that diff_blocks contains both edits
|
||||
diffBlocks, ok := firstDiff["diff_blocks"].([]any)
|
||||
if !ok || len(diffBlocks) != 2 {
|
||||
t.Fatalf("expected 2 diff_blocks, got %d", len(diffBlocks))
|
||||
}
|
||||
|
||||
// Verify each block has old_text and new_text
|
||||
for i, block := range diffBlocks {
|
||||
b, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("diff_block[%d] is not an object", i)
|
||||
}
|
||||
if _, ok := b["old_text"]; !ok {
|
||||
t.Fatalf("diff_block[%d] missing old_text", i)
|
||||
}
|
||||
if _, ok := b["new_text"]; !ok {
|
||||
t.Fatalf("diff_block[%d] missing new_text", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,14 +28,14 @@ type SubagentSpawnResult struct {
|
||||
// SubagentSpawnFunc is a callback that spawns an in-process subagent. The
|
||||
// parent Kit instance injects this into the context so the core tool can
|
||||
// call back without importing pkg/kit (which would create a cycle).
|
||||
// The toolCallID parameter is the LLM-assigned ID of the spawn_subagent
|
||||
// The toolCallID parameter is the LLM-assigned ID of the subagent
|
||||
// tool call, enabling the parent to correlate subagent events.
|
||||
type SubagentSpawnFunc func(ctx context.Context, toolCallID, prompt, model, systemPrompt string, timeout time.Duration) (*SubagentSpawnResult, error)
|
||||
|
||||
type subagentCtxKey struct{}
|
||||
|
||||
// WithSubagentSpawner stores a spawn function in the context so that the
|
||||
// spawn_subagent core tool can create in-process subagents.
|
||||
// subagent core tool can create in-process subagents.
|
||||
func WithSubagentSpawner(ctx context.Context, fn SubagentSpawnFunc) context.Context {
|
||||
return context.WithValue(ctx, subagentCtxKey{}, fn)
|
||||
}
|
||||
@@ -49,7 +49,7 @@ func getSubagentSpawner(ctx context.Context) SubagentSpawnFunc {
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// spawn_subagent tool
|
||||
// subagent tool
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type subagentArgs struct {
|
||||
@@ -59,11 +59,11 @@ type subagentArgs struct {
|
||||
TimeoutSeconds int `json:"timeout_seconds,omitempty"`
|
||||
}
|
||||
|
||||
// NewSubagentTool creates the spawn_subagent core tool.
|
||||
// NewSubagentTool creates the subagent core tool.
|
||||
func NewSubagentTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
return &coreTool{
|
||||
info: fantasy.ToolInfo{
|
||||
Name: "spawn_subagent",
|
||||
Name: "subagent",
|
||||
Description: `Spawn a subagent to perform a task autonomously.
|
||||
|
||||
The subagent runs as a separate in-process Kit instance with full tool access
|
||||
|
||||
@@ -86,7 +86,7 @@ func ReadOnlyTools(opts ...ToolOption) []fantasy.AgentTool {
|
||||
}
|
||||
}
|
||||
|
||||
// SubagentTools returns all core tools except spawn_subagent. This prevents
|
||||
// SubagentTools returns all core tools except subagent. This prevents
|
||||
// infinite recursion when a subagent is itself a Kit instance.
|
||||
func SubagentTools(opts ...ToolOption) []fantasy.AgentTool {
|
||||
return []fantasy.AgentTool{
|
||||
|
||||
+319
-1
@@ -572,6 +572,102 @@ type Context struct {
|
||||
// })
|
||||
// // handle.Kill() to cancel, handle.Wait() to block
|
||||
SpawnSubagent func(SubagentConfig) (*SubagentHandle, *SubagentResult, error)
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Tree Navigation API (Phase 1 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// GetTreeNode returns a node by ID with full metadata and children.
|
||||
// Returns nil if entry not found.
|
||||
GetTreeNode func(entryID string) *TreeNode
|
||||
|
||||
// GetCurrentBranch returns the path from root to current leaf.
|
||||
// Each node contains full metadata (unlike GetMessages which flattens).
|
||||
GetCurrentBranch func() []TreeNode
|
||||
|
||||
// GetChildren returns direct child IDs of an entry.
|
||||
GetChildren func(entryID string) []string
|
||||
|
||||
// NavigateTo branches/forks the session to the specified entry ID.
|
||||
// Equivalent to SDK's Branch() but for extensions.
|
||||
NavigateTo func(entryID string) TreeNavigationResult
|
||||
|
||||
// SummarizeBranch uses LLM to summarize a branch range.
|
||||
// Returns summary text or error string (empty if success).
|
||||
SummarizeBranch func(fromID, toID string) string
|
||||
|
||||
// CollapseBranch replaces a branch range with a summary entry.
|
||||
// This is the "fresh context" primitive for context window management.
|
||||
CollapseBranch func(fromID, toID, summary string) TreeNavigationResult
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Skill Loading API (Phase 2 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// LoadSkill loads a single skill file from path.
|
||||
// Parses YAML frontmatter, returns skill with content ready for injection.
|
||||
LoadSkill func(path string) (*Skill, string)
|
||||
|
||||
// LoadSkillsFromDir discovers and loads all skills from a directory.
|
||||
LoadSkillsFromDir func(dir string) SkillLoadResult
|
||||
|
||||
// DiscoverSkills finds skills in standard locations.
|
||||
// Checks ~/.config/kit/skills/, .kit/skills/, .agents/skills/
|
||||
DiscoverSkills func() SkillLoadResult
|
||||
|
||||
// InjectSkillAsContext sends a skill's content as a system message.
|
||||
// Looks up skill by name from discovered skills.
|
||||
InjectSkillAsContext func(skillName string) string
|
||||
|
||||
// InjectRawSkillAsContext loads and immediately injects a skill file.
|
||||
InjectRawSkillAsContext func(path string) string
|
||||
|
||||
// GetAvailableSkills returns all currently loaded/discovered skills.
|
||||
GetAvailableSkills func() []Skill
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Template Parsing API (Phase 3 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// ParseTemplate extracts {{variables}} from template content.
|
||||
ParseTemplate func(name, content string) PromptTemplate
|
||||
|
||||
// RenderTemplate substitutes variables into template content.
|
||||
RenderTemplate func(tpl PromptTemplate, vars map[string]string) string
|
||||
|
||||
// ParseArguments parses command-line style arguments.
|
||||
ParseArguments func(input string, pattern ArgumentPattern) ParseResult
|
||||
|
||||
// SimpleParseArguments parses $1, $2, $@ style arguments.
|
||||
// Returns slice where [0]=full input, [1]=$1, [2]=$2, ... [n]=$@
|
||||
SimpleParseArguments func(input string, count int) []string
|
||||
|
||||
// EvaluateModelConditional checks if condition matches current model.
|
||||
// Condition supports wildcards: * matches any, ? matches single char.
|
||||
EvaluateModelConditional func(condition string) bool
|
||||
|
||||
// RenderWithModelConditionals processes <if-model> blocks in content.
|
||||
RenderWithModelConditionals func(content string) string
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Model Resolution API (Phase 4 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// ResolveModelChain attempts each model in order until one is available.
|
||||
ResolveModelChain func(preferences []string) ModelResolutionResult
|
||||
|
||||
// GetModelCapabilities returns capabilities for a specific model.
|
||||
// If model is empty, uses current model.
|
||||
GetModelCapabilities func(model string) (ModelCapabilities, string)
|
||||
|
||||
// CheckModelAvailable verifies if a model string is valid.
|
||||
CheckModelAvailable func(model string) bool
|
||||
|
||||
// GetCurrentProvider returns just the provider part of current model.
|
||||
GetCurrentProvider func() string
|
||||
|
||||
// GetCurrentModelID returns just the model ID part of current model.
|
||||
GetCurrentModelID func() string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -598,6 +694,148 @@ type SessionMessage struct {
|
||||
Timestamp string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tree navigation types (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TreeNode represents a node in the session tree for navigation.
|
||||
// Extensions use this to traverse conversation history and implement
|
||||
// features like "fresh context" loops and branch summarization.
|
||||
type TreeNode struct {
|
||||
// ID is the unique entry identifier.
|
||||
ID string
|
||||
// ParentID links this entry to its parent (empty if root).
|
||||
ParentID string
|
||||
// Type is the entry type: "message", "branch_summary", "model_change", "extension_data", "tool_execution".
|
||||
Type string
|
||||
// Role is the message role for message entries: "user", "assistant", "system", "tool".
|
||||
Role string
|
||||
// Content is the text content or summary.
|
||||
Content string
|
||||
// Model is the model that generated this (for assistant messages).
|
||||
Model string
|
||||
// Provider is the provider used.
|
||||
Provider string
|
||||
// Timestamp is the RFC3339-formatted creation time.
|
||||
Timestamp string
|
||||
// Children is the list of child entry IDs for tree traversal.
|
||||
Children []string
|
||||
}
|
||||
|
||||
// TreeNavigationResult reports success or failure of tree operations.
|
||||
type TreeNavigationResult struct {
|
||||
// Success is true if the operation completed.
|
||||
Success bool
|
||||
// Error describes what went wrong (empty if success).
|
||||
Error string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Skill types (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Skill represents a loaded skill file with parsed YAML frontmatter.
|
||||
type Skill struct {
|
||||
// Name is the human-readable identifier.
|
||||
Name string
|
||||
// Description summarizes what this skill provides.
|
||||
Description string
|
||||
// Content is the markdown body (frontmatter stripped).
|
||||
Content string
|
||||
// Path is the absolute filesystem path.
|
||||
Path string
|
||||
// Tags are optional labels for categorization.
|
||||
Tags []string
|
||||
// When controls automatic inclusion: "always", "on-demand", or file-glob.
|
||||
When string
|
||||
}
|
||||
|
||||
// SkillLoadResult reports skills loaded from a directory.
|
||||
type SkillLoadResult struct {
|
||||
// Skills is the list of loaded skills.
|
||||
Skills []Skill
|
||||
// Error describes loading failures (empty if success).
|
||||
Error string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Template parsing types (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// PromptTemplate represents a parsed template with variable placeholders.
|
||||
type PromptTemplate struct {
|
||||
// Name is the template identifier.
|
||||
Name string
|
||||
// Content is the original template content.
|
||||
Content string
|
||||
// Variables are the extracted {{variable}} names.
|
||||
Variables []string
|
||||
}
|
||||
|
||||
// ArgumentPattern defines how to parse command arguments.
|
||||
type ArgumentPattern struct {
|
||||
// Positional names for $1, $2, etc.
|
||||
Positional []string
|
||||
// Rest is the variable name for $@ (all remaining).
|
||||
Rest string
|
||||
// Flags maps flag names to variable names (e.g., "--loop" -> "loop").
|
||||
Flags map[string]string
|
||||
}
|
||||
|
||||
// ParseResult reports argument parsing outcome.
|
||||
type ParseResult struct {
|
||||
// Vars maps variable names to values for positional args.
|
||||
Vars map[string]string
|
||||
// Flags maps flag names to values.
|
||||
Flags map[string]string
|
||||
// Rest is remaining unparsed text.
|
||||
Rest string
|
||||
// Error describes parsing failures (empty if success).
|
||||
Error string
|
||||
}
|
||||
|
||||
// ModelConditional represents an <if-model> block for evaluation.
|
||||
type ModelConditional struct {
|
||||
// Condition is the model pattern (e.g., "claude-*", "anthropic/*").
|
||||
Condition string
|
||||
// Content is rendered if condition matches.
|
||||
Content string
|
||||
// Else is rendered if condition doesn't match.
|
||||
Else string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Model resolution types (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ModelCapabilities describes what a model supports.
|
||||
type ModelCapabilities struct {
|
||||
// Provider is the provider ID (e.g., "anthropic").
|
||||
Provider string
|
||||
// ModelID is the model identifier (e.g., "claude-sonnet-4-20250929").
|
||||
ModelID string
|
||||
// ContextLimit is the maximum context window in tokens.
|
||||
ContextLimit int
|
||||
// OutputLimit is the maximum output tokens.
|
||||
OutputLimit int
|
||||
// Reasoning indicates if the model supports reasoning/thinking.
|
||||
Reasoning bool
|
||||
// Streaming indicates if the model supports streaming.
|
||||
Streaming bool
|
||||
}
|
||||
|
||||
// ModelResolutionResult reports model chain resolution outcome.
|
||||
type ModelResolutionResult struct {
|
||||
// Model is the selected model in "provider/model" format.
|
||||
Model string
|
||||
// Capabilities describes the selected model.
|
||||
Capabilities ModelCapabilities
|
||||
// Attempted lists models tried before success.
|
||||
Attempted []string
|
||||
// Error describes resolution failures (empty if success).
|
||||
Error string
|
||||
}
|
||||
|
||||
// ExtensionEntry represents persisted extension data stored in the session.
|
||||
// Extensions use AppendEntry to save custom state and GetEntries to retrieve
|
||||
// it on session resume.
|
||||
@@ -750,6 +988,9 @@ type API struct {
|
||||
registerOption func(OptionDef)
|
||||
registerShortcutFn func(ShortcutDef, func(Context))
|
||||
registerMessageRendererFn func(MessageRendererConfig)
|
||||
onSubagentStart func(func(SubagentStartEvent, Context))
|
||||
onSubagentChunk func(func(SubagentChunkEvent, Context))
|
||||
onSubagentEnd func(func(SubagentEndEvent, Context))
|
||||
}
|
||||
|
||||
// OnToolCall registers a handler that fires before a tool executes.
|
||||
@@ -781,6 +1022,27 @@ func (a *API) OnToolResult(handler func(ToolResultEvent, Context) *ToolResultRes
|
||||
a.onToolResult(handler)
|
||||
}
|
||||
|
||||
// OnSubagentStart registers a handler that fires when a subagent tool
|
||||
// call begins executing. Use the ToolCallID to correlate with subsequent
|
||||
// OnSubagentChunk and OnSubagentEnd events for the same subagent.
|
||||
func (a *API) OnSubagentStart(handler func(SubagentStartEvent, Context)) {
|
||||
a.onSubagentStart(handler)
|
||||
}
|
||||
|
||||
// OnSubagentChunk registers a handler for real-time events from a running
|
||||
// subagent. ChunkType identifies the kind of event ("text", "tool_call",
|
||||
// "tool_result", "tool_execution_start", "tool_execution_end", etc.).
|
||||
// Correlate with OnSubagentStart via the ToolCallID field.
|
||||
func (a *API) OnSubagentChunk(handler func(SubagentChunkEvent, Context)) {
|
||||
a.onSubagentChunk(handler)
|
||||
}
|
||||
|
||||
// OnSubagentEnd registers a handler that fires when a subagent call
|
||||
// completes. ErrorMsg is non-empty when the subagent failed.
|
||||
func (a *API) OnSubagentEnd(handler func(SubagentEndEvent, Context)) {
|
||||
a.onSubagentEnd(handler)
|
||||
}
|
||||
|
||||
// OnInput registers a handler that fires when user input is received.
|
||||
// Return a non-nil InputResult to transform or handle the input.
|
||||
func (a *API) OnInput(handler func(InputEvent, Context) *InputResult) {
|
||||
@@ -1781,9 +2043,65 @@ type BeforeCompactResult struct {
|
||||
func (BeforeCompactResult) isResult() {}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Theme types (exposed to Yaegi — concrete structs, string hex colors)
|
||||
// Subagent lifecycle events (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SubagentStartEvent fires when a subagent tool call begins executing.
|
||||
type SubagentStartEvent struct {
|
||||
// ToolCallID is the LLM-assigned ID of the subagent tool call.
|
||||
// Use this to correlate SubagentChunkEvent and SubagentEndEvent.
|
||||
ToolCallID string
|
||||
// Task is the task description passed to the subagent.
|
||||
Task string
|
||||
}
|
||||
|
||||
func (e SubagentStartEvent) Type() EventType { return SubagentStart }
|
||||
|
||||
// SubagentChunkEvent fires for each real-time event from a running subagent.
|
||||
// Type field indicates the kind of event; read the relevant fields accordingly.
|
||||
type SubagentChunkEvent struct {
|
||||
// ToolCallID matches the SubagentStartEvent.ToolCallID for this subagent.
|
||||
ToolCallID string
|
||||
// Task is the task description (repeated for convenience).
|
||||
Task string
|
||||
// ChunkType identifies the event kind:
|
||||
// "text" — LLM text chunk (read Content)
|
||||
// "reasoning" — reasoning/thinking delta (read Content)
|
||||
// "tool_call" — subagent called a tool (read ToolName, ToolArgs)
|
||||
// "tool_result" — tool returned a result (read ToolName, ToolResult, IsError)
|
||||
// "tool_execution_start" — tool began executing (read ToolName)
|
||||
// "tool_execution_end" — tool finished executing (read ToolName)
|
||||
// "turn_start" — subagent turn began
|
||||
// "turn_end" — subagent turn ended
|
||||
ChunkType string
|
||||
// Content carries text for "text" and "reasoning" chunk types.
|
||||
Content string
|
||||
// ToolName is set on tool-related chunk types.
|
||||
ToolName string
|
||||
// ToolArgs is the JSON-encoded tool arguments for "tool_call" chunks.
|
||||
ToolArgs string
|
||||
// ToolResult is the tool output for "tool_result" chunks.
|
||||
ToolResult string
|
||||
// IsError is true when a "tool_result" chunk represents an error.
|
||||
IsError bool
|
||||
}
|
||||
|
||||
func (e SubagentChunkEvent) Type() EventType { return SubagentChunk }
|
||||
|
||||
// SubagentEndEvent fires when a subagent tool call completes.
|
||||
type SubagentEndEvent struct {
|
||||
// ToolCallID matches the SubagentStartEvent.ToolCallID for this subagent.
|
||||
ToolCallID string
|
||||
// Task is the task description.
|
||||
Task string
|
||||
// Response is the subagent's final text response (empty on error).
|
||||
Response string
|
||||
// ErrorMsg is non-empty when the subagent failed.
|
||||
ErrorMsg string
|
||||
}
|
||||
|
||||
func (e SubagentEndEvent) Type() EventType { return SubagentEnd }
|
||||
|
||||
// ThemeColor is an adaptive color pair with light and dark hex values.
|
||||
// Either field may be empty to inherit from the default theme.
|
||||
type ThemeColor struct {
|
||||
|
||||
@@ -71,6 +71,18 @@ const (
|
||||
// BeforeCompact fires before context compaction runs. Handlers can
|
||||
// cancel compaction by returning Cancel=true.
|
||||
BeforeCompact EventType = "before_compact"
|
||||
|
||||
// SubagentStart fires when a subagent tool call begins executing.
|
||||
// Carries the tool call ID and the task description.
|
||||
SubagentStart EventType = "subagent_start"
|
||||
|
||||
// SubagentChunk fires for each real-time event emitted by a running
|
||||
// subagent: text chunks, tool calls, tool results, etc.
|
||||
SubagentChunk EventType = "subagent_chunk"
|
||||
|
||||
// SubagentEnd fires when a subagent tool call completes (success
|
||||
// or error). Carries the final response and any error message.
|
||||
SubagentEnd EventType = "subagent_end"
|
||||
)
|
||||
|
||||
// AllEventTypes returns every supported event type.
|
||||
@@ -82,6 +94,7 @@ func AllEventTypes() []EventType {
|
||||
SessionStart, SessionShutdown,
|
||||
ModelChange, ContextPrepare,
|
||||
BeforeFork, BeforeSessionSwitch, BeforeCompact,
|
||||
SubagentStart, SubagentChunk, SubagentEnd,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ import "testing"
|
||||
|
||||
func TestAllEventTypes_Count(t *testing.T) {
|
||||
all := AllEventTypes()
|
||||
if len(all) != 18 {
|
||||
t.Fatalf("expected 18 event types, got %d", len(all))
|
||||
if len(all) != 21 {
|
||||
t.Fatalf("expected 21 event types, got %d", len(all))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,6 +55,9 @@ func TestEventType_TypeMethod(t *testing.T) {
|
||||
{BeforeForkEvent{TargetID: "abc"}, BeforeFork},
|
||||
{BeforeSessionSwitchEvent{Reason: "new"}, BeforeSessionSwitch},
|
||||
{BeforeCompactEvent{EstimatedTokens: 1000}, BeforeCompact},
|
||||
{SubagentStartEvent{ToolCallID: "x", Task: "t"}, SubagentStart},
|
||||
{SubagentChunkEvent{ToolCallID: "x", ChunkType: "text"}, SubagentChunk},
|
||||
{SubagentEndEvent{ToolCallID: "x"}, SubagentEnd},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -580,6 +580,24 @@ func loadSingleExtension(path string) (*LoadedExtension, error) {
|
||||
registerShortcutFn: func(def ShortcutDef, handler func(Context)) {
|
||||
ext.Shortcuts = append(ext.Shortcuts, ShortcutEntry{Def: def, Handler: handler})
|
||||
},
|
||||
onSubagentStart: func(h func(SubagentStartEvent, Context)) {
|
||||
reg(SubagentStart, func(e Event, c Context) Result {
|
||||
h(e.(SubagentStartEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onSubagentChunk: func(h func(SubagentChunkEvent, Context)) {
|
||||
reg(SubagentChunk, func(e Event, c Context) Result {
|
||||
h(e.(SubagentChunkEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onSubagentEnd: func(h func(SubagentEndEvent, Context)) {
|
||||
reg(SubagentEnd, func(e Event, c Context) Result {
|
||||
h(e.(SubagentEndEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
// Call Init — the extension registers its handlers, tools, commands.
|
||||
|
||||
@@ -56,11 +56,261 @@ func NewRunner(exts []LoadedExtension) *Runner {
|
||||
}
|
||||
|
||||
// SetContext updates the runtime context (session ID, model, etc.) that is
|
||||
// passed to every handler invocation. Thread-safe.
|
||||
// passed to every handler invocation. Nil function fields are replaced with
|
||||
// safe no-ops so extension handlers never panic on a missing callback.
|
||||
// Thread-safe.
|
||||
func (r *Runner) SetContext(ctx Context) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.ctx = ctx
|
||||
r.ctx = normalizeContext(ctx)
|
||||
}
|
||||
|
||||
// normalizeContext replaces nil function fields in ctx with no-op stubs so
|
||||
// that extension handlers can call any ctx method without a nil-function panic.
|
||||
func normalizeContext(ctx Context) Context {
|
||||
if ctx.Print == nil {
|
||||
ctx.Print = func(string) {}
|
||||
}
|
||||
if ctx.PrintInfo == nil {
|
||||
ctx.PrintInfo = func(string) {}
|
||||
}
|
||||
if ctx.PrintError == nil {
|
||||
ctx.PrintError = func(string) {}
|
||||
}
|
||||
if ctx.PrintBlock == nil {
|
||||
ctx.PrintBlock = func(PrintBlockOpts) {}
|
||||
}
|
||||
if ctx.SendMessage == nil {
|
||||
ctx.SendMessage = func(string) {}
|
||||
}
|
||||
if ctx.CancelAndSend == nil {
|
||||
ctx.CancelAndSend = func(string) {}
|
||||
}
|
||||
if ctx.SetWidget == nil {
|
||||
ctx.SetWidget = func(WidgetConfig) {}
|
||||
}
|
||||
if ctx.RemoveWidget == nil {
|
||||
ctx.RemoveWidget = func(string) {}
|
||||
}
|
||||
if ctx.SetHeader == nil {
|
||||
ctx.SetHeader = func(HeaderFooterConfig) {}
|
||||
}
|
||||
if ctx.RemoveHeader == nil {
|
||||
ctx.RemoveHeader = func() {}
|
||||
}
|
||||
if ctx.SetFooter == nil {
|
||||
ctx.SetFooter = func(HeaderFooterConfig) {}
|
||||
}
|
||||
if ctx.RemoveFooter == nil {
|
||||
ctx.RemoveFooter = func() {}
|
||||
}
|
||||
if ctx.PromptSelect == nil {
|
||||
ctx.PromptSelect = func(PromptSelectConfig) PromptSelectResult {
|
||||
return PromptSelectResult{Cancelled: true}
|
||||
}
|
||||
}
|
||||
if ctx.PromptConfirm == nil {
|
||||
ctx.PromptConfirm = func(PromptConfirmConfig) PromptConfirmResult {
|
||||
return PromptConfirmResult{Cancelled: true}
|
||||
}
|
||||
}
|
||||
if ctx.PromptInput == nil {
|
||||
ctx.PromptInput = func(PromptInputConfig) PromptInputResult {
|
||||
return PromptInputResult{Cancelled: true}
|
||||
}
|
||||
}
|
||||
if ctx.PromptMultiSelect == nil {
|
||||
ctx.PromptMultiSelect = func(PromptMultiSelectConfig) PromptMultiSelectResult {
|
||||
return PromptMultiSelectResult{Cancelled: true}
|
||||
}
|
||||
}
|
||||
if ctx.ShowOverlay == nil {
|
||||
ctx.ShowOverlay = func(OverlayConfig) OverlayResult {
|
||||
return OverlayResult{Cancelled: true, Index: -1}
|
||||
}
|
||||
}
|
||||
if ctx.SetEditor == nil {
|
||||
ctx.SetEditor = func(EditorConfig) {}
|
||||
}
|
||||
if ctx.ResetEditor == nil {
|
||||
ctx.ResetEditor = func() {}
|
||||
}
|
||||
if ctx.SetEditorText == nil {
|
||||
ctx.SetEditorText = func(string) {}
|
||||
}
|
||||
if ctx.SetUIVisibility == nil {
|
||||
ctx.SetUIVisibility = func(UIVisibility) {}
|
||||
}
|
||||
if ctx.SetStatus == nil {
|
||||
ctx.SetStatus = func(string, string, int) {}
|
||||
}
|
||||
if ctx.RemoveStatus == nil {
|
||||
ctx.RemoveStatus = func(string) {}
|
||||
}
|
||||
if ctx.GetContextStats == nil {
|
||||
ctx.GetContextStats = func() ContextStats { return ContextStats{} }
|
||||
}
|
||||
if ctx.GetMessages == nil {
|
||||
ctx.GetMessages = func() []SessionMessage { return nil }
|
||||
}
|
||||
if ctx.GetSessionPath == nil {
|
||||
ctx.GetSessionPath = func() string { return "" }
|
||||
}
|
||||
if ctx.AppendEntry == nil {
|
||||
ctx.AppendEntry = func(string, string) (string, error) { return "", nil }
|
||||
}
|
||||
if ctx.GetEntries == nil {
|
||||
ctx.GetEntries = func(string) []ExtensionEntry { return nil }
|
||||
}
|
||||
if ctx.GetOption == nil {
|
||||
ctx.GetOption = func(string) string { return "" }
|
||||
}
|
||||
if ctx.SetOption == nil {
|
||||
ctx.SetOption = func(string, string) {}
|
||||
}
|
||||
if ctx.SetModel == nil {
|
||||
ctx.SetModel = func(string) error { return nil }
|
||||
}
|
||||
if ctx.GetAvailableModels == nil {
|
||||
ctx.GetAvailableModels = func() []ModelInfoEntry { return nil }
|
||||
}
|
||||
if ctx.EmitCustomEvent == nil {
|
||||
ctx.EmitCustomEvent = func(string, string) {}
|
||||
}
|
||||
if ctx.GetAllTools == nil {
|
||||
ctx.GetAllTools = func() []ToolInfo { return nil }
|
||||
}
|
||||
if ctx.SetActiveTools == nil {
|
||||
ctx.SetActiveTools = func([]string) {}
|
||||
}
|
||||
if ctx.Exit == nil {
|
||||
ctx.Exit = func() {}
|
||||
}
|
||||
if ctx.Complete == nil {
|
||||
ctx.Complete = func(CompleteRequest) (CompleteResponse, error) {
|
||||
return CompleteResponse{}, nil
|
||||
}
|
||||
}
|
||||
if ctx.SuspendTUI == nil {
|
||||
ctx.SuspendTUI = func(callback func()) error { callback(); return nil }
|
||||
}
|
||||
if ctx.RenderMessage == nil {
|
||||
ctx.RenderMessage = func(string, string) {}
|
||||
}
|
||||
if ctx.RegisterTheme == nil {
|
||||
ctx.RegisterTheme = func(string, ThemeColorConfig) {}
|
||||
}
|
||||
if ctx.SetTheme == nil {
|
||||
ctx.SetTheme = func(string) error { return nil }
|
||||
}
|
||||
if ctx.ListThemes == nil {
|
||||
ctx.ListThemes = func() []string { return nil }
|
||||
}
|
||||
if ctx.ReloadExtensions == nil {
|
||||
ctx.ReloadExtensions = func() error { return nil }
|
||||
}
|
||||
if ctx.SpawnSubagent == nil {
|
||||
ctx.SpawnSubagent = func(SubagentConfig) (*SubagentHandle, *SubagentResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Tree Navigation API no-ops
|
||||
// -------------------------------------------------------------------------
|
||||
if ctx.GetTreeNode == nil {
|
||||
ctx.GetTreeNode = func(string) *TreeNode { return nil }
|
||||
}
|
||||
if ctx.GetCurrentBranch == nil {
|
||||
ctx.GetCurrentBranch = func() []TreeNode { return nil }
|
||||
}
|
||||
if ctx.GetChildren == nil {
|
||||
ctx.GetChildren = func(string) []string { return nil }
|
||||
}
|
||||
if ctx.NavigateTo == nil {
|
||||
ctx.NavigateTo = func(string) TreeNavigationResult {
|
||||
return TreeNavigationResult{Success: false, Error: "not implemented"}
|
||||
}
|
||||
}
|
||||
if ctx.SummarizeBranch == nil {
|
||||
ctx.SummarizeBranch = func(string, string) string {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
if ctx.CollapseBranch == nil {
|
||||
ctx.CollapseBranch = func(string, string, string) TreeNavigationResult {
|
||||
return TreeNavigationResult{Success: false, Error: "not implemented"}
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Skill Loading API no-ops
|
||||
// -------------------------------------------------------------------------
|
||||
if ctx.LoadSkill == nil {
|
||||
ctx.LoadSkill = func(string) (*Skill, string) { return nil, "" }
|
||||
}
|
||||
if ctx.LoadSkillsFromDir == nil {
|
||||
ctx.LoadSkillsFromDir = func(string) SkillLoadResult { return SkillLoadResult{} }
|
||||
}
|
||||
if ctx.DiscoverSkills == nil {
|
||||
ctx.DiscoverSkills = func() SkillLoadResult { return SkillLoadResult{} }
|
||||
}
|
||||
if ctx.InjectSkillAsContext == nil {
|
||||
ctx.InjectSkillAsContext = func(string) string { return "" }
|
||||
}
|
||||
if ctx.InjectRawSkillAsContext == nil {
|
||||
ctx.InjectRawSkillAsContext = func(string) string { return "" }
|
||||
}
|
||||
if ctx.GetAvailableSkills == nil {
|
||||
ctx.GetAvailableSkills = func() []Skill { return nil }
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Template Parsing API no-ops
|
||||
// -------------------------------------------------------------------------
|
||||
if ctx.ParseTemplate == nil {
|
||||
ctx.ParseTemplate = func(string, string) PromptTemplate { return PromptTemplate{} }
|
||||
}
|
||||
if ctx.RenderTemplate == nil {
|
||||
ctx.RenderTemplate = func(PromptTemplate, map[string]string) string { return "" }
|
||||
}
|
||||
if ctx.ParseArguments == nil {
|
||||
ctx.ParseArguments = func(string, ArgumentPattern) ParseResult { return ParseResult{} }
|
||||
}
|
||||
if ctx.SimpleParseArguments == nil {
|
||||
ctx.SimpleParseArguments = func(string, int) []string { return nil }
|
||||
}
|
||||
if ctx.EvaluateModelConditional == nil {
|
||||
ctx.EvaluateModelConditional = func(string) bool { return false }
|
||||
}
|
||||
if ctx.RenderWithModelConditionals == nil {
|
||||
ctx.RenderWithModelConditionals = func(string) string { return "" }
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Model Resolution API no-ops
|
||||
// -------------------------------------------------------------------------
|
||||
if ctx.ResolveModelChain == nil {
|
||||
ctx.ResolveModelChain = func([]string) ModelResolutionResult {
|
||||
return ModelResolutionResult{Error: "not implemented"}
|
||||
}
|
||||
}
|
||||
if ctx.GetModelCapabilities == nil {
|
||||
ctx.GetModelCapabilities = func(string) (ModelCapabilities, string) {
|
||||
return ModelCapabilities{}, "not implemented"
|
||||
}
|
||||
}
|
||||
if ctx.CheckModelAvailable == nil {
|
||||
ctx.CheckModelAvailable = func(string) bool { return false }
|
||||
}
|
||||
if ctx.GetCurrentProvider == nil {
|
||||
ctx.GetCurrentProvider = func() string { return "" }
|
||||
}
|
||||
if ctx.GetCurrentModelID == nil {
|
||||
ctx.GetCurrentModelID = func() string { return "" }
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
// GetContext returns a snapshot of the current runtime context. Thread-safe.
|
||||
|
||||
@@ -119,10 +119,33 @@ func Symbols() interp.Exports {
|
||||
"SubagentHandle": reflect.ValueOf((*SubagentHandle)(nil)),
|
||||
"SubagentEvent": reflect.ValueOf((*SubagentEvent)(nil)),
|
||||
|
||||
// Subagent lifecycle events
|
||||
"SubagentStartEvent": reflect.ValueOf((*SubagentStartEvent)(nil)),
|
||||
"SubagentChunkEvent": reflect.ValueOf((*SubagentChunkEvent)(nil)),
|
||||
"SubagentEndEvent": reflect.ValueOf((*SubagentEndEvent)(nil)),
|
||||
|
||||
// Theme types
|
||||
"ThemeColor": reflect.ValueOf((*ThemeColor)(nil)),
|
||||
"ThemeColorConfig": reflect.ValueOf((*ThemeColorConfig)(nil)),
|
||||
|
||||
// Tree navigation types
|
||||
"TreeNode": reflect.ValueOf((*TreeNode)(nil)),
|
||||
"TreeNavigationResult": reflect.ValueOf((*TreeNavigationResult)(nil)),
|
||||
|
||||
// Skill types
|
||||
"Skill": reflect.ValueOf((*Skill)(nil)),
|
||||
"SkillLoadResult": reflect.ValueOf((*SkillLoadResult)(nil)),
|
||||
|
||||
// Template parsing types
|
||||
"PromptTemplate": reflect.ValueOf((*PromptTemplate)(nil)),
|
||||
"ArgumentPattern": reflect.ValueOf((*ArgumentPattern)(nil)),
|
||||
"ParseResult": reflect.ValueOf((*ParseResult)(nil)),
|
||||
"ModelConditional": reflect.ValueOf((*ModelConditional)(nil)),
|
||||
|
||||
// Model resolution types
|
||||
"ModelCapabilities": reflect.ValueOf((*ModelCapabilities)(nil)),
|
||||
"ModelResolutionResult": reflect.ValueOf((*ModelResolutionResult)(nil)),
|
||||
|
||||
// Event structs
|
||||
"ToolCallEvent": reflect.ValueOf((*ToolCallEvent)(nil)),
|
||||
"ToolCallResult": reflect.ValueOf((*ToolCallResult)(nil)),
|
||||
|
||||
@@ -171,5 +171,23 @@ func NewTestAPI(ext *LoadedExtension) API {
|
||||
registerMessageRendererFn: func(config MessageRendererConfig) {
|
||||
ext.MessageRenderers = append(ext.MessageRenderers, config)
|
||||
},
|
||||
onSubagentStart: func(h func(SubagentStartEvent, Context)) {
|
||||
reg(SubagentStart, func(e Event, c Context) Result {
|
||||
h(e.(SubagentStartEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onSubagentChunk: func(h func(SubagentChunkEvent, Context)) {
|
||||
reg(SubagentChunk, func(e Event, c Context) Result {
|
||||
h(e.(SubagentChunkEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onSubagentEnd: func(h func(SubagentEndEvent, Context)) {
|
||||
reg(SubagentEnd, func(e Event, c Context) Result {
|
||||
h(e.(SubagentEndEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ var coreToolKinds = map[string]string{
|
||||
"ls": "read",
|
||||
"grep": "search",
|
||||
"find": "search",
|
||||
"spawn_subagent": "agent",
|
||||
"subagent": "agent",
|
||||
}
|
||||
|
||||
// toolKindFor returns the ToolKind for a given tool name, defaulting to
|
||||
|
||||
@@ -4,11 +4,38 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// sanitizeToolCallID ensures the ID matches Anthropic's required pattern:
|
||||
// ^[a-zA-Z0-9_-]+$ (alphanumeric, underscores, and hyphens only).
|
||||
// Invalid characters are replaced with underscores.
|
||||
func sanitizeToolCallID(id string) string {
|
||||
var sb strings.Builder
|
||||
for _, r := range id {
|
||||
switch {
|
||||
case (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z'):
|
||||
sb.WriteRune(r)
|
||||
case r >= '0' && r <= '9':
|
||||
sb.WriteRune(r)
|
||||
case r == '_' || r == '-':
|
||||
sb.WriteRune(r)
|
||||
default:
|
||||
// Replace invalid characters with underscore
|
||||
sb.WriteByte('_')
|
||||
}
|
||||
}
|
||||
result := sb.String()
|
||||
// Ensure non-empty (Anthropic requires at least one character)
|
||||
if result == "" {
|
||||
return "tool_0"
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ContentPart is the marker interface for all message content block types.
|
||||
// A message contains a heterogeneous slice of ContentPart values, enabling
|
||||
// rich structured messages that carry text, reasoning, tool calls, tool
|
||||
@@ -312,7 +339,7 @@ func (m *Message) ToFantasyMessages() []fantasy.Message {
|
||||
// Add tool calls
|
||||
for _, tc := range m.ToolCalls() {
|
||||
parts = append(parts, fantasy.ToolCallPart{
|
||||
ToolCallID: tc.ID,
|
||||
ToolCallID: sanitizeToolCallID(tc.ID),
|
||||
ToolName: tc.Name,
|
||||
Input: tc.Input,
|
||||
})
|
||||
@@ -340,7 +367,7 @@ func (m *Message) ToFantasyMessages() []fantasy.Message {
|
||||
}
|
||||
}
|
||||
parts = append(parts, fantasy.ToolResultPart{
|
||||
ToolCallID: result.ToolCallID,
|
||||
ToolCallID: sanitizeToolCallID(result.ToolCallID),
|
||||
Output: output,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSanitizeToolCallID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "valid alphanumeric ID",
|
||||
input: "call_123abc",
|
||||
expected: "call_123abc",
|
||||
},
|
||||
{
|
||||
name: "ID with dots (OpenCode/Kimi style)",
|
||||
input: "call.123.abc",
|
||||
expected: "call_123_abc",
|
||||
},
|
||||
{
|
||||
name: "ID with colons",
|
||||
input: "tool:123:abc",
|
||||
expected: "tool_123_abc",
|
||||
},
|
||||
{
|
||||
name: "ID with special characters",
|
||||
input: "tool@#$%^&*()",
|
||||
expected: "tool_________",
|
||||
},
|
||||
{
|
||||
name: "Anthropic style ID (already valid)",
|
||||
input: "toolu_0123456789ABCDEF",
|
||||
expected: "toolu_0123456789ABCDEF",
|
||||
},
|
||||
{
|
||||
name: "OpenAI style ID (already valid)",
|
||||
input: "call_O17Uplv4lJvD6DVdIvFFeRMw",
|
||||
expected: "call_O17Uplv4lJvD6DVdIvFFeRMw",
|
||||
},
|
||||
{
|
||||
name: "ID with hyphens",
|
||||
input: "my-tool-call-123",
|
||||
expected: "my-tool-call-123",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "tool_0",
|
||||
},
|
||||
{
|
||||
name: "only special characters",
|
||||
input: "@#$%",
|
||||
expected: "____",
|
||||
},
|
||||
{
|
||||
name: "mixed valid and invalid",
|
||||
input: "call_123.abc-def@ghi",
|
||||
expected: "call_123_abc-def_ghi",
|
||||
},
|
||||
{
|
||||
name: "Unicode characters",
|
||||
input: "tool_日本語",
|
||||
expected: "tool____",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := sanitizeToolCallID(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("sanitizeToolCallID(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeToolCallID_MatchesAnthropicPattern(t *testing.T) {
|
||||
// Test that sanitized IDs match Anthropic's required pattern: ^[a-zA-Z0-9_-]+$
|
||||
// This is a simplified check - in reality the pattern allows alphanumeric, underscore, hyphen
|
||||
testIDs := []string{
|
||||
"call.123.abc",
|
||||
"tool:123:def",
|
||||
"id@#$%^&*()",
|
||||
"mixed.valid-id_test",
|
||||
"",
|
||||
}
|
||||
|
||||
for _, id := range testIDs {
|
||||
sanitized := sanitizeToolCallID(id)
|
||||
|
||||
// Verify each character is valid
|
||||
for i, r := range sanitized {
|
||||
valid := (r >= 'a' && r <= 'z') ||
|
||||
(r >= 'A' && r <= 'Z') ||
|
||||
(r >= '0' && r <= '9') ||
|
||||
r == '_' ||
|
||||
r == '-'
|
||||
|
||||
if !valid {
|
||||
t.Errorf("sanitizeToolCallID(%q) = %q, contains invalid character at position %d: %q",
|
||||
id, sanitized, i, string(r))
|
||||
}
|
||||
}
|
||||
|
||||
// Verify non-empty
|
||||
if sanitized == "" {
|
||||
t.Errorf("sanitizeToolCallID(%q) returned empty string", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -17,15 +17,21 @@ type modelsDBProvider struct {
|
||||
|
||||
// modelsDBModel represents a model entry from models.dev/api.json.
|
||||
type modelsDBModel struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Family string `json:"family,omitempty"`
|
||||
Attachment bool `json:"attachment"`
|
||||
Reasoning bool `json:"reasoning"`
|
||||
ToolCall bool `json:"tool_call"`
|
||||
Temperature bool `json:"temperature"`
|
||||
Cost modelsDBCost `json:"cost"`
|
||||
Limit modelsDBLimit `json:"limit"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Family string `json:"family,omitempty"`
|
||||
Attachment bool `json:"attachment"`
|
||||
Reasoning bool `json:"reasoning"`
|
||||
ToolCall bool `json:"tool_call"`
|
||||
Temperature bool `json:"temperature"`
|
||||
Cost modelsDBCost `json:"cost"`
|
||||
Limit modelsDBLimit `json:"limit"`
|
||||
Provider *modelsDBModelProvider `json:"provider,omitempty"` // Model-specific provider override
|
||||
}
|
||||
|
||||
// modelsDBModelProvider represents a provider reference within a model.
|
||||
type modelsDBModelProvider struct {
|
||||
NPM string `json:"npm"`
|
||||
}
|
||||
|
||||
// modelsDBCost represents model pricing from models.dev.
|
||||
|
||||
@@ -169,6 +169,9 @@ type ProviderResult struct {
|
||||
// ProviderOptions contains provider-specific options to be passed to the
|
||||
// fantasy agent (e.g. OpenAI Responses API reasoning options).
|
||||
ProviderOptions fantasy.ProviderOptions
|
||||
// SkipMaxOutputTokens indicates that this provider doesn't support the
|
||||
// max_output_tokens parameter (e.g., OpenAI Codex OAuth API).
|
||||
SkipMaxOutputTokens bool
|
||||
}
|
||||
|
||||
// ParseModelString parses a model string in "provider/model" format (e.g. "anthropic/claude-sonnet-4-5").
|
||||
@@ -263,14 +266,22 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
|
||||
// autoRouteProvider attempts to create a provider by looking up its npm package
|
||||
// in the models.dev database and routing through the appropriate fantasy provider.
|
||||
// For openai-compatible providers, it uses the api URL from models.dev.
|
||||
// Models may have a provider override that specifies a different npm package than
|
||||
// the provider's default (e.g., opencode's claude-opus-4-6 uses @ai-sdk/anthropic).
|
||||
func autoRouteProvider(ctx context.Context, config *ProviderConfig, provider, modelName string, registry *ModelsRegistry) (*ProviderResult, error) {
|
||||
providerInfo := registry.GetProviderInfo(provider)
|
||||
if providerInfo == nil {
|
||||
return nil, fmt.Errorf("unsupported provider: %s (not found in model database)", provider)
|
||||
}
|
||||
|
||||
// Check for model-specific provider override
|
||||
npmPackage := providerInfo.NPM
|
||||
if modelInfo := registry.LookupModel(provider, modelName); modelInfo != nil && modelInfo.ProviderNPM != "" {
|
||||
npmPackage = modelInfo.ProviderNPM
|
||||
}
|
||||
|
||||
// Determine the fantasy provider for this npm package
|
||||
fantasyProvider := npmToFantasyProvider[providerInfo.NPM]
|
||||
fantasyProvider := npmToFantasyProvider[npmPackage]
|
||||
if fantasyProvider == "" && providerInfo.API != "" {
|
||||
// Unknown npm but has API URL → route through openaicompat
|
||||
fantasyProvider = "openaicompat"
|
||||
@@ -290,7 +301,7 @@ func autoRouteProvider(ctx context.Context, config *ProviderConfig, provider, mo
|
||||
}
|
||||
return createAutoRoutedOpenAIProvider(ctx, config, modelName, providerInfo)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported provider: %s (npm: %s has no fantasy mapping)", provider, providerInfo.NPM)
|
||||
return nil, fmt.Errorf("unsupported provider: %s (npm: %s has no fantasy mapping)", provider, npmPackage)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -348,7 +359,10 @@ func createAutoRoutedAnthropicProvider(ctx context.Context, config *ProviderConf
|
||||
opts = append(opts, anthropic.WithAPIKey(apiKey))
|
||||
|
||||
if config.ProviderURL != "" {
|
||||
opts = append(opts, anthropic.WithBaseURL(config.ProviderURL))
|
||||
// The anthropic client appends "/v1/messages" to the base URL.
|
||||
// If the provider URL ends with "/v1", strip it to avoid double "/v1/v1" paths.
|
||||
baseURL := strings.TrimSuffix(config.ProviderURL, "/v1")
|
||||
opts = append(opts, anthropic.WithBaseURL(baseURL))
|
||||
}
|
||||
|
||||
if config.TLSSkipVerify {
|
||||
@@ -610,13 +624,52 @@ func createVertexAnthropicProvider(ctx context.Context, config *ProviderConfig,
|
||||
|
||||
func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
||||
apiKey := config.ProviderAPIKey
|
||||
source := "command-line flag"
|
||||
var accountID string
|
||||
var isCodexOAuth bool
|
||||
|
||||
if apiKey == "" {
|
||||
apiKey = os.Getenv("OPENAI_API_KEY")
|
||||
}
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("OpenAI API key not provided. Use --provider-api-key flag or OPENAI_API_KEY environment variable")
|
||||
// Check stored credentials first
|
||||
cm, err := auth.NewCredentialManager()
|
||||
if err == nil {
|
||||
if creds, err := cm.GetOpenAICredentials(); err == nil && creds != nil {
|
||||
if creds.Type == "oauth" && creds.AccessToken != "" {
|
||||
// For OAuth, get a valid access token (may refresh if needed)
|
||||
token, err := cm.GetValidOpenAIAccessToken()
|
||||
if err == nil && token != "" {
|
||||
apiKey = token
|
||||
accountID = creds.AccountID
|
||||
isCodexOAuth = true
|
||||
source = "stored Codex OAuth credentials"
|
||||
}
|
||||
} else if creds.Type == "api_key" && creds.APIKey != "" {
|
||||
apiKey = creds.APIKey
|
||||
source = "stored API key"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to environment variable
|
||||
if apiKey == "" {
|
||||
apiKey = os.Getenv("OPENAI_API_KEY")
|
||||
source = "OPENAI_API_KEY environment variable"
|
||||
}
|
||||
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("OpenAI API key not provided. Use 'kit auth login openai', --provider-api-key flag, or OPENAI_API_KEY environment variable")
|
||||
}
|
||||
|
||||
if os.Getenv("DEBUG") != "" || os.Getenv("KIT_DEBUG") != "" {
|
||||
fmt.Fprintf(os.Stderr, "Using OpenAI API key from: %s\n", source)
|
||||
}
|
||||
|
||||
// For Codex OAuth, use the ChatGPT backend API with custom headers
|
||||
if isCodexOAuth {
|
||||
return createOpenAICodexProvider(ctx, config, modelName, apiKey, accountID)
|
||||
}
|
||||
|
||||
// Regular OpenAI API key flow
|
||||
var opts []openai.Option
|
||||
opts = append(opts, openai.WithAPIKey(apiKey))
|
||||
opts = append(opts, openai.WithUseResponsesAPI())
|
||||
@@ -645,6 +698,135 @@ func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
return &ProviderResult{Model: model, ProviderOptions: providerOpts}, nil
|
||||
}
|
||||
|
||||
// createOpenAICodexProvider creates a provider for ChatGPT/Codex OAuth tokens.
|
||||
// Uses the chatgpt.com/backend-api/codex endpoint with special headers.
|
||||
func createOpenAICodexProvider(ctx context.Context, config *ProviderConfig, modelName, token, accountID string) (*ProviderResult, error) {
|
||||
// Check for spark models which are not accessible via OAuth
|
||||
if detectCodexModelFamily(modelName) == "gpt-codex-spark" {
|
||||
return nil, fmt.Errorf("gpt-codex-spark models are not accessible via ChatGPT OAuth. " +
|
||||
"These models require special access or a different authentication method. " +
|
||||
"Please use regular Codex models like 'openai/gpt-5.3-codex' instead")
|
||||
}
|
||||
|
||||
// Use the ChatGPT backend API with /codex path
|
||||
baseURL := "https://chatgpt.com/backend-api/codex"
|
||||
if config.ProviderURL != "" {
|
||||
baseURL = config.ProviderURL
|
||||
}
|
||||
|
||||
// Build custom HTTP client with required headers
|
||||
httpClient := createCodexHTTPClient(token, accountID, config.TLSSkipVerify)
|
||||
|
||||
var opts []openai.Option
|
||||
opts = append(opts, openai.WithAPIKey(token))
|
||||
opts = append(opts, openai.WithBaseURL(baseURL))
|
||||
opts = append(opts, openai.WithUseResponsesAPI())
|
||||
opts = append(opts, openai.WithHTTPClient(httpClient))
|
||||
|
||||
provider, err := openai.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OpenAI Codex provider: %w", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OpenAI Codex model: %w", err)
|
||||
}
|
||||
|
||||
providerOpts := buildCodexProviderOptions(config, modelName)
|
||||
|
||||
return &ProviderResult{
|
||||
Model: model,
|
||||
ProviderOptions: providerOpts,
|
||||
SkipMaxOutputTokens: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildCodexProviderOptions returns fantasy.ProviderOptions configured for
|
||||
// OpenAI Codex API. The Codex API requires the system prompt to be passed
|
||||
// as 'instructions' rather than as a system message.
|
||||
func buildCodexProviderOptions(config *ProviderConfig, modelName string) fantasy.ProviderOptions {
|
||||
store := false
|
||||
opts := &openai.ResponsesProviderOptions{
|
||||
Store: &store,
|
||||
}
|
||||
|
||||
if config.SystemPrompt != "" {
|
||||
opts.Instructions = &config.SystemPrompt
|
||||
}
|
||||
|
||||
if openai.IsResponsesReasoningModel(modelName) {
|
||||
opts.ReasoningEffort = thinkingLevelToReasoningEffort(config.ThinkingLevel)
|
||||
}
|
||||
|
||||
return fantasy.ProviderOptions{openai.Name: opts}
|
||||
}
|
||||
|
||||
// detectCodexModelFamily determines the model family from the model name
|
||||
func detectCodexModelFamily(modelName string) string {
|
||||
modelName = strings.ToLower(modelName)
|
||||
if strings.Contains(modelName, "spark") {
|
||||
return "gpt-codex-spark"
|
||||
}
|
||||
if strings.Contains(modelName, "codex-mini") || strings.Contains(modelName, "mini-latest") {
|
||||
return "gpt-codex-mini"
|
||||
}
|
||||
if strings.Contains(modelName, "codex") {
|
||||
return "gpt-codex"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// createCodexHTTPClient creates an HTTP client with headers required for ChatGPT/Codex API
|
||||
func createCodexHTTPClient(token, accountID string, skipVerify bool) *http.Client {
|
||||
var base http.RoundTripper
|
||||
if skipVerify {
|
||||
base = &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
base = http.DefaultTransport
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: &codexTransport{
|
||||
base: base,
|
||||
token: token,
|
||||
accountID: accountID,
|
||||
},
|
||||
Timeout: 120 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// codexTransport is a custom RoundTripper that adds ChatGPT/Codex specific headers
|
||||
type codexTransport struct {
|
||||
base http.RoundTripper
|
||||
token string
|
||||
accountID string
|
||||
}
|
||||
|
||||
func (t *codexTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
newReq := req.Clone(req.Context())
|
||||
|
||||
// Add required headers for ChatGPT/Codex API
|
||||
// These headers mimic the official pi client to avoid Cloudflare blocking
|
||||
newReq.Header.Set("Authorization", "Bearer "+t.token)
|
||||
if t.accountID != "" {
|
||||
newReq.Header.Set("chatgpt-account-id", t.accountID)
|
||||
}
|
||||
newReq.Header.Set("originator", "kit")
|
||||
newReq.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36")
|
||||
newReq.Header.Set("OpenAI-Beta", "responses=experimental")
|
||||
newReq.Header.Set("Accept", "text/event-stream")
|
||||
newReq.Header.Set("Accept-Language", "en-US,en;q=0.9")
|
||||
newReq.Header.Set("Cache-Control", "no-cache")
|
||||
newReq.Header.Set("Pragma", "no-cache")
|
||||
|
||||
return t.base.RoundTrip(newReq)
|
||||
}
|
||||
|
||||
func createGoogleProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
||||
apiKey := firstNonEmpty(
|
||||
config.ProviderAPIKey,
|
||||
|
||||
@@ -22,6 +22,7 @@ type ModelInfo struct {
|
||||
Temperature bool
|
||||
Cost Cost
|
||||
Limit Limit
|
||||
ProviderNPM string // Model-specific provider npm override (e.g. "@ai-sdk/anthropic")
|
||||
}
|
||||
|
||||
// Cost represents the pricing information for a model.
|
||||
@@ -78,6 +79,10 @@ func buildFromModelsDB() map[string]ProviderInfo {
|
||||
for providerID, dp := range dbProviders {
|
||||
modelsMap := make(map[string]ModelInfo, len(dp.Models))
|
||||
for modelID, dm := range dp.Models {
|
||||
providerNPM := ""
|
||||
if dm.Provider != nil {
|
||||
providerNPM = dm.Provider.NPM
|
||||
}
|
||||
modelsMap[modelID] = ModelInfo{
|
||||
ID: dm.ID,
|
||||
Name: dm.Name,
|
||||
@@ -94,6 +99,7 @@ func buildFromModelsDB() map[string]ProviderInfo {
|
||||
Context: dm.Limit.Context,
|
||||
Output: dm.Limit.Output,
|
||||
},
|
||||
ProviderNPM: providerNPM,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -219,6 +225,15 @@ func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) err
|
||||
}
|
||||
}
|
||||
|
||||
// For openai, check stored credentials (OAuth / API key)
|
||||
if provider == "openai" {
|
||||
if cm, err := auth.NewCredentialManager(); err == nil {
|
||||
if has, _ := cm.HasOpenAICredentials(); has {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
envVars, err := r.getRequiredEnvVars(provider)
|
||||
if err != nil {
|
||||
// Unknown provider — nothing to validate
|
||||
|
||||
@@ -96,6 +96,7 @@ func ListAllSessions() ([]SessionInfo, error) {
|
||||
}
|
||||
|
||||
// listSessionsInDir reads all .jsonl files in a directory and extracts session info.
|
||||
// Empty sessions (no messages) are automatically cleaned up and not returned.
|
||||
func listSessionsInDir(dir string) ([]SessionInfo, error) {
|
||||
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
@@ -117,6 +118,11 @@ func listSessionsInDir(dir string) ([]SessionInfo, error) {
|
||||
if err != nil {
|
||||
continue // skip malformed session files
|
||||
}
|
||||
// Clean up and skip empty sessions (no messages)
|
||||
if info.MessageCount == 0 {
|
||||
_ = os.Remove(path)
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, *info)
|
||||
}
|
||||
|
||||
|
||||
@@ -628,6 +628,11 @@ func (tm *TreeManager) MessageCount() int {
|
||||
return count
|
||||
}
|
||||
|
||||
// IsEmpty returns true if the session has no messages (only header).
|
||||
func (tm *TreeManager) IsEmpty() bool {
|
||||
return tm.MessageCount() == 0
|
||||
}
|
||||
|
||||
// Close closes the underlying file handle.
|
||||
func (tm *TreeManager) Close() error {
|
||||
tm.mu.Lock()
|
||||
|
||||
@@ -127,9 +127,7 @@ func (p *MCPConnectionPool) GetConnection(ctx context.Context, serverName string
|
||||
return conn, nil
|
||||
} else {
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
p.debugLogger.LogDebug(fmt.Sprintf("[POOL] Connection %s unhealthy, removing", serverName))
|
||||
}
|
||||
p.debugLogger.LogDebug(fmt.Sprintf("[POOL] Connection %s unhealthy, removing", serverName))
|
||||
}
|
||||
_ = conn.client.Close()
|
||||
delete(p.connections, serverName)
|
||||
|
||||
@@ -3,6 +3,7 @@ package tools
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -70,7 +71,7 @@ func TestMCPToolManager_LoadTools_GracefulFailure(t *testing.T) {
|
||||
}
|
||||
|
||||
// The error should mention that all servers failed
|
||||
if err != nil && !contains(err.Error(), "all MCP servers failed") {
|
||||
if err != nil && !strings.Contains(err.Error(), "all MCP servers failed") {
|
||||
t.Errorf("Expected error message to mention all servers failed, got: %v", err)
|
||||
}
|
||||
|
||||
@@ -460,12 +461,4 @@ func sliceEqual(a, b []any) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Helper function to check if a string contains a substring
|
||||
func contains(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -349,7 +349,7 @@ func TestStreamComponent_SpinnerKeepsRunningDuringStreaming(t *testing.T) {
|
||||
c = sendStreamMsg(c, app.StreamChunkEvent{Content: "hello"})
|
||||
|
||||
// Flush pending chunks (simulates the 16ms tick firing).
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{})
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{generation: c.flushGeneration})
|
||||
|
||||
if !c.spinning {
|
||||
t.Fatal("expected spinning=true after first chunk")
|
||||
@@ -376,7 +376,7 @@ func TestStreamComponent_ChunkAccumulation(t *testing.T) {
|
||||
}
|
||||
|
||||
// Flush pending chunks (simulates the 16ms tick firing).
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{})
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{generation: c.flushGeneration})
|
||||
|
||||
got := c.streamContent.String()
|
||||
want := "Hello, world!"
|
||||
@@ -396,6 +396,7 @@ func TestStreamComponent_ToolExecution_IsStarting_ShowsSpinner(t *testing.T) {
|
||||
c := newTestStream()
|
||||
|
||||
_, cmd := c.Update(app.ToolExecutionEvent{
|
||||
ToolCallID: "call-exec-1",
|
||||
ToolName: "exec_tool",
|
||||
IsStarting: true,
|
||||
})
|
||||
@@ -403,8 +404,9 @@ func TestStreamComponent_ToolExecution_IsStarting_ShowsSpinner(t *testing.T) {
|
||||
if !c.spinning {
|
||||
t.Fatal("expected spinning=true during tool execution")
|
||||
}
|
||||
if len(c.activeTools) != 1 || !strings.Contains(c.activeTools[0], "exec_tool") {
|
||||
t.Fatalf("expected activeTools to contain tool name, got %v", c.activeTools)
|
||||
tools := c.activeToolDisplays()
|
||||
if len(tools) != 1 || !strings.Contains(tools[0], "exec_tool") {
|
||||
t.Fatalf("expected activeTools to contain tool name, got %v", tools)
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Fatal("expected tick cmd from ToolExecutionEvent{IsStarting:true}")
|
||||
@@ -418,11 +420,13 @@ func TestStreamComponent_ToolExecution_NotStarting_KeepsSpinning(t *testing.T) {
|
||||
c = sendStreamMsg(c, app.SpinnerEvent{Show: true})
|
||||
// Simulate a tool starting
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{
|
||||
ToolCallID: "call-some-1",
|
||||
ToolName: "some_tool",
|
||||
IsStarting: true,
|
||||
})
|
||||
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{
|
||||
ToolCallID: "call-some-1",
|
||||
ToolName: "some_tool",
|
||||
IsStarting: false,
|
||||
})
|
||||
@@ -440,9 +444,9 @@ func TestStreamComponent_ParallelToolExecution(t *testing.T) {
|
||||
c := newTestStream()
|
||||
|
||||
// Start three tools in parallel
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "read", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "grep", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "find", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-read", ToolName: "read", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-grep", ToolName: "grep", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-find", ToolName: "find", IsStarting: true})
|
||||
|
||||
if len(c.activeTools) != 3 {
|
||||
t.Fatalf("expected 3 active tools, got %d: %v", len(c.activeTools), c.activeTools)
|
||||
@@ -455,19 +459,44 @@ func TestStreamComponent_ParallelToolExecution(t *testing.T) {
|
||||
}
|
||||
|
||||
// Finish one tool
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "grep", IsStarting: false})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-grep", ToolName: "grep", IsStarting: false})
|
||||
if len(c.activeTools) != 2 {
|
||||
t.Fatalf("expected 2 active tools after one finished, got %d: %v", len(c.activeTools), c.activeTools)
|
||||
}
|
||||
|
||||
// Finish remaining tools
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "read", IsStarting: false})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "find", IsStarting: false})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-read", ToolName: "read", IsStarting: false})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-find", ToolName: "find", IsStarting: false})
|
||||
if len(c.activeTools) != 0 {
|
||||
t.Fatalf("expected 0 active tools after all finished, got %d: %v", len(c.activeTools), c.activeTools)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamComponent_ParallelSameToolName_UsesToolCallID verifies finishing one
|
||||
// tool call does not remove another concurrent call with the same tool name.
|
||||
func TestStreamComponent_ParallelSameToolName_UsesToolCallID(t *testing.T) {
|
||||
c := newTestStream()
|
||||
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-read-1", ToolName: "read", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-read-2", ToolName: "read", IsStarting: true})
|
||||
|
||||
tools := c.activeToolDisplays()
|
||||
if len(tools) != 2 {
|
||||
t.Fatalf("expected 2 active read calls, got %d (%v)", len(tools), tools)
|
||||
}
|
||||
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-read-1", ToolName: "read", IsStarting: false})
|
||||
tools = c.activeToolDisplays()
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("expected 1 active read call after finishing one ID, got %d (%v)", len(tools), tools)
|
||||
}
|
||||
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-read-2", ToolName: "read", IsStarting: false})
|
||||
if len(c.activeToolDisplays()) != 0 {
|
||||
t.Fatalf("expected no active tools after finishing both IDs, got %v", c.activeToolDisplays())
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// TestStreamComponent_GetRenderedContent verifies the method returns rendered
|
||||
// text when content is accumulated, and empty string when not.
|
||||
@@ -621,3 +650,43 @@ func TestStreamComponent_StaleTick_Discarded(t *testing.T) {
|
||||
t.Fatal("current-gen tick should reschedule")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamComponent_StaleFlushTick_Discarded verifies that flush ticks from a
|
||||
// previous generation (e.g. pre-Reset) are ignored.
|
||||
func TestStreamComponent_StaleFlushTick_Discarded(t *testing.T) {
|
||||
c := newTestStream()
|
||||
|
||||
// Start a pending flush and capture its generation.
|
||||
c = sendStreamMsg(c, app.StreamChunkEvent{Content: "old"})
|
||||
staleGen := c.flushGeneration
|
||||
if !c.flushPending {
|
||||
t.Fatal("precondition: expected flushPending=true after first chunk")
|
||||
}
|
||||
|
||||
// Reset should invalidate in-flight flush ticks.
|
||||
c.Reset()
|
||||
if c.flushGeneration == staleGen {
|
||||
t.Fatal("expected flushGeneration to change after Reset")
|
||||
}
|
||||
|
||||
// New content in a new generation.
|
||||
c = sendStreamMsg(c, app.StreamChunkEvent{Content: "new"})
|
||||
if got := c.pendingStream.String(); got != "new" {
|
||||
t.Fatalf("expected pendingStream='new', got %q", got)
|
||||
}
|
||||
|
||||
// Stale flush tick should be ignored.
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{generation: staleGen})
|
||||
if got := c.pendingStream.String(); got != "new" {
|
||||
t.Fatalf("stale flush tick should not commit pending stream, got %q", got)
|
||||
}
|
||||
|
||||
// Current generation flush should commit.
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{generation: c.flushGeneration})
|
||||
if got := c.pendingStream.String(); got != "" {
|
||||
t.Fatalf("expected pendingStream empty after current flush, got %q", got)
|
||||
}
|
||||
if got := c.streamContent.String(); got != "new" {
|
||||
t.Fatalf("expected streamContent='new' after current flush, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
+8
-9
@@ -179,9 +179,8 @@ func (c *CLI) DisplayDebugConfig(config map[string]any) {
|
||||
}
|
||||
|
||||
// UpdateUsageFromResponse records token usage using metadata from the fantasy
|
||||
// response when available. Falls back to text-based estimation if the metadata is
|
||||
// missing or appears unreliable. This provides more accurate usage tracking when
|
||||
// providers supply token count information.
|
||||
// response. Only actual API-reported tokens are used for cost tracking.
|
||||
// If the provider doesn't report token counts, no usage is recorded.
|
||||
func (c *CLI) UpdateUsageFromResponse(response *fantasy.Response, inputText string) {
|
||||
if c.usageTracker == nil {
|
||||
return
|
||||
@@ -191,19 +190,19 @@ func (c *CLI) UpdateUsageFromResponse(response *fantasy.Response, inputText stri
|
||||
inputTokens := int(usage.InputTokens)
|
||||
outputTokens := int(usage.OutputTokens)
|
||||
|
||||
// Validate that the metadata seems reasonable
|
||||
if inputTokens > 0 && outputTokens > 0 {
|
||||
// Only use actual API-reported tokens for cost tracking.
|
||||
// We intentionally do NOT estimate tokens - estimation is inaccurate
|
||||
// and should never be used for cost calculations.
|
||||
if inputTokens > 0 {
|
||||
cacheReadTokens := int(usage.CacheReadTokens)
|
||||
cacheWriteTokens := int(usage.CacheCreationTokens)
|
||||
c.usageTracker.UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens)
|
||||
// Per-response usage is a single API call, so it represents the
|
||||
// actual context window fill level.
|
||||
c.usageTracker.SetContextTokens(inputTokens + outputTokens)
|
||||
} else {
|
||||
// Fallback to estimation if no metadata is available.
|
||||
// EstimateAndUpdateUsage sets context tokens internally.
|
||||
c.usageTracker.EstimateAndUpdateUsage(inputText, response.Content.Text())
|
||||
}
|
||||
// If inputTokens is 0, the provider didn't report usage - we skip recording
|
||||
// rather than estimating, to ensure cost accuracy.
|
||||
}
|
||||
|
||||
// DisplayUsageAfterResponse renders and displays token usage information immediately
|
||||
|
||||
@@ -127,30 +127,6 @@ func (r *CompactRenderer) RenderAssistantMessage(content string, timestamp time.
|
||||
}
|
||||
}
|
||||
|
||||
// RenderToolCallMessage renders a tool call notification in compact format, showing
|
||||
// the tool being executed with its arguments in a single line. The tool name is
|
||||
// highlighted and arguments are displayed in a muted color for visual distinction.
|
||||
func (r *CompactRenderer) RenderToolCallMessage(toolName, toolArgs string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Tool).Render("[")
|
||||
label := lipgloss.NewStyle().Foreground(theme.Tool).Bold(true).Render(toolName)
|
||||
|
||||
// Format args for compact display
|
||||
argsDisplay := r.formatToolArgs(toolArgs)
|
||||
if argsDisplay != "" {
|
||||
argsDisplay = lipgloss.NewStyle().Foreground(theme.Muted).Render(argsDisplay)
|
||||
}
|
||||
|
||||
line := fmt.Sprintf("%s %s %s", symbol, label, argsDisplay)
|
||||
|
||||
return UIMessage{
|
||||
Type: ToolCallMessage,
|
||||
Content: line,
|
||||
Height: 1,
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderToolMessage renders a unified tool block in compact format, combining
|
||||
// the tool invocation header (icon + display name + params) with the execution
|
||||
// result body. Status is indicated by icon: checkmark for success, cross for error.
|
||||
|
||||
@@ -292,11 +292,6 @@ func ApplyGradient(text string, colorA, colorB color.Color) string {
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// CreateGradientText creates styled text with a gradient effect between two colors.
|
||||
func CreateGradientText(text string, startColor, endColor color.Color) string {
|
||||
return ApplyGradient(text, startColor, endColor)
|
||||
}
|
||||
|
||||
// Compact styling utilities
|
||||
|
||||
// StyleCompactSymbol creates a lipgloss style for message type indicators in
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// FileSuggestion represents a single file or directory suggestion for the @
|
||||
@@ -345,44 +344,16 @@ func scoreFilePath(query, path string) int {
|
||||
}
|
||||
|
||||
// Fuzzy character match on basename.
|
||||
if score := fuzzyCharMatch(query, baseNameLower); score > 0 {
|
||||
if score := fuzzyCharacterMatch(query, baseNameLower); score > 0 {
|
||||
return score
|
||||
}
|
||||
|
||||
// Fuzzy character match on full path.
|
||||
if score := fuzzyCharMatch(query, pathLower); score > 0 {
|
||||
if score := fuzzyCharacterMatch(query, pathLower); score > 0 {
|
||||
return score - 50
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// fuzzyCharMatch performs character-by-character fuzzy matching. Returns a
|
||||
// positive score if all query characters appear in order in the target.
|
||||
func fuzzyCharMatch(query, target string) int {
|
||||
if utf8.RuneCountInString(query) > utf8.RuneCountInString(target) {
|
||||
return 0
|
||||
}
|
||||
|
||||
qRunes := []rune(query)
|
||||
tRunes := []rune(target)
|
||||
qi := 0
|
||||
score := 100
|
||||
consecutive := 0
|
||||
|
||||
for ti := 0; ti < len(tRunes) && qi < len(qRunes); ti++ {
|
||||
if tRunes[ti] == qRunes[qi] {
|
||||
qi++
|
||||
consecutive++
|
||||
score += consecutive * 5
|
||||
} else {
|
||||
consecutive = 0
|
||||
score -= 2
|
||||
}
|
||||
}
|
||||
|
||||
if qi < len(qRunes) {
|
||||
return 0
|
||||
}
|
||||
return score
|
||||
}
|
||||
|
||||
+11
-7
@@ -113,19 +113,23 @@ func fuzzyScore(query string, cmd *SlashCommand) int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// fuzzyCharacterMatch performs character-by-character fuzzy matching
|
||||
// fuzzyCharacterMatch performs character-by-character fuzzy matching using
|
||||
// rune-safe iteration so multi-byte Unicode characters are handled correctly.
|
||||
// Returns a positive score if all query runes appear in order within target.
|
||||
func fuzzyCharacterMatch(query, target string) int {
|
||||
if len(query) > len(target) {
|
||||
qRunes := []rune(query)
|
||||
tRunes := []rune(target)
|
||||
if len(qRunes) > len(tRunes) {
|
||||
return 0
|
||||
}
|
||||
|
||||
queryIdx := 0
|
||||
qi := 0
|
||||
score := 100
|
||||
consecutiveMatches := 0
|
||||
|
||||
for i := 0; i < len(target) && queryIdx < len(query); i++ {
|
||||
if target[i] == query[queryIdx] {
|
||||
queryIdx++
|
||||
for ti := 0; ti < len(tRunes) && qi < len(qRunes); ti++ {
|
||||
if tRunes[ti] == qRunes[qi] {
|
||||
qi++
|
||||
consecutiveMatches++
|
||||
score += consecutiveMatches * 10
|
||||
} else {
|
||||
@@ -135,7 +139,7 @@ func fuzzyCharacterMatch(query, target string) int {
|
||||
}
|
||||
|
||||
// Must match all characters in query
|
||||
if queryIdx < len(query) {
|
||||
if qi < len(qRunes) {
|
||||
return 0
|
||||
}
|
||||
|
||||
|
||||
+14
-1
@@ -65,6 +65,10 @@ type InputComponent struct {
|
||||
// hideHint suppresses the "enter submit · ctrl+j..." hint text.
|
||||
hideHint bool
|
||||
|
||||
// agentBusy indicates the agent is currently working. When true, the
|
||||
// hint text shows steering shortcut (Ctrl+S) instead of submit.
|
||||
agentBusy bool
|
||||
|
||||
// pendingImages holds clipboard images attached to the next submission.
|
||||
// Images are added via Ctrl+V and cleared on submit or Ctrl+U.
|
||||
pendingImages []ImageAttachment
|
||||
@@ -514,7 +518,16 @@ func (s *InputComponent) View() tea.View {
|
||||
// Adapt hint text to available width (accounting for left padding of 3).
|
||||
var hint string
|
||||
availableHintWidth := s.width - 3
|
||||
if availableHintWidth >= 67 {
|
||||
if s.agentBusy {
|
||||
// When the agent is working, show steering shortcut.
|
||||
if availableHintWidth >= 55 {
|
||||
hint = "enter queue • ctrl+s steer • esc esc cancel"
|
||||
} else if availableHintWidth >= 35 {
|
||||
hint = "↵ queue • ^S steer • esc×2 cancel"
|
||||
} else {
|
||||
hint = "^S steer"
|
||||
}
|
||||
} else if availableHintWidth >= 67 {
|
||||
hint = "enter submit • ctrl+j / shift+enter new line • ctrl+v paste image"
|
||||
} else if availableHintWidth >= 40 {
|
||||
hint = "↵ submit • ctrl+j newline • ctrl+v image"
|
||||
|
||||
+118
-331
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"charm.land/lipgloss/v2"
|
||||
"github.com/indaco/herald"
|
||||
)
|
||||
|
||||
// ansiEscapeRe matches ANSI escape sequences used for terminal styling.
|
||||
@@ -22,9 +23,9 @@ const (
|
||||
UserMessage MessageType = iota
|
||||
AssistantMessage
|
||||
ToolMessage
|
||||
ToolCallMessage // New type for showing tool calls in progress
|
||||
SystemMessage // New type for KIT system messages (help, tools, etc.)
|
||||
ErrorMessage // New type for error messages
|
||||
ToolCallMessage
|
||||
SystemMessage
|
||||
ErrorMessage
|
||||
)
|
||||
|
||||
// UIMessage encapsulates a fully rendered message ready for display in the UI,
|
||||
@@ -40,29 +41,14 @@ type UIMessage struct {
|
||||
Streaming bool
|
||||
}
|
||||
|
||||
// Helper functions to get theme colors
|
||||
// getTheme returns the current theme (helper for compact_renderer.go)
|
||||
func getTheme() Theme {
|
||||
return GetTheme()
|
||||
}
|
||||
|
||||
// toolDisplayNames maps raw tool names to human-friendly display names.
|
||||
var toolDisplayNames = map[string]string{
|
||||
"bash": "Bash",
|
||||
"read": "Read",
|
||||
"write": "Write",
|
||||
"edit": "Edit",
|
||||
"grep": "Grep",
|
||||
"find": "Find",
|
||||
"ls": "Ls",
|
||||
"run_shell_cmd": "Bash",
|
||||
}
|
||||
|
||||
// toolDisplayName returns a human-friendly display name for a tool.
|
||||
// Falls back to capitalizing the first letter of the raw name.
|
||||
// toolDisplayName returns a human-friendly display name for a tool,
|
||||
// title-casing the first letter of the raw name.
|
||||
func toolDisplayName(rawName string) string {
|
||||
if display, ok := toolDisplayNames[rawName]; ok {
|
||||
return display
|
||||
}
|
||||
if rawName != "" {
|
||||
return strings.ToUpper(rawName[:1]) + rawName[1:]
|
||||
}
|
||||
@@ -70,8 +56,6 @@ func toolDisplayName(rawName string) string {
|
||||
}
|
||||
|
||||
// formatToolParams formats tool input parameters for inline header display.
|
||||
// Extracts the primary parameter (command/filePath) first, then shows
|
||||
// remaining params as (key=val, ...). Truncates to maxWidth.
|
||||
func formatToolParams(toolArgs string, maxWidth int) string {
|
||||
args := strings.TrimSpace(toolArgs)
|
||||
if args == "" || args == "{}" {
|
||||
@@ -80,7 +64,6 @@ func formatToolParams(toolArgs string, maxWidth int) string {
|
||||
|
||||
var params map[string]any
|
||||
if err := json.Unmarshal([]byte(args), ¶ms); err != nil {
|
||||
// Fallback: strip braces and return raw content
|
||||
args = strings.TrimPrefix(args, "{")
|
||||
args = strings.TrimSuffix(args, "}")
|
||||
args = strings.TrimSpace(args)
|
||||
@@ -94,7 +77,6 @@ func formatToolParams(toolArgs string, maxWidth int) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Identify primary parameter by checking known keys in priority order
|
||||
primaryKeys := []string{"command", "filePath", "path", "pattern", "query", "url"}
|
||||
var primaryKey string
|
||||
var primaryVal string
|
||||
@@ -111,14 +93,13 @@ func formatToolParams(toolArgs string, maxWidth int) string {
|
||||
result.WriteString(primaryVal)
|
||||
}
|
||||
|
||||
// Collect remaining parameters, skipping body-content keys (already
|
||||
// rendered in the tool body) and any values that are too large.
|
||||
bodyKeys := map[string]bool{
|
||||
"content": true,
|
||||
"old_text": true,
|
||||
"new_text": true,
|
||||
"oldText": true,
|
||||
"newText": true,
|
||||
"edits": true,
|
||||
"todos": true,
|
||||
}
|
||||
var remaining []string
|
||||
@@ -154,65 +135,35 @@ func formatToolParams(toolArgs string, maxWidth int) string {
|
||||
}
|
||||
|
||||
// MessageRenderer handles the formatting and rendering of different message types
|
||||
// with consistent styling, markdown support, and appropriate visual hierarchies
|
||||
// for the standard (non-compact) display mode.
|
||||
type MessageRenderer struct {
|
||||
width int
|
||||
debug bool
|
||||
|
||||
// getToolRenderer returns extension-provided rendering overrides for a
|
||||
// specific tool. May be nil if no extensions are loaded. Used in
|
||||
// RenderToolMessage to check for custom header/body formatting before
|
||||
// falling back to builtin renderers.
|
||||
width int
|
||||
debug bool
|
||||
ty *herald.Typography
|
||||
getToolRenderer func(toolName string) *ToolRendererData
|
||||
}
|
||||
|
||||
// newMessageRenderer creates and initializes a new MessageRenderer with the specified
|
||||
// terminal width and debug mode setting. The width parameter determines line wrapping
|
||||
// and layout calculations.
|
||||
// newMessageRenderer creates and initializes a new MessageRenderer
|
||||
func newMessageRenderer(width int, debug bool) *MessageRenderer {
|
||||
return &MessageRenderer{
|
||||
width: width,
|
||||
debug: debug,
|
||||
ty: createTypography(GetTheme()),
|
||||
}
|
||||
}
|
||||
|
||||
// SetWidth updates the terminal width for the renderer, affecting how content
|
||||
// is wrapped and formatted in subsequent render operations.
|
||||
// SetWidth updates the terminal width for the renderer
|
||||
func (r *MessageRenderer) SetWidth(width int) {
|
||||
r.width = width
|
||||
}
|
||||
|
||||
// RenderUserMessage renders a user's input message with distinctive right-aligned
|
||||
// formatting, including the system username, timestamp, and markdown-rendered content.
|
||||
// The message is displayed with a colored right border for visual distinction.
|
||||
// RenderUserMessage renders a user's input message using herald Tip alert
|
||||
func (r *MessageRenderer) RenderUserMessage(content string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
|
||||
// Only run markdown rendering when the message contains code spans or
|
||||
// fenced code blocks. Plain text is rendered directly so that newlines
|
||||
// are preserved without the extra paragraph spacing glamour adds.
|
||||
var messageContent string
|
||||
if strings.Contains(content, "`") {
|
||||
// Glamour treats single \n as a soft break, so convert to paragraph
|
||||
// breaks and collapse the resulting blank lines after rendering.
|
||||
mdContent := strings.ReplaceAll(content, "\n", "\n\n")
|
||||
messageContent = r.renderMarkdown(mdContent, r.width-8)
|
||||
messageContent = removeBlankLines(messageContent)
|
||||
} else {
|
||||
messageContent = content
|
||||
if strings.TrimSpace(content) == "" {
|
||||
content = "(empty message)"
|
||||
}
|
||||
|
||||
fullContent := strings.TrimSuffix(messageContent, "\n")
|
||||
|
||||
// Left border with Blue color for user messages.
|
||||
rendered := renderContentBlock(
|
||||
fullContent,
|
||||
r.width,
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(theme.Info),
|
||||
WithMarginBottom(1),
|
||||
)
|
||||
rendered := r.ty.Tip(content)
|
||||
rendered = lipgloss.NewStyle().MarginBottom(1).Render(rendered)
|
||||
|
||||
return UIMessage{
|
||||
Type: UserMessage,
|
||||
@@ -222,12 +173,8 @@ func (r *MessageRenderer) RenderUserMessage(content string, timestamp time.Time)
|
||||
}
|
||||
}
|
||||
|
||||
// RenderAssistantMessage renders an AI assistant's response with left-aligned formatting,
|
||||
// including the model name, timestamp, and markdown-rendered content. Empty responses
|
||||
// are ignored and return an empty message. The message features a colored left border
|
||||
// for visual distinction.
|
||||
// RenderAssistantMessage renders an AI assistant's response
|
||||
func (r *MessageRenderer) RenderAssistantMessage(content string, timestamp time.Time, modelName string) UIMessage {
|
||||
// Ignore empty responses - don't render anything
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return UIMessage{
|
||||
Type: AssistantMessage,
|
||||
@@ -237,17 +184,9 @@ func (r *MessageRenderer) RenderAssistantMessage(content string, timestamp time.
|
||||
}
|
||||
}
|
||||
|
||||
theme := getTheme()
|
||||
messageContent := r.renderMarkdown(content, r.width-8)
|
||||
fullContent := strings.TrimSuffix(messageContent, "\n")
|
||||
|
||||
// Left border with Primary (Mauve) color for assistant messages.
|
||||
rendered := renderContentBlock(
|
||||
fullContent,
|
||||
r.width,
|
||||
WithBorderColor(theme.Primary),
|
||||
WithMarginBottom(1),
|
||||
)
|
||||
// Use markdown rendering with Chroma syntax highlighting
|
||||
rendered := toMarkdown(content, r.width-4)
|
||||
rendered = lipgloss.NewStyle().MarginBottom(1).Render(rendered)
|
||||
|
||||
return UIMessage{
|
||||
Type: AssistantMessage,
|
||||
@@ -257,30 +196,14 @@ func (r *MessageRenderer) RenderAssistantMessage(content string, timestamp time.
|
||||
}
|
||||
}
|
||||
|
||||
// RenderSystemMessage renders KIT system messages such as help text, command outputs,
|
||||
// and informational notifications. These messages are displayed with a distinctive system
|
||||
// color border and "KIT System" label to differentiate them from user and AI content.
|
||||
// RenderSystemMessage renders KIT system messages using herald Note alert
|
||||
func (r *MessageRenderer) RenderSystemMessage(content string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
|
||||
var messageContent string
|
||||
if strings.TrimSpace(content) == "" {
|
||||
messageContent = "No content available"
|
||||
} else if strings.Contains(content, "`") {
|
||||
messageContent = r.renderMarkdown(content, r.width-8)
|
||||
} else {
|
||||
messageContent = content
|
||||
content = "No content available"
|
||||
}
|
||||
|
||||
fullContent := "◇ " + strings.TrimSuffix(messageContent, "\n")
|
||||
|
||||
rendered := renderContentBlock(
|
||||
fullContent,
|
||||
r.width,
|
||||
WithNoBorder(),
|
||||
WithForeground(theme.Muted),
|
||||
WithMarginBottom(1),
|
||||
)
|
||||
rendered := r.ty.Note(content)
|
||||
rendered = lipgloss.NewStyle().MarginBottom(1).Render(rendered)
|
||||
|
||||
return UIMessage{
|
||||
Type: SystemMessage,
|
||||
@@ -290,27 +213,9 @@ func (r *MessageRenderer) RenderSystemMessage(content string, timestamp time.Tim
|
||||
}
|
||||
}
|
||||
|
||||
// RenderDebugMessage renders diagnostic and debugging information with special formatting
|
||||
// including a debug icon, colored border, and structured layout. Debug messages are only
|
||||
// displayed when debug mode is enabled and help developers troubleshoot issues.
|
||||
// RenderDebugMessage renders diagnostic and debugging information
|
||||
func (r *MessageRenderer) RenderDebugMessage(message string, timestamp time.Time) UIMessage {
|
||||
baseStyle := lipgloss.NewStyle()
|
||||
|
||||
theme := getTheme()
|
||||
style := baseStyle.
|
||||
Width(r.width - 3).
|
||||
BorderLeft(true).
|
||||
Foreground(theme.Muted).
|
||||
BorderForeground(theme.Tool).
|
||||
BorderStyle(lipgloss.ThickBorder()).
|
||||
PaddingLeft(1).
|
||||
MarginLeft(2).
|
||||
MarginBottom(1)
|
||||
|
||||
header := baseStyle.
|
||||
Foreground(theme.Tool).
|
||||
Bold(true).
|
||||
Render("🔍 Debug Output")
|
||||
header := r.ty.H6("🔍 Debug Output")
|
||||
|
||||
lines := strings.Split(message, "\n")
|
||||
var formattedLines []string
|
||||
@@ -320,87 +225,52 @@ func (r *MessageRenderer) RenderDebugMessage(message string, timestamp time.Time
|
||||
}
|
||||
}
|
||||
|
||||
content := baseStyle.
|
||||
Foreground(theme.Muted).
|
||||
Render(strings.Join(formattedLines, "\n"))
|
||||
|
||||
fullContent := lipgloss.JoinVertical(lipgloss.Left,
|
||||
content := r.ty.Compose(
|
||||
header,
|
||||
content,
|
||||
r.ty.P(strings.Join(formattedLines, "\n")),
|
||||
)
|
||||
content = lipgloss.NewStyle().MarginBottom(1).Render(content)
|
||||
|
||||
return UIMessage{
|
||||
Content: style.Render(fullContent),
|
||||
Height: lipgloss.Height(style.Render(fullContent)),
|
||||
Content: content,
|
||||
Height: lipgloss.Height(content),
|
||||
}
|
||||
}
|
||||
|
||||
// RenderDebugConfigMessage renders configuration settings in a formatted debug display
|
||||
// with key-value pairs shown in a structured layout. Used to display runtime configuration
|
||||
// for debugging purposes with a distinctive icon and border styling.
|
||||
// RenderDebugConfigMessage renders configuration settings
|
||||
func (r *MessageRenderer) RenderDebugConfigMessage(config map[string]any, timestamp time.Time) UIMessage {
|
||||
baseStyle := lipgloss.NewStyle()
|
||||
|
||||
theme := getTheme()
|
||||
style := baseStyle.
|
||||
Width(r.width - 1).
|
||||
BorderLeft(true).
|
||||
Foreground(theme.Muted).
|
||||
BorderForeground(theme.Tool).
|
||||
BorderStyle(lipgloss.ThickBorder()).
|
||||
PaddingLeft(1)
|
||||
|
||||
header := baseStyle.
|
||||
Foreground(theme.Tool).
|
||||
Bold(true).
|
||||
Render("🔧 Debug Configuration")
|
||||
header := r.ty.H6("🔧 Debug Configuration")
|
||||
|
||||
var configLines []string
|
||||
for key, value := range config {
|
||||
if value != nil {
|
||||
configLines = append(configLines, fmt.Sprintf(" %s: %v", key, value))
|
||||
configLines = append(configLines, fmt.Sprintf("%s: %v", key, value))
|
||||
}
|
||||
}
|
||||
|
||||
configContent := baseStyle.
|
||||
Foreground(theme.Muted).
|
||||
Render(strings.Join(configLines, "\n"))
|
||||
|
||||
parts := []string{header}
|
||||
var content string
|
||||
if len(configLines) > 0 {
|
||||
parts = append(parts, configContent)
|
||||
content = r.ty.Compose(
|
||||
header,
|
||||
r.ty.P(strings.Join(configLines, "\n")),
|
||||
)
|
||||
} else {
|
||||
content = header
|
||||
}
|
||||
|
||||
rendered := style.Render(
|
||||
lipgloss.JoinVertical(lipgloss.Left, parts...),
|
||||
)
|
||||
content = lipgloss.NewStyle().MarginBottom(1).Render(content)
|
||||
|
||||
return UIMessage{
|
||||
Type: SystemMessage,
|
||||
Content: rendered,
|
||||
Height: lipgloss.Height(rendered),
|
||||
Content: content,
|
||||
Height: lipgloss.Height(content),
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderErrorMessage renders error notifications with distinctive red coloring and
|
||||
// bold text to ensure visibility. Error messages include timestamp information and
|
||||
// are displayed with an error-colored border for immediate recognition.
|
||||
// RenderErrorMessage renders error notifications
|
||||
func (r *MessageRenderer) RenderErrorMessage(errorMsg string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
|
||||
errorContent := lipgloss.NewStyle().
|
||||
Foreground(theme.Error).
|
||||
Bold(true).
|
||||
Render(errorMsg)
|
||||
|
||||
rendered := renderContentBlock(
|
||||
errorContent,
|
||||
r.width,
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(theme.Error),
|
||||
WithMarginBottom(1),
|
||||
)
|
||||
rendered := r.ty.Caution(errorMsg)
|
||||
rendered = lipgloss.NewStyle().MarginBottom(1).Render(rendered)
|
||||
|
||||
return UIMessage{
|
||||
Type: ErrorMessage,
|
||||
@@ -410,93 +280,18 @@ func (r *MessageRenderer) RenderErrorMessage(errorMsg string, timestamp time.Tim
|
||||
}
|
||||
}
|
||||
|
||||
// RenderToolCallMessage renders a notification that a tool is being executed, showing
|
||||
// the tool name, formatted arguments (if any), and execution timestamp. The message
|
||||
// uses tool-specific coloring to distinguish it from regular conversation messages.
|
||||
func (r *MessageRenderer) RenderToolCallMessage(toolName, toolArgs string, timestamp time.Time) UIMessage {
|
||||
// Format timestamp
|
||||
timeStr := timestamp.Local().Format("15:04")
|
||||
|
||||
// Format arguments with better presentation
|
||||
theme := getTheme()
|
||||
var argsContent string
|
||||
if toolArgs != "" && toolArgs != "{}" {
|
||||
argsContent = lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Italic(true).
|
||||
Render(fmt.Sprintf("Arguments: %s", r.formatToolArgs(toolArgs)))
|
||||
}
|
||||
|
||||
// Create info line
|
||||
info := fmt.Sprintf(" Executing %s (%s)", toolName, timeStr)
|
||||
|
||||
// Combine parts
|
||||
var fullContent string
|
||||
if argsContent != "" {
|
||||
fullContent = argsContent + "\n" +
|
||||
lipgloss.NewStyle().Foreground(theme.VeryMuted).Render(info)
|
||||
} else {
|
||||
fullContent = lipgloss.NewStyle().Foreground(theme.VeryMuted).Render(info)
|
||||
}
|
||||
|
||||
// Use the new block renderer
|
||||
rendered := renderContentBlock(
|
||||
fullContent,
|
||||
r.width,
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(theme.Tool),
|
||||
WithMarginBottom(1),
|
||||
)
|
||||
|
||||
return UIMessage{
|
||||
Type: ToolCallMessage,
|
||||
Content: rendered,
|
||||
Height: lipgloss.Height(rendered),
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderToolMessage renders a unified tool block combining the tool invocation
|
||||
// header (icon + display name + params) with the execution result body. The
|
||||
// border color indicates status: green for success, red for error. This replaces
|
||||
// the previous two-block approach (separate call + result blocks).
|
||||
// RenderToolMessage renders a unified tool block
|
||||
func (r *MessageRenderer) RenderToolMessage(toolName, toolArgs, toolResult string, isError bool) UIMessage {
|
||||
theme := getTheme()
|
||||
|
||||
// Resolve extension renderer once for all overrides.
|
||||
var extRd *ToolRendererData
|
||||
if r.getToolRenderer != nil {
|
||||
extRd = r.getToolRenderer(toolName)
|
||||
}
|
||||
|
||||
// --- Header: [icon] [name] [params] ---
|
||||
var icon string
|
||||
borderColor := theme.Success
|
||||
iconColor := theme.Success
|
||||
if isError {
|
||||
icon = "×"
|
||||
borderColor = theme.Error
|
||||
iconColor = theme.Error
|
||||
} else {
|
||||
icon = "✓"
|
||||
}
|
||||
|
||||
// Extension can override border color (applies to both success and error).
|
||||
if extRd != nil && extRd.BorderColor != "" {
|
||||
borderColor = lipgloss.Color(extRd.BorderColor)
|
||||
}
|
||||
|
||||
iconStr := lipgloss.NewStyle().Foreground(iconColor).Bold(true).Render(icon)
|
||||
|
||||
// Extension can override display name.
|
||||
displayName := toolDisplayName(toolName)
|
||||
if extRd != nil && extRd.DisplayName != "" {
|
||||
displayName = extRd.DisplayName
|
||||
}
|
||||
nameStr := lipgloss.NewStyle().Foreground(theme.Info).Bold(true).Render(displayName)
|
||||
|
||||
// Format params with width budget for the header line.
|
||||
// Check extension renderer for custom header params first.
|
||||
paramBudget := max(r.width-10-len(displayName), 20)
|
||||
var params string
|
||||
if extRd != nil && extRd.RenderHeader != nil {
|
||||
@@ -506,69 +301,70 @@ func (r *MessageRenderer) RenderToolMessage(toolName, toolArgs, toolResult strin
|
||||
params = formatToolParams(toolArgs, paramBudget)
|
||||
}
|
||||
|
||||
header := iconStr + " " + nameStr
|
||||
if params != "" {
|
||||
header += " " + lipgloss.NewStyle().Foreground(theme.Muted).Render(params)
|
||||
var icon string
|
||||
iconColor := GetTheme().Success
|
||||
if isError {
|
||||
icon = "×"
|
||||
iconColor = GetTheme().Error
|
||||
} else {
|
||||
icon = "✓"
|
||||
}
|
||||
|
||||
// --- Body: check extension renderer first, then builtin, then default ---
|
||||
// Style the tool name with color
|
||||
theme := GetTheme()
|
||||
nameColor := theme.Info
|
||||
if isError {
|
||||
nameColor = theme.Error
|
||||
}
|
||||
styledName := lipgloss.NewStyle().Foreground(nameColor).Bold(true).Render(displayName)
|
||||
styledIcon := lipgloss.NewStyle().Foreground(iconColor).Render(icon)
|
||||
|
||||
// Build the content: icon + name + params on first line, then body
|
||||
headerLine := styledIcon + " " + styledName
|
||||
if params != "" {
|
||||
headerLine += " " + lipgloss.NewStyle().Foreground(theme.Muted).Render(params)
|
||||
}
|
||||
|
||||
// Get body content
|
||||
var body string
|
||||
if extRd != nil && extRd.RenderBody != nil {
|
||||
body = extRd.RenderBody(toolResult, isError, r.width-8)
|
||||
// Apply markdown rendering if requested and body is non-empty.
|
||||
if body != "" && extRd.BodyMarkdown {
|
||||
body = strings.TrimSuffix(toMarkdown(body, r.width-8), "\n")
|
||||
}
|
||||
}
|
||||
if body == "" {
|
||||
if isError {
|
||||
body = lipgloss.NewStyle().
|
||||
Foreground(theme.Error).
|
||||
Render(toolResult)
|
||||
body = r.formatToolResult(toolName, toolResult)
|
||||
} else {
|
||||
body = renderToolBody(toolName, toolArgs, toolResult, r.width-8)
|
||||
if body == "" {
|
||||
body = r.formatToolResult(toolName, toolResult, r.width-8)
|
||||
body = r.formatToolResult(toolName, toolResult)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(body) == "" {
|
||||
body = lipgloss.NewStyle().
|
||||
Italic(true).
|
||||
Foreground(theme.Muted).
|
||||
Render("(no output)")
|
||||
body = r.ty.Italic("(no output)")
|
||||
}
|
||||
|
||||
// Combine header + body into a single block.
|
||||
fullContent := header + "\n\n" + strings.TrimSuffix(body, "\n")
|
||||
|
||||
// Build rendering options; extension can override background.
|
||||
blockOpts := []renderingOption{
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(borderColor),
|
||||
WithMarginBottom(1),
|
||||
}
|
||||
if extRd != nil && extRd.Background != "" {
|
||||
blockOpts = append(blockOpts, WithBackground(lipgloss.Color(extRd.Background)))
|
||||
}
|
||||
|
||||
rendered := renderContentBlock(
|
||||
fullContent,
|
||||
r.width,
|
||||
blockOpts...,
|
||||
// Compose: icon + name + params, then body
|
||||
fullContent := r.ty.Compose(
|
||||
headerLine,
|
||||
"",
|
||||
body,
|
||||
)
|
||||
fullContent = lipgloss.NewStyle().MarginBottom(1).Render(fullContent)
|
||||
|
||||
return UIMessage{
|
||||
Type: ToolMessage,
|
||||
Content: rendered,
|
||||
Height: lipgloss.Height(rendered),
|
||||
Content: fullContent,
|
||||
Height: lipgloss.Height(fullContent),
|
||||
}
|
||||
}
|
||||
|
||||
// formatToolArgs formats tool arguments for display
|
||||
func (r *MessageRenderer) formatToolArgs(args string) string {
|
||||
// Remove outer braces and clean up JSON formatting
|
||||
args = strings.TrimSpace(args)
|
||||
if strings.HasPrefix(args, "{") && strings.HasSuffix(args, "}") {
|
||||
args = strings.TrimPrefix(args, "{")
|
||||
@@ -576,12 +372,10 @@ func (r *MessageRenderer) formatToolArgs(args string) string {
|
||||
args = strings.TrimSpace(args)
|
||||
}
|
||||
|
||||
// If it's empty after cleanup, return a placeholder
|
||||
if args == "" {
|
||||
return "(no arguments)"
|
||||
}
|
||||
|
||||
// Truncate if too long, but skip truncation in debug mode
|
||||
if !r.debug {
|
||||
maxLen := 100
|
||||
if len(args) > maxLen {
|
||||
@@ -593,10 +387,7 @@ func (r *MessageRenderer) formatToolArgs(args string) string {
|
||||
}
|
||||
|
||||
// formatToolResult formats tool results based on tool type
|
||||
func (r *MessageRenderer) formatToolResult(toolName, result string, width int) string {
|
||||
baseStyle := lipgloss.NewStyle()
|
||||
|
||||
// Truncate very long results only if not in debug mode
|
||||
func (r *MessageRenderer) formatToolResult(toolName, result string) string {
|
||||
if !r.debug {
|
||||
maxLines := 10
|
||||
lines := strings.Split(result, "\n")
|
||||
@@ -605,51 +396,47 @@ func (r *MessageRenderer) formatToolResult(toolName, result string, width int) s
|
||||
}
|
||||
}
|
||||
|
||||
// Format bash/command output with better formatting
|
||||
if strings.Contains(toolName, "bash") || strings.Contains(toolName, "command") || strings.Contains(toolName, "shell") || toolName == "run_shell_cmd" {
|
||||
theme := getTheme()
|
||||
|
||||
// Split result into sections if it contains both stdout and stderr
|
||||
if strings.Contains(toolName, "bash") || strings.Contains(toolName, "command") ||
|
||||
strings.Contains(toolName, "shell") {
|
||||
if strings.Contains(result, "<stdout>") || strings.Contains(result, "<stderr>") {
|
||||
return r.formatBashOutput(result, width, theme)
|
||||
return parseBashOutput(result, GetTheme())
|
||||
}
|
||||
|
||||
// For simple output, just render as monospace text with proper line breaks
|
||||
return baseStyle.
|
||||
Width(width).
|
||||
Foreground(theme.Muted).
|
||||
Render(result)
|
||||
}
|
||||
|
||||
// For other tools, render as muted text
|
||||
theme := getTheme()
|
||||
return baseStyle.
|
||||
Width(width).
|
||||
Foreground(theme.Muted).
|
||||
Render(result)
|
||||
return result
|
||||
}
|
||||
|
||||
// formatBashOutput formats bash command output with proper section handling.
|
||||
// Delegates tag parsing to the shared parseBashOutput helper.
|
||||
func (r *MessageRenderer) formatBashOutput(result string, width int, theme Theme) string {
|
||||
parsed := parseBashOutput(result, theme)
|
||||
return lipgloss.NewStyle().
|
||||
Width(width).
|
||||
Foreground(theme.Muted).
|
||||
Render(parsed)
|
||||
}
|
||||
|
||||
// renderMarkdown renders markdown content using glamour
|
||||
func (r *MessageRenderer) renderMarkdown(content string, width int) string {
|
||||
rendered := toMarkdown(content, width)
|
||||
return strings.TrimSuffix(rendered, "\n")
|
||||
// createTypography creates a typography instance from theme
|
||||
func createTypography(theme Theme) *herald.Typography {
|
||||
return herald.New(
|
||||
herald.WithPalette(herald.ColorPalette{
|
||||
Primary: theme.Primary,
|
||||
Secondary: theme.Secondary,
|
||||
Tertiary: theme.Info,
|
||||
Accent: theme.Accent,
|
||||
Highlight: theme.Highlight,
|
||||
Muted: theme.Muted,
|
||||
Text: theme.Text,
|
||||
Surface: theme.Background,
|
||||
Base: theme.CodeBg,
|
||||
}),
|
||||
herald.WithAlertPalette(herald.AlertPalette{
|
||||
Note: theme.Info,
|
||||
Tip: theme.Success,
|
||||
Important: theme.Accent,
|
||||
Warning: theme.Warning,
|
||||
Caution: theme.Error,
|
||||
}),
|
||||
herald.WithCodeLineNumbers(true),
|
||||
// Customize alert labels
|
||||
herald.WithAlertLabel(herald.AlertNote, "Info"),
|
||||
herald.WithAlertLabel(herald.AlertTip, "You"),
|
||||
herald.WithAlertLabel(herald.AlertWarning, "Working"),
|
||||
herald.WithAlertLabel(herald.AlertCaution, "Error"),
|
||||
)
|
||||
}
|
||||
|
||||
// removeBlankLines removes lines that are visually blank from rendered output.
|
||||
// Glamour wraps every character (including padding spaces) with ANSI color
|
||||
// codes, so we must strip escape sequences before checking whether a line is
|
||||
// empty. This collapses paragraph spacing so user messages render without
|
||||
// extra vertical gaps.
|
||||
func removeBlankLines(s string) string {
|
||||
lines := strings.Split(s, "\n")
|
||||
filtered := lines[:0]
|
||||
|
||||
+244
-52
@@ -3,11 +3,11 @@ package ui
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
@@ -83,6 +83,9 @@ type AppController interface {
|
||||
// GetTreeSession returns the tree session manager, or nil if tree sessions
|
||||
// are not enabled. Used by slash commands like /tree, /fork, /session.
|
||||
GetTreeSession() *session.TreeManager
|
||||
// SwitchTreeSession replaces the active tree session with a new one,
|
||||
// closing the old session. Used by /new to create a completely fresh session.
|
||||
SwitchTreeSession(ts *session.TreeManager)
|
||||
// SendEvent sends a tea.Msg to the program asynchronously. Safe to call
|
||||
// from any goroutine. Used by extension command goroutines to deliver
|
||||
// results back to the TUI without going through tea.Cmd (which can stall
|
||||
@@ -98,6 +101,12 @@ type AppController interface {
|
||||
// alongside the text. Returns the current queue depth (0 = started
|
||||
// immediately, >0 = queued).
|
||||
RunWithFiles(prompt string, files []fantasy.FilePart) int
|
||||
// Steer injects a steering message into the currently running agent
|
||||
// turn. If the agent is busy, the message is delivered between steps
|
||||
// (after current tool finishes, before next LLM call). If idle, the
|
||||
// message starts executing immediately. Returns 0 if started
|
||||
// immediately, >0 if injected/pending.
|
||||
Steer(prompt string) int
|
||||
}
|
||||
|
||||
// SkillItem holds display metadata about a loaded skill for the startup
|
||||
@@ -415,6 +424,11 @@ type AppModel struct {
|
||||
// the input and move to scrollback when the agent picks them up.
|
||||
queuedMessages []string
|
||||
|
||||
// steeringMessages stores the text of prompts that were sent as steer
|
||||
// messages (injected mid-turn via Ctrl+S). Rendered with a "STEERING"
|
||||
// badge above the input. Cleared when the steer is consumed.
|
||||
steeringMessages []string
|
||||
|
||||
// pendingUserPrints holds user messages that have been consumed from the
|
||||
// queue but not yet printed to scrollback. They are deferred until
|
||||
// SpinnerEvent{Show: true} so the previous assistant response can be
|
||||
@@ -569,8 +583,10 @@ type AppModel struct {
|
||||
streamingBashStderr []string
|
||||
// streamingBashMaxLines caps how many lines to accumulate to prevent memory issues.
|
||||
streamingBashMaxLines int
|
||||
// streamingMu protects the streaming bash output fields from concurrent access.
|
||||
streamingMu sync.RWMutex
|
||||
// streaming bash fields are only mutated/read from the Bubble Tea event loop
|
||||
// (Update/View), so no mutex is required here.
|
||||
// streamingBashCommand holds the command being executed for display as a header.
|
||||
streamingBashCommand string
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -770,28 +786,29 @@ func (m *AppModel) PrintStartupInfo() {
|
||||
return
|
||||
}
|
||||
|
||||
render := func(text string) string {
|
||||
return m.renderer.RenderSystemMessage(text, time.Now()).Content
|
||||
}
|
||||
// Create typography instance for startup rendering
|
||||
ty := createTypography(GetTheme())
|
||||
|
||||
fmt.Println()
|
||||
|
||||
// Build the combined startup content.
|
||||
var lines []string
|
||||
// Build key-value pairs for startup info
|
||||
var pairs [][2]string
|
||||
|
||||
if m.providerName != "" && m.modelName != "" {
|
||||
lines = append(lines, fmt.Sprintf("Model loaded: %s (%s)", m.providerName, m.modelName))
|
||||
pairs = append(pairs, [2]string{"Model", fmt.Sprintf("%s (%s)", m.providerName, m.modelName)})
|
||||
}
|
||||
|
||||
if m.loadingMessage != "" {
|
||||
lines = append(lines, m.loadingMessage)
|
||||
pairs = append(pairs, [2]string{"Status", m.loadingMessage})
|
||||
}
|
||||
|
||||
// Context — loaded AGENTS.md files.
|
||||
if len(m.contextPaths) > 0 {
|
||||
for _, p := range m.contextPaths {
|
||||
lines = append(lines, fmt.Sprintf("Context: %s", tildeHome(p)))
|
||||
contextStr := tildeHome(m.contextPaths[0])
|
||||
if len(m.contextPaths) > 1 {
|
||||
contextStr += fmt.Sprintf(" +%d more", len(m.contextPaths)-1)
|
||||
}
|
||||
pairs = append(pairs, [2]string{"Context", contextStr})
|
||||
}
|
||||
|
||||
// Skills — listed by name.
|
||||
@@ -800,21 +817,23 @@ func (m *AppModel) PrintStartupInfo() {
|
||||
for i, si := range m.skillItems {
|
||||
names[i] = si.Name
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("Skills: %s", strings.Join(names, ", ")))
|
||||
pairs = append(pairs, [2]string{"Skills", strings.Join(names, ", ")})
|
||||
}
|
||||
|
||||
// Extension tool count (only shown when > 0).
|
||||
if m.extensionToolCount > 0 {
|
||||
lines = append(lines, fmt.Sprintf("Loaded %d extension tools", m.extensionToolCount))
|
||||
pairs = append(pairs, [2]string{"Extensions", fmt.Sprintf("%d tools", m.extensionToolCount)})
|
||||
}
|
||||
|
||||
// MCP tool count (only shown when > 0).
|
||||
if m.mcpToolCount > 0 {
|
||||
lines = append(lines, fmt.Sprintf("Loaded %d tools from MCP servers", m.mcpToolCount))
|
||||
pairs = append(pairs, [2]string{"MCP", fmt.Sprintf("%d tools", m.mcpToolCount)})
|
||||
}
|
||||
|
||||
if len(lines) > 0 {
|
||||
fmt.Println(render(strings.Join(lines, "\n\n")))
|
||||
if len(pairs) > 0 {
|
||||
rendered := ty.KVGroup(pairs)
|
||||
rendered = lipgloss.NewStyle().MarginBottom(1).Render(rendered)
|
||||
fmt.Println(rendered)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1070,6 +1089,45 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
// In other states pass ESC through to children below.
|
||||
|
||||
case "ctrl+s":
|
||||
// Steer: inject the current input as a steering message into the
|
||||
// running agent turn. Only active during stateWorking — in input
|
||||
// state, Ctrl+S is passed through to children (no-op by default).
|
||||
if m.state == stateWorking && m.appCtrl != nil {
|
||||
var text string
|
||||
if ic, ok := m.input.(*InputComponent); ok {
|
||||
text = strings.TrimSpace(ic.textarea.Value())
|
||||
}
|
||||
if text != "" {
|
||||
// Clear the input and push to history.
|
||||
if ic, ok := m.input.(*InputComponent); ok {
|
||||
ic.pushHistory(text)
|
||||
ic.textarea.SetValue("")
|
||||
}
|
||||
|
||||
// Preprocess @file references.
|
||||
processedText := text
|
||||
if m.cwd != "" {
|
||||
processedText = ProcessFileAttachments(text, m.cwd)
|
||||
}
|
||||
|
||||
// Inject the steer message.
|
||||
sLen := m.appCtrl.Steer(processedText)
|
||||
if sLen > 0 {
|
||||
m.steeringMessages = append(m.steeringMessages, text)
|
||||
m.distributeHeight()
|
||||
} else {
|
||||
// Started immediately (agent was idle).
|
||||
m.pendingUserPrints = append(m.pendingUserPrints, text)
|
||||
m.flushStreamAndPendingUserMessages()
|
||||
if m.state != stateWorking {
|
||||
m.state = stateWorking
|
||||
}
|
||||
}
|
||||
}
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
}
|
||||
|
||||
// Route key events to the focused child. Check for editor
|
||||
@@ -1316,6 +1374,16 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// rendered when the ToolResultEvent arrives.
|
||||
m.flushStreamContent()
|
||||
|
||||
// For bash commands, extract and store the command for the streaming output header.
|
||||
if msg.ToolName == "bash" {
|
||||
var args struct {
|
||||
Command string `json:"command"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(msg.ToolArgs), &args); err == nil && args.Command != "" {
|
||||
m.streamingBashCommand = args.Command
|
||||
}
|
||||
}
|
||||
|
||||
case app.ToolExecutionEvent:
|
||||
// Pass to stream component for execution spinner display.
|
||||
if m.stream != nil {
|
||||
@@ -1327,10 +1395,9 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// Buffer tool result for scrollback.
|
||||
m.printToolResult(msg)
|
||||
// Clear streaming bash output since tool completed.
|
||||
m.streamingMu.Lock()
|
||||
m.streamingBashOutput = nil
|
||||
m.streamingBashStderr = nil
|
||||
m.streamingMu.Unlock()
|
||||
m.streamingBashCommand = ""
|
||||
// Start spinner again while waiting for the next LLM response.
|
||||
if m.stream != nil {
|
||||
_, cmd := m.stream.Update(app.SpinnerEvent{Show: true})
|
||||
@@ -1339,7 +1406,6 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
case app.ToolOutputEvent:
|
||||
// Accumulate streaming bash output for display.
|
||||
m.streamingMu.Lock()
|
||||
if msg.IsStderr {
|
||||
m.streamingBashStderr = append(m.streamingBashStderr, msg.Chunk)
|
||||
// Cap stderr lines to prevent memory issues.
|
||||
@@ -1353,7 +1419,6 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.streamingBashOutput = m.streamingBashOutput[len(m.streamingBashOutput)-m.streamingBashMaxLines:]
|
||||
}
|
||||
}
|
||||
m.streamingMu.Unlock()
|
||||
|
||||
case app.ToolCallContentEvent:
|
||||
// In streaming mode this text was already delivered via StreamChunkEvents
|
||||
@@ -1389,6 +1454,38 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
m.distributeHeight()
|
||||
|
||||
case app.SteerConsumedEvent:
|
||||
// Steering messages were consumed — either injected mid-turn via
|
||||
// PrepareStep, or drained into the queue after a text-only turn.
|
||||
//
|
||||
// Two cases:
|
||||
//
|
||||
// 1. Mid-turn (stateWorking, PrepareStep fired): no SpinnerEvent{Show:
|
||||
// true} will follow within this turn, so we cannot rely on
|
||||
// flushStreamAndPendingUserMessages() being called. Flush any live
|
||||
// stream content first (assistant text up to the steer point), then
|
||||
// render the steering user messages immediately to scrollback.
|
||||
//
|
||||
// 2. Post-turn (text-only response, drained after StepComplete): a
|
||||
// SpinnerEvent{Show: true} for the next turn is already in flight.
|
||||
// Defer to pendingUserPrints so the previous assistant response is
|
||||
// flushed first, preserving chronological order.
|
||||
if m.state == stateWorking {
|
||||
// Case 1: mid-turn — flush + print immediately.
|
||||
m.flushStreamContent()
|
||||
for _, text := range m.steeringMessages {
|
||||
m.printUserMessage(text)
|
||||
}
|
||||
m.steeringMessages = m.steeringMessages[:0]
|
||||
m.distributeHeight()
|
||||
cmds = append(cmds, m.drainScrollback())
|
||||
} else {
|
||||
// Case 2: post-turn — defer so SpinnerEvent orders correctly.
|
||||
m.pendingUserPrints = append(m.pendingUserPrints, m.steeringMessages...)
|
||||
m.steeringMessages = m.steeringMessages[:0]
|
||||
m.distributeHeight()
|
||||
}
|
||||
|
||||
case app.StepCompleteEvent:
|
||||
// Keep stream content visible in the view — don't flush to scrollback
|
||||
// yet. Flushing + resetting in the same frame would shrink the view
|
||||
@@ -1641,6 +1738,7 @@ func (m *AppModel) View() tea.View {
|
||||
// Propagate hint visibility to the input component before rendering.
|
||||
if ic, ok := m.input.(*InputComponent); ok {
|
||||
ic.hideHint = vis.HideInputHint
|
||||
ic.agentBusy = m.state == stateWorking
|
||||
}
|
||||
|
||||
// When a prompt is active, it replaces the input area for consistency
|
||||
@@ -1742,20 +1840,23 @@ func (m *AppModel) renderStream() string {
|
||||
|
||||
// renderStreamingBashOutput renders accumulated streaming bash output (stdout + stderr)
|
||||
// below the LLM streaming text. Returns empty string if no bash output is present.
|
||||
// Lines are truncated to the terminal width and capped to maxBashLines to prevent
|
||||
// long-running commands from blowing up the TUI layout.
|
||||
func (m *AppModel) renderStreamingBashOutput(theme Theme) string {
|
||||
m.streamingMu.RLock()
|
||||
stdoutLines := make([]string, len(m.streamingBashOutput))
|
||||
copy(stdoutLines, m.streamingBashOutput)
|
||||
stderrLines := make([]string, len(m.streamingBashStderr))
|
||||
copy(stderrLines, m.streamingBashStderr)
|
||||
m.streamingMu.RUnlock()
|
||||
command := m.streamingBashCommand
|
||||
|
||||
if len(stdoutLines) == 0 && len(stderrLines) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
const lineIndent = " "
|
||||
width := m.width - 2 // Account for indent and padding
|
||||
lineWidth := max(m.width-2-len(lineIndent), 20)
|
||||
// Account for PaddingLeft(1) on the output/stderr styles.
|
||||
maxLineChars := lineWidth - 1
|
||||
|
||||
outputStyle := lipgloss.NewStyle().
|
||||
Background(theme.CodeBg).
|
||||
@@ -1766,17 +1867,59 @@ func (m *AppModel) renderStreamingBashOutput(theme Theme) string {
|
||||
Background(theme.CodeBg).
|
||||
PaddingLeft(1)
|
||||
|
||||
// Header style for the command - muted text with a subtle indicator.
|
||||
headerStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
PaddingLeft(1)
|
||||
|
||||
// Cap displayed lines to maxBashLines (show the tail, since streaming
|
||||
// output is most useful at the end). The buffer itself is larger to
|
||||
// preserve context, but we only render the last N lines.
|
||||
totalLines := len(stdoutLines) + len(stderrLines)
|
||||
var hiddenCount int
|
||||
if totalLines > maxBashLines {
|
||||
hiddenCount = totalLines - maxBashLines
|
||||
// Trim from stdout first (older output), then stderr.
|
||||
remaining := maxBashLines
|
||||
if len(stderrLines) >= remaining {
|
||||
stdoutLines = nil
|
||||
stderrLines = stderrLines[len(stderrLines)-remaining:]
|
||||
} else {
|
||||
remaining -= len(stderrLines)
|
||||
if len(stdoutLines) > remaining {
|
||||
stdoutLines = stdoutLines[len(stdoutLines)-remaining:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var lines []string
|
||||
|
||||
// Command header - show the bash command being executed.
|
||||
if command != "" {
|
||||
headerText := fmt.Sprintf("$ %s", command)
|
||||
headerContent := headerStyle.Width(lineWidth).Render(truncateLine(headerText, maxLineChars))
|
||||
lines = append(lines, lineIndent+headerContent)
|
||||
}
|
||||
|
||||
// Truncation hint at the top.
|
||||
if hiddenCount > 0 {
|
||||
hint := fmt.Sprintf("...(%d more lines above)", hiddenCount)
|
||||
hintContent := outputStyle.Width(lineWidth).
|
||||
Foreground(theme.Muted).Italic(true).Render(hint)
|
||||
lines = append(lines, lineIndent+hintContent)
|
||||
}
|
||||
|
||||
// Render stdout lines.
|
||||
for _, line := range stdoutLines {
|
||||
styled := outputStyle.Width(width - len(lineIndent)).Render(line)
|
||||
line = truncateLine(strings.TrimRight(line, "\n"), maxLineChars)
|
||||
styled := outputStyle.Width(lineWidth).Render(line)
|
||||
lines = append(lines, lineIndent+styled)
|
||||
}
|
||||
|
||||
// Render stderr lines with error styling.
|
||||
for _, line := range stderrLines {
|
||||
styled := stderrStyle.Width(width - len(lineIndent)).Render(line)
|
||||
line = truncateLine(strings.TrimRight(line, "\n"), maxLineChars)
|
||||
styled := stderrStyle.Width(lineWidth).Render(line)
|
||||
lines = append(lines, lineIndent+styled)
|
||||
}
|
||||
|
||||
@@ -1901,16 +2044,26 @@ func (m *AppModel) cycleThinkingLevel() {
|
||||
go func() { _ = SaveThinkingLevelPreference(next) }()
|
||||
}
|
||||
|
||||
// renderSeparator renders the separator line with an optional queue count badge.
|
||||
// renderSeparator renders the separator line with an optional queue/steer count badge.
|
||||
func (m *AppModel) renderSeparator() string {
|
||||
theme := GetTheme()
|
||||
lineStyle := lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
queueLen := len(m.queuedMessages)
|
||||
steerLen := len(m.steeringMessages)
|
||||
|
||||
if queueLen > 0 {
|
||||
badge := lipgloss.NewStyle().
|
||||
Foreground(theme.Secondary).
|
||||
Render(fmt.Sprintf("%d queued", queueLen))
|
||||
if steerLen > 0 || queueLen > 0 {
|
||||
var parts []string
|
||||
if steerLen > 0 {
|
||||
parts = append(parts, lipgloss.NewStyle().
|
||||
Foreground(theme.Warning).
|
||||
Render(fmt.Sprintf("%d steering", steerLen)))
|
||||
}
|
||||
if queueLen > 0 {
|
||||
parts = append(parts, lipgloss.NewStyle().
|
||||
Foreground(theme.Secondary).
|
||||
Render(fmt.Sprintf("%d queued", queueLen)))
|
||||
}
|
||||
badge := strings.Join(parts, " ")
|
||||
|
||||
// Fill the separator with dashes up to the badge.
|
||||
dashWidth := max(m.width-lipgloss.Width(badge)-1, 0)
|
||||
@@ -2009,27 +2162,47 @@ func (m *AppModel) renderHeaderFooter(getter func() *WidgetData) string {
|
||||
return renderContentBlock(data.Text, m.width, opts...)
|
||||
}
|
||||
|
||||
// renderQueuedMessages renders queued prompts as styled content blocks with a
|
||||
// "QUEUED" badge, anchored between the separator and input. Each message is
|
||||
// displayed in a bordered block matching the overall message styling.
|
||||
// renderQueuedMessages renders queued and steering prompts as styled content
|
||||
// blocks with badges, anchored between the separator and input. Steering
|
||||
// messages use a distinct "STEERING" badge to differentiate from queued ones.
|
||||
func (m *AppModel) renderQueuedMessages() string {
|
||||
if len(m.queuedMessages) == 0 {
|
||||
if len(m.queuedMessages) == 0 && len(m.steeringMessages) == 0 {
|
||||
return ""
|
||||
}
|
||||
theme := GetTheme()
|
||||
badge := CreateBadge("QUEUED", theme.Accent)
|
||||
|
||||
var blocks []string
|
||||
for _, msg := range m.queuedMessages {
|
||||
content := msg + "\n" + badge
|
||||
rendered := renderContentBlock(
|
||||
content,
|
||||
m.width,
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(theme.Muted),
|
||||
)
|
||||
blocks = append(blocks, rendered)
|
||||
|
||||
// Render steering messages first (higher priority).
|
||||
if len(m.steeringMessages) > 0 {
|
||||
badge := CreateBadge("STEERING", theme.Warning)
|
||||
for _, msg := range m.steeringMessages {
|
||||
content := msg + "\n" + badge
|
||||
rendered := renderContentBlock(
|
||||
content,
|
||||
m.width,
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(theme.Warning),
|
||||
)
|
||||
blocks = append(blocks, rendered)
|
||||
}
|
||||
}
|
||||
|
||||
// Render queued messages.
|
||||
if len(m.queuedMessages) > 0 {
|
||||
badge := CreateBadge("QUEUED", theme.Accent)
|
||||
for _, msg := range m.queuedMessages {
|
||||
content := msg + "\n" + badge
|
||||
rendered := renderContentBlock(
|
||||
content,
|
||||
m.width,
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(theme.Muted),
|
||||
)
|
||||
blocks = append(blocks, rendered)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(blocks, "\n")
|
||||
}
|
||||
|
||||
@@ -2100,6 +2273,7 @@ func (m *AppModel) handleSlashCommand(sc *SlashCommand) tea.Cmd {
|
||||
m.appCtrl.ClearQueue()
|
||||
}
|
||||
m.queuedMessages = m.queuedMessages[:0]
|
||||
m.steeringMessages = m.steeringMessages[:0]
|
||||
m.distributeHeight()
|
||||
|
||||
case "/tree":
|
||||
@@ -2246,7 +2420,7 @@ func (m *AppModel) printHelpMessage() {
|
||||
"**Navigation:**\n" +
|
||||
"- `/tree`: Navigate session tree (switch branches)\n" +
|
||||
"- `/fork`: Branch from an earlier message\n" +
|
||||
"- `/new`: Start a new branch (preserves history)\n" +
|
||||
"- `/new`: Start a new session (discards context, saves old session)\n" +
|
||||
"- `/resume`: Open session picker to switch sessions\n" +
|
||||
"- `/name <name>`: Set a display name for this session\n\n" +
|
||||
"**System:**\n" +
|
||||
@@ -2287,7 +2461,9 @@ func (m *AppModel) printHelpMessage() {
|
||||
"- `!!command`: Run shell command, output excluded from LLM context\n\n" +
|
||||
"**Keys:**\n" +
|
||||
"- `Ctrl+C`: Exit at any time\n" +
|
||||
"- `ESC` (x2): Cancel ongoing LLM generation\n\n" +
|
||||
"- `ESC` (x2): Cancel ongoing LLM generation\n" +
|
||||
"- `Ctrl+S`: Steer — redirect the agent mid-turn (injected between tool calls)\n" +
|
||||
"- `Enter` (while working): Queue message for after the agent finishes\n\n" +
|
||||
"You can also just type your message to chat with the AI assistant."
|
||||
m.printSystemMessage(help)
|
||||
}
|
||||
@@ -2812,7 +2988,8 @@ func (m *AppModel) handleForkCommand() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleNewCommand starts a fresh session by resetting the tree leaf.
|
||||
// handleNewCommand starts a completely new session (Pi-style /new behavior).
|
||||
// Creates a new session file, discarding all context from the previous conversation.
|
||||
func (m *AppModel) handleNewCommand() tea.Cmd {
|
||||
// Emit before-session-switch event in a goroutine so that extension
|
||||
// handlers can call blocking operations (e.g. ctx.PromptConfirm) without
|
||||
@@ -2835,6 +3012,8 @@ func (m *AppModel) handleNewCommand() tea.Cmd {
|
||||
|
||||
// performNewSession performs the actual session reset. Called either directly
|
||||
// (when no before-hook exists) or after the async hook completes.
|
||||
// Matches Pi behavior: creates a completely new session file, discarding all
|
||||
// context from the previous conversation.
|
||||
func (m *AppModel) performNewSession() tea.Cmd {
|
||||
ts := m.appCtrl.GetTreeSession()
|
||||
if ts == nil {
|
||||
@@ -2842,15 +3021,28 @@ func (m *AppModel) performNewSession() tea.Cmd {
|
||||
if m.appCtrl != nil {
|
||||
m.appCtrl.ClearMessages()
|
||||
}
|
||||
// Reset usage statistics for fresh session
|
||||
if m.usageTracker != nil {
|
||||
m.usageTracker.Reset()
|
||||
}
|
||||
m.printSystemMessage("Conversation cleared. Starting fresh.")
|
||||
return nil
|
||||
}
|
||||
|
||||
ts.ResetLeaf()
|
||||
if m.appCtrl != nil {
|
||||
m.appCtrl.ClearMessages()
|
||||
// Create a brand new session file (Pi-style /new behavior)
|
||||
newTs, err := session.CreateTreeSession(m.cwd)
|
||||
if err != nil {
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to create new session: %v", err))
|
||||
return nil
|
||||
}
|
||||
m.printSystemMessage("New branch started. Previous conversation is preserved in the tree.")
|
||||
|
||||
// Switch to the new session, closing the old one
|
||||
m.appCtrl.SwitchTreeSession(newTs)
|
||||
// Reset usage statistics for the new session
|
||||
if m.usageTracker != nil {
|
||||
m.usageTracker.Reset()
|
||||
}
|
||||
m.printSystemMessage("New session started. Previous conversation saved.")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -54,6 +54,10 @@ func (s *stubAppController) GetTreeSession() *session.TreeManager {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAppController) SwitchTreeSession(_ *session.TreeManager) {
|
||||
// no-op in tests
|
||||
}
|
||||
|
||||
func (s *stubAppController) SendEvent(_ tea.Msg) {
|
||||
// no-op in tests
|
||||
}
|
||||
@@ -67,6 +71,11 @@ func (s *stubAppController) RunWithFiles(prompt string, _ []fantasy.FilePart) in
|
||||
return s.queueLen
|
||||
}
|
||||
|
||||
func (s *stubAppController) Steer(prompt string) int {
|
||||
s.runCalls = append(s.runCalls, prompt)
|
||||
return s.queueLen
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Stub child components
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -679,6 +688,57 @@ func TestToolResult_clearsStreamingBashOutput(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolCallStarted_extractsBashCommand verifies that ToolCallStartedEvent
|
||||
// extracts the bash command from ToolArgs and stores it for the streaming output header.
|
||||
func TestToolCallStarted_extractsBashCommand(t *testing.T) {
|
||||
ctrl := &stubAppController{}
|
||||
m, _, _ := newTestAppModel(ctrl)
|
||||
m.state = stateWorking
|
||||
|
||||
// Send ToolCallStartedEvent with bash command.
|
||||
m = sendMsg(m, app.ToolCallStartedEvent{
|
||||
ToolCallID: "call-1",
|
||||
ToolName: "bash",
|
||||
ToolArgs: `{"command":"ls -la /home"}`,
|
||||
})
|
||||
|
||||
if m.streamingBashCommand != "ls -la /home" {
|
||||
t.Fatalf("expected streamingBashCommand='ls -la /home', got %q", m.streamingBashCommand)
|
||||
}
|
||||
|
||||
// ToolResultEvent should clear the command.
|
||||
m = sendMsg(m, app.ToolResultEvent{
|
||||
ToolCallID: "call-1",
|
||||
ToolName: "bash",
|
||||
ToolArgs: `{"command":"ls -la /home"}`,
|
||||
Result: "output",
|
||||
IsError: false,
|
||||
})
|
||||
|
||||
if m.streamingBashCommand != "" {
|
||||
t.Fatalf("expected streamingBashCommand cleared, got %q", m.streamingBashCommand)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolCallStarted_nonBashTool_doesNotSetCommand verifies that non-bash tools
|
||||
// do not set the streamingBashCommand field.
|
||||
func TestToolCallStarted_nonBashTool_doesNotSetCommand(t *testing.T) {
|
||||
ctrl := &stubAppController{}
|
||||
m, _, _ := newTestAppModel(ctrl)
|
||||
m.state = stateWorking
|
||||
|
||||
// Send ToolCallStartedEvent with a non-bash tool.
|
||||
m = sendMsg(m, app.ToolCallStartedEvent{
|
||||
ToolCallID: "call-1",
|
||||
ToolName: "read",
|
||||
ToolArgs: `{"file":"/etc/passwd"}`,
|
||||
})
|
||||
|
||||
if m.streamingBashCommand != "" {
|
||||
t.Fatalf("expected streamingBashCommand to remain empty for non-bash tools, got %q", m.streamingBashCommand)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStepError_printCmd verifies that StepErrorEvent with a non-nil error
|
||||
// produces a non-nil cmd (the tea.Println call for the error message).
|
||||
func TestStepError_printCmd(t *testing.T) {
|
||||
|
||||
+136
-94
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/lipgloss/v2"
|
||||
"github.com/indaco/herald"
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
)
|
||||
|
||||
@@ -79,7 +80,12 @@ func streamSpinnerTickCmd(generation uint64) tea.Cmd {
|
||||
// streamFlushTickMsg fires when it's time to commit pending chunks to the
|
||||
// main content builders and trigger a re-render. This coalesces rapid
|
||||
// streaming chunks into fewer expensive markdown re-renders.
|
||||
type streamFlushTickMsg struct{}
|
||||
//
|
||||
// generation ties the tick to the pending flush session that created it so
|
||||
// stale ticks from a prior Reset() are discarded.
|
||||
type streamFlushTickMsg struct {
|
||||
generation uint64
|
||||
}
|
||||
|
||||
// streamFlushInterval is the coalescing window for stream chunks. Chunks
|
||||
// arriving within this window are batched into a single render pass.
|
||||
@@ -89,9 +95,9 @@ const streamFlushInterval = 16 * time.Millisecond
|
||||
|
||||
// streamFlushTickCmd returns a tea.Cmd that fires streamFlushTickMsg after
|
||||
// the coalescing interval.
|
||||
func streamFlushTickCmd() tea.Cmd {
|
||||
func streamFlushTickCmd(generation uint64) tea.Cmd {
|
||||
return tea.Tick(streamFlushInterval, func(_ time.Time) tea.Msg {
|
||||
return streamFlushTickMsg{}
|
||||
return streamFlushTickMsg{generation: generation}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -149,9 +155,11 @@ type StreamComponent struct {
|
||||
// spinnerFrame is the current frame index.
|
||||
spinnerFrame int
|
||||
|
||||
// activeTools tracks the names of tools currently executing in parallel.
|
||||
// When multiple tools run concurrently, all are displayed in the spinner.
|
||||
activeTools []string
|
||||
// activeTools maps ToolCallID -> display label for currently running tools.
|
||||
activeTools map[string]string
|
||||
|
||||
// activeToolOrder preserves deterministic display order for active tools.
|
||||
activeToolOrder []string
|
||||
|
||||
// streamContent holds committed streaming text (flushed from pending).
|
||||
streamContent strings.Builder
|
||||
@@ -172,6 +180,10 @@ type StreamComponent struct {
|
||||
// the same coalescing window.
|
||||
flushPending bool
|
||||
|
||||
// flushGeneration is incremented when stream state resets so stale flush
|
||||
// ticks from a previous step can be discarded.
|
||||
flushGeneration uint64
|
||||
|
||||
// renderCache holds the last rendered output string. Reused by View()
|
||||
// between flush ticks to avoid redundant markdown re-parsing.
|
||||
renderCache string
|
||||
@@ -190,14 +202,8 @@ type StreamComponent struct {
|
||||
// reasoningDuration holds the total reasoning time, frozen when streaming text begins.
|
||||
reasoningDuration time.Duration
|
||||
|
||||
// messageRenderer renders assistant messages in standard mode.
|
||||
messageRenderer *MessageRenderer
|
||||
|
||||
// compactRenderer renders assistant messages in compact mode.
|
||||
compactRenderer *CompactRenderer
|
||||
|
||||
// compactMode selects which renderer to use.
|
||||
compactMode bool
|
||||
// renderer renders streaming assistant text in either compact or standard mode.
|
||||
renderer Renderer
|
||||
|
||||
// modelName is displayed in the streaming text header.
|
||||
modelName string
|
||||
@@ -211,6 +217,9 @@ type StreamComponent struct {
|
||||
// height constrains the render output to at most this many lines.
|
||||
// 0 means unconstrained.
|
||||
height int
|
||||
|
||||
// ty provides typography functions for rendering text.
|
||||
ty *herald.Typography
|
||||
}
|
||||
|
||||
// NewStreamComponent creates a new StreamComponent ready to be embedded in AppModel.
|
||||
@@ -218,13 +227,20 @@ func NewStreamComponent(compactMode bool, width int, modelName string) *StreamCo
|
||||
if width == 0 {
|
||||
width = 80
|
||||
}
|
||||
|
||||
var renderer Renderer
|
||||
if compactMode {
|
||||
renderer = NewCompactRenderer(width, false)
|
||||
} else {
|
||||
renderer = newMessageRenderer(width, false)
|
||||
}
|
||||
|
||||
return &StreamComponent{
|
||||
spinnerFrames: knightRiderFrames(),
|
||||
compactMode: compactMode,
|
||||
modelName: modelName,
|
||||
messageRenderer: newMessageRenderer(width, false),
|
||||
compactRenderer: NewCompactRenderer(width, false),
|
||||
width: width,
|
||||
spinnerFrames: knightRiderFrames(),
|
||||
modelName: modelName,
|
||||
renderer: renderer,
|
||||
width: width,
|
||||
ty: createTypography(GetTheme()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -251,11 +267,13 @@ func (s *StreamComponent) Reset() {
|
||||
s.spinnerGeneration++ // invalidate any in-flight tick commands
|
||||
s.spinnerFrame = 0
|
||||
s.activeTools = nil
|
||||
s.activeToolOrder = nil
|
||||
s.streamContent.Reset()
|
||||
s.reasoningContent.Reset()
|
||||
s.pendingStream.Reset()
|
||||
s.pendingReasoning.Reset()
|
||||
s.flushPending = false
|
||||
s.flushGeneration++
|
||||
s.renderCache = ""
|
||||
s.renderDirty = false
|
||||
s.timestamp = time.Time{}
|
||||
@@ -282,7 +300,8 @@ func (s *StreamComponent) GetRenderedContent() string {
|
||||
|
||||
text := s.streamContent.String()
|
||||
if text != "" {
|
||||
sections = append(sections, s.renderStreamingText(text))
|
||||
rendered := s.renderStreamingText(text)
|
||||
sections = append(sections, rendered)
|
||||
}
|
||||
|
||||
if len(sections) == 0 {
|
||||
@@ -322,8 +341,9 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
case tea.WindowSizeMsg:
|
||||
s.width = msg.Width
|
||||
s.messageRenderer.SetWidth(s.width)
|
||||
s.compactRenderer.SetWidth(s.width)
|
||||
if s.renderer != nil {
|
||||
s.renderer.SetWidth(s.width)
|
||||
}
|
||||
// Invalidate render cache — width change affects wrapping/styling.
|
||||
s.renderCache = ""
|
||||
s.renderDirty = true
|
||||
@@ -359,6 +379,9 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
|
||||
case streamFlushTickMsg:
|
||||
if msg.generation != s.flushGeneration {
|
||||
break
|
||||
}
|
||||
s.flushPending = false
|
||||
s.commitPending()
|
||||
|
||||
@@ -373,7 +396,7 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
s.pendingReasoning.WriteString(msg.Delta)
|
||||
if !s.flushPending {
|
||||
s.flushPending = true
|
||||
return s, streamFlushTickCmd()
|
||||
return s, streamFlushTickCmd(s.flushGeneration)
|
||||
}
|
||||
|
||||
case app.StreamChunkEvent:
|
||||
@@ -388,14 +411,25 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
s.pendingStream.WriteString(msg.Content)
|
||||
if !s.flushPending {
|
||||
s.flushPending = true
|
||||
return s, streamFlushTickCmd()
|
||||
return s, streamFlushTickCmd(s.flushGeneration)
|
||||
}
|
||||
|
||||
case app.ToolExecutionEvent:
|
||||
toolID := msg.ToolCallID
|
||||
if toolID == "" {
|
||||
// Defensive fallback for older/third-party emitters that may omit
|
||||
// ToolCallID. Best-effort only: same-name+args concurrent calls can
|
||||
// still collide without a stable ID.
|
||||
toolID = fmt.Sprintf("%s|%s", msg.ToolName, msg.ToolArgs)
|
||||
}
|
||||
if msg.IsStarting {
|
||||
// Add tool to active list for parallel execution display.
|
||||
toolDisplay := formatToolExecutionMessage(msg.ToolName, msg.ToolArgs)
|
||||
s.activeTools = append(s.activeTools, toolDisplay)
|
||||
if s.activeTools == nil {
|
||||
s.activeTools = make(map[string]string)
|
||||
}
|
||||
if _, exists := s.activeTools[toolID]; !exists {
|
||||
s.activeToolOrder = append(s.activeToolOrder, toolID)
|
||||
}
|
||||
s.activeTools[toolID] = formatToolExecutionMessage(msg.ToolName)
|
||||
s.spinnerFrame = 0
|
||||
if !s.spinning {
|
||||
s.phase = streamPhaseActive
|
||||
@@ -404,9 +438,10 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return s, streamSpinnerTickCmd(s.spinnerGeneration)
|
||||
}
|
||||
} else {
|
||||
// Tool finished — remove from active list but keep spinning if others remain.
|
||||
toolDisplay := formatToolExecutionMessage(msg.ToolName, msg.ToolArgs)
|
||||
s.activeTools = removeFromSlice(s.activeTools, toolDisplay)
|
||||
if s.activeTools != nil {
|
||||
delete(s.activeTools, toolID)
|
||||
}
|
||||
s.activeToolOrder = removeToolID(s.activeToolOrder, toolID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -415,7 +450,9 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
// View implements tea.Model. Renders the current stream region content.
|
||||
func (s *StreamComponent) View() tea.View {
|
||||
return tea.NewView(s.render())
|
||||
fullContent := s.render()
|
||||
visibleContent := s.viewContent(fullContent)
|
||||
return tea.NewView(visibleContent)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -458,54 +495,51 @@ func (s *StreamComponent) render() string {
|
||||
|
||||
content := strings.Join(sections, "\n")
|
||||
|
||||
// Clamp to height if constrained: keep the last h lines so the most
|
||||
// recent output is always visible.
|
||||
if s.height > 0 && content != "" {
|
||||
lines := strings.Split(content, "\n")
|
||||
if len(lines) > s.height {
|
||||
lines = lines[len(lines)-s.height:]
|
||||
content = strings.Join(lines, "\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Cache FULL content without height clamping.
|
||||
// Height clamping is applied in View() for display only.
|
||||
s.renderCache = content
|
||||
s.renderDirty = false
|
||||
return content
|
||||
}
|
||||
|
||||
// renderReasoningBlock renders the reasoning/thinking content in a surface-tinted
|
||||
// box. When collapsed, shows the last 10 lines with a truncation hint. When
|
||||
// viewContent returns the visible portion of content based on height constraint.
|
||||
// This is called by View() to get the slice that fits in the terminal.
|
||||
func (s *StreamComponent) viewContent(fullContent string) string {
|
||||
if s.height > 0 && fullContent != "" {
|
||||
lines := strings.Split(fullContent, "\n")
|
||||
if len(lines) > s.height {
|
||||
// Keep only the last h lines so the most recent output is visible.
|
||||
lines = lines[len(lines)-s.height:]
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
}
|
||||
return fullContent
|
||||
}
|
||||
|
||||
// renderReasoningBlock renders the reasoning/thinking content using blockquote.
|
||||
// When collapsed, shows the last 10 lines with a truncation hint. When
|
||||
// expanded, shows all lines. Includes a "Thought for Xs" duration footer.
|
||||
func (s *StreamComponent) renderReasoningBlock(reasoning string) string {
|
||||
theme := GetTheme()
|
||||
maxWidth := max(s.width-4, 20)
|
||||
|
||||
lines := strings.Split(strings.TrimRight(reasoning, "\n"), "\n")
|
||||
|
||||
contentStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.MutedBorder).
|
||||
Italic(true)
|
||||
|
||||
var parts []string
|
||||
|
||||
// When collapsed and content exceeds 10 lines, show only the last 10
|
||||
// with a truncation hint (matching iteratr's thinking block pattern).
|
||||
// with a truncation hint.
|
||||
const maxCollapsedLines = 10
|
||||
if !s.thinkingVisible && len(lines) > maxCollapsedLines {
|
||||
hidden := len(lines) - maxCollapsedLines
|
||||
hintStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.VeryMuted).
|
||||
Background(theme.MutedBorder).
|
||||
Italic(true)
|
||||
parts = append(parts, hintStyle.Render(fmt.Sprintf("... (%d lines hidden)", hidden)))
|
||||
parts = append(parts, s.ty.Italic(fmt.Sprintf("... (%d lines hidden)", hidden)))
|
||||
lines = lines[len(lines)-maxCollapsedLines:]
|
||||
}
|
||||
|
||||
// Render reasoning text.
|
||||
parts = append(parts, contentStyle.Width(maxWidth).Render(strings.Join(lines, "\n")))
|
||||
// Main content using Italic with Muted color for visual distinction.
|
||||
content := strings.TrimLeft(strings.Join(lines, "\n"), " \t\n")
|
||||
theme := GetTheme()
|
||||
mutedStyle := lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
parts = append(parts, mutedStyle.Render(s.ty.Italic(content)))
|
||||
|
||||
// Duration footer.
|
||||
// Duration footer with VeryMuted label and Accent duration.
|
||||
var duration time.Duration
|
||||
if s.reasoningDuration > 0 {
|
||||
duration = s.reasoningDuration
|
||||
@@ -519,21 +553,21 @@ func (s *StreamComponent) renderReasoningBlock(reasoning string) string {
|
||||
} else {
|
||||
durationStr = fmt.Sprintf("%.1fs", duration.Seconds())
|
||||
}
|
||||
footer := lipgloss.NewStyle().Foreground(theme.VeryMuted).Background(theme.MutedBorder).Render("Thought for ") +
|
||||
lipgloss.NewStyle().Foreground(theme.Info).Background(theme.MutedBorder).Render(durationStr)
|
||||
parts = append(parts, footer)
|
||||
label := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render("Thought for ")
|
||||
durationStyled := lipgloss.NewStyle().Foreground(theme.Accent).Render(durationStr)
|
||||
parts = append(parts, label+durationStyled)
|
||||
}
|
||||
|
||||
innerContent := strings.Join(parts, "\n")
|
||||
|
||||
// Wrap in box with surface background for visual distinction.
|
||||
boxStyle := lipgloss.NewStyle().
|
||||
Background(theme.MutedBorder). // Surface0 (#313244)
|
||||
PaddingLeft(1).
|
||||
Width(maxWidth + 2).
|
||||
MarginBottom(1)
|
||||
|
||||
return boxStyle.Render(innerContent)
|
||||
// Concatenate parts with newline between blockquote and footer
|
||||
var result string
|
||||
if len(parts) == 1 {
|
||||
result = parts[0]
|
||||
} else if len(parts) == 2 {
|
||||
result = parts[0] + "\n" + parts[1]
|
||||
} else {
|
||||
result = strings.Join(parts, "\n")
|
||||
}
|
||||
return lipgloss.NewStyle().MarginBottom(1).Render(result)
|
||||
}
|
||||
|
||||
// SetThinkingVisible sets whether reasoning blocks are shown or collapsed.
|
||||
@@ -559,7 +593,8 @@ func (s *StreamComponent) SpinnerView() string {
|
||||
return ""
|
||||
}
|
||||
frame := s.spinnerFrames[s.spinnerFrame%len(s.spinnerFrames)]
|
||||
if len(s.activeTools) == 0 {
|
||||
tools := s.activeToolDisplays()
|
||||
if len(tools) == 0 {
|
||||
return " " + frame
|
||||
}
|
||||
theme := GetTheme()
|
||||
@@ -569,10 +604,10 @@ func (s *StreamComponent) SpinnerView() string {
|
||||
|
||||
// Format active tools list
|
||||
var toolsMsg string
|
||||
if len(s.activeTools) == 1 {
|
||||
toolsMsg = s.activeTools[0]
|
||||
if len(tools) == 1 {
|
||||
toolsMsg = tools[0]
|
||||
} else {
|
||||
toolsMsg = "Running: " + strings.Join(s.activeTools, ", ")
|
||||
toolsMsg = "Running: " + strings.Join(tools, ", ")
|
||||
}
|
||||
return " " + frame + " " + msgStyle.Render(toolsMsg)
|
||||
}
|
||||
@@ -584,30 +619,37 @@ func (s *StreamComponent) renderStreamingText(text string) string {
|
||||
if ts.IsZero() {
|
||||
ts = time.Now()
|
||||
}
|
||||
|
||||
if s.compactMode {
|
||||
msg := s.compactRenderer.RenderAssistantMessage(text, ts, s.modelName)
|
||||
return msg.Content
|
||||
if s.renderer == nil {
|
||||
return text
|
||||
}
|
||||
msg := s.messageRenderer.RenderAssistantMessage(text, ts, s.modelName)
|
||||
msg := s.renderer.RenderAssistantMessage(text, ts, s.modelName)
|
||||
return msg.Content
|
||||
}
|
||||
|
||||
// removeFromSlice removes the first occurrence of a string from a slice.
|
||||
func removeFromSlice(slice []string, s string) []string {
|
||||
for i, v := range slice {
|
||||
if v == s {
|
||||
return append(slice[:i], slice[i+1:]...)
|
||||
func (s *StreamComponent) activeToolDisplays() []string {
|
||||
if len(s.activeTools) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(s.activeToolOrder))
|
||||
for _, id := range s.activeToolOrder {
|
||||
if display, ok := s.activeTools[id]; ok {
|
||||
out = append(out, display)
|
||||
}
|
||||
}
|
||||
return slice
|
||||
return out
|
||||
}
|
||||
|
||||
// removeToolID removes the first occurrence of a tool ID from a slice.
|
||||
func removeToolID(ids []string, id string) []string {
|
||||
for i, v := range ids {
|
||||
if v == id {
|
||||
return append(ids[:i], ids[i+1:]...)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// formatToolExecutionMessage creates a descriptive spinner message for tool execution.
|
||||
// For spawn_subagent, it shows simply as "Subagent" with optional task preview.
|
||||
func formatToolExecutionMessage(toolName, toolArgs string) string {
|
||||
if toolName == "spawn_subagent" {
|
||||
return "Subagent"
|
||||
}
|
||||
func formatToolExecutionMessage(toolName string) string {
|
||||
return toolName
|
||||
}
|
||||
|
||||
@@ -46,12 +46,12 @@ func renderToolBody(toolName, toolArgs, toolResult string, width int) string {
|
||||
if body := renderWriteBody(toolArgs, toolResult, width); body != "" {
|
||||
return body
|
||||
}
|
||||
case toolName == "bash" || toolName == "run_shell_cmd" ||
|
||||
case toolName == "bash" || toolName == "grep" || toolName == "find" ||
|
||||
strings.Contains(toolName, "shell") || strings.Contains(toolName, "command"):
|
||||
if body := renderBashBody(toolResult, width); body != "" {
|
||||
return body
|
||||
}
|
||||
case toolName == "spawn_subagent":
|
||||
case toolName == "subagent":
|
||||
if body := renderSubagentBody(toolResult, width); body != "" {
|
||||
return body
|
||||
}
|
||||
@@ -64,21 +64,44 @@ func renderToolBody(toolName, toolArgs, toolResult string, width int) string {
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// renderEditBody renders a side-by-side diff from old_text/new_text in toolArgs.
|
||||
// Supports both single-edit mode and multi-edit mode (edits array).
|
||||
func renderEditBody(toolArgs, toolResult string, width int) string {
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(toolArgs), &args); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Try to extract the starting line number from the unified diff in the result
|
||||
startLine := extractDiffStartLine(toolResult)
|
||||
|
||||
// Check for multi-edit mode (edits array)
|
||||
if editsArr, ok := args["edits"].([]any); ok && len(editsArr) > 0 {
|
||||
var results []string
|
||||
for _, edit := range editsArr {
|
||||
if e, ok := edit.(map[string]any); ok {
|
||||
oldText, _ := e["old_text"].(string)
|
||||
newText, _ := e["new_text"].(string)
|
||||
if oldText != "" || newText != "" {
|
||||
diff := renderDiffBlock(oldText, newText, startLine, width)
|
||||
if diff != "" {
|
||||
results = append(results, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(results) > 0 {
|
||||
return strings.Join(results, "\n")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Single-edit mode (legacy)
|
||||
oldText, _ := args["old_text"].(string)
|
||||
newText, _ := args["new_text"].(string)
|
||||
if oldText == "" && newText == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Try to extract the starting line number from the unified diff in the result
|
||||
startLine := extractDiffStartLine(toolResult)
|
||||
|
||||
return renderDiffBlock(oldText, newText, startLine, width)
|
||||
}
|
||||
|
||||
@@ -754,10 +777,10 @@ func renderToolBodyCompact(toolName, toolArgs, toolResult string, width int) str
|
||||
return renderReadCompact(toolResult)
|
||||
case toolName == "write":
|
||||
return renderWriteCompact(toolArgs)
|
||||
case toolName == "bash" || toolName == "run_shell_cmd" ||
|
||||
case toolName == "bash" || toolName == "grep" || toolName == "find" ||
|
||||
strings.Contains(toolName, "shell") || strings.Contains(toolName, "command"):
|
||||
return renderBashCompact(toolResult, width)
|
||||
case toolName == "spawn_subagent":
|
||||
case toolName == "subagent":
|
||||
return renderSubagentCompact(toolResult)
|
||||
}
|
||||
return ""
|
||||
@@ -916,8 +939,8 @@ func renderBashCompact(toolResult string, width int) string {
|
||||
// Subagent tool renderers — show only summary, not full output
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// renderSubagentBody renders a clean summary of subagent results.
|
||||
// Extracts timing/token info and shows only a brief summary instead of raw output.
|
||||
// renderSubagentBody renders a clean summary of subagent results with bash-style
|
||||
// background styling for consistency with other tools.
|
||||
func renderSubagentBody(toolResult string, width int) string {
|
||||
theme := getTheme()
|
||||
result := strings.TrimSpace(toolResult)
|
||||
@@ -937,9 +960,19 @@ func renderSubagentBody(toolResult string, width int) string {
|
||||
// First line is always the status summary
|
||||
statusLine := lines[0]
|
||||
|
||||
// Build a clean summary
|
||||
var summary strings.Builder
|
||||
summary.WriteString(lipgloss.NewStyle().Foreground(theme.Muted).Render(statusLine))
|
||||
// Build content lines for display with bash-style background
|
||||
outputStyle := lipgloss.NewStyle().Background(theme.CodeBg).PaddingLeft(1)
|
||||
errorStyle := lipgloss.NewStyle().Foreground(theme.Error).Background(theme.CodeBg).PaddingLeft(1)
|
||||
|
||||
const lineIndent = " "
|
||||
lineWidth := max(width-len(lineIndent), 20)
|
||||
maxLineChars := lineWidth - 1 // account for PaddingLeft(1)
|
||||
|
||||
var contentLines []string
|
||||
|
||||
// Add status line
|
||||
styledStatus := outputStyle.Width(lineWidth).Render(truncateLine(statusLine, maxLineChars))
|
||||
contentLines = append(contentLines, lineIndent+styledStatus)
|
||||
|
||||
// For successful results, extract a brief preview of the actual result
|
||||
if strings.Contains(statusLine, "successfully") {
|
||||
@@ -947,25 +980,45 @@ func renderSubagentBody(toolResult string, width int) string {
|
||||
if _, resultContent, found := strings.Cut(result, "Result:\n"); found {
|
||||
resultContent = strings.TrimSpace(resultContent)
|
||||
if resultContent != "" {
|
||||
// Show first 3 meaningful lines as preview
|
||||
preview := extractSubagentPreview(resultContent, 3, width-4)
|
||||
if preview != "" {
|
||||
summary.WriteString("\n\n")
|
||||
summary.WriteString(lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Italic(true).
|
||||
Render(preview))
|
||||
// Show first few meaningful lines as preview
|
||||
previewLines := extractSubagentPreviewLines(resultContent, 5, maxLineChars)
|
||||
if len(previewLines) > 0 {
|
||||
// Add blank separator line
|
||||
blankLine := outputStyle.Width(lineWidth).Render("")
|
||||
contentLines = append(contentLines, lineIndent+blankLine)
|
||||
|
||||
for _, line := range previewLines {
|
||||
styled := outputStyle.Width(lineWidth).Render(line)
|
||||
contentLines = append(contentLines, lineIndent+styled)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// For failed results, show error info
|
||||
if _, errorContent, found := strings.Cut(result, "Error:\n"); found {
|
||||
errorContent = strings.TrimSpace(errorContent)
|
||||
if errorContent != "" {
|
||||
previewLines := extractSubagentPreviewLines(errorContent, 3, maxLineChars)
|
||||
if len(previewLines) > 0 {
|
||||
blankLine := outputStyle.Width(lineWidth).Render("")
|
||||
contentLines = append(contentLines, lineIndent+blankLine)
|
||||
|
||||
for _, line := range previewLines {
|
||||
styled := errorStyle.Width(lineWidth).Render(line)
|
||||
contentLines = append(contentLines, lineIndent+styled)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return summary.String()
|
||||
return strings.Join(contentLines, "\n")
|
||||
}
|
||||
|
||||
// extractSubagentPreview extracts the first N non-empty lines from content,
|
||||
// truncating each line to maxWidth.
|
||||
func extractSubagentPreview(content string, maxLines, maxWidth int) string {
|
||||
// extractSubagentPreviewLines extracts the first N non-empty lines from content,
|
||||
// truncating each line to maxWidth. Returns as a slice of strings.
|
||||
func extractSubagentPreviewLines(content string, maxLines, maxWidth int) []string {
|
||||
lines := strings.Split(content, "\n")
|
||||
var preview []string
|
||||
|
||||
@@ -984,12 +1037,6 @@ func extractSubagentPreview(content string, maxLines, maxWidth int) string {
|
||||
}
|
||||
}
|
||||
|
||||
if len(preview) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
result := strings.Join(preview, "\n")
|
||||
|
||||
// Count remaining lines for "more" indicator
|
||||
totalLines := 0
|
||||
for _, line := range lines {
|
||||
@@ -998,10 +1045,10 @@ func extractSubagentPreview(content string, maxLines, maxWidth int) string {
|
||||
}
|
||||
}
|
||||
if totalLines > maxLines {
|
||||
result += fmt.Sprintf("\n...(%d more lines)", totalLines-maxLines)
|
||||
preview = append(preview, fmt.Sprintf("...(%d more lines)", totalLines-maxLines))
|
||||
}
|
||||
|
||||
return result
|
||||
return preview
|
||||
}
|
||||
|
||||
// renderSubagentCompact returns a brief one-line summary for subagent results.
|
||||
|
||||
@@ -134,13 +134,23 @@ func (ut *UsageTracker) EstimateAndUpdateUsage(inputText, outputText string) {
|
||||
}
|
||||
|
||||
// SetContextTokens records the approximate current context window utilization.
|
||||
// This should be set from the final API call's input + output tokens (i.e.
|
||||
// FinalResponse.Usage) rather than the aggregate TotalUsage, because TotalUsage
|
||||
// This should be set from FinalUsage.InputTokens, which already includes the
|
||||
// full conversation history (system prompt + all previous messages). Do NOT
|
||||
// add OutputTokens as that would double-count (output becomes input next turn).
|
||||
// Use FinalResponse.Usage rather than aggregate TotalUsage, because TotalUsage
|
||||
// sums across all tool-calling steps and overstates the actual window fill level.
|
||||
func (ut *UsageTracker) SetContextTokens(tokens int) {
|
||||
ut.mu.Lock()
|
||||
defer ut.mu.Unlock()
|
||||
ut.contextTokens = tokens
|
||||
// Track the maximum context seen so far. In multi-step tool calls,
|
||||
// FinalUsage.InputTokens may reflect only the last step's input, which
|
||||
// can be smaller than previous steps. We want to show the largest context
|
||||
// the model has processed in this session.
|
||||
if tokens > ut.contextTokens {
|
||||
ut.contextTokens = tokens
|
||||
}
|
||||
// If tokens < current, we keep the larger value (no-op)
|
||||
// This prevents the display from dropping during multi-step tool calls.
|
||||
}
|
||||
|
||||
// RenderUsageInfo generates a formatted string displaying current usage statistics
|
||||
@@ -151,10 +161,6 @@ func (ut *UsageTracker) RenderUsageInfo() string {
|
||||
ut.mu.RLock()
|
||||
defer ut.mu.RUnlock()
|
||||
|
||||
if ut.sessionStats.RequestCount == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
baseStyle := lipgloss.NewStyle()
|
||||
|
||||
// Display the current context window token count (from the last API call),
|
||||
@@ -266,3 +272,14 @@ func (ut *UsageTracker) SetWidth(width int) {
|
||||
defer ut.mu.Unlock()
|
||||
ut.width = width
|
||||
}
|
||||
|
||||
// UpdateModelInfo updates the model information and OAuth status when the model
|
||||
// is switched mid-session. This ensures token costs and context limits are
|
||||
// calculated correctly for the new model.
|
||||
func (ut *UsageTracker) UpdateModelInfo(modelInfo *models.ModelInfo, provider string, isOAuth bool) {
|
||||
ut.mu.Lock()
|
||||
defer ut.mu.Unlock()
|
||||
ut.modelInfo = modelInfo
|
||||
ut.provider = provider
|
||||
ut.isOAuth = isOAuth
|
||||
}
|
||||
|
||||
@@ -67,3 +67,62 @@ func TestUsageTracker_RenderUsageInfo_OAuth(t *testing.T) {
|
||||
t.Errorf("Expected regular rendered output to show actual cost, got: %s", regularRendered)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageTracker_RenderUsageInfo_StartupState(t *testing.T) {
|
||||
// Create a mock model info with costs and context limit
|
||||
modelInfo := &models.ModelInfo{
|
||||
ID: "claude-3-5-sonnet-20241022",
|
||||
Name: "Claude 3.5 Sonnet v2",
|
||||
Cost: models.Cost{
|
||||
Input: 3.0,
|
||||
Output: 15.0,
|
||||
},
|
||||
Limit: models.Limit{
|
||||
Context: 200000,
|
||||
Output: 8192,
|
||||
},
|
||||
}
|
||||
|
||||
// Test startup state (no requests made yet) - Regular API key
|
||||
regularTracker := NewUsageTracker(modelInfo, "anthropic", 80, false)
|
||||
rendered := stripAnsi(regularTracker.RenderUsageInfo())
|
||||
|
||||
// Should NOT return empty string on startup
|
||||
if rendered == "" {
|
||||
t.Errorf("Expected non-empty output on startup, got empty string")
|
||||
}
|
||||
|
||||
// Should show 0 tokens
|
||||
if !strings.Contains(rendered, "Tokens: 0") {
|
||||
t.Errorf("Expected 'Tokens: 0' on startup, got: %s", rendered)
|
||||
}
|
||||
|
||||
// Should NOT show percentage when tokens are 0
|
||||
if strings.Contains(rendered, "(%") {
|
||||
t.Errorf("Expected no percentage on startup with 0 tokens, got: %s", rendered)
|
||||
}
|
||||
|
||||
// Should show $0.0000 cost for regular API key
|
||||
if !strings.Contains(rendered, "Cost: $0.0000") {
|
||||
t.Errorf("Expected 'Cost: $0.0000' on startup, got: %s", rendered)
|
||||
}
|
||||
|
||||
// Test startup state (no requests made yet) - OAuth
|
||||
oauthTracker := NewUsageTracker(modelInfo, "anthropic", 80, true)
|
||||
oauthRendered := stripAnsi(oauthTracker.RenderUsageInfo())
|
||||
|
||||
// Should NOT return empty string on startup
|
||||
if oauthRendered == "" {
|
||||
t.Errorf("Expected non-empty output on startup for OAuth, got empty string")
|
||||
}
|
||||
|
||||
// Should show 0 tokens for OAuth
|
||||
if !strings.Contains(oauthRendered, "Tokens: 0") {
|
||||
t.Errorf("Expected 'Tokens: 0' on startup for OAuth, got: %s", oauthRendered)
|
||||
}
|
||||
|
||||
// Should show $0.00 cost for OAuth
|
||||
if !strings.Contains(oauthRendered, "Cost: $0.00") {
|
||||
t.Errorf("Expected 'Cost: $0.00' on startup for OAuth, got: %s", oauthRendered)
|
||||
}
|
||||
}
|
||||
|
||||
+20
-15
@@ -71,22 +71,28 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
Monitor tool execution in real-time:
|
||||
|
||||
```go
|
||||
response, err := host.PromptWithCallbacks(
|
||||
unsub := host.OnToolCall(func(e kit.ToolCallEvent) {
|
||||
fmt.Printf("Calling tool: %s\n", e.ToolName)
|
||||
})
|
||||
defer unsub()
|
||||
|
||||
unsub2 := host.OnToolResult(func(e kit.ToolResultEvent) {
|
||||
if e.IsError {
|
||||
fmt.Printf("Tool %s failed: %s\n", e.ToolName, e.Result)
|
||||
} else {
|
||||
fmt.Printf("Tool %s succeeded\n", e.ToolName)
|
||||
}
|
||||
})
|
||||
defer unsub2()
|
||||
|
||||
unsub3 := host.OnStreaming(func(e kit.MessageUpdateEvent) {
|
||||
fmt.Print(e.Chunk)
|
||||
})
|
||||
defer unsub3()
|
||||
|
||||
response, err := host.Prompt(
|
||||
ctx,
|
||||
"List files in the current directory",
|
||||
func(name, args string) {
|
||||
fmt.Printf("Calling tool: %s\n", name)
|
||||
},
|
||||
func(name, args, result string, isError bool) {
|
||||
if isError {
|
||||
fmt.Printf("Tool %s failed: %s\n", name, result)
|
||||
} else {
|
||||
fmt.Printf("Tool %s succeeded\n", name)
|
||||
}
|
||||
},
|
||||
func(chunk string) {
|
||||
fmt.Print(chunk) // Stream output
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
@@ -125,7 +131,6 @@ host.ClearSession()
|
||||
|
||||
- `New(ctx, opts)` - Create new Kit instance
|
||||
- `Prompt(ctx, message)` - Send message and get response
|
||||
- `PromptWithCallbacks(ctx, message, ...)` - Send message with progress callbacks
|
||||
- `LoadSession(path)` - Load session from file
|
||||
- `SaveSession(path)` - Save session to file
|
||||
- `ClearSession()` - Clear conversation history
|
||||
|
||||
@@ -9,6 +9,10 @@ type CredentialManager = auth.CredentialManager
|
||||
// and API key authentication methods.
|
||||
type AnthropicCredentials = auth.AnthropicCredentials
|
||||
|
||||
// OpenAICredentials holds OpenAI API credentials supporting both OAuth
|
||||
// and API key authentication methods.
|
||||
type OpenAICredentials = auth.OpenAICredentials
|
||||
|
||||
// CredentialStore holds all stored credentials for various providers.
|
||||
type CredentialStore = auth.CredentialStore
|
||||
|
||||
@@ -42,3 +46,34 @@ func GetAnthropicAPIKey() string {
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
// HasOpenAICredentials checks if valid OpenAI credentials are stored
|
||||
// (either OAuth token or API key).
|
||||
func HasOpenAICredentials() bool {
|
||||
cm, err := auth.NewCredentialManager()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
has, err := cm.HasOpenAICredentials()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return has
|
||||
}
|
||||
|
||||
// GetOpenAIAPIKey resolves the OpenAI API key using the standard
|
||||
// resolution order: stored credentials -> OPENAI_API_KEY env var.
|
||||
// Returns an empty string if no key is found.
|
||||
func GetOpenAIAPIKey() string {
|
||||
cm, err := auth.NewCredentialManager()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
// Try to get valid access token (handles OAuth refresh)
|
||||
token, err := cm.GetValidOpenAIAccessToken()
|
||||
if err == nil && token != "" {
|
||||
return token
|
||||
}
|
||||
// Fall back to environment variable
|
||||
return ""
|
||||
}
|
||||
|
||||
+32
-5
@@ -41,6 +41,10 @@ const (
|
||||
EventReasoningDelta EventType = "reasoning_delta"
|
||||
// EventToolOutput fires when a tool produces streaming output chunks.
|
||||
EventToolOutput EventType = "tool_output"
|
||||
EventStepUsage EventType = "step_usage"
|
||||
// EventSteerConsumed fires when one or more steering messages have been
|
||||
// injected into the agent turn via PrepareStep.
|
||||
EventSteerConsumed EventType = "steer_consumed"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -66,7 +70,7 @@ const (
|
||||
ToolKindEdit = "edit" // File modification (edit, write)
|
||||
ToolKindRead = "read" // File reading (read, ls)
|
||||
ToolKindSearch = "search" // Content/file search (grep, find)
|
||||
ToolKindSubagent = "agent" // Subagent spawning (spawn_subagent)
|
||||
ToolKindSubagent = "agent" // Subagent spawning (subagent)
|
||||
)
|
||||
|
||||
// coreToolKinds maps built-in tool names to their kind. MCP and extension
|
||||
@@ -79,7 +83,7 @@ var coreToolKinds = map[string]string{
|
||||
"ls": ToolKindRead,
|
||||
"grep": ToolKindSearch,
|
||||
"find": ToolKindSearch,
|
||||
"spawn_subagent": ToolKindSubagent,
|
||||
"subagent": ToolKindSubagent,
|
||||
}
|
||||
|
||||
// toolKindFor returns the ToolKind for a given tool name, defaulting to
|
||||
@@ -212,7 +216,7 @@ type ToolResultEvent struct {
|
||||
// ToolResultMetadata carries structured data from tool executions.
|
||||
type ToolResultMetadata struct {
|
||||
FileDiffs []FileDiffInfo `json:"file_diffs,omitempty"` // Present for edit/write tools
|
||||
SubagentSessionID string `json:"subagent_session_id,omitempty"` // Present for spawn_subagent tool
|
||||
SubagentSessionID string `json:"subagent_session_id,omitempty"` // Present for subagent tool
|
||||
}
|
||||
|
||||
// FileDiffInfo describes a file modification from an edit or write tool.
|
||||
@@ -249,6 +253,19 @@ type ResponseEvent struct {
|
||||
// EventType implements Event.
|
||||
func (e ResponseEvent) EventType() EventType { return EventResponse }
|
||||
|
||||
// StepUsageEvent fires after each complete step in a multi-step agent turn,
|
||||
// carrying the token usage for that specific step. This enables real-time
|
||||
// cost tracking during long-running tool-calling conversations.
|
||||
type StepUsageEvent struct {
|
||||
InputTokens uint64
|
||||
OutputTokens uint64
|
||||
CacheReadTokens uint64
|
||||
CacheWriteTokens uint64
|
||||
}
|
||||
|
||||
// EventType implements Event.
|
||||
func (e StepUsageEvent) EventType() EventType { return EventStepUsage }
|
||||
|
||||
// CompactionEvent fires after a successful compaction.
|
||||
type CompactionEvent struct {
|
||||
Summary string
|
||||
@@ -262,6 +279,16 @@ type CompactionEvent struct {
|
||||
// EventType implements Event.
|
||||
func (e CompactionEvent) EventType() EventType { return EventCompaction }
|
||||
|
||||
// SteerConsumedEvent fires when one or more steering messages have been
|
||||
// injected into the agent turn via PrepareStep. The Count indicates how
|
||||
// many messages were consumed in this batch.
|
||||
type SteerConsumedEvent struct {
|
||||
Count int
|
||||
}
|
||||
|
||||
// EventType implements Event.
|
||||
func (e SteerConsumedEvent) EventType() EventType { return EventSteerConsumed }
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// EventBus
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -430,13 +457,13 @@ func (s *subagentListenerSet) emit(event Event) {
|
||||
//
|
||||
// The listener receives the same event types as Subscribe() (ToolCallEvent,
|
||||
// MessageUpdateEvent, etc.) but scoped to the child agent's activity. If the
|
||||
// tool call ID doesn't correspond to an active or future spawn_subagent call,
|
||||
// tool call ID doesn't correspond to an active or future subagent call,
|
||||
// the listener simply never fires.
|
||||
//
|
||||
// Typical usage — register inside an OnToolCall handler:
|
||||
//
|
||||
// kit.OnToolCall(func(e kit.ToolCallEvent) {
|
||||
// if e.ToolName == "spawn_subagent" {
|
||||
// if e.ToolName == "subagent" {
|
||||
// kit.SubscribeSubagent(e.ToolCallID, func(child kit.Event) {
|
||||
// // real-time subagent events
|
||||
// })
|
||||
|
||||
@@ -2,6 +2,7 @@ package kit
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
@@ -119,6 +120,125 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
})
|
||||
}
|
||||
|
||||
// --- Subagent lifecycle events ---
|
||||
// When an extension registers OnSubagentStart/Chunk/End handlers, bridge
|
||||
// the SDK's per-subagent event stream (SubscribeSubagent) into the
|
||||
// extension runner.
|
||||
//
|
||||
// Flow:
|
||||
// ToolExecutionStartEvent(subagent) → emit SubagentStartEvent
|
||||
// → SubscribeSubagent → emit SubagentChunkEvents
|
||||
// ToolResultEvent(subagent) → emit SubagentEndEvent
|
||||
//
|
||||
// We use ToolExecutionStart (not ToolCall) for SubagentStart because that
|
||||
// is when the subagent actually begins running. We use ToolResult for
|
||||
// SubagentEnd because that carries the final response text.
|
||||
wantsSubagent := runner.HasHandlers(extensions.SubagentStart) ||
|
||||
runner.HasHandlers(extensions.SubagentChunk) ||
|
||||
runner.HasHandlers(extensions.SubagentEnd)
|
||||
|
||||
if wantsSubagent {
|
||||
// taskByCallID tracks the task description extracted from ToolCall input,
|
||||
// keyed by toolCallID. Populated on ToolCall, consumed on ToolResult.
|
||||
taskByCallID := make(map[string]string)
|
||||
var taskMu = &taskMutex{}
|
||||
|
||||
// Intercept ToolCall to capture the task and subscribe to child events.
|
||||
m.Subscribe(func(e Event) {
|
||||
ev, ok := e.(ToolCallEvent)
|
||||
if !ok || ev.ToolName != "subagent" {
|
||||
return
|
||||
}
|
||||
|
||||
// Extract task from parsed args.
|
||||
task := ""
|
||||
if ev.ParsedArgs != nil {
|
||||
if t, ok := ev.ParsedArgs["task"].(string); ok {
|
||||
task = t
|
||||
}
|
||||
}
|
||||
taskMu.set(taskByCallID, ev.ToolCallID, task)
|
||||
|
||||
// Subscribe to child events so we can forward them as SubagentChunkEvents.
|
||||
if runner.HasHandlers(extensions.SubagentChunk) {
|
||||
m.SubscribeSubagent(ev.ToolCallID, func(childEvent Event) {
|
||||
chunk := extensions.SubagentChunkEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
Task: task,
|
||||
}
|
||||
switch ce := childEvent.(type) {
|
||||
case MessageUpdateEvent:
|
||||
chunk.ChunkType = "text"
|
||||
chunk.Content = ce.Chunk
|
||||
case TurnStartEvent:
|
||||
chunk.ChunkType = "turn_start"
|
||||
case TurnEndEvent:
|
||||
chunk.ChunkType = "turn_end"
|
||||
case ToolCallEvent:
|
||||
chunk.ChunkType = "tool_call"
|
||||
chunk.ToolName = ce.ToolName
|
||||
chunk.ToolArgs = ce.ToolArgs
|
||||
case ToolExecutionStartEvent:
|
||||
chunk.ChunkType = "tool_execution_start"
|
||||
chunk.ToolName = ce.ToolName
|
||||
case ToolExecutionEndEvent:
|
||||
chunk.ChunkType = "tool_execution_end"
|
||||
chunk.ToolName = ce.ToolName
|
||||
case ToolResultEvent:
|
||||
chunk.ChunkType = "tool_result"
|
||||
chunk.ToolName = ce.ToolName
|
||||
chunk.ToolResult = ce.Result
|
||||
chunk.IsError = ce.IsError
|
||||
default:
|
||||
return // skip unknown event types
|
||||
}
|
||||
_, _ = runner.Emit(chunk)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// Emit SubagentStartEvent when execution begins.
|
||||
if runner.HasHandlers(extensions.SubagentStart) {
|
||||
m.Subscribe(func(e Event) {
|
||||
ev, ok := e.(ToolExecutionStartEvent)
|
||||
if !ok || ev.ToolName != "subagent" {
|
||||
return
|
||||
}
|
||||
task := taskMu.get(taskByCallID, ev.ToolCallID)
|
||||
_, _ = runner.Emit(extensions.SubagentStartEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
Task: task,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Emit SubagentEndEvent when the tool result arrives.
|
||||
if runner.HasHandlers(extensions.SubagentEnd) {
|
||||
m.Subscribe(func(e Event) {
|
||||
ev, ok := e.(ToolResultEvent)
|
||||
if !ok || ev.ToolName != "subagent" {
|
||||
return
|
||||
}
|
||||
task := taskMu.get(taskByCallID, ev.ToolCallID)
|
||||
taskMu.del(taskByCallID, ev.ToolCallID)
|
||||
errMsg := ""
|
||||
if ev.IsError {
|
||||
errMsg = ev.Result
|
||||
}
|
||||
response := ""
|
||||
if !ev.IsError {
|
||||
response = ev.Result
|
||||
}
|
||||
_, _ = runner.Emit(extensions.SubagentEndEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
Task: task,
|
||||
Response: response,
|
||||
ErrorMsg: errMsg,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- Context filtering hook ---
|
||||
// Extension ContextPrepare → SDK ContextPrepare hook.
|
||||
if runner.HasHandlers(extensions.ContextPrepare) {
|
||||
@@ -204,3 +324,27 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// taskMutex is a simple mutex-protected map helper used by bridgeExtensions.
|
||||
// It lives in this file to avoid polluting the kit package with unexported types.
|
||||
type taskMutex struct {
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (t *taskMutex) set(m map[string]string, key, val string) {
|
||||
t.mu.Lock()
|
||||
m[key] = val
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
func (t *taskMutex) get(m map[string]string, key string) string {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return m[key]
|
||||
}
|
||||
|
||||
func (t *taskMutex) del(m map[string]string, key string) {
|
||||
t.mu.Lock()
|
||||
delete(m, key)
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
+133
-62
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -66,34 +67,24 @@ type Kit struct {
|
||||
// subagentListeners holds per-tool-call event listeners registered via
|
||||
// SubscribeSubagent(). Keyed by toolCallID → *subagentListenerSet.
|
||||
subagentListeners sync.Map
|
||||
|
||||
// steerCh is a buffered channel used to inject steering messages into
|
||||
// the running agent turn via Fantasy's PrepareStep. Created fresh for
|
||||
// each generate() call and set to nil when idle. Protected by steerMu.
|
||||
steerMu sync.Mutex
|
||||
steerCh chan string
|
||||
leftoverSteer []string // unconsumed steer messages from the last turn
|
||||
}
|
||||
|
||||
// Subscribe registers an EventListener that will be called for every lifecycle
|
||||
// event emitted during Prompt() and PromptWithCallbacks(). Returns an
|
||||
// unsubscribe function that removes the listener.
|
||||
// event emitted during Prompt(). Returns an unsubscribe function that removes
|
||||
// the listener.
|
||||
func (m *Kit) Subscribe(listener EventListener) func() {
|
||||
return m.events.subscribe(listener)
|
||||
}
|
||||
|
||||
// GetExtRunner returns the extension runner (nil if extensions are disabled).
|
||||
//
|
||||
// Deprecated: Use SetExtensionContext and EmitSessionStart instead. GetExtRunner
|
||||
// leaks the internal extensions.Runner type across the SDK boundary.
|
||||
func (m *Kit) GetExtRunner() *extensions.Runner { return m.extRunner }
|
||||
|
||||
// GetBufferedLogger returns the buffered debug logger (nil if not configured).
|
||||
//
|
||||
// Deprecated: Use GetBufferedDebugMessages instead.
|
||||
func (m *Kit) GetBufferedLogger() *tools.BufferedDebugLogger { return m.bufferedLogger }
|
||||
|
||||
// GetAgent returns the underlying agent.
|
||||
//
|
||||
// Deprecated: Use GetToolNames, GetLoadingMessage, GetLoadedServerNames,
|
||||
// GetMCPToolCount, GetExtensionToolCount instead.
|
||||
func (m *Kit) GetAgent() *agent.Agent { return m.agent }
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Narrow accessors — prefer these over GetAgent/GetExtRunner/GetBufferedLogger
|
||||
// Narrow accessors
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// GetToolNames returns the names of all tools available to the agent.
|
||||
@@ -529,8 +520,11 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
|
||||
}
|
||||
|
||||
// Build a provider config from current settings, overriding the model.
|
||||
// Load system prompt properly (handles both file paths and inline content).
|
||||
systemPrompt, _ := config.LoadSystemPrompt(viper.GetString("system-prompt"))
|
||||
config := &models.ProviderConfig{
|
||||
ModelString: modelString,
|
||||
SystemPrompt: systemPrompt,
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
ProviderURL: viper.GetString("provider-url"),
|
||||
MaxTokens: viper.GetInt("max-tokens"),
|
||||
@@ -1053,6 +1047,15 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
// Bridge extension events to SDK hooks.
|
||||
if agentResult.ExtRunner != nil {
|
||||
k.bridgeExtensions(agentResult.ExtRunner)
|
||||
|
||||
// Initialize extension context with minimal defaults. SDK users can call
|
||||
// SetExtensionContext to override with richer implementations (TUI callbacks,
|
||||
// prompts, etc.). This ensures extensions never crash on nil function fields.
|
||||
k.SetExtensionContext(extensions.Context{
|
||||
CWD: cwd,
|
||||
Model: k.modelString,
|
||||
Interactive: false, // SDK mode defaults to non-interactive
|
||||
})
|
||||
}
|
||||
|
||||
return k, nil
|
||||
@@ -1253,7 +1256,7 @@ type SubagentConfig struct {
|
||||
SystemPrompt string
|
||||
|
||||
// Tools overrides the tool set. If nil, SubagentTools() is used (all
|
||||
// core tools except spawn_subagent, preventing infinite recursion).
|
||||
// core tools except subagent, preventing infinite recursion).
|
||||
Tools []Tool
|
||||
|
||||
// NoSession, when true, uses an in-memory ephemeral session. When false
|
||||
@@ -1327,7 +1330,7 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult
|
||||
systemPrompt = "You are a helpful coding assistant. Complete the task efficiently and thoroughly."
|
||||
}
|
||||
|
||||
// Default tools: everything except spawn_subagent.
|
||||
// Default tools: everything except subagent.
|
||||
tools := cfg.Tools
|
||||
if tools == nil {
|
||||
tools = SubagentTools()
|
||||
@@ -1405,8 +1408,37 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult
|
||||
// All prompt modes (Prompt, Steer, FollowUp, PromptWithOptions) share this
|
||||
// single code path so callback wiring is never duplicated.
|
||||
func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.GenerateWithLoopResult, error) {
|
||||
// Create a per-turn steer channel and attach it to the context so the
|
||||
// agent's PrepareStep can inject steering messages between steps.
|
||||
steerCh := make(chan string, 16)
|
||||
m.steerMu.Lock()
|
||||
m.steerCh = steerCh
|
||||
m.steerMu.Unlock()
|
||||
defer func() {
|
||||
// Drain any unconsumed steer messages before nilling the channel.
|
||||
// These are stored in leftoverSteer so DrainSteer() can return them.
|
||||
var leftover []string
|
||||
for {
|
||||
select {
|
||||
case msg := <-steerCh:
|
||||
leftover = append(leftover, msg)
|
||||
default:
|
||||
goto drained
|
||||
}
|
||||
}
|
||||
drained:
|
||||
m.steerMu.Lock()
|
||||
m.steerCh = nil
|
||||
m.leftoverSteer = leftover
|
||||
m.steerMu.Unlock()
|
||||
}()
|
||||
ctx = agent.ContextWithSteerCh(ctx, steerCh)
|
||||
ctx = agent.ContextWithSteerConsumed(ctx, func(count int) {
|
||||
m.events.emit(SteerConsumedEvent{Count: count})
|
||||
})
|
||||
|
||||
// Inject the in-process subagent spawner into the context so the
|
||||
// spawn_subagent core tool can create child Kit instances without
|
||||
// subagent core tool can create child Kit instances without
|
||||
// importing pkg/kit (which would create an import cycle).
|
||||
ctx = core.WithSubagentSpawner(ctx, func(
|
||||
spawnCtx context.Context, toolCallID, prompt, model, systemPrompt string, timeout time.Duration,
|
||||
@@ -1491,6 +1523,19 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
|
||||
IsStderr: isStderr,
|
||||
})
|
||||
},
|
||||
func(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64) {
|
||||
// Emit step usage event for real-time cost tracking
|
||||
if viper.GetBool("debug") {
|
||||
log.Printf("[DEBUG] Kit.generate emitting StepUsageEvent: input=%d output=%d cacheRead=%d cacheCreate=%d",
|
||||
inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens)
|
||||
}
|
||||
m.events.emit(StepUsageEvent{
|
||||
InputTokens: uint64(inputTokens),
|
||||
OutputTokens: uint64(outputTokens),
|
||||
CacheReadTokens: uint64(cacheReadTokens),
|
||||
CacheWriteTokens: uint64(cacheCreationTokens),
|
||||
})
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1705,6 +1750,71 @@ func (m *Kit) FollowUp(ctx context.Context, text string) (string, error) {
|
||||
return result.Response, nil
|
||||
}
|
||||
|
||||
// InjectSteer sends a steering message into the currently active agent turn.
|
||||
// The message will be injected as a user message between steps (after the
|
||||
// current tool execution finishes, before the next LLM call). If no turn is
|
||||
// active the message is silently dropped — callers should check IsGenerating()
|
||||
// or use Prompt()/Steer() for idle-state messaging.
|
||||
//
|
||||
// InjectSteer is safe to call from any goroutine. Multiple calls queue
|
||||
// messages in order; all pending steer messages are drained and injected
|
||||
// together at the next step boundary.
|
||||
//
|
||||
// This is the preferred way to redirect an agent mid-turn without cancelling
|
||||
// in-progress tool execution.
|
||||
func (m *Kit) InjectSteer(message string) {
|
||||
m.steerMu.Lock()
|
||||
ch := m.steerCh
|
||||
m.steerMu.Unlock()
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case ch <- message:
|
||||
default:
|
||||
// Channel full — extremely unlikely with buffer of 16, but don't block.
|
||||
}
|
||||
}
|
||||
|
||||
// IsGenerating returns true if an agent turn is currently in progress.
|
||||
// Use this to decide between InjectSteer (mid-turn) and Prompt (new turn).
|
||||
func (m *Kit) IsGenerating() bool {
|
||||
m.steerMu.Lock()
|
||||
defer m.steerMu.Unlock()
|
||||
return m.steerCh != nil
|
||||
}
|
||||
|
||||
// DrainSteer removes and returns all unconsumed steer messages. Called after
|
||||
// a turn completes so the app layer can process any steer messages that
|
||||
// arrived after the last PrepareStep fired (e.g. during a text-only response
|
||||
// with no tool calls, or after the agent finished its last step).
|
||||
func (m *Kit) DrainSteer() []string {
|
||||
m.steerMu.Lock()
|
||||
defer m.steerMu.Unlock()
|
||||
|
||||
// First check leftover messages saved when generate() returned.
|
||||
if len(m.leftoverSteer) > 0 {
|
||||
msgs := m.leftoverSteer
|
||||
m.leftoverSteer = nil
|
||||
return msgs
|
||||
}
|
||||
|
||||
// If a turn is still active, drain from the live channel.
|
||||
if m.steerCh != nil {
|
||||
var msgs []string
|
||||
for {
|
||||
select {
|
||||
case msg := <-m.steerCh:
|
||||
msgs = append(msgs, msg)
|
||||
default:
|
||||
return msgs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PromptOptions configures a single PromptWithOptions call.
|
||||
type PromptOptions struct {
|
||||
// SystemMessage is prepended as a system message before the user prompt.
|
||||
@@ -1730,45 +1840,6 @@ func (m *Kit) PromptWithOptions(ctx context.Context, msg string, opts PromptOpti
|
||||
return result.Response, nil
|
||||
}
|
||||
|
||||
// PromptWithCallbacks sends a message with callbacks for monitoring tool
|
||||
// execution and streaming responses. Lifecycle events are also emitted to all
|
||||
// registered subscribers (via Subscribe).
|
||||
//
|
||||
// Deprecated: Use Subscribe/OnToolCall/OnToolResult/OnStreaming instead of
|
||||
// inline callbacks. PromptWithCallbacks is retained for backward compatibility.
|
||||
func (m *Kit) PromptWithCallbacks(
|
||||
ctx context.Context,
|
||||
message string,
|
||||
onToolCall func(name, args string),
|
||||
onToolResult func(name, args, result string, isError bool),
|
||||
onStreaming func(chunk string),
|
||||
) (string, error) {
|
||||
// Register temporary subscribers for the inline callbacks.
|
||||
var unsubs []func()
|
||||
if onToolCall != nil {
|
||||
unsubs = append(unsubs, m.OnToolCall(func(e ToolCallEvent) {
|
||||
onToolCall(e.ToolName, e.ToolArgs)
|
||||
}))
|
||||
}
|
||||
if onToolResult != nil {
|
||||
unsubs = append(unsubs, m.OnToolResult(func(e ToolResultEvent) {
|
||||
onToolResult(e.ToolName, e.ToolArgs, e.Result, e.IsError)
|
||||
}))
|
||||
}
|
||||
if onStreaming != nil {
|
||||
unsubs = append(unsubs, m.OnStreaming(func(e MessageUpdateEvent) {
|
||||
onStreaming(e.Chunk)
|
||||
}))
|
||||
}
|
||||
defer func() {
|
||||
for _, unsub := range unsubs {
|
||||
unsub()
|
||||
}
|
||||
}()
|
||||
|
||||
return m.Prompt(ctx, message)
|
||||
}
|
||||
|
||||
// PromptResult sends a message and returns the full turn result including
|
||||
// usage statistics and conversation messages. Use this instead of Prompt()
|
||||
// when you need more than just the response text.
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/message"
|
||||
"github.com/mark3labs/kit/internal/session"
|
||||
)
|
||||
|
||||
@@ -86,3 +91,213 @@ func (m *Kit) SetSessionName(name string) error {
|
||||
_, err := m.treeSession.AppendSessionInfo(name)
|
||||
return err
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tree Navigation Bridge for Extensions (Phase 1)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// GetTreeNode returns a node by ID with full metadata and children.
|
||||
// Returns nil if entry not found or no tree session.
|
||||
func (m *Kit) GetTreeNode(entryID string) *TreeNode {
|
||||
if m.treeSession == nil {
|
||||
return nil
|
||||
}
|
||||
entry := m.treeSession.GetEntry(entryID)
|
||||
if entry == nil {
|
||||
return nil
|
||||
}
|
||||
return m.entryToTreeNode(entry)
|
||||
}
|
||||
|
||||
// GetCurrentBranch returns the path from root to current leaf as TreeNodes.
|
||||
func (m *Kit) GetCurrentBranch() []TreeNode {
|
||||
if m.treeSession == nil {
|
||||
return nil
|
||||
}
|
||||
branch := m.treeSession.GetBranch("")
|
||||
var nodes []TreeNode
|
||||
for _, entry := range branch {
|
||||
node := m.entryToTreeNode(entry)
|
||||
if node != nil {
|
||||
nodes = append(nodes, *node)
|
||||
}
|
||||
}
|
||||
return nodes
|
||||
}
|
||||
|
||||
// GetChildren returns direct child IDs of an entry.
|
||||
func (m *Kit) GetChildren(parentID string) []string {
|
||||
if m.treeSession == nil {
|
||||
return nil
|
||||
}
|
||||
return m.treeSession.GetChildren(parentID)
|
||||
}
|
||||
|
||||
// NavigateTo branches/forks the session to the specified entry ID.
|
||||
// Returns error description or empty string for success.
|
||||
func (m *Kit) NavigateTo(entryID string) string {
|
||||
if m.treeSession == nil {
|
||||
return "no tree session available"
|
||||
}
|
||||
if err := m.treeSession.Branch(entryID); err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// SummarizeBranch uses LLM to summarize a branch range.
|
||||
// Returns summary text or error string.
|
||||
func (m *Kit) SummarizeBranch(fromID, toID string) string {
|
||||
if m.treeSession == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Get the branch and find the range
|
||||
branch := m.treeSession.GetBranch("")
|
||||
var startIdx, endIdx = -1, -1
|
||||
for i, entry := range branch {
|
||||
id := m.getEntryID(entry)
|
||||
if id == fromID {
|
||||
startIdx = i
|
||||
}
|
||||
if id == toID {
|
||||
endIdx = i
|
||||
}
|
||||
}
|
||||
|
||||
if startIdx < 0 || endIdx < 0 || startIdx > endIdx {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Build text to summarize
|
||||
var content strings.Builder
|
||||
for i := startIdx; i <= endIdx; i++ {
|
||||
node := m.entryToTreeNode(branch[i])
|
||||
if node != nil && node.Content != "" {
|
||||
fmt.Fprintf(&content, "[%s] %s\n\n", node.Role, node.Content)
|
||||
}
|
||||
}
|
||||
|
||||
if content.Len() == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Use LLM to summarize
|
||||
resp, err := m.ExecuteCompletion(context.Background(), extensions.CompleteRequest{
|
||||
Model: "", // Use current model
|
||||
System: "You are a concise summarization assistant. Summarize the conversation in 2-3 sentences.",
|
||||
Prompt: content.String(),
|
||||
})
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return resp.Text
|
||||
}
|
||||
|
||||
// CollapseBranch replaces a branch range with a summary entry.
|
||||
// Returns error description or empty string for success.
|
||||
func (m *Kit) CollapseBranch(fromID, toID, summary string) string {
|
||||
if m.treeSession == nil {
|
||||
return "no tree session available"
|
||||
}
|
||||
_, err := m.treeSession.AppendBranchSummary(fromID, summary)
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// entryToTreeNode converts a session entry to a TreeNode.
|
||||
func (m *Kit) entryToTreeNode(entry any) *TreeNode {
|
||||
switch e := entry.(type) {
|
||||
case *session.MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var content strings.Builder
|
||||
for _, p := range msg.Parts {
|
||||
switch pt := p.(type) {
|
||||
case message.TextContent:
|
||||
content.WriteString(pt.Text)
|
||||
case message.ReasoningContent:
|
||||
content.WriteString(pt.Thinking)
|
||||
case message.ToolCall:
|
||||
fmt.Fprintf(&content, "[tool_call: %s]", pt.Name)
|
||||
case message.ToolResult:
|
||||
fmt.Fprintf(&content, "[tool_result: %s]", pt.Content)
|
||||
}
|
||||
}
|
||||
return &TreeNode{
|
||||
ID: e.ID,
|
||||
ParentID: e.ParentID,
|
||||
Type: "message",
|
||||
Role: string(msg.Role),
|
||||
Content: content.String(),
|
||||
Model: msg.Model,
|
||||
Provider: msg.Provider,
|
||||
Timestamp: e.Timestamp.Format(time.RFC3339),
|
||||
Children: m.treeSession.GetChildren(e.ID),
|
||||
}
|
||||
case *session.BranchSummaryEntry:
|
||||
return &TreeNode{
|
||||
ID: e.ID,
|
||||
ParentID: e.ParentID,
|
||||
Type: "branch_summary",
|
||||
Content: e.Summary,
|
||||
Timestamp: e.Timestamp.Format(time.RFC3339),
|
||||
Children: m.treeSession.GetChildren(e.ID),
|
||||
}
|
||||
case *session.ModelChangeEntry:
|
||||
return &TreeNode{
|
||||
ID: e.ID,
|
||||
ParentID: e.ParentID,
|
||||
Type: "model_change",
|
||||
Content: fmt.Sprintf("Model changed to %s/%s", e.Provider, e.ModelID),
|
||||
Model: e.Provider + "/" + e.ModelID,
|
||||
Provider: e.Provider,
|
||||
Timestamp: e.Timestamp.Format(time.RFC3339),
|
||||
Children: m.treeSession.GetChildren(e.ID),
|
||||
}
|
||||
case *session.ExtensionDataEntry:
|
||||
return &TreeNode{
|
||||
ID: e.ID,
|
||||
ParentID: e.ParentID,
|
||||
Type: "extension_data",
|
||||
Content: fmt.Sprintf("Extension data: %s", e.ExtType),
|
||||
Timestamp: e.Timestamp.Format(time.RFC3339),
|
||||
Children: m.treeSession.GetChildren(e.ID),
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// getEntryID extracts the ID from a session entry.
|
||||
func (m *Kit) getEntryID(entry any) string {
|
||||
switch e := entry.(type) {
|
||||
case *session.MessageEntry:
|
||||
return e.ID
|
||||
case *session.BranchSummaryEntry:
|
||||
return e.ID
|
||||
case *session.ModelChangeEntry:
|
||||
return e.ID
|
||||
case *session.ExtensionDataEntry:
|
||||
return e.ID
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// TreeNode represents a node in the session tree for SDK consumers.
|
||||
type TreeNode struct {
|
||||
ID string
|
||||
ParentID string
|
||||
Type string // "message", "branch_summary", "model_change", "extension_data"
|
||||
Role string // for messages: "user", "assistant", "system", "tool"
|
||||
Content string
|
||||
Model string
|
||||
Provider string
|
||||
Timestamp string
|
||||
Children []string
|
||||
}
|
||||
|
||||
+90
-1
@@ -1,6 +1,12 @@
|
||||
package kit
|
||||
|
||||
import "github.com/mark3labs/kit/internal/skills"
|
||||
import (
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/skills"
|
||||
)
|
||||
|
||||
// ==== Skills Types ====
|
||||
|
||||
@@ -67,3 +73,86 @@ func LoadPromptTemplate(path string) (*PromptTemplate, error) {
|
||||
func NewPromptBuilder(basePrompt string) *PromptBuilder {
|
||||
return skills.NewPromptBuilder(basePrompt)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Skill Bridge for Extensions (Phase 2)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// skillCache holds skills discovered for the current session.
|
||||
type skillCache struct {
|
||||
skills []*Skill
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var globalSkillCache skillCache
|
||||
|
||||
// DiscoverSkillsForExtension finds skills in standard locations for extensions.
|
||||
// Returns skills in the extension-facing format.
|
||||
func (m *Kit) DiscoverSkillsForExtension() []extensions.Skill {
|
||||
cwd, _ := os.Getwd()
|
||||
|
||||
// Check cache first
|
||||
globalSkillCache.mu.RLock()
|
||||
if len(globalSkillCache.skills) > 0 {
|
||||
globalSkillCache.mu.RUnlock()
|
||||
return m.convertSkills(globalSkillCache.skills)
|
||||
}
|
||||
globalSkillCache.mu.RUnlock()
|
||||
|
||||
// Load fresh
|
||||
skillList, _ := skills.LoadSkills(cwd)
|
||||
|
||||
globalSkillCache.mu.Lock()
|
||||
globalSkillCache.skills = skillList
|
||||
globalSkillCache.mu.Unlock()
|
||||
|
||||
return m.convertSkills(skillList)
|
||||
}
|
||||
|
||||
// LoadSkillForExtension loads a single skill file for extensions.
|
||||
func (m *Kit) LoadSkillForExtension(path string) (*extensions.Skill, string) {
|
||||
s, err := skills.LoadSkill(path)
|
||||
if err != nil {
|
||||
return nil, err.Error()
|
||||
}
|
||||
return m.convertSkill(s), ""
|
||||
}
|
||||
|
||||
// LoadSkillsFromDirForExtension loads all skills from a directory for extensions.
|
||||
func (m *Kit) LoadSkillsFromDirForExtension(dir string) extensions.SkillLoadResult {
|
||||
skillList, err := skills.LoadSkillsFromDir(dir)
|
||||
if err != nil {
|
||||
return extensions.SkillLoadResult{Error: err.Error()}
|
||||
}
|
||||
return extensions.SkillLoadResult{Skills: m.convertSkills(skillList)}
|
||||
}
|
||||
|
||||
// convertSkill converts internal skill to extension-facing format.
|
||||
func (m *Kit) convertSkill(s *skills.Skill) *extensions.Skill {
|
||||
return &extensions.Skill{
|
||||
Name: s.Name,
|
||||
Description: s.Description,
|
||||
Content: s.Content,
|
||||
Path: s.Path,
|
||||
Tags: s.Tags,
|
||||
When: s.When,
|
||||
}
|
||||
}
|
||||
|
||||
// convertSkills converts a slice of skills.
|
||||
func (m *Kit) convertSkills(skills []*skills.Skill) []extensions.Skill {
|
||||
result := make([]extensions.Skill, 0, len(skills))
|
||||
for _, s := range skills {
|
||||
if converted := m.convertSkill(s); converted != nil {
|
||||
result = append(result, *converted)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ClearSkillCache clears the global skill cache (called on reload).
|
||||
func (m *Kit) ClearSkillCache() {
|
||||
globalSkillCache.mu.Lock()
|
||||
globalSkillCache.skills = nil
|
||||
globalSkillCache.mu.Unlock()
|
||||
}
|
||||
|
||||
@@ -0,0 +1,462 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Template Parsing Bridge for Extensions (Phase 3)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// varRegex matches {{variable}} placeholders in templates.
|
||||
var varRegex = regexp.MustCompile(`\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}\}`)
|
||||
|
||||
// ParseTemplate extracts {{variables}} from template content.
|
||||
func ParseTemplate(name, content string) extensions.PromptTemplate {
|
||||
matches := varRegex.FindAllStringSubmatch(content, -1)
|
||||
vars := make([]string, 0, len(matches))
|
||||
seen := make(map[string]bool)
|
||||
for _, m := range matches {
|
||||
if len(m) > 1 && !seen[m[1]] {
|
||||
seen[m[1]] = true
|
||||
vars = append(vars, m[1])
|
||||
}
|
||||
}
|
||||
return extensions.PromptTemplate{
|
||||
Name: name,
|
||||
Content: content,
|
||||
Variables: vars,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderTemplate substitutes variables into template content.
|
||||
func RenderTemplate(tpl extensions.PromptTemplate, vars map[string]string) string {
|
||||
result := tpl.Content
|
||||
for name, value := range vars {
|
||||
placeholder := "{{" + name + "}}"
|
||||
result = strings.ReplaceAll(result, placeholder, value)
|
||||
// Also handle with spaces
|
||||
placeholderSpaced := "{{ " + name + " }}"
|
||||
result = strings.ReplaceAll(result, placeholderSpaced, value)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ParseArguments parses command-line style arguments.
|
||||
func ParseArguments(input string, pattern extensions.ArgumentPattern) extensions.ParseResult {
|
||||
result := extensions.ParseResult{
|
||||
Vars: make(map[string]string),
|
||||
Flags: make(map[string]string),
|
||||
}
|
||||
|
||||
fields := parseFields(input)
|
||||
if len(fields) == 0 {
|
||||
return result
|
||||
}
|
||||
|
||||
// First field is the command itself (if present)
|
||||
startIdx := 0
|
||||
if len(fields) > 0 && !strings.HasPrefix(fields[0], "-") {
|
||||
// Check if it's a command name or positional arg
|
||||
if len(pattern.Positional) == 0 || !isFlag(fields[0], pattern.Flags) {
|
||||
startIdx = 1 // Skip command name
|
||||
}
|
||||
}
|
||||
|
||||
// Parse flags
|
||||
i := startIdx
|
||||
for i < len(fields) {
|
||||
field := fields[i]
|
||||
|
||||
// Check for flags
|
||||
if strings.HasPrefix(field, "--") {
|
||||
flagName := field[2:]
|
||||
if varName, ok := pattern.Flags["--"+flagName]; ok {
|
||||
// Flag with value
|
||||
if i+1 < len(fields) && !strings.HasPrefix(fields[i+1], "-") {
|
||||
result.Flags["--"+flagName] = fields[i+1]
|
||||
result.Vars[varName] = fields[i+1]
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
// Boolean flag
|
||||
result.Flags["--"+flagName] = "true"
|
||||
result.Vars[varName] = "true"
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(field, "-") && len(field) > 1 {
|
||||
flagName := field[1:]
|
||||
if varName, ok := pattern.Flags["-"+flagName]; ok {
|
||||
// Flag with value
|
||||
if i+1 < len(fields) && !strings.HasPrefix(fields[i+1], "-") {
|
||||
result.Flags["-"+flagName] = fields[i+1]
|
||||
result.Vars[varName] = fields[i+1]
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
// Boolean flag
|
||||
result.Flags["-"+flagName] = "true"
|
||||
result.Vars[varName] = "true"
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
i++
|
||||
}
|
||||
|
||||
// Collect remaining as positional args and "rest"
|
||||
positional := make([]string, 0)
|
||||
i = startIdx
|
||||
for i < len(fields) {
|
||||
field := fields[i]
|
||||
if !strings.HasPrefix(field, "-") {
|
||||
// Check if this was consumed as a flag value
|
||||
consumed := false
|
||||
for _, v := range result.Vars {
|
||||
if v == field {
|
||||
// Might be consumed, check previous field
|
||||
if i > 0 {
|
||||
prev := fields[i-1]
|
||||
if strings.HasPrefix(prev, "-") {
|
||||
consumed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !consumed {
|
||||
positional = append(positional, field)
|
||||
}
|
||||
}
|
||||
i++
|
||||
}
|
||||
|
||||
// Map positional args
|
||||
for i, name := range pattern.Positional {
|
||||
if i < len(positional) {
|
||||
result.Vars[name] = positional[i]
|
||||
}
|
||||
}
|
||||
|
||||
// Set rest
|
||||
if pattern.Rest != "" && len(positional) > len(pattern.Positional) {
|
||||
restStart := len(pattern.Positional)
|
||||
if restStart < len(positional) {
|
||||
result.Vars[pattern.Rest] = strings.Join(positional[restStart:], " ")
|
||||
}
|
||||
}
|
||||
|
||||
result.Rest = strings.Join(fields, " ")
|
||||
return result
|
||||
}
|
||||
|
||||
// SimpleParseArguments parses $1, $2, $@ style arguments.
|
||||
// Returns slice where [0]=full input, [1]=$1, [2]=$2, ... [n]=$@
|
||||
func SimpleParseArguments(input string, count int) []string {
|
||||
fields := parseFields(input)
|
||||
result := make([]string, 0, count+2)
|
||||
result = append(result, input) // [0] = full input
|
||||
|
||||
// [1]..[count] = positional args
|
||||
for i := 0; i < count; i++ {
|
||||
if i < len(fields) {
|
||||
result = append(result, fields[i])
|
||||
} else {
|
||||
result = append(result, "")
|
||||
}
|
||||
}
|
||||
|
||||
// [n] = $@ (all remaining)
|
||||
if len(fields) > count {
|
||||
result = append(result, strings.Join(fields[count:], " "))
|
||||
} else {
|
||||
result = append(result, "")
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// parseFields splits input respecting quoted strings.
|
||||
func parseFields(input string) []string {
|
||||
var fields []string
|
||||
var current strings.Builder
|
||||
inQuote := false
|
||||
quoteChar := rune(0)
|
||||
|
||||
for _, r := range input {
|
||||
switch r {
|
||||
case '"', '\'':
|
||||
if !inQuote {
|
||||
inQuote = true
|
||||
quoteChar = r
|
||||
} else if r == quoteChar {
|
||||
inQuote = false
|
||||
quoteChar = 0
|
||||
} else {
|
||||
current.WriteRune(r)
|
||||
}
|
||||
case ' ', '\t':
|
||||
if inQuote {
|
||||
current.WriteRune(r)
|
||||
} else {
|
||||
if current.Len() > 0 {
|
||||
fields = append(fields, current.String())
|
||||
current.Reset()
|
||||
}
|
||||
}
|
||||
default:
|
||||
current.WriteRune(r)
|
||||
}
|
||||
}
|
||||
|
||||
if current.Len() > 0 {
|
||||
fields = append(fields, current.String())
|
||||
}
|
||||
|
||||
return fields
|
||||
}
|
||||
|
||||
// isFlag checks if a field is a known flag.
|
||||
func isFlag(field string, flags map[string]string) bool {
|
||||
if strings.HasPrefix(field, "--") {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(field, "-") && len(field) > 1 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// EvaluateModelConditional checks if condition matches current model.
|
||||
// Condition supports wildcards: * matches any, ? matches single char.
|
||||
func EvaluateModelConditional(currentModel, condition string) bool {
|
||||
// Handle comma-separated conditions (OR logic)
|
||||
for _, c := range strings.Split(condition, ",") {
|
||||
c = strings.TrimSpace(c)
|
||||
if matchModelPattern(currentModel, c) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// matchModelPattern matches a model against a pattern with wildcards.
|
||||
func matchModelPattern(model, pattern string) bool {
|
||||
// Convert pattern to regexp
|
||||
pattern = strings.ReplaceAll(pattern, "*", ".*")
|
||||
pattern = strings.ReplaceAll(pattern, "?", ".")
|
||||
pattern = "^" + pattern + "$"
|
||||
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
// Fallback: exact match
|
||||
return model == pattern
|
||||
}
|
||||
return re.MatchString(model)
|
||||
}
|
||||
|
||||
// RenderWithModelConditionals processes <if-model> blocks in content.
|
||||
func RenderWithModelConditionals(content, currentModel string) string {
|
||||
// Simple regex-based processor for <if-model> blocks
|
||||
// Supports: <if-model is="pattern">content</if-model>
|
||||
// And: <if-model is="pattern">content<else>other</if-model>
|
||||
|
||||
result := content
|
||||
|
||||
// Pattern for if-model blocks
|
||||
ifModelRegex := regexp.MustCompile(`(?s)<if-model\s+is="([^"]+)">(.*?)(?:<else>(.*?))?</if-model>`)
|
||||
|
||||
for {
|
||||
match := ifModelRegex.FindStringSubmatchIndex(result)
|
||||
if match == nil {
|
||||
break
|
||||
}
|
||||
|
||||
condition := result[match[2]:match[3]]
|
||||
ifContent := result[match[4]:match[5]]
|
||||
elseContent := ""
|
||||
if match[6] >= 0 && match[7] >= 0 {
|
||||
elseContent = result[match[6]:match[7]]
|
||||
}
|
||||
|
||||
var replacement string
|
||||
if EvaluateModelConditional(currentModel, condition) {
|
||||
replacement = ifContent
|
||||
} else {
|
||||
replacement = elseContent
|
||||
}
|
||||
|
||||
result = result[:match[0]] + replacement + result[match[1]:]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Model Resolution Bridge for Extensions (Phase 4)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ResolveModelChain attempts each model in order until one is available.
|
||||
func ResolveModelChain(preferences []string) extensions.ModelResolutionResult {
|
||||
result := extensions.ModelResolutionResult{
|
||||
Attempted: make([]string, 0, len(preferences)),
|
||||
}
|
||||
|
||||
registry := models.GetGlobalRegistry()
|
||||
|
||||
for _, pref := range preferences {
|
||||
pref = strings.TrimSpace(pref)
|
||||
result.Attempted = append(result.Attempted, pref)
|
||||
|
||||
// Parse model string
|
||||
provider, modelID, err := models.ParseModelString(pref)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if provider exists
|
||||
if registry.GetProviderInfo(provider) == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if model exists in registry
|
||||
modelInfo := registry.LookupModel(provider, modelID)
|
||||
if modelInfo == nil {
|
||||
// Try with just the model as bare name
|
||||
continue
|
||||
}
|
||||
|
||||
// Found available model
|
||||
result.Model = provider + "/" + modelID
|
||||
result.Capabilities = extensions.ModelCapabilities{
|
||||
Provider: provider,
|
||||
ModelID: modelID,
|
||||
ContextLimit: modelInfo.Limit.Context,
|
||||
OutputLimit: modelInfo.Limit.Output,
|
||||
Reasoning: modelInfo.Reasoning,
|
||||
Streaming: true, // Assume streaming support
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
result.Error = "no models in chain are available"
|
||||
return result
|
||||
}
|
||||
|
||||
// GetModelCapabilities returns capabilities for a specific model.
|
||||
// If model is empty, returns zero capabilities.
|
||||
func GetModelCapabilities(model string) (extensions.ModelCapabilities, string) {
|
||||
if model == "" {
|
||||
return extensions.ModelCapabilities{}, "no model specified"
|
||||
}
|
||||
|
||||
provider, modelID, err := models.ParseModelString(model)
|
||||
if err != nil {
|
||||
return extensions.ModelCapabilities{}, err.Error()
|
||||
}
|
||||
|
||||
registry := models.GetGlobalRegistry()
|
||||
modelInfo := registry.LookupModel(provider, modelID)
|
||||
if modelInfo == nil {
|
||||
return extensions.ModelCapabilities{}, "model not found in registry"
|
||||
}
|
||||
|
||||
return extensions.ModelCapabilities{
|
||||
Provider: provider,
|
||||
ModelID: modelID,
|
||||
ContextLimit: modelInfo.Limit.Context,
|
||||
OutputLimit: modelInfo.Limit.Output,
|
||||
Reasoning: modelInfo.Reasoning,
|
||||
Streaming: true,
|
||||
}, ""
|
||||
}
|
||||
|
||||
// CheckModelAvailable verifies if a model string is valid and provider exists.
|
||||
func CheckModelAvailable(model string) bool {
|
||||
provider, _, err := models.ParseModelString(model)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
registry := models.GetGlobalRegistry()
|
||||
if registry.GetProviderInfo(provider) == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Model doesn't need to be in registry - could be dynamic/Ollama
|
||||
return true
|
||||
}
|
||||
|
||||
// GetCurrentProvider extracts provider from model string.
|
||||
func GetCurrentProvider(model string) string {
|
||||
provider, _, _ := models.ParseModelString(model)
|
||||
return provider
|
||||
}
|
||||
|
||||
// GetCurrentModelID extracts model ID from model string.
|
||||
func GetCurrentModelID(model string) string {
|
||||
_, modelID, _ := models.ParseModelString(model)
|
||||
return modelID
|
||||
}
|
||||
|
||||
// JoinModel combines provider and model ID into a model string.
|
||||
func JoinModel(provider, modelID string) string {
|
||||
if provider == "" {
|
||||
return modelID
|
||||
}
|
||||
return provider + "/" + modelID
|
||||
}
|
||||
|
||||
// MatchModelGlob matches a model against a glob pattern.
|
||||
// Pattern can contain * (match any) and ? (match single).
|
||||
func MatchModelGlob(model, pattern string) bool {
|
||||
return matchModelPattern(model, pattern)
|
||||
}
|
||||
|
||||
// ExtractProviderFromPath extracts provider from a path-like model string.
|
||||
func ExtractProviderFromPath(model string) string {
|
||||
parts := strings.Split(model, "/")
|
||||
if len(parts) >= 2 {
|
||||
return parts[0]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ExtractModelFromPath extracts model ID from a path-like model string.
|
||||
func ExtractModelFromPath(model string) string {
|
||||
parts := strings.Split(model, "/")
|
||||
if len(parts) >= 2 {
|
||||
return parts[1]
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
// IsBareModelID checks if a string is a bare model ID (no provider).
|
||||
func IsBareModelID(model string) bool {
|
||||
return !strings.Contains(model, "/")
|
||||
}
|
||||
|
||||
// AddProviderToModel adds a provider prefix to a bare model ID.
|
||||
func AddProviderToModel(provider, model string) string {
|
||||
if strings.Contains(model, "/") {
|
||||
return model // Already has provider
|
||||
}
|
||||
return provider + "/" + model
|
||||
}
|
||||
|
||||
// RemoveProviderFromModel removes the provider prefix from a model string.
|
||||
func RemoveProviderFromModel(model string) string {
|
||||
parts := strings.SplitN(model, "/", 2)
|
||||
if len(parts) == 2 {
|
||||
return parts[1]
|
||||
}
|
||||
return model
|
||||
}
|
||||
+1
-1
@@ -52,7 +52,7 @@ func CodingTools(opts ...ToolOption) []Tool { return core.CodingTools(opts...) }
|
||||
// read, grep, find, ls.
|
||||
func ReadOnlyTools(opts ...ToolOption) []Tool { return core.ReadOnlyTools(opts...) }
|
||||
|
||||
// SubagentTools returns all core tools except spawn_subagent. Use this when
|
||||
// SubagentTools returns all core tools except subagent. Use this when
|
||||
// creating child Kit instances (in-process subagents) to prevent infinite
|
||||
// recursion.
|
||||
func SubagentTools(opts ...ToolOption) []Tool { return core.SubagentTools(opts...) }
|
||||
|
||||
@@ -1210,6 +1210,129 @@ func applyMode(ctx ext.Context, active bool, tools []string) {
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Bridged SDK APIs (New)
|
||||
|
||||
Extensions can now access powerful internal SDK capabilities that enable advanced features like conversation tree navigation, dynamic skill loading, template parsing, and model resolution.
|
||||
|
||||
### Tree Navigation
|
||||
|
||||
Navigate the conversation tree, summarize branches, and implement "fresh context" loops:
|
||||
|
||||
```go
|
||||
// Get a specific node by ID with full metadata and children
|
||||
node := ctx.GetTreeNode("entry-id")
|
||||
// node.ID, node.ParentID, node.Type ("message"/"branch_summary"/etc)
|
||||
// node.Role, node.Content, node.Model, node.Children ([]string)
|
||||
|
||||
// Get the current branch from root to leaf
|
||||
branch := ctx.GetCurrentBranch() // []ext.TreeNode
|
||||
|
||||
// Get child entry IDs of a node
|
||||
children := ctx.GetChildren("entry-id") // []string
|
||||
|
||||
// Navigate/fork to a different entry in the tree
|
||||
result := ctx.NavigateTo("entry-id") // ext.TreeNavigationResult{Success, Error}
|
||||
|
||||
// Summarize a range of the branch using LLM
|
||||
summary := ctx.SummarizeBranch("from-id", "to-id") // string
|
||||
|
||||
// Collapse a branch range into a summary entry (fresh context primitive)
|
||||
result := ctx.CollapseBranch("from-id", "to-id", "summary text")
|
||||
```
|
||||
|
||||
### Skill Loading
|
||||
|
||||
Load and inject skills dynamically at runtime:
|
||||
|
||||
```go
|
||||
// Discover skills from standard locations
|
||||
result := ctx.DiscoverSkills() // ext.SkillLoadResult{Skills, Error}
|
||||
// Standard locations: ~/.config/kit/skills/, .kit/skills/, .agents/skills/
|
||||
|
||||
// Load a specific skill file
|
||||
skill, err := ctx.LoadSkill("/path/to/skill.md") // (*ext.Skill, error string)
|
||||
// skill.Name, skill.Description, skill.Content, skill.Tags, skill.When
|
||||
|
||||
// Load all skills from a directory
|
||||
result := ctx.LoadSkillsFromDir("/path/to/skills") // ext.SkillLoadResult
|
||||
|
||||
// Inject a skill as context (pre-loads for next turn)
|
||||
err := ctx.InjectSkillAsContext("skill-name") // error string
|
||||
|
||||
// Inject a skill file directly
|
||||
err := ctx.InjectRawSkillAsContext("/path/to/skill.md") // error string
|
||||
|
||||
// Get all discovered skills
|
||||
skills := ctx.GetAvailableSkills() // []ext.Skill
|
||||
```
|
||||
|
||||
### Template Parsing
|
||||
|
||||
Parse and render templates with variable substitution:
|
||||
|
||||
```go
|
||||
// Parse a template to extract {{variables}}
|
||||
tpl := ctx.ParseTemplate("name", "Hello {{name}}, welcome to {{place}}!")
|
||||
// tpl.Name, tpl.Content, tpl.Variables ([]string)
|
||||
|
||||
// Render a template with variable values
|
||||
vars := map[string]string{"name": "Alice", "place": "Kit"}
|
||||
rendered := ctx.RenderTemplate(tpl, vars) // "Hello Alice, welcome to Kit!"
|
||||
|
||||
// Parse command-line style arguments
|
||||
pattern := ext.ArgumentPattern{
|
||||
Positional: []string{"command", "target"}, // $1, $2
|
||||
Rest: "args", // $@
|
||||
Flags: map[string]string{"--loop": "loop", "-f": "force"},
|
||||
}
|
||||
result := ctx.ParseArguments("deploy staging --loop 5", pattern)
|
||||
// result.Vars["command"] = "deploy"
|
||||
// result.Vars["target"] = "staging"
|
||||
// result.Flags["--loop"] = "5"
|
||||
|
||||
// Simple positional argument parsing ($1, $2, $@)
|
||||
args := ctx.SimpleParseArguments("deploy staging --force", 2)
|
||||
// args[0] = "deploy staging --force" (full input)
|
||||
// args[1] = "deploy" ($1)
|
||||
// args[2] = "staging" ($2)
|
||||
// args[3] = "--force" ($@)
|
||||
|
||||
// Evaluate model conditionals with wildcards
|
||||
matches := ctx.EvaluateModelConditional("claude-*") // bool
|
||||
// Patterns: * matches any, ? matches single char, comma = OR
|
||||
|
||||
// Render content with <if-model> conditionals
|
||||
content := `<if-model is="claude-*">Hi Claude<else>Hi there</if-model>`
|
||||
rendered := ctx.RenderWithModelConditionals(content) // based on current model
|
||||
```
|
||||
|
||||
### Model Resolution
|
||||
|
||||
Resolve model fallback chains and query capabilities:
|
||||
|
||||
```go
|
||||
// Resolve a chain of model preferences (tries each until available)
|
||||
result := ctx.ResolveModelChain([]string{
|
||||
"anthropic/claude-opus-4",
|
||||
"anthropic/claude-sonnet-4",
|
||||
"openai/gpt-4o",
|
||||
})
|
||||
// result.Model (selected), result.Capabilities, result.Attempted, result.Error
|
||||
|
||||
// Get capabilities for a specific model
|
||||
caps, err := ctx.GetModelCapabilities("anthropic/claude-sonnet-4")
|
||||
// caps.Provider, caps.ModelID, caps.ContextLimit, caps.Reasoning, caps.Streaming
|
||||
|
||||
// Check if a model is available (provider exists)
|
||||
available := ctx.CheckModelAvailable("anthropic/claude-sonnet-4") // bool
|
||||
|
||||
// Get current provider/model ID
|
||||
provider := ctx.GetCurrentProvider() // "anthropic"
|
||||
modelID := ctx.GetCurrentModelID() // "claude-sonnet-4"
|
||||
```
|
||||
|
||||
## Key Files for Reference
|
||||
|
||||
- [`internal/extensions/api.go`](https://github.com/mark3labs/kit/blob/main/internal/extensions/api.go) — Complete API type definitions
|
||||
|
||||
+4
-14
@@ -167,16 +167,6 @@ result, err := host.PromptResultWithMessages(ctx, []string{
|
||||
})
|
||||
```
|
||||
|
||||
### Legacy inline callbacks (deprecated — use event subscribers instead)
|
||||
|
||||
```go
|
||||
response, err := host.PromptWithCallbacks(ctx, "List files",
|
||||
func(name, args string) { fmt.Printf("Tool: %s\n", name) },
|
||||
func(name, args, result string, isError bool) { /* tool result */ },
|
||||
func(chunk string) { fmt.Print(chunk) }, // streaming
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Event System
|
||||
@@ -261,7 +251,7 @@ Tools are classified by kind for UI rendering:
|
||||
- `ToolKindEdit` = `"edit"` — edit, write
|
||||
- `ToolKindRead` = `"read"` — read, ls
|
||||
- `ToolKindSearch` = `"search"` — grep, find
|
||||
- `ToolKindSubagent` = `"agent"` — spawn_subagent
|
||||
- `ToolKindSubagent` = `"agent"` — subagent
|
||||
|
||||
---
|
||||
|
||||
@@ -368,7 +358,7 @@ kit.NewLsTool(opts...) // directory listing
|
||||
kit.AllTools(opts...) // all 7 core tools
|
||||
kit.CodingTools(opts...) // bash, read, write, edit
|
||||
kit.ReadOnlyTools(opts...) // read, grep, find, ls
|
||||
kit.SubagentTools(opts...) // all except spawn_subagent (prevents recursion)
|
||||
kit.SubagentTools(opts...) // all except subagent (prevents recursion)
|
||||
```
|
||||
|
||||
### Tool options
|
||||
@@ -524,7 +514,7 @@ result, err := host.Subagent(ctx, kit.SubagentConfig{
|
||||
Prompt: "Analyze the test files and summarize coverage",
|
||||
Model: "anthropic/claude-haiku-3-5-20241022", // empty = parent's model
|
||||
SystemPrompt: "You are a test analysis expert.",
|
||||
Tools: nil, // nil = SubagentTools() (all except spawn_subagent)
|
||||
Tools: nil, // nil = SubagentTools() (all except subagent)
|
||||
NoSession: true, // ephemeral
|
||||
Timeout: 2 * time.Minute, // 0 = 5 minute default
|
||||
OnEvent: func(e kit.Event) {
|
||||
@@ -542,7 +532,7 @@ result, err := host.Subagent(ctx, kit.SubagentConfig{
|
||||
|
||||
```go
|
||||
host.OnToolCall(func(e kit.ToolCallEvent) {
|
||||
if e.ToolName == "spawn_subagent" {
|
||||
if e.ToolName == "subagent" {
|
||||
host.SubscribeSubagent(e.ToolCallID, func(child kit.Event) {
|
||||
// Real-time events scoped to this subagent
|
||||
})
|
||||
|
||||
@@ -32,12 +32,12 @@ Key flags for subprocess usage:
|
||||
|
||||
Positional arguments are the prompt. `@file` arguments attach file content as context.
|
||||
|
||||
## Built-in spawn_subagent tool
|
||||
## Built-in subagent tool
|
||||
|
||||
Kit includes a built-in `spawn_subagent` tool that the LLM can use to delegate tasks to independent child agents:
|
||||
Kit includes a built-in `subagent` tool that the LLM can use to delegate tasks to independent child agents:
|
||||
|
||||
```
|
||||
spawn_subagent(
|
||||
subagent(
|
||||
task: "Analyze the test files and summarize coverage",
|
||||
model: "anthropic/claude-haiku-latest", // optional
|
||||
system_prompt: "You are a test analysis expert.", // optional
|
||||
@@ -59,6 +59,79 @@ result := ctx.SpawnSubagent(ext.SubagentConfig{
|
||||
})
|
||||
```
|
||||
|
||||
### Monitoring subagents from extensions
|
||||
|
||||
When the LLM (not the extension itself) spawns a subagent using the `subagent` tool, extensions can monitor its activity in real-time using three lifecycle event handlers:
|
||||
|
||||
```go
|
||||
// Track active subagents and display their output
|
||||
var subagentWidgets map[string]*SubagentWidget
|
||||
|
||||
func Init(api ext.API) {
|
||||
// Subagent started by the main agent
|
||||
api.OnSubagentStart(func(e ext.SubagentStartEvent, ctx ext.Context) {
|
||||
// e.ToolCallID — unique ID for this subagent invocation
|
||||
// e.Task — the task/prompt sent to the subagent
|
||||
widget := NewWidget(e.ToolCallID, e.Task)
|
||||
subagentWidgets[e.ToolCallID] = widget
|
||||
ctx.SetWidget(widget.Config())
|
||||
})
|
||||
|
||||
// Real-time streaming from subagent
|
||||
api.OnSubagentChunk(func(e ext.SubagentChunkEvent, ctx ext.Context) {
|
||||
// e.ToolCallID — matches the start event
|
||||
// e.ChunkType — "text", "tool_call", "tool_execution_start", "tool_result"
|
||||
// e.Content — text content
|
||||
// e.ToolName — tool name (for tool chunks)
|
||||
// e.IsError — true if tool result failed
|
||||
widget := subagentWidgets[e.ToolCallID]
|
||||
if widget != nil {
|
||||
widget.AddOutput(e)
|
||||
ctx.SetWidget(widget.Config())
|
||||
}
|
||||
})
|
||||
|
||||
// Subagent completed
|
||||
api.OnSubagentEnd(func(e ext.SubagentEndEvent, ctx ext.Context) {
|
||||
// e.Response — final response from subagent
|
||||
// e.ErrorMsg — error message if subagent failed
|
||||
widget := subagentWidgets[e.ToolCallID]
|
||||
if widget != nil {
|
||||
widget.MarkComplete(e.Response, e.ErrorMsg)
|
||||
ctx.SetWidget(widget.Config())
|
||||
delete(subagentWidgets, e.ToolCallID)
|
||||
}
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
**Event structs:**
|
||||
|
||||
```go
|
||||
type SubagentStartEvent struct {
|
||||
ToolCallID string // Unique ID for this subagent invocation
|
||||
Task string // The task/prompt sent to subagent
|
||||
}
|
||||
|
||||
type SubagentChunkEvent struct {
|
||||
ToolCallID string // Matches SubagentStartEvent.ToolCallID
|
||||
Task string // Task description
|
||||
ChunkType string // "text", "tool_call", "tool_execution_start", "tool_result"
|
||||
Content string // For text chunks
|
||||
ToolName string // For tool-related chunks
|
||||
IsError bool // For tool_result chunks
|
||||
}
|
||||
|
||||
type SubagentEndEvent struct {
|
||||
ToolCallID string // Matches start event
|
||||
Task string // Task description
|
||||
Response string // Final response from subagent
|
||||
ErrorMsg string // Error message if failed
|
||||
}
|
||||
```
|
||||
|
||||
This enables building monitoring widgets that display real-time activity from all subagents spawned by the main agent.
|
||||
|
||||
## Go SDK subagents
|
||||
|
||||
The SDK provides in-process subagent spawning:
|
||||
@@ -74,11 +147,11 @@ result, err := host.Subagent(ctx, kit.SubagentConfig{
|
||||
|
||||
### Real-time subagent events
|
||||
|
||||
Use `SubscribeSubagent` to receive real-time events from LLM-initiated subagents (i.e., when the model uses the `spawn_subagent` tool). Register inside an `OnToolCall` handler using the tool call ID:
|
||||
Use `SubscribeSubagent` to receive real-time events from LLM-initiated subagents (i.e., when the model uses the `subagent` tool). Register inside an `OnToolCall` handler using the tool call ID:
|
||||
|
||||
```go
|
||||
host.OnToolCall(func(e kit.ToolCallEvent) {
|
||||
if e.ToolName == "spawn_subagent" {
|
||||
if e.ToolName == "subagent" {
|
||||
host.SubscribeSubagent(e.ToolCallID, func(event kit.Event) {
|
||||
switch ev := event.(type) {
|
||||
case kit.MessageUpdateEvent:
|
||||
|
||||
@@ -74,7 +74,7 @@ These commands are available inside the Kit TUI during an interactive session:
|
||||
| `/reset-usage` | Reset usage statistics |
|
||||
| `/tree` | Navigate session tree |
|
||||
| `/fork` | Branch from an earlier message |
|
||||
| `/new` | Start a new session |
|
||||
| `/new` | Start a new session (creates new session file) |
|
||||
| `/name [name]` | Set or show session display name |
|
||||
| `/resume` | Open session picker to switch sessions (alias: `/r`) |
|
||||
| `/session` | Show session info |
|
||||
@@ -95,9 +95,17 @@ Press **ESC twice** to cancel the current operation:
|
||||
|
||||
This ensures that `tool_use` and `tool_result` messages are always sent to the API as matched pairs, avoiding errors from orphaned tool calls.
|
||||
|
||||
## Prompt templates
|
||||
### Mid-turn steering
|
||||
|
||||
Create reusable prompt templates with shell-style argument substitution. Templates are loaded from `~/.kit/prompts/*.md` and `.kit/prompts/*.md`.
|
||||
Press **Ctrl+S** during streaming to inject a system-level instruction mid-turn. This allows you to steer the conversation direction without waiting for the model to finish:
|
||||
|
||||
- Works during streaming output
|
||||
- Sends a steering instruction as a system message
|
||||
- Model continues from the interruption point with the new guidance
|
||||
|
||||
Example: While the model is writing code, press Ctrl+S and type "Use async/await instead" to change the implementation approach.
|
||||
|
||||
## Prompt templates
|
||||
|
||||
### Creating templates
|
||||
|
||||
|
||||
@@ -96,9 +96,45 @@ mcpServers:
|
||||
|
||||
A legacy format with `transport`, `args`, `env`, and `headers` fields is also supported.
|
||||
|
||||
## Theme configuration
|
||||
## Custom models
|
||||
|
||||
Set theme colors inline or reference an external file:
|
||||
Define custom models in your `.kit.yml` for use with the `custom` provider. This is useful for self-hosted models or API endpoints not in the built-in database:
|
||||
|
||||
```yaml
|
||||
customModels:
|
||||
my-model:
|
||||
name: "My Custom Model"
|
||||
reasoning: true
|
||||
temperature: true
|
||||
cost:
|
||||
input: 0.002
|
||||
output: 0.004
|
||||
limit:
|
||||
context: 128000
|
||||
output: 32000
|
||||
```
|
||||
|
||||
### Custom model fields
|
||||
|
||||
| Field | Type | Required | Description |
|
||||
|-------|------|----------|-------------|
|
||||
| `name` | string | Yes | Display name for the model |
|
||||
| `reasoning` | bool | No | Whether the model supports reasoning/thinking |
|
||||
| `temperature` | bool | No | Whether the model supports temperature adjustment |
|
||||
| `cost.input` | float | No | Cost per 1K input tokens |
|
||||
| `cost.output` | float | No | Cost per 1K output tokens |
|
||||
| `limit.context` | int | Yes | Maximum context window in tokens |
|
||||
| `limit.output` | int | No | Maximum output tokens |
|
||||
|
||||
Use with a custom provider URL:
|
||||
|
||||
```bash
|
||||
kit --provider-url "http://localhost:8080/v1" --model custom/my-model "Hello"
|
||||
```
|
||||
|
||||
When `--provider-url` is specified without `--model`, Kit defaults to `custom/custom` which has zero cost tracking and a 262K context window.
|
||||
|
||||
## Theme configuration
|
||||
|
||||
```yaml
|
||||
# Inline partial overrides (unspecified fields inherit from default)
|
||||
|
||||
@@ -7,7 +7,7 @@ description: All extension capabilities — lifecycle events, tools, commands, w
|
||||
|
||||
## Lifecycle events
|
||||
|
||||
Extensions can hook into 20 lifecycle events:
|
||||
Extensions can hook into 23 lifecycle events:
|
||||
|
||||
| Event | Description |
|
||||
|-------|-------------|
|
||||
@@ -31,6 +31,9 @@ Extensions can hook into 20 lifecycle events:
|
||||
| `OnBeforeSessionSwitch` | Before switching sessions |
|
||||
| `OnBeforeCompact` | Before conversation compaction |
|
||||
| `OnCustomEvent` | Custom inter-extension event received |
|
||||
| `OnSubagentStart` | Subagent spawned by the main agent |
|
||||
| `OnSubagentChunk` | Real-time output from subagent (text, tool calls, results) |
|
||||
| `OnSubagentEnd` | Subagent completed with final response/error |
|
||||
|
||||
### Example
|
||||
|
||||
@@ -234,6 +237,54 @@ result := ctx.SpawnSubagent(ext.SubagentConfig{
|
||||
})
|
||||
```
|
||||
|
||||
### Monitoring subagents spawned by the main agent
|
||||
|
||||
When the LLM uses the built-in `subagent` tool, extensions can monitor the subagent's activity in real-time using three lifecycle events:
|
||||
|
||||
```go
|
||||
// Subagent started
|
||||
api.OnSubagentStart(func(e ext.SubagentStartEvent, ctx ext.Context) {
|
||||
// e.ToolCallID — unique ID for this subagent invocation
|
||||
// e.Task — the task/prompt sent to the subagent
|
||||
ctx.PrintInfo(fmt.Sprintf("Subagent started: %s", e.Task))
|
||||
})
|
||||
|
||||
// Real-time streaming output from subagent
|
||||
api.OnSubagentChunk(func(e ext.SubagentChunkEvent, ctx ext.Context) {
|
||||
// e.ToolCallID — matches the start event
|
||||
// e.Task — task description
|
||||
// e.ChunkType — "text", "tool_call", "tool_execution_start", "tool_result"
|
||||
// e.Content — text content (for text chunks)
|
||||
// e.ToolName — tool name (for tool-related chunks)
|
||||
// e.IsError — true if tool result is an error
|
||||
switch e.ChunkType {
|
||||
case "text":
|
||||
// Streaming text output
|
||||
case "tool_call":
|
||||
// Subagent is calling a tool
|
||||
case "tool_execution_start":
|
||||
// Tool execution started
|
||||
case "tool_result":
|
||||
// Tool execution completed (check e.IsError)
|
||||
}
|
||||
})
|
||||
|
||||
// Subagent completed
|
||||
api.OnSubagentEnd(func(e ext.SubagentEndEvent, ctx ext.Context) {
|
||||
// e.ToolCallID — matches start event
|
||||
// e.Task — task description
|
||||
// e.Response — final response from subagent
|
||||
// e.ErrorMsg — error message if subagent failed
|
||||
if e.ErrorMsg != "" {
|
||||
ctx.PrintError(fmt.Sprintf("Subagent failed: %s", e.ErrorMsg))
|
||||
} else {
|
||||
ctx.PrintInfo(fmt.Sprintf("Subagent completed: %s", e.Response))
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
This enables building widgets that display real-time subagent activity.
|
||||
|
||||
## LLM completion
|
||||
|
||||
Make direct model calls without going through the agent loop:
|
||||
@@ -283,3 +334,124 @@ api.OnCustomEvent("my-extension:data-ready", func(data any, ctx ext.Context) {
|
||||
// handle event
|
||||
})
|
||||
```
|
||||
|
||||
## Bridged SDK APIs
|
||||
|
||||
Extensions can access powerful internal SDK capabilities that enable advanced features like conversation tree navigation, dynamic skill loading, template parsing, and model resolution.
|
||||
|
||||
### Tree Navigation
|
||||
|
||||
Navigate the conversation tree, summarize branches, and implement "fresh context" loops:
|
||||
|
||||
```go
|
||||
// Get a specific node by ID with full metadata and children
|
||||
node := ctx.GetTreeNode("entry-id")
|
||||
// node.ID, node.ParentID, node.Type ("message"/"branch_summary"/etc)
|
||||
// node.Role, node.Content, node.Model, node.Children ([]string)
|
||||
|
||||
// Get the current branch from root to leaf
|
||||
branch := ctx.GetCurrentBranch() // []ext.TreeNode
|
||||
|
||||
// Get child entry IDs of a node
|
||||
children := ctx.GetChildren("entry-id") // []string
|
||||
|
||||
// Navigate/fork to a different entry in the tree
|
||||
result := ctx.NavigateTo("entry-id") // ext.TreeNavigationResult{Success, Error}
|
||||
|
||||
// Summarize a range of the branch using LLM
|
||||
summary := ctx.SummarizeBranch("from-id", "to-id") // string
|
||||
|
||||
// Collapse a branch range into a summary entry (fresh context primitive)
|
||||
result := ctx.CollapseBranch("from-id", "to-id", "summary text")
|
||||
```
|
||||
|
||||
### Skill Loading
|
||||
|
||||
Load and inject skills dynamically at runtime:
|
||||
|
||||
```go
|
||||
// Discover skills from standard locations
|
||||
result := ctx.DiscoverSkills() // ext.SkillLoadResult{Skills, Error}
|
||||
// Standard locations: ~/.config/kit/skills/, .kit/skills/, .agents/skills/
|
||||
|
||||
// Load a specific skill file
|
||||
skill, err := ctx.LoadSkill("/path/to/skill.md") // (*ext.Skill, error string)
|
||||
// skill.Name, skill.Description, skill.Content, skill.Tags, skill.When
|
||||
|
||||
// Load all skills from a directory
|
||||
result := ctx.LoadSkillsFromDir("/path/to/skills") // ext.SkillLoadResult
|
||||
|
||||
// Inject a skill as context (pre-loads for next turn)
|
||||
err := ctx.InjectSkillAsContext("skill-name") // error string
|
||||
|
||||
// Inject a skill file directly
|
||||
err := ctx.InjectRawSkillAsContext("/path/to/skill.md") // error string
|
||||
|
||||
// Get all discovered skills
|
||||
skills := ctx.GetAvailableSkills() // []ext.Skill
|
||||
```
|
||||
|
||||
### Template Parsing
|
||||
|
||||
Parse and render templates with variable substitution:
|
||||
|
||||
```go
|
||||
// Parse a template to extract {{variables}}
|
||||
tpl := ctx.ParseTemplate("name", "Hello {{name}}, welcome to {{place}}!")
|
||||
// tpl.Name, tpl.Content, tpl.Variables ([]string)
|
||||
|
||||
// Render a template with variable values
|
||||
vars := map[string]string{"name": "Alice", "place": "Kit"}
|
||||
rendered := ctx.RenderTemplate(tpl, vars) // "Hello Alice, welcome to Kit!"
|
||||
|
||||
// Parse command-line style arguments
|
||||
pattern := ext.ArgumentPattern{
|
||||
Positional: []string{"command", "target"}, // $1, $2
|
||||
Rest: "args", // $@
|
||||
Flags: map[string]string{"--loop": "loop", "-f": "force"},
|
||||
}
|
||||
result := ctx.ParseArguments("deploy staging --loop 5", pattern)
|
||||
// result.Vars["command"] = "deploy"
|
||||
// result.Vars["target"] = "staging"
|
||||
// result.Flags["--loop"] = "5"
|
||||
|
||||
// Simple positional argument parsing ($1, $2, $@)
|
||||
args := ctx.SimpleParseArguments("deploy staging --force", 2)
|
||||
// args[0] = "deploy staging --force" (full input)
|
||||
// args[1] = "deploy" ($1)
|
||||
// args[2] = "staging" ($2)
|
||||
// args[3] = "--force" ($@)
|
||||
|
||||
// Evaluate model conditionals with wildcards
|
||||
matches := ctx.EvaluateModelConditional("claude-*") // bool
|
||||
// Patterns: * matches any, ? matches single char, comma = OR
|
||||
|
||||
// Render content with <if-model> conditionals
|
||||
content := `<if-model is="claude-*">Hi Claude<else>Hi there</if-model>`
|
||||
rendered := ctx.RenderWithModelConditionals(content) // based on current model
|
||||
```
|
||||
|
||||
### Model Resolution
|
||||
|
||||
Resolve model fallback chains and query capabilities:
|
||||
|
||||
```go
|
||||
// Resolve a chain of model preferences (tries each until available)
|
||||
result := ctx.ResolveModelChain([]string{
|
||||
"anthropic/claude-opus-4",
|
||||
"anthropic/claude-sonnet-4",
|
||||
"openai/gpt-4o",
|
||||
})
|
||||
// result.Model (selected), result.Capabilities, result.Attempted, result.Error
|
||||
|
||||
// Get capabilities for a specific model
|
||||
caps, err := ctx.GetModelCapabilities("anthropic/claude-sonnet-4")
|
||||
// caps.Provider, caps.ModelID, caps.ContextLimit, caps.Reasoning, caps.Streaming
|
||||
|
||||
// Check if a model is available (provider exists)
|
||||
available := ctx.CheckModelAvailable("anthropic/claude-sonnet-4") // bool
|
||||
|
||||
// Get current provider/model ID
|
||||
provider := ctx.GetCurrentProvider() // "anthropic"
|
||||
modelID := ctx.GetCurrentModelID() // "claude-sonnet-4"
|
||||
```
|
||||
|
||||
@@ -51,6 +51,15 @@ Kit ships with a rich set of example extensions in the `examples/extensions/` di
|
||||
| [`summarize.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/summarize.go) | Conversation summarization |
|
||||
| [`lsp-diagnostics.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/lsp-diagnostics.go) | LSP diagnostic integration |
|
||||
|
||||
## Bridged SDK APIs
|
||||
|
||||
These examples demonstrate the new bridged SDK APIs that give extensions access to internal Kit capabilities:
|
||||
|
||||
| Extension | Description |
|
||||
|-----------|-------------|
|
||||
| [`conversation-manager.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/conversation-manager.go) | **NEW** Tree navigation (`GetTreeNode`, `GetCurrentBranch`, `NavigateTo`), branch summarization (`SummarizeBranch`), and fresh context loops (`CollapseBranch`) |
|
||||
| [`prompt-templates.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/prompt-templates.go) | **NEW** Frontmatter-driven templates with model fallback chains (`ResolveModelChain`), skill injection (`InjectSkillAsContext`), and template parsing (`ParseTemplate`, `RenderTemplate`) |
|
||||
|
||||
## Themes
|
||||
|
||||
| Extension | Description |
|
||||
|
||||
+1
-1
@@ -13,7 +13,7 @@ A powerful, extensible AI coding agent CLI with multi-provider support, built-in
|
||||
## Features
|
||||
|
||||
- **Multi-Provider LLM Support** — Anthropic, OpenAI, Google Gemini, Ollama, Azure OpenAI, AWS Bedrock, OpenRouter, and more
|
||||
- **Built-in Core Tools** — bash, read, write, edit, grep, find, ls, spawn_subagent with no MCP overhead
|
||||
- **Built-in Core Tools** — bash, read, write, edit, grep, find, ls, subagent with no MCP overhead
|
||||
- **MCP Integration** — Connect external MCP servers for expanded capabilities
|
||||
- **Extension System** — Write custom tools, commands, widgets, and UI modifications in Go
|
||||
- **Interactive TUI** — Rich terminal interface powered by Bubble Tea with streaming, syntax highlighting, and custom rendering
|
||||
|
||||
@@ -5,48 +5,6 @@ description: Monitor tool calls and streaming output with the Kit Go SDK.
|
||||
|
||||
# Callbacks
|
||||
|
||||
## PromptWithCallbacks
|
||||
|
||||
The `PromptWithCallbacks` method provides real-time visibility into tool calls and streaming output:
|
||||
|
||||
```go
|
||||
response, err := host.PromptWithCallbacks(
|
||||
ctx,
|
||||
"List files in current directory",
|
||||
func(name, args string) {
|
||||
// Called when the model invokes a tool
|
||||
fmt.Println("Calling tool:", name)
|
||||
},
|
||||
func(name, args, result string, isError bool) {
|
||||
// Called when a tool returns its result
|
||||
if isError {
|
||||
fmt.Println("Tool failed:", name)
|
||||
}
|
||||
},
|
||||
func(chunk string) {
|
||||
// Called for each streaming text chunk
|
||||
fmt.Print(chunk)
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
### Callback signatures
|
||||
|
||||
| Callback | Signature | When |
|
||||
|----------|-----------|------|
|
||||
| `onToolCall` | `func(name, args string)` | Model requests a tool call |
|
||||
| `onToolResult` | `func(name, args, result string, isError bool)` | Tool execution completes |
|
||||
| `onStreaming` | `func(chunk string)` | Streaming text chunk received |
|
||||
|
||||
Any callback can be `nil` if you don't need it:
|
||||
|
||||
```go
|
||||
// Only care about streaming output
|
||||
response, err := host.PromptWithCallbacks(ctx, "Hello", nil, nil, func(chunk string) {
|
||||
fmt.Print(chunk)
|
||||
})
|
||||
```
|
||||
|
||||
## Event-based monitoring
|
||||
|
||||
For more granular control, use the event subscription API:
|
||||
@@ -116,11 +74,11 @@ The first argument is a priority (lower = runs first).
|
||||
|
||||
## Subagent event monitoring
|
||||
|
||||
Monitor real-time events from LLM-initiated subagents (when the model uses the `spawn_subagent` tool):
|
||||
Monitor real-time events from LLM-initiated subagents (when the model uses the `subagent` tool):
|
||||
|
||||
```go
|
||||
host.OnToolCall(func(e kit.ToolCallEvent) {
|
||||
if e.ToolName == "spawn_subagent" {
|
||||
if e.ToolName == "subagent" {
|
||||
host.SubscribeSubagent(e.ToolCallID, func(event kit.Event) {
|
||||
// Receives the same event types as Subscribe(), scoped to the child agent
|
||||
switch ev := event.(type) {
|
||||
|
||||
@@ -62,7 +62,6 @@ The SDK provides several prompt variants:
|
||||
| Method | Description |
|
||||
|--------|-------------|
|
||||
| `Prompt(ctx, message)` | Simple prompt, returns response string |
|
||||
| `PromptWithCallbacks(ctx, message, ...)` | With tool call and streaming callbacks |
|
||||
| `PromptWithOptions(ctx, message, opts)` | With per-call options |
|
||||
| `PromptResult(ctx, message)` | Returns full `TurnResult` with usage stats |
|
||||
| `PromptResultWithFiles(ctx, message, files)` | Multimodal with file attachments |
|
||||
|
||||
+11
-3
@@ -30,12 +30,20 @@ When conversations grow long, Kit can compact them to free up context window spa
|
||||
|
||||
Use `/compact [focus]` to manually compact, or enable `--auto-compact` to compact automatically near the context limit.
|
||||
|
||||
## Auto-cleanup
|
||||
|
||||
Kit automatically cleans up empty sessions on shutdown and when using `/resume`. A session is considered empty if it has no messages beyond the initial system prompt. This prevents cluttering your sessions directory with unused files.
|
||||
|
||||
To start fresh without creating a session file at all, use ephemeral mode:
|
||||
|
||||
```bash
|
||||
kit --no-session
|
||||
```
|
||||
|
||||
## Resuming sessions
|
||||
|
||||
### Continue most recent
|
||||
|
||||
Resume the most recent session for the current directory:
|
||||
|
||||
```bash
|
||||
kit --continue
|
||||
kit -c
|
||||
@@ -73,7 +81,7 @@ These slash commands are available during an interactive session:
|
||||
| `/share` | Upload session to GitHub Gist and get a shareable viewer URL |
|
||||
| `/tree` | Navigate the session tree |
|
||||
| `/fork` | Branch from an earlier message |
|
||||
| `/new` | Start a fresh session |
|
||||
| `/new` | Start a new session (creates new session file) |
|
||||
|
||||
## Ephemeral mode
|
||||
|
||||
|
||||
@@ -1566,7 +1566,7 @@ a:hover { text-decoration: underline; }
|
||||
'grep': '🔍',
|
||||
'find': '📁',
|
||||
'ls': '📂',
|
||||
'spawn_subagent': '🤖',
|
||||
'subagent': '🤖',
|
||||
'fetch': '🌐',
|
||||
'todo': '✅'
|
||||
};
|
||||
@@ -1612,7 +1612,7 @@ a:hover { text-decoration: underline; }
|
||||
headerLabel = formatLsHeader(input);
|
||||
bodyHtml = renderGenericBody(input, result);
|
||||
break;
|
||||
case 'spawn_subagent':
|
||||
case 'subagent':
|
||||
headerLabel = formatSubagentHeader(input);
|
||||
bodyHtml = renderSubagentBody(input, result);
|
||||
break;
|
||||
|
||||
Reference in New Issue
Block a user