mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
Compare commits
43 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d9326fcf21 | |||
| 22c479277e | |||
| 8ae204f12f | |||
| 8b1665a4ce | |||
| 941f1daf0b | |||
| ab7e2bda61 | |||
| 741520927c | |||
| 4c1bda9541 | |||
| 3b69b13556 | |||
| 83a959a379 | |||
| 3491e05e9e | |||
| 0a54a8aa05 | |||
| 3cb3e5dba1 | |||
| 31966c469f | |||
| f03625d6e5 | |||
| d06641dc0a | |||
| bbf1106e27 | |||
| babed03a3d | |||
| 1cd074836f | |||
| ab3ce260c8 | |||
| 8e8cc3946d | |||
| e18e36625e | |||
| be55bc03f1 | |||
| 09919b6307 | |||
| 7a2de4cc3c | |||
| acd7fd7f45 | |||
| 3446f38516 | |||
| db4bb19bac | |||
| d1cffb85ef | |||
| 329cd4ea4a | |||
| 4e779d576f | |||
| fc054f50e8 | |||
| d8f1b32885 | |||
| 1e2a3e2589 | |||
| c7f43917b1 | |||
| 6a8833a7b1 | |||
| 82cbf1d457 | |||
| ab09d5c9e4 | |||
| 2347e0e506 | |||
| 3e1c19442b | |||
| 3fc0ad906e | |||
| f373c34f54 | |||
| 1206837af4 |
@@ -107,13 +107,8 @@ func resolveGoFilePath(inputJSON, cwd string) (string, bool) {
|
||||
}
|
||||
|
||||
func runGoDiagnostics(cwd, absPath string) string {
|
||||
target := absPath
|
||||
if rel, err := filepath.Rel(cwd, absPath); err == nil && !strings.HasPrefix(rel, "..") {
|
||||
target = rel
|
||||
}
|
||||
|
||||
gopls := runGopls(cwd, absPath)
|
||||
lint := runGolangCILint(cwd, target)
|
||||
lint := runGolangCILint(cwd, "./...")
|
||||
|
||||
return fmt.Sprintf(
|
||||
"<go_diagnostics file=%q>\n[gopls]\n%s\n\n[golangci-lint]\n%s\n</go_diagnostics>",
|
||||
|
||||
@@ -0,0 +1,304 @@
|
||||
//go:build ignore
|
||||
|
||||
// subagent-monitor — live horizontal widget strip for spawned subagents
|
||||
//
|
||||
// Subscribes to subagents spawned by the main Kit agent and displays a
|
||||
// single widget just above the input box. Each subagent occupies one column
|
||||
// in a side-by-side horizontal layout. Columns show scrolling real-time
|
||||
// output as the subagent works. When a subagent finishes its column is
|
||||
// removed automatically.
|
||||
//
|
||||
// Yaegi-safe design notes:
|
||||
// - No sync.Mutex (Yaegi has reflection issues with sync primitives)
|
||||
// - No channels in maps (Yaegi panics on range over map[string]chan)
|
||||
// - All ctx.* calls guarded with nil checks
|
||||
// - Simple data structures only
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Per-subagent state
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type submonEntry struct {
|
||||
id int
|
||||
callID string
|
||||
task string
|
||||
lines []string
|
||||
started time.Time
|
||||
elapsed time.Duration
|
||||
}
|
||||
|
||||
const (
|
||||
submonColWidth = 34 // visible character width per column
|
||||
submonMaxLines = 5 // scrolling output lines per column
|
||||
submonColGap = 2 // spaces between columns
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Package-level state - all simple types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
var (
|
||||
submonCtx ext.Context
|
||||
submonHasCtx bool
|
||||
submonEntries []*submonEntry
|
||||
submonNextID int
|
||||
)
|
||||
|
||||
func submonInit() {
|
||||
submonEntries = nil
|
||||
submonNextID = 1
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// String helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func submonPad(s string, w int) string {
|
||||
r := []rune(s)
|
||||
if len(r) >= w {
|
||||
return string(r[:w])
|
||||
}
|
||||
return s + strings.Repeat(" ", w-len(r))
|
||||
}
|
||||
|
||||
func submonTrunc(s string, w int) string {
|
||||
r := []rune(s)
|
||||
if len(r) <= w {
|
||||
return s
|
||||
}
|
||||
if w <= 1 {
|
||||
return "…"
|
||||
}
|
||||
return string(r[:w-1]) + "…"
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Widget rendering
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func submonRenderColumn(e *submonEntry) []string {
|
||||
var rows []string
|
||||
|
||||
// Calculate elapsed time on-demand to avoid race conditions with ticker
|
||||
elapsed := e.elapsed
|
||||
if elapsed == 0 && !e.started.IsZero() {
|
||||
elapsed = time.Since(e.started)
|
||||
}
|
||||
secs := int(elapsed.Seconds())
|
||||
timeStr := fmt.Sprintf("%ds", secs)
|
||||
taskMax := submonColWidth - len(timeStr) - 3
|
||||
taskPart := submonTrunc(e.task, taskMax)
|
||||
header := fmt.Sprintf("#%d %s %s", e.id, taskPart, timeStr)
|
||||
rows = append(rows, submonPad(header, submonColWidth))
|
||||
|
||||
display := e.lines
|
||||
if len(display) > submonMaxLines {
|
||||
display = display[len(display)-submonMaxLines:]
|
||||
}
|
||||
for _, l := range display {
|
||||
rows = append(rows, submonPad(" "+submonTrunc(l, submonColWidth-2), submonColWidth))
|
||||
}
|
||||
for len(rows) < submonMaxLines+1 {
|
||||
if len(rows) == 1 && len(e.lines) == 0 {
|
||||
rows = append(rows, submonPad(" waiting…", submonColWidth))
|
||||
} else {
|
||||
rows = append(rows, strings.Repeat(" ", submonColWidth))
|
||||
}
|
||||
}
|
||||
return rows
|
||||
}
|
||||
|
||||
func submonBuildWidget() string {
|
||||
if len(submonEntries) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
numCols := len(submonEntries)
|
||||
numRows := submonMaxLines + 1
|
||||
cols := make([][]string, numCols)
|
||||
for i, e := range submonEntries {
|
||||
rows := submonRenderColumn(e)
|
||||
col := make([]string, numRows)
|
||||
for j := 0; j < numRows; j++ {
|
||||
if j < len(rows) {
|
||||
col[j] = rows[j]
|
||||
} else {
|
||||
col[j] = strings.Repeat(" ", submonColWidth)
|
||||
}
|
||||
}
|
||||
cols[i] = col
|
||||
}
|
||||
|
||||
gap := strings.Repeat(" ", submonColGap)
|
||||
var sb strings.Builder
|
||||
for row := 0; row < numRows; row++ {
|
||||
for ci := range cols {
|
||||
if ci > 0 {
|
||||
sb.WriteString(gap)
|
||||
}
|
||||
sb.WriteString(cols[ci][row])
|
||||
}
|
||||
if row < numRows-1 {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func submonPushWidget() {
|
||||
if !submonHasCtx {
|
||||
return
|
||||
}
|
||||
if submonCtx.SetWidget == nil {
|
||||
return
|
||||
}
|
||||
|
||||
text := submonBuildWidget()
|
||||
if len(submonEntries) == 0 {
|
||||
if submonCtx.RemoveWidget != nil {
|
||||
submonCtx.RemoveWidget("submon")
|
||||
}
|
||||
return
|
||||
}
|
||||
submonCtx.SetWidget(ext.WidgetConfig{
|
||||
ID: "submon",
|
||||
Placement: ext.WidgetAbove,
|
||||
Content: ext.WidgetContent{Text: text},
|
||||
Style: ext.WidgetStyle{BorderColor: "#89b4fa"},
|
||||
Priority: 0,
|
||||
})
|
||||
}
|
||||
|
||||
func submonAppendLine(e *submonEntry, line string) {
|
||||
line = strings.TrimRight(line, "\r\n")
|
||||
if strings.TrimSpace(line) == "" {
|
||||
return
|
||||
}
|
||||
e.lines = append(e.lines, line)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Init
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func Init(api ext.API) {
|
||||
submonInit()
|
||||
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
submonCtx = ctx
|
||||
submonHasCtx = true
|
||||
submonInit()
|
||||
if ctx.RemoveWidget != nil {
|
||||
ctx.RemoveWidget("submon")
|
||||
}
|
||||
})
|
||||
|
||||
api.OnAgentEnd(func(_ ext.AgentEndEvent, ctx ext.Context) {
|
||||
submonCtx = ctx
|
||||
submonHasCtx = true
|
||||
})
|
||||
|
||||
// ── SubagentStart ────────────────────────────────────────────────────────
|
||||
api.OnSubagentStart(func(e ext.SubagentStartEvent, ctx ext.Context) {
|
||||
submonCtx = ctx
|
||||
submonHasCtx = true
|
||||
|
||||
id := submonNextID
|
||||
submonNextID++
|
||||
entry := &submonEntry{
|
||||
id: id,
|
||||
callID: e.ToolCallID,
|
||||
task: e.Task,
|
||||
started: time.Now(),
|
||||
}
|
||||
submonEntries = append(submonEntries, entry)
|
||||
|
||||
submonPushWidget()
|
||||
})
|
||||
|
||||
// ── SubagentChunk ────────────────────────────────────────────────────────
|
||||
api.OnSubagentChunk(func(e ext.SubagentChunkEvent, ctx ext.Context) {
|
||||
submonCtx = ctx
|
||||
submonHasCtx = true
|
||||
|
||||
var entry *submonEntry
|
||||
for _, en := range submonEntries {
|
||||
if en.callID == e.ToolCallID {
|
||||
entry = en
|
||||
break
|
||||
}
|
||||
}
|
||||
if entry == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch e.ChunkType {
|
||||
case "text":
|
||||
for _, line := range strings.Split(e.Content, "\n") {
|
||||
submonAppendLine(entry, line)
|
||||
}
|
||||
case "tool_call":
|
||||
submonAppendLine(entry, "→ "+e.ToolName)
|
||||
case "tool_execution_start":
|
||||
submonAppendLine(entry, "⚙ "+e.ToolName)
|
||||
case "tool_result":
|
||||
if e.IsError {
|
||||
submonAppendLine(entry, "✗ "+e.ToolName)
|
||||
} else {
|
||||
submonAppendLine(entry, "✓ "+e.ToolName)
|
||||
}
|
||||
}
|
||||
|
||||
submonPushWidget()
|
||||
})
|
||||
|
||||
// ── SubagentEnd ──────────────────────────────────────────────────────────
|
||||
api.OnSubagentEnd(func(e ext.SubagentEndEvent, ctx ext.Context) {
|
||||
submonCtx = ctx
|
||||
submonHasCtx = true
|
||||
|
||||
var entry *submonEntry
|
||||
for _, en := range submonEntries {
|
||||
if en.callID == e.ToolCallID {
|
||||
entry = en
|
||||
break
|
||||
}
|
||||
}
|
||||
if entry != nil {
|
||||
entry.elapsed = time.Since(entry.started)
|
||||
if e.ErrorMsg != "" {
|
||||
submonAppendLine(entry, "✗ "+submonTrunc(e.ErrorMsg, submonColWidth-2))
|
||||
}
|
||||
}
|
||||
|
||||
submonPushWidget()
|
||||
|
||||
// Remove the entry immediately (no goroutine to avoid races)
|
||||
newEntries := submonEntries[:0]
|
||||
for _, en := range submonEntries {
|
||||
if en.callID != e.ToolCallID {
|
||||
newEntries = append(newEntries, en)
|
||||
}
|
||||
}
|
||||
submonEntries = newEntries
|
||||
submonPushWidget()
|
||||
})
|
||||
|
||||
// ── SessionShutdown ──────────────────────────────────────────────────────
|
||||
api.OnSessionShutdown(func(_ ext.SessionShutdownEvent, ctx ext.Context) {
|
||||
submonInit()
|
||||
// Guard ctx access - may be nil during shutdown
|
||||
if ctx.RemoveWidget != nil {
|
||||
ctx.RemoveWidget("submon")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -287,7 +287,7 @@ kit -e examples/extensions/minimal.go
|
||||
|
||||
### Extension Capabilities
|
||||
|
||||
**Lifecycle Events**: OnSessionStart, OnSessionShutdown, OnBeforeAgentStart, OnAgentStart, OnAgentEnd, OnToolCall, OnToolExecutionStart, OnToolExecutionEnd, OnToolResult, OnInput, OnMessageStart, OnMessageUpdate, OnMessageEnd, OnModelChange, OnContextPrepare, OnBeforeFork, OnBeforeSessionSwitch, OnBeforeCompact
|
||||
**Lifecycle Events**: OnSessionStart, OnSessionShutdown, OnBeforeAgentStart, OnAgentStart, OnAgentEnd, OnToolCall, OnToolExecutionStart, OnToolOutput, OnToolExecutionEnd, OnToolResult, OnInput, OnMessageStart, OnMessageUpdate, OnMessageEnd, OnModelChange, OnContextPrepare, OnBeforeFork, OnBeforeSessionSwitch, OnBeforeCompact, OnCustomEvent, OnSubagentStart, OnSubagentChunk, OnSubagentEnd
|
||||
|
||||
**Custom Components**:
|
||||
- **Tools**: Add new tools the LLM can invoke
|
||||
@@ -336,12 +336,13 @@ See the `examples/extensions/` directory:
|
||||
- `subagent-widget.go` - Multi-agent orchestration with status widget
|
||||
- `subagent-test.go` - Subagent testing utilities
|
||||
- `summarize.go` - Conversation summarization
|
||||
- `go-edit-lint.go` - LSP diagnostic integration with TUI visibility
|
||||
- `tool-logger.go` - Log all tool calls
|
||||
- `neon-theme.go` - Custom theme registration and switching
|
||||
- `tool-renderer-demo.go` - Custom tool call rendering
|
||||
- `widget-status.go` - Persistent status widgets
|
||||
|
||||
Also see `.kit/extensions/go-edit-lint.go` (in this repo) for a project-local extension example that runs gopls and golangci-lint on Go file edits.
|
||||
|
||||
### Loading Extensions
|
||||
|
||||
**Auto-discovery** (loads automatically):
|
||||
@@ -703,8 +704,24 @@ npm/ - NPM package wrapper for distribution
|
||||
- **Google Vertex** - Claude on Vertex AI
|
||||
- **OpenRouter** - Multi-provider router
|
||||
- **Vercel AI** - Vercel AI SDK models
|
||||
- **Custom** - Any OpenAI-compatible endpoint via `--provider-url`
|
||||
- **Auto-routed** - Any provider from models.dev database
|
||||
|
||||
### Custom Provider
|
||||
|
||||
Use `custom/custom` when pointing Kit at any OpenAI-compatible endpoint with `--provider-url`:
|
||||
|
||||
```bash
|
||||
kit --provider-url "http://localhost:8080/v1" "Hello"
|
||||
```
|
||||
|
||||
This automatically defaults to `custom/custom` without needing to specify a model. The custom provider routes through fantasy's `openaicompat` provider and supports:
|
||||
|
||||
- Zero cost tracking (input/output = 0)
|
||||
- 262K context window, 65K output limit
|
||||
- Reasoning and temperature support
|
||||
- Optional `CUSTOM_API_KEY` environment variable or `--provider-api-key` flag
|
||||
|
||||
### Model String Format
|
||||
|
||||
```bash
|
||||
|
||||
+300
-7
@@ -1,9 +1,13 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/huh/v2"
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
@@ -14,7 +18,7 @@ import (
|
||||
// authCmd represents the auth command for managing AI provider authentication.
|
||||
// This command provides subcommands for login, logout, and status checking
|
||||
// of authentication credentials for various AI providers, with OAuth support
|
||||
// for providers like Anthropic.
|
||||
// for providers like Anthropic and OpenAI.
|
||||
var authCmd = &cobra.Command{
|
||||
Use: "auth",
|
||||
Short: "Manage authentication credentials for AI providers",
|
||||
@@ -25,9 +29,11 @@ using OAuth flows. Stored credentials take precedence over environment variables
|
||||
|
||||
Available providers:
|
||||
- anthropic: Anthropic Claude API (OAuth)
|
||||
- openai: OpenAI API (OAuth and API key)
|
||||
|
||||
Examples:
|
||||
kit auth login anthropic
|
||||
kit auth login openai
|
||||
kit auth logout anthropic
|
||||
kit auth status`,
|
||||
}
|
||||
@@ -46,9 +52,11 @@ environment variables when making API calls.
|
||||
|
||||
Available providers:
|
||||
- anthropic: Anthropic Claude API (OAuth)
|
||||
- openai: OpenAI ChatGPT Plus/Pro (Codex OAuth)
|
||||
|
||||
Example:
|
||||
kit auth login anthropic`,
|
||||
kit auth login anthropic
|
||||
kit auth login openai`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runAuthLogin,
|
||||
}
|
||||
@@ -61,14 +69,16 @@ var authLogoutCmd = &cobra.Command{
|
||||
Short: "Remove stored authentication credentials for a provider",
|
||||
Long: `Remove stored authentication credentials for an AI provider.
|
||||
|
||||
This will delete the stored API key for the specified provider. You will need
|
||||
to use environment variables or command-line flags for authentication after logout.
|
||||
This will delete the stored API key or OAuth credentials for the specified provider.
|
||||
You will need to use environment variables or command-line flags for authentication after logout.
|
||||
|
||||
Available providers:
|
||||
- anthropic: Anthropic Claude API
|
||||
- openai: OpenAI API
|
||||
|
||||
Example:
|
||||
kit auth logout anthropic`,
|
||||
kit auth logout anthropic
|
||||
kit auth logout openai`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runAuthLogout,
|
||||
}
|
||||
@@ -101,8 +111,10 @@ func runAuthLogin(cmd *cobra.Command, args []string) error {
|
||||
switch provider {
|
||||
case "anthropic":
|
||||
return loginAnthropic()
|
||||
case "openai":
|
||||
return loginOpenAI()
|
||||
default:
|
||||
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic", provider)
|
||||
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic, openai", provider)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,8 +124,10 @@ func runAuthLogout(cmd *cobra.Command, args []string) error {
|
||||
switch provider {
|
||||
case "anthropic":
|
||||
return logoutAnthropic()
|
||||
case "openai":
|
||||
return logoutOpenAI()
|
||||
default:
|
||||
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic", provider)
|
||||
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic, openai", provider)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,8 +171,44 @@ func runAuthStatus(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Check OpenAI credentials
|
||||
fmt.Print("\nOpenAI: ")
|
||||
if hasOpenAICreds, err := cm.HasOpenAICredentials(); err != nil {
|
||||
fmt.Printf("Error checking credentials: %v\n", err)
|
||||
} else if hasOpenAICreds {
|
||||
if creds, err := cm.GetOpenAICredentials(); err != nil {
|
||||
fmt.Printf("Error reading credentials: %v\n", err)
|
||||
} else {
|
||||
authType := "API Key"
|
||||
status := "✓ Authenticated"
|
||||
|
||||
if creds.Type == "oauth" {
|
||||
authType = "OAuth (ChatGPT/Codex)"
|
||||
if creds.IsExpired() {
|
||||
status = "⚠️ Token expired (will refresh automatically)"
|
||||
} else if creds.NeedsRefresh() {
|
||||
status = "⚠️ Token expires soon (will refresh automatically)"
|
||||
}
|
||||
}
|
||||
|
||||
accountInfo := ""
|
||||
if creds.Type == "oauth" && creds.AccountID != "" {
|
||||
accountInfo = fmt.Sprintf(" [%s]", creds.AccountID)
|
||||
}
|
||||
|
||||
fmt.Printf("%s (%s%s, stored %s)\n", status, authType, accountInfo, creds.CreatedAt.Format("2006-01-02 15:04:05"))
|
||||
}
|
||||
} else {
|
||||
fmt.Println("✗ Not authenticated")
|
||||
// Check if environment variable is set
|
||||
if os.Getenv("OPENAI_API_KEY") != "" {
|
||||
fmt.Println(" (OPENAI_API_KEY environment variable is set)")
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("\nTo authenticate with a provider:")
|
||||
fmt.Println(" kit auth login anthropic")
|
||||
fmt.Println(" kit auth login openai")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -282,3 +332,246 @@ func logoutAnthropic() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func loginOpenAI() error {
|
||||
cm, err := kit.NewCredentialManager()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize credential manager: %w", err)
|
||||
}
|
||||
|
||||
// Check if already authenticated
|
||||
if hasAuth, err := cm.HasOpenAICredentials(); err == nil && hasAuth {
|
||||
var reauth bool
|
||||
err := huh.NewConfirm().
|
||||
Title("You are already authenticated with OpenAI (ChatGPT/Codex)").
|
||||
Description("Do you want to re-authenticate?").
|
||||
Affirmative("Yes").
|
||||
Negative("No").
|
||||
Value(&reauth).
|
||||
Run()
|
||||
if err != nil || !reauth {
|
||||
fmt.Println("Authentication cancelled.")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Create OAuth client
|
||||
client := auth.NewOpenAIOAuthClient()
|
||||
|
||||
// Generate authorization URL
|
||||
fmt.Println("🔐 Starting OAuth authentication with OpenAI (ChatGPT/Codex)...")
|
||||
fmt.Println("This will open your browser to authenticate with your ChatGPT account.")
|
||||
fmt.Println()
|
||||
|
||||
authData, err := client.GetAuthorizationURL()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate authorization URL: %w", err)
|
||||
}
|
||||
|
||||
// Start local callback server
|
||||
callbackServer, err := startOpenAICallbackServer(authData.State)
|
||||
if err != nil {
|
||||
fmt.Printf("⚠️ Could not start local callback server: %v\n", err)
|
||||
fmt.Println("Falling back to manual code entry.")
|
||||
}
|
||||
if callbackServer != nil {
|
||||
defer callbackServer.Close()
|
||||
}
|
||||
|
||||
// Display URL and try to open browser
|
||||
fmt.Println("📱 Opening your browser for authentication...")
|
||||
fmt.Println("If the browser doesn't open automatically, please visit this URL:")
|
||||
fmt.Printf("\n%s\n\n", authData.URL)
|
||||
|
||||
// Try to open browser
|
||||
auth.TryOpenBrowser(authData.URL)
|
||||
|
||||
// Wait for callback or manual input
|
||||
var code string
|
||||
if callbackServer != nil {
|
||||
fmt.Println("Waiting for browser authentication...")
|
||||
select {
|
||||
case callbackCode := <-callbackServer.CodeChan:
|
||||
if callbackCode != "" {
|
||||
code = callbackCode
|
||||
fmt.Println("✓ Received authorization code from browser callback.")
|
||||
}
|
||||
case <-time.After(2 * time.Minute):
|
||||
fmt.Println("\n⏱️ Timeout waiting for browser callback.")
|
||||
callbackServer.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// If no code from callback, prompt for manual entry
|
||||
if code == "" {
|
||||
fmt.Println("\nAfter authorizing, paste the callback URL or authorization code below.")
|
||||
fmt.Println("(The callback URL will look like: http://localhost:1455/auth/callback?code=...&state=...)")
|
||||
fmt.Println()
|
||||
|
||||
var input string
|
||||
err = huh.NewInput().
|
||||
Title("Callback URL or Code").
|
||||
Description("Paste the full callback URL or just the authorization code").
|
||||
Value(&input).
|
||||
Run()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read input: %w", err)
|
||||
}
|
||||
input = strings.TrimSpace(input)
|
||||
|
||||
if input == "" {
|
||||
return fmt.Errorf("authorization code cannot be empty")
|
||||
}
|
||||
|
||||
// Parse the input (could be full URL or just code)
|
||||
parsedCode, parsedState := auth.ParseOpenAIAuthorizationInput(input)
|
||||
if parsedCode == "" {
|
||||
return fmt.Errorf("could not extract authorization code from input")
|
||||
}
|
||||
|
||||
// Validate state if provided
|
||||
if parsedState != "" && parsedState != authData.State {
|
||||
return fmt.Errorf("state mismatch - possible security issue")
|
||||
}
|
||||
code = parsedCode
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
fmt.Println("\n🔄 Exchanging authorization code for access token...")
|
||||
creds, err := client.ExchangeCode(code, authData.Verifier)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to exchange authorization code: %w", err)
|
||||
}
|
||||
|
||||
// Store the credentials
|
||||
if err := cm.SetOpenAIOAuthCredentials(creds); err != nil {
|
||||
return fmt.Errorf("failed to store credentials: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("✅ Successfully authenticated with OpenAI (ChatGPT/Codex)!")
|
||||
fmt.Printf("📁 Credentials stored in: %s\n", cm.GetCredentialsPath())
|
||||
fmt.Printf("👤 Account ID: %s\n", creds.AccountID)
|
||||
fmt.Println("\n🎉 Your OAuth credentials will now be used for OpenAI API calls.")
|
||||
fmt.Println("💡 You can check your authentication status with: kit auth status")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// callbackServer holds the HTTP server and channel for receiving the OAuth callback
|
||||
type callbackServer struct {
|
||||
Server *http.Server
|
||||
CodeChan chan string
|
||||
State string
|
||||
}
|
||||
|
||||
// Close shuts down the callback server
|
||||
func (cs *callbackServer) Close() {
|
||||
if cs.Server != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = cs.Server.Shutdown(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// startOpenAICallbackServer starts a local HTTP server to receive the OAuth callback
|
||||
func startOpenAICallbackServer(expectedState string) (*callbackServer, error) {
|
||||
codeChan := make(chan string, 1)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
server := &http.Server{
|
||||
Addr: "127.0.0.1:1455",
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
mux.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check state
|
||||
state := r.URL.Query().Get("state")
|
||||
if state != expectedState {
|
||||
http.Error(w, "State mismatch", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
code := r.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
http.Error(w, "Missing authorization code", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Send code to channel
|
||||
select {
|
||||
case codeChan <- code:
|
||||
default:
|
||||
}
|
||||
|
||||
// Return success page
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = fmt.Fprintf(w, `<!DOCTYPE html>
|
||||
<html>
|
||||
<head><title>Authentication Successful</title></head>
|
||||
<body style="font-family: sans-serif; text-align: center; padding: 50px;">
|
||||
<h1>✓ Authentication Successful</h1>
|
||||
<p>You can close this window and return to the terminal.</p>
|
||||
</body>
|
||||
</html>`)
|
||||
})
|
||||
|
||||
// Try to start server
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:1455")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("port 1455 not available: %w", err)
|
||||
}
|
||||
_ = listener.Close()
|
||||
|
||||
go func() {
|
||||
_ = server.ListenAndServe()
|
||||
}()
|
||||
|
||||
return &callbackServer{
|
||||
Server: server,
|
||||
CodeChan: codeChan,
|
||||
State: expectedState,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func logoutOpenAI() error {
|
||||
cm, err := kit.NewCredentialManager()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize credential manager: %w", err)
|
||||
}
|
||||
|
||||
// Check if authenticated
|
||||
hasAuth, err := cm.HasOpenAICredentials()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check authentication status: %w", err)
|
||||
}
|
||||
|
||||
if !hasAuth {
|
||||
fmt.Println("You are not currently authenticated with OpenAI.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Confirm logout
|
||||
var confirm bool
|
||||
err = huh.NewConfirm().
|
||||
Title("Remove OpenAI credentials").
|
||||
Description("Are you sure you want to remove your stored credentials?").
|
||||
Affirmative("Yes").
|
||||
Negative("No").
|
||||
Value(&confirm).
|
||||
Run()
|
||||
if err != nil || !confirm {
|
||||
fmt.Println("Logout cancelled.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove credentials
|
||||
if err := cm.RemoveOpenAICredentials(); err != nil {
|
||||
return fmt.Errorf("failed to remove credentials: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("✓ Successfully logged out from OpenAI!")
|
||||
fmt.Println("You will need to use environment variables or command-line flags for authentication.")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
@@ -47,6 +48,9 @@ func runModels(_ *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
func printAllProviders(showAll bool) error {
|
||||
// Reload the registry to pick up any custom models from config
|
||||
models.ReloadGlobalRegistry()
|
||||
|
||||
var providerIDs []string
|
||||
if showAll {
|
||||
providerIDs = kit.GetSupportedProviders()
|
||||
@@ -98,6 +102,9 @@ func printAllProviders(showAll bool) error {
|
||||
}
|
||||
|
||||
func printProvider(provider string) error {
|
||||
// Reload the registry to pick up any custom models from config
|
||||
models.ReloadGlobalRegistry()
|
||||
|
||||
m, err := kit.GetModelsForProvider(provider)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unknown provider %q. Run 'kit models' to see all providers", provider)
|
||||
|
||||
+48
-1
@@ -13,6 +13,7 @@ import (
|
||||
"charm.land/fantasy"
|
||||
"charm.land/lipgloss/v2"
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
@@ -689,6 +690,16 @@ func runNormalMode(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
// When --provider-url is set but no explicit --model was provided,
|
||||
// default to "custom/custom" so the user doesn't need to remember a
|
||||
// provider/model pair for custom OpenAI-compatible endpoints.
|
||||
// This intentionally overrides saved preferences but respects config-file
|
||||
// models — if you specify a model in ~/.kit.yml, it will be used with
|
||||
// custom/custom's provider routing.
|
||||
if viper.GetString("provider-url") != "" && !modelFlagChanged && !viper.InConfig("model") {
|
||||
viper.Set("model", "custom/custom")
|
||||
}
|
||||
|
||||
// Load MCP configuration.
|
||||
mcpConfig, err := config.LoadAndValidateConfig()
|
||||
if err != nil {
|
||||
@@ -800,7 +811,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
PrintError: func(text string) { appInstance.PrintFromExtension("error", text) },
|
||||
PrintBlock: appInstance.PrintBlockFromExtension,
|
||||
SendMessage: func(text string) { appInstance.Run(text) },
|
||||
CancelAndSend: func(text string) { appInstance.Steer(text) },
|
||||
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
SetWidget: func(config extensions.WidgetConfig) {
|
||||
kitInstance.SetExtensionWidget(config)
|
||||
@@ -945,6 +956,24 @@ func runNormalMode(ctx context.Context) error {
|
||||
kitInstance.UpdateExtensionContextModel(modelString)
|
||||
// Fire OnModelChange event to extensions.
|
||||
kitInstance.EmitModelChange(modelString, previousModel, "extension")
|
||||
// Update usage tracker with new model info for correct token counting.
|
||||
if usageTracker != nil {
|
||||
newProvider, newModel, _ := models.ParseModelString(modelString)
|
||||
if newProvider != "unknown" && newModel != "unknown" && newProvider != "ollama" {
|
||||
registry := models.GetGlobalRegistry()
|
||||
if modelInfo := registry.LookupModel(newProvider, newModel); modelInfo != nil {
|
||||
// Check OAuth status for Anthropic models
|
||||
isOAuth := false
|
||||
if newProvider == "anthropic" {
|
||||
_, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key"))
|
||||
if err == nil && strings.HasPrefix(source, "stored OAuth") {
|
||||
isOAuth = true
|
||||
}
|
||||
}
|
||||
usageTracker.UpdateModelInfo(modelInfo, newProvider, isOAuth)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetAvailableModels: func() []extensions.ModelInfoEntry {
|
||||
@@ -1142,6 +1171,24 @@ func runNormalMode(ctx context.Context) error {
|
||||
// this callback runs synchronously inside BubbleTea's Update(), and
|
||||
// NotifyModelChanged calls prog.Send() which deadlocks. The UI layer
|
||||
// updates m.providerName and m.modelName directly after setModel returns.
|
||||
// Update usage tracker with new model info for correct token counting.
|
||||
if usageTracker != nil {
|
||||
newProvider, newModel, _ := models.ParseModelString(modelString)
|
||||
if newProvider != "unknown" && newModel != "unknown" && newProvider != "ollama" {
|
||||
registry := models.GetGlobalRegistry()
|
||||
if modelInfo := registry.LookupModel(newProvider, newModel); modelInfo != nil {
|
||||
// Check OAuth status for Anthropic models
|
||||
isOAuth := false
|
||||
if newProvider == "anthropic" {
|
||||
_, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key"))
|
||||
if err == nil && strings.HasPrefix(source, "stored OAuth") {
|
||||
isOAuth = true
|
||||
}
|
||||
}
|
||||
usageTracker.UpdateModelInfo(modelInfo, newProvider, isOAuth)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
emitModelChangeForUI := func(newModel, previousModel, source string) {
|
||||
|
||||
+12
-12
@@ -8,19 +8,21 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// skillCmd installs the kit-extensions skill via the skills.sh CLI (npx skills).
|
||||
// This teaches AI agents how to create Kit extensions with full knowledge of
|
||||
// the extension API, lifecycle events, widgets, tools, commands, and Yaegi constraints.
|
||||
// skillCmd installs Kit skills via the skills.sh CLI (npx skills).
|
||||
var skillCmd = &cobra.Command{
|
||||
Use: "skill",
|
||||
Short: "Install the Kit extensions skill via skills.sh",
|
||||
Long: `Install the kit-extensions skill that teaches AI agents how to create
|
||||
Kit extensions. Uses the skills.sh CLI (npx skills) to install the skill
|
||||
from the Kit repository.
|
||||
Short: "Install Kit skills via skills.sh",
|
||||
Long: `Install Kit skills that teach AI agents how to build with Kit.
|
||||
Uses the skills.sh CLI (npx skills) to install all skills from the Kit repository.
|
||||
|
||||
The skill provides comprehensive documentation of Kit's extension API including
|
||||
lifecycle events, custom tools, slash commands, widgets, editor interceptors,
|
||||
tool renderers, and critical Yaegi interpreter constraints.
|
||||
Two skills are provided:
|
||||
|
||||
1. Extensions — creating Kit extensions with full knowledge of the extension
|
||||
API, lifecycle events, widgets, tools, commands, editor interceptors,
|
||||
tool renderers, and Yaegi interpreter constraints.
|
||||
|
||||
2. SDK — building AI-powered applications with the Kit Go SDK, including
|
||||
providers, agents, tools, and MCP integration.
|
||||
|
||||
Example:
|
||||
kit skill`,
|
||||
@@ -41,8 +43,6 @@ func runSkill(_ *cobra.Command, _ []string) error {
|
||||
"skills",
|
||||
"add",
|
||||
"mark3labs/kit",
|
||||
"--skill",
|
||||
"kit-extensions",
|
||||
}
|
||||
|
||||
cmd := exec.Command(npx, args...)
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/pkg/extensions/test"
|
||||
)
|
||||
|
||||
// TestSubagentMonitor_SessionStart verifies OnSessionStart initializes state
|
||||
// without panicking and properly guards nil ctx calls.
|
||||
func TestSubagentMonitor_SessionStart(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
|
||||
// Emit SessionStart - should not panic even with nil ctx functions
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionStart should not error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubagentMonitor_SubagentLifecycle verifies the full subagent lifecycle
|
||||
// creates entries and emits widget updates.
|
||||
func TestSubagentMonitor_SubagentLifecycle(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
|
||||
// Start session
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionStart should not error: %v", err)
|
||||
}
|
||||
|
||||
// Emit SubagentStart
|
||||
_, err = harness.Emit(extensions.SubagentStartEvent{
|
||||
ToolCallID: "call-1",
|
||||
Task: "test task",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentStart should not error: %v", err)
|
||||
}
|
||||
|
||||
// Emit a few chunks
|
||||
for i := range 3 {
|
||||
_, err = harness.Emit(extensions.SubagentChunkEvent{
|
||||
ToolCallID: "call-1",
|
||||
Task: "test task",
|
||||
ChunkType: "text",
|
||||
Content: fmt.Sprintf("line %d", i),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentChunk %d should not error: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Emit tool call chunk
|
||||
_, err = harness.Emit(extensions.SubagentChunkEvent{
|
||||
ToolCallID: "call-1",
|
||||
Task: "test task",
|
||||
ChunkType: "tool_call",
|
||||
ToolName: "bash",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentChunk tool_call should not error: %v", err)
|
||||
}
|
||||
|
||||
// Emit SubagentEnd
|
||||
_, err = harness.Emit(extensions.SubagentEndEvent{
|
||||
ToolCallID: "call-1",
|
||||
Task: "test task",
|
||||
Response: "done",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentEnd should not error: %v", err)
|
||||
}
|
||||
|
||||
// Give time for cleanup goroutine
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestSubagentMonitor_MultipleSubagents verifies multiple parallel subagents.
|
||||
func TestSubagentMonitor_MultipleSubagents(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionStart should not error: %v", err)
|
||||
}
|
||||
|
||||
// Start 3 subagents
|
||||
for i := 1; i <= 3; i++ {
|
||||
_, err := harness.Emit(extensions.SubagentStartEvent{
|
||||
ToolCallID: fmt.Sprintf("call-%d", i),
|
||||
Task: fmt.Sprintf("task %d", i),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentStart %d should not error: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Emit chunks for each
|
||||
for i := 1; i <= 3; i++ {
|
||||
_, err := harness.Emit(extensions.SubagentChunkEvent{
|
||||
ToolCallID: fmt.Sprintf("call-%d", i),
|
||||
Task: fmt.Sprintf("task %d", i),
|
||||
ChunkType: "text",
|
||||
Content: fmt.Sprintf("output from agent %d", i),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentChunk %d should not error: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// End all subagents
|
||||
for i := 1; i <= 3; i++ {
|
||||
_, err := harness.Emit(extensions.SubagentEndEvent{
|
||||
ToolCallID: fmt.Sprintf("call-%d", i),
|
||||
Task: fmt.Sprintf("task %d", i),
|
||||
Response: "completed",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentEnd %d should not error: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestSubagentMonitor_SessionShutdown verifies shutdown doesn't panic
|
||||
// even with nil ctx functions.
|
||||
func TestSubagentMonitor_SessionShutdown(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
|
||||
// Start then shutdown
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionStart should not error: %v", err)
|
||||
}
|
||||
|
||||
// Start a subagent
|
||||
_, err = harness.Emit(extensions.SubagentStartEvent{
|
||||
ToolCallID: "call-1",
|
||||
Task: "test task",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentStart should not error: %v", err)
|
||||
}
|
||||
|
||||
// Shutdown - should not panic even with active subagent
|
||||
_, err = harness.Emit(extensions.SessionShutdownEvent{})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionShutdown should not error: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
# SDK Examples
|
||||
|
||||
These examples demonstrate how to use the Kit SDK (`pkg/kit`) to build agents programmatically in Go.
|
||||
|
||||
## Examples
|
||||
|
||||
### [basic](basic/)
|
||||
|
||||
Shows core SDK usage: creating a Kit instance, sending prompts, overriding the model, subscribing to events (tool calls, streaming), and session management.
|
||||
|
||||
```bash
|
||||
go run ./examples/sdk/basic
|
||||
```
|
||||
|
||||
### [scripting](scripting/)
|
||||
|
||||
A minimal script-friendly wrapper that takes a prompt from the command line and prints the response — useful for piping and automation.
|
||||
|
||||
```bash
|
||||
go run ./examples/sdk/scripting "Explain what this repo does"
|
||||
```
|
||||
|
||||
### [crypto-monitor](crypto-monitor/)
|
||||
|
||||
A background agent that checks Bitcoin and Ethereum prices every 30 minutes and sends desktop notifications via `notify-send` (dbus). Demonstrates using the SDK for a long-running autonomous task with a single tool.
|
||||
|
||||
```bash
|
||||
go run ./examples/sdk/crypto-monitor
|
||||
|
||||
# Override the check interval:
|
||||
CRYPTO_INTERVAL=5m go run ./examples/sdk/crypto-monitor
|
||||
```
|
||||
|
||||
## Getting Started
|
||||
|
||||
```go
|
||||
import kit "github.com/mark3labs/kit/pkg/kit"
|
||||
|
||||
host, err := kit.New(ctx, nil) // uses ~/.kit.yml defaults
|
||||
defer host.Close()
|
||||
|
||||
response, err := host.Prompt(ctx, "Hello!")
|
||||
```
|
||||
|
||||
See the [SDK README](../../pkg/kit/README.md) for the full API reference.
|
||||
@@ -0,0 +1,85 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"time"
|
||||
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
const systemPrompt = `You are a cryptocurrency price monitor. Your job is to:
|
||||
|
||||
1. Fetch the current prices of Bitcoin and Ethereum using bash with curl
|
||||
2. Send a desktop notification with the results using notify-send
|
||||
|
||||
To fetch prices, use this CoinGecko API endpoint (no API key needed):
|
||||
curl -s 'https://api.coingecko.com/api/v3/simple/price?ids=bitcoin,ethereum&vs_currencies=usd&include_24hr_change=true'
|
||||
|
||||
To send a desktop notification:
|
||||
notify-send -i dialog-information "Crypto Prices" "BTC: $XX,XXX (+X.X%)\nETH: $X,XXX (+X.X%)"
|
||||
|
||||
Include the 24h percentage change in the notification. Use a green arrow (▲) for
|
||||
positive changes and a red arrow (▼) for negative. Format prices with commas.
|
||||
|
||||
If the API call fails, send a notification about the failure instead.
|
||||
|
||||
Always complete both steps: fetch then notify. Be concise — no commentary needed.`
|
||||
|
||||
func main() {
|
||||
interval := 30 * time.Minute
|
||||
if os.Getenv("CRYPTO_INTERVAL") != "" {
|
||||
d, err := time.ParseDuration(os.Getenv("CRYPTO_INTERVAL"))
|
||||
if err == nil {
|
||||
interval = d
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||
defer cancel()
|
||||
|
||||
host, err := kit.New(ctx, &kit.Options{
|
||||
SystemPrompt: systemPrompt,
|
||||
Tools: []kit.Tool{kit.NewBashTool()},
|
||||
NoSession: true,
|
||||
Quiet: true,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create kit instance: %v", err)
|
||||
}
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
fmt.Printf("Crypto price monitor started (every %s)\n", interval)
|
||||
fmt.Println("Press Ctrl+C to stop")
|
||||
|
||||
// Run immediately on startup, then on each tick.
|
||||
check(ctx, host)
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
check(ctx, host)
|
||||
case <-ctx.Done():
|
||||
fmt.Println("\nStopping price monitor")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func check(ctx context.Context, host *kit.Kit) {
|
||||
fmt.Printf("[%s] Checking prices...\n", time.Now().Format("15:04:05"))
|
||||
|
||||
// Clear session so each check is independent.
|
||||
host.ClearSession()
|
||||
|
||||
_, err := host.Prompt(ctx, "Fetch current Bitcoin and Ethereum prices and send a desktop notification.")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,7 @@ go 1.26.1
|
||||
require (
|
||||
charm.land/bubbles/v2 v2.0.0
|
||||
charm.land/bubbletea/v2 v2.0.2
|
||||
charm.land/fantasy v0.16.0
|
||||
charm.land/fantasy v0.17.1
|
||||
charm.land/huh/v2 v2.0.3
|
||||
charm.land/lipgloss/v2 v2.0.2
|
||||
github.com/alecthomas/chroma/v2 v2.23.1
|
||||
@@ -13,7 +13,7 @@ require (
|
||||
github.com/charmbracelet/fang v1.0.0
|
||||
github.com/charmbracelet/log v1.0.0
|
||||
github.com/coder/acp-go-sdk v0.6.3
|
||||
github.com/mark3labs/mcp-go v0.45.0
|
||||
github.com/mark3labs/mcp-go v0.46.0
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/traefik/yaegi v0.16.1
|
||||
@@ -23,14 +23,14 @@ require (
|
||||
|
||||
require (
|
||||
cloud.google.com/go v0.123.0 // indirect
|
||||
cloud.google.com/go/auth v0.18.2 // indirect
|
||||
cloud.google.com/go/auth v0.19.0 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.9.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect
|
||||
github.com/atotto/clipboard v0.1.4 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.4 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.12 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.12 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20 // indirect
|
||||
@@ -45,8 +45,6 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 // indirect
|
||||
github.com/aws/smithy-go v1.24.2 // indirect
|
||||
github.com/aymerick/douceur v0.2.0 // indirect
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/buger/jsonparser v1.1.2 // indirect
|
||||
github.com/catppuccin/go v0.3.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab // indirect
|
||||
@@ -56,9 +54,9 @@ require (
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266 // indirect
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260316091819-b93f6a3b8502 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260322003602-9b007323c5cd // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260323091123-df7b1bcffcca // indirect
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260322003602-9b007323c5cd // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260323091123-df7b1bcffcca // indirect
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/json v0.2.0 // indirect
|
||||
github.com/charmbracelet/x/termios v0.1.1 // indirect
|
||||
@@ -77,18 +75,17 @@ require (
|
||||
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/jsonschema-go v0.4.2 // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.19.0 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.20.0 // indirect
|
||||
github.com/gorilla/css v1.0.1 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/invopop/jsonschema v0.13.0 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.2.12 // indirect
|
||||
github.com/kaptinlin/jsonpointer v0.4.17 // indirect
|
||||
github.com/kaptinlin/jsonschema v0.7.6 // indirect
|
||||
github.com/kaptinlin/messageformat-go v0.4.18 // indirect
|
||||
github.com/mailru/easyjson v0.9.2 // indirect
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect
|
||||
github.com/muesli/mango v0.2.0 // indirect
|
||||
@@ -96,7 +93,7 @@ require (
|
||||
github.com/muesli/mango-pflag v0.2.0 // indirect
|
||||
github.com/muesli/reflow v0.3.0 // indirect
|
||||
github.com/muesli/roff v0.1.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.3.0 // indirect
|
||||
github.com/sagikazarmark/locafero v0.12.0 // indirect
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
@@ -105,10 +102,9 @@ require (
|
||||
github.com/tidwall/match v1.2.0 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
github.com/yuin/goldmark v1.7.17 // indirect
|
||||
github.com/yuin/goldmark v1.8.2 // indirect
|
||||
github.com/yuin/goldmark-emoji v1.0.6 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 // indirect
|
||||
@@ -122,7 +118,7 @@ require (
|
||||
golang.org/x/net v0.52.0 // indirect
|
||||
golang.org/x/oauth2 v0.36.0 // indirect
|
||||
golang.org/x/time v0.15.0 // indirect
|
||||
google.golang.org/api v0.272.0 // indirect
|
||||
google.golang.org/api v0.273.0 // indirect
|
||||
google.golang.org/genai v1.51.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7 // indirect
|
||||
google.golang.org/grpc v1.79.3 // indirect
|
||||
|
||||
@@ -2,16 +2,16 @@ charm.land/bubbles/v2 v2.0.0 h1:tE3eK/pHjmtrDiRdoC9uGNLgpopOd8fjhEe31B/ai5s=
|
||||
charm.land/bubbles/v2 v2.0.0/go.mod h1:rCHoleP2XhU8um45NTuOWBPNVHxnkXKTiZqcclL/qOI=
|
||||
charm.land/bubbletea/v2 v2.0.2 h1:4CRtRnuZOdFDTWSff9r8QFt/9+z6Emubz3aDMnf/dx0=
|
||||
charm.land/bubbletea/v2 v2.0.2/go.mod h1:3LRff2U4WIYXy7MTxfbAQ+AdfM3D8Xuvz2wbsOD9OHQ=
|
||||
charm.land/fantasy v0.16.0 h1:vE/6sR9nPcSD8qXJXX6wR8NXjtWlBVAzwQmTh5pHVrs=
|
||||
charm.land/fantasy v0.16.0/go.mod h1:VZjpXVh7IgeiIzGQybEnKzd68ofDsRj94+kzH1ZCAfQ=
|
||||
charm.land/fantasy v0.17.1 h1:SQzfnyJPDuQWt6e//KKmQmEEXdqHMC0IZz10XwkLcEM=
|
||||
charm.land/fantasy v0.17.1/go.mod h1:FF5ALCCHETacHJPBqU42CtwMInYQ0ul52fdzIHQMbQk=
|
||||
charm.land/huh/v2 v2.0.3 h1:2cJsMqEPwSywGHvdlKsJyQKPtSJLVnFKyFbsYZTlLkU=
|
||||
charm.land/huh/v2 v2.0.3/go.mod h1:93eEveeeqn47MwiC3tf+2atZ2l7Is88rAtmZNZ8x9Wc=
|
||||
charm.land/lipgloss/v2 v2.0.2 h1:xFolbF8JdpNkM2cEPTfXEcW1p6NRzOWTSamRfYEw8cs=
|
||||
charm.land/lipgloss/v2 v2.0.2/go.mod h1:KjPle2Qd3YmvP1KL5OMHiHysGcNwq6u83MUjYkFvEkM=
|
||||
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
|
||||
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
|
||||
cloud.google.com/go/auth v0.18.2 h1:+Nbt5Ev0xEqxlNjd6c+yYUeosQ5TtEUaNcN/3FozlaM=
|
||||
cloud.google.com/go/auth v0.18.2/go.mod h1:xD+oY7gcahcu7G2SG2DsBerfFxgPAJz17zz2joOFF3M=
|
||||
cloud.google.com/go/auth v0.19.0 h1:DGYwtbcsGsT1ywuxsIoWi1u/vlks0moIblQHgSDgQkQ=
|
||||
cloud.google.com/go/auth v0.19.0/go.mod h1:2Aph7BT2KnaSFOM0JDPyiYgNh6PL9vGMiP8CUIXZ+IY=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
|
||||
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
|
||||
@@ -36,8 +36,8 @@ github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.4 h1:10f50G7WyU02T56ox1wWXq+zTX9I1zxG46HYuG1hH/k=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.4/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 h1:3kGOqnh1pPeddVa/E37XNTaWJ8W6vrbYV9lJEkCnhuY=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 h1:eBMB84YGghSocM7PsjmmPffTa+1FBUeNvGvFou6V/4o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.12 h1:O3csC7HUGn2895eNrLytOJQdoL2xyJy0iYXhoZ1OmP0=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.12/go.mod h1:96zTvoOFR4FURjI+/5wY1vc1ABceROO4lWgWJuxgy0g=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.12 h1:oqtA6v+y5fZg//tcTWahyN9PEn5eDU/Wpvc2+kJ4aY8=
|
||||
@@ -70,10 +70,6 @@ github.com/aymanbagabas/go-udiff v0.4.1 h1:OEIrQ8maEeDBXQDoGCbbTTXYJMYRCRO1fnodZ
|
||||
github.com/aymanbagabas/go-udiff v0.4.1/go.mod h1:0L9PGwj20lrtmEMeyw4WKJ/TMyDtvAoK9bf2u/mNo3w=
|
||||
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
|
||||
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
|
||||
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
|
||||
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
|
||||
github.com/buger/jsonparser v1.1.2 h1:frqHqw7otoVbk5M8LlE/L7HTnIq2v9RX6EJ48i9AxJk=
|
||||
github.com/buger/jsonparser v1.1.2/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
|
||||
github.com/catppuccin/go v0.3.0 h1:d+0/YicIq+hSTo5oPuRi5kOpqkVA5tAsU6dNhvRu+aY=
|
||||
github.com/catppuccin/go v0.3.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
@@ -104,14 +100,14 @@ github.com/charmbracelet/x/conpty v0.1.1 h1:s1bUxjoi7EpqiXysVtC+a8RrvPPNcNvAjfi4
|
||||
github.com/charmbracelet/x/conpty v0.1.1/go.mod h1:OmtR77VODEFbiTzGE9G1XiRJAga6011PIm4u5fTNZpk=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 h1:JSt3B+U9iqk37QUU2Rvb6DSBYRLtWqFqfxf8l5hOZUA=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86/go.mod h1:2P0UgXMEa6TsToMSuFqKFQR+fZTO9CNGUNokkPatT/0=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260322003602-9b007323c5cd h1:eStB6uX52pgrm6TxQcEKctPrEC+a/9ubJC+P671idOc=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260322003602-9b007323c5cd/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260323091123-df7b1bcffcca h1:62yAoS1Ynbuzwcn1LkNBxi3IMF5p0E0cHCoaLOOmN9w=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260323091123-df7b1bcffcca/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f h1:pk6gmGpCE7F3FcjaOEKYriCvpmIN4+6OS/RD0vm4uIA=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f/go.mod h1:IfZAMTHB6XkZSeXUqriemErjAWCCzT0LwjKFYCZyw0I=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 h1:55/qLwjIh0gL0Vni+QAWk7T/qRVP6sBf+2agPBgnOFE=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0/go.mod h1:5UHwmG+is5THxMyCJHNPCn2/ecI07aKNrW+LcResjJ8=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260322003602-9b007323c5cd h1:U8xj0UXwqHzO+UYHZJopKF+gWaQEW8oj60fmiq9TFY4=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260322003602-9b007323c5cd/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260323091123-df7b1bcffcca h1:QQoyQLgUzojMNWHVHToN6d9qTvT0KWtxUKIRPx/Ox5o=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260323091123-df7b1bcffcca/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 h1:i69S2XI7uG1u4NLGeJPSYU++Nmjvpo9nwd6aoEm7gkA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0/go.mod h1:/ehtMPNh9K4odGFkqYJKpIYyePhdp1hLBRvyY4bWkH8=
|
||||
github.com/charmbracelet/x/json v0.2.0 h1:DqB+ZGx2h+Z+1s98HOuOyli+i97wsFQIxP2ZQANTPrQ=
|
||||
@@ -173,14 +169,16 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
|
||||
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
|
||||
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
|
||||
github.com/googleapis/gax-go/v2 v2.19.0 h1:fYQaUOiGwll0cGj7jmHT/0nPlcrZDFPrZRhTsoCr8hE=
|
||||
github.com/googleapis/gax-go/v2 v2.19.0/go.mod h1:w2ROXVdfGEVFXzmlciUU4EdjHgWvB5h2n6x/8XSTTJA=
|
||||
github.com/googleapis/gax-go/v2 v2.20.0 h1:NIKVuLhDlIV74muWlsMM4CcQZqN6JJ20Qcxd9YMuYcs=
|
||||
github.com/googleapis/gax-go/v2 v2.20.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
|
||||
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
|
||||
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
@@ -189,8 +187,6 @@ github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUq
|
||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=
|
||||
github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0=
|
||||
github.com/kaptinlin/go-i18n v0.2.12 h1:ywDsvb4KDFddMC2dpI/rrIzGU2mWUSvHmWUm9BMsdl4=
|
||||
github.com/kaptinlin/go-i18n v0.2.12/go.mod h1:pVcu9qsW5pOIOoZFJXesRYmLos1vMQrby70JPAoWmJU=
|
||||
github.com/kaptinlin/jsonpointer v0.4.17 h1:mY9k8ciWncxbsECyaxKnR0MdmxamNdp2tLQkAKVrtSk=
|
||||
@@ -207,10 +203,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mailru/easyjson v0.9.2 h1:dX8U45hQsZpxd80nLvDGihsQ/OxlvTkVUXH2r/8cb2M=
|
||||
github.com/mailru/easyjson v0.9.2/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
|
||||
github.com/mark3labs/mcp-go v0.45.0 h1:s0S8qR/9fWaQ3pHxz7pm1uQ0DrswoSnRIxKIjbiQtkc=
|
||||
github.com/mark3labs/mcp-go v0.45.0/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw=
|
||||
github.com/mark3labs/mcp-go v0.46.0 h1:8KRibF4wcKejbLsHxCA/QBVUr5fQ9nwz/n8lGqmaALo=
|
||||
github.com/mark3labs/mcp-go v0.46.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
|
||||
@@ -234,8 +228,8 @@ github.com/muesli/roff v0.1.0 h1:YD0lalCotmYuF5HhZliKWlIx7IEhiXeSfq7hNjFqGF8=
|
||||
github.com/muesli/roff v0.1.0/go.mod h1:pjAHQM9hdUUwm/krAfrLGgJkXJ+YuhtsfZ42kieB2Ig=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pelletier/go-toml/v2 v2.3.0 h1:k59bC/lIZREW0/iVaQR8nDHxVq8OVlIzYCOJf421CaM=
|
||||
github.com/pelletier/go-toml/v2 v2.3.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo=
|
||||
@@ -279,14 +273,12 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/traefik/yaegi v0.16.1 h1:f1De3DVJqIDKmnasUF6MwmWv1dSEEat0wcpXhD2On3E=
|
||||
github.com/traefik/yaegi v0.16.1/go.mod h1:4eVhbPb3LnD2VigQjhYbEJ69vDRFdT2HQNrXx8eEwUY=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/yuin/goldmark v1.7.17 h1:p36OVWwRb246iHxA/U4p8OPEpOTESm4n+g+8t0EE5uA=
|
||||
github.com/yuin/goldmark v1.7.17/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
|
||||
github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE=
|
||||
github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
|
||||
github.com/yuin/goldmark-emoji v1.0.6 h1:QWfF2FYaXwL74tfGOW5izeiZepUDroDJfWubQI9HTHs=
|
||||
github.com/yuin/goldmark-emoji v1.0.6/go.mod h1:ukxJDKFpdFb5x0a5HqbdlcKtebh086iJpI31LTKmWuA=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
@@ -328,10 +320,14 @@ golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/api v0.272.0 h1:eLUQZGnAS3OHn31URRf9sAmRk3w2JjMx37d2k8AjJmA=
|
||||
google.golang.org/api v0.272.0/go.mod h1:wKjowi5LNJc5qarNvDCvNQBn3rVK8nSy6jg2SwRwzIA=
|
||||
google.golang.org/api v0.273.0 h1:r/Bcv36Xa/te1ugaN1kdJ5LoA5Wj/cL+a4gj6FiPBjQ=
|
||||
google.golang.org/api v0.273.0/go.mod h1:JbAt7mF+XVmWu6xNP8/+CTiGH30ofmCmk9nM8d8fHew=
|
||||
google.golang.org/genai v1.51.0 h1:IZGuUqgfx40INv3hLFGCbOSGp0qFqm7LVmDghzNIYqg=
|
||||
google.golang.org/genai v1.51.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 h1:XzmzkmB14QhVhgnawEVsOn6OFsnpyxNPRY9QV01dNB0=
|
||||
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:L43LFes82YgSonw6iTXTxXUX1OlULt4AQtkik4ULL/I=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7 h1:41r6JMbpzBMen0R/4TZeeAmGXSJC7DftGINUodzTkPI=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:EIQZ5bFCfRQDV4MhRle7+OgjNtZ6P1PiZBgAKuxXu/Y=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7 h1:ndE4FoJqsIceKP2oYSnUZqhTdYufCYYkqwtFzfrhI7w=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE=
|
||||
|
||||
+75
-5
@@ -63,6 +63,18 @@ type ToolCallContentHandler func(content string)
|
||||
// ReasoningDeltaHandler is a function type for handling streaming reasoning/thinking deltas.
|
||||
type ReasoningDeltaHandler func(delta string)
|
||||
|
||||
// ToolOutputHandler is a function type for handling streaming tool output chunks.
|
||||
// Used by tools like bash to stream output as it arrives rather than waiting
|
||||
// for the command to complete. The isStderr flag indicates if the chunk
|
||||
// contains stderr output.
|
||||
// Note: This is an alias for core.ToolOutputCallback to avoid import cycles.
|
||||
type ToolOutputHandler = core.ToolOutputCallback
|
||||
|
||||
// StepUsageHandler is a function type for handling token usage after each
|
||||
// complete step in a multi-step agent turn. This enables real-time cost
|
||||
// tracking during long-running tool-calling conversations.
|
||||
type StepUsageHandler func(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64)
|
||||
|
||||
// Agent represents an AI agent with core tool integration using the fantasy library.
|
||||
// Core tools (bash, read, write, edit, grep, find, ls) are registered as direct
|
||||
// fantasy.AgentTool implementations — no MCP layer, no serialization overhead.
|
||||
@@ -171,7 +183,8 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
|
||||
// Pass generation parameters when available.
|
||||
if agentConfig.ModelConfig != nil {
|
||||
if agentConfig.ModelConfig.MaxTokens > 0 {
|
||||
// Skip max_output_tokens for providers that don't support it (e.g., Codex OAuth)
|
||||
if agentConfig.ModelConfig.MaxTokens > 0 && !providerResult.SkipMaxOutputTokens {
|
||||
agentOpts = append(agentOpts, fantasy.WithMaxOutputTokens(int64(agentConfig.ModelConfig.MaxTokens)))
|
||||
}
|
||||
if agentConfig.ModelConfig.Temperature != nil {
|
||||
@@ -218,7 +231,7 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []fantasy.Message
|
||||
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult,
|
||||
onResponse, onToolCallContent, nil, nil)
|
||||
onResponse, onToolCallContent, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
// GenerateWithLoopAndStreaming processes messages using the fantasy agent with streaming and callbacks.
|
||||
@@ -229,8 +242,15 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
|
||||
onStreamingResponse StreamingResponseHandler,
|
||||
onReasoningDelta ReasoningDeltaHandler,
|
||||
onToolOutput ToolOutputHandler,
|
||||
onStepUsage StepUsageHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
|
||||
// Inject tool output handler into context for use by core tools (e.g., bash).
|
||||
if onToolOutput != nil {
|
||||
ctx = core.ContextWithToolOutputCallback(ctx, onToolOutput)
|
||||
}
|
||||
|
||||
// Fantasy requires the current user input as Prompt, with prior messages as history.
|
||||
// Extract the last user message text and files as the prompt, and pass everything
|
||||
// before it as Messages. Files (e.g. clipboard images) are passed via the Files
|
||||
@@ -256,7 +276,7 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
var completedStepMessages []fantasy.Message
|
||||
|
||||
// Use fantasy's streaming agent
|
||||
result, err := a.fantasyAgent.Stream(ctx, fantasy.AgentStreamCall{
|
||||
streamCall := fantasy.AgentStreamCall{
|
||||
Prompt: prompt,
|
||||
Files: files,
|
||||
Messages: history,
|
||||
@@ -338,9 +358,58 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
if text != "" && len(toolCalls) > 0 && onToolCallContent != nil {
|
||||
onToolCallContent(text)
|
||||
}
|
||||
// Emit step usage for real-time cost tracking
|
||||
if onStepUsage != nil {
|
||||
onStepUsage(step.Usage.InputTokens, step.Usage.OutputTokens,
|
||||
step.Usage.CacheReadTokens, step.Usage.CacheCreationTokens)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// If a steer channel is attached to the context, wire up a
|
||||
// PrepareStep function that drains the channel between steps
|
||||
// and injects pending steer messages as user messages before
|
||||
// the next LLM call. This enables graceful mid-turn steering
|
||||
// without cancelling in-progress tool execution.
|
||||
if steerCh := steerChFromContext(ctx); steerCh != nil {
|
||||
onConsumed := steerConsumedFromContext(ctx)
|
||||
streamCall.PrepareStep = func(
|
||||
stepCtx context.Context,
|
||||
opts fantasy.PrepareStepFunctionOptions,
|
||||
) (context.Context, fantasy.PrepareStepResult, error) {
|
||||
// Drain all pending steer messages (non-blocking).
|
||||
var steered []string
|
||||
for {
|
||||
select {
|
||||
case msg := <-steerCh:
|
||||
steered = append(steered, msg)
|
||||
default:
|
||||
goto done
|
||||
}
|
||||
}
|
||||
done:
|
||||
result := fantasy.PrepareStepResult{
|
||||
Model: opts.Model,
|
||||
Messages: opts.Messages,
|
||||
}
|
||||
if len(steered) > 0 {
|
||||
// Inject each steer message as a user message so the
|
||||
// LLM sees the redirection on the next step.
|
||||
for _, text := range steered {
|
||||
result.Messages = append(result.Messages,
|
||||
fantasy.NewUserMessage(text))
|
||||
}
|
||||
// Notify that steer messages were consumed.
|
||||
if onConsumed != nil {
|
||||
onConsumed(len(steered))
|
||||
}
|
||||
}
|
||||
return stepCtx, result, nil
|
||||
}
|
||||
}
|
||||
|
||||
result, err := a.fantasyAgent.Stream(ctx, streamCall)
|
||||
if err != nil {
|
||||
// On cancellation (or any error), return a partial result
|
||||
// containing messages from completed steps so the caller can
|
||||
@@ -604,7 +673,8 @@ func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) err
|
||||
}
|
||||
|
||||
// Pass generation parameters when available.
|
||||
if config.MaxTokens > 0 {
|
||||
// Skip max_output_tokens for providers that don't support it (e.g., Codex OAuth)
|
||||
if config.MaxTokens > 0 && !providerResult.SkipMaxOutputTokens {
|
||||
agentOpts = append(agentOpts, fantasy.WithMaxOutputTokens(int64(config.MaxTokens)))
|
||||
}
|
||||
if config.Temperature != nil {
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
package agent
|
||||
|
||||
import "context"
|
||||
|
||||
// steerChKey is the context key for the steer channel.
|
||||
type steerChKey struct{}
|
||||
|
||||
// steerConsumedKey is the context key for the steer-consumed callback.
|
||||
type steerConsumedKey struct{}
|
||||
|
||||
// ContextWithSteerCh returns a new context with the steer channel attached.
|
||||
// The agent's PrepareStep function checks this channel between steps and
|
||||
// injects any pending steer messages as user messages before the next LLM call.
|
||||
func ContextWithSteerCh(ctx context.Context, ch <-chan string) context.Context {
|
||||
return context.WithValue(ctx, steerChKey{}, ch)
|
||||
}
|
||||
|
||||
// ContextWithSteerConsumed returns a new context with a callback that fires
|
||||
// when steer messages are consumed by PrepareStep. The count argument is the
|
||||
// number of messages injected in this batch.
|
||||
func ContextWithSteerConsumed(ctx context.Context, fn func(count int)) context.Context {
|
||||
return context.WithValue(ctx, steerConsumedKey{}, fn)
|
||||
}
|
||||
|
||||
// steerChFromContext extracts the steer channel from the context, or nil.
|
||||
func steerChFromContext(ctx context.Context) <-chan string {
|
||||
ch, _ := ctx.Value(steerChKey{}).(<-chan string)
|
||||
return ch
|
||||
}
|
||||
|
||||
// steerConsumedFromContext extracts the steer-consumed callback, or nil.
|
||||
func steerConsumedFromContext(ctx context.Context) func(int) {
|
||||
fn, _ := ctx.Value(steerConsumedKey{}).(func(int))
|
||||
return fn
|
||||
}
|
||||
+121
-26
@@ -3,6 +3,7 @@ package app
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
@@ -159,11 +160,57 @@ func (a *App) QueueLength() int {
|
||||
return len(a.queue)
|
||||
}
|
||||
|
||||
// Steer cancels the current agent step (if running), clears the queue, and
|
||||
// sends a new message that will execute as soon as the current step finishes
|
||||
// cancelling. If the agent is idle, the message executes immediately.
|
||||
// This is the "steer" delivery mode for SendMessage.
|
||||
func (a *App) Steer(prompt string) {
|
||||
// Steer injects a steering message into the currently running agent turn.
|
||||
// If the agent is in a multi-step tool loop, the message is delivered after
|
||||
// the current tool execution finishes but before the next LLM call (graceful
|
||||
// mid-turn injection via Fantasy's PrepareStep). If the agent is streaming
|
||||
// a text-only response (no pending tool calls), the message waits until the
|
||||
// response completes and then executes as the next turn.
|
||||
//
|
||||
// If the agent is idle, the message starts executing immediately (same as Run).
|
||||
//
|
||||
// Returns the number of pending steer/queue items (0 = started immediately,
|
||||
// >0 = injected/queued). The caller must update UI state based on the return
|
||||
// value — Steer does NOT send events to the program to avoid deadlocking
|
||||
// when called from within Update().
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) Steer(prompt string) int {
|
||||
a.mu.Lock()
|
||||
|
||||
if a.closed {
|
||||
a.mu.Unlock()
|
||||
return 0
|
||||
}
|
||||
|
||||
if !a.busy {
|
||||
// Not busy — start immediately, same as Run().
|
||||
item := queueItem{Prompt: prompt}
|
||||
a.busy = true
|
||||
a.wg.Add(1)
|
||||
a.mu.Unlock()
|
||||
go a.drainQueue(item)
|
||||
return 0
|
||||
}
|
||||
|
||||
a.mu.Unlock()
|
||||
|
||||
// Agent is busy — inject via the SDK's steer channel. The message
|
||||
// will be picked up by PrepareStep between agent steps (after tool
|
||||
// execution, before next LLM call). If PrepareStep doesn't fire
|
||||
// (text-only response), drainQueue will pick it up after the turn.
|
||||
if a.opts.Kit != nil {
|
||||
a.opts.Kit.InjectSteer(prompt)
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// InterruptAndSend cancels the current agent step (if running), clears the
|
||||
// queue, and sends a new message that will execute as soon as the current
|
||||
// step finishes cancelling. If the agent is idle, the message executes
|
||||
// immediately. This is the hard-cancel delivery mode used by extensions'
|
||||
// CancelAndSend.
|
||||
func (a *App) InterruptAndSend(prompt string) {
|
||||
a.mu.Lock()
|
||||
|
||||
if a.closed {
|
||||
@@ -226,6 +273,10 @@ func (a *App) SwitchTreeSession(ts *session.TreeManager) {
|
||||
_ = old.Close()
|
||||
}
|
||||
a.opts.TreeSession = ts
|
||||
// Also update the kit SDK's tree session so messages are persisted correctly.
|
||||
if a.opts.Kit != nil {
|
||||
a.opts.Kit.SetTreeSession(ts)
|
||||
}
|
||||
// Reload messages from new session.
|
||||
a.store.Clear()
|
||||
if ts != nil {
|
||||
@@ -401,6 +452,13 @@ func (a *App) Close() {
|
||||
|
||||
// Wait for background goroutines.
|
||||
a.wg.Wait()
|
||||
|
||||
// Clean up empty session file on shutdown.
|
||||
if ts := a.opts.TreeSession; ts != nil && ts.IsEmpty() {
|
||||
if path := ts.GetFilePath(); path != "" {
|
||||
_ = os.Remove(path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -434,6 +492,24 @@ func (a *App) drainQueue(first queueItem) {
|
||||
// Process all collected items as a single batch
|
||||
a.runQueueBatch(items)
|
||||
|
||||
// Drain any unconsumed steer messages from the SDK channel.
|
||||
// These arrive when the user steered during a text-only response
|
||||
// (no tool calls, so PrepareStep didn't fire for a second step).
|
||||
// They go to the front of the queue so they run next.
|
||||
if a.opts.Kit != nil {
|
||||
if leftover := a.opts.Kit.DrainSteer(); len(leftover) > 0 {
|
||||
a.mu.Lock()
|
||||
steerItems := make([]queueItem, len(leftover))
|
||||
for i, text := range leftover {
|
||||
steerItems[i] = queueItem{Prompt: text}
|
||||
}
|
||||
a.queue = append(steerItems, a.queue...)
|
||||
a.mu.Unlock()
|
||||
// Notify UI about the consumed steer messages.
|
||||
a.sendEvent(SteerConsumedEvent{})
|
||||
}
|
||||
}
|
||||
|
||||
// Check if more items were queued while we were processing
|
||||
a.mu.Lock()
|
||||
hasMore := len(a.queue) > 0
|
||||
@@ -486,11 +562,10 @@ func (a *App) runQueueBatch(items []queueItem) {
|
||||
result, err := a.executeBatch(stepCtx, items, eventFn)
|
||||
if err != nil {
|
||||
if stepCtx.Err() != nil {
|
||||
// Step was cancelled by the user (double-ESC). The SDK's
|
||||
// runTurn has rolled the tree session back to the pre-turn
|
||||
// state, discarding the user message and any tool call/result
|
||||
// pairs from the cancelled turn. Sync the in-memory store
|
||||
// to match the rolled-back tree session.
|
||||
// Step was cancelled by the user (double-ESC). The SDK
|
||||
// preserves the user message and any completed tool
|
||||
// call/result pairs; only the in-progress message or tool
|
||||
// call is discarded. Sync the in-memory store to match.
|
||||
if ts := a.opts.TreeSession; ts != nil {
|
||||
a.store.Replace(ts.GetFantasyMessages())
|
||||
}
|
||||
@@ -672,6 +747,15 @@ func (a *App) subscribeSDKEvents(sendFn func(tea.Msg)) func() {
|
||||
sendFn(StreamChunkEvent{Content: ev.Chunk})
|
||||
case kit.ReasoningDeltaEvent:
|
||||
sendFn(ReasoningChunkEvent{Delta: ev.Delta})
|
||||
case kit.ToolOutputEvent:
|
||||
sendFn(ToolOutputEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName,
|
||||
Chunk: ev.Chunk,
|
||||
IsStderr: ev.IsStderr,
|
||||
})
|
||||
case kit.SteerConsumedEvent:
|
||||
sendFn(SteerConsumedEvent{})
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -842,28 +926,39 @@ func (a *App) PrintBlockFromExtension(opts extensions.PrintBlockOpts) {
|
||||
}
|
||||
|
||||
// updateUsageFromTurnResult records token usage from an SDK TurnResult into the
|
||||
// configured UsageTracker. This is the SDK-path equivalent of updateUsage.
|
||||
// configured UsageTracker. Called once per turn after the turn completes.
|
||||
//
|
||||
// Cost/token accumulation uses TotalUsage (sum across all tool-calling steps in
|
||||
// the turn). Context-window fill uses FinalUsage.InputTokens only — that is the
|
||||
// number of tokens sent to the model on the last API call, which equals the
|
||||
// actual context window occupation (all accumulated messages + tool results).
|
||||
// OutputTokens are not added here because they are the response length, not
|
||||
// context fill.
|
||||
func (a *App) updateUsageFromTurnResult(result *kit.TurnResult, userPrompt string) {
|
||||
if a.opts.UsageTracker == nil || result == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if result.TotalUsage != nil {
|
||||
inputTokens := int(result.TotalUsage.InputTokens)
|
||||
outputTokens := int(result.TotalUsage.OutputTokens)
|
||||
if inputTokens > 0 && outputTokens > 0 {
|
||||
cacheReadTokens := int(result.TotalUsage.CacheReadTokens)
|
||||
cacheWriteTokens := int(result.TotalUsage.CacheCreationTokens)
|
||||
a.opts.UsageTracker.UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens)
|
||||
} else {
|
||||
a.opts.UsageTracker.EstimateAndUpdateUsage(userPrompt, result.Response)
|
||||
return
|
||||
}
|
||||
// --- Accumulate cost/token totals for the session ---
|
||||
if result.TotalUsage != nil && result.TotalUsage.InputTokens > 0 {
|
||||
a.opts.UsageTracker.UpdateUsage(
|
||||
int(result.TotalUsage.InputTokens),
|
||||
int(result.TotalUsage.OutputTokens),
|
||||
int(result.TotalUsage.CacheReadTokens),
|
||||
int(result.TotalUsage.CacheCreationTokens),
|
||||
)
|
||||
} else {
|
||||
// Provider didn't report token counts — fall back to character-based
|
||||
// estimates so the footer shows something rather than nothing.
|
||||
a.opts.UsageTracker.EstimateAndUpdateUsage(userPrompt, result.Response)
|
||||
}
|
||||
|
||||
if result.FinalUsage != nil {
|
||||
if ct := int(result.FinalUsage.InputTokens) + int(result.FinalUsage.OutputTokens); ct > 0 {
|
||||
a.opts.UsageTracker.SetContextTokens(ct)
|
||||
}
|
||||
// --- Context window fill (drives the % bar) ---
|
||||
// Use FinalUsage.InputTokens: the input token count of the last API call
|
||||
// equals the number of tokens currently occupying the context window.
|
||||
// Adding OutputTokens would overstate fill since the response is not part
|
||||
// of the context that was *sent* to the model.
|
||||
if result.FinalUsage != nil && result.FinalUsage.InputTokens > 0 {
|
||||
a.opts.UsageTracker.SetContextTokens(int(result.FinalUsage.InputTokens))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,6 +54,19 @@ type ToolResultEvent struct {
|
||||
IsError bool
|
||||
}
|
||||
|
||||
// ToolOutputEvent is sent when a tool produces streaming output chunks (e.g., bash output).
|
||||
// This allows the TUI to display tool output as it arrives, before the tool completes.
|
||||
type ToolOutputEvent struct {
|
||||
// ToolCallID is the stable identifier for the tool call producing output.
|
||||
ToolCallID string
|
||||
// ToolName is the name of the tool producing output.
|
||||
ToolName string
|
||||
// Chunk is a piece of the tool's output text.
|
||||
Chunk string
|
||||
// IsStderr indicates whether this chunk came from stderr.
|
||||
IsStderr bool
|
||||
}
|
||||
|
||||
// ToolCallContentEvent is sent when a step includes text content alongside tool calls.
|
||||
// This allows the TUI to display assistant commentary that accompanies tool usage.
|
||||
type ToolCallContentEvent struct {
|
||||
@@ -128,6 +141,12 @@ type CompactErrorEvent struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// SteerConsumedEvent is sent when one or more steering messages have been
|
||||
// consumed — either injected mid-turn via PrepareStep, or drained into the
|
||||
// queue after a turn completes. The TUI uses this to clear the steering
|
||||
// badge from the display.
|
||||
type SteerConsumedEvent struct{}
|
||||
|
||||
// ModelChangedEvent is sent when an extension changes the active model via
|
||||
// ctx.SetModel. The TUI updates the model name shown in the status bar and
|
||||
// message attribution.
|
||||
|
||||
@@ -10,9 +10,10 @@ import (
|
||||
)
|
||||
|
||||
// CredentialStore holds all stored credentials for various providers.
|
||||
// Currently supports Anthropic credentials with both OAuth and API key authentication methods.
|
||||
// Currently supports Anthropic and OpenAI credentials with both OAuth and API key authentication methods.
|
||||
type CredentialStore struct {
|
||||
Anthropic *AnthropicCredentials `json:"anthropic,omitempty"`
|
||||
OpenAI *OpenAICredentials `json:"openai,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicCredentials holds Anthropic API credentials supporting both OAuth
|
||||
@@ -28,6 +29,20 @@ type AnthropicCredentials struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// OpenAICredentials holds OpenAI API credentials supporting both OAuth
|
||||
// and API key authentication methods. The Type field indicates which authentication
|
||||
// method is being used. For OAuth, tokens are stored with expiration timestamps
|
||||
// for automatic refresh. For API keys, only the key itself is stored.
|
||||
type OpenAICredentials struct {
|
||||
Type string `json:"type"` // "oauth" or "api_key"
|
||||
APIKey string `json:"api_key,omitempty"` // For API key auth
|
||||
AccessToken string `json:"access_token,omitempty"` // For OAuth
|
||||
RefreshToken string `json:"refresh_token,omitempty"` // For OAuth
|
||||
ExpiresAt int64 `json:"expires_at,omitempty"` // For OAuth
|
||||
AccountID string `json:"account_id,omitempty"` // For OAuth (ChatGPT account ID)
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// IsExpired checks if the OAuth token is expired based on the ExpiresAt timestamp.
|
||||
// Returns false for API key authentication or if no expiration is set.
|
||||
func (c *AnthropicCredentials) IsExpired() bool {
|
||||
@@ -48,6 +63,26 @@ func (c *AnthropicCredentials) NeedsRefresh() bool {
|
||||
return time.Now().Unix() >= (c.ExpiresAt - 300) // 5 minutes buffer
|
||||
}
|
||||
|
||||
// IsExpired checks if the OAuth token is expired based on the ExpiresAt timestamp.
|
||||
// Returns false for API key authentication or if no expiration is set.
|
||||
func (c *OpenAICredentials) IsExpired() bool {
|
||||
if c.Type != "oauth" || c.ExpiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= c.ExpiresAt
|
||||
}
|
||||
|
||||
// NeedsRefresh checks if the OAuth token needs refresh, returning true if the token
|
||||
// will expire within the next 5 minutes. This allows for proactive token refresh
|
||||
// to avoid authentication failures during operations. Returns false for API key
|
||||
// authentication or if no expiration is set.
|
||||
func (c *OpenAICredentials) NeedsRefresh() bool {
|
||||
if c.Type != "oauth" || c.ExpiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= (c.ExpiresAt - 300) // 5 minutes buffer
|
||||
}
|
||||
|
||||
// CredentialManager handles secure storage and retrieval of authentication credentials.
|
||||
// It manages a JSON file stored in the user's config directory with appropriate
|
||||
// file permissions for security.
|
||||
@@ -212,6 +247,142 @@ func (cm *CredentialManager) HasAnthropicCredentials() (bool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// SetOpenAICredentials stores OpenAI API key credentials. It validates the
|
||||
// API key format before storing. The API key must start with "sk-" and be
|
||||
// at least 20 characters long. Returns an error if the API key is invalid or
|
||||
// if storage fails.
|
||||
func (cm *CredentialManager) SetOpenAICredentials(apiKey string) error {
|
||||
if err := validateOpenAIAPIKey(apiKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store.OpenAI = &OpenAICredentials{
|
||||
Type: "api_key",
|
||||
APIKey: apiKey,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
return cm.SaveCredentials(store)
|
||||
}
|
||||
|
||||
// GetOpenAICredentials retrieves stored OpenAI credentials. Returns nil if
|
||||
// no credentials are stored. The returned credentials may be either OAuth or API
|
||||
// key type, check the Type field to determine which.
|
||||
func (cm *CredentialManager) GetOpenAICredentials() (*OpenAICredentials, error) {
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return store.OpenAI, nil
|
||||
}
|
||||
|
||||
// RemoveOpenAICredentials removes stored OpenAI credentials from storage.
|
||||
// If this was the only credential stored, the entire credentials file is removed.
|
||||
// Returns an error if the removal fails.
|
||||
func (cm *CredentialManager) RemoveOpenAICredentials() error {
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store.OpenAI = nil
|
||||
|
||||
// If store is empty, remove the file entirely
|
||||
if store.Anthropic == nil && store.OpenAI == nil {
|
||||
if err := os.Remove(cm.credentialsPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to remove credentials file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return cm.SaveCredentials(store)
|
||||
}
|
||||
|
||||
// HasOpenAICredentials checks if valid OpenAI credentials are stored.
|
||||
// Returns true if either a non-empty OAuth access token or API key is present,
|
||||
// false otherwise. Returns an error if credentials cannot be loaded.
|
||||
func (cm *CredentialManager) HasOpenAICredentials() (bool, error) {
|
||||
creds, err := cm.GetOpenAICredentials()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if creds == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Check based on credential type
|
||||
switch creds.Type {
|
||||
case "oauth":
|
||||
return creds.AccessToken != "", nil
|
||||
case "api_key":
|
||||
return creds.APIKey != "", nil
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// SetOpenAIOAuthCredentials stores OpenAI OAuth credentials in the credential manager's secure storage.
|
||||
// The credentials should include access token, refresh token, and expiration information.
|
||||
// Returns an error if the credentials cannot be saved.
|
||||
func (cm *CredentialManager) SetOpenAIOAuthCredentials(creds *OpenAICredentials) error {
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store.OpenAI = creds
|
||||
return cm.SaveCredentials(store)
|
||||
}
|
||||
|
||||
// GetValidOpenAIAccessToken returns a valid access token for API requests. For OAuth credentials,
|
||||
// it automatically refreshes the token if it's expired or about to expire. For API key
|
||||
// credentials, it simply returns the API key. Returns an error if no credentials are found,
|
||||
// if token refresh fails, or if the credential type is unknown.
|
||||
func (cm *CredentialManager) GetValidOpenAIAccessToken() (string, error) {
|
||||
creds, err := cm.GetOpenAICredentials()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if creds == nil {
|
||||
return "", fmt.Errorf("no credentials found")
|
||||
}
|
||||
|
||||
// For API key auth, return the API key
|
||||
if creds.Type == "api_key" {
|
||||
return creds.APIKey, nil
|
||||
}
|
||||
|
||||
// For OAuth, check if token needs refresh
|
||||
if creds.Type == "oauth" {
|
||||
if creds.NeedsRefresh() {
|
||||
// Refresh the token
|
||||
client := NewOpenAIOAuthClient()
|
||||
newCreds, err := client.RefreshToken(creds.RefreshToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to refresh token: %w", err)
|
||||
}
|
||||
|
||||
// Update stored credentials
|
||||
if err := cm.SetOpenAIOAuthCredentials(newCreds); err != nil {
|
||||
return "", fmt.Errorf("failed to save refreshed token: %w", err)
|
||||
}
|
||||
|
||||
return newCreds.AccessToken, nil
|
||||
}
|
||||
|
||||
return creds.AccessToken, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("unknown credential type: %s", creds.Type)
|
||||
}
|
||||
|
||||
// GetCredentialsPath returns the absolute path to the credentials JSON file.
|
||||
// This is useful for debugging or displaying the storage location to users.
|
||||
func (cm *CredentialManager) GetCredentialsPath() string {
|
||||
@@ -238,6 +409,26 @@ func validateAnthropicAPIKey(apiKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateOpenAIAPIKey validates the format of an OpenAI API key
|
||||
func validateOpenAIAPIKey(apiKey string) error {
|
||||
apiKey = strings.TrimSpace(apiKey)
|
||||
|
||||
if apiKey == "" {
|
||||
return fmt.Errorf("API key cannot be empty")
|
||||
}
|
||||
|
||||
// OpenAI API keys typically start with "sk-" and are quite long
|
||||
if !strings.HasPrefix(apiKey, "sk-") {
|
||||
return fmt.Errorf("invalid OpenAI API key format (should start with 'sk-')")
|
||||
}
|
||||
|
||||
if len(apiKey) < 20 {
|
||||
return fmt.Errorf("API key appears to be too short")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAnthropicAPIKey retrieves an Anthropic API key from multiple sources in priority order:
|
||||
// 1. Command-line flag value (highest priority)
|
||||
// 2. Stored credentials (OAuth or API key)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -30,6 +31,7 @@ type OAuthClient struct {
|
||||
type AuthData struct {
|
||||
URL string
|
||||
Verifier string
|
||||
State string // Optional state parameter for CSRF protection
|
||||
}
|
||||
|
||||
// NewOAuthClient creates a new OAuth client configured for Anthropic's OAuth service.
|
||||
@@ -199,6 +201,270 @@ func (c *OAuthClient) parseCodeAndState(code string) (parsedCode, parsedState st
|
||||
return
|
||||
}
|
||||
|
||||
// OpenAIOAuthClient handles OAuth 2.0 authentication flow with OpenAI Codex (ChatGPT Plus/Pro).
|
||||
// This uses OpenAI's auth0-based OAuth service for ChatGPT account authentication.
|
||||
type OpenAIOAuthClient struct {
|
||||
ClientID string
|
||||
AuthorizeURL string
|
||||
TokenURL string
|
||||
RedirectURI string
|
||||
Scopes string
|
||||
}
|
||||
|
||||
// NewOpenAIOAuthClient creates a new OAuth client configured for OpenAI Codex OAuth.
|
||||
// This uses the public client ID for CLI applications with PKCE for security.
|
||||
func NewOpenAIOAuthClient() *OpenAIOAuthClient {
|
||||
return &OpenAIOAuthClient{
|
||||
// Public client ID for OpenAI Codex CLI OAuth
|
||||
ClientID: "app_EMoamEEZ73f0CkXaXp7hrann",
|
||||
AuthorizeURL: "https://auth.openai.com/oauth/authorize",
|
||||
TokenURL: "https://auth.openai.com/oauth/token",
|
||||
RedirectURI: "http://localhost:1455/auth/callback",
|
||||
Scopes: "openid profile email offline_access",
|
||||
}
|
||||
}
|
||||
|
||||
// GetAuthorizationURL generates a complete authorization URL for the OAuth flow with
|
||||
// PKCE parameters. Returns an AuthData structure containing the URL for user
|
||||
// authentication and the PKCE verifier for the subsequent code exchange.
|
||||
func (c *OpenAIOAuthClient) GetAuthorizationURL() (*AuthData, error) {
|
||||
verifier, challenge, err := generatePKCE()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate PKCE: %w", err)
|
||||
}
|
||||
|
||||
// Generate random state
|
||||
stateBytes := make([]byte, 16)
|
||||
if _, err := rand.Read(stateBytes); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
}
|
||||
state := fmt.Sprintf("%x", stateBytes)
|
||||
|
||||
params := url.Values{
|
||||
"response_type": {"code"},
|
||||
"client_id": {c.ClientID},
|
||||
"redirect_uri": {c.RedirectURI},
|
||||
"scope": {c.Scopes},
|
||||
"code_challenge": {challenge},
|
||||
"code_challenge_method": {"S256"},
|
||||
"state": {state},
|
||||
"id_token_add_organizations": {"true"},
|
||||
"codex_cli_simplified_flow": {"true"},
|
||||
"originator": {"kit"},
|
||||
}
|
||||
|
||||
authURL := fmt.Sprintf("%s?%s", c.AuthorizeURL, params.Encode())
|
||||
|
||||
return &AuthData{
|
||||
URL: authURL,
|
||||
Verifier: verifier,
|
||||
State: state,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges an authorization code for access and refresh tokens.
|
||||
// The code parameter should be the authorization code received from the OAuth callback.
|
||||
// The verifier parameter must be the same PKCE verifier generated during GetAuthorizationURL.
|
||||
// Returns OpenAICredentials containing the tokens, expiration, and account ID.
|
||||
func (c *OpenAIOAuthClient) ExchangeCode(code, verifier string) (*OpenAICredentials, error) {
|
||||
return c.exchangeAuthorizationCode(code, verifier, c.RedirectURI)
|
||||
}
|
||||
|
||||
// exchangeAuthorizationCode performs the token exchange with the OAuth server
|
||||
func (c *OpenAIOAuthClient) exchangeAuthorizationCode(code, verifier, redirectUri string) (*OpenAICredentials, error) {
|
||||
data := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"client_id": {c.ClientID},
|
||||
"code": {code},
|
||||
"code_verifier": {verifier},
|
||||
"redirect_uri": {redirectUri},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), "POST", c.TokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make token request: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("token exchange failed: %s", string(body))
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode token response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.AccessToken == "" || tokenResp.RefreshToken == "" {
|
||||
return nil, fmt.Errorf("token response missing required fields")
|
||||
}
|
||||
|
||||
// Extract account ID from JWT token
|
||||
accountID := extractOpenAIAccountID(tokenResp.AccessToken)
|
||||
if accountID == "" {
|
||||
return nil, fmt.Errorf("failed to extract account ID from token")
|
||||
}
|
||||
|
||||
return &OpenAICredentials{
|
||||
Type: "oauth",
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
||||
CreatedAt: time.Now(),
|
||||
AccountID: accountID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an expired or expiring access token using a refresh token.
|
||||
// Returns new OpenAICredentials with updated access token, refresh token (may be
|
||||
// rotated), and new expiration timestamp. Returns an error if the refresh fails or
|
||||
// the refresh token is invalid.
|
||||
func (c *OpenAIOAuthClient) RefreshToken(refreshToken string) (*OpenAICredentials, error) {
|
||||
data := url.Values{
|
||||
"grant_type": {"refresh_token"},
|
||||
"refresh_token": {refreshToken},
|
||||
"client_id": {c.ClientID},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), "POST", c.TokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make refresh request: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("token refresh failed: %s", string(body))
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode refresh response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.AccessToken == "" || tokenResp.RefreshToken == "" {
|
||||
return nil, fmt.Errorf("refresh response missing required fields")
|
||||
}
|
||||
|
||||
// Extract account ID from JWT token
|
||||
accountID := extractOpenAIAccountID(tokenResp.AccessToken)
|
||||
if accountID == "" {
|
||||
return nil, fmt.Errorf("failed to extract account ID from refreshed token")
|
||||
}
|
||||
|
||||
return &OpenAICredentials{
|
||||
Type: "oauth",
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
||||
CreatedAt: time.Now(),
|
||||
AccountID: accountID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extractOpenAIAccountID extracts the ChatGPT account ID from a JWT access token.
|
||||
// The account ID is stored in the claim path https://api.openai.com/auth.chatgpt_account_id
|
||||
func extractOpenAIAccountID(token string) string {
|
||||
// JWT tokens are base64-encoded JSON payloads
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Decode payload (second part)
|
||||
payload := parts[1]
|
||||
// Add padding if needed
|
||||
if len(payload)%4 != 0 {
|
||||
payload += strings.Repeat("=", 4-len(payload)%4)
|
||||
}
|
||||
|
||||
decoded, err := base64.URLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var claims map[string]any
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Navigate to the claim path: https://api.openai.com/auth.chatgpt_account_id
|
||||
authPath, ok := claims["https://api.openai.com/auth"].(map[string]any)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
accountID, ok := authPath["chatgpt_account_id"].(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
return accountID
|
||||
}
|
||||
|
||||
// ParseOpenAIAuthorizationInput parses various forms of authorization input:
|
||||
// - Full callback URL: http://localhost:1455/auth/callback?code=xxx&state=yyy
|
||||
// - Code#State format: abc123#state456
|
||||
// - Query string: code=abc123&state=state456
|
||||
// - Just the code: abc123
|
||||
func ParseOpenAIAuthorizationInput(input string) (code, state string) {
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// Try parsing as URL
|
||||
if strings.HasPrefix(input, "http") {
|
||||
if u, err := url.Parse(input); err == nil {
|
||||
return u.Query().Get("code"), u.Query().Get("state")
|
||||
}
|
||||
}
|
||||
|
||||
// Try code#state format
|
||||
if strings.Contains(input, "#") {
|
||||
parts := strings.SplitN(input, "#", 2)
|
||||
return parts[0], parts[1]
|
||||
}
|
||||
|
||||
// Try query string format
|
||||
if strings.Contains(input, "code=") {
|
||||
if values, err := url.ParseQuery(input); err == nil {
|
||||
return values.Get("code"), values.Get("state")
|
||||
}
|
||||
}
|
||||
|
||||
// Assume it's just the code
|
||||
return input, ""
|
||||
}
|
||||
|
||||
// SetOAuthCredentials stores OAuth credentials in the credential manager's secure storage.
|
||||
// The credentials should include access token, refresh token, and expiration information.
|
||||
// Returns an error if the credentials cannot be saved.
|
||||
|
||||
@@ -157,6 +157,32 @@ type Theme struct {
|
||||
Markdown MarkdownThemeConfig `json:"markdown,omitzero" yaml:"markdown,omitempty"`
|
||||
}
|
||||
|
||||
// CustomModelConfig defines a custom model that can be used with custom/custom
|
||||
// or other custom/ prefixed models. These models are loaded from the config file
|
||||
// and merged into the custom provider in the model registry.
|
||||
type CustomModelConfig struct {
|
||||
Name string `json:"name" yaml:"name"`
|
||||
Family string `json:"family,omitempty" yaml:"family,omitempty"`
|
||||
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
|
||||
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
|
||||
Temperature bool `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
|
||||
Cost CostConfig `json:"cost" yaml:"cost"`
|
||||
Limit LimitConfig `json:"limit" yaml:"limit"`
|
||||
}
|
||||
|
||||
// CostConfig defines the pricing for a custom model.
|
||||
type CostConfig struct {
|
||||
Input float64 `json:"input" yaml:"input"`
|
||||
Output float64 `json:"output" yaml:"output"`
|
||||
}
|
||||
|
||||
// LimitConfig defines context and output limits for a custom model.
|
||||
type LimitConfig struct {
|
||||
Context int `json:"context" yaml:"context"`
|
||||
Output int `json:"output" yaml:"output"`
|
||||
}
|
||||
|
||||
// Config represents the complete application configuration including MCP servers,
|
||||
// model settings, UI preferences, and API credentials. It supports both command-line
|
||||
// flags and configuration file settings.
|
||||
@@ -187,6 +213,9 @@ type Config struct {
|
||||
// Prompt templates configuration
|
||||
Prompts []string `json:"prompts,omitempty" yaml:"prompts,omitempty"`
|
||||
NoPromptTemplates bool `json:"no-prompt-templates,omitempty" yaml:"no-prompt-templates,omitempty"`
|
||||
|
||||
// Custom model definitions (under custom/ provider)
|
||||
CustomModels map[string]CustomModelConfig `json:"customModels,omitempty" yaml:"customModels,omitempty"`
|
||||
}
|
||||
|
||||
// GetTransportType returns the transport type for the server config, mapping
|
||||
|
||||
+160
-11
@@ -1,17 +1,41 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// ToolOutputCallback is the signature for streaming tool output.
|
||||
// It receives tool call ID, tool name, output chunk, and whether it's stderr.
|
||||
type ToolOutputCallback func(toolCallID, toolName, chunk string, isStderr bool)
|
||||
|
||||
// contextKey is a custom type for context keys to avoid collisions.
|
||||
type contextKey string
|
||||
|
||||
const toolOutputCallbackKey contextKey = "toolOutputCallback"
|
||||
|
||||
// ContextWithToolOutputCallback returns a new context with the tool output callback set.
|
||||
func ContextWithToolOutputCallback(ctx context.Context, callback ToolOutputCallback) context.Context {
|
||||
return context.WithValue(ctx, toolOutputCallbackKey, callback)
|
||||
}
|
||||
|
||||
// toolOutputCallbackFromContext retrieves the tool output callback from context.
|
||||
func toolOutputCallbackFromContext(ctx context.Context) ToolOutputCallback {
|
||||
if cb, ok := ctx.Value(toolOutputCallbackKey).(ToolOutputCallback); ok {
|
||||
return cb
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const defaultBashTimeout = 120 * time.Second
|
||||
const maxBashTimeout = 600 * time.Second
|
||||
|
||||
@@ -99,32 +123,157 @@ func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
}
|
||||
cmd.Env = append(os.Environ(), "SHELL="+bashPath)
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
// Get the output callback if present (for streaming support)
|
||||
outputCallback := toolOutputCallbackFromContext(ctx)
|
||||
|
||||
err = cmd.Run()
|
||||
if outputCallback != nil {
|
||||
// Streaming mode: use pipes to capture output as it arrives
|
||||
return executeBashStreaming(cmdCtx, call, cmd, outputCallback)
|
||||
}
|
||||
|
||||
// Non-streaming mode: collect all output at once (original behavior)
|
||||
return executeBashBuffered(cmdCtx, call, cmd)
|
||||
}
|
||||
|
||||
// executeBashBuffered collects all output before returning (original behavior).
|
||||
// It uses explicit pipes (not cmd.Stdout) so that cmd.WaitDelay can forcibly
|
||||
// close them when grandchild processes hold pipe handles open after the
|
||||
// direct child exits.
|
||||
func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd) (fantasy.ToolResponse, error) {
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stdout pipe"), nil
|
||||
}
|
||||
stderrPipe, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stderr pipe"), nil
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to start command: %v", err)), nil
|
||||
}
|
||||
|
||||
// Read pipes concurrently
|
||||
var wg sync.WaitGroup
|
||||
var stdout, stderr strings.Builder
|
||||
var stdoutErr, stderrErr error
|
||||
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, stdoutErr = io.Copy(&stdout, stdoutPipe)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, stderrErr = io.Copy(&stderr, stderrPipe)
|
||||
}()
|
||||
|
||||
// Wait for the process to exit first. cmd.WaitDelay ensures that if
|
||||
// pipes remain open (held by grandchild processes), they'll be forcibly
|
||||
// closed after the grace period, which unblocks the io.Copy goroutines.
|
||||
waitErr := cmd.Wait()
|
||||
|
||||
// Wait for pipe readers to finish draining.
|
||||
wg.Wait()
|
||||
|
||||
// Ignore pipe read errors caused by WaitDelay force-closing —
|
||||
// we still have whatever was read before the close.
|
||||
_ = stdoutErr
|
||||
_ = stderrErr
|
||||
|
||||
exitCode := 0
|
||||
if waitErr != nil {
|
||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||
exitCode = exitErr.ExitCode()
|
||||
} else if cmdCtx.Err() == context.DeadlineExceeded {
|
||||
return fantasy.NewTextErrorResponse("command timed out"), nil
|
||||
}
|
||||
}
|
||||
|
||||
return buildBashResponse(stdout.String(), stderr.String(), exitCode)
|
||||
}
|
||||
|
||||
// executeBashStreaming streams output as it arrives via the callback.
|
||||
func executeBashStreaming(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd, outputCallback ToolOutputCallback) (fantasy.ToolResponse, error) {
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stdout pipe"), nil
|
||||
}
|
||||
stderrPipe, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stderr pipe"), nil
|
||||
}
|
||||
|
||||
// Start command execution
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to start command: %v", err)), nil
|
||||
}
|
||||
|
||||
// Stream stdout and stderr concurrently
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
var stdoutChunks, stderrChunks []string
|
||||
|
||||
streamOutput := func(reader io.Reader, isStderr bool) {
|
||||
defer wg.Done()
|
||||
scanner := bufio.NewScanner(reader)
|
||||
// Use larger buffer for long lines
|
||||
buf := make([]byte, 0, 64*1024)
|
||||
scanner.Buffer(buf, 1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
chunk := scanner.Text()
|
||||
// Send chunk to UI
|
||||
outputCallback(call.ID, "bash", chunk, isStderr)
|
||||
// Collect for final result
|
||||
mu.Lock()
|
||||
if isStderr {
|
||||
stderrChunks = append(stderrChunks, chunk)
|
||||
} else {
|
||||
stdoutChunks = append(stdoutChunks, chunk)
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
wg.Add(2)
|
||||
go streamOutput(stdoutPipe, false)
|
||||
go streamOutput(stderrPipe, true)
|
||||
|
||||
// Wait for the process to exit. cmd.WaitDelay ensures that if pipes
|
||||
// remain open (held by grandchild processes), they'll be forcibly closed
|
||||
// after the grace period, which unblocks the scanners above.
|
||||
err = cmd.Wait()
|
||||
|
||||
// Wait for the pipe readers to finish draining. This will complete
|
||||
// quickly since cmd.Wait() (with WaitDelay) has already ensured
|
||||
// the pipes are closed.
|
||||
wg.Wait()
|
||||
|
||||
exitCode := 0
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
exitCode = exitErr.ExitCode()
|
||||
} else if cmdCtx.Err() == context.DeadlineExceeded {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("command timed out after %v", timeout)), nil
|
||||
return fantasy.NewTextErrorResponse("command timed out"), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Build result
|
||||
return buildBashResponse(strings.Join(stdoutChunks, "\n"), strings.Join(stderrChunks, "\n"), exitCode)
|
||||
}
|
||||
|
||||
// buildBashResponse constructs the final tool response from stdout/stderr.
|
||||
func buildBashResponse(stdout, stderr string, exitCode int) (fantasy.ToolResponse, error) {
|
||||
var result strings.Builder
|
||||
if stdout.Len() > 0 {
|
||||
result.WriteString(stdout.String())
|
||||
if stdout != "" {
|
||||
result.WriteString(stdout)
|
||||
}
|
||||
if stderr.Len() > 0 {
|
||||
if stderr != "" {
|
||||
if result.Len() > 0 {
|
||||
result.WriteString("\n")
|
||||
}
|
||||
result.WriteString("STDERR:\n")
|
||||
result.WriteString(stderr.String())
|
||||
result.WriteString(stderr)
|
||||
}
|
||||
if exitCode != 0 {
|
||||
if result.Len() > 0 {
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// helper to create a bash tool call with the given command and optional timeout.
|
||||
func bashCall(command string, timeout float64) fantasy.ToolCall {
|
||||
args := map[string]any{"command": command}
|
||||
if timeout > 0 {
|
||||
args["timeout"] = timeout
|
||||
}
|
||||
input, _ := json.Marshal(args)
|
||||
return fantasy.ToolCall{
|
||||
ID: "test-call",
|
||||
Name: "bash",
|
||||
Input: string(input),
|
||||
}
|
||||
}
|
||||
|
||||
func TestBash_SimpleCommand(t *testing.T) {
|
||||
resp, err := executeBash(context.Background(), bashCall("echo hello", 0), "")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("expected success, got error: %s", resp.Content)
|
||||
}
|
||||
if resp.Content != "hello\n" {
|
||||
t.Errorf("expected 'hello\\n', got %q", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBash_TimeoutKillsProcess(t *testing.T) {
|
||||
start := time.Now()
|
||||
resp, err := executeBash(context.Background(), bashCall("sleep 60", 2), "")
|
||||
elapsed := time.Since(start)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Fatal("expected error response for timed-out command")
|
||||
}
|
||||
if elapsed > 10*time.Second {
|
||||
t.Errorf("command took %v, expected ~2s timeout", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBash_BackgroundProcessDoesNotHang(t *testing.T) {
|
||||
// This command spawns a background sleep that would hold pipes open
|
||||
// forever if we didn't have process group killing + WaitDelay.
|
||||
start := time.Now()
|
||||
resp, err := executeBash(context.Background(), bashCall("echo done; sleep 3600 &", 5), "")
|
||||
elapsed := time.Since(start)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
// The foreground command (echo) should complete quickly
|
||||
if elapsed > 5*time.Second {
|
||||
t.Errorf("command took %v, should complete in <5s (background process should not block)", elapsed)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("expected success, got error: %s", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBash_BackgroundProcessDoesNotHang_Streaming(t *testing.T) {
|
||||
// Same test but in streaming mode (with output callback).
|
||||
ctx := ContextWithToolOutputCallback(context.Background(), func(_, _, _ string, _ bool) {})
|
||||
start := time.Now()
|
||||
resp, err := executeBash(ctx, bashCall("echo streaming; sleep 3600 &", 5), "")
|
||||
elapsed := time.Since(start)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if elapsed > 5*time.Second {
|
||||
t.Errorf("streaming command took %v, should complete in <5s", elapsed)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("expected success, got error: %s", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBash_ContextCancellation(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
_, _ = executeBash(ctx, bashCall("sleep 60", 0), "")
|
||||
}()
|
||||
|
||||
// Cancel after a short delay
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
// Should return promptly after cancellation
|
||||
select {
|
||||
case <-done:
|
||||
// success
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("executeBash did not return after context cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBash_BannedCommand(t *testing.T) {
|
||||
resp, err := executeBash(context.Background(), bashCall("alias foo=bar", 0), "")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Fatal("expected error for banned command")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBash_EmptyCommand(t *testing.T) {
|
||||
resp, err := executeBash(context.Background(), bashCall("", 0), "")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Fatal("expected error for empty command")
|
||||
}
|
||||
}
|
||||
+234
-44
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
@@ -13,19 +14,45 @@ import (
|
||||
udiff "github.com/aymanbagabas/go-udiff"
|
||||
)
|
||||
|
||||
type editArgs struct {
|
||||
Path string `json:"path"`
|
||||
// Edit represents a single replacement in a multi-edit operation.
|
||||
type Edit struct {
|
||||
OldText string `json:"old_text"`
|
||||
NewText string `json:"new_text"`
|
||||
}
|
||||
|
||||
// editArgs holds the arguments for the edit tool.
|
||||
// Supports both single-edit mode (old_text/new_text) and multi-edit mode (edits array).
|
||||
type editArgs struct {
|
||||
Path string `json:"path"`
|
||||
OldText string `json:"old_text"` // Single-edit mode
|
||||
NewText string `json:"new_text"` // Single-edit mode
|
||||
Edits []Edit `json:"edits"` // Multi-edit mode
|
||||
}
|
||||
|
||||
// replacement represents a normalized edit ready for processing.
|
||||
type replacement struct {
|
||||
oldText string // normalized old text for matching
|
||||
newText string // normalized new text
|
||||
originalOld string // original old text for metadata
|
||||
originalNew string // original new text for metadata
|
||||
index int // index in the original edits array (for error messages)
|
||||
}
|
||||
|
||||
// matchedReplacement represents a replacement with its match location.
|
||||
type matchedReplacement struct {
|
||||
replacement
|
||||
start int // start index in normalized content
|
||||
end int // end index in normalized content
|
||||
usedFuzzyMatch bool // true if fuzzy matching was used
|
||||
}
|
||||
|
||||
// NewEditTool creates the edit core tool.
|
||||
func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
cfg := ApplyOptions(opts)
|
||||
return &coreTool{
|
||||
info: fantasy.ToolInfo{
|
||||
Name: "edit",
|
||||
Description: "Edit a file by replacing exact text. The old_text must match exactly (including whitespace). Use this for precise, surgical edits. Fails if old_text is not found or matches multiple locations.",
|
||||
Description: "Edit a file by replacing exact text. Supports single edit via old_text/new_text, or multiple edits via the edits array. All edits in the array are matched against the original file content (non-incremental) and must be non-overlapping.",
|
||||
Parameters: map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
@@ -33,14 +60,32 @@ func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
},
|
||||
"old_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Exact text to find and replace (must match exactly)",
|
||||
"description": "Exact text to find and replace (single-edit mode). Must not be used with 'edits' array.",
|
||||
},
|
||||
"new_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "New text to replace the old text with",
|
||||
"description": "New text to replace the old text with (single-edit mode). Must not be used with 'edits' array.",
|
||||
},
|
||||
"edits": map[string]any{
|
||||
"type": "array",
|
||||
"description": "Array of edits for multi-region replacement. Each edit must have unique, non-overlapping old_text. All matches are against the original file content.",
|
||||
"items": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"old_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Exact text to find and replace for this edit",
|
||||
},
|
||||
"new_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "New text for this edit",
|
||||
},
|
||||
},
|
||||
"required": []string{"old_text", "new_text"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Required: []string{"path", "old_text", "new_text"},
|
||||
Required: []string{"path"},
|
||||
},
|
||||
handler: func(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return executeEdit(ctx, call, cfg.WorkDir)
|
||||
@@ -51,7 +96,7 @@ func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
func executeEdit(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
|
||||
var args editArgs
|
||||
if err := parseArgs(call.Input, &args); err != nil {
|
||||
return fantasy.NewTextErrorResponse("path, old_text, and new_text parameters are required"), nil
|
||||
return fantasy.NewTextErrorResponse("failed to parse arguments: " + err.Error()), nil
|
||||
}
|
||||
if args.Path == "" {
|
||||
return fantasy.NewTextErrorResponse("path parameter is required"), nil
|
||||
@@ -69,56 +114,201 @@ func executeEdit(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
|
||||
content := string(contentBytes)
|
||||
|
||||
// Normalize line endings for matching
|
||||
normalized := strings.ReplaceAll(content, "\r\n", "\n")
|
||||
normalizedOld := strings.ReplaceAll(args.OldText, "\r\n", "\n")
|
||||
|
||||
// Try exact match first
|
||||
count := strings.Count(normalized, normalizedOld)
|
||||
|
||||
// If no exact match, try fuzzy matching
|
||||
if count == 0 {
|
||||
if idx, matchLen := fuzzyMatch(normalized, normalizedOld); idx >= 0 {
|
||||
// Apply fuzzy match — the matched text is the original content slice
|
||||
matchedText := normalized[idx : idx+matchLen]
|
||||
newContent := normalized[:idx] + args.NewText + normalized[idx+matchLen:]
|
||||
if err := os.WriteFile(absPath, []byte(newContent), 0644); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to write file: %v", err)), nil
|
||||
}
|
||||
diff := generateDiff(absPath, normalized, newContent)
|
||||
resp := fantasy.NewTextResponse(fmt.Sprintf("Applied edit (fuzzy match) to %s\n%s", args.Path, diff))
|
||||
return fantasy.WithResponseMetadata(resp, editDiffMeta(absPath, matchedText, args.NewText)), nil
|
||||
}
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("old_text not found in %s", args.Path)), nil
|
||||
// Normalize and validate input
|
||||
replacements, err := normalizeEditInput(args)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
if count > 1 {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("found %d matches for old_text in %s. Provide more context to identify the correct match.", count, args.Path)), nil
|
||||
// Apply all edits
|
||||
newContent, applied, err := applyEdits(content, replacements)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
// Apply the edit
|
||||
newContent := strings.Replace(normalized, normalizedOld, args.NewText, 1)
|
||||
|
||||
// Write the file
|
||||
if err := os.WriteFile(absPath, []byte(newContent), 0644); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to write file: %v", err)), nil
|
||||
}
|
||||
|
||||
diff := generateDiff(absPath, normalized, newContent)
|
||||
resp := fantasy.NewTextResponse(fmt.Sprintf("Applied edit to %s\n%s", args.Path, diff))
|
||||
return fantasy.WithResponseMetadata(resp, editDiffMeta(absPath, normalizedOld, args.NewText)), nil
|
||||
// Generate diff
|
||||
normalizedContent := strings.ReplaceAll(content, "\r\n", "\n")
|
||||
diff := generateDiff(absPath, normalizedContent, newContent)
|
||||
|
||||
// Build response with fuzzy match indication
|
||||
fuzzyCount := 0
|
||||
for _, m := range applied {
|
||||
if m.usedFuzzyMatch {
|
||||
fuzzyCount++
|
||||
}
|
||||
}
|
||||
|
||||
var msg string
|
||||
if len(applied) == 1 {
|
||||
if fuzzyCount > 0 {
|
||||
msg = fmt.Sprintf("Applied edit (fuzzy match) to %s\n%s", args.Path, diff)
|
||||
} else {
|
||||
msg = fmt.Sprintf("Applied edit to %s\n%s", args.Path, diff)
|
||||
}
|
||||
} else {
|
||||
if fuzzyCount > 0 {
|
||||
msg = fmt.Sprintf("Applied %d edits (%d fuzzy) to %s\n%s", len(applied), fuzzyCount, args.Path, diff)
|
||||
} else {
|
||||
msg = fmt.Sprintf("Applied %d edits to %s\n%s", len(applied), args.Path, diff)
|
||||
}
|
||||
}
|
||||
|
||||
resp := fantasy.NewTextResponse(msg)
|
||||
return fantasy.WithResponseMetadata(resp, editDiffMeta(absPath, applied)), nil
|
||||
}
|
||||
|
||||
// normalizeEditInput validates and normalizes the edit input.
|
||||
// Returns error if both single-edit and multi-edit modes are used.
|
||||
func normalizeEditInput(args editArgs) ([]replacement, error) {
|
||||
singleMode := args.OldText != "" || args.NewText != ""
|
||||
multiMode := len(args.Edits) > 0
|
||||
|
||||
if singleMode && multiMode {
|
||||
return nil, fmt.Errorf("cannot use old_text/new_text together with edits array")
|
||||
}
|
||||
|
||||
if !singleMode && !multiMode {
|
||||
return nil, fmt.Errorf("must provide either old_text/new_text or edits array")
|
||||
}
|
||||
|
||||
if singleMode {
|
||||
if args.OldText == "" {
|
||||
return nil, fmt.Errorf("old_text is required when using single-edit mode")
|
||||
}
|
||||
if args.NewText == "" {
|
||||
return nil, fmt.Errorf("new_text is required when using single-edit mode")
|
||||
}
|
||||
return []replacement{{
|
||||
oldText: strings.ReplaceAll(args.OldText, "\r\n", "\n"),
|
||||
newText: strings.ReplaceAll(args.NewText, "\r\n", "\n"),
|
||||
originalOld: args.OldText,
|
||||
originalNew: args.NewText,
|
||||
index: 0,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// Multi-edit mode
|
||||
var reps []replacement
|
||||
for i, edit := range args.Edits {
|
||||
if edit.OldText == "" {
|
||||
return nil, fmt.Errorf("edits[%d].old_text is required", i)
|
||||
}
|
||||
reps = append(reps, replacement{
|
||||
oldText: strings.ReplaceAll(edit.OldText, "\r\n", "\n"),
|
||||
newText: strings.ReplaceAll(edit.NewText, "\r\n", "\n"),
|
||||
originalOld: edit.OldText,
|
||||
originalNew: edit.NewText,
|
||||
index: i,
|
||||
})
|
||||
}
|
||||
return reps, nil
|
||||
}
|
||||
|
||||
// applyEdits applies multiple replacements to the content.
|
||||
// All matches are against the original content (non-incremental).
|
||||
// Returns the new content, the applied matches, and any error.
|
||||
func applyEdits(content string, edits []replacement) (string, []matchedReplacement, error) {
|
||||
normalizedContent := strings.ReplaceAll(content, "\r\n", "\n")
|
||||
|
||||
// Find all matches
|
||||
var matched []matchedReplacement
|
||||
for _, edit := range edits {
|
||||
m, err := findMatch(normalizedContent, edit)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
matched = append(matched, *m)
|
||||
}
|
||||
|
||||
// Sort by position
|
||||
sort.Slice(matched, func(i, j int) bool {
|
||||
return matched[i].start < matched[j].start
|
||||
})
|
||||
|
||||
// Check for overlaps
|
||||
for i := 1; i < len(matched); i++ {
|
||||
if matched[i-1].end > matched[i].start {
|
||||
return "", nil, fmt.Errorf("edits[%d] and edits[%d] overlap; merge them into a single edit",
|
||||
matched[i-1].index, matched[i].index)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply edits in reverse order (end to start) to maintain stable offsets
|
||||
result := normalizedContent
|
||||
for i := len(matched) - 1; i >= 0; i-- {
|
||||
m := matched[i]
|
||||
result = result[:m.start] + m.newText + result[m.end:]
|
||||
}
|
||||
|
||||
return result, matched, nil
|
||||
}
|
||||
|
||||
// findMatch finds a unique match for the edit in the content.
|
||||
// Returns error if not found or ambiguous.
|
||||
func findMatch(content string, edit replacement) (*matchedReplacement, error) {
|
||||
// Try exact match first
|
||||
count := strings.Count(content, edit.oldText)
|
||||
|
||||
if count == 0 {
|
||||
// Try fuzzy match
|
||||
idx, matchLen := fuzzyMatch(content, edit.oldText)
|
||||
if idx < 0 {
|
||||
return nil, fmt.Errorf("edits[%d]: could not find old_text in file. The text must match exactly (including whitespace)", edit.index)
|
||||
}
|
||||
// Use the matched text from content for the replacement
|
||||
matchedText := content[idx : idx+matchLen]
|
||||
return &matchedReplacement{
|
||||
replacement: replacement{
|
||||
oldText: matchedText,
|
||||
newText: edit.newText,
|
||||
originalOld: edit.originalOld,
|
||||
originalNew: edit.originalNew,
|
||||
index: edit.index,
|
||||
},
|
||||
start: idx,
|
||||
end: idx + matchLen,
|
||||
usedFuzzyMatch: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if count > 1 {
|
||||
return nil, fmt.Errorf("found %d matches for edits[%d].old_text; each old_text must be unique, provide more context to identify the correct match", count, edit.index)
|
||||
}
|
||||
|
||||
// Single exact match
|
||||
idx := strings.Index(content, edit.oldText)
|
||||
return &matchedReplacement{
|
||||
replacement: edit,
|
||||
start: idx,
|
||||
end: idx + len(edit.oldText),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// editDiffMeta builds the structured metadata attached to edit tool responses.
|
||||
func editDiffMeta(path, oldText, newText string) map[string]any {
|
||||
func editDiffMeta(path string, applied []matchedReplacement) map[string]any {
|
||||
var diffBlocks []map[string]any
|
||||
totalAdditions, totalDeletions := 0, 0
|
||||
|
||||
for _, m := range applied {
|
||||
diffBlocks = append(diffBlocks, map[string]any{
|
||||
"old_text": m.originalOld,
|
||||
"new_text": m.originalNew,
|
||||
})
|
||||
totalAdditions += strings.Count(m.originalNew, "\n") + 1
|
||||
totalDeletions += strings.Count(m.originalOld, "\n") + 1
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"file_diffs": []map[string]any{{
|
||||
"path": path,
|
||||
"additions": strings.Count(newText, "\n") + 1,
|
||||
"deletions": strings.Count(oldText, "\n") + 1,
|
||||
"diff_blocks": []map[string]any{{
|
||||
"old_text": oldText,
|
||||
"new_text": newText,
|
||||
}},
|
||||
"path": path,
|
||||
"additions": totalAdditions,
|
||||
"deletions": totalDeletions,
|
||||
"diff_blocks": diffBlocks,
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -715,3 +715,315 @@ func TestExecuteEdit_MetadataContainsFileDiffs(t *testing.T) {
|
||||
t.Fatal("file_diffs should be a non-empty array")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Multi-edit tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExecuteEdit_MultiEdit_Basic(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "multi.txt")
|
||||
writeFileOrFail(t, path, "line1\nline2\nline3\nline4\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "line1", NewText: "LINE1"},
|
||||
{OldText: "line3", NewText: "LINE3"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("tool returned error: %s", resp.Content)
|
||||
}
|
||||
|
||||
got, _ := os.ReadFile(path)
|
||||
gotStr := string(got)
|
||||
|
||||
if !strings.Contains(gotStr, "LINE1") {
|
||||
t.Error("first edit not applied: missing LINE1")
|
||||
}
|
||||
if !strings.Contains(gotStr, "LINE3") {
|
||||
t.Error("second edit not applied: missing LINE3")
|
||||
}
|
||||
if !strings.Contains(gotStr, "line2") {
|
||||
t.Error("line2 was modified but should be untouched")
|
||||
}
|
||||
if !strings.Contains(gotStr, "line4") {
|
||||
t.Error("line4 was modified but should be untouched")
|
||||
}
|
||||
|
||||
// Check response mentions multiple edits
|
||||
if !strings.Contains(resp.Content, "2 edits") {
|
||||
t.Errorf("response should mention '2 edits', got: %s", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_NonIncrementalMatching(t *testing.T) {
|
||||
// All edits are matched against the original content, not incrementally
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "noninc.txt")
|
||||
writeFileOrFail(t, path, "aaa\nbbb\nccc\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "aaa", NewText: "AAA"},
|
||||
{OldText: "bbb", NewText: "BBB"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("tool returned error: %s", resp.Content)
|
||||
}
|
||||
|
||||
got, _ := os.ReadFile(path)
|
||||
gotStr := string(got)
|
||||
|
||||
want := "AAA\nBBB\nccc\n"
|
||||
if gotStr != want {
|
||||
t.Errorf("got %q, want %q", gotStr, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_OverlapDetection(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "overlap.txt")
|
||||
writeFileOrFail(t, path, "hello world\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "hello", NewText: "HELLO"},
|
||||
{OldText: "hello world", NewText: "GOODBYE"}, // Overlaps with first edit
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error for overlapping edits")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "overlap") {
|
||||
t.Errorf("expected 'overlap' in error, got: %s", resp.Content)
|
||||
}
|
||||
|
||||
// File should be untouched
|
||||
got, _ := os.ReadFile(path)
|
||||
if string(got) != "hello world\n" {
|
||||
t.Error("file was modified despite error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_DuplicateDetection(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "dup.txt")
|
||||
writeFileOrFail(t, path, "hello\nworld\nhello\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "hello", NewText: "HELLO"},
|
||||
{OldText: "world", NewText: "WORLD"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error for ambiguous old_text (duplicate matches)")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "unique") {
|
||||
t.Errorf("expected 'unique' in error, got: %s", resp.Content)
|
||||
}
|
||||
|
||||
// File should be untouched
|
||||
got, _ := os.ReadFile(path)
|
||||
if string(got) != "hello\nworld\nhello\n" {
|
||||
t.Error("file was modified despite error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_NotFound(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "notfound.txt")
|
||||
writeFileOrFail(t, path, "hello world\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "nonexistent", NewText: "REPLACEMENT"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error for not found")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "edits[0]") {
|
||||
t.Errorf("expected 'edits[0]' in error, got: %s", resp.Content)
|
||||
}
|
||||
|
||||
// File should be untouched
|
||||
got, _ := os.ReadFile(path)
|
||||
if string(got) != "hello world\n" {
|
||||
t.Error("file was modified despite error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_EmptyArray(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "empty.txt")
|
||||
writeFileOrFail(t, path, "hello\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error for empty edits array")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_MixedWithSingleMode(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "mixed.txt")
|
||||
writeFileOrFail(t, path, "hello\n")
|
||||
|
||||
input, _ := json.Marshal(map[string]any{
|
||||
"path": path,
|
||||
"old_text": "hello",
|
||||
"new_text": "HELLO",
|
||||
"edits": []Edit{
|
||||
{OldText: "hello", NewText: "HI"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error when mixing single and multi-edit modes")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "cannot use") {
|
||||
t.Errorf("expected 'cannot use' in error, got: %s", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_FuzzyMatch(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "fuzzy_multi.txt")
|
||||
// File has trailing whitespace
|
||||
original := "func foo() { \n\treturn 1 \n}\nfunc bar() { \n\treturn 2 \n}\n"
|
||||
writeFileOrFail(t, path, original)
|
||||
|
||||
// Search without trailing whitespace (common LLM behavior)
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "func foo() {\n\treturn 1\n}", NewText: "func foo() {\n\treturn 10\n}"},
|
||||
{OldText: "func bar() {\n\treturn 2\n}", NewText: "func bar() {\n\treturn 20\n}"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("tool returned error: %s", resp.Content)
|
||||
}
|
||||
|
||||
got, _ := os.ReadFile(path)
|
||||
gotStr := string(got)
|
||||
|
||||
if !strings.Contains(gotStr, "return 10") {
|
||||
t.Error("first edit not applied")
|
||||
}
|
||||
if !strings.Contains(gotStr, "return 20") {
|
||||
t.Error("second edit not applied")
|
||||
}
|
||||
|
||||
// Response should mention fuzzy match
|
||||
if !strings.Contains(resp.Content, "fuzzy") {
|
||||
t.Errorf("response should mention 'fuzzy', got: %s", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_Metadata(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "meta_multi.txt")
|
||||
writeFileOrFail(t, path, "aaa\nbbb\nccc\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "aaa", NewText: "AAA"},
|
||||
{OldText: "bbb", NewText: "BBB"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("tool returned error: %s", resp.Content)
|
||||
}
|
||||
|
||||
var meta map[string]any
|
||||
if err := json.Unmarshal([]byte(resp.Metadata), &meta); err != nil {
|
||||
t.Fatalf("metadata is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
diffs, ok := meta["file_diffs"].([]any)
|
||||
if !ok || len(diffs) == 0 {
|
||||
t.Fatal("metadata missing file_diffs")
|
||||
}
|
||||
|
||||
firstDiff, ok := diffs[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("first diff is not an object")
|
||||
}
|
||||
|
||||
// Check that diff_blocks contains both edits
|
||||
diffBlocks, ok := firstDiff["diff_blocks"].([]any)
|
||||
if !ok || len(diffBlocks) != 2 {
|
||||
t.Fatalf("expected 2 diff_blocks, got %d", len(diffBlocks))
|
||||
}
|
||||
|
||||
// Verify each block has old_text and new_text
|
||||
for i, block := range diffBlocks {
|
||||
b, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("diff_block[%d] is not an object", i)
|
||||
}
|
||||
if _, ok := b["old_text"]; !ok {
|
||||
t.Fatalf("diff_block[%d] missing old_text", i)
|
||||
}
|
||||
if _, ok := b["new_text"]; !ok {
|
||||
t.Fatalf("diff_block[%d] missing new_text", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+102
-1
@@ -727,6 +727,7 @@ type API struct {
|
||||
onToolCall func(func(ToolCallEvent, Context) *ToolCallResult)
|
||||
onToolExecStart func(func(ToolExecutionStartEvent, Context))
|
||||
onToolExecEnd func(func(ToolExecutionEndEvent, Context))
|
||||
onToolOutput func(func(ToolOutputEvent, Context))
|
||||
onToolResult func(func(ToolResultEvent, Context) *ToolResultResult)
|
||||
onInput func(func(InputEvent, Context) *InputResult)
|
||||
onBeforeAgentStart func(func(BeforeAgentStartEvent, Context) *BeforeAgentStartResult)
|
||||
@@ -749,6 +750,9 @@ type API struct {
|
||||
registerOption func(OptionDef)
|
||||
registerShortcutFn func(ShortcutDef, func(Context))
|
||||
registerMessageRendererFn func(MessageRendererConfig)
|
||||
onSubagentStart func(func(SubagentStartEvent, Context))
|
||||
onSubagentChunk func(func(SubagentChunkEvent, Context))
|
||||
onSubagentEnd func(func(SubagentEndEvent, Context))
|
||||
}
|
||||
|
||||
// OnToolCall registers a handler that fires before a tool executes.
|
||||
@@ -767,12 +771,40 @@ func (a *API) OnToolExecutionEnd(handler func(ToolExecutionEndEvent, Context)) {
|
||||
a.onToolExecEnd(handler)
|
||||
}
|
||||
|
||||
// OnToolOutput registers a handler for streaming tool output chunks.
|
||||
// This fires for each output line as it arrives from tools like bash,
|
||||
// allowing extensions to observe or process output in real-time.
|
||||
func (a *API) OnToolOutput(handler func(ToolOutputEvent, Context)) {
|
||||
a.onToolOutput(handler)
|
||||
}
|
||||
|
||||
// OnToolResult registers a handler that fires after tool execution.
|
||||
// Return a non-nil ToolResultResult to modify the output.
|
||||
func (a *API) OnToolResult(handler func(ToolResultEvent, Context) *ToolResultResult) {
|
||||
a.onToolResult(handler)
|
||||
}
|
||||
|
||||
// OnSubagentStart registers a handler that fires when a spawn_subagent tool
|
||||
// call begins executing. Use the ToolCallID to correlate with subsequent
|
||||
// OnSubagentChunk and OnSubagentEnd events for the same subagent.
|
||||
func (a *API) OnSubagentStart(handler func(SubagentStartEvent, Context)) {
|
||||
a.onSubagentStart(handler)
|
||||
}
|
||||
|
||||
// OnSubagentChunk registers a handler for real-time events from a running
|
||||
// subagent. ChunkType identifies the kind of event ("text", "tool_call",
|
||||
// "tool_result", "tool_execution_start", "tool_execution_end", etc.).
|
||||
// Correlate with OnSubagentStart via the ToolCallID field.
|
||||
func (a *API) OnSubagentChunk(handler func(SubagentChunkEvent, Context)) {
|
||||
a.onSubagentChunk(handler)
|
||||
}
|
||||
|
||||
// OnSubagentEnd registers a handler that fires when a spawn_subagent call
|
||||
// completes. ErrorMsg is non-empty when the subagent failed.
|
||||
func (a *API) OnSubagentEnd(handler func(SubagentEndEvent, Context)) {
|
||||
a.onSubagentEnd(handler)
|
||||
}
|
||||
|
||||
// OnInput registers a handler that fires when user input is received.
|
||||
// Return a non-nil InputResult to transform or handle the input.
|
||||
func (a *API) OnInput(handler func(InputEvent, Context) *InputResult) {
|
||||
@@ -1538,6 +1570,19 @@ type ToolExecutionEndEvent struct {
|
||||
|
||||
func (e ToolExecutionEndEvent) Type() EventType { return ToolExecutionEnd }
|
||||
|
||||
// ToolOutputEvent fires when a tool produces streaming output chunks.
|
||||
// This is primarily used for long-running tools like bash to show output
|
||||
// in real-time as it arrives, before the tool completes.
|
||||
type ToolOutputEvent struct {
|
||||
ToolCallID string
|
||||
ToolName string
|
||||
ToolKind string
|
||||
Chunk string // Output text chunk
|
||||
IsStderr bool // Whether this chunk came from stderr
|
||||
}
|
||||
|
||||
func (e ToolOutputEvent) Type() EventType { return ToolOutput }
|
||||
|
||||
// ToolResultEvent fires after tool execution with the output.
|
||||
type ToolResultEvent struct {
|
||||
ToolCallID string
|
||||
@@ -1760,9 +1805,65 @@ type BeforeCompactResult struct {
|
||||
func (BeforeCompactResult) isResult() {}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Theme types (exposed to Yaegi — concrete structs, string hex colors)
|
||||
// Subagent lifecycle events (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SubagentStartEvent fires when a spawn_subagent tool call begins executing.
|
||||
type SubagentStartEvent struct {
|
||||
// ToolCallID is the LLM-assigned ID of the spawn_subagent tool call.
|
||||
// Use this to correlate SubagentChunkEvent and SubagentEndEvent.
|
||||
ToolCallID string
|
||||
// Task is the task description passed to the subagent.
|
||||
Task string
|
||||
}
|
||||
|
||||
func (e SubagentStartEvent) Type() EventType { return SubagentStart }
|
||||
|
||||
// SubagentChunkEvent fires for each real-time event from a running subagent.
|
||||
// Type field indicates the kind of event; read the relevant fields accordingly.
|
||||
type SubagentChunkEvent struct {
|
||||
// ToolCallID matches the SubagentStartEvent.ToolCallID for this subagent.
|
||||
ToolCallID string
|
||||
// Task is the task description (repeated for convenience).
|
||||
Task string
|
||||
// ChunkType identifies the event kind:
|
||||
// "text" — LLM text chunk (read Content)
|
||||
// "reasoning" — reasoning/thinking delta (read Content)
|
||||
// "tool_call" — subagent called a tool (read ToolName, ToolArgs)
|
||||
// "tool_result" — tool returned a result (read ToolName, ToolResult, IsError)
|
||||
// "tool_execution_start" — tool began executing (read ToolName)
|
||||
// "tool_execution_end" — tool finished executing (read ToolName)
|
||||
// "turn_start" — subagent turn began
|
||||
// "turn_end" — subagent turn ended
|
||||
ChunkType string
|
||||
// Content carries text for "text" and "reasoning" chunk types.
|
||||
Content string
|
||||
// ToolName is set on tool-related chunk types.
|
||||
ToolName string
|
||||
// ToolArgs is the JSON-encoded tool arguments for "tool_call" chunks.
|
||||
ToolArgs string
|
||||
// ToolResult is the tool output for "tool_result" chunks.
|
||||
ToolResult string
|
||||
// IsError is true when a "tool_result" chunk represents an error.
|
||||
IsError bool
|
||||
}
|
||||
|
||||
func (e SubagentChunkEvent) Type() EventType { return SubagentChunk }
|
||||
|
||||
// SubagentEndEvent fires when a spawn_subagent tool call completes.
|
||||
type SubagentEndEvent struct {
|
||||
// ToolCallID matches the SubagentStartEvent.ToolCallID for this subagent.
|
||||
ToolCallID string
|
||||
// Task is the task description.
|
||||
Task string
|
||||
// Response is the subagent's final text response (empty on error).
|
||||
Response string
|
||||
// ErrorMsg is non-empty when the subagent failed.
|
||||
ErrorMsg string
|
||||
}
|
||||
|
||||
func (e SubagentEndEvent) Type() EventType { return SubagentEnd }
|
||||
|
||||
// ThemeColor is an adaptive color pair with light and dark hex values.
|
||||
// Either field may be empty to inherit from the default theme.
|
||||
type ThemeColor struct {
|
||||
|
||||
@@ -19,6 +19,9 @@ const (
|
||||
// ToolExecutionEnd fires when a tool finishes executing.
|
||||
ToolExecutionEnd EventType = "tool_execution_end"
|
||||
|
||||
// ToolOutput fires when a tool produces streaming output chunks.
|
||||
ToolOutput EventType = "tool_output"
|
||||
|
||||
// ToolResult fires after a tool executes. Handlers can modify the result.
|
||||
ToolResult EventType = "tool_result"
|
||||
|
||||
@@ -68,6 +71,18 @@ const (
|
||||
// BeforeCompact fires before context compaction runs. Handlers can
|
||||
// cancel compaction by returning Cancel=true.
|
||||
BeforeCompact EventType = "before_compact"
|
||||
|
||||
// SubagentStart fires when a spawn_subagent tool call begins executing.
|
||||
// Carries the tool call ID and the task description.
|
||||
SubagentStart EventType = "subagent_start"
|
||||
|
||||
// SubagentChunk fires for each real-time event emitted by a running
|
||||
// subagent: text chunks, tool calls, tool results, etc.
|
||||
SubagentChunk EventType = "subagent_chunk"
|
||||
|
||||
// SubagentEnd fires when a spawn_subagent tool call completes (success
|
||||
// or error). Carries the final response and any error message.
|
||||
SubagentEnd EventType = "subagent_end"
|
||||
)
|
||||
|
||||
// AllEventTypes returns every supported event type.
|
||||
@@ -79,6 +94,7 @@ func AllEventTypes() []EventType {
|
||||
SessionStart, SessionShutdown,
|
||||
ModelChange, ContextPrepare,
|
||||
BeforeFork, BeforeSessionSwitch, BeforeCompact,
|
||||
SubagentStart, SubagentChunk, SubagentEnd,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ import "testing"
|
||||
|
||||
func TestAllEventTypes_Count(t *testing.T) {
|
||||
all := AllEventTypes()
|
||||
if len(all) != 18 {
|
||||
t.Fatalf("expected 18 event types, got %d", len(all))
|
||||
if len(all) != 21 {
|
||||
t.Fatalf("expected 21 event types, got %d", len(all))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,6 +55,9 @@ func TestEventType_TypeMethod(t *testing.T) {
|
||||
{BeforeForkEvent{TargetID: "abc"}, BeforeFork},
|
||||
{BeforeSessionSwitchEvent{Reason: "new"}, BeforeSessionSwitch},
|
||||
{BeforeCompactEvent{EstimatedTokens: 1000}, BeforeCompact},
|
||||
{SubagentStartEvent{ToolCallID: "x", Task: "t"}, SubagentStart},
|
||||
{SubagentChunkEvent{ToolCallID: "x", ChunkType: "text"}, SubagentChunk},
|
||||
{SubagentEndEvent{ToolCallID: "x"}, SubagentEnd},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -439,6 +439,12 @@ func loadSingleExtension(path string) (*LoadedExtension, error) {
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onToolOutput: func(h func(ToolOutputEvent, Context)) {
|
||||
reg(ToolOutput, func(e Event, c Context) Result {
|
||||
h(e.(ToolOutputEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onToolResult: func(h func(ToolResultEvent, Context) *ToolResultResult) {
|
||||
reg(ToolResult, func(e Event, c Context) Result {
|
||||
r := h(e.(ToolResultEvent), c)
|
||||
@@ -574,6 +580,24 @@ func loadSingleExtension(path string) (*LoadedExtension, error) {
|
||||
registerShortcutFn: func(def ShortcutDef, handler func(Context)) {
|
||||
ext.Shortcuts = append(ext.Shortcuts, ShortcutEntry{Def: def, Handler: handler})
|
||||
},
|
||||
onSubagentStart: func(h func(SubagentStartEvent, Context)) {
|
||||
reg(SubagentStart, func(e Event, c Context) Result {
|
||||
h(e.(SubagentStartEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onSubagentChunk: func(h func(SubagentChunkEvent, Context)) {
|
||||
reg(SubagentChunk, func(e Event, c Context) Result {
|
||||
h(e.(SubagentChunkEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onSubagentEnd: func(h func(SubagentEndEvent, Context)) {
|
||||
reg(SubagentEnd, func(e Event, c Context) Result {
|
||||
h(e.(SubagentEndEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
// Call Init — the extension registers its handlers, tools, commands.
|
||||
|
||||
@@ -56,11 +56,165 @@ func NewRunner(exts []LoadedExtension) *Runner {
|
||||
}
|
||||
|
||||
// SetContext updates the runtime context (session ID, model, etc.) that is
|
||||
// passed to every handler invocation. Thread-safe.
|
||||
// passed to every handler invocation. Nil function fields are replaced with
|
||||
// safe no-ops so extension handlers never panic on a missing callback.
|
||||
// Thread-safe.
|
||||
func (r *Runner) SetContext(ctx Context) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.ctx = ctx
|
||||
r.ctx = normalizeContext(ctx)
|
||||
}
|
||||
|
||||
// normalizeContext replaces nil function fields in ctx with no-op stubs so
|
||||
// that extension handlers can call any ctx method without a nil-function panic.
|
||||
func normalizeContext(ctx Context) Context {
|
||||
if ctx.Print == nil {
|
||||
ctx.Print = func(string) {}
|
||||
}
|
||||
if ctx.PrintInfo == nil {
|
||||
ctx.PrintInfo = func(string) {}
|
||||
}
|
||||
if ctx.PrintError == nil {
|
||||
ctx.PrintError = func(string) {}
|
||||
}
|
||||
if ctx.PrintBlock == nil {
|
||||
ctx.PrintBlock = func(PrintBlockOpts) {}
|
||||
}
|
||||
if ctx.SendMessage == nil {
|
||||
ctx.SendMessage = func(string) {}
|
||||
}
|
||||
if ctx.CancelAndSend == nil {
|
||||
ctx.CancelAndSend = func(string) {}
|
||||
}
|
||||
if ctx.SetWidget == nil {
|
||||
ctx.SetWidget = func(WidgetConfig) {}
|
||||
}
|
||||
if ctx.RemoveWidget == nil {
|
||||
ctx.RemoveWidget = func(string) {}
|
||||
}
|
||||
if ctx.SetHeader == nil {
|
||||
ctx.SetHeader = func(HeaderFooterConfig) {}
|
||||
}
|
||||
if ctx.RemoveHeader == nil {
|
||||
ctx.RemoveHeader = func() {}
|
||||
}
|
||||
if ctx.SetFooter == nil {
|
||||
ctx.SetFooter = func(HeaderFooterConfig) {}
|
||||
}
|
||||
if ctx.RemoveFooter == nil {
|
||||
ctx.RemoveFooter = func() {}
|
||||
}
|
||||
if ctx.PromptSelect == nil {
|
||||
ctx.PromptSelect = func(PromptSelectConfig) PromptSelectResult {
|
||||
return PromptSelectResult{Cancelled: true}
|
||||
}
|
||||
}
|
||||
if ctx.PromptConfirm == nil {
|
||||
ctx.PromptConfirm = func(PromptConfirmConfig) PromptConfirmResult {
|
||||
return PromptConfirmResult{Cancelled: true}
|
||||
}
|
||||
}
|
||||
if ctx.PromptInput == nil {
|
||||
ctx.PromptInput = func(PromptInputConfig) PromptInputResult {
|
||||
return PromptInputResult{Cancelled: true}
|
||||
}
|
||||
}
|
||||
if ctx.PromptMultiSelect == nil {
|
||||
ctx.PromptMultiSelect = func(PromptMultiSelectConfig) PromptMultiSelectResult {
|
||||
return PromptMultiSelectResult{Cancelled: true}
|
||||
}
|
||||
}
|
||||
if ctx.ShowOverlay == nil {
|
||||
ctx.ShowOverlay = func(OverlayConfig) OverlayResult {
|
||||
return OverlayResult{Cancelled: true, Index: -1}
|
||||
}
|
||||
}
|
||||
if ctx.SetEditor == nil {
|
||||
ctx.SetEditor = func(EditorConfig) {}
|
||||
}
|
||||
if ctx.ResetEditor == nil {
|
||||
ctx.ResetEditor = func() {}
|
||||
}
|
||||
if ctx.SetEditorText == nil {
|
||||
ctx.SetEditorText = func(string) {}
|
||||
}
|
||||
if ctx.SetUIVisibility == nil {
|
||||
ctx.SetUIVisibility = func(UIVisibility) {}
|
||||
}
|
||||
if ctx.SetStatus == nil {
|
||||
ctx.SetStatus = func(string, string, int) {}
|
||||
}
|
||||
if ctx.RemoveStatus == nil {
|
||||
ctx.RemoveStatus = func(string) {}
|
||||
}
|
||||
if ctx.GetContextStats == nil {
|
||||
ctx.GetContextStats = func() ContextStats { return ContextStats{} }
|
||||
}
|
||||
if ctx.GetMessages == nil {
|
||||
ctx.GetMessages = func() []SessionMessage { return nil }
|
||||
}
|
||||
if ctx.GetSessionPath == nil {
|
||||
ctx.GetSessionPath = func() string { return "" }
|
||||
}
|
||||
if ctx.AppendEntry == nil {
|
||||
ctx.AppendEntry = func(string, string) (string, error) { return "", nil }
|
||||
}
|
||||
if ctx.GetEntries == nil {
|
||||
ctx.GetEntries = func(string) []ExtensionEntry { return nil }
|
||||
}
|
||||
if ctx.GetOption == nil {
|
||||
ctx.GetOption = func(string) string { return "" }
|
||||
}
|
||||
if ctx.SetOption == nil {
|
||||
ctx.SetOption = func(string, string) {}
|
||||
}
|
||||
if ctx.SetModel == nil {
|
||||
ctx.SetModel = func(string) error { return nil }
|
||||
}
|
||||
if ctx.GetAvailableModels == nil {
|
||||
ctx.GetAvailableModels = func() []ModelInfoEntry { return nil }
|
||||
}
|
||||
if ctx.EmitCustomEvent == nil {
|
||||
ctx.EmitCustomEvent = func(string, string) {}
|
||||
}
|
||||
if ctx.GetAllTools == nil {
|
||||
ctx.GetAllTools = func() []ToolInfo { return nil }
|
||||
}
|
||||
if ctx.SetActiveTools == nil {
|
||||
ctx.SetActiveTools = func([]string) {}
|
||||
}
|
||||
if ctx.Exit == nil {
|
||||
ctx.Exit = func() {}
|
||||
}
|
||||
if ctx.Complete == nil {
|
||||
ctx.Complete = func(CompleteRequest) (CompleteResponse, error) {
|
||||
return CompleteResponse{}, nil
|
||||
}
|
||||
}
|
||||
if ctx.SuspendTUI == nil {
|
||||
ctx.SuspendTUI = func(callback func()) error { callback(); return nil }
|
||||
}
|
||||
if ctx.RenderMessage == nil {
|
||||
ctx.RenderMessage = func(string, string) {}
|
||||
}
|
||||
if ctx.RegisterTheme == nil {
|
||||
ctx.RegisterTheme = func(string, ThemeColorConfig) {}
|
||||
}
|
||||
if ctx.SetTheme == nil {
|
||||
ctx.SetTheme = func(string) error { return nil }
|
||||
}
|
||||
if ctx.ListThemes == nil {
|
||||
ctx.ListThemes = func() []string { return nil }
|
||||
}
|
||||
if ctx.ReloadExtensions == nil {
|
||||
ctx.ReloadExtensions = func() error { return nil }
|
||||
}
|
||||
if ctx.SpawnSubagent == nil {
|
||||
ctx.SpawnSubagent = func(SubagentConfig) (*SubagentHandle, *SubagentResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
// GetContext returns a snapshot of the current runtime context. Thread-safe.
|
||||
|
||||
@@ -119,6 +119,11 @@ func Symbols() interp.Exports {
|
||||
"SubagentHandle": reflect.ValueOf((*SubagentHandle)(nil)),
|
||||
"SubagentEvent": reflect.ValueOf((*SubagentEvent)(nil)),
|
||||
|
||||
// Subagent lifecycle events
|
||||
"SubagentStartEvent": reflect.ValueOf((*SubagentStartEvent)(nil)),
|
||||
"SubagentChunkEvent": reflect.ValueOf((*SubagentChunkEvent)(nil)),
|
||||
"SubagentEndEvent": reflect.ValueOf((*SubagentEndEvent)(nil)),
|
||||
|
||||
// Theme types
|
||||
"ThemeColor": reflect.ValueOf((*ThemeColor)(nil)),
|
||||
"ThemeColorConfig": reflect.ValueOf((*ThemeColorConfig)(nil)),
|
||||
@@ -128,6 +133,7 @@ func Symbols() interp.Exports {
|
||||
"ToolCallResult": reflect.ValueOf((*ToolCallResult)(nil)),
|
||||
"ToolExecutionStartEvent": reflect.ValueOf((*ToolExecutionStartEvent)(nil)),
|
||||
"ToolExecutionEndEvent": reflect.ValueOf((*ToolExecutionEndEvent)(nil)),
|
||||
"ToolOutputEvent": reflect.ValueOf((*ToolOutputEvent)(nil)),
|
||||
"ToolResultEvent": reflect.ValueOf((*ToolResultEvent)(nil)),
|
||||
"ToolResultResult": reflect.ValueOf((*ToolResultResult)(nil)),
|
||||
"InputEvent": reflect.ValueOf((*InputEvent)(nil)),
|
||||
|
||||
@@ -30,6 +30,12 @@ func NewTestAPI(ext *LoadedExtension) API {
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onToolOutput: func(h func(ToolOutputEvent, Context)) {
|
||||
reg(ToolOutput, func(e Event, c Context) Result {
|
||||
h(e.(ToolOutputEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onToolResult: func(h func(ToolResultEvent, Context) *ToolResultResult) {
|
||||
reg(ToolResult, func(e Event, c Context) Result {
|
||||
r := h(e.(ToolResultEvent), c)
|
||||
@@ -165,5 +171,23 @@ func NewTestAPI(ext *LoadedExtension) API {
|
||||
registerMessageRendererFn: func(config MessageRendererConfig) {
|
||||
ext.MessageRenderers = append(ext.MessageRenderers, config)
|
||||
},
|
||||
onSubagentStart: func(h func(SubagentStartEvent, Context)) {
|
||||
reg(SubagentStart, func(e Event, c Context) Result {
|
||||
h(e.(SubagentStartEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onSubagentChunk: func(h func(SubagentChunkEvent, Context)) {
|
||||
reg(SubagentChunk, func(e Event, c Context) Result {
|
||||
h(e.(SubagentChunkEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onSubagentEnd: func(h func(SubagentEndEvent, Context)) {
|
||||
reg(SubagentEnd, func(e Event, c Context) Result {
|
||||
h(e.(SubagentEndEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,11 +4,38 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// sanitizeToolCallID ensures the ID matches Anthropic's required pattern:
|
||||
// ^[a-zA-Z0-9_-]+$ (alphanumeric, underscores, and hyphens only).
|
||||
// Invalid characters are replaced with underscores.
|
||||
func sanitizeToolCallID(id string) string {
|
||||
var sb strings.Builder
|
||||
for _, r := range id {
|
||||
switch {
|
||||
case (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z'):
|
||||
sb.WriteRune(r)
|
||||
case r >= '0' && r <= '9':
|
||||
sb.WriteRune(r)
|
||||
case r == '_' || r == '-':
|
||||
sb.WriteRune(r)
|
||||
default:
|
||||
// Replace invalid characters with underscore
|
||||
sb.WriteByte('_')
|
||||
}
|
||||
}
|
||||
result := sb.String()
|
||||
// Ensure non-empty (Anthropic requires at least one character)
|
||||
if result == "" {
|
||||
return "tool_0"
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ContentPart is the marker interface for all message content block types.
|
||||
// A message contains a heterogeneous slice of ContentPart values, enabling
|
||||
// rich structured messages that carry text, reasoning, tool calls, tool
|
||||
@@ -312,7 +339,7 @@ func (m *Message) ToFantasyMessages() []fantasy.Message {
|
||||
// Add tool calls
|
||||
for _, tc := range m.ToolCalls() {
|
||||
parts = append(parts, fantasy.ToolCallPart{
|
||||
ToolCallID: tc.ID,
|
||||
ToolCallID: sanitizeToolCallID(tc.ID),
|
||||
ToolName: tc.Name,
|
||||
Input: tc.Input,
|
||||
})
|
||||
@@ -340,7 +367,7 @@ func (m *Message) ToFantasyMessages() []fantasy.Message {
|
||||
}
|
||||
}
|
||||
parts = append(parts, fantasy.ToolResultPart{
|
||||
ToolCallID: result.ToolCallID,
|
||||
ToolCallID: sanitizeToolCallID(result.ToolCallID),
|
||||
Output: output,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSanitizeToolCallID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "valid alphanumeric ID",
|
||||
input: "call_123abc",
|
||||
expected: "call_123abc",
|
||||
},
|
||||
{
|
||||
name: "ID with dots (OpenCode/Kimi style)",
|
||||
input: "call.123.abc",
|
||||
expected: "call_123_abc",
|
||||
},
|
||||
{
|
||||
name: "ID with colons",
|
||||
input: "tool:123:abc",
|
||||
expected: "tool_123_abc",
|
||||
},
|
||||
{
|
||||
name: "ID with special characters",
|
||||
input: "tool@#$%^&*()",
|
||||
expected: "tool_________",
|
||||
},
|
||||
{
|
||||
name: "Anthropic style ID (already valid)",
|
||||
input: "toolu_0123456789ABCDEF",
|
||||
expected: "toolu_0123456789ABCDEF",
|
||||
},
|
||||
{
|
||||
name: "OpenAI style ID (already valid)",
|
||||
input: "call_O17Uplv4lJvD6DVdIvFFeRMw",
|
||||
expected: "call_O17Uplv4lJvD6DVdIvFFeRMw",
|
||||
},
|
||||
{
|
||||
name: "ID with hyphens",
|
||||
input: "my-tool-call-123",
|
||||
expected: "my-tool-call-123",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "tool_0",
|
||||
},
|
||||
{
|
||||
name: "only special characters",
|
||||
input: "@#$%",
|
||||
expected: "____",
|
||||
},
|
||||
{
|
||||
name: "mixed valid and invalid",
|
||||
input: "call_123.abc-def@ghi",
|
||||
expected: "call_123_abc-def_ghi",
|
||||
},
|
||||
{
|
||||
name: "Unicode characters",
|
||||
input: "tool_日本語",
|
||||
expected: "tool____",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := sanitizeToolCallID(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("sanitizeToolCallID(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeToolCallID_MatchesAnthropicPattern(t *testing.T) {
|
||||
// Test that sanitized IDs match Anthropic's required pattern: ^[a-zA-Z0-9_-]+$
|
||||
// This is a simplified check - in reality the pattern allows alphanumeric, underscore, hyphen
|
||||
testIDs := []string{
|
||||
"call.123.abc",
|
||||
"tool:123:def",
|
||||
"id@#$%^&*()",
|
||||
"mixed.valid-id_test",
|
||||
"",
|
||||
}
|
||||
|
||||
for _, id := range testIDs {
|
||||
sanitized := sanitizeToolCallID(id)
|
||||
|
||||
// Verify each character is valid
|
||||
for i, r := range sanitized {
|
||||
valid := (r >= 'a' && r <= 'z') ||
|
||||
(r >= 'A' && r <= 'Z') ||
|
||||
(r >= '0' && r <= '9') ||
|
||||
r == '_' ||
|
||||
r == '-'
|
||||
|
||||
if !valid {
|
||||
t.Errorf("sanitizeToolCallID(%q) = %q, contains invalid character at position %d: %q",
|
||||
id, sanitized, i, string(r))
|
||||
}
|
||||
}
|
||||
|
||||
// Verify non-empty
|
||||
if sanitized == "" {
|
||||
t.Errorf("sanitizeToolCallID(%q) returned empty string", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// loadCustomModelsFromConfig loads custom model definitions from the config file
|
||||
// and returns them as a map of model ID -> ModelInfo. Returns nil if no custom
|
||||
// models are configured.
|
||||
func loadCustomModelsFromConfig() map[string]ModelInfo {
|
||||
if !viper.IsSet("customModels") {
|
||||
return nil
|
||||
}
|
||||
|
||||
var customModels map[string]CustomModelConfig
|
||||
if err := viper.UnmarshalKey("customModels", &customModels); err != nil {
|
||||
log.Printf("Warning: Failed to parse customModels: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make(map[string]ModelInfo, len(customModels))
|
||||
for modelID, cfg := range customModels {
|
||||
info := modelConfigToModelInfo(modelID, cfg)
|
||||
result[modelID] = info
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// modelConfigToModelInfo converts a CustomModelConfig to a ModelInfo.
|
||||
func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
|
||||
return ModelInfo{
|
||||
ID: modelID,
|
||||
Name: cfg.Name,
|
||||
Attachment: cfg.Attachment,
|
||||
Reasoning: cfg.Reasoning,
|
||||
Temperature: cfg.Temperature,
|
||||
Cost: Cost{
|
||||
Input: cfg.Cost.Input,
|
||||
Output: cfg.Cost.Output,
|
||||
},
|
||||
Limit: Limit{
|
||||
Context: cfg.Limit.Context,
|
||||
Output: cfg.Limit.Output,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// CustomModelConfig defines a custom model configuration loaded from the config file.
|
||||
// This is a duplicate here to avoid circular dependencies with internal/config.
|
||||
type CustomModelConfig struct {
|
||||
Name string `json:"name" yaml:"name"`
|
||||
Family string `json:"family,omitempty" yaml:"family,omitempty"`
|
||||
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
|
||||
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
|
||||
Temperature bool `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
|
||||
Cost CostConfig `json:"cost" yaml:"cost"`
|
||||
Limit LimitConfig `json:"limit" yaml:"limit"`
|
||||
}
|
||||
|
||||
// CostConfig defines the pricing for a custom model.
|
||||
type CostConfig struct {
|
||||
Input float64 `json:"input" yaml:"input"`
|
||||
Output float64 `json:"output" yaml:"output"`
|
||||
}
|
||||
|
||||
// LimitConfig defines context and output limits for a custom model.
|
||||
type LimitConfig struct {
|
||||
Context int `json:"context" yaml:"context"`
|
||||
Output int `json:"output" yaml:"output"`
|
||||
}
|
||||
@@ -17,15 +17,21 @@ type modelsDBProvider struct {
|
||||
|
||||
// modelsDBModel represents a model entry from models.dev/api.json.
|
||||
type modelsDBModel struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Family string `json:"family,omitempty"`
|
||||
Attachment bool `json:"attachment"`
|
||||
Reasoning bool `json:"reasoning"`
|
||||
ToolCall bool `json:"tool_call"`
|
||||
Temperature bool `json:"temperature"`
|
||||
Cost modelsDBCost `json:"cost"`
|
||||
Limit modelsDBLimit `json:"limit"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Family string `json:"family,omitempty"`
|
||||
Attachment bool `json:"attachment"`
|
||||
Reasoning bool `json:"reasoning"`
|
||||
ToolCall bool `json:"tool_call"`
|
||||
Temperature bool `json:"temperature"`
|
||||
Cost modelsDBCost `json:"cost"`
|
||||
Limit modelsDBLimit `json:"limit"`
|
||||
Provider *modelsDBModelProvider `json:"provider,omitempty"` // Model-specific provider override
|
||||
}
|
||||
|
||||
// modelsDBModelProvider represents a provider reference within a model.
|
||||
type modelsDBModelProvider struct {
|
||||
NPM string `json:"npm"`
|
||||
}
|
||||
|
||||
// modelsDBCost represents model pricing from models.dev.
|
||||
|
||||
@@ -169,6 +169,9 @@ type ProviderResult struct {
|
||||
// ProviderOptions contains provider-specific options to be passed to the
|
||||
// fantasy agent (e.g. OpenAI Responses API reasoning options).
|
||||
ProviderOptions fantasy.ProviderOptions
|
||||
// SkipMaxOutputTokens indicates that this provider doesn't support the
|
||||
// max_output_tokens parameter (e.g., OpenAI Codex OAuth API).
|
||||
SkipMaxOutputTokens bool
|
||||
}
|
||||
|
||||
// ParseModelString parses a model string in "provider/model" format (e.g. "anthropic/claude-sonnet-4-5").
|
||||
@@ -253,6 +256,8 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
|
||||
return createBedrockProvider(ctx, config, modelName)
|
||||
case "vercel":
|
||||
return createVercelProvider(ctx, config, modelName)
|
||||
case "custom":
|
||||
return createCustomProvider(ctx, config, modelName)
|
||||
default:
|
||||
return autoRouteProvider(ctx, config, provider, modelName, registry)
|
||||
}
|
||||
@@ -261,14 +266,22 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
|
||||
// autoRouteProvider attempts to create a provider by looking up its npm package
|
||||
// in the models.dev database and routing through the appropriate fantasy provider.
|
||||
// For openai-compatible providers, it uses the api URL from models.dev.
|
||||
// Models may have a provider override that specifies a different npm package than
|
||||
// the provider's default (e.g., opencode's claude-opus-4-6 uses @ai-sdk/anthropic).
|
||||
func autoRouteProvider(ctx context.Context, config *ProviderConfig, provider, modelName string, registry *ModelsRegistry) (*ProviderResult, error) {
|
||||
providerInfo := registry.GetProviderInfo(provider)
|
||||
if providerInfo == nil {
|
||||
return nil, fmt.Errorf("unsupported provider: %s (not found in model database)", provider)
|
||||
}
|
||||
|
||||
// Check for model-specific provider override
|
||||
npmPackage := providerInfo.NPM
|
||||
if modelInfo := registry.LookupModel(provider, modelName); modelInfo != nil && modelInfo.ProviderNPM != "" {
|
||||
npmPackage = modelInfo.ProviderNPM
|
||||
}
|
||||
|
||||
// Determine the fantasy provider for this npm package
|
||||
fantasyProvider := npmToFantasyProvider[providerInfo.NPM]
|
||||
fantasyProvider := npmToFantasyProvider[npmPackage]
|
||||
if fantasyProvider == "" && providerInfo.API != "" {
|
||||
// Unknown npm but has API URL → route through openaicompat
|
||||
fantasyProvider = "openaicompat"
|
||||
@@ -288,7 +301,7 @@ func autoRouteProvider(ctx context.Context, config *ProviderConfig, provider, mo
|
||||
}
|
||||
return createAutoRoutedOpenAIProvider(ctx, config, modelName, providerInfo)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported provider: %s (npm: %s has no fantasy mapping)", provider, providerInfo.NPM)
|
||||
return nil, fmt.Errorf("unsupported provider: %s (npm: %s has no fantasy mapping)", provider, npmPackage)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -346,7 +359,10 @@ func createAutoRoutedAnthropicProvider(ctx context.Context, config *ProviderConf
|
||||
opts = append(opts, anthropic.WithAPIKey(apiKey))
|
||||
|
||||
if config.ProviderURL != "" {
|
||||
opts = append(opts, anthropic.WithBaseURL(config.ProviderURL))
|
||||
// The anthropic client appends "/v1/messages" to the base URL.
|
||||
// If the provider URL ends with "/v1", strip it to avoid double "/v1/v1" paths.
|
||||
baseURL := strings.TrimSuffix(config.ProviderURL, "/v1")
|
||||
opts = append(opts, anthropic.WithBaseURL(baseURL))
|
||||
}
|
||||
|
||||
if config.TLSSkipVerify {
|
||||
@@ -608,13 +624,52 @@ func createVertexAnthropicProvider(ctx context.Context, config *ProviderConfig,
|
||||
|
||||
func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
||||
apiKey := config.ProviderAPIKey
|
||||
source := "command-line flag"
|
||||
var accountID string
|
||||
var isCodexOAuth bool
|
||||
|
||||
if apiKey == "" {
|
||||
apiKey = os.Getenv("OPENAI_API_KEY")
|
||||
}
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("OpenAI API key not provided. Use --provider-api-key flag or OPENAI_API_KEY environment variable")
|
||||
// Check stored credentials first
|
||||
cm, err := auth.NewCredentialManager()
|
||||
if err == nil {
|
||||
if creds, err := cm.GetOpenAICredentials(); err == nil && creds != nil {
|
||||
if creds.Type == "oauth" && creds.AccessToken != "" {
|
||||
// For OAuth, get a valid access token (may refresh if needed)
|
||||
token, err := cm.GetValidOpenAIAccessToken()
|
||||
if err == nil && token != "" {
|
||||
apiKey = token
|
||||
accountID = creds.AccountID
|
||||
isCodexOAuth = true
|
||||
source = "stored Codex OAuth credentials"
|
||||
}
|
||||
} else if creds.Type == "api_key" && creds.APIKey != "" {
|
||||
apiKey = creds.APIKey
|
||||
source = "stored API key"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to environment variable
|
||||
if apiKey == "" {
|
||||
apiKey = os.Getenv("OPENAI_API_KEY")
|
||||
source = "OPENAI_API_KEY environment variable"
|
||||
}
|
||||
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("OpenAI API key not provided. Use 'kit auth login openai', --provider-api-key flag, or OPENAI_API_KEY environment variable")
|
||||
}
|
||||
|
||||
if os.Getenv("DEBUG") != "" || os.Getenv("KIT_DEBUG") != "" {
|
||||
fmt.Fprintf(os.Stderr, "Using OpenAI API key from: %s\n", source)
|
||||
}
|
||||
|
||||
// For Codex OAuth, use the ChatGPT backend API with custom headers
|
||||
if isCodexOAuth {
|
||||
return createOpenAICodexProvider(ctx, config, modelName, apiKey, accountID)
|
||||
}
|
||||
|
||||
// Regular OpenAI API key flow
|
||||
var opts []openai.Option
|
||||
opts = append(opts, openai.WithAPIKey(apiKey))
|
||||
opts = append(opts, openai.WithUseResponsesAPI())
|
||||
@@ -643,6 +698,135 @@ func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
return &ProviderResult{Model: model, ProviderOptions: providerOpts}, nil
|
||||
}
|
||||
|
||||
// createOpenAICodexProvider creates a provider for ChatGPT/Codex OAuth tokens.
|
||||
// Uses the chatgpt.com/backend-api/codex endpoint with special headers.
|
||||
func createOpenAICodexProvider(ctx context.Context, config *ProviderConfig, modelName, token, accountID string) (*ProviderResult, error) {
|
||||
// Check for spark models which are not accessible via OAuth
|
||||
if detectCodexModelFamily(modelName) == "gpt-codex-spark" {
|
||||
return nil, fmt.Errorf("gpt-codex-spark models are not accessible via ChatGPT OAuth. " +
|
||||
"These models require special access or a different authentication method. " +
|
||||
"Please use regular Codex models like 'openai/gpt-5.3-codex' instead")
|
||||
}
|
||||
|
||||
// Use the ChatGPT backend API with /codex path
|
||||
baseURL := "https://chatgpt.com/backend-api/codex"
|
||||
if config.ProviderURL != "" {
|
||||
baseURL = config.ProviderURL
|
||||
}
|
||||
|
||||
// Build custom HTTP client with required headers
|
||||
httpClient := createCodexHTTPClient(token, accountID, config.TLSSkipVerify)
|
||||
|
||||
var opts []openai.Option
|
||||
opts = append(opts, openai.WithAPIKey(token))
|
||||
opts = append(opts, openai.WithBaseURL(baseURL))
|
||||
opts = append(opts, openai.WithUseResponsesAPI())
|
||||
opts = append(opts, openai.WithHTTPClient(httpClient))
|
||||
|
||||
provider, err := openai.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OpenAI Codex provider: %w", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OpenAI Codex model: %w", err)
|
||||
}
|
||||
|
||||
providerOpts := buildCodexProviderOptions(config, modelName)
|
||||
|
||||
return &ProviderResult{
|
||||
Model: model,
|
||||
ProviderOptions: providerOpts,
|
||||
SkipMaxOutputTokens: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildCodexProviderOptions returns fantasy.ProviderOptions configured for
|
||||
// OpenAI Codex API. The Codex API requires the system prompt to be passed
|
||||
// as 'instructions' rather than as a system message.
|
||||
func buildCodexProviderOptions(config *ProviderConfig, modelName string) fantasy.ProviderOptions {
|
||||
store := false
|
||||
opts := &openai.ResponsesProviderOptions{
|
||||
Store: &store,
|
||||
}
|
||||
|
||||
if config.SystemPrompt != "" {
|
||||
opts.Instructions = &config.SystemPrompt
|
||||
}
|
||||
|
||||
if openai.IsResponsesReasoningModel(modelName) {
|
||||
opts.ReasoningEffort = thinkingLevelToReasoningEffort(config.ThinkingLevel)
|
||||
}
|
||||
|
||||
return fantasy.ProviderOptions{openai.Name: opts}
|
||||
}
|
||||
|
||||
// detectCodexModelFamily determines the model family from the model name
|
||||
func detectCodexModelFamily(modelName string) string {
|
||||
modelName = strings.ToLower(modelName)
|
||||
if strings.Contains(modelName, "spark") {
|
||||
return "gpt-codex-spark"
|
||||
}
|
||||
if strings.Contains(modelName, "codex-mini") || strings.Contains(modelName, "mini-latest") {
|
||||
return "gpt-codex-mini"
|
||||
}
|
||||
if strings.Contains(modelName, "codex") {
|
||||
return "gpt-codex"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// createCodexHTTPClient creates an HTTP client with headers required for ChatGPT/Codex API
|
||||
func createCodexHTTPClient(token, accountID string, skipVerify bool) *http.Client {
|
||||
var base http.RoundTripper
|
||||
if skipVerify {
|
||||
base = &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
base = http.DefaultTransport
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: &codexTransport{
|
||||
base: base,
|
||||
token: token,
|
||||
accountID: accountID,
|
||||
},
|
||||
Timeout: 120 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// codexTransport is a custom RoundTripper that adds ChatGPT/Codex specific headers
|
||||
type codexTransport struct {
|
||||
base http.RoundTripper
|
||||
token string
|
||||
accountID string
|
||||
}
|
||||
|
||||
func (t *codexTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
newReq := req.Clone(req.Context())
|
||||
|
||||
// Add required headers for ChatGPT/Codex API
|
||||
// These headers mimic the official pi client to avoid Cloudflare blocking
|
||||
newReq.Header.Set("Authorization", "Bearer "+t.token)
|
||||
if t.accountID != "" {
|
||||
newReq.Header.Set("chatgpt-account-id", t.accountID)
|
||||
}
|
||||
newReq.Header.Set("originator", "kit")
|
||||
newReq.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36")
|
||||
newReq.Header.Set("OpenAI-Beta", "responses=experimental")
|
||||
newReq.Header.Set("Accept", "text/event-stream")
|
||||
newReq.Header.Set("Accept-Language", "en-US,en;q=0.9")
|
||||
newReq.Header.Set("Cache-Control", "no-cache")
|
||||
newReq.Header.Set("Pragma", "no-cache")
|
||||
|
||||
return t.base.RoundTrip(newReq)
|
||||
}
|
||||
|
||||
func createGoogleProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
||||
apiKey := firstNonEmpty(
|
||||
config.ProviderAPIKey,
|
||||
@@ -779,6 +963,42 @@ func createVercelProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
return &ProviderResult{Model: model}, nil
|
||||
}
|
||||
|
||||
func createCustomProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
||||
if config.ProviderURL == "" {
|
||||
return nil, fmt.Errorf("custom provider requires --provider-url")
|
||||
}
|
||||
|
||||
apiKey := config.ProviderAPIKey
|
||||
if apiKey == "" {
|
||||
apiKey = os.Getenv("CUSTOM_API_KEY")
|
||||
}
|
||||
if apiKey == "" {
|
||||
// Many local/custom endpoints don't require a key; use a placeholder.
|
||||
apiKey = "custom"
|
||||
}
|
||||
|
||||
var opts []openaicompat.Option
|
||||
opts = append(opts, openaicompat.WithBaseURL(config.ProviderURL))
|
||||
opts = append(opts, openaicompat.WithAPIKey(apiKey))
|
||||
opts = append(opts, openaicompat.WithName("custom"))
|
||||
|
||||
if config.TLSSkipVerify {
|
||||
opts = append(opts, openaicompat.WithHTTPClient(createHTTPClientWithTLSConfig(true)))
|
||||
}
|
||||
|
||||
p, err := openaicompat.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create custom provider: %w", err)
|
||||
}
|
||||
|
||||
model, err := p.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create custom model: %w", err)
|
||||
}
|
||||
|
||||
return &ProviderResult{Model: model}, nil
|
||||
}
|
||||
|
||||
func createOllamaProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
||||
baseURL := "http://localhost:11434"
|
||||
if host := os.Getenv("OLLAMA_HOST"); host != "" {
|
||||
|
||||
@@ -22,6 +22,7 @@ type ModelInfo struct {
|
||||
Temperature bool
|
||||
Cost Cost
|
||||
Limit Limit
|
||||
ProviderNPM string // Model-specific provider npm override (e.g. "@ai-sdk/anthropic")
|
||||
}
|
||||
|
||||
// Cost represents the pricing information for a model.
|
||||
@@ -78,6 +79,10 @@ func buildFromModelsDB() map[string]ProviderInfo {
|
||||
for providerID, dp := range dbProviders {
|
||||
modelsMap := make(map[string]ModelInfo, len(dp.Models))
|
||||
for modelID, dm := range dp.Models {
|
||||
providerNPM := ""
|
||||
if dm.Provider != nil {
|
||||
providerNPM = dm.Provider.NPM
|
||||
}
|
||||
modelsMap[modelID] = ModelInfo{
|
||||
ID: dm.ID,
|
||||
Name: dm.Name,
|
||||
@@ -94,6 +99,7 @@ func buildFromModelsDB() map[string]ProviderInfo {
|
||||
Context: dm.Limit.Context,
|
||||
Output: dm.Limit.Output,
|
||||
},
|
||||
ProviderNPM: providerNPM,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,6 +122,47 @@ func buildFromModelsDB() map[string]ProviderInfo {
|
||||
}
|
||||
}
|
||||
|
||||
// Register the "custom" provider stub for --provider-url without --model.
|
||||
// This allows users to point kit at any OpenAI-compatible endpoint without
|
||||
// needing to specify a model from the database.
|
||||
providers["custom"] = ProviderInfo{
|
||||
ID: "custom",
|
||||
Name: "Custom",
|
||||
Models: map[string]ModelInfo{
|
||||
"custom": {
|
||||
ID: "custom",
|
||||
Name: "Custom",
|
||||
Attachment: false,
|
||||
Reasoning: true,
|
||||
Temperature: true,
|
||||
Cost: Cost{
|
||||
Input: 0,
|
||||
Output: 0,
|
||||
},
|
||||
Limit: Limit{
|
||||
Context: 262_144,
|
||||
Output: 65_536,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Load custom models from config file and merge into custom provider.
|
||||
// Config file models take precedence - if a model ID exists in both
|
||||
// models.dev and config, the config version wins.
|
||||
if customModels := loadCustomModelsFromConfig(); customModels != nil {
|
||||
for modelID, info := range customModels {
|
||||
// Validate custom model config
|
||||
if info.Limit.Context <= 0 {
|
||||
fmt.Fprintf(os.Stderr, "Warning: custom model %q has invalid context limit: %d\n", modelID, info.Limit.Context)
|
||||
}
|
||||
if info.Limit.Output <= 0 {
|
||||
fmt.Fprintf(os.Stderr, "Warning: custom model %q has invalid output limit: %d\n", modelID, info.Limit.Output)
|
||||
}
|
||||
providers["custom"].Models[modelID] = info
|
||||
}
|
||||
}
|
||||
|
||||
return providers
|
||||
}
|
||||
|
||||
@@ -178,6 +225,15 @@ func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) err
|
||||
}
|
||||
}
|
||||
|
||||
// For openai, check stored credentials (OAuth / API key)
|
||||
if provider == "openai" {
|
||||
if cm, err := auth.NewCredentialManager(); err == nil {
|
||||
if has, _ := cm.HasOpenAICredentials(); has {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
envVars, err := r.getRequiredEnvVars(provider)
|
||||
if err != nil {
|
||||
// Unknown provider — nothing to validate
|
||||
|
||||
@@ -96,6 +96,7 @@ func ListAllSessions() ([]SessionInfo, error) {
|
||||
}
|
||||
|
||||
// listSessionsInDir reads all .jsonl files in a directory and extracts session info.
|
||||
// Empty sessions (no messages) are automatically cleaned up and not returned.
|
||||
func listSessionsInDir(dir string) ([]SessionInfo, error) {
|
||||
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
@@ -117,6 +118,11 @@ func listSessionsInDir(dir string) ([]SessionInfo, error) {
|
||||
if err != nil {
|
||||
continue // skip malformed session files
|
||||
}
|
||||
// Clean up and skip empty sessions (no messages)
|
||||
if info.MessageCount == 0 {
|
||||
_ = os.Remove(path)
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, *info)
|
||||
}
|
||||
|
||||
|
||||
@@ -628,6 +628,11 @@ func (tm *TreeManager) MessageCount() int {
|
||||
return count
|
||||
}
|
||||
|
||||
// IsEmpty returns true if the session has no messages (only header).
|
||||
func (tm *TreeManager) IsEmpty() bool {
|
||||
return tm.MessageCount() == 0
|
||||
}
|
||||
|
||||
// Close closes the underlying file handle.
|
||||
func (tm *TreeManager) Close() error {
|
||||
tm.mu.Lock()
|
||||
|
||||
+2
-1
@@ -192,7 +192,8 @@ func (c *CLI) UpdateUsageFromResponse(response *fantasy.Response, inputText stri
|
||||
outputTokens := int(usage.OutputTokens)
|
||||
|
||||
// Validate that the metadata seems reasonable
|
||||
if inputTokens > 0 && outputTokens > 0 {
|
||||
// Use API-reported tokens if input tokens are available (output may be 0 in some cases)
|
||||
if inputTokens > 0 {
|
||||
cacheReadTokens := int(usage.CacheReadTokens)
|
||||
cacheWriteTokens := int(usage.CacheCreationTokens)
|
||||
c.usageTracker.UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens)
|
||||
|
||||
+14
-1
@@ -65,6 +65,10 @@ type InputComponent struct {
|
||||
// hideHint suppresses the "enter submit · ctrl+j..." hint text.
|
||||
hideHint bool
|
||||
|
||||
// agentBusy indicates the agent is currently working. When true, the
|
||||
// hint text shows steering shortcut (Ctrl+S) instead of submit.
|
||||
agentBusy bool
|
||||
|
||||
// pendingImages holds clipboard images attached to the next submission.
|
||||
// Images are added via Ctrl+V and cleared on submit or Ctrl+U.
|
||||
pendingImages []ImageAttachment
|
||||
@@ -514,7 +518,16 @@ func (s *InputComponent) View() tea.View {
|
||||
// Adapt hint text to available width (accounting for left padding of 3).
|
||||
var hint string
|
||||
availableHintWidth := s.width - 3
|
||||
if availableHintWidth >= 67 {
|
||||
if s.agentBusy {
|
||||
// When the agent is working, show steering shortcut.
|
||||
if availableHintWidth >= 55 {
|
||||
hint = "enter queue • ctrl+s steer • esc esc cancel"
|
||||
} else if availableHintWidth >= 35 {
|
||||
hint = "↵ queue • ^S steer • esc×2 cancel"
|
||||
} else {
|
||||
hint = "^S steer"
|
||||
}
|
||||
} else if availableHintWidth >= 67 {
|
||||
hint = "enter submit • ctrl+j / shift+enter new line • ctrl+v paste image"
|
||||
} else if availableHintWidth >= 40 {
|
||||
hint = "↵ submit • ctrl+j newline • ctrl+v image"
|
||||
|
||||
+14
-2
@@ -111,14 +111,26 @@ func formatToolParams(toolArgs string, maxWidth int) string {
|
||||
result.WriteString(primaryVal)
|
||||
}
|
||||
|
||||
// Collect remaining parameters (skip large values like file content)
|
||||
// Collect remaining parameters, skipping body-content keys (already
|
||||
// rendered in the tool body) and any values that are too large.
|
||||
bodyKeys := map[string]bool{
|
||||
"content": true,
|
||||
"old_text": true,
|
||||
"new_text": true,
|
||||
"oldText": true,
|
||||
"newText": true,
|
||||
"edits": true,
|
||||
"todos": true,
|
||||
}
|
||||
var remaining []string
|
||||
for key, val := range params {
|
||||
if key == primaryKey {
|
||||
continue
|
||||
}
|
||||
if bodyKeys[key] {
|
||||
continue
|
||||
}
|
||||
valStr := fmt.Sprintf("%v", val)
|
||||
// Skip very large values (e.g., oldString, newString, content, todos)
|
||||
if len(valStr) > 100 {
|
||||
continue
|
||||
}
|
||||
|
||||
+316
-33
@@ -3,10 +3,12 @@ package ui
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
@@ -82,6 +84,9 @@ type AppController interface {
|
||||
// GetTreeSession returns the tree session manager, or nil if tree sessions
|
||||
// are not enabled. Used by slash commands like /tree, /fork, /session.
|
||||
GetTreeSession() *session.TreeManager
|
||||
// SwitchTreeSession replaces the active tree session with a new one,
|
||||
// closing the old session. Used by /new to create a completely fresh session.
|
||||
SwitchTreeSession(ts *session.TreeManager)
|
||||
// SendEvent sends a tea.Msg to the program asynchronously. Safe to call
|
||||
// from any goroutine. Used by extension command goroutines to deliver
|
||||
// results back to the TUI without going through tea.Cmd (which can stall
|
||||
@@ -97,6 +102,12 @@ type AppController interface {
|
||||
// alongside the text. Returns the current queue depth (0 = started
|
||||
// immediately, >0 = queued).
|
||||
RunWithFiles(prompt string, files []fantasy.FilePart) int
|
||||
// Steer injects a steering message into the currently running agent
|
||||
// turn. If the agent is busy, the message is delivered between steps
|
||||
// (after current tool finishes, before next LLM call). If idle, the
|
||||
// message starts executing immediately. Returns 0 if started
|
||||
// immediately, >0 if injected/pending.
|
||||
Steer(prompt string) int
|
||||
}
|
||||
|
||||
// SkillItem holds display metadata about a loaded skill for the startup
|
||||
@@ -414,6 +425,11 @@ type AppModel struct {
|
||||
// the input and move to scrollback when the agent picks them up.
|
||||
queuedMessages []string
|
||||
|
||||
// steeringMessages stores the text of prompts that were sent as steer
|
||||
// messages (injected mid-turn via Ctrl+S). Rendered with a "STEERING"
|
||||
// badge above the input. Cleared when the steer is consumed.
|
||||
steeringMessages []string
|
||||
|
||||
// pendingUserPrints holds user messages that have been consumed from the
|
||||
// queue but not yet printed to scrollback. They are deferred until
|
||||
// SpinnerEvent{Show: true} so the previous assistant response can be
|
||||
@@ -560,6 +576,18 @@ type AppModel struct {
|
||||
// width and height track the terminal dimensions.
|
||||
width int
|
||||
height int
|
||||
|
||||
// streamingBashOutput holds the current streaming bash output lines.
|
||||
// Lines are accumulated as they arrive and displayed in the stream region.
|
||||
streamingBashOutput []string
|
||||
// streamingBashStderr holds stderr lines separately (rendered differently).
|
||||
streamingBashStderr []string
|
||||
// streamingBashMaxLines caps how many lines to accumulate to prevent memory issues.
|
||||
streamingBashMaxLines int
|
||||
// streamingMu protects the streaming bash output fields from concurrent access.
|
||||
streamingMu sync.RWMutex
|
||||
// streamingBashCommand holds the command being executed for display as a header.
|
||||
streamingBashCommand string
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -670,6 +698,9 @@ func NewAppModel(appCtrl AppController, opts AppModelOptions) *AppModel {
|
||||
m.mcpToolCount = opts.MCPToolCount
|
||||
m.extensionToolCount = opts.ExtensionToolCount
|
||||
|
||||
// Initialize streaming bash output buffer.
|
||||
m.streamingBashMaxLines = 50 // cap to prevent memory issues
|
||||
|
||||
// Wire up child components now that we have the concrete implementations.
|
||||
m.input = NewInputComponent(width, "Enter your prompt (Type /help for commands, Ctrl+C to quit)", appCtrl)
|
||||
|
||||
@@ -1056,6 +1087,45 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
// In other states pass ESC through to children below.
|
||||
|
||||
case "ctrl+s":
|
||||
// Steer: inject the current input as a steering message into the
|
||||
// running agent turn. Only active during stateWorking — in input
|
||||
// state, Ctrl+S is passed through to children (no-op by default).
|
||||
if m.state == stateWorking && m.appCtrl != nil {
|
||||
var text string
|
||||
if ic, ok := m.input.(*InputComponent); ok {
|
||||
text = strings.TrimSpace(ic.textarea.Value())
|
||||
}
|
||||
if text != "" {
|
||||
// Clear the input and push to history.
|
||||
if ic, ok := m.input.(*InputComponent); ok {
|
||||
ic.pushHistory(text)
|
||||
ic.textarea.SetValue("")
|
||||
}
|
||||
|
||||
// Preprocess @file references.
|
||||
processedText := text
|
||||
if m.cwd != "" {
|
||||
processedText = ProcessFileAttachments(text, m.cwd)
|
||||
}
|
||||
|
||||
// Inject the steer message.
|
||||
sLen := m.appCtrl.Steer(processedText)
|
||||
if sLen > 0 {
|
||||
m.steeringMessages = append(m.steeringMessages, text)
|
||||
m.distributeHeight()
|
||||
} else {
|
||||
// Started immediately (agent was idle).
|
||||
m.pendingUserPrints = append(m.pendingUserPrints, text)
|
||||
m.flushStreamAndPendingUserMessages()
|
||||
if m.state != stateWorking {
|
||||
m.state = stateWorking
|
||||
}
|
||||
}
|
||||
}
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
}
|
||||
|
||||
// Route key events to the focused child. Check for editor
|
||||
@@ -1302,6 +1372,18 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// rendered when the ToolResultEvent arrives.
|
||||
m.flushStreamContent()
|
||||
|
||||
// For bash commands, extract and store the command for the streaming output header.
|
||||
if msg.ToolName == "bash" {
|
||||
var args struct {
|
||||
Command string `json:"command"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(msg.ToolArgs), &args); err == nil && args.Command != "" {
|
||||
m.streamingMu.Lock()
|
||||
m.streamingBashCommand = args.Command
|
||||
m.streamingMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
case app.ToolExecutionEvent:
|
||||
// Pass to stream component for execution spinner display.
|
||||
if m.stream != nil {
|
||||
@@ -1312,12 +1394,36 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
case app.ToolResultEvent:
|
||||
// Buffer tool result for scrollback.
|
||||
m.printToolResult(msg)
|
||||
// Clear streaming bash output since tool completed.
|
||||
m.streamingMu.Lock()
|
||||
m.streamingBashOutput = nil
|
||||
m.streamingBashStderr = nil
|
||||
m.streamingBashCommand = ""
|
||||
m.streamingMu.Unlock()
|
||||
// Start spinner again while waiting for the next LLM response.
|
||||
if m.stream != nil {
|
||||
_, cmd := m.stream.Update(app.SpinnerEvent{Show: true})
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
|
||||
case app.ToolOutputEvent:
|
||||
// Accumulate streaming bash output for display.
|
||||
m.streamingMu.Lock()
|
||||
if msg.IsStderr {
|
||||
m.streamingBashStderr = append(m.streamingBashStderr, msg.Chunk)
|
||||
// Cap stderr lines to prevent memory issues.
|
||||
if len(m.streamingBashStderr) > m.streamingBashMaxLines {
|
||||
m.streamingBashStderr = m.streamingBashStderr[len(m.streamingBashStderr)-m.streamingBashMaxLines:]
|
||||
}
|
||||
} else {
|
||||
m.streamingBashOutput = append(m.streamingBashOutput, msg.Chunk)
|
||||
// Cap stdout lines to prevent memory issues.
|
||||
if len(m.streamingBashOutput) > m.streamingBashMaxLines {
|
||||
m.streamingBashOutput = m.streamingBashOutput[len(m.streamingBashOutput)-m.streamingBashMaxLines:]
|
||||
}
|
||||
}
|
||||
m.streamingMu.Unlock()
|
||||
|
||||
case app.ToolCallContentEvent:
|
||||
// In streaming mode this text was already delivered via StreamChunkEvents
|
||||
// and will be flushed before the next tool call. Ignore to avoid
|
||||
@@ -1352,6 +1458,38 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
m.distributeHeight()
|
||||
|
||||
case app.SteerConsumedEvent:
|
||||
// Steering messages were consumed — either injected mid-turn via
|
||||
// PrepareStep, or drained into the queue after a text-only turn.
|
||||
//
|
||||
// Two cases:
|
||||
//
|
||||
// 1. Mid-turn (stateWorking, PrepareStep fired): no SpinnerEvent{Show:
|
||||
// true} will follow within this turn, so we cannot rely on
|
||||
// flushStreamAndPendingUserMessages() being called. Flush any live
|
||||
// stream content first (assistant text up to the steer point), then
|
||||
// render the steering user messages immediately to scrollback.
|
||||
//
|
||||
// 2. Post-turn (text-only response, drained after StepComplete): a
|
||||
// SpinnerEvent{Show: true} for the next turn is already in flight.
|
||||
// Defer to pendingUserPrints so the previous assistant response is
|
||||
// flushed first, preserving chronological order.
|
||||
if m.state == stateWorking {
|
||||
// Case 1: mid-turn — flush + print immediately.
|
||||
m.flushStreamContent()
|
||||
for _, text := range m.steeringMessages {
|
||||
m.printUserMessage(text)
|
||||
}
|
||||
m.steeringMessages = m.steeringMessages[:0]
|
||||
m.distributeHeight()
|
||||
cmds = append(cmds, m.drainScrollback())
|
||||
} else {
|
||||
// Case 2: post-turn — defer so SpinnerEvent orders correctly.
|
||||
m.pendingUserPrints = append(m.pendingUserPrints, m.steeringMessages...)
|
||||
m.steeringMessages = m.steeringMessages[:0]
|
||||
m.distributeHeight()
|
||||
}
|
||||
|
||||
case app.StepCompleteEvent:
|
||||
// Keep stream content visible in the view — don't flush to scrollback
|
||||
// yet. Flushing + resetting in the same frame would shrink the view
|
||||
@@ -1604,6 +1742,7 @@ func (m *AppModel) View() tea.View {
|
||||
// Propagate hint visibility to the input component before rendering.
|
||||
if ic, ok := m.input.(*InputComponent); ok {
|
||||
ic.hideHint = vis.HideInputHint
|
||||
ic.agentBusy = m.state == stateWorking
|
||||
}
|
||||
|
||||
// When a prompt is active, it replaces the input area for consistency
|
||||
@@ -1670,24 +1809,127 @@ func (m *AppModel) View() tea.View {
|
||||
|
||||
// renderStream returns the stream region content.
|
||||
func (m *AppModel) renderStream() string {
|
||||
if m.stream == nil {
|
||||
theme := GetTheme()
|
||||
|
||||
var parts []string
|
||||
|
||||
// Stream component content (LLM streaming text, reasoning, spinner placeholder).
|
||||
if m.stream != nil {
|
||||
if content := m.stream.View().Content; content != "" {
|
||||
parts = append(parts, content)
|
||||
}
|
||||
}
|
||||
|
||||
// Streaming bash output section (if any).
|
||||
bashView := m.renderStreamingBashOutput(theme)
|
||||
if bashView != "" {
|
||||
parts = append(parts, bashView)
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Show canceling warning if set.
|
||||
if m.canceling {
|
||||
theme := GetTheme()
|
||||
warning := lipgloss.NewStyle().
|
||||
Foreground(theme.Warning).
|
||||
Bold(true).
|
||||
Render(" ⚠ Press ESC again to cancel")
|
||||
return lipgloss.JoinVertical(lipgloss.Left,
|
||||
m.stream.View().Content,
|
||||
warning,
|
||||
)
|
||||
parts = append(parts, warning)
|
||||
}
|
||||
|
||||
return m.stream.View().Content
|
||||
return lipgloss.JoinVertical(lipgloss.Left, parts...)
|
||||
}
|
||||
|
||||
// renderStreamingBashOutput renders accumulated streaming bash output (stdout + stderr)
|
||||
// below the LLM streaming text. Returns empty string if no bash output is present.
|
||||
// Lines are truncated to the terminal width and capped to maxBashLines to prevent
|
||||
// long-running commands from blowing up the TUI layout.
|
||||
func (m *AppModel) renderStreamingBashOutput(theme Theme) string {
|
||||
m.streamingMu.RLock()
|
||||
stdoutLines := make([]string, len(m.streamingBashOutput))
|
||||
copy(stdoutLines, m.streamingBashOutput)
|
||||
stderrLines := make([]string, len(m.streamingBashStderr))
|
||||
copy(stderrLines, m.streamingBashStderr)
|
||||
command := m.streamingBashCommand
|
||||
m.streamingMu.RUnlock()
|
||||
|
||||
if len(stdoutLines) == 0 && len(stderrLines) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
const lineIndent = " "
|
||||
lineWidth := max(m.width-2-len(lineIndent), 20)
|
||||
// Account for PaddingLeft(1) on the output/stderr styles.
|
||||
maxLineChars := lineWidth - 1
|
||||
|
||||
outputStyle := lipgloss.NewStyle().
|
||||
Background(theme.CodeBg).
|
||||
PaddingLeft(1)
|
||||
|
||||
stderrStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Error).
|
||||
Background(theme.CodeBg).
|
||||
PaddingLeft(1)
|
||||
|
||||
// Header style for the command - muted text with a subtle indicator.
|
||||
headerStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
PaddingLeft(1)
|
||||
|
||||
// Cap displayed lines to maxBashLines (show the tail, since streaming
|
||||
// output is most useful at the end). The buffer itself is larger to
|
||||
// preserve context, but we only render the last N lines.
|
||||
totalLines := len(stdoutLines) + len(stderrLines)
|
||||
var hiddenCount int
|
||||
if totalLines > maxBashLines {
|
||||
hiddenCount = totalLines - maxBashLines
|
||||
// Trim from stdout first (older output), then stderr.
|
||||
remaining := maxBashLines
|
||||
if len(stderrLines) >= remaining {
|
||||
stdoutLines = nil
|
||||
stderrLines = stderrLines[len(stderrLines)-remaining:]
|
||||
} else {
|
||||
remaining -= len(stderrLines)
|
||||
if len(stdoutLines) > remaining {
|
||||
stdoutLines = stdoutLines[len(stdoutLines)-remaining:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var lines []string
|
||||
|
||||
// Command header - show the bash command being executed.
|
||||
if command != "" {
|
||||
headerText := fmt.Sprintf("$ %s", command)
|
||||
headerContent := headerStyle.Width(lineWidth).Render(truncateLine(headerText, maxLineChars))
|
||||
lines = append(lines, lineIndent+headerContent)
|
||||
}
|
||||
|
||||
// Truncation hint at the top.
|
||||
if hiddenCount > 0 {
|
||||
hint := fmt.Sprintf("...(%d more lines above)", hiddenCount)
|
||||
hintContent := outputStyle.Width(lineWidth).
|
||||
Foreground(theme.Muted).Italic(true).Render(hint)
|
||||
lines = append(lines, lineIndent+hintContent)
|
||||
}
|
||||
|
||||
// Render stdout lines.
|
||||
for _, line := range stdoutLines {
|
||||
line = truncateLine(strings.TrimRight(line, "\n"), maxLineChars)
|
||||
styled := outputStyle.Width(lineWidth).Render(line)
|
||||
lines = append(lines, lineIndent+styled)
|
||||
}
|
||||
|
||||
// Render stderr lines with error styling.
|
||||
for _, line := range stderrLines {
|
||||
line = truncateLine(strings.TrimRight(line, "\n"), maxLineChars)
|
||||
styled := stderrStyle.Width(lineWidth).Render(line)
|
||||
lines = append(lines, lineIndent+styled)
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// renderStatusBar renders a persistent single-line status bar below the input.
|
||||
@@ -1808,16 +2050,26 @@ func (m *AppModel) cycleThinkingLevel() {
|
||||
go func() { _ = SaveThinkingLevelPreference(next) }()
|
||||
}
|
||||
|
||||
// renderSeparator renders the separator line with an optional queue count badge.
|
||||
// renderSeparator renders the separator line with an optional queue/steer count badge.
|
||||
func (m *AppModel) renderSeparator() string {
|
||||
theme := GetTheme()
|
||||
lineStyle := lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
queueLen := len(m.queuedMessages)
|
||||
steerLen := len(m.steeringMessages)
|
||||
|
||||
if queueLen > 0 {
|
||||
badge := lipgloss.NewStyle().
|
||||
Foreground(theme.Secondary).
|
||||
Render(fmt.Sprintf("%d queued", queueLen))
|
||||
if steerLen > 0 || queueLen > 0 {
|
||||
var parts []string
|
||||
if steerLen > 0 {
|
||||
parts = append(parts, lipgloss.NewStyle().
|
||||
Foreground(theme.Warning).
|
||||
Render(fmt.Sprintf("%d steering", steerLen)))
|
||||
}
|
||||
if queueLen > 0 {
|
||||
parts = append(parts, lipgloss.NewStyle().
|
||||
Foreground(theme.Secondary).
|
||||
Render(fmt.Sprintf("%d queued", queueLen)))
|
||||
}
|
||||
badge := strings.Join(parts, " ")
|
||||
|
||||
// Fill the separator with dashes up to the badge.
|
||||
dashWidth := max(m.width-lipgloss.Width(badge)-1, 0)
|
||||
@@ -1916,27 +2168,47 @@ func (m *AppModel) renderHeaderFooter(getter func() *WidgetData) string {
|
||||
return renderContentBlock(data.Text, m.width, opts...)
|
||||
}
|
||||
|
||||
// renderQueuedMessages renders queued prompts as styled content blocks with a
|
||||
// "QUEUED" badge, anchored between the separator and input. Each message is
|
||||
// displayed in a bordered block matching the overall message styling.
|
||||
// renderQueuedMessages renders queued and steering prompts as styled content
|
||||
// blocks with badges, anchored between the separator and input. Steering
|
||||
// messages use a distinct "STEERING" badge to differentiate from queued ones.
|
||||
func (m *AppModel) renderQueuedMessages() string {
|
||||
if len(m.queuedMessages) == 0 {
|
||||
if len(m.queuedMessages) == 0 && len(m.steeringMessages) == 0 {
|
||||
return ""
|
||||
}
|
||||
theme := GetTheme()
|
||||
badge := CreateBadge("QUEUED", theme.Accent)
|
||||
|
||||
var blocks []string
|
||||
for _, msg := range m.queuedMessages {
|
||||
content := msg + "\n" + badge
|
||||
rendered := renderContentBlock(
|
||||
content,
|
||||
m.width,
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(theme.Muted),
|
||||
)
|
||||
blocks = append(blocks, rendered)
|
||||
|
||||
// Render steering messages first (higher priority).
|
||||
if len(m.steeringMessages) > 0 {
|
||||
badge := CreateBadge("STEERING", theme.Warning)
|
||||
for _, msg := range m.steeringMessages {
|
||||
content := msg + "\n" + badge
|
||||
rendered := renderContentBlock(
|
||||
content,
|
||||
m.width,
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(theme.Warning),
|
||||
)
|
||||
blocks = append(blocks, rendered)
|
||||
}
|
||||
}
|
||||
|
||||
// Render queued messages.
|
||||
if len(m.queuedMessages) > 0 {
|
||||
badge := CreateBadge("QUEUED", theme.Accent)
|
||||
for _, msg := range m.queuedMessages {
|
||||
content := msg + "\n" + badge
|
||||
rendered := renderContentBlock(
|
||||
content,
|
||||
m.width,
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(theme.Muted),
|
||||
)
|
||||
blocks = append(blocks, rendered)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(blocks, "\n")
|
||||
}
|
||||
|
||||
@@ -2007,6 +2279,7 @@ func (m *AppModel) handleSlashCommand(sc *SlashCommand) tea.Cmd {
|
||||
m.appCtrl.ClearQueue()
|
||||
}
|
||||
m.queuedMessages = m.queuedMessages[:0]
|
||||
m.steeringMessages = m.steeringMessages[:0]
|
||||
m.distributeHeight()
|
||||
|
||||
case "/tree":
|
||||
@@ -2153,7 +2426,7 @@ func (m *AppModel) printHelpMessage() {
|
||||
"**Navigation:**\n" +
|
||||
"- `/tree`: Navigate session tree (switch branches)\n" +
|
||||
"- `/fork`: Branch from an earlier message\n" +
|
||||
"- `/new`: Start a new branch (preserves history)\n" +
|
||||
"- `/new`: Start a new session (discards context, saves old session)\n" +
|
||||
"- `/resume`: Open session picker to switch sessions\n" +
|
||||
"- `/name <name>`: Set a display name for this session\n\n" +
|
||||
"**System:**\n" +
|
||||
@@ -2194,7 +2467,9 @@ func (m *AppModel) printHelpMessage() {
|
||||
"- `!!command`: Run shell command, output excluded from LLM context\n\n" +
|
||||
"**Keys:**\n" +
|
||||
"- `Ctrl+C`: Exit at any time\n" +
|
||||
"- `ESC` (x2): Cancel ongoing LLM generation\n\n" +
|
||||
"- `ESC` (x2): Cancel ongoing LLM generation\n" +
|
||||
"- `Ctrl+S`: Steer — redirect the agent mid-turn (injected between tool calls)\n" +
|
||||
"- `Enter` (while working): Queue message for after the agent finishes\n\n" +
|
||||
"You can also just type your message to chat with the AI assistant."
|
||||
m.printSystemMessage(help)
|
||||
}
|
||||
@@ -2719,7 +2994,8 @@ func (m *AppModel) handleForkCommand() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleNewCommand starts a fresh session by resetting the tree leaf.
|
||||
// handleNewCommand starts a completely new session (Pi-style /new behavior).
|
||||
// Creates a new session file, discarding all context from the previous conversation.
|
||||
func (m *AppModel) handleNewCommand() tea.Cmd {
|
||||
// Emit before-session-switch event in a goroutine so that extension
|
||||
// handlers can call blocking operations (e.g. ctx.PromptConfirm) without
|
||||
@@ -2742,6 +3018,8 @@ func (m *AppModel) handleNewCommand() tea.Cmd {
|
||||
|
||||
// performNewSession performs the actual session reset. Called either directly
|
||||
// (when no before-hook exists) or after the async hook completes.
|
||||
// Matches Pi behavior: creates a completely new session file, discarding all
|
||||
// context from the previous conversation.
|
||||
func (m *AppModel) performNewSession() tea.Cmd {
|
||||
ts := m.appCtrl.GetTreeSession()
|
||||
if ts == nil {
|
||||
@@ -2753,11 +3031,16 @@ func (m *AppModel) performNewSession() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
ts.ResetLeaf()
|
||||
if m.appCtrl != nil {
|
||||
m.appCtrl.ClearMessages()
|
||||
// Create a brand new session file (Pi-style /new behavior)
|
||||
newTs, err := session.CreateTreeSession(m.cwd)
|
||||
if err != nil {
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to create new session: %v", err))
|
||||
return nil
|
||||
}
|
||||
m.printSystemMessage("New branch started. Previous conversation is preserved in the tree.")
|
||||
|
||||
// Switch to the new session, closing the old one
|
||||
m.appCtrl.SwitchTreeSession(newTs)
|
||||
m.printSystemMessage("New session started. Previous conversation saved.")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
+146
-9
@@ -54,6 +54,10 @@ func (s *stubAppController) GetTreeSession() *session.TreeManager {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAppController) SwitchTreeSession(_ *session.TreeManager) {
|
||||
// no-op in tests
|
||||
}
|
||||
|
||||
func (s *stubAppController) SendEvent(_ tea.Msg) {
|
||||
// no-op in tests
|
||||
}
|
||||
@@ -67,6 +71,11 @@ func (s *stubAppController) RunWithFiles(prompt string, _ []fantasy.FilePart) in
|
||||
return s.queueLen
|
||||
}
|
||||
|
||||
func (s *stubAppController) Steer(prompt string) int {
|
||||
s.runCalls = append(s.runCalls, prompt)
|
||||
return s.queueLen
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Stub child components
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -112,15 +121,16 @@ func newTestAppModel(ctrl AppController) (*AppModel, *stubStreamComponent, *stub
|
||||
stream := &stubStreamComponent{}
|
||||
input := &stubInputComponent{}
|
||||
m := &AppModel{
|
||||
state: stateInput,
|
||||
appCtrl: ctrl,
|
||||
stream: stream,
|
||||
input: input,
|
||||
renderer: newMessageRenderer(80, false),
|
||||
compactMode: false,
|
||||
modelName: "test-model",
|
||||
width: 80,
|
||||
height: 24,
|
||||
state: stateInput,
|
||||
appCtrl: ctrl,
|
||||
stream: stream,
|
||||
input: input,
|
||||
renderer: newMessageRenderer(80, false),
|
||||
compactMode: false,
|
||||
modelName: "test-model",
|
||||
width: 80,
|
||||
height: 24,
|
||||
streamingBashMaxLines: 50, // Initialize buffer cap like NewAppModel does
|
||||
}
|
||||
return m, stream, input
|
||||
}
|
||||
@@ -602,6 +612,133 @@ func TestToolResult_printsAndStartsSpinner(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolOutputEvent_accumulatesBashOutput verifies that ToolOutputEvent
|
||||
// accumulates stdout and stderr lines into the streaming bash output buffers.
|
||||
func TestToolOutputEvent_accumulatesBashOutput(t *testing.T) {
|
||||
ctrl := &stubAppController{}
|
||||
m, _, _ := newTestAppModel(ctrl)
|
||||
m.state = stateWorking
|
||||
|
||||
// Send stdout chunk.
|
||||
m = sendMsg(m, app.ToolOutputEvent{
|
||||
ToolCallID: "call-1",
|
||||
ToolName: "bash",
|
||||
Chunk: "line one\n",
|
||||
IsStderr: false,
|
||||
})
|
||||
|
||||
if len(m.streamingBashOutput) != 1 || m.streamingBashOutput[0] != "line one\n" {
|
||||
t.Fatalf("expected streamingBashOutput=['line one\\n'], got %v", m.streamingBashOutput)
|
||||
}
|
||||
if len(m.streamingBashStderr) != 0 {
|
||||
t.Fatalf("expected empty streamingBashStderr, got %v", m.streamingBashStderr)
|
||||
}
|
||||
|
||||
// Send another stdout chunk.
|
||||
m = sendMsg(m, app.ToolOutputEvent{
|
||||
ToolCallID: "call-1",
|
||||
ToolName: "bash",
|
||||
Chunk: "line two\n",
|
||||
IsStderr: false,
|
||||
})
|
||||
|
||||
if len(m.streamingBashOutput) != 2 {
|
||||
t.Fatalf("expected 2 stdout lines, got %d", len(m.streamingBashOutput))
|
||||
}
|
||||
|
||||
// Send stderr chunk.
|
||||
m = sendMsg(m, app.ToolOutputEvent{
|
||||
ToolCallID: "call-1",
|
||||
ToolName: "bash",
|
||||
Chunk: "error: something failed\n",
|
||||
IsStderr: true,
|
||||
})
|
||||
|
||||
if len(m.streamingBashStderr) != 1 {
|
||||
t.Fatalf("expected 1 stderr line, got %d", len(m.streamingBashStderr))
|
||||
}
|
||||
if m.streamingBashStderr[0] != "error: something failed\n" {
|
||||
t.Fatalf("expected stderr 'error: something failed\\n', got %q", m.streamingBashStderr[0])
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolResult_clearsStreamingBashOutput verifies that ToolResultEvent clears
|
||||
// the streaming bash output buffers since the final result will be printed.
|
||||
func TestToolResult_clearsStreamingBashOutput(t *testing.T) {
|
||||
ctrl := &stubAppController{}
|
||||
m, _, _ := newTestAppModel(ctrl)
|
||||
m.state = stateWorking
|
||||
|
||||
// Accumulate some bash output.
|
||||
m.streamingBashOutput = []string{"output line"}
|
||||
m.streamingBashStderr = []string{"error line"}
|
||||
|
||||
_, _ = m.Update(app.ToolResultEvent{
|
||||
ToolName: "bash",
|
||||
ToolArgs: `{"cmd":"ls"}`,
|
||||
Result: "output line\nerror line\n",
|
||||
IsError: false,
|
||||
})
|
||||
|
||||
if len(m.streamingBashOutput) != 0 {
|
||||
t.Fatalf("expected streamingBashOutput cleared, got %v", m.streamingBashOutput)
|
||||
}
|
||||
if len(m.streamingBashStderr) != 0 {
|
||||
t.Fatalf("expected streamingBashStderr cleared, got %v", m.streamingBashStderr)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolCallStarted_extractsBashCommand verifies that ToolCallStartedEvent
|
||||
// extracts the bash command from ToolArgs and stores it for the streaming output header.
|
||||
func TestToolCallStarted_extractsBashCommand(t *testing.T) {
|
||||
ctrl := &stubAppController{}
|
||||
m, _, _ := newTestAppModel(ctrl)
|
||||
m.state = stateWorking
|
||||
|
||||
// Send ToolCallStartedEvent with bash command.
|
||||
m = sendMsg(m, app.ToolCallStartedEvent{
|
||||
ToolCallID: "call-1",
|
||||
ToolName: "bash",
|
||||
ToolArgs: `{"command":"ls -la /home"}`,
|
||||
})
|
||||
|
||||
if m.streamingBashCommand != "ls -la /home" {
|
||||
t.Fatalf("expected streamingBashCommand='ls -la /home', got %q", m.streamingBashCommand)
|
||||
}
|
||||
|
||||
// ToolResultEvent should clear the command.
|
||||
m = sendMsg(m, app.ToolResultEvent{
|
||||
ToolCallID: "call-1",
|
||||
ToolName: "bash",
|
||||
ToolArgs: `{"command":"ls -la /home"}`,
|
||||
Result: "output",
|
||||
IsError: false,
|
||||
})
|
||||
|
||||
if m.streamingBashCommand != "" {
|
||||
t.Fatalf("expected streamingBashCommand cleared, got %q", m.streamingBashCommand)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolCallStarted_nonBashTool_doesNotSetCommand verifies that non-bash tools
|
||||
// do not set the streamingBashCommand field.
|
||||
func TestToolCallStarted_nonBashTool_doesNotSetCommand(t *testing.T) {
|
||||
ctrl := &stubAppController{}
|
||||
m, _, _ := newTestAppModel(ctrl)
|
||||
m.state = stateWorking
|
||||
|
||||
// Send ToolCallStartedEvent with a non-bash tool.
|
||||
m = sendMsg(m, app.ToolCallStartedEvent{
|
||||
ToolCallID: "call-1",
|
||||
ToolName: "read",
|
||||
ToolArgs: `{"file":"/etc/passwd"}`,
|
||||
})
|
||||
|
||||
if m.streamingBashCommand != "" {
|
||||
t.Fatalf("expected streamingBashCommand to remain empty for non-bash tools, got %q", m.streamingBashCommand)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStepError_printCmd verifies that StepErrorEvent with a non-nil error
|
||||
// produces a non-nil cmd (the tea.Println call for the error message).
|
||||
func TestStepError_printCmd(t *testing.T) {
|
||||
|
||||
+25
-14
@@ -282,7 +282,8 @@ func (s *StreamComponent) GetRenderedContent() string {
|
||||
|
||||
text := s.streamContent.String()
|
||||
if text != "" {
|
||||
sections = append(sections, s.renderStreamingText(text))
|
||||
rendered := s.renderStreamingText(text)
|
||||
sections = append(sections, rendered)
|
||||
}
|
||||
|
||||
if len(sections) == 0 {
|
||||
@@ -415,7 +416,9 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
// View implements tea.Model. Renders the current stream region content.
|
||||
func (s *StreamComponent) View() tea.View {
|
||||
return tea.NewView(s.render())
|
||||
fullContent := s.render()
|
||||
visibleContent := s.viewContent(fullContent)
|
||||
return tea.NewView(visibleContent)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -458,21 +461,27 @@ func (s *StreamComponent) render() string {
|
||||
|
||||
content := strings.Join(sections, "\n")
|
||||
|
||||
// Clamp to height if constrained: keep the last h lines so the most
|
||||
// recent output is always visible.
|
||||
if s.height > 0 && content != "" {
|
||||
lines := strings.Split(content, "\n")
|
||||
if len(lines) > s.height {
|
||||
lines = lines[len(lines)-s.height:]
|
||||
content = strings.Join(lines, "\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Cache FULL content without height clamping.
|
||||
// Height clamping is applied in View() for display only.
|
||||
s.renderCache = content
|
||||
s.renderDirty = false
|
||||
return content
|
||||
}
|
||||
|
||||
// viewContent returns the visible portion of content based on height constraint.
|
||||
// This is called by View() to get the slice that fits in the terminal.
|
||||
func (s *StreamComponent) viewContent(fullContent string) string {
|
||||
if s.height > 0 && fullContent != "" {
|
||||
lines := strings.Split(fullContent, "\n")
|
||||
if len(lines) > s.height {
|
||||
// Keep only the last h lines so the most recent output is visible.
|
||||
lines = lines[len(lines)-s.height:]
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
}
|
||||
return fullContent
|
||||
}
|
||||
|
||||
// renderReasoningBlock renders the reasoning/thinking content in a surface-tinted
|
||||
// box. When collapsed, shows the last 10 lines with a truncation hint. When
|
||||
// expanded, shows all lines. Includes a "Thought for Xs" duration footer.
|
||||
@@ -484,6 +493,7 @@ func (s *StreamComponent) renderReasoningBlock(reasoning string) string {
|
||||
|
||||
contentStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.MutedBorder).
|
||||
Italic(true)
|
||||
|
||||
var parts []string
|
||||
@@ -495,6 +505,7 @@ func (s *StreamComponent) renderReasoningBlock(reasoning string) string {
|
||||
hidden := len(lines) - maxCollapsedLines
|
||||
hintStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.VeryMuted).
|
||||
Background(theme.MutedBorder).
|
||||
Italic(true)
|
||||
parts = append(parts, hintStyle.Render(fmt.Sprintf("... (%d lines hidden)", hidden)))
|
||||
lines = lines[len(lines)-maxCollapsedLines:]
|
||||
@@ -517,8 +528,8 @@ func (s *StreamComponent) renderReasoningBlock(reasoning string) string {
|
||||
} else {
|
||||
durationStr = fmt.Sprintf("%.1fs", duration.Seconds())
|
||||
}
|
||||
footer := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render("Thought for ") +
|
||||
lipgloss.NewStyle().Foreground(theme.Info).Render(durationStr)
|
||||
footer := lipgloss.NewStyle().Foreground(theme.VeryMuted).Background(theme.MutedBorder).Render("Thought for ") +
|
||||
lipgloss.NewStyle().Foreground(theme.Info).Background(theme.MutedBorder).Render(durationStr)
|
||||
parts = append(parts, footer)
|
||||
}
|
||||
|
||||
|
||||
+110
-15
@@ -7,11 +7,86 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Color derivation helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// parseHexColor parses a "#RRGGBB" hex string into r, g, b components (0-255).
|
||||
func parseHexColor(hex string) (r, g, b int) {
|
||||
hex = strings.TrimPrefix(hex, "#")
|
||||
if len(hex) == 6 {
|
||||
if v, err := strconv.ParseUint(hex[0:2], 16, 8); err == nil {
|
||||
r = int(v)
|
||||
}
|
||||
if v, err := strconv.ParseUint(hex[2:4], 16, 8); err == nil {
|
||||
g = int(v)
|
||||
}
|
||||
if v, err := strconv.ParseUint(hex[4:6], 16, 8); err == nil {
|
||||
b = int(v)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// blendHex linearly interpolates between two hex colors by amount (0.0–1.0).
|
||||
func blendHex(base, tint string, amount float64) string {
|
||||
br, bg, bb := parseHexColor(base)
|
||||
tr, tg, tb := parseHexColor(tint)
|
||||
clamp := func(v int) int {
|
||||
if v < 0 {
|
||||
return 0
|
||||
}
|
||||
if v > 255 {
|
||||
return 255
|
||||
}
|
||||
return v
|
||||
}
|
||||
r := clamp(int(float64(br)*(1-amount) + float64(tr)*amount))
|
||||
g := clamp(int(float64(bg)*(1-amount) + float64(tg)*amount))
|
||||
b := clamp(int(float64(bb)*(1-amount) + float64(tb)*amount))
|
||||
return fmt.Sprintf("#%02x%02x%02x", r, g, b)
|
||||
}
|
||||
|
||||
// deriveDiffBg computes diff / code background colors from the theme's
|
||||
// background, success, and error hex pairs. Returns an adaptive color for each
|
||||
// diff element. The tint amounts are tuned for subtle differentiation.
|
||||
func deriveDiffBg(bgPair, successPair, errorPair [2]string) (diffInsert, diffDelete, diffEqual, diffMissing, codeBg, gutterBg, writeBg color.Color) {
|
||||
derive := func(idx int) (color.Color, color.Color, color.Color, color.Color) {
|
||||
bg := bgPair[idx]
|
||||
// Contrast target: darken for light mode (idx 0), lighten for dark (idx 1).
|
||||
contrast := "#000000"
|
||||
if idx == 1 {
|
||||
contrast = "#ffffff"
|
||||
}
|
||||
ins := blendHex(bg, successPair[idx], 0.13)
|
||||
del := blendHex(bg, errorPair[idx], 0.13)
|
||||
eq := blendHex(bg, contrast, 0.05)
|
||||
miss := blendHex(bg, contrast, 0.03)
|
||||
return AdaptiveColor(ins, ins), AdaptiveColor(del, del), AdaptiveColor(eq, eq), AdaptiveColor(miss, miss)
|
||||
}
|
||||
|
||||
// Pick the correct index based on detected background.
|
||||
idx := 0
|
||||
if isDarkBg {
|
||||
idx = 1
|
||||
}
|
||||
insL, delL, eqL, missL := derive(idx)
|
||||
diffInsert = insL
|
||||
diffDelete = delL
|
||||
diffEqual = eqL
|
||||
diffMissing = missL
|
||||
codeBg = eqL
|
||||
gutterBg = missL
|
||||
writeBg = insL
|
||||
return
|
||||
}
|
||||
|
||||
// ThemeEntry is a named, loadable theme — either built-in or discovered from disk.
|
||||
type ThemeEntry struct {
|
||||
Name string // Display name (filename stem or preset name)
|
||||
@@ -80,14 +155,9 @@ func makeTheme(p presetColors) Theme {
|
||||
Accent: acOr(p.accent, ac(p.primary)),
|
||||
Highlight: acOr(p.highlight, def.Highlight),
|
||||
}
|
||||
// Derive diff/code backgrounds from the base background.
|
||||
t.DiffInsertBg = def.DiffInsertBg
|
||||
t.DiffDeleteBg = def.DiffDeleteBg
|
||||
t.DiffEqualBg = def.DiffEqualBg
|
||||
t.DiffMissingBg = def.DiffMissingBg
|
||||
t.CodeBg = def.CodeBg
|
||||
t.GutterBg = def.GutterBg
|
||||
t.WriteBg = def.WriteBg
|
||||
// Derive diff/code backgrounds from the theme's own palette.
|
||||
t.DiffInsertBg, t.DiffDeleteBg, t.DiffEqualBg, t.DiffMissingBg,
|
||||
t.CodeBg, t.GutterBg, t.WriteBg = deriveDiffBg(p.background, p.success, p.error_)
|
||||
// Markdown colors.
|
||||
t.Markdown = MarkdownThemeColors{
|
||||
Text: t.Text,
|
||||
@@ -609,6 +679,17 @@ func loadThemeFile(path string) (Theme, error) {
|
||||
|
||||
func fileConfigToTheme(cfg themeFileConfig) Theme {
|
||||
def := DefaultTheme()
|
||||
|
||||
// Resolve the base background/success/error hex pairs for diff derivation.
|
||||
// We need the raw hex strings to feed deriveDiffBg.
|
||||
bgPair := resolveHexPair(cfg.Background, [2]string{"#F0F0F0", "#0D0D0D"})
|
||||
successPair := resolveHexPair(cfg.Success, [2]string{"#998800", "#CCAA00"})
|
||||
errorPair := resolveHexPair(cfg.Error, [2]string{"#CC0000", "#FF3333"})
|
||||
|
||||
// Derive diff backgrounds from the theme's own palette.
|
||||
derivedInsert, derivedDelete, derivedEqual, derivedMissing,
|
||||
derivedCodeBg, derivedGutterBg, derivedWriteBg := deriveDiffBg(bgPair, successPair, errorPair)
|
||||
|
||||
return Theme{
|
||||
Primary: cfg.Primary.resolve(def.Primary),
|
||||
Secondary: cfg.Secondary.resolve(def.Secondary),
|
||||
@@ -627,13 +708,13 @@ func fileConfigToTheme(cfg themeFileConfig) Theme {
|
||||
Accent: cfg.Accent.resolve(def.Accent),
|
||||
Highlight: cfg.Highlight.resolve(def.Highlight),
|
||||
|
||||
DiffInsertBg: cfg.DiffInsertBg.resolve(def.DiffInsertBg),
|
||||
DiffDeleteBg: cfg.DiffDeleteBg.resolve(def.DiffDeleteBg),
|
||||
DiffEqualBg: cfg.DiffEqualBg.resolve(def.DiffEqualBg),
|
||||
DiffMissingBg: cfg.DiffMissingBg.resolve(def.DiffMissingBg),
|
||||
CodeBg: cfg.CodeBg.resolve(def.CodeBg),
|
||||
GutterBg: cfg.GutterBg.resolve(def.GutterBg),
|
||||
WriteBg: cfg.WriteBg.resolve(def.WriteBg),
|
||||
DiffInsertBg: cfg.DiffInsertBg.resolve(derivedInsert),
|
||||
DiffDeleteBg: cfg.DiffDeleteBg.resolve(derivedDelete),
|
||||
DiffEqualBg: cfg.DiffEqualBg.resolve(derivedEqual),
|
||||
DiffMissingBg: cfg.DiffMissingBg.resolve(derivedMissing),
|
||||
CodeBg: cfg.CodeBg.resolve(derivedCodeBg),
|
||||
GutterBg: cfg.GutterBg.resolve(derivedGutterBg),
|
||||
WriteBg: cfg.WriteBg.resolve(derivedWriteBg),
|
||||
|
||||
Markdown: MarkdownThemeColors{
|
||||
Text: cfg.Markdown.Text.resolve(def.Markdown.Text),
|
||||
@@ -651,3 +732,17 @@ func fileConfigToTheme(cfg themeFileConfig) Theme {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// resolveHexPair returns the hex pair from an adaptiveColorPair, falling back
|
||||
// to defaults when the pair is empty.
|
||||
func resolveHexPair(a adaptiveColorPair, fallback [2]string) [2]string {
|
||||
light := a.Light
|
||||
if light == "" {
|
||||
light = fallback[0]
|
||||
}
|
||||
dark := a.Dark
|
||||
if dark == "" {
|
||||
dark = fallback[1]
|
||||
}
|
||||
return [2]string{light, dark}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseHexColor(t *testing.T) {
|
||||
tests := []struct {
|
||||
hex string
|
||||
r, g, b int
|
||||
}{
|
||||
{"#000000", 0, 0, 0},
|
||||
{"#ffffff", 255, 255, 255},
|
||||
{"#1e1e2e", 0x1e, 0x1e, 0x2e},
|
||||
{"#a6e3a1", 0xa6, 0xe3, 0xa1},
|
||||
{"#f38ba8", 0xf3, 0x8b, 0xa8},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
r, g, b := parseHexColor(tt.hex)
|
||||
if r != tt.r || g != tt.g || b != tt.b {
|
||||
t.Errorf("parseHexColor(%q) = (%d,%d,%d), want (%d,%d,%d)",
|
||||
tt.hex, r, g, b, tt.r, tt.g, tt.b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBlendHex(t *testing.T) {
|
||||
// Blending with 0 amount should return the base color.
|
||||
got := blendHex("#1e1e2e", "#a6e3a1", 0.0)
|
||||
if got != "#1e1e2e" {
|
||||
t.Errorf("blendHex with 0.0 = %q, want #1e1e2e", got)
|
||||
}
|
||||
|
||||
// Blending with 1.0 amount should return the tint color.
|
||||
got = blendHex("#1e1e2e", "#a6e3a1", 1.0)
|
||||
if got != "#a6e3a1" {
|
||||
t.Errorf("blendHex with 1.0 = %q, want #a6e3a1", got)
|
||||
}
|
||||
|
||||
// Blending black and white at 0.5 should give mid gray.
|
||||
got = blendHex("#000000", "#ffffff", 0.5)
|
||||
// 127 = int(0 + 255*0.5) — truncated, so #7f7f7f
|
||||
if got != "#7f7f7f" {
|
||||
t.Errorf("blendHex black/white at 0.5 = %q, want #7f7f7f", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveDiffBgProducesDifferentColorsPerTheme(t *testing.T) {
|
||||
// Catppuccin palette
|
||||
catBg := [2]string{"#eff1f5", "#1e1e2e"}
|
||||
catSuccess := [2]string{"#40a02b", "#a6e3a1"}
|
||||
catError := [2]string{"#d20f39", "#f38ba8"}
|
||||
|
||||
// KITT palette
|
||||
kittBg := [2]string{"#F0F0F0", "#0D0D0D"}
|
||||
kittSuccess := [2]string{"#998800", "#CCAA00"}
|
||||
kittError := [2]string{"#CC0000", "#FF3333"}
|
||||
|
||||
catInsert, catDelete, _, _, _, _, _ := deriveDiffBg(catBg, catSuccess, catError)
|
||||
kittInsert, kittDelete, _, _, _, _, _ := deriveDiffBg(kittBg, kittSuccess, kittError)
|
||||
|
||||
if catInsert == kittInsert {
|
||||
t.Error("catppuccin DiffInsertBg should differ from kitt DiffInsertBg")
|
||||
}
|
||||
if catDelete == kittDelete {
|
||||
t.Error("catppuccin DiffDeleteBg should differ from kitt DiffDeleteBg")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeThemeDerivesUniqueDiffColors(t *testing.T) {
|
||||
themes := builtinThemes()
|
||||
kitt := themes["kitt"]
|
||||
cat := themes["catppuccin"]
|
||||
|
||||
// The catppuccin diff backgrounds should NOT equal the kitt defaults.
|
||||
if cat.DiffInsertBg == kitt.DiffInsertBg {
|
||||
t.Error("catppuccin DiffInsertBg should differ from kitt default")
|
||||
}
|
||||
if cat.DiffDeleteBg == kitt.DiffDeleteBg {
|
||||
t.Error("catppuccin DiffDeleteBg should differ from kitt default")
|
||||
}
|
||||
if cat.DiffEqualBg == kitt.DiffEqualBg {
|
||||
t.Error("catppuccin DiffEqualBg should differ from kitt default")
|
||||
}
|
||||
}
|
||||
@@ -23,6 +23,7 @@ const (
|
||||
maxCodeLines = 20 // lines for Read / code blocks
|
||||
maxWriteLines = 10 // lines for Write blocks
|
||||
maxBashLines = 20 // lines for Bash output (matches Read)
|
||||
maxLsLines = 20 // lines for Ls directory listings
|
||||
)
|
||||
|
||||
// renderToolBody dispatches to tool-specific body renderers based on tool name.
|
||||
@@ -63,21 +64,44 @@ func renderToolBody(toolName, toolArgs, toolResult string, width int) string {
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// renderEditBody renders a side-by-side diff from old_text/new_text in toolArgs.
|
||||
// Supports both single-edit mode and multi-edit mode (edits array).
|
||||
func renderEditBody(toolArgs, toolResult string, width int) string {
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(toolArgs), &args); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Try to extract the starting line number from the unified diff in the result
|
||||
startLine := extractDiffStartLine(toolResult)
|
||||
|
||||
// Check for multi-edit mode (edits array)
|
||||
if editsArr, ok := args["edits"].([]any); ok && len(editsArr) > 0 {
|
||||
var results []string
|
||||
for _, edit := range editsArr {
|
||||
if e, ok := edit.(map[string]any); ok {
|
||||
oldText, _ := e["old_text"].(string)
|
||||
newText, _ := e["new_text"].(string)
|
||||
if oldText != "" || newText != "" {
|
||||
diff := renderDiffBlock(oldText, newText, startLine, width)
|
||||
if diff != "" {
|
||||
results = append(results, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(results) > 0 {
|
||||
return strings.Join(results, "\n")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Single-edit mode (legacy)
|
||||
oldText, _ := args["old_text"].(string)
|
||||
newText, _ := args["new_text"].(string)
|
||||
if oldText == "" && newText == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Try to extract the starting line number from the unified diff in the result
|
||||
startLine := extractDiffStartLine(toolResult)
|
||||
|
||||
return renderDiffBlock(oldText, newText, startLine, width)
|
||||
}
|
||||
|
||||
@@ -229,7 +253,7 @@ func renderDiffBlock(before, after string, startLine int, width int) string {
|
||||
gutterMissing := lipgloss.NewStyle().Background(theme.DiffMissingBg)
|
||||
|
||||
contentInsert := lipgloss.NewStyle().Background(theme.DiffInsertBg)
|
||||
contentDelete := lipgloss.NewStyle().Background(theme.DiffDeleteBg).Strikethrough(true)
|
||||
contentDelete := lipgloss.NewStyle().Background(theme.DiffDeleteBg)
|
||||
contentEqual := lipgloss.NewStyle().Foreground(theme.Muted).Background(theme.DiffEqualBg)
|
||||
contentMissing := lipgloss.NewStyle().Background(theme.DiffMissingBg)
|
||||
|
||||
@@ -315,6 +339,13 @@ func renderLsBody(toolResult string, width int) string {
|
||||
|
||||
lines := strings.Split(content, "\n")
|
||||
|
||||
// Truncate to maxLsLines for display
|
||||
var hiddenCount int
|
||||
if len(lines) > maxLsLines {
|
||||
hiddenCount = len(lines) - maxLsLines
|
||||
lines = lines[:maxLsLines]
|
||||
}
|
||||
|
||||
const indent = " "
|
||||
codeWidth := max(width-len(indent), 20)
|
||||
|
||||
@@ -329,6 +360,13 @@ func renderLsBody(toolResult string, width int) string {
|
||||
result = append(result, indent+styled)
|
||||
}
|
||||
|
||||
if hiddenCount > 0 {
|
||||
hint := fmt.Sprintf("...(%d more entries)", hiddenCount)
|
||||
hintContent := codeStyle.Width(codeWidth).
|
||||
Foreground(theme.Muted).Italic(true).Render(hint)
|
||||
result = append(result, indent+hintContent)
|
||||
}
|
||||
|
||||
return strings.Join(result, "\n")
|
||||
}
|
||||
|
||||
|
||||
@@ -266,3 +266,14 @@ func (ut *UsageTracker) SetWidth(width int) {
|
||||
defer ut.mu.Unlock()
|
||||
ut.width = width
|
||||
}
|
||||
|
||||
// UpdateModelInfo updates the model information and OAuth status when the model
|
||||
// is switched mid-session. This ensures token costs and context limits are
|
||||
// calculated correctly for the new model.
|
||||
func (ut *UsageTracker) UpdateModelInfo(modelInfo *models.ModelInfo, provider string, isOAuth bool) {
|
||||
ut.mu.Lock()
|
||||
defer ut.mu.Unlock()
|
||||
ut.modelInfo = modelInfo
|
||||
ut.provider = provider
|
||||
ut.isOAuth = isOAuth
|
||||
}
|
||||
|
||||
@@ -9,6 +9,10 @@ type CredentialManager = auth.CredentialManager
|
||||
// and API key authentication methods.
|
||||
type AnthropicCredentials = auth.AnthropicCredentials
|
||||
|
||||
// OpenAICredentials holds OpenAI API credentials supporting both OAuth
|
||||
// and API key authentication methods.
|
||||
type OpenAICredentials = auth.OpenAICredentials
|
||||
|
||||
// CredentialStore holds all stored credentials for various providers.
|
||||
type CredentialStore = auth.CredentialStore
|
||||
|
||||
@@ -42,3 +46,34 @@ func GetAnthropicAPIKey() string {
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
// HasOpenAICredentials checks if valid OpenAI credentials are stored
|
||||
// (either OAuth token or API key).
|
||||
func HasOpenAICredentials() bool {
|
||||
cm, err := auth.NewCredentialManager()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
has, err := cm.HasOpenAICredentials()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return has
|
||||
}
|
||||
|
||||
// GetOpenAIAPIKey resolves the OpenAI API key using the standard
|
||||
// resolution order: stored credentials -> OPENAI_API_KEY env var.
|
||||
// Returns an empty string if no key is found.
|
||||
func GetOpenAIAPIKey() string {
|
||||
cm, err := auth.NewCredentialManager()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
// Try to get valid access token (handles OAuth refresh)
|
||||
token, err := cm.GetValidOpenAIAccessToken()
|
||||
if err == nil && token != "" {
|
||||
return token
|
||||
}
|
||||
// Fall back to environment variable
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -39,6 +39,12 @@ const (
|
||||
EventCompaction EventType = "compaction"
|
||||
// EventReasoningDelta fires for each streaming reasoning/thinking chunk.
|
||||
EventReasoningDelta EventType = "reasoning_delta"
|
||||
// EventToolOutput fires when a tool produces streaming output chunks.
|
||||
EventToolOutput EventType = "tool_output"
|
||||
EventStepUsage EventType = "step_usage"
|
||||
// EventSteerConsumed fires when one or more steering messages have been
|
||||
// injected into the agent turn via PrepareStep.
|
||||
EventSteerConsumed EventType = "steer_consumed"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -143,6 +149,17 @@ type ReasoningDeltaEvent struct {
|
||||
// EventType implements Event.
|
||||
func (e ReasoningDeltaEvent) EventType() EventType { return EventReasoningDelta }
|
||||
|
||||
// ToolOutputEvent fires when a tool produces streaming output chunks (e.g., bash output).
|
||||
type ToolOutputEvent struct {
|
||||
ToolCallID string
|
||||
ToolName string
|
||||
Chunk string
|
||||
IsStderr bool
|
||||
}
|
||||
|
||||
// EventType implements Event.
|
||||
func (e ToolOutputEvent) EventType() EventType { return EventToolOutput }
|
||||
|
||||
// MessageEndEvent fires when the assistant message is complete.
|
||||
type MessageEndEvent struct {
|
||||
Content string
|
||||
@@ -236,6 +253,19 @@ type ResponseEvent struct {
|
||||
// EventType implements Event.
|
||||
func (e ResponseEvent) EventType() EventType { return EventResponse }
|
||||
|
||||
// StepUsageEvent fires after each complete step in a multi-step agent turn,
|
||||
// carrying the token usage for that specific step. This enables real-time
|
||||
// cost tracking during long-running tool-calling conversations.
|
||||
type StepUsageEvent struct {
|
||||
InputTokens uint64
|
||||
OutputTokens uint64
|
||||
CacheReadTokens uint64
|
||||
CacheWriteTokens uint64
|
||||
}
|
||||
|
||||
// EventType implements Event.
|
||||
func (e StepUsageEvent) EventType() EventType { return EventStepUsage }
|
||||
|
||||
// CompactionEvent fires after a successful compaction.
|
||||
type CompactionEvent struct {
|
||||
Summary string
|
||||
@@ -249,6 +279,16 @@ type CompactionEvent struct {
|
||||
// EventType implements Event.
|
||||
func (e CompactionEvent) EventType() EventType { return EventCompaction }
|
||||
|
||||
// SteerConsumedEvent fires when one or more steering messages have been
|
||||
// injected into the agent turn via PrepareStep. The Count indicates how
|
||||
// many messages were consumed in this batch.
|
||||
type SteerConsumedEvent struct {
|
||||
Count int
|
||||
}
|
||||
|
||||
// EventType implements Event.
|
||||
func (e SteerConsumedEvent) EventType() EventType { return EventSteerConsumed }
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// EventBus
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -322,6 +362,16 @@ func (m *Kit) OnToolResult(handler func(ToolResultEvent)) func() {
|
||||
})
|
||||
}
|
||||
|
||||
// OnToolOutput registers a handler that fires only for ToolOutputEvent
|
||||
// (streaming tool output chunks, e.g., from bash). Returns an unsubscribe function.
|
||||
func (m *Kit) OnToolOutput(handler func(ToolOutputEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if to, ok := e.(ToolOutputEvent); ok {
|
||||
handler(to)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// OnStreaming registers a handler that fires only for MessageUpdateEvent
|
||||
// (streaming text chunks). Returns an unsubscribe function.
|
||||
func (m *Kit) OnStreaming(handler func(MessageUpdateEvent)) func() {
|
||||
|
||||
@@ -2,6 +2,7 @@ package kit
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
@@ -86,6 +87,20 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
})
|
||||
}
|
||||
|
||||
// Tool output streaming events (observation only).
|
||||
if runner.HasHandlers(extensions.ToolOutput) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(ToolOutputEvent); ok {
|
||||
_, _ = runner.Emit(extensions.ToolOutputEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName,
|
||||
Chunk: ev.Chunk,
|
||||
IsStderr: ev.IsStderr,
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if runner.HasHandlers(extensions.AgentEnd) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(TurnEndEvent); ok {
|
||||
@@ -105,6 +120,125 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
})
|
||||
}
|
||||
|
||||
// --- Subagent lifecycle events ---
|
||||
// When an extension registers OnSubagentStart/Chunk/End handlers, bridge
|
||||
// the SDK's per-subagent event stream (SubscribeSubagent) into the
|
||||
// extension runner.
|
||||
//
|
||||
// Flow:
|
||||
// ToolExecutionStartEvent(spawn_subagent) → emit SubagentStartEvent
|
||||
// → SubscribeSubagent → emit SubagentChunkEvents
|
||||
// ToolResultEvent(spawn_subagent) → emit SubagentEndEvent
|
||||
//
|
||||
// We use ToolExecutionStart (not ToolCall) for SubagentStart because that
|
||||
// is when the subagent actually begins running. We use ToolResult for
|
||||
// SubagentEnd because that carries the final response text.
|
||||
wantsSubagent := runner.HasHandlers(extensions.SubagentStart) ||
|
||||
runner.HasHandlers(extensions.SubagentChunk) ||
|
||||
runner.HasHandlers(extensions.SubagentEnd)
|
||||
|
||||
if wantsSubagent {
|
||||
// taskByCallID tracks the task description extracted from ToolCall input,
|
||||
// keyed by toolCallID. Populated on ToolCall, consumed on ToolResult.
|
||||
taskByCallID := make(map[string]string)
|
||||
var taskMu = &taskMutex{}
|
||||
|
||||
// Intercept ToolCall to capture the task and subscribe to child events.
|
||||
m.Subscribe(func(e Event) {
|
||||
ev, ok := e.(ToolCallEvent)
|
||||
if !ok || ev.ToolName != "spawn_subagent" {
|
||||
return
|
||||
}
|
||||
|
||||
// Extract task from parsed args.
|
||||
task := ""
|
||||
if ev.ParsedArgs != nil {
|
||||
if t, ok := ev.ParsedArgs["task"].(string); ok {
|
||||
task = t
|
||||
}
|
||||
}
|
||||
taskMu.set(taskByCallID, ev.ToolCallID, task)
|
||||
|
||||
// Subscribe to child events so we can forward them as SubagentChunkEvents.
|
||||
if runner.HasHandlers(extensions.SubagentChunk) {
|
||||
m.SubscribeSubagent(ev.ToolCallID, func(childEvent Event) {
|
||||
chunk := extensions.SubagentChunkEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
Task: task,
|
||||
}
|
||||
switch ce := childEvent.(type) {
|
||||
case MessageUpdateEvent:
|
||||
chunk.ChunkType = "text"
|
||||
chunk.Content = ce.Chunk
|
||||
case TurnStartEvent:
|
||||
chunk.ChunkType = "turn_start"
|
||||
case TurnEndEvent:
|
||||
chunk.ChunkType = "turn_end"
|
||||
case ToolCallEvent:
|
||||
chunk.ChunkType = "tool_call"
|
||||
chunk.ToolName = ce.ToolName
|
||||
chunk.ToolArgs = ce.ToolArgs
|
||||
case ToolExecutionStartEvent:
|
||||
chunk.ChunkType = "tool_execution_start"
|
||||
chunk.ToolName = ce.ToolName
|
||||
case ToolExecutionEndEvent:
|
||||
chunk.ChunkType = "tool_execution_end"
|
||||
chunk.ToolName = ce.ToolName
|
||||
case ToolResultEvent:
|
||||
chunk.ChunkType = "tool_result"
|
||||
chunk.ToolName = ce.ToolName
|
||||
chunk.ToolResult = ce.Result
|
||||
chunk.IsError = ce.IsError
|
||||
default:
|
||||
return // skip unknown event types
|
||||
}
|
||||
_, _ = runner.Emit(chunk)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// Emit SubagentStartEvent when execution begins.
|
||||
if runner.HasHandlers(extensions.SubagentStart) {
|
||||
m.Subscribe(func(e Event) {
|
||||
ev, ok := e.(ToolExecutionStartEvent)
|
||||
if !ok || ev.ToolName != "spawn_subagent" {
|
||||
return
|
||||
}
|
||||
task := taskMu.get(taskByCallID, ev.ToolCallID)
|
||||
_, _ = runner.Emit(extensions.SubagentStartEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
Task: task,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Emit SubagentEndEvent when the tool result arrives.
|
||||
if runner.HasHandlers(extensions.SubagentEnd) {
|
||||
m.Subscribe(func(e Event) {
|
||||
ev, ok := e.(ToolResultEvent)
|
||||
if !ok || ev.ToolName != "spawn_subagent" {
|
||||
return
|
||||
}
|
||||
task := taskMu.get(taskByCallID, ev.ToolCallID)
|
||||
taskMu.del(taskByCallID, ev.ToolCallID)
|
||||
errMsg := ""
|
||||
if ev.IsError {
|
||||
errMsg = ev.Result
|
||||
}
|
||||
response := ""
|
||||
if !ev.IsError {
|
||||
response = ev.Result
|
||||
}
|
||||
_, _ = runner.Emit(extensions.SubagentEndEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
Task: task,
|
||||
Response: response,
|
||||
ErrorMsg: errMsg,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- Context filtering hook ---
|
||||
// Extension ContextPrepare → SDK ContextPrepare hook.
|
||||
if runner.HasHandlers(extensions.ContextPrepare) {
|
||||
@@ -190,3 +324,27 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// taskMutex is a simple mutex-protected map helper used by bridgeExtensions.
|
||||
// It lives in this file to avoid polluting the kit package with unexported types.
|
||||
type taskMutex struct {
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (t *taskMutex) set(m map[string]string, key, val string) {
|
||||
t.mu.Lock()
|
||||
m[key] = val
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
func (t *taskMutex) get(m map[string]string, key string) string {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return m[key]
|
||||
}
|
||||
|
||||
func (t *taskMutex) del(m map[string]string, key string) {
|
||||
t.mu.Lock()
|
||||
delete(m, key)
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
+147
-26
@@ -66,6 +66,13 @@ type Kit struct {
|
||||
// subagentListeners holds per-tool-call event listeners registered via
|
||||
// SubscribeSubagent(). Keyed by toolCallID → *subagentListenerSet.
|
||||
subagentListeners sync.Map
|
||||
|
||||
// steerCh is a buffered channel used to inject steering messages into
|
||||
// the running agent turn via Fantasy's PrepareStep. Created fresh for
|
||||
// each generate() call and set to nil when idle. Protected by steerMu.
|
||||
steerMu sync.Mutex
|
||||
steerCh chan string
|
||||
leftoverSteer []string // unconsumed steer messages from the last turn
|
||||
}
|
||||
|
||||
// Subscribe registers an EventListener that will be called for every lifecycle
|
||||
@@ -529,8 +536,11 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
|
||||
}
|
||||
|
||||
// Build a provider config from current settings, overriding the model.
|
||||
// Load system prompt properly (handles both file paths and inline content).
|
||||
systemPrompt, _ := config.LoadSystemPrompt(viper.GetString("system-prompt"))
|
||||
config := &models.ProviderConfig{
|
||||
ModelString: modelString,
|
||||
SystemPrompt: systemPrompt,
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
ProviderURL: viper.GetString("provider-url"),
|
||||
MaxTokens: viper.GetInt("max-tokens"),
|
||||
@@ -917,8 +927,12 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
setSDKDefaults()
|
||||
|
||||
// Initialize config (loads config files and env vars).
|
||||
if err := InitConfig(opts.ConfigFile, false); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize config: %w", err)
|
||||
// Only initialize if not already done (e.g., by CLI's cobra.OnInitialize).
|
||||
// Check if model is already set, which indicates config was loaded.
|
||||
if viper.GetString("model") == "" {
|
||||
if err := InitConfig(opts.ConfigFile, false); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle CLI debug mode.
|
||||
@@ -1049,6 +1063,15 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
// Bridge extension events to SDK hooks.
|
||||
if agentResult.ExtRunner != nil {
|
||||
k.bridgeExtensions(agentResult.ExtRunner)
|
||||
|
||||
// Initialize extension context with minimal defaults. SDK users can call
|
||||
// SetExtensionContext to override with richer implementations (TUI callbacks,
|
||||
// prompts, etc.). This ensures extensions never crash on nil function fields.
|
||||
k.SetExtensionContext(extensions.Context{
|
||||
CWD: cwd,
|
||||
Model: k.modelString,
|
||||
Interactive: false, // SDK mode defaults to non-interactive
|
||||
})
|
||||
}
|
||||
|
||||
return k, nil
|
||||
@@ -1401,6 +1424,35 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult
|
||||
// All prompt modes (Prompt, Steer, FollowUp, PromptWithOptions) share this
|
||||
// single code path so callback wiring is never duplicated.
|
||||
func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.GenerateWithLoopResult, error) {
|
||||
// Create a per-turn steer channel and attach it to the context so the
|
||||
// agent's PrepareStep can inject steering messages between steps.
|
||||
steerCh := make(chan string, 16)
|
||||
m.steerMu.Lock()
|
||||
m.steerCh = steerCh
|
||||
m.steerMu.Unlock()
|
||||
defer func() {
|
||||
// Drain any unconsumed steer messages before nilling the channel.
|
||||
// These are stored in leftoverSteer so DrainSteer() can return them.
|
||||
var leftover []string
|
||||
for {
|
||||
select {
|
||||
case msg := <-steerCh:
|
||||
leftover = append(leftover, msg)
|
||||
default:
|
||||
goto drained
|
||||
}
|
||||
}
|
||||
drained:
|
||||
m.steerMu.Lock()
|
||||
m.steerCh = nil
|
||||
m.leftoverSteer = leftover
|
||||
m.steerMu.Unlock()
|
||||
}()
|
||||
ctx = agent.ContextWithSteerCh(ctx, steerCh)
|
||||
ctx = agent.ContextWithSteerConsumed(ctx, func(count int) {
|
||||
m.events.emit(SteerConsumedEvent{Count: count})
|
||||
})
|
||||
|
||||
// Inject the in-process subagent spawner into the context so the
|
||||
// spawn_subagent core tool can create child Kit instances without
|
||||
// importing pkg/kit (which would create an import cycle).
|
||||
@@ -1478,6 +1530,24 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
|
||||
func(delta string) {
|
||||
m.events.emit(ReasoningDeltaEvent{Delta: delta})
|
||||
},
|
||||
func(toolCallID, toolName, chunk string, isStderr bool) {
|
||||
// Emit tool output chunk event for streaming bash output
|
||||
m.events.emit(ToolOutputEvent{
|
||||
ToolCallID: toolCallID,
|
||||
ToolName: toolName,
|
||||
Chunk: chunk,
|
||||
IsStderr: isStderr,
|
||||
})
|
||||
},
|
||||
func(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64) {
|
||||
// Emit step usage event for real-time cost tracking
|
||||
m.events.emit(StepUsageEvent{
|
||||
InputTokens: uint64(inputTokens),
|
||||
OutputTokens: uint64(outputTokens),
|
||||
CacheReadTokens: uint64(cacheReadTokens),
|
||||
CacheWriteTokens: uint64(cacheCreationTokens),
|
||||
})
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1537,13 +1607,6 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
|
||||
}
|
||||
}
|
||||
|
||||
// Save the leaf position before appending anything so we can roll back
|
||||
// to this point if the turn is cancelled (double-ESC). Rolling back
|
||||
// discards the user message and any tool call / tool result pairs that
|
||||
// were generated, which avoids leaving orphaned tool_use messages
|
||||
// without matching tool_result (APIs require them in pairs).
|
||||
preLeafID := m.treeSession.GetLeafID()
|
||||
|
||||
// Persist pre-generation messages to tree session.
|
||||
for _, msg := range preMessages {
|
||||
_, _ = m.treeSession.AppendFantasyMessage(msg)
|
||||
@@ -1571,23 +1634,16 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
|
||||
|
||||
result, err := m.generate(ctx, messages)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
// Context was cancelled (e.g. user pressed ESC twice). Roll
|
||||
// the tree session back to the pre-turn leaf so that the
|
||||
// user message, any tool_use messages, and any tool_result
|
||||
// messages from this turn are all discarded. APIs require
|
||||
// tool calls and tool results to appear in matched pairs;
|
||||
// persisting a partial turn would leave orphaned entries
|
||||
// that break subsequent requests.
|
||||
_ = m.treeSession.Branch(preLeafID)
|
||||
} else {
|
||||
// Non-cancellation error (e.g. API failure). Persist any
|
||||
// messages that were generated during this turn (completed
|
||||
// tool call/result pairs) so partial progress is not lost.
|
||||
if result != nil && len(result.ConversationMessages) > sentCount {
|
||||
for _, msg := range result.ConversationMessages[sentCount:] {
|
||||
_, _ = m.treeSession.AppendFantasyMessage(msg)
|
||||
}
|
||||
// Persist any messages from completed steps (tool call/result
|
||||
// pairs) so partial progress is not lost. The agent layer only
|
||||
// includes fully-paired tool_use + tool_result messages in
|
||||
// completedStepMessages, so there are no orphaned entries that
|
||||
// would break subsequent API requests. The user message and any
|
||||
// completed work remain in the session; only the in-progress
|
||||
// (pending) message or tool call is discarded.
|
||||
if result != nil && len(result.ConversationMessages) > sentCount {
|
||||
for _, msg := range result.ConversationMessages[sentCount:] {
|
||||
_, _ = m.treeSession.AppendFantasyMessage(msg)
|
||||
}
|
||||
}
|
||||
m.events.emit(TurnEndEvent{Error: err})
|
||||
@@ -1706,6 +1762,71 @@ func (m *Kit) FollowUp(ctx context.Context, text string) (string, error) {
|
||||
return result.Response, nil
|
||||
}
|
||||
|
||||
// InjectSteer sends a steering message into the currently active agent turn.
|
||||
// The message will be injected as a user message between steps (after the
|
||||
// current tool execution finishes, before the next LLM call). If no turn is
|
||||
// active the message is silently dropped — callers should check IsGenerating()
|
||||
// or use Prompt()/Steer() for idle-state messaging.
|
||||
//
|
||||
// InjectSteer is safe to call from any goroutine. Multiple calls queue
|
||||
// messages in order; all pending steer messages are drained and injected
|
||||
// together at the next step boundary.
|
||||
//
|
||||
// This is the preferred way to redirect an agent mid-turn without cancelling
|
||||
// in-progress tool execution.
|
||||
func (m *Kit) InjectSteer(message string) {
|
||||
m.steerMu.Lock()
|
||||
ch := m.steerCh
|
||||
m.steerMu.Unlock()
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case ch <- message:
|
||||
default:
|
||||
// Channel full — extremely unlikely with buffer of 16, but don't block.
|
||||
}
|
||||
}
|
||||
|
||||
// IsGenerating returns true if an agent turn is currently in progress.
|
||||
// Use this to decide between InjectSteer (mid-turn) and Prompt (new turn).
|
||||
func (m *Kit) IsGenerating() bool {
|
||||
m.steerMu.Lock()
|
||||
defer m.steerMu.Unlock()
|
||||
return m.steerCh != nil
|
||||
}
|
||||
|
||||
// DrainSteer removes and returns all unconsumed steer messages. Called after
|
||||
// a turn completes so the app layer can process any steer messages that
|
||||
// arrived after the last PrepareStep fired (e.g. during a text-only response
|
||||
// with no tool calls, or after the agent finished its last step).
|
||||
func (m *Kit) DrainSteer() []string {
|
||||
m.steerMu.Lock()
|
||||
defer m.steerMu.Unlock()
|
||||
|
||||
// First check leftover messages saved when generate() returned.
|
||||
if len(m.leftoverSteer) > 0 {
|
||||
msgs := m.leftoverSteer
|
||||
m.leftoverSteer = nil
|
||||
return msgs
|
||||
}
|
||||
|
||||
// If a turn is still active, drain from the live channel.
|
||||
if m.steerCh != nil {
|
||||
var msgs []string
|
||||
for {
|
||||
select {
|
||||
case msg := <-m.steerCh:
|
||||
msgs = append(msgs, msg)
|
||||
default:
|
||||
return msgs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PromptOptions configures a single PromptWithOptions call.
|
||||
type PromptOptions struct {
|
||||
// SystemMessage is prepended as a system message before the user prompt.
|
||||
|
||||
+4
-1
@@ -16,7 +16,10 @@ func TestNew(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test default initialization
|
||||
host, err := kit.New(ctx, nil)
|
||||
opts := &kit.Options{
|
||||
Model: "anthropic/claude-sonnet-4-5-20250929",
|
||||
}
|
||||
host, err := kit.New(ctx, opts)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Kit with defaults: %v", err)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,772 @@
|
||||
---
|
||||
name: kit-sdk
|
||||
description: Guide for building Go applications with the Kit SDK. Use when the user asks to create a program, service, script, or application that uses Kit programmatically as a Go library — e.g. embedding LLM interactions, building agents, creating CLI tools powered by Kit, or integrating Kit into backend services. Do NOT use for Kit extensions (use kit-extensions skill instead).
|
||||
---
|
||||
|
||||
# Kit SDK Development Guide
|
||||
|
||||
The Kit SDK (`pkg/kit`) lets you embed Kit's full agent capabilities — LLM interactions, tool execution, session management, streaming, hooks — into any Go application. Unlike extensions (which are interpreted scripts running inside Kit's TUI), SDK programs are standalone compiled Go binaries.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
go get github.com/mark3labs/kit
|
||||
```
|
||||
|
||||
Import path (alias recommended):
|
||||
|
||||
```go
|
||||
import kit "github.com/mark3labs/kit/pkg/kit"
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
|
||||
host, err := kit.New(ctx, nil) // nil = load ~/.kit.yml defaults
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
response, err := host.Prompt(ctx, "What is 2+2?")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println(response)
|
||||
}
|
||||
```
|
||||
|
||||
## Core Lifecycle
|
||||
|
||||
1. **Create**: `kit.New(ctx, opts)` — loads config, initializes MCP servers, creates LLM provider, sets up agent
|
||||
2. **Interact**: `host.Prompt(ctx, msg)` — send messages, agent uses tools as needed
|
||||
3. **Close**: `host.Close()` — cleans up MCP connections, model resources, session file handle
|
||||
|
||||
Always defer `Close()`:
|
||||
|
||||
```go
|
||||
defer func() { _ = host.Close() }()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Options Reference
|
||||
|
||||
All fields are optional. Zero values use CLI defaults.
|
||||
|
||||
```go
|
||||
host, err := kit.New(ctx, &kit.Options{
|
||||
// Model
|
||||
Model: "anthropic/claude-sonnet-4-5-20250929", // "provider/model" format
|
||||
SystemPrompt: "You are a helpful assistant",
|
||||
ConfigFile: "/path/to/config.yml", // default: ~/.kit.yml
|
||||
|
||||
// Behavior
|
||||
MaxSteps: 10, // 0 = unlimited tool-calling steps
|
||||
Streaming: true, // stream LLM output (default from config)
|
||||
Quiet: true, // suppress debug output
|
||||
Debug: true, // enable debug logging
|
||||
|
||||
// Session
|
||||
SessionDir: "/path/to/project", // base dir for session discovery (default: cwd)
|
||||
SessionPath: "/path/to/session.jsonl", // open specific session file
|
||||
Continue: true, // resume most recent session for SessionDir
|
||||
NoSession: true, // ephemeral in-memory session, no disk persistence
|
||||
|
||||
// Tools
|
||||
Tools: []kit.Tool{kit.NewBashTool()}, // REPLACES entire default tool set
|
||||
ExtraTools: []kit.Tool{myTool}, // ADDS alongside core/MCP/extension tools
|
||||
|
||||
// Skills
|
||||
Skills: []string{"/path/to/skill.md"}, // explicit skill files (empty = auto-discover)
|
||||
SkillsDir: "/path/to/skills", // override project-local skills dir
|
||||
|
||||
// Compaction
|
||||
AutoCompact: true, // auto-compact near context limit
|
||||
CompactionOptions: &kit.CompactionOptions{...}, // nil = defaults
|
||||
})
|
||||
```
|
||||
|
||||
**Critical distinction**: `Tools` replaces ALL default tools (core + MCP + extension). `ExtraTools` adds tools alongside the defaults. Use `Tools` to restrict the agent's capabilities; use `ExtraTools` to extend them.
|
||||
|
||||
---
|
||||
|
||||
## Prompt Methods
|
||||
|
||||
### Simple prompt — string in, string out
|
||||
|
||||
```go
|
||||
response, err := host.Prompt(ctx, "Explain this code")
|
||||
```
|
||||
|
||||
### Full result with usage stats
|
||||
|
||||
```go
|
||||
result, err := host.PromptResult(ctx, "Analyze this file")
|
||||
// result.Response — assistant's text
|
||||
// result.StopReason — "stop", "length", "tool-calls", "error", etc.
|
||||
// result.SessionID — session UUID
|
||||
// result.TotalUsage — aggregate tokens across all steps (*kit.FantasyUsage)
|
||||
// result.FinalUsage — tokens from last API call only
|
||||
// result.Messages — full updated conversation ([]kit.FantasyMessage)
|
||||
```
|
||||
|
||||
### Multimodal with file attachments
|
||||
|
||||
```go
|
||||
import "charm.land/fantasy"
|
||||
|
||||
files := []fantasy.FilePart{{
|
||||
Name: "screenshot.png",
|
||||
MediaType: "image/png",
|
||||
Data: imageBytes,
|
||||
}}
|
||||
result, err := host.PromptResultWithFiles(ctx, "What's in this image?", files)
|
||||
```
|
||||
|
||||
### Per-call system message injection
|
||||
|
||||
```go
|
||||
response, err := host.PromptWithOptions(ctx, "Review this PR", kit.PromptOptions{
|
||||
SystemMessage: "Focus on security vulnerabilities only.",
|
||||
})
|
||||
```
|
||||
|
||||
### System-level steering (no visible user message)
|
||||
|
||||
```go
|
||||
response, err := host.Steer(ctx, "Switch to a more formal tone")
|
||||
```
|
||||
|
||||
### Continue without new input
|
||||
|
||||
```go
|
||||
response, err := host.FollowUp(ctx, "") // empty = "Continue."
|
||||
```
|
||||
|
||||
### Multiple user messages in one turn
|
||||
|
||||
```go
|
||||
result, err := host.PromptResultWithMessages(ctx, []string{
|
||||
"Here is the code:",
|
||||
"@file.go", // content from earlier
|
||||
"Please review it.",
|
||||
})
|
||||
```
|
||||
|
||||
### Legacy inline callbacks (deprecated — use event subscribers instead)
|
||||
|
||||
```go
|
||||
response, err := host.PromptWithCallbacks(ctx, "List files",
|
||||
func(name, args string) { fmt.Printf("Tool: %s\n", name) },
|
||||
func(name, args, result string, isError bool) { /* tool result */ },
|
||||
func(chunk string) { fmt.Print(chunk) }, // streaming
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Event System
|
||||
|
||||
Events are read-only observations of the agent lifecycle. Register before calling Prompt.
|
||||
|
||||
### Typed convenience subscribers
|
||||
|
||||
```go
|
||||
// Each returns an unsubscribe function.
|
||||
unsub := host.OnToolCall(func(e kit.ToolCallEvent) {
|
||||
// e.ToolCallID, e.ToolName, e.ToolKind, e.ToolArgs, e.ParsedArgs
|
||||
})
|
||||
defer unsub()
|
||||
|
||||
host.OnToolResult(func(e kit.ToolResultEvent) {
|
||||
// e.ToolCallID, e.ToolName, e.ToolKind, e.ToolArgs, e.ParsedArgs
|
||||
// e.Result, e.IsError, e.Metadata (*ToolResultMetadata)
|
||||
})
|
||||
|
||||
host.OnToolOutput(func(e kit.ToolOutputEvent) {
|
||||
// e.ToolCallID, e.ToolName, e.Chunk, e.IsStderr
|
||||
// Streaming bash output chunks
|
||||
})
|
||||
|
||||
host.OnStreaming(func(e kit.MessageUpdateEvent) {
|
||||
fmt.Print(e.Chunk) // real-time text streaming
|
||||
})
|
||||
|
||||
host.OnResponse(func(e kit.ResponseEvent) {
|
||||
// e.Content — final response text
|
||||
})
|
||||
|
||||
host.OnTurnStart(func(e kit.TurnStartEvent) {
|
||||
// e.Prompt
|
||||
})
|
||||
|
||||
host.OnTurnEnd(func(e kit.TurnEndEvent) {
|
||||
// e.Response, e.Error, e.StopReason
|
||||
})
|
||||
```
|
||||
|
||||
### Generic subscriber (receives all events)
|
||||
|
||||
```go
|
||||
unsub := host.Subscribe(func(e kit.Event) {
|
||||
switch ev := e.(type) {
|
||||
case kit.ToolCallEvent:
|
||||
// ...
|
||||
case kit.MessageUpdateEvent:
|
||||
// ...
|
||||
case kit.CompactionEvent:
|
||||
// ev.Summary, ev.OriginalTokens, ev.CompactedTokens
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
### All event types
|
||||
|
||||
| Event Type | Struct | Key Fields |
|
||||
|------------|--------|------------|
|
||||
| `turn_start` | `TurnStartEvent` | `Prompt` |
|
||||
| `turn_end` | `TurnEndEvent` | `Response`, `Error`, `StopReason` |
|
||||
| `message_start` | `MessageStartEvent` | *(none)* |
|
||||
| `message_update` | `MessageUpdateEvent` | `Chunk` |
|
||||
| `message_end` | `MessageEndEvent` | `Content` |
|
||||
| `tool_call` | `ToolCallEvent` | `ToolCallID`, `ToolName`, `ToolKind`, `ToolArgs`, `ParsedArgs` |
|
||||
| `tool_execution_start` | `ToolExecutionStartEvent` | `ToolCallID`, `ToolName`, `ToolKind`, `ToolArgs` |
|
||||
| `tool_execution_end` | `ToolExecutionEndEvent` | `ToolCallID`, `ToolName`, `ToolKind` |
|
||||
| `tool_result` | `ToolResultEvent` | `ToolCallID`, `ToolName`, `ToolKind`, `ToolArgs`, `ParsedArgs`, `Result`, `IsError`, `Metadata` |
|
||||
| `tool_call_content` | `ToolCallContentEvent` | `Content` |
|
||||
| `tool_output` | `ToolOutputEvent` | `ToolCallID`, `ToolName`, `Chunk`, `IsStderr` |
|
||||
| `response` | `ResponseEvent` | `Content` |
|
||||
| `compaction` | `CompactionEvent` | `Summary`, `OriginalTokens`, `CompactedTokens`, `MessagesRemoved`, `ReadFiles`, `ModifiedFiles` |
|
||||
| `reasoning_delta` | `ReasoningDeltaEvent` | `Delta` |
|
||||
|
||||
### Tool kind constants
|
||||
|
||||
Tools are classified by kind for UI rendering:
|
||||
|
||||
- `ToolKindExecute` = `"execute"` — bash
|
||||
- `ToolKindEdit` = `"edit"` — edit, write
|
||||
- `ToolKindRead` = `"read"` — read, ls
|
||||
- `ToolKindSearch` = `"search"` — grep, find
|
||||
- `ToolKindSubagent` = `"agent"` — spawn_subagent
|
||||
|
||||
---
|
||||
|
||||
## Hook System (Interceptors)
|
||||
|
||||
Hooks can **modify or cancel** operations. Events are read-only; hooks are read-write.
|
||||
|
||||
### BeforeToolCall — block tool execution
|
||||
|
||||
```go
|
||||
unsub := host.OnBeforeToolCall(kit.HookPriorityNormal, func(h kit.BeforeToolCallHook) *kit.BeforeToolCallResult {
|
||||
// h.ToolCallID, h.ToolName, h.ToolArgs
|
||||
if h.ToolName == "bash" {
|
||||
return &kit.BeforeToolCallResult{Block: true, Reason: "bash disabled"}
|
||||
}
|
||||
return nil // allow
|
||||
})
|
||||
```
|
||||
|
||||
### AfterToolResult — modify tool output
|
||||
|
||||
```go
|
||||
host.OnAfterToolResult(kit.HookPriorityNormal, func(h kit.AfterToolResultHook) *kit.AfterToolResultResult {
|
||||
// h.ToolCallID, h.ToolName, h.ToolArgs, h.Result, h.IsError
|
||||
if h.ToolName == "read" {
|
||||
filtered := redactSecrets(h.Result)
|
||||
return &kit.AfterToolResultResult{Result: &filtered}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
### BeforeTurn — modify prompt, inject messages
|
||||
|
||||
```go
|
||||
host.OnBeforeTurn(kit.HookPriorityNormal, func(h kit.BeforeTurnHook) *kit.BeforeTurnResult {
|
||||
// h.Prompt
|
||||
newPrompt := h.Prompt + "\nAlways respond in JSON."
|
||||
return &kit.BeforeTurnResult{Prompt: &newPrompt}
|
||||
// Also available: SystemPrompt *string, InjectText *string
|
||||
})
|
||||
```
|
||||
|
||||
### AfterTurn — observation only
|
||||
|
||||
```go
|
||||
host.OnAfterTurn(kit.HookPriorityNormal, func(h kit.AfterTurnHook) {
|
||||
// h.Response, h.Error
|
||||
log.Printf("Turn completed: %d chars", len(h.Response))
|
||||
})
|
||||
```
|
||||
|
||||
### ContextPrepare — filter/inject context window
|
||||
|
||||
```go
|
||||
host.OnContextPrepare(kit.HookPriorityNormal, func(h kit.ContextPrepareHook) *kit.ContextPrepareResult {
|
||||
// h.Messages — []fantasy.Message (the full context being sent to the LLM)
|
||||
// Return nil to pass through, or replace entire context:
|
||||
return &kit.ContextPrepareResult{Messages: filteredMessages}
|
||||
})
|
||||
```
|
||||
|
||||
### BeforeCompact — cancel or customize compaction
|
||||
|
||||
```go
|
||||
host.OnBeforeCompact(kit.HookPriorityNormal, func(h kit.BeforeCompactHook) *kit.BeforeCompactResult {
|
||||
// h.EstimatedTokens, h.ContextLimit, h.UsagePercent, h.MessageCount, h.IsAutomatic
|
||||
if h.IsAutomatic && h.UsagePercent < 0.9 {
|
||||
return &kit.BeforeCompactResult{Cancel: true, Reason: "not yet"}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
### Hook priorities
|
||||
|
||||
```go
|
||||
kit.HookPriorityHigh = 0 // runs first
|
||||
kit.HookPriorityNormal = 50 // default
|
||||
kit.HookPriorityLow = 100 // runs last
|
||||
```
|
||||
|
||||
Lower values run first. Within the same priority, registration order applies. First non-nil result wins.
|
||||
|
||||
---
|
||||
|
||||
## Tools
|
||||
|
||||
### Built-in tool constructors
|
||||
|
||||
```go
|
||||
kit.NewReadTool(opts...) // file reading
|
||||
kit.NewWriteTool(opts...) // file writing
|
||||
kit.NewEditTool(opts...) // surgical text editing
|
||||
kit.NewBashTool(opts...) // bash command execution
|
||||
kit.NewGrepTool(opts...) // content search (uses ripgrep when available)
|
||||
kit.NewFindTool(opts...) // file search (uses fd when available)
|
||||
kit.NewLsTool(opts...) // directory listing
|
||||
```
|
||||
|
||||
### Tool bundles
|
||||
|
||||
```go
|
||||
kit.AllTools(opts...) // all 7 core tools
|
||||
kit.CodingTools(opts...) // bash, read, write, edit
|
||||
kit.ReadOnlyTools(opts...) // read, grep, find, ls
|
||||
kit.SubagentTools(opts...) // all except spawn_subagent (prevents recursion)
|
||||
```
|
||||
|
||||
### Tool options
|
||||
|
||||
```go
|
||||
kit.WithWorkDir("/path/to/dir") // override working directory for file-based tools
|
||||
```
|
||||
|
||||
### Using tools in Options
|
||||
|
||||
```go
|
||||
// Restricted: agent can ONLY run bash
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
Tools: []kit.Tool{kit.NewBashTool()},
|
||||
})
|
||||
|
||||
// Extended: all defaults PLUS a custom tool
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
ExtraTools: []kit.Tool{myCustomTool},
|
||||
})
|
||||
```
|
||||
|
||||
### Querying tools at runtime
|
||||
|
||||
```go
|
||||
names := host.GetToolNames() // []string of all tool names
|
||||
tools := host.GetTools() // []kit.Tool (full tool objects)
|
||||
mcpCount := host.GetMCPToolCount() // tools from MCP servers
|
||||
extCount := host.GetExtensionToolCount() // tools from extensions
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Session Management
|
||||
|
||||
Sessions automatically persist as JSONL tree files. No explicit save needed.
|
||||
|
||||
### Session modes (via Options)
|
||||
|
||||
| Mode | Options | Behavior |
|
||||
|------|---------|----------|
|
||||
| Default | `{}` | New session file for cwd |
|
||||
| Specific file | `{SessionPath: "path.jsonl"}` | Open existing session |
|
||||
| Continue | `{Continue: true}` | Resume most recent session for cwd |
|
||||
| Ephemeral | `{NoSession: true}` | In-memory only, no disk persistence |
|
||||
| Custom dir | `{SessionDir: "/path"}` | Base directory for session discovery |
|
||||
|
||||
### Instance methods
|
||||
|
||||
```go
|
||||
host.GetSessionPath() // file path of active session
|
||||
host.GetSessionID() // UUID of active session
|
||||
host.ClearSession() // reset to fresh branch (doesn't delete file)
|
||||
host.Branch("entry-id") // branch from a specific entry
|
||||
host.SetSessionName("my session") // set display name
|
||||
|
||||
// Get conversation messages
|
||||
msgs := host.GetSessionMessages() // []extensions.SessionMessage (flattened text)
|
||||
msgs := host.GetStructuredMessages() // []kit.StructuredMessage (typed content parts)
|
||||
```
|
||||
|
||||
### Package-level session operations (no Kit instance needed)
|
||||
|
||||
```go
|
||||
sessions, _ := kit.ListSessions("/path/to/project") // sessions for a directory
|
||||
sessions, _ := kit.ListAllSessions() // all sessions everywhere
|
||||
kit.DeleteSession("/path/to/session.jsonl")
|
||||
tm, _ := kit.OpenTreeSession("/path/to/session.jsonl") // open for direct access
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Model Management
|
||||
|
||||
### At creation time
|
||||
|
||||
```go
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
Model: "openai/gpt-4o",
|
||||
})
|
||||
```
|
||||
|
||||
### At runtime
|
||||
|
||||
```go
|
||||
err := host.SetModel(ctx, "anthropic/claude-sonnet-4-5-20250929")
|
||||
modelStr := host.GetModelString() // "provider/model"
|
||||
info := host.GetModelInfo() // *kit.ModelInfo (capabilities, limits, pricing) or nil
|
||||
isReasoning := host.IsReasoningModel()
|
||||
level := host.GetThinkingLevel()
|
||||
err = host.SetThinkingLevel(ctx, "medium") // recreates agent with new thinking budget
|
||||
```
|
||||
|
||||
### Model registry
|
||||
|
||||
```go
|
||||
models := host.GetAvailableModels() // []extensions.ModelInfoEntry
|
||||
providers := kit.GetSupportedProviders() // []string
|
||||
providers := kit.GetFantasyProviders() // providers usable with fantasy
|
||||
models, _ := kit.GetModelsForProvider("anthropic") // map[string]kit.ModelInfo
|
||||
info := kit.LookupModel("anthropic", "claude-sonnet-4-5-20250929") // *kit.ModelInfo
|
||||
info := kit.GetProviderInfo("openai") // *kit.ProviderInfo (env vars, API URL)
|
||||
err := kit.ValidateEnvironment("anthropic", "") // check API keys
|
||||
suggestions := kit.SuggestModels("anthropic", "claudee") // fuzzy match
|
||||
kit.RefreshModelRegistry() // reload model database
|
||||
```
|
||||
|
||||
### Model string format
|
||||
|
||||
Always `"provider/model"`: `"anthropic/claude-sonnet-4-5-20250929"`, `"openai/gpt-4o"`, `"ollama/qwen3:8b"`.
|
||||
|
||||
```go
|
||||
provider, modelID, err := kit.ParseModelString("anthropic/claude-sonnet-4-5-20250929")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Context & Compaction
|
||||
|
||||
```go
|
||||
tokens := host.EstimateContextTokens() // heuristic token count
|
||||
shouldCompact := host.ShouldCompact() // true if near context limit
|
||||
|
||||
stats := host.GetContextStats()
|
||||
// stats.EstimatedTokens — uses API-reported count when available (more accurate)
|
||||
// stats.ContextLimit — model's context window size
|
||||
// stats.UsagePercent — fraction used (0.0–1.0)
|
||||
// stats.MessageCount — number of messages
|
||||
|
||||
// Manual compaction
|
||||
result, err := host.Compact(ctx, nil, "") // nil opts = defaults, "" = default prompt
|
||||
// result.Summary, result.OriginalTokens, result.CompactedTokens, result.MessagesRemoved
|
||||
|
||||
// Auto-compaction via Options
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
AutoCompact: true,
|
||||
CompactionOptions: &kit.CompactionOptions{
|
||||
ReserveTokens: 16384,
|
||||
KeepRecentTokens: 4096,
|
||||
ContextWindow: 200000,
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## In-Process Subagents
|
||||
|
||||
Spawn child Kit instances without subprocess overhead:
|
||||
|
||||
```go
|
||||
result, err := host.Subagent(ctx, kit.SubagentConfig{
|
||||
Prompt: "Analyze the test files and summarize coverage",
|
||||
Model: "anthropic/claude-haiku-3-5-20241022", // empty = parent's model
|
||||
SystemPrompt: "You are a test analysis expert.",
|
||||
Tools: nil, // nil = SubagentTools() (all except spawn_subagent)
|
||||
NoSession: true, // ephemeral
|
||||
Timeout: 2 * time.Minute, // 0 = 5 minute default
|
||||
OnEvent: func(e kit.Event) {
|
||||
// Real-time events from the child agent
|
||||
if chunk, ok := e.(kit.MessageUpdateEvent); ok {
|
||||
fmt.Print(chunk.Chunk)
|
||||
}
|
||||
},
|
||||
})
|
||||
// result.Response, result.Error, result.SessionID, result.StopReason
|
||||
// result.Usage (*kit.FantasyUsage), result.Elapsed (time.Duration)
|
||||
```
|
||||
|
||||
### Subscribing to subagent events from parent
|
||||
|
||||
```go
|
||||
host.OnToolCall(func(e kit.ToolCallEvent) {
|
||||
if e.ToolName == "spawn_subagent" {
|
||||
host.SubscribeSubagent(e.ToolCallID, func(child kit.Event) {
|
||||
// Real-time events scoped to this subagent
|
||||
})
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Authentication
|
||||
|
||||
```go
|
||||
cm, _ := kit.NewCredentialManager()
|
||||
hasKey := kit.HasAnthropicCredentials()
|
||||
apiKey := kit.GetAnthropicAPIKey() // stored creds → ANTHROPIC_API_KEY env var
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Skills
|
||||
|
||||
```go
|
||||
// Load a single skill file
|
||||
skill, _ := kit.LoadSkill("/path/to/SKILL.md")
|
||||
// skill.Name, skill.Description, skill.Content, skill.Path
|
||||
|
||||
// Load from directory
|
||||
skills, _ := kit.LoadSkillsFromDir("/path/to/skills")
|
||||
|
||||
// Auto-discover (global + project-local)
|
||||
skills, _ := kit.LoadSkills("/path/to/project")
|
||||
|
||||
// Prompt building with skills
|
||||
pb := kit.NewPromptBuilder("You are an assistant")
|
||||
pb.WithSkills(skills)
|
||||
pb.WithSection("", "Extra context here")
|
||||
systemPrompt := pb.Build()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Re-exported Types
|
||||
|
||||
The SDK re-exports internal types so you don't need direct internal imports:
|
||||
|
||||
```go
|
||||
// Message types
|
||||
kit.Message, kit.MessageRole, kit.ContentPart
|
||||
kit.TextContent, kit.ReasoningContent, kit.ToolCall, kit.ToolResult, kit.Finish
|
||||
kit.RoleUser, kit.RoleAssistant, kit.RoleTool, kit.RoleSystem
|
||||
|
||||
// Session types
|
||||
kit.SessionInfo, kit.TreeManager, kit.SessionHeader, kit.MessageEntry
|
||||
|
||||
// Config types
|
||||
kit.Config, kit.MCPServerConfig
|
||||
|
||||
// Provider types
|
||||
kit.ProviderConfig, kit.ProviderResult, kit.ModelInfo, kit.ModelCost, kit.ModelLimit
|
||||
|
||||
// Fantasy types (from charm.land/fantasy)
|
||||
kit.FantasyMessage, kit.FantasyUsage, kit.FantasyResponse
|
||||
|
||||
// Compaction types
|
||||
kit.CompactionResult, kit.CompactionOptions
|
||||
|
||||
// Conversion helpers
|
||||
msgs := kit.ConvertToFantasyMessages(&msg) // SDK message → fantasy messages
|
||||
msg := kit.ConvertFromFantasyMessage(fMsg) // fantasy message → SDK message
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Pattern: Scripting / CLI pipe
|
||||
|
||||
Minimal program for automation — stdout-only output:
|
||||
|
||||
```go
|
||||
host, _ := kit.New(ctx, &kit.Options{Quiet: true})
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
response, _ := host.Prompt(ctx, os.Args[1])
|
||||
fmt.Println(response)
|
||||
```
|
||||
|
||||
### Pattern: Long-running autonomous agent
|
||||
|
||||
Daemon that performs repeated independent tasks:
|
||||
|
||||
```go
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
SystemPrompt: taskPrompt,
|
||||
Tools: []kit.Tool{kit.NewBashTool()},
|
||||
NoSession: true,
|
||||
Quiet: true,
|
||||
})
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
ticker := time.NewTicker(30 * time.Minute)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
host.ClearSession() // fresh context each iteration
|
||||
host.Prompt(ctx, "Perform the monitoring task")
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Pattern: Streaming output to terminal
|
||||
|
||||
```go
|
||||
host.OnStreaming(func(e kit.MessageUpdateEvent) {
|
||||
fmt.Print(e.Chunk)
|
||||
})
|
||||
response, _ := host.Prompt(ctx, "Write a poem")
|
||||
```
|
||||
|
||||
### Pattern: Multi-turn conversation with memory
|
||||
|
||||
```go
|
||||
host.Prompt(ctx, "My name is Alice")
|
||||
response, _ := host.Prompt(ctx, "What's my name?")
|
||||
// Session automatically maintains context across calls
|
||||
fmt.Printf("Session: %s\n", host.GetSessionPath())
|
||||
```
|
||||
|
||||
### Pattern: Tool execution monitoring
|
||||
|
||||
```go
|
||||
host.OnToolCall(func(e kit.ToolCallEvent) {
|
||||
fmt.Printf("[%s] %s(%s)\n", e.ToolKind, e.ToolName, e.ToolArgs)
|
||||
})
|
||||
host.OnToolResult(func(e kit.ToolResultEvent) {
|
||||
status := "✓"
|
||||
if e.IsError { status = "✗" }
|
||||
fmt.Printf("[%s] %s %s\n", e.ToolKind, status, e.ToolName)
|
||||
})
|
||||
```
|
||||
|
||||
### Pattern: Guard rails with hooks
|
||||
|
||||
```go
|
||||
// Block dangerous commands
|
||||
host.OnBeforeToolCall(kit.HookPriorityHigh, func(h kit.BeforeToolCallHook) *kit.BeforeToolCallResult {
|
||||
if h.ToolName == "bash" && strings.Contains(h.ToolArgs, "rm -rf") {
|
||||
return &kit.BeforeToolCallResult{Block: true, Reason: "dangerous command"}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Inject context before every turn
|
||||
host.OnBeforeTurn(kit.HookPriorityNormal, func(h kit.BeforeTurnHook) *kit.BeforeTurnResult {
|
||||
context := "Current user: admin\nEnvironment: production"
|
||||
return &kit.BeforeTurnResult{InjectText: &context}
|
||||
})
|
||||
```
|
||||
|
||||
### Pattern: Parallel subagents
|
||||
|
||||
```go
|
||||
var wg sync.WaitGroup
|
||||
results := make([]*kit.SubagentResult, 3)
|
||||
|
||||
tasks := []string{"Analyze auth module", "Analyze database layer", "Analyze API routes"}
|
||||
for i, task := range tasks {
|
||||
wg.Add(1)
|
||||
go func(idx int, t string) {
|
||||
defer wg.Done()
|
||||
results[idx], _ = host.Subagent(ctx, kit.SubagentConfig{
|
||||
Prompt: t,
|
||||
NoSession: true,
|
||||
Timeout: 3 * time.Minute,
|
||||
})
|
||||
}(i, task)
|
||||
}
|
||||
wg.Wait()
|
||||
```
|
||||
|
||||
### Pattern: Read-only analysis agent
|
||||
|
||||
```go
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
SystemPrompt: "You are a code reviewer. Only read and analyze, never modify files.",
|
||||
Tools: kit.ReadOnlyTools(),
|
||||
})
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
The SDK loads config identically to the CLI:
|
||||
|
||||
1. Explicit `ConfigFile` in Options (highest priority)
|
||||
2. `.kit.yml` in current directory
|
||||
3. `~/.kit.yml` in home directory
|
||||
4. Environment variables with `KIT_` prefix (`KIT_MODEL`, etc.)
|
||||
5. Provider-specific env vars (`ANTHROPIC_API_KEY`, `OPENAI_API_KEY`, etc.)
|
||||
|
||||
Config files support `${ENV_VAR}` expansion.
|
||||
|
||||
```go
|
||||
// Initialize config manually (usually not needed — kit.New handles this)
|
||||
kit.InitConfig("/path/to/config.yml", false)
|
||||
kit.LoadConfigWithEnvSubstitution("/path/to/config.yml")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Key Files for Reference
|
||||
|
||||
- [`pkg/kit/kit.go`](https://github.com/mark3labs/kit/blob/main/pkg/kit/kit.go) — Kit struct, New(), Prompt methods, Subagent, Close
|
||||
- [`pkg/kit/types.go`](https://github.com/mark3labs/kit/blob/main/pkg/kit/types.go) — Re-exported types from internal packages
|
||||
- [`pkg/kit/tools.go`](https://github.com/mark3labs/kit/blob/main/pkg/kit/tools.go) — Tool constructors and bundles
|
||||
- [`pkg/kit/events.go`](https://github.com/mark3labs/kit/blob/main/pkg/kit/events.go) — Event types, EventBus, typed subscribers
|
||||
- [`pkg/kit/hooks.go`](https://github.com/mark3labs/kit/blob/main/pkg/kit/hooks.go) — Hook system (BeforeToolCall, AfterToolResult, etc.)
|
||||
- [`pkg/kit/sessions.go`](https://github.com/mark3labs/kit/blob/main/pkg/kit/sessions.go) — Session management
|
||||
- [`pkg/kit/compaction.go`](https://github.com/mark3labs/kit/blob/main/pkg/kit/compaction.go) — Context compaction
|
||||
- [`pkg/kit/models.go`](https://github.com/mark3labs/kit/blob/main/pkg/kit/models.go) — Model registry lookups
|
||||
- [`pkg/kit/config.go`](https://github.com/mark3labs/kit/blob/main/pkg/kit/config.go) — Config initialization and defaults
|
||||
- [`pkg/kit/skills.go`](https://github.com/mark3labs/kit/blob/main/pkg/kit/skills.go) — Skills loading and prompt building
|
||||
- [`pkg/kit/auth.go`](https://github.com/mark3labs/kit/blob/main/pkg/kit/auth.go) — Credential management
|
||||
- [`examples/sdk/`](https://github.com/mark3labs/kit/tree/main/examples/sdk) — Working example programs
|
||||
@@ -59,6 +59,79 @@ result := ctx.SpawnSubagent(ext.SubagentConfig{
|
||||
})
|
||||
```
|
||||
|
||||
### Monitoring subagents from extensions
|
||||
|
||||
When the LLM (not the extension itself) spawns a subagent using the `spawn_subagent` tool, extensions can monitor its activity in real-time using three lifecycle event handlers:
|
||||
|
||||
```go
|
||||
// Track active subagents and display their output
|
||||
var subagentWidgets map[string]*SubagentWidget
|
||||
|
||||
func Init(api ext.API) {
|
||||
// Subagent started by the main agent
|
||||
api.OnSubagentStart(func(e ext.SubagentStartEvent, ctx ext.Context) {
|
||||
// e.ToolCallID — unique ID for this subagent invocation
|
||||
// e.Task — the task/prompt sent to the subagent
|
||||
widget := NewWidget(e.ToolCallID, e.Task)
|
||||
subagentWidgets[e.ToolCallID] = widget
|
||||
ctx.SetWidget(widget.Config())
|
||||
})
|
||||
|
||||
// Real-time streaming from subagent
|
||||
api.OnSubagentChunk(func(e ext.SubagentChunkEvent, ctx ext.Context) {
|
||||
// e.ToolCallID — matches the start event
|
||||
// e.ChunkType — "text", "tool_call", "tool_execution_start", "tool_result"
|
||||
// e.Content — text content
|
||||
// e.ToolName — tool name (for tool chunks)
|
||||
// e.IsError — true if tool result failed
|
||||
widget := subagentWidgets[e.ToolCallID]
|
||||
if widget != nil {
|
||||
widget.AddOutput(e)
|
||||
ctx.SetWidget(widget.Config())
|
||||
}
|
||||
})
|
||||
|
||||
// Subagent completed
|
||||
api.OnSubagentEnd(func(e ext.SubagentEndEvent, ctx ext.Context) {
|
||||
// e.Response — final response from subagent
|
||||
// e.ErrorMsg — error message if subagent failed
|
||||
widget := subagentWidgets[e.ToolCallID]
|
||||
if widget != nil {
|
||||
widget.MarkComplete(e.Response, e.ErrorMsg)
|
||||
ctx.SetWidget(widget.Config())
|
||||
delete(subagentWidgets, e.ToolCallID)
|
||||
}
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
**Event structs:**
|
||||
|
||||
```go
|
||||
type SubagentStartEvent struct {
|
||||
ToolCallID string // Unique ID for this subagent invocation
|
||||
Task string // The task/prompt sent to subagent
|
||||
}
|
||||
|
||||
type SubagentChunkEvent struct {
|
||||
ToolCallID string // Matches SubagentStartEvent.ToolCallID
|
||||
Task string // Task description
|
||||
ChunkType string // "text", "tool_call", "tool_execution_start", "tool_result"
|
||||
Content string // For text chunks
|
||||
ToolName string // For tool-related chunks
|
||||
IsError bool // For tool_result chunks
|
||||
}
|
||||
|
||||
type SubagentEndEvent struct {
|
||||
ToolCallID string // Matches start event
|
||||
Task string // Task description
|
||||
Response string // Final response from subagent
|
||||
ErrorMsg string // Error message if failed
|
||||
}
|
||||
```
|
||||
|
||||
This enables building monitoring widgets that display real-time activity from all subagents spawned by the main agent.
|
||||
|
||||
## Go SDK subagents
|
||||
|
||||
The SDK provides in-process subagent spawning:
|
||||
|
||||
@@ -74,7 +74,7 @@ These commands are available inside the Kit TUI during an interactive session:
|
||||
| `/reset-usage` | Reset usage statistics |
|
||||
| `/tree` | Navigate session tree |
|
||||
| `/fork` | Branch from an earlier message |
|
||||
| `/new` | Start a new session |
|
||||
| `/new` | Start a new session (creates new session file) |
|
||||
| `/name [name]` | Set or show session display name |
|
||||
| `/resume` | Open session picker to switch sessions (alias: `/r`) |
|
||||
| `/session` | Show session info |
|
||||
@@ -95,9 +95,17 @@ Press **ESC twice** to cancel the current operation:
|
||||
|
||||
This ensures that `tool_use` and `tool_result` messages are always sent to the API as matched pairs, avoiding errors from orphaned tool calls.
|
||||
|
||||
## Prompt templates
|
||||
### Mid-turn steering
|
||||
|
||||
Create reusable prompt templates with shell-style argument substitution. Templates are loaded from `~/.kit/prompts/*.md` and `.kit/prompts/*.md`.
|
||||
Press **Ctrl+S** during streaming to inject a system-level instruction mid-turn. This allows you to steer the conversation direction without waiting for the model to finish:
|
||||
|
||||
- Works during streaming output
|
||||
- Sends a steering instruction as a system message
|
||||
- Model continues from the interruption point with the new guidance
|
||||
|
||||
Example: While the model is writing code, press Ctrl+S and type "Use async/await instead" to change the implementation approach.
|
||||
|
||||
## Prompt templates
|
||||
|
||||
### Creating templates
|
||||
|
||||
|
||||
@@ -96,9 +96,45 @@ mcpServers:
|
||||
|
||||
A legacy format with `transport`, `args`, `env`, and `headers` fields is also supported.
|
||||
|
||||
## Theme configuration
|
||||
## Custom models
|
||||
|
||||
Set theme colors inline or reference an external file:
|
||||
Define custom models in your `.kit.yml` for use with the `custom` provider. This is useful for self-hosted models or API endpoints not in the built-in database:
|
||||
|
||||
```yaml
|
||||
customModels:
|
||||
my-model:
|
||||
name: "My Custom Model"
|
||||
reasoning: true
|
||||
temperature: true
|
||||
cost:
|
||||
input: 0.002
|
||||
output: 0.004
|
||||
limit:
|
||||
context: 128000
|
||||
output: 32000
|
||||
```
|
||||
|
||||
### Custom model fields
|
||||
|
||||
| Field | Type | Required | Description |
|
||||
|-------|------|----------|-------------|
|
||||
| `name` | string | Yes | Display name for the model |
|
||||
| `reasoning` | bool | No | Whether the model supports reasoning/thinking |
|
||||
| `temperature` | bool | No | Whether the model supports temperature adjustment |
|
||||
| `cost.input` | float | No | Cost per 1K input tokens |
|
||||
| `cost.output` | float | No | Cost per 1K output tokens |
|
||||
| `limit.context` | int | Yes | Maximum context window in tokens |
|
||||
| `limit.output` | int | No | Maximum output tokens |
|
||||
|
||||
Use with a custom provider URL:
|
||||
|
||||
```bash
|
||||
kit --provider-url "http://localhost:8080/v1" --model custom/my-model "Hello"
|
||||
```
|
||||
|
||||
When `--provider-url` is specified without `--model`, Kit defaults to `custom/custom` which has zero cost tracking and a 262K context window.
|
||||
|
||||
## Theme configuration
|
||||
|
||||
```yaml
|
||||
# Inline partial overrides (unspecified fields inherit from default)
|
||||
|
||||
@@ -7,7 +7,7 @@ description: All extension capabilities — lifecycle events, tools, commands, w
|
||||
|
||||
## Lifecycle events
|
||||
|
||||
Extensions can hook into 18 lifecycle events:
|
||||
Extensions can hook into 23 lifecycle events:
|
||||
|
||||
| Event | Description |
|
||||
|-------|-------------|
|
||||
@@ -18,6 +18,7 @@ Extensions can hook into 18 lifecycle events:
|
||||
| `OnAgentEnd` | Agent loop completed |
|
||||
| `OnToolCall` | Tool call requested by the model |
|
||||
| `OnToolExecutionStart` | Tool execution beginning |
|
||||
| `OnToolOutput` | Streaming tool output chunk (for long-running tools) |
|
||||
| `OnToolExecutionEnd` | Tool execution completed |
|
||||
| `OnToolResult` | Tool result returned |
|
||||
| `OnInput` | User input received |
|
||||
@@ -29,6 +30,10 @@ Extensions can hook into 18 lifecycle events:
|
||||
| `OnBeforeFork` | Before forking a conversation branch |
|
||||
| `OnBeforeSessionSwitch` | Before switching sessions |
|
||||
| `OnBeforeCompact` | Before conversation compaction |
|
||||
| `OnCustomEvent` | Custom inter-extension event received |
|
||||
| `OnSubagentStart` | Subagent spawned by the main agent |
|
||||
| `OnSubagentChunk` | Real-time output from subagent (text, tool calls, results) |
|
||||
| `OnSubagentEnd` | Subagent completed with final response/error |
|
||||
|
||||
### Example
|
||||
|
||||
@@ -232,6 +237,54 @@ result := ctx.SpawnSubagent(ext.SubagentConfig{
|
||||
})
|
||||
```
|
||||
|
||||
### Monitoring subagents spawned by the main agent
|
||||
|
||||
When the LLM uses the built-in `spawn_subagent` tool, extensions can monitor the subagent's activity in real-time using three lifecycle events:
|
||||
|
||||
```go
|
||||
// Subagent started
|
||||
api.OnSubagentStart(func(e ext.SubagentStartEvent, ctx ext.Context) {
|
||||
// e.ToolCallID — unique ID for this subagent invocation
|
||||
// e.Task — the task/prompt sent to the subagent
|
||||
ctx.PrintInfo(fmt.Sprintf("Subagent started: %s", e.Task))
|
||||
})
|
||||
|
||||
// Real-time streaming output from subagent
|
||||
api.OnSubagentChunk(func(e ext.SubagentChunkEvent, ctx ext.Context) {
|
||||
// e.ToolCallID — matches the start event
|
||||
// e.Task — task description
|
||||
// e.ChunkType — "text", "tool_call", "tool_execution_start", "tool_result"
|
||||
// e.Content — text content (for text chunks)
|
||||
// e.ToolName — tool name (for tool-related chunks)
|
||||
// e.IsError — true if tool result is an error
|
||||
switch e.ChunkType {
|
||||
case "text":
|
||||
// Streaming text output
|
||||
case "tool_call":
|
||||
// Subagent is calling a tool
|
||||
case "tool_execution_start":
|
||||
// Tool execution started
|
||||
case "tool_result":
|
||||
// Tool execution completed (check e.IsError)
|
||||
}
|
||||
})
|
||||
|
||||
// Subagent completed
|
||||
api.OnSubagentEnd(func(e ext.SubagentEndEvent, ctx ext.Context) {
|
||||
// e.ToolCallID — matches start event
|
||||
// e.Task — task description
|
||||
// e.Response — final response from subagent
|
||||
// e.ErrorMsg — error message if subagent failed
|
||||
if e.ErrorMsg != "" {
|
||||
ctx.PrintError(fmt.Sprintf("Subagent failed: %s", e.ErrorMsg))
|
||||
} else {
|
||||
ctx.PrintInfo(fmt.Sprintf("Subagent completed: %s", e.Response))
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
This enables building widgets that display real-time subagent activity.
|
||||
|
||||
## LLM completion
|
||||
|
||||
Make direct model calls without going through the agent loop:
|
||||
|
||||
@@ -51,6 +51,12 @@ Kit ships with a rich set of example extensions in the `examples/extensions/` di
|
||||
| [`summarize.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/summarize.go) | Conversation summarization |
|
||||
| [`lsp-diagnostics.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/lsp-diagnostics.go) | LSP diagnostic integration |
|
||||
|
||||
## Themes
|
||||
|
||||
| Extension | Description |
|
||||
|-----------|-------------|
|
||||
| [`neon-theme.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/neon-theme.go) | Custom theme registration and switching |
|
||||
|
||||
## Multi-agent
|
||||
|
||||
| Extension | Description |
|
||||
@@ -74,3 +80,7 @@ Kit ships with a rich set of example extensions in the `examples/extensions/` di
|
||||
| [`kit-kit-agents/`](https://github.com/mark3labs/kit/tree/master/examples/extensions/kit-kit-agents) | Multi-agent orchestration example |
|
||||
| [`kit-telegram/`](https://github.com/mark3labs/kit/tree/master/examples/extensions/kit-telegram) | Telegram bot integration |
|
||||
| [`status-tools/`](https://github.com/mark3labs/kit/tree/master/examples/extensions/status-tools) | Status bar tool examples |
|
||||
|
||||
## Project-local example
|
||||
|
||||
The Kit repository also includes a project-local extension at `.kit/extensions/go-edit-lint.go` that demonstrates running `gopls` and `golangci-lint` on Go file edits. This serves as an example of how to create extensions specific to a project by placing them in the `.kit/extensions/` directory.
|
||||
|
||||
@@ -20,6 +20,7 @@ Kit supports a wide range of LLM providers through a unified `provider/model` st
|
||||
| **Google Vertex** | `google-vertex-anthropic/` | Claude on Vertex AI |
|
||||
| **OpenRouter** | `openrouter/` | Multi-provider router |
|
||||
| **Vercel AI** | `vercel/` | Vercel AI SDK models |
|
||||
| **Custom** | `custom/` | Any OpenAI-compatible endpoint |
|
||||
| **Auto-routed** | any | Any provider from the models.dev database |
|
||||
|
||||
## Model string format
|
||||
@@ -132,6 +133,16 @@ For self-hosted or proxy endpoints:
|
||||
kit --provider-url "https://my-proxy.example.com/v1" --model openai/gpt-4o
|
||||
```
|
||||
|
||||
When `--provider-url` is provided without `--model`, Kit automatically defaults to `custom/custom`:
|
||||
|
||||
```bash
|
||||
kit --provider-url "http://localhost:8080/v1" "Hello"
|
||||
```
|
||||
|
||||
The `custom/custom` model has zero cost, 262K context window, and supports reasoning. It routes through fantasy's `openaicompat` provider and accepts any OpenAI-compatible API endpoint.
|
||||
|
||||
Optionally set `CUSTOM_API_KEY` environment variable or use `--provider-api-key` for endpoints requiring authentication.
|
||||
|
||||
## Model database
|
||||
|
||||
Kit ships with a local model database that maps provider names to API configurations. You can manage it with:
|
||||
|
||||
+11
-3
@@ -30,12 +30,20 @@ When conversations grow long, Kit can compact them to free up context window spa
|
||||
|
||||
Use `/compact [focus]` to manually compact, or enable `--auto-compact` to compact automatically near the context limit.
|
||||
|
||||
## Auto-cleanup
|
||||
|
||||
Kit automatically cleans up empty sessions on shutdown and when using `/resume`. A session is considered empty if it has no messages beyond the initial system prompt. This prevents cluttering your sessions directory with unused files.
|
||||
|
||||
To start fresh without creating a session file at all, use ephemeral mode:
|
||||
|
||||
```bash
|
||||
kit --no-session
|
||||
```
|
||||
|
||||
## Resuming sessions
|
||||
|
||||
### Continue most recent
|
||||
|
||||
Resume the most recent session for the current directory:
|
||||
|
||||
```bash
|
||||
kit --continue
|
||||
kit -c
|
||||
@@ -73,7 +81,7 @@ These slash commands are available during an interactive session:
|
||||
| `/share` | Upload session to GitHub Gist and get a shareable viewer URL |
|
||||
| `/tree` | Navigate the session tree |
|
||||
| `/fork` | Branch from an earlier message |
|
||||
| `/new` | Start a fresh session |
|
||||
| `/new` | Start a new session (creates new session file) |
|
||||
|
||||
## Ephemeral mode
|
||||
|
||||
|
||||
Reference in New Issue
Block a user