mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-18 21:36:30 +00:00
Compare commits
32 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b68b3dd0bf | |||
| 48521bf76d | |||
| 16df3a738c | |||
| 9d0b8c8cef | |||
| d9326fcf21 | |||
| 22c479277e | |||
| 8ae204f12f | |||
| 8b1665a4ce | |||
| 941f1daf0b | |||
| ab7e2bda61 | |||
| 741520927c | |||
| 4c1bda9541 | |||
| 3b69b13556 | |||
| 83a959a379 | |||
| 3491e05e9e | |||
| 0a54a8aa05 | |||
| 3cb3e5dba1 | |||
| 31966c469f | |||
| f03625d6e5 | |||
| d06641dc0a | |||
| bbf1106e27 | |||
| babed03a3d | |||
| 1cd074836f | |||
| ab3ce260c8 | |||
| 8e8cc3946d | |||
| e18e36625e | |||
| be55bc03f1 | |||
| 09919b6307 | |||
| 7a2de4cc3c | |||
| acd7fd7f45 | |||
| 3446f38516 | |||
| db4bb19bac |
@@ -0,0 +1,304 @@
|
||||
//go:build ignore
|
||||
|
||||
// subagent-monitor — live horizontal widget strip for spawned subagents
|
||||
//
|
||||
// Subscribes to subagents spawned by the main Kit agent and displays a
|
||||
// single widget just above the input box. Each subagent occupies one column
|
||||
// in a side-by-side horizontal layout. Columns show scrolling real-time
|
||||
// output as the subagent works. When a subagent finishes its column is
|
||||
// removed automatically.
|
||||
//
|
||||
// Yaegi-safe design notes:
|
||||
// - No sync.Mutex (Yaegi has reflection issues with sync primitives)
|
||||
// - No channels in maps (Yaegi panics on range over map[string]chan)
|
||||
// - All ctx.* calls guarded with nil checks
|
||||
// - Simple data structures only
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Per-subagent state
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type submonEntry struct {
|
||||
id int
|
||||
callID string
|
||||
task string
|
||||
lines []string
|
||||
started time.Time
|
||||
elapsed time.Duration
|
||||
}
|
||||
|
||||
const (
|
||||
submonColWidth = 34 // visible character width per column
|
||||
submonMaxLines = 5 // scrolling output lines per column
|
||||
submonColGap = 2 // spaces between columns
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Package-level state - all simple types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
var (
|
||||
submonCtx ext.Context
|
||||
submonHasCtx bool
|
||||
submonEntries []*submonEntry
|
||||
submonNextID int
|
||||
)
|
||||
|
||||
func submonInit() {
|
||||
submonEntries = nil
|
||||
submonNextID = 1
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// String helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func submonPad(s string, w int) string {
|
||||
r := []rune(s)
|
||||
if len(r) >= w {
|
||||
return string(r[:w])
|
||||
}
|
||||
return s + strings.Repeat(" ", w-len(r))
|
||||
}
|
||||
|
||||
func submonTrunc(s string, w int) string {
|
||||
r := []rune(s)
|
||||
if len(r) <= w {
|
||||
return s
|
||||
}
|
||||
if w <= 1 {
|
||||
return "…"
|
||||
}
|
||||
return string(r[:w-1]) + "…"
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Widget rendering
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func submonRenderColumn(e *submonEntry) []string {
|
||||
var rows []string
|
||||
|
||||
// Calculate elapsed time on-demand to avoid race conditions with ticker
|
||||
elapsed := e.elapsed
|
||||
if elapsed == 0 && !e.started.IsZero() {
|
||||
elapsed = time.Since(e.started)
|
||||
}
|
||||
secs := int(elapsed.Seconds())
|
||||
timeStr := fmt.Sprintf("%ds", secs)
|
||||
taskMax := submonColWidth - len(timeStr) - 3
|
||||
taskPart := submonTrunc(e.task, taskMax)
|
||||
header := fmt.Sprintf("#%d %s %s", e.id, taskPart, timeStr)
|
||||
rows = append(rows, submonPad(header, submonColWidth))
|
||||
|
||||
display := e.lines
|
||||
if len(display) > submonMaxLines {
|
||||
display = display[len(display)-submonMaxLines:]
|
||||
}
|
||||
for _, l := range display {
|
||||
rows = append(rows, submonPad(" "+submonTrunc(l, submonColWidth-2), submonColWidth))
|
||||
}
|
||||
for len(rows) < submonMaxLines+1 {
|
||||
if len(rows) == 1 && len(e.lines) == 0 {
|
||||
rows = append(rows, submonPad(" waiting…", submonColWidth))
|
||||
} else {
|
||||
rows = append(rows, strings.Repeat(" ", submonColWidth))
|
||||
}
|
||||
}
|
||||
return rows
|
||||
}
|
||||
|
||||
func submonBuildWidget() string {
|
||||
if len(submonEntries) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
numCols := len(submonEntries)
|
||||
numRows := submonMaxLines + 1
|
||||
cols := make([][]string, numCols)
|
||||
for i, e := range submonEntries {
|
||||
rows := submonRenderColumn(e)
|
||||
col := make([]string, numRows)
|
||||
for j := 0; j < numRows; j++ {
|
||||
if j < len(rows) {
|
||||
col[j] = rows[j]
|
||||
} else {
|
||||
col[j] = strings.Repeat(" ", submonColWidth)
|
||||
}
|
||||
}
|
||||
cols[i] = col
|
||||
}
|
||||
|
||||
gap := strings.Repeat(" ", submonColGap)
|
||||
var sb strings.Builder
|
||||
for row := 0; row < numRows; row++ {
|
||||
for ci := range cols {
|
||||
if ci > 0 {
|
||||
sb.WriteString(gap)
|
||||
}
|
||||
sb.WriteString(cols[ci][row])
|
||||
}
|
||||
if row < numRows-1 {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func submonPushWidget() {
|
||||
if !submonHasCtx {
|
||||
return
|
||||
}
|
||||
if submonCtx.SetWidget == nil {
|
||||
return
|
||||
}
|
||||
|
||||
text := submonBuildWidget()
|
||||
if len(submonEntries) == 0 {
|
||||
if submonCtx.RemoveWidget != nil {
|
||||
submonCtx.RemoveWidget("submon")
|
||||
}
|
||||
return
|
||||
}
|
||||
submonCtx.SetWidget(ext.WidgetConfig{
|
||||
ID: "submon",
|
||||
Placement: ext.WidgetAbove,
|
||||
Content: ext.WidgetContent{Text: text},
|
||||
Style: ext.WidgetStyle{BorderColor: "#89b4fa"},
|
||||
Priority: 0,
|
||||
})
|
||||
}
|
||||
|
||||
func submonAppendLine(e *submonEntry, line string) {
|
||||
line = strings.TrimRight(line, "\r\n")
|
||||
if strings.TrimSpace(line) == "" {
|
||||
return
|
||||
}
|
||||
e.lines = append(e.lines, line)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Init
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func Init(api ext.API) {
|
||||
submonInit()
|
||||
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
submonCtx = ctx
|
||||
submonHasCtx = true
|
||||
submonInit()
|
||||
if ctx.RemoveWidget != nil {
|
||||
ctx.RemoveWidget("submon")
|
||||
}
|
||||
})
|
||||
|
||||
api.OnAgentEnd(func(_ ext.AgentEndEvent, ctx ext.Context) {
|
||||
submonCtx = ctx
|
||||
submonHasCtx = true
|
||||
})
|
||||
|
||||
// ── SubagentStart ────────────────────────────────────────────────────────
|
||||
api.OnSubagentStart(func(e ext.SubagentStartEvent, ctx ext.Context) {
|
||||
submonCtx = ctx
|
||||
submonHasCtx = true
|
||||
|
||||
id := submonNextID
|
||||
submonNextID++
|
||||
entry := &submonEntry{
|
||||
id: id,
|
||||
callID: e.ToolCallID,
|
||||
task: e.Task,
|
||||
started: time.Now(),
|
||||
}
|
||||
submonEntries = append(submonEntries, entry)
|
||||
|
||||
submonPushWidget()
|
||||
})
|
||||
|
||||
// ── SubagentChunk ────────────────────────────────────────────────────────
|
||||
api.OnSubagentChunk(func(e ext.SubagentChunkEvent, ctx ext.Context) {
|
||||
submonCtx = ctx
|
||||
submonHasCtx = true
|
||||
|
||||
var entry *submonEntry
|
||||
for _, en := range submonEntries {
|
||||
if en.callID == e.ToolCallID {
|
||||
entry = en
|
||||
break
|
||||
}
|
||||
}
|
||||
if entry == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch e.ChunkType {
|
||||
case "text":
|
||||
for _, line := range strings.Split(e.Content, "\n") {
|
||||
submonAppendLine(entry, line)
|
||||
}
|
||||
case "tool_call":
|
||||
submonAppendLine(entry, "→ "+e.ToolName)
|
||||
case "tool_execution_start":
|
||||
submonAppendLine(entry, "⚙ "+e.ToolName)
|
||||
case "tool_result":
|
||||
if e.IsError {
|
||||
submonAppendLine(entry, "✗ "+e.ToolName)
|
||||
} else {
|
||||
submonAppendLine(entry, "✓ "+e.ToolName)
|
||||
}
|
||||
}
|
||||
|
||||
submonPushWidget()
|
||||
})
|
||||
|
||||
// ── SubagentEnd ──────────────────────────────────────────────────────────
|
||||
api.OnSubagentEnd(func(e ext.SubagentEndEvent, ctx ext.Context) {
|
||||
submonCtx = ctx
|
||||
submonHasCtx = true
|
||||
|
||||
var entry *submonEntry
|
||||
for _, en := range submonEntries {
|
||||
if en.callID == e.ToolCallID {
|
||||
entry = en
|
||||
break
|
||||
}
|
||||
}
|
||||
if entry != nil {
|
||||
entry.elapsed = time.Since(entry.started)
|
||||
if e.ErrorMsg != "" {
|
||||
submonAppendLine(entry, "✗ "+submonTrunc(e.ErrorMsg, submonColWidth-2))
|
||||
}
|
||||
}
|
||||
|
||||
submonPushWidget()
|
||||
|
||||
// Remove the entry immediately (no goroutine to avoid races)
|
||||
newEntries := submonEntries[:0]
|
||||
for _, en := range submonEntries {
|
||||
if en.callID != e.ToolCallID {
|
||||
newEntries = append(newEntries, en)
|
||||
}
|
||||
}
|
||||
submonEntries = newEntries
|
||||
submonPushWidget()
|
||||
})
|
||||
|
||||
// ── SessionShutdown ──────────────────────────────────────────────────────
|
||||
api.OnSessionShutdown(func(_ ext.SessionShutdownEvent, ctx ext.Context) {
|
||||
submonInit()
|
||||
// Guard ctx access - may be nil during shutdown
|
||||
if ctx.RemoveWidget != nil {
|
||||
ctx.RemoveWidget("submon")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -287,7 +287,7 @@ kit -e examples/extensions/minimal.go
|
||||
|
||||
### Extension Capabilities
|
||||
|
||||
**Lifecycle Events**: OnSessionStart, OnSessionShutdown, OnBeforeAgentStart, OnAgentStart, OnAgentEnd, OnToolCall, OnToolExecutionStart, OnToolOutput, OnToolExecutionEnd, OnToolResult, OnInput, OnMessageStart, OnMessageUpdate, OnMessageEnd, OnModelChange, OnContextPrepare, OnBeforeFork, OnBeforeSessionSwitch, OnBeforeCompact, OnCustomEvent
|
||||
**Lifecycle Events**: OnSessionStart, OnSessionShutdown, OnBeforeAgentStart, OnAgentStart, OnAgentEnd, OnToolCall, OnToolExecutionStart, OnToolOutput, OnToolExecutionEnd, OnToolResult, OnInput, OnMessageStart, OnMessageUpdate, OnMessageEnd, OnModelChange, OnContextPrepare, OnBeforeFork, OnBeforeSessionSwitch, OnBeforeCompact, OnCustomEvent, OnSubagentStart, OnSubagentChunk, OnSubagentEnd
|
||||
|
||||
**Custom Components**:
|
||||
- **Tools**: Add new tools the LLM can invoke
|
||||
|
||||
+300
-7
@@ -1,9 +1,13 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/huh/v2"
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
@@ -14,7 +18,7 @@ import (
|
||||
// authCmd represents the auth command for managing AI provider authentication.
|
||||
// This command provides subcommands for login, logout, and status checking
|
||||
// of authentication credentials for various AI providers, with OAuth support
|
||||
// for providers like Anthropic.
|
||||
// for providers like Anthropic and OpenAI.
|
||||
var authCmd = &cobra.Command{
|
||||
Use: "auth",
|
||||
Short: "Manage authentication credentials for AI providers",
|
||||
@@ -25,9 +29,11 @@ using OAuth flows. Stored credentials take precedence over environment variables
|
||||
|
||||
Available providers:
|
||||
- anthropic: Anthropic Claude API (OAuth)
|
||||
- openai: OpenAI API (OAuth and API key)
|
||||
|
||||
Examples:
|
||||
kit auth login anthropic
|
||||
kit auth login openai
|
||||
kit auth logout anthropic
|
||||
kit auth status`,
|
||||
}
|
||||
@@ -46,9 +52,11 @@ environment variables when making API calls.
|
||||
|
||||
Available providers:
|
||||
- anthropic: Anthropic Claude API (OAuth)
|
||||
- openai: OpenAI ChatGPT Plus/Pro (Codex OAuth)
|
||||
|
||||
Example:
|
||||
kit auth login anthropic`,
|
||||
kit auth login anthropic
|
||||
kit auth login openai`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runAuthLogin,
|
||||
}
|
||||
@@ -61,14 +69,16 @@ var authLogoutCmd = &cobra.Command{
|
||||
Short: "Remove stored authentication credentials for a provider",
|
||||
Long: `Remove stored authentication credentials for an AI provider.
|
||||
|
||||
This will delete the stored API key for the specified provider. You will need
|
||||
to use environment variables or command-line flags for authentication after logout.
|
||||
This will delete the stored API key or OAuth credentials for the specified provider.
|
||||
You will need to use environment variables or command-line flags for authentication after logout.
|
||||
|
||||
Available providers:
|
||||
- anthropic: Anthropic Claude API
|
||||
- openai: OpenAI API
|
||||
|
||||
Example:
|
||||
kit auth logout anthropic`,
|
||||
kit auth logout anthropic
|
||||
kit auth logout openai`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runAuthLogout,
|
||||
}
|
||||
@@ -101,8 +111,10 @@ func runAuthLogin(cmd *cobra.Command, args []string) error {
|
||||
switch provider {
|
||||
case "anthropic":
|
||||
return loginAnthropic()
|
||||
case "openai":
|
||||
return loginOpenAI()
|
||||
default:
|
||||
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic", provider)
|
||||
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic, openai", provider)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,8 +124,10 @@ func runAuthLogout(cmd *cobra.Command, args []string) error {
|
||||
switch provider {
|
||||
case "anthropic":
|
||||
return logoutAnthropic()
|
||||
case "openai":
|
||||
return logoutOpenAI()
|
||||
default:
|
||||
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic", provider)
|
||||
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic, openai", provider)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,8 +171,44 @@ func runAuthStatus(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Check OpenAI credentials
|
||||
fmt.Print("\nOpenAI: ")
|
||||
if hasOpenAICreds, err := cm.HasOpenAICredentials(); err != nil {
|
||||
fmt.Printf("Error checking credentials: %v\n", err)
|
||||
} else if hasOpenAICreds {
|
||||
if creds, err := cm.GetOpenAICredentials(); err != nil {
|
||||
fmt.Printf("Error reading credentials: %v\n", err)
|
||||
} else {
|
||||
authType := "API Key"
|
||||
status := "✓ Authenticated"
|
||||
|
||||
if creds.Type == "oauth" {
|
||||
authType = "OAuth (ChatGPT/Codex)"
|
||||
if creds.IsExpired() {
|
||||
status = "⚠️ Token expired (will refresh automatically)"
|
||||
} else if creds.NeedsRefresh() {
|
||||
status = "⚠️ Token expires soon (will refresh automatically)"
|
||||
}
|
||||
}
|
||||
|
||||
accountInfo := ""
|
||||
if creds.Type == "oauth" && creds.AccountID != "" {
|
||||
accountInfo = fmt.Sprintf(" [%s]", creds.AccountID)
|
||||
}
|
||||
|
||||
fmt.Printf("%s (%s%s, stored %s)\n", status, authType, accountInfo, creds.CreatedAt.Format("2006-01-02 15:04:05"))
|
||||
}
|
||||
} else {
|
||||
fmt.Println("✗ Not authenticated")
|
||||
// Check if environment variable is set
|
||||
if os.Getenv("OPENAI_API_KEY") != "" {
|
||||
fmt.Println(" (OPENAI_API_KEY environment variable is set)")
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("\nTo authenticate with a provider:")
|
||||
fmt.Println(" kit auth login anthropic")
|
||||
fmt.Println(" kit auth login openai")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -282,3 +332,246 @@ func logoutAnthropic() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func loginOpenAI() error {
|
||||
cm, err := kit.NewCredentialManager()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize credential manager: %w", err)
|
||||
}
|
||||
|
||||
// Check if already authenticated
|
||||
if hasAuth, err := cm.HasOpenAICredentials(); err == nil && hasAuth {
|
||||
var reauth bool
|
||||
err := huh.NewConfirm().
|
||||
Title("You are already authenticated with OpenAI (ChatGPT/Codex)").
|
||||
Description("Do you want to re-authenticate?").
|
||||
Affirmative("Yes").
|
||||
Negative("No").
|
||||
Value(&reauth).
|
||||
Run()
|
||||
if err != nil || !reauth {
|
||||
fmt.Println("Authentication cancelled.")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Create OAuth client
|
||||
client := auth.NewOpenAIOAuthClient()
|
||||
|
||||
// Generate authorization URL
|
||||
fmt.Println("🔐 Starting OAuth authentication with OpenAI (ChatGPT/Codex)...")
|
||||
fmt.Println("This will open your browser to authenticate with your ChatGPT account.")
|
||||
fmt.Println()
|
||||
|
||||
authData, err := client.GetAuthorizationURL()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate authorization URL: %w", err)
|
||||
}
|
||||
|
||||
// Start local callback server
|
||||
callbackServer, err := startOpenAICallbackServer(authData.State)
|
||||
if err != nil {
|
||||
fmt.Printf("⚠️ Could not start local callback server: %v\n", err)
|
||||
fmt.Println("Falling back to manual code entry.")
|
||||
}
|
||||
if callbackServer != nil {
|
||||
defer callbackServer.Close()
|
||||
}
|
||||
|
||||
// Display URL and try to open browser
|
||||
fmt.Println("📱 Opening your browser for authentication...")
|
||||
fmt.Println("If the browser doesn't open automatically, please visit this URL:")
|
||||
fmt.Printf("\n%s\n\n", authData.URL)
|
||||
|
||||
// Try to open browser
|
||||
auth.TryOpenBrowser(authData.URL)
|
||||
|
||||
// Wait for callback or manual input
|
||||
var code string
|
||||
if callbackServer != nil {
|
||||
fmt.Println("Waiting for browser authentication...")
|
||||
select {
|
||||
case callbackCode := <-callbackServer.CodeChan:
|
||||
if callbackCode != "" {
|
||||
code = callbackCode
|
||||
fmt.Println("✓ Received authorization code from browser callback.")
|
||||
}
|
||||
case <-time.After(2 * time.Minute):
|
||||
fmt.Println("\n⏱️ Timeout waiting for browser callback.")
|
||||
callbackServer.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// If no code from callback, prompt for manual entry
|
||||
if code == "" {
|
||||
fmt.Println("\nAfter authorizing, paste the callback URL or authorization code below.")
|
||||
fmt.Println("(The callback URL will look like: http://localhost:1455/auth/callback?code=...&state=...)")
|
||||
fmt.Println()
|
||||
|
||||
var input string
|
||||
err = huh.NewInput().
|
||||
Title("Callback URL or Code").
|
||||
Description("Paste the full callback URL or just the authorization code").
|
||||
Value(&input).
|
||||
Run()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read input: %w", err)
|
||||
}
|
||||
input = strings.TrimSpace(input)
|
||||
|
||||
if input == "" {
|
||||
return fmt.Errorf("authorization code cannot be empty")
|
||||
}
|
||||
|
||||
// Parse the input (could be full URL or just code)
|
||||
parsedCode, parsedState := auth.ParseOpenAIAuthorizationInput(input)
|
||||
if parsedCode == "" {
|
||||
return fmt.Errorf("could not extract authorization code from input")
|
||||
}
|
||||
|
||||
// Validate state if provided
|
||||
if parsedState != "" && parsedState != authData.State {
|
||||
return fmt.Errorf("state mismatch - possible security issue")
|
||||
}
|
||||
code = parsedCode
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
fmt.Println("\n🔄 Exchanging authorization code for access token...")
|
||||
creds, err := client.ExchangeCode(code, authData.Verifier)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to exchange authorization code: %w", err)
|
||||
}
|
||||
|
||||
// Store the credentials
|
||||
if err := cm.SetOpenAIOAuthCredentials(creds); err != nil {
|
||||
return fmt.Errorf("failed to store credentials: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("✅ Successfully authenticated with OpenAI (ChatGPT/Codex)!")
|
||||
fmt.Printf("📁 Credentials stored in: %s\n", cm.GetCredentialsPath())
|
||||
fmt.Printf("👤 Account ID: %s\n", creds.AccountID)
|
||||
fmt.Println("\n🎉 Your OAuth credentials will now be used for OpenAI API calls.")
|
||||
fmt.Println("💡 You can check your authentication status with: kit auth status")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// callbackServer holds the HTTP server and channel for receiving the OAuth callback
|
||||
type callbackServer struct {
|
||||
Server *http.Server
|
||||
CodeChan chan string
|
||||
State string
|
||||
}
|
||||
|
||||
// Close shuts down the callback server
|
||||
func (cs *callbackServer) Close() {
|
||||
if cs.Server != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = cs.Server.Shutdown(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// startOpenAICallbackServer starts a local HTTP server to receive the OAuth callback
|
||||
func startOpenAICallbackServer(expectedState string) (*callbackServer, error) {
|
||||
codeChan := make(chan string, 1)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
server := &http.Server{
|
||||
Addr: "127.0.0.1:1455",
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
mux.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check state
|
||||
state := r.URL.Query().Get("state")
|
||||
if state != expectedState {
|
||||
http.Error(w, "State mismatch", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
code := r.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
http.Error(w, "Missing authorization code", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Send code to channel
|
||||
select {
|
||||
case codeChan <- code:
|
||||
default:
|
||||
}
|
||||
|
||||
// Return success page
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = fmt.Fprintf(w, `<!DOCTYPE html>
|
||||
<html>
|
||||
<head><title>Authentication Successful</title></head>
|
||||
<body style="font-family: sans-serif; text-align: center; padding: 50px;">
|
||||
<h1>✓ Authentication Successful</h1>
|
||||
<p>You can close this window and return to the terminal.</p>
|
||||
</body>
|
||||
</html>`)
|
||||
})
|
||||
|
||||
// Try to start server
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:1455")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("port 1455 not available: %w", err)
|
||||
}
|
||||
_ = listener.Close()
|
||||
|
||||
go func() {
|
||||
_ = server.ListenAndServe()
|
||||
}()
|
||||
|
||||
return &callbackServer{
|
||||
Server: server,
|
||||
CodeChan: codeChan,
|
||||
State: expectedState,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func logoutOpenAI() error {
|
||||
cm, err := kit.NewCredentialManager()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize credential manager: %w", err)
|
||||
}
|
||||
|
||||
// Check if authenticated
|
||||
hasAuth, err := cm.HasOpenAICredentials()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check authentication status: %w", err)
|
||||
}
|
||||
|
||||
if !hasAuth {
|
||||
fmt.Println("You are not currently authenticated with OpenAI.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Confirm logout
|
||||
var confirm bool
|
||||
err = huh.NewConfirm().
|
||||
Title("Remove OpenAI credentials").
|
||||
Description("Are you sure you want to remove your stored credentials?").
|
||||
Affirmative("Yes").
|
||||
Negative("No").
|
||||
Value(&confirm).
|
||||
Run()
|
||||
if err != nil || !confirm {
|
||||
fmt.Println("Logout cancelled.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove credentials
|
||||
if err := cm.RemoveOpenAICredentials(); err != nil {
|
||||
return fmt.Errorf("failed to remove credentials: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("✓ Successfully logged out from OpenAI!")
|
||||
fmt.Println("You will need to use environment variables or command-line flags for authentication.")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
+38
-1
@@ -13,6 +13,7 @@ import (
|
||||
"charm.land/fantasy"
|
||||
"charm.land/lipgloss/v2"
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
@@ -810,7 +811,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
PrintError: func(text string) { appInstance.PrintFromExtension("error", text) },
|
||||
PrintBlock: appInstance.PrintBlockFromExtension,
|
||||
SendMessage: func(text string) { appInstance.Run(text) },
|
||||
CancelAndSend: func(text string) { appInstance.Steer(text) },
|
||||
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
SetWidget: func(config extensions.WidgetConfig) {
|
||||
kitInstance.SetExtensionWidget(config)
|
||||
@@ -955,6 +956,24 @@ func runNormalMode(ctx context.Context) error {
|
||||
kitInstance.UpdateExtensionContextModel(modelString)
|
||||
// Fire OnModelChange event to extensions.
|
||||
kitInstance.EmitModelChange(modelString, previousModel, "extension")
|
||||
// Update usage tracker with new model info for correct token counting.
|
||||
if usageTracker != nil {
|
||||
newProvider, newModel, _ := models.ParseModelString(modelString)
|
||||
if newProvider != "unknown" && newModel != "unknown" && newProvider != "ollama" {
|
||||
registry := models.GetGlobalRegistry()
|
||||
if modelInfo := registry.LookupModel(newProvider, newModel); modelInfo != nil {
|
||||
// Check OAuth status for Anthropic models
|
||||
isOAuth := false
|
||||
if newProvider == "anthropic" {
|
||||
_, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key"))
|
||||
if err == nil && strings.HasPrefix(source, "stored OAuth") {
|
||||
isOAuth = true
|
||||
}
|
||||
}
|
||||
usageTracker.UpdateModelInfo(modelInfo, newProvider, isOAuth)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetAvailableModels: func() []extensions.ModelInfoEntry {
|
||||
@@ -1152,6 +1171,24 @@ func runNormalMode(ctx context.Context) error {
|
||||
// this callback runs synchronously inside BubbleTea's Update(), and
|
||||
// NotifyModelChanged calls prog.Send() which deadlocks. The UI layer
|
||||
// updates m.providerName and m.modelName directly after setModel returns.
|
||||
// Update usage tracker with new model info for correct token counting.
|
||||
if usageTracker != nil {
|
||||
newProvider, newModel, _ := models.ParseModelString(modelString)
|
||||
if newProvider != "unknown" && newModel != "unknown" && newProvider != "ollama" {
|
||||
registry := models.GetGlobalRegistry()
|
||||
if modelInfo := registry.LookupModel(newProvider, newModel); modelInfo != nil {
|
||||
// Check OAuth status for Anthropic models
|
||||
isOAuth := false
|
||||
if newProvider == "anthropic" {
|
||||
_, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key"))
|
||||
if err == nil && strings.HasPrefix(source, "stored OAuth") {
|
||||
isOAuth = true
|
||||
}
|
||||
}
|
||||
usageTracker.UpdateModelInfo(modelInfo, newProvider, isOAuth)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
emitModelChangeForUI := func(newModel, previousModel, source string) {
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/pkg/extensions/test"
|
||||
)
|
||||
|
||||
// TestSubagentMonitor_SessionStart verifies OnSessionStart initializes state
|
||||
// without panicking and properly guards nil ctx calls.
|
||||
func TestSubagentMonitor_SessionStart(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
|
||||
// Emit SessionStart - should not panic even with nil ctx functions
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionStart should not error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubagentMonitor_SubagentLifecycle verifies the full subagent lifecycle
|
||||
// creates entries and emits widget updates.
|
||||
func TestSubagentMonitor_SubagentLifecycle(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
|
||||
// Start session
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionStart should not error: %v", err)
|
||||
}
|
||||
|
||||
// Emit SubagentStart
|
||||
_, err = harness.Emit(extensions.SubagentStartEvent{
|
||||
ToolCallID: "call-1",
|
||||
Task: "test task",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentStart should not error: %v", err)
|
||||
}
|
||||
|
||||
// Emit a few chunks
|
||||
for i := range 3 {
|
||||
_, err = harness.Emit(extensions.SubagentChunkEvent{
|
||||
ToolCallID: "call-1",
|
||||
Task: "test task",
|
||||
ChunkType: "text",
|
||||
Content: fmt.Sprintf("line %d", i),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentChunk %d should not error: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Emit tool call chunk
|
||||
_, err = harness.Emit(extensions.SubagentChunkEvent{
|
||||
ToolCallID: "call-1",
|
||||
Task: "test task",
|
||||
ChunkType: "tool_call",
|
||||
ToolName: "bash",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentChunk tool_call should not error: %v", err)
|
||||
}
|
||||
|
||||
// Emit SubagentEnd
|
||||
_, err = harness.Emit(extensions.SubagentEndEvent{
|
||||
ToolCallID: "call-1",
|
||||
Task: "test task",
|
||||
Response: "done",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentEnd should not error: %v", err)
|
||||
}
|
||||
|
||||
// Give time for cleanup goroutine
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestSubagentMonitor_MultipleSubagents verifies multiple parallel subagents.
|
||||
func TestSubagentMonitor_MultipleSubagents(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionStart should not error: %v", err)
|
||||
}
|
||||
|
||||
// Start 3 subagents
|
||||
for i := 1; i <= 3; i++ {
|
||||
_, err := harness.Emit(extensions.SubagentStartEvent{
|
||||
ToolCallID: fmt.Sprintf("call-%d", i),
|
||||
Task: fmt.Sprintf("task %d", i),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentStart %d should not error: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Emit chunks for each
|
||||
for i := 1; i <= 3; i++ {
|
||||
_, err := harness.Emit(extensions.SubagentChunkEvent{
|
||||
ToolCallID: fmt.Sprintf("call-%d", i),
|
||||
Task: fmt.Sprintf("task %d", i),
|
||||
ChunkType: "text",
|
||||
Content: fmt.Sprintf("output from agent %d", i),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentChunk %d should not error: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// End all subagents
|
||||
for i := 1; i <= 3; i++ {
|
||||
_, err := harness.Emit(extensions.SubagentEndEvent{
|
||||
ToolCallID: fmt.Sprintf("call-%d", i),
|
||||
Task: fmt.Sprintf("task %d", i),
|
||||
Response: "completed",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentEnd %d should not error: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestSubagentMonitor_SessionShutdown verifies shutdown doesn't panic
|
||||
// even with nil ctx functions.
|
||||
func TestSubagentMonitor_SessionShutdown(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
|
||||
// Start then shutdown
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionStart should not error: %v", err)
|
||||
}
|
||||
|
||||
// Start a subagent
|
||||
_, err = harness.Emit(extensions.SubagentStartEvent{
|
||||
ToolCallID: "call-1",
|
||||
Task: "test task",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubagentStart should not error: %v", err)
|
||||
}
|
||||
|
||||
// Shutdown - should not panic even with active subagent
|
||||
_, err = harness.Emit(extensions.SessionShutdownEvent{})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionShutdown should not error: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,7 @@ go 1.26.1
|
||||
require (
|
||||
charm.land/bubbles/v2 v2.0.0
|
||||
charm.land/bubbletea/v2 v2.0.2
|
||||
charm.land/fantasy v0.16.0
|
||||
charm.land/fantasy v0.17.1
|
||||
charm.land/huh/v2 v2.0.3
|
||||
charm.land/lipgloss/v2 v2.0.2
|
||||
github.com/alecthomas/chroma/v2 v2.23.1
|
||||
@@ -13,7 +13,7 @@ require (
|
||||
github.com/charmbracelet/fang v1.0.0
|
||||
github.com/charmbracelet/log v1.0.0
|
||||
github.com/coder/acp-go-sdk v0.6.3
|
||||
github.com/mark3labs/mcp-go v0.45.0
|
||||
github.com/mark3labs/mcp-go v0.46.0
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/traefik/yaegi v0.16.1
|
||||
@@ -23,14 +23,14 @@ require (
|
||||
|
||||
require (
|
||||
cloud.google.com/go v0.123.0 // indirect
|
||||
cloud.google.com/go/auth v0.18.2 // indirect
|
||||
cloud.google.com/go/auth v0.19.0 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.9.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect
|
||||
github.com/atotto/clipboard v0.1.4 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.4 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.12 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.12 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20 // indirect
|
||||
@@ -45,8 +45,6 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 // indirect
|
||||
github.com/aws/smithy-go v1.24.2 // indirect
|
||||
github.com/aymerick/douceur v0.2.0 // indirect
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/buger/jsonparser v1.1.2 // indirect
|
||||
github.com/catppuccin/go v0.3.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab // indirect
|
||||
@@ -56,9 +54,9 @@ require (
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266 // indirect
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260316091819-b93f6a3b8502 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260322003602-9b007323c5cd // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260323091123-df7b1bcffcca // indirect
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260322003602-9b007323c5cd // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260323091123-df7b1bcffcca // indirect
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/json v0.2.0 // indirect
|
||||
github.com/charmbracelet/x/termios v0.1.1 // indirect
|
||||
@@ -77,18 +75,17 @@ require (
|
||||
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/jsonschema-go v0.4.2 // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.19.0 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.20.0 // indirect
|
||||
github.com/gorilla/css v1.0.1 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/invopop/jsonschema v0.13.0 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.2.12 // indirect
|
||||
github.com/kaptinlin/jsonpointer v0.4.17 // indirect
|
||||
github.com/kaptinlin/jsonschema v0.7.6 // indirect
|
||||
github.com/kaptinlin/messageformat-go v0.4.18 // indirect
|
||||
github.com/mailru/easyjson v0.9.2 // indirect
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect
|
||||
github.com/muesli/mango v0.2.0 // indirect
|
||||
@@ -96,7 +93,7 @@ require (
|
||||
github.com/muesli/mango-pflag v0.2.0 // indirect
|
||||
github.com/muesli/reflow v0.3.0 // indirect
|
||||
github.com/muesli/roff v0.1.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.3.0 // indirect
|
||||
github.com/sagikazarmark/locafero v0.12.0 // indirect
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
@@ -105,10 +102,9 @@ require (
|
||||
github.com/tidwall/match v1.2.0 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
github.com/yuin/goldmark v1.7.17 // indirect
|
||||
github.com/yuin/goldmark v1.8.2 // indirect
|
||||
github.com/yuin/goldmark-emoji v1.0.6 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 // indirect
|
||||
@@ -122,7 +118,7 @@ require (
|
||||
golang.org/x/net v0.52.0 // indirect
|
||||
golang.org/x/oauth2 v0.36.0 // indirect
|
||||
golang.org/x/time v0.15.0 // indirect
|
||||
google.golang.org/api v0.272.0 // indirect
|
||||
google.golang.org/api v0.273.0 // indirect
|
||||
google.golang.org/genai v1.51.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7 // indirect
|
||||
google.golang.org/grpc v1.79.3 // indirect
|
||||
|
||||
@@ -2,16 +2,16 @@ charm.land/bubbles/v2 v2.0.0 h1:tE3eK/pHjmtrDiRdoC9uGNLgpopOd8fjhEe31B/ai5s=
|
||||
charm.land/bubbles/v2 v2.0.0/go.mod h1:rCHoleP2XhU8um45NTuOWBPNVHxnkXKTiZqcclL/qOI=
|
||||
charm.land/bubbletea/v2 v2.0.2 h1:4CRtRnuZOdFDTWSff9r8QFt/9+z6Emubz3aDMnf/dx0=
|
||||
charm.land/bubbletea/v2 v2.0.2/go.mod h1:3LRff2U4WIYXy7MTxfbAQ+AdfM3D8Xuvz2wbsOD9OHQ=
|
||||
charm.land/fantasy v0.16.0 h1:vE/6sR9nPcSD8qXJXX6wR8NXjtWlBVAzwQmTh5pHVrs=
|
||||
charm.land/fantasy v0.16.0/go.mod h1:VZjpXVh7IgeiIzGQybEnKzd68ofDsRj94+kzH1ZCAfQ=
|
||||
charm.land/fantasy v0.17.1 h1:SQzfnyJPDuQWt6e//KKmQmEEXdqHMC0IZz10XwkLcEM=
|
||||
charm.land/fantasy v0.17.1/go.mod h1:FF5ALCCHETacHJPBqU42CtwMInYQ0ul52fdzIHQMbQk=
|
||||
charm.land/huh/v2 v2.0.3 h1:2cJsMqEPwSywGHvdlKsJyQKPtSJLVnFKyFbsYZTlLkU=
|
||||
charm.land/huh/v2 v2.0.3/go.mod h1:93eEveeeqn47MwiC3tf+2atZ2l7Is88rAtmZNZ8x9Wc=
|
||||
charm.land/lipgloss/v2 v2.0.2 h1:xFolbF8JdpNkM2cEPTfXEcW1p6NRzOWTSamRfYEw8cs=
|
||||
charm.land/lipgloss/v2 v2.0.2/go.mod h1:KjPle2Qd3YmvP1KL5OMHiHysGcNwq6u83MUjYkFvEkM=
|
||||
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
|
||||
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
|
||||
cloud.google.com/go/auth v0.18.2 h1:+Nbt5Ev0xEqxlNjd6c+yYUeosQ5TtEUaNcN/3FozlaM=
|
||||
cloud.google.com/go/auth v0.18.2/go.mod h1:xD+oY7gcahcu7G2SG2DsBerfFxgPAJz17zz2joOFF3M=
|
||||
cloud.google.com/go/auth v0.19.0 h1:DGYwtbcsGsT1ywuxsIoWi1u/vlks0moIblQHgSDgQkQ=
|
||||
cloud.google.com/go/auth v0.19.0/go.mod h1:2Aph7BT2KnaSFOM0JDPyiYgNh6PL9vGMiP8CUIXZ+IY=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
|
||||
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
|
||||
@@ -36,8 +36,8 @@ github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.4 h1:10f50G7WyU02T56ox1wWXq+zTX9I1zxG46HYuG1hH/k=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.4/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 h1:3kGOqnh1pPeddVa/E37XNTaWJ8W6vrbYV9lJEkCnhuY=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 h1:eBMB84YGghSocM7PsjmmPffTa+1FBUeNvGvFou6V/4o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.12 h1:O3csC7HUGn2895eNrLytOJQdoL2xyJy0iYXhoZ1OmP0=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.12/go.mod h1:96zTvoOFR4FURjI+/5wY1vc1ABceROO4lWgWJuxgy0g=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.12 h1:oqtA6v+y5fZg//tcTWahyN9PEn5eDU/Wpvc2+kJ4aY8=
|
||||
@@ -70,10 +70,6 @@ github.com/aymanbagabas/go-udiff v0.4.1 h1:OEIrQ8maEeDBXQDoGCbbTTXYJMYRCRO1fnodZ
|
||||
github.com/aymanbagabas/go-udiff v0.4.1/go.mod h1:0L9PGwj20lrtmEMeyw4WKJ/TMyDtvAoK9bf2u/mNo3w=
|
||||
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
|
||||
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
|
||||
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
|
||||
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
|
||||
github.com/buger/jsonparser v1.1.2 h1:frqHqw7otoVbk5M8LlE/L7HTnIq2v9RX6EJ48i9AxJk=
|
||||
github.com/buger/jsonparser v1.1.2/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
|
||||
github.com/catppuccin/go v0.3.0 h1:d+0/YicIq+hSTo5oPuRi5kOpqkVA5tAsU6dNhvRu+aY=
|
||||
github.com/catppuccin/go v0.3.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
@@ -104,14 +100,14 @@ github.com/charmbracelet/x/conpty v0.1.1 h1:s1bUxjoi7EpqiXysVtC+a8RrvPPNcNvAjfi4
|
||||
github.com/charmbracelet/x/conpty v0.1.1/go.mod h1:OmtR77VODEFbiTzGE9G1XiRJAga6011PIm4u5fTNZpk=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 h1:JSt3B+U9iqk37QUU2Rvb6DSBYRLtWqFqfxf8l5hOZUA=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86/go.mod h1:2P0UgXMEa6TsToMSuFqKFQR+fZTO9CNGUNokkPatT/0=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260322003602-9b007323c5cd h1:eStB6uX52pgrm6TxQcEKctPrEC+a/9ubJC+P671idOc=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260322003602-9b007323c5cd/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260323091123-df7b1bcffcca h1:62yAoS1Ynbuzwcn1LkNBxi3IMF5p0E0cHCoaLOOmN9w=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260323091123-df7b1bcffcca/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f h1:pk6gmGpCE7F3FcjaOEKYriCvpmIN4+6OS/RD0vm4uIA=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f/go.mod h1:IfZAMTHB6XkZSeXUqriemErjAWCCzT0LwjKFYCZyw0I=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 h1:55/qLwjIh0gL0Vni+QAWk7T/qRVP6sBf+2agPBgnOFE=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0/go.mod h1:5UHwmG+is5THxMyCJHNPCn2/ecI07aKNrW+LcResjJ8=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260322003602-9b007323c5cd h1:U8xj0UXwqHzO+UYHZJopKF+gWaQEW8oj60fmiq9TFY4=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260322003602-9b007323c5cd/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260323091123-df7b1bcffcca h1:QQoyQLgUzojMNWHVHToN6d9qTvT0KWtxUKIRPx/Ox5o=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260323091123-df7b1bcffcca/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 h1:i69S2XI7uG1u4NLGeJPSYU++Nmjvpo9nwd6aoEm7gkA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0/go.mod h1:/ehtMPNh9K4odGFkqYJKpIYyePhdp1hLBRvyY4bWkH8=
|
||||
github.com/charmbracelet/x/json v0.2.0 h1:DqB+ZGx2h+Z+1s98HOuOyli+i97wsFQIxP2ZQANTPrQ=
|
||||
@@ -173,14 +169,16 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
|
||||
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
|
||||
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
|
||||
github.com/googleapis/gax-go/v2 v2.19.0 h1:fYQaUOiGwll0cGj7jmHT/0nPlcrZDFPrZRhTsoCr8hE=
|
||||
github.com/googleapis/gax-go/v2 v2.19.0/go.mod h1:w2ROXVdfGEVFXzmlciUU4EdjHgWvB5h2n6x/8XSTTJA=
|
||||
github.com/googleapis/gax-go/v2 v2.20.0 h1:NIKVuLhDlIV74muWlsMM4CcQZqN6JJ20Qcxd9YMuYcs=
|
||||
github.com/googleapis/gax-go/v2 v2.20.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
|
||||
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
|
||||
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
@@ -189,8 +187,6 @@ github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUq
|
||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=
|
||||
github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0=
|
||||
github.com/kaptinlin/go-i18n v0.2.12 h1:ywDsvb4KDFddMC2dpI/rrIzGU2mWUSvHmWUm9BMsdl4=
|
||||
github.com/kaptinlin/go-i18n v0.2.12/go.mod h1:pVcu9qsW5pOIOoZFJXesRYmLos1vMQrby70JPAoWmJU=
|
||||
github.com/kaptinlin/jsonpointer v0.4.17 h1:mY9k8ciWncxbsECyaxKnR0MdmxamNdp2tLQkAKVrtSk=
|
||||
@@ -207,10 +203,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mailru/easyjson v0.9.2 h1:dX8U45hQsZpxd80nLvDGihsQ/OxlvTkVUXH2r/8cb2M=
|
||||
github.com/mailru/easyjson v0.9.2/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
|
||||
github.com/mark3labs/mcp-go v0.45.0 h1:s0S8qR/9fWaQ3pHxz7pm1uQ0DrswoSnRIxKIjbiQtkc=
|
||||
github.com/mark3labs/mcp-go v0.45.0/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw=
|
||||
github.com/mark3labs/mcp-go v0.46.0 h1:8KRibF4wcKejbLsHxCA/QBVUr5fQ9nwz/n8lGqmaALo=
|
||||
github.com/mark3labs/mcp-go v0.46.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
|
||||
@@ -234,8 +228,8 @@ github.com/muesli/roff v0.1.0 h1:YD0lalCotmYuF5HhZliKWlIx7IEhiXeSfq7hNjFqGF8=
|
||||
github.com/muesli/roff v0.1.0/go.mod h1:pjAHQM9hdUUwm/krAfrLGgJkXJ+YuhtsfZ42kieB2Ig=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pelletier/go-toml/v2 v2.3.0 h1:k59bC/lIZREW0/iVaQR8nDHxVq8OVlIzYCOJf421CaM=
|
||||
github.com/pelletier/go-toml/v2 v2.3.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo=
|
||||
@@ -279,14 +273,12 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/traefik/yaegi v0.16.1 h1:f1De3DVJqIDKmnasUF6MwmWv1dSEEat0wcpXhD2On3E=
|
||||
github.com/traefik/yaegi v0.16.1/go.mod h1:4eVhbPb3LnD2VigQjhYbEJ69vDRFdT2HQNrXx8eEwUY=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/yuin/goldmark v1.7.17 h1:p36OVWwRb246iHxA/U4p8OPEpOTESm4n+g+8t0EE5uA=
|
||||
github.com/yuin/goldmark v1.7.17/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
|
||||
github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE=
|
||||
github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
|
||||
github.com/yuin/goldmark-emoji v1.0.6 h1:QWfF2FYaXwL74tfGOW5izeiZepUDroDJfWubQI9HTHs=
|
||||
github.com/yuin/goldmark-emoji v1.0.6/go.mod h1:ukxJDKFpdFb5x0a5HqbdlcKtebh086iJpI31LTKmWuA=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
@@ -328,10 +320,14 @@ golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/api v0.272.0 h1:eLUQZGnAS3OHn31URRf9sAmRk3w2JjMx37d2k8AjJmA=
|
||||
google.golang.org/api v0.272.0/go.mod h1:wKjowi5LNJc5qarNvDCvNQBn3rVK8nSy6jg2SwRwzIA=
|
||||
google.golang.org/api v0.273.0 h1:r/Bcv36Xa/te1ugaN1kdJ5LoA5Wj/cL+a4gj6FiPBjQ=
|
||||
google.golang.org/api v0.273.0/go.mod h1:JbAt7mF+XVmWu6xNP8/+CTiGH30ofmCmk9nM8d8fHew=
|
||||
google.golang.org/genai v1.51.0 h1:IZGuUqgfx40INv3hLFGCbOSGp0qFqm7LVmDghzNIYqg=
|
||||
google.golang.org/genai v1.51.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 h1:XzmzkmB14QhVhgnawEVsOn6OFsnpyxNPRY9QV01dNB0=
|
||||
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:L43LFes82YgSonw6iTXTxXUX1OlULt4AQtkik4ULL/I=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7 h1:41r6JMbpzBMen0R/4TZeeAmGXSJC7DftGINUodzTkPI=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:EIQZ5bFCfRQDV4MhRle7+OgjNtZ6P1PiZBgAKuxXu/Y=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7 h1:ndE4FoJqsIceKP2oYSnUZqhTdYufCYYkqwtFzfrhI7w=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE=
|
||||
|
||||
+62
-5
@@ -70,6 +70,11 @@ type ReasoningDeltaHandler func(delta string)
|
||||
// Note: This is an alias for core.ToolOutputCallback to avoid import cycles.
|
||||
type ToolOutputHandler = core.ToolOutputCallback
|
||||
|
||||
// StepUsageHandler is a function type for handling token usage after each
|
||||
// complete step in a multi-step agent turn. This enables real-time cost
|
||||
// tracking during long-running tool-calling conversations.
|
||||
type StepUsageHandler func(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64)
|
||||
|
||||
// Agent represents an AI agent with core tool integration using the fantasy library.
|
||||
// Core tools (bash, read, write, edit, grep, find, ls) are registered as direct
|
||||
// fantasy.AgentTool implementations — no MCP layer, no serialization overhead.
|
||||
@@ -178,7 +183,8 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
|
||||
// Pass generation parameters when available.
|
||||
if agentConfig.ModelConfig != nil {
|
||||
if agentConfig.ModelConfig.MaxTokens > 0 {
|
||||
// Skip max_output_tokens for providers that don't support it (e.g., Codex OAuth)
|
||||
if agentConfig.ModelConfig.MaxTokens > 0 && !providerResult.SkipMaxOutputTokens {
|
||||
agentOpts = append(agentOpts, fantasy.WithMaxOutputTokens(int64(agentConfig.ModelConfig.MaxTokens)))
|
||||
}
|
||||
if agentConfig.ModelConfig.Temperature != nil {
|
||||
@@ -225,7 +231,7 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []fantasy.Message
|
||||
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult,
|
||||
onResponse, onToolCallContent, nil, nil, nil)
|
||||
onResponse, onToolCallContent, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
// GenerateWithLoopAndStreaming processes messages using the fantasy agent with streaming and callbacks.
|
||||
@@ -237,6 +243,7 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
onStreamingResponse StreamingResponseHandler,
|
||||
onReasoningDelta ReasoningDeltaHandler,
|
||||
onToolOutput ToolOutputHandler,
|
||||
onStepUsage StepUsageHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
|
||||
// Inject tool output handler into context for use by core tools (e.g., bash).
|
||||
@@ -269,7 +276,7 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
var completedStepMessages []fantasy.Message
|
||||
|
||||
// Use fantasy's streaming agent
|
||||
result, err := a.fantasyAgent.Stream(ctx, fantasy.AgentStreamCall{
|
||||
streamCall := fantasy.AgentStreamCall{
|
||||
Prompt: prompt,
|
||||
Files: files,
|
||||
Messages: history,
|
||||
@@ -351,9 +358,58 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
if text != "" && len(toolCalls) > 0 && onToolCallContent != nil {
|
||||
onToolCallContent(text)
|
||||
}
|
||||
// Emit step usage for real-time cost tracking
|
||||
if onStepUsage != nil {
|
||||
onStepUsage(step.Usage.InputTokens, step.Usage.OutputTokens,
|
||||
step.Usage.CacheReadTokens, step.Usage.CacheCreationTokens)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// If a steer channel is attached to the context, wire up a
|
||||
// PrepareStep function that drains the channel between steps
|
||||
// and injects pending steer messages as user messages before
|
||||
// the next LLM call. This enables graceful mid-turn steering
|
||||
// without cancelling in-progress tool execution.
|
||||
if steerCh := steerChFromContext(ctx); steerCh != nil {
|
||||
onConsumed := steerConsumedFromContext(ctx)
|
||||
streamCall.PrepareStep = func(
|
||||
stepCtx context.Context,
|
||||
opts fantasy.PrepareStepFunctionOptions,
|
||||
) (context.Context, fantasy.PrepareStepResult, error) {
|
||||
// Drain all pending steer messages (non-blocking).
|
||||
var steered []string
|
||||
for {
|
||||
select {
|
||||
case msg := <-steerCh:
|
||||
steered = append(steered, msg)
|
||||
default:
|
||||
goto done
|
||||
}
|
||||
}
|
||||
done:
|
||||
result := fantasy.PrepareStepResult{
|
||||
Model: opts.Model,
|
||||
Messages: opts.Messages,
|
||||
}
|
||||
if len(steered) > 0 {
|
||||
// Inject each steer message as a user message so the
|
||||
// LLM sees the redirection on the next step.
|
||||
for _, text := range steered {
|
||||
result.Messages = append(result.Messages,
|
||||
fantasy.NewUserMessage(text))
|
||||
}
|
||||
// Notify that steer messages were consumed.
|
||||
if onConsumed != nil {
|
||||
onConsumed(len(steered))
|
||||
}
|
||||
}
|
||||
return stepCtx, result, nil
|
||||
}
|
||||
}
|
||||
|
||||
result, err := a.fantasyAgent.Stream(ctx, streamCall)
|
||||
if err != nil {
|
||||
// On cancellation (or any error), return a partial result
|
||||
// containing messages from completed steps so the caller can
|
||||
@@ -617,7 +673,8 @@ func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) err
|
||||
}
|
||||
|
||||
// Pass generation parameters when available.
|
||||
if config.MaxTokens > 0 {
|
||||
// Skip max_output_tokens for providers that don't support it (e.g., Codex OAuth)
|
||||
if config.MaxTokens > 0 && !providerResult.SkipMaxOutputTokens {
|
||||
agentOpts = append(agentOpts, fantasy.WithMaxOutputTokens(int64(config.MaxTokens)))
|
||||
}
|
||||
if config.Temperature != nil {
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
package agent
|
||||
|
||||
import "context"
|
||||
|
||||
// steerChKey is the context key for the steer channel.
|
||||
type steerChKey struct{}
|
||||
|
||||
// steerConsumedKey is the context key for the steer-consumed callback.
|
||||
type steerConsumedKey struct{}
|
||||
|
||||
// ContextWithSteerCh returns a new context with the steer channel attached.
|
||||
// The agent's PrepareStep function checks this channel between steps and
|
||||
// injects any pending steer messages as user messages before the next LLM call.
|
||||
func ContextWithSteerCh(ctx context.Context, ch <-chan string) context.Context {
|
||||
return context.WithValue(ctx, steerChKey{}, ch)
|
||||
}
|
||||
|
||||
// ContextWithSteerConsumed returns a new context with a callback that fires
|
||||
// when steer messages are consumed by PrepareStep. The count argument is the
|
||||
// number of messages injected in this batch.
|
||||
func ContextWithSteerConsumed(ctx context.Context, fn func(count int)) context.Context {
|
||||
return context.WithValue(ctx, steerConsumedKey{}, fn)
|
||||
}
|
||||
|
||||
// steerChFromContext extracts the steer channel from the context, or nil.
|
||||
func steerChFromContext(ctx context.Context) <-chan string {
|
||||
ch, _ := ctx.Value(steerChKey{}).(<-chan string)
|
||||
return ch
|
||||
}
|
||||
|
||||
// steerConsumedFromContext extracts the steer-consumed callback, or nil.
|
||||
func steerConsumedFromContext(ctx context.Context) func(int) {
|
||||
fn, _ := ctx.Value(steerConsumedKey{}).(func(int))
|
||||
return fn
|
||||
}
|
||||
+154
-32
@@ -3,7 +3,9 @@ package app
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/fantasy"
|
||||
@@ -159,11 +161,57 @@ func (a *App) QueueLength() int {
|
||||
return len(a.queue)
|
||||
}
|
||||
|
||||
// Steer cancels the current agent step (if running), clears the queue, and
|
||||
// sends a new message that will execute as soon as the current step finishes
|
||||
// cancelling. If the agent is idle, the message executes immediately.
|
||||
// This is the "steer" delivery mode for SendMessage.
|
||||
func (a *App) Steer(prompt string) {
|
||||
// Steer injects a steering message into the currently running agent turn.
|
||||
// If the agent is in a multi-step tool loop, the message is delivered after
|
||||
// the current tool execution finishes but before the next LLM call (graceful
|
||||
// mid-turn injection via Fantasy's PrepareStep). If the agent is streaming
|
||||
// a text-only response (no pending tool calls), the message waits until the
|
||||
// response completes and then executes as the next turn.
|
||||
//
|
||||
// If the agent is idle, the message starts executing immediately (same as Run).
|
||||
//
|
||||
// Returns the number of pending steer/queue items (0 = started immediately,
|
||||
// >0 = injected/queued). The caller must update UI state based on the return
|
||||
// value — Steer does NOT send events to the program to avoid deadlocking
|
||||
// when called from within Update().
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) Steer(prompt string) int {
|
||||
a.mu.Lock()
|
||||
|
||||
if a.closed {
|
||||
a.mu.Unlock()
|
||||
return 0
|
||||
}
|
||||
|
||||
if !a.busy {
|
||||
// Not busy — start immediately, same as Run().
|
||||
item := queueItem{Prompt: prompt}
|
||||
a.busy = true
|
||||
a.wg.Add(1)
|
||||
a.mu.Unlock()
|
||||
go a.drainQueue(item)
|
||||
return 0
|
||||
}
|
||||
|
||||
a.mu.Unlock()
|
||||
|
||||
// Agent is busy — inject via the SDK's steer channel. The message
|
||||
// will be picked up by PrepareStep between agent steps (after tool
|
||||
// execution, before next LLM call). If PrepareStep doesn't fire
|
||||
// (text-only response), drainQueue will pick it up after the turn.
|
||||
if a.opts.Kit != nil {
|
||||
a.opts.Kit.InjectSteer(prompt)
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// InterruptAndSend cancels the current agent step (if running), clears the
|
||||
// queue, and sends a new message that will execute as soon as the current
|
||||
// step finishes cancelling. If the agent is idle, the message executes
|
||||
// immediately. This is the hard-cancel delivery mode used by extensions'
|
||||
// CancelAndSend.
|
||||
func (a *App) InterruptAndSend(prompt string) {
|
||||
a.mu.Lock()
|
||||
|
||||
if a.closed {
|
||||
@@ -226,6 +274,10 @@ func (a *App) SwitchTreeSession(ts *session.TreeManager) {
|
||||
_ = old.Close()
|
||||
}
|
||||
a.opts.TreeSession = ts
|
||||
// Also update the kit SDK's tree session so messages are persisted correctly.
|
||||
if a.opts.Kit != nil {
|
||||
a.opts.Kit.SetTreeSession(ts)
|
||||
}
|
||||
// Reload messages from new session.
|
||||
a.store.Clear()
|
||||
if ts != nil {
|
||||
@@ -401,6 +453,13 @@ func (a *App) Close() {
|
||||
|
||||
// Wait for background goroutines.
|
||||
a.wg.Wait()
|
||||
|
||||
// Clean up empty session file on shutdown.
|
||||
if ts := a.opts.TreeSession; ts != nil && ts.IsEmpty() {
|
||||
if path := ts.GetFilePath(); path != "" {
|
||||
_ = os.Remove(path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -434,6 +493,24 @@ func (a *App) drainQueue(first queueItem) {
|
||||
// Process all collected items as a single batch
|
||||
a.runQueueBatch(items)
|
||||
|
||||
// Drain any unconsumed steer messages from the SDK channel.
|
||||
// These arrive when the user steered during a text-only response
|
||||
// (no tool calls, so PrepareStep didn't fire for a second step).
|
||||
// They go to the front of the queue so they run next.
|
||||
if a.opts.Kit != nil {
|
||||
if leftover := a.opts.Kit.DrainSteer(); len(leftover) > 0 {
|
||||
a.mu.Lock()
|
||||
steerItems := make([]queueItem, len(leftover))
|
||||
for i, text := range leftover {
|
||||
steerItems[i] = queueItem{Prompt: text}
|
||||
}
|
||||
a.queue = append(steerItems, a.queue...)
|
||||
a.mu.Unlock()
|
||||
// Notify UI about the consumed steer messages.
|
||||
a.sendEvent(SteerConsumedEvent{})
|
||||
}
|
||||
}
|
||||
|
||||
// Check if more items were queued while we were processing
|
||||
a.mu.Lock()
|
||||
hasMore := len(a.queue) > 0
|
||||
@@ -522,9 +599,10 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe to SDK events for TUI rendering. The subscription is
|
||||
// temporary — it lives only for the duration of this step.
|
||||
unsub := a.subscribeSDKEvents(sendFn)
|
||||
// Subscribe to SDK events for TUI rendering and per-step usage updates.
|
||||
// The subscription is temporary — it lives only for the duration of this step.
|
||||
var sawStepUsage atomic.Bool
|
||||
unsub := a.subscribeSDKEvents(sendFn, &sawStepUsage)
|
||||
defer unsub()
|
||||
|
||||
// Show spinner while the agent works.
|
||||
@@ -544,8 +622,9 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M
|
||||
// Sync in-memory store with the SDK's authoritative conversation.
|
||||
a.store.Replace(result.Messages)
|
||||
|
||||
// Update usage tracker.
|
||||
a.updateUsageFromTurnResult(result, prompt)
|
||||
// Update usage tracker. If per-step usage was already recorded from
|
||||
// StepUsageEvent callbacks, avoid double-counting totals.
|
||||
a.updateUsageFromTurnResult(result, prompt, sawStepUsage.Load())
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -569,9 +648,10 @@ func (a *App) executeBatch(ctx context.Context, items []queueItem, eventFn func(
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe to SDK events for TUI rendering. The subscription is
|
||||
// temporary — it lives only for the duration of this step.
|
||||
unsub := a.subscribeSDKEvents(sendFn)
|
||||
// Subscribe to SDK events for TUI rendering and per-step usage updates.
|
||||
// The subscription is temporary — it lives only for the duration of this step.
|
||||
var sawStepUsage atomic.Bool
|
||||
unsub := a.subscribeSDKEvents(sendFn, &sawStepUsage)
|
||||
defer unsub()
|
||||
|
||||
// Show spinner while the agent works.
|
||||
@@ -626,8 +706,10 @@ func (a *App) executeBatch(ctx context.Context, items []queueItem, eventFn func(
|
||||
// Sync in-memory store with the SDK's authoritative conversation.
|
||||
a.store.Replace(result.Messages)
|
||||
|
||||
// Update usage tracker (using last item's prompt for tracking).
|
||||
a.updateUsageFromTurnResult(result, items[len(items)-1].Prompt)
|
||||
// Update usage tracker (using last item's prompt for fallback estimation).
|
||||
// If per-step usage was already recorded from StepUsageEvent callbacks,
|
||||
// avoid double-counting totals.
|
||||
a.updateUsageFromTurnResult(result, items[len(items)-1].Prompt, sawStepUsage.Load())
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -644,9 +726,10 @@ func (a *App) sendEvent(msg tea.Msg) {
|
||||
}
|
||||
|
||||
// subscribeSDKEvents registers temporary SDK event subscribers that convert
|
||||
// SDK events to tea.Msg events and dispatch them via sendFn. Returns an
|
||||
// unsubscribe function that removes all listeners.
|
||||
func (a *App) subscribeSDKEvents(sendFn func(tea.Msg)) func() {
|
||||
// SDK events to tea.Msg events and dispatch them via sendFn. When stepUsageSeen
|
||||
// is provided, it is set to true after any non-zero StepUsageEvent is observed.
|
||||
// Returns an unsubscribe function that removes all listeners.
|
||||
func (a *App) subscribeSDKEvents(sendFn func(tea.Msg), stepUsageSeen *atomic.Bool) func() {
|
||||
k := a.opts.Kit
|
||||
var unsubs []func()
|
||||
|
||||
@@ -678,6 +761,10 @@ func (a *App) subscribeSDKEvents(sendFn func(tea.Msg)) func() {
|
||||
Chunk: ev.Chunk,
|
||||
IsStderr: ev.IsStderr,
|
||||
})
|
||||
case kit.SteerConsumedEvent:
|
||||
sendFn(SteerConsumedEvent{})
|
||||
case kit.StepUsageEvent:
|
||||
a.recordStepUsage(ev, stepUsageSeen)
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -847,29 +934,64 @@ func (a *App) PrintBlockFromExtension(opts extensions.PrintBlockOpts) {
|
||||
}
|
||||
}
|
||||
|
||||
// recordStepUsage applies token/cost usage reported for a completed step.
|
||||
// Step usage events arrive even when a turn is later cancelled, so this keeps
|
||||
// the usage widget accurate on all stop paths.
|
||||
func (a *App) recordStepUsage(ev kit.StepUsageEvent, stepUsageSeen *atomic.Bool) {
|
||||
hasUsage := ev.InputTokens > 0 || ev.OutputTokens > 0 || ev.CacheReadTokens > 0 || ev.CacheWriteTokens > 0
|
||||
if !hasUsage {
|
||||
return
|
||||
}
|
||||
if stepUsageSeen != nil {
|
||||
stepUsageSeen.Store(true)
|
||||
}
|
||||
if a.opts.UsageTracker == nil {
|
||||
return
|
||||
}
|
||||
a.opts.UsageTracker.UpdateUsage(
|
||||
int(ev.InputTokens),
|
||||
int(ev.OutputTokens),
|
||||
int(ev.CacheReadTokens),
|
||||
int(ev.CacheWriteTokens),
|
||||
)
|
||||
// Keep context fill reasonably fresh during long/partial turns.
|
||||
a.opts.UsageTracker.SetContextTokens(int(ev.InputTokens + ev.OutputTokens))
|
||||
}
|
||||
|
||||
// updateUsageFromTurnResult records token usage from an SDK TurnResult into the
|
||||
// configured UsageTracker. This is the SDK-path equivalent of updateUsage.
|
||||
func (a *App) updateUsageFromTurnResult(result *kit.TurnResult, userPrompt string) {
|
||||
// configured UsageTracker. Called once per turn after the turn completes.
|
||||
//
|
||||
// When sawStepUsage is true, totals were already accumulated incrementally via
|
||||
// StepUsageEvent callbacks; in that case this method only updates context fill.
|
||||
// Otherwise it falls back to TotalUsage (or estimation) to keep costs/tokens
|
||||
// visible for providers/modes that don't emit per-step usage.
|
||||
func (a *App) updateUsageFromTurnResult(result *kit.TurnResult, userPrompt string, sawStepUsage bool) {
|
||||
if a.opts.UsageTracker == nil || result == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if result.TotalUsage != nil {
|
||||
inputTokens := int(result.TotalUsage.InputTokens)
|
||||
outputTokens := int(result.TotalUsage.OutputTokens)
|
||||
if inputTokens > 0 && outputTokens > 0 {
|
||||
cacheReadTokens := int(result.TotalUsage.CacheReadTokens)
|
||||
cacheWriteTokens := int(result.TotalUsage.CacheCreationTokens)
|
||||
a.opts.UsageTracker.UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens)
|
||||
// --- Accumulate cost/token totals for the session ---
|
||||
if !sawStepUsage {
|
||||
if result.TotalUsage != nil && result.TotalUsage.InputTokens > 0 {
|
||||
a.opts.UsageTracker.UpdateUsage(
|
||||
int(result.TotalUsage.InputTokens),
|
||||
int(result.TotalUsage.OutputTokens),
|
||||
int(result.TotalUsage.CacheReadTokens),
|
||||
int(result.TotalUsage.CacheCreationTokens),
|
||||
)
|
||||
} else {
|
||||
// Provider didn't report token counts — fall back to character-based
|
||||
// estimates so the footer shows something rather than nothing.
|
||||
a.opts.UsageTracker.EstimateAndUpdateUsage(userPrompt, result.Response)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if result.FinalUsage != nil {
|
||||
if ct := int(result.FinalUsage.InputTokens) + int(result.FinalUsage.OutputTokens); ct > 0 {
|
||||
a.opts.UsageTracker.SetContextTokens(ct)
|
||||
}
|
||||
// --- Context window fill (drives the % bar) ---
|
||||
// Use FinalUsage.InputTokens: the input token count of the last API call
|
||||
// equals the number of tokens currently occupying the context window.
|
||||
// Adding OutputTokens would overstate fill since the response is not part
|
||||
// of the context that was *sent* to the model.
|
||||
if result.FinalUsage != nil && result.FinalUsage.InputTokens > 0 {
|
||||
a.opts.UsageTracker.SetContextTokens(int(result.FinalUsage.InputTokens))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
@@ -14,6 +16,47 @@ import (
|
||||
// Helpers
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
type usageUpdaterStub struct {
|
||||
mu sync.Mutex
|
||||
|
||||
updateCalls int
|
||||
estimateCalls int
|
||||
contextCalls int
|
||||
|
||||
lastUpdateInput int
|
||||
lastUpdateOutput int
|
||||
lastUpdateCacheRead int
|
||||
lastUpdateCacheWrite int
|
||||
lastContextTokens int
|
||||
lastEstimateInput string
|
||||
lastEstimateOutput string
|
||||
}
|
||||
|
||||
func (s *usageUpdaterStub) UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.updateCalls++
|
||||
s.lastUpdateInput = inputTokens
|
||||
s.lastUpdateOutput = outputTokens
|
||||
s.lastUpdateCacheRead = cacheReadTokens
|
||||
s.lastUpdateCacheWrite = cacheWriteTokens
|
||||
}
|
||||
|
||||
func (s *usageUpdaterStub) EstimateAndUpdateUsage(inputText, outputText string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.estimateCalls++
|
||||
s.lastEstimateInput = inputText
|
||||
s.lastEstimateOutput = outputText
|
||||
}
|
||||
|
||||
func (s *usageUpdaterStub) SetContextTokens(tokens int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.contextCalls++
|
||||
s.lastContextTokens = tokens
|
||||
}
|
||||
|
||||
// turnResult builds a minimal TurnResult with response text t.
|
||||
func turnResult(t string) *kit.TurnResult {
|
||||
return &kit.TurnResult{Response: t}
|
||||
@@ -489,3 +532,67 @@ func TestQueueLength_reflects(t *testing.T) {
|
||||
t.Fatalf("expected 3, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecordStepUsage_updatesTracker verifies that per-step usage updates are
|
||||
// recorded immediately (including context tokens) for stop-path correctness.
|
||||
func TestRecordStepUsage_updatesTracker(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
defer app.Close()
|
||||
|
||||
app.recordStepUsage(kit.StepUsageEvent{
|
||||
InputTokens: 120,
|
||||
OutputTokens: 45,
|
||||
CacheReadTokens: 5,
|
||||
CacheWriteTokens: 2,
|
||||
}, nil)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
|
||||
if usage.updateCalls != 1 {
|
||||
t.Fatalf("expected 1 update call, got %d", usage.updateCalls)
|
||||
}
|
||||
if usage.lastUpdateInput != 120 || usage.lastUpdateOutput != 45 || usage.lastUpdateCacheRead != 5 || usage.lastUpdateCacheWrite != 2 {
|
||||
t.Fatalf("unexpected usage update payload: in=%d out=%d cache_read=%d cache_write=%d",
|
||||
usage.lastUpdateInput, usage.lastUpdateOutput, usage.lastUpdateCacheRead, usage.lastUpdateCacheWrite)
|
||||
}
|
||||
if usage.contextCalls != 1 {
|
||||
t.Fatalf("expected 1 context token update, got %d", usage.contextCalls)
|
||||
}
|
||||
if usage.lastContextTokens != 165 {
|
||||
t.Fatalf("expected context tokens 165, got %d", usage.lastContextTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateUsageFromTurnResult_skipsTotalsWhenStepUsageSeen ensures we avoid
|
||||
// double-counting totals once StepUsageEvent-based updates were already applied.
|
||||
func TestUpdateUsageFromTurnResult_skipsTotalsWhenStepUsageSeen(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
defer app.Close()
|
||||
|
||||
app.updateUsageFromTurnResult(&kit.TurnResult{
|
||||
Response: "ok",
|
||||
TotalUsage: &fantasy.Usage{
|
||||
InputTokens: 999,
|
||||
OutputTokens: 111,
|
||||
CacheReadTokens: 7,
|
||||
CacheCreationTokens: 3,
|
||||
},
|
||||
FinalUsage: &fantasy.Usage{InputTokens: 456},
|
||||
}, "prompt", true)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
|
||||
if usage.updateCalls != 0 {
|
||||
t.Fatalf("expected no total usage update when sawStepUsage=true, got %d", usage.updateCalls)
|
||||
}
|
||||
if usage.estimateCalls != 0 {
|
||||
t.Fatalf("expected no estimate update when sawStepUsage=true, got %d", usage.estimateCalls)
|
||||
}
|
||||
if usage.contextCalls != 1 || usage.lastContextTokens != 456 {
|
||||
t.Fatalf("expected final context tokens=456, got calls=%d tokens=%d", usage.contextCalls, usage.lastContextTokens)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,6 +141,12 @@ type CompactErrorEvent struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// SteerConsumedEvent is sent when one or more steering messages have been
|
||||
// consumed — either injected mid-turn via PrepareStep, or drained into the
|
||||
// queue after a turn completes. The TUI uses this to clear the steering
|
||||
// badge from the display.
|
||||
type SteerConsumedEvent struct{}
|
||||
|
||||
// ModelChangedEvent is sent when an extension changes the active model via
|
||||
// ctx.SetModel. The TUI updates the model name shown in the status bar and
|
||||
// message attribution.
|
||||
|
||||
@@ -10,9 +10,10 @@ import (
|
||||
)
|
||||
|
||||
// CredentialStore holds all stored credentials for various providers.
|
||||
// Currently supports Anthropic credentials with both OAuth and API key authentication methods.
|
||||
// Currently supports Anthropic and OpenAI credentials with both OAuth and API key authentication methods.
|
||||
type CredentialStore struct {
|
||||
Anthropic *AnthropicCredentials `json:"anthropic,omitempty"`
|
||||
OpenAI *OpenAICredentials `json:"openai,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicCredentials holds Anthropic API credentials supporting both OAuth
|
||||
@@ -28,6 +29,20 @@ type AnthropicCredentials struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// OpenAICredentials holds OpenAI API credentials supporting both OAuth
|
||||
// and API key authentication methods. The Type field indicates which authentication
|
||||
// method is being used. For OAuth, tokens are stored with expiration timestamps
|
||||
// for automatic refresh. For API keys, only the key itself is stored.
|
||||
type OpenAICredentials struct {
|
||||
Type string `json:"type"` // "oauth" or "api_key"
|
||||
APIKey string `json:"api_key,omitempty"` // For API key auth
|
||||
AccessToken string `json:"access_token,omitempty"` // For OAuth
|
||||
RefreshToken string `json:"refresh_token,omitempty"` // For OAuth
|
||||
ExpiresAt int64 `json:"expires_at,omitempty"` // For OAuth
|
||||
AccountID string `json:"account_id,omitempty"` // For OAuth (ChatGPT account ID)
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// IsExpired checks if the OAuth token is expired based on the ExpiresAt timestamp.
|
||||
// Returns false for API key authentication or if no expiration is set.
|
||||
func (c *AnthropicCredentials) IsExpired() bool {
|
||||
@@ -48,6 +63,26 @@ func (c *AnthropicCredentials) NeedsRefresh() bool {
|
||||
return time.Now().Unix() >= (c.ExpiresAt - 300) // 5 minutes buffer
|
||||
}
|
||||
|
||||
// IsExpired checks if the OAuth token is expired based on the ExpiresAt timestamp.
|
||||
// Returns false for API key authentication or if no expiration is set.
|
||||
func (c *OpenAICredentials) IsExpired() bool {
|
||||
if c.Type != "oauth" || c.ExpiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= c.ExpiresAt
|
||||
}
|
||||
|
||||
// NeedsRefresh checks if the OAuth token needs refresh, returning true if the token
|
||||
// will expire within the next 5 minutes. This allows for proactive token refresh
|
||||
// to avoid authentication failures during operations. Returns false for API key
|
||||
// authentication or if no expiration is set.
|
||||
func (c *OpenAICredentials) NeedsRefresh() bool {
|
||||
if c.Type != "oauth" || c.ExpiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= (c.ExpiresAt - 300) // 5 minutes buffer
|
||||
}
|
||||
|
||||
// CredentialManager handles secure storage and retrieval of authentication credentials.
|
||||
// It manages a JSON file stored in the user's config directory with appropriate
|
||||
// file permissions for security.
|
||||
@@ -212,6 +247,142 @@ func (cm *CredentialManager) HasAnthropicCredentials() (bool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// SetOpenAICredentials stores OpenAI API key credentials. It validates the
|
||||
// API key format before storing. The API key must start with "sk-" and be
|
||||
// at least 20 characters long. Returns an error if the API key is invalid or
|
||||
// if storage fails.
|
||||
func (cm *CredentialManager) SetOpenAICredentials(apiKey string) error {
|
||||
if err := validateOpenAIAPIKey(apiKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store.OpenAI = &OpenAICredentials{
|
||||
Type: "api_key",
|
||||
APIKey: apiKey,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
return cm.SaveCredentials(store)
|
||||
}
|
||||
|
||||
// GetOpenAICredentials retrieves stored OpenAI credentials. Returns nil if
|
||||
// no credentials are stored. The returned credentials may be either OAuth or API
|
||||
// key type, check the Type field to determine which.
|
||||
func (cm *CredentialManager) GetOpenAICredentials() (*OpenAICredentials, error) {
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return store.OpenAI, nil
|
||||
}
|
||||
|
||||
// RemoveOpenAICredentials removes stored OpenAI credentials from storage.
|
||||
// If this was the only credential stored, the entire credentials file is removed.
|
||||
// Returns an error if the removal fails.
|
||||
func (cm *CredentialManager) RemoveOpenAICredentials() error {
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store.OpenAI = nil
|
||||
|
||||
// If store is empty, remove the file entirely
|
||||
if store.Anthropic == nil && store.OpenAI == nil {
|
||||
if err := os.Remove(cm.credentialsPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to remove credentials file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return cm.SaveCredentials(store)
|
||||
}
|
||||
|
||||
// HasOpenAICredentials checks if valid OpenAI credentials are stored.
|
||||
// Returns true if either a non-empty OAuth access token or API key is present,
|
||||
// false otherwise. Returns an error if credentials cannot be loaded.
|
||||
func (cm *CredentialManager) HasOpenAICredentials() (bool, error) {
|
||||
creds, err := cm.GetOpenAICredentials()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if creds == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Check based on credential type
|
||||
switch creds.Type {
|
||||
case "oauth":
|
||||
return creds.AccessToken != "", nil
|
||||
case "api_key":
|
||||
return creds.APIKey != "", nil
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// SetOpenAIOAuthCredentials stores OpenAI OAuth credentials in the credential manager's secure storage.
|
||||
// The credentials should include access token, refresh token, and expiration information.
|
||||
// Returns an error if the credentials cannot be saved.
|
||||
func (cm *CredentialManager) SetOpenAIOAuthCredentials(creds *OpenAICredentials) error {
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store.OpenAI = creds
|
||||
return cm.SaveCredentials(store)
|
||||
}
|
||||
|
||||
// GetValidOpenAIAccessToken returns a valid access token for API requests. For OAuth credentials,
|
||||
// it automatically refreshes the token if it's expired or about to expire. For API key
|
||||
// credentials, it simply returns the API key. Returns an error if no credentials are found,
|
||||
// if token refresh fails, or if the credential type is unknown.
|
||||
func (cm *CredentialManager) GetValidOpenAIAccessToken() (string, error) {
|
||||
creds, err := cm.GetOpenAICredentials()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if creds == nil {
|
||||
return "", fmt.Errorf("no credentials found")
|
||||
}
|
||||
|
||||
// For API key auth, return the API key
|
||||
if creds.Type == "api_key" {
|
||||
return creds.APIKey, nil
|
||||
}
|
||||
|
||||
// For OAuth, check if token needs refresh
|
||||
if creds.Type == "oauth" {
|
||||
if creds.NeedsRefresh() {
|
||||
// Refresh the token
|
||||
client := NewOpenAIOAuthClient()
|
||||
newCreds, err := client.RefreshToken(creds.RefreshToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to refresh token: %w", err)
|
||||
}
|
||||
|
||||
// Update stored credentials
|
||||
if err := cm.SetOpenAIOAuthCredentials(newCreds); err != nil {
|
||||
return "", fmt.Errorf("failed to save refreshed token: %w", err)
|
||||
}
|
||||
|
||||
return newCreds.AccessToken, nil
|
||||
}
|
||||
|
||||
return creds.AccessToken, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("unknown credential type: %s", creds.Type)
|
||||
}
|
||||
|
||||
// GetCredentialsPath returns the absolute path to the credentials JSON file.
|
||||
// This is useful for debugging or displaying the storage location to users.
|
||||
func (cm *CredentialManager) GetCredentialsPath() string {
|
||||
@@ -238,6 +409,26 @@ func validateAnthropicAPIKey(apiKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateOpenAIAPIKey validates the format of an OpenAI API key
|
||||
func validateOpenAIAPIKey(apiKey string) error {
|
||||
apiKey = strings.TrimSpace(apiKey)
|
||||
|
||||
if apiKey == "" {
|
||||
return fmt.Errorf("API key cannot be empty")
|
||||
}
|
||||
|
||||
// OpenAI API keys typically start with "sk-" and are quite long
|
||||
if !strings.HasPrefix(apiKey, "sk-") {
|
||||
return fmt.Errorf("invalid OpenAI API key format (should start with 'sk-')")
|
||||
}
|
||||
|
||||
if len(apiKey) < 20 {
|
||||
return fmt.Errorf("API key appears to be too short")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAnthropicAPIKey retrieves an Anthropic API key from multiple sources in priority order:
|
||||
// 1. Command-line flag value (highest priority)
|
||||
// 2. Stored credentials (OAuth or API key)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -30,6 +31,7 @@ type OAuthClient struct {
|
||||
type AuthData struct {
|
||||
URL string
|
||||
Verifier string
|
||||
State string // Optional state parameter for CSRF protection
|
||||
}
|
||||
|
||||
// NewOAuthClient creates a new OAuth client configured for Anthropic's OAuth service.
|
||||
@@ -199,6 +201,270 @@ func (c *OAuthClient) parseCodeAndState(code string) (parsedCode, parsedState st
|
||||
return
|
||||
}
|
||||
|
||||
// OpenAIOAuthClient handles OAuth 2.0 authentication flow with OpenAI Codex (ChatGPT Plus/Pro).
|
||||
// This uses OpenAI's auth0-based OAuth service for ChatGPT account authentication.
|
||||
type OpenAIOAuthClient struct {
|
||||
ClientID string
|
||||
AuthorizeURL string
|
||||
TokenURL string
|
||||
RedirectURI string
|
||||
Scopes string
|
||||
}
|
||||
|
||||
// NewOpenAIOAuthClient creates a new OAuth client configured for OpenAI Codex OAuth.
|
||||
// This uses the public client ID for CLI applications with PKCE for security.
|
||||
func NewOpenAIOAuthClient() *OpenAIOAuthClient {
|
||||
return &OpenAIOAuthClient{
|
||||
// Public client ID for OpenAI Codex CLI OAuth
|
||||
ClientID: "app_EMoamEEZ73f0CkXaXp7hrann",
|
||||
AuthorizeURL: "https://auth.openai.com/oauth/authorize",
|
||||
TokenURL: "https://auth.openai.com/oauth/token",
|
||||
RedirectURI: "http://localhost:1455/auth/callback",
|
||||
Scopes: "openid profile email offline_access",
|
||||
}
|
||||
}
|
||||
|
||||
// GetAuthorizationURL generates a complete authorization URL for the OAuth flow with
|
||||
// PKCE parameters. Returns an AuthData structure containing the URL for user
|
||||
// authentication and the PKCE verifier for the subsequent code exchange.
|
||||
func (c *OpenAIOAuthClient) GetAuthorizationURL() (*AuthData, error) {
|
||||
verifier, challenge, err := generatePKCE()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate PKCE: %w", err)
|
||||
}
|
||||
|
||||
// Generate random state
|
||||
stateBytes := make([]byte, 16)
|
||||
if _, err := rand.Read(stateBytes); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
}
|
||||
state := fmt.Sprintf("%x", stateBytes)
|
||||
|
||||
params := url.Values{
|
||||
"response_type": {"code"},
|
||||
"client_id": {c.ClientID},
|
||||
"redirect_uri": {c.RedirectURI},
|
||||
"scope": {c.Scopes},
|
||||
"code_challenge": {challenge},
|
||||
"code_challenge_method": {"S256"},
|
||||
"state": {state},
|
||||
"id_token_add_organizations": {"true"},
|
||||
"codex_cli_simplified_flow": {"true"},
|
||||
"originator": {"kit"},
|
||||
}
|
||||
|
||||
authURL := fmt.Sprintf("%s?%s", c.AuthorizeURL, params.Encode())
|
||||
|
||||
return &AuthData{
|
||||
URL: authURL,
|
||||
Verifier: verifier,
|
||||
State: state,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges an authorization code for access and refresh tokens.
|
||||
// The code parameter should be the authorization code received from the OAuth callback.
|
||||
// The verifier parameter must be the same PKCE verifier generated during GetAuthorizationURL.
|
||||
// Returns OpenAICredentials containing the tokens, expiration, and account ID.
|
||||
func (c *OpenAIOAuthClient) ExchangeCode(code, verifier string) (*OpenAICredentials, error) {
|
||||
return c.exchangeAuthorizationCode(code, verifier, c.RedirectURI)
|
||||
}
|
||||
|
||||
// exchangeAuthorizationCode performs the token exchange with the OAuth server
|
||||
func (c *OpenAIOAuthClient) exchangeAuthorizationCode(code, verifier, redirectUri string) (*OpenAICredentials, error) {
|
||||
data := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"client_id": {c.ClientID},
|
||||
"code": {code},
|
||||
"code_verifier": {verifier},
|
||||
"redirect_uri": {redirectUri},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), "POST", c.TokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make token request: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("token exchange failed: %s", string(body))
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode token response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.AccessToken == "" || tokenResp.RefreshToken == "" {
|
||||
return nil, fmt.Errorf("token response missing required fields")
|
||||
}
|
||||
|
||||
// Extract account ID from JWT token
|
||||
accountID := extractOpenAIAccountID(tokenResp.AccessToken)
|
||||
if accountID == "" {
|
||||
return nil, fmt.Errorf("failed to extract account ID from token")
|
||||
}
|
||||
|
||||
return &OpenAICredentials{
|
||||
Type: "oauth",
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
||||
CreatedAt: time.Now(),
|
||||
AccountID: accountID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an expired or expiring access token using a refresh token.
|
||||
// Returns new OpenAICredentials with updated access token, refresh token (may be
|
||||
// rotated), and new expiration timestamp. Returns an error if the refresh fails or
|
||||
// the refresh token is invalid.
|
||||
func (c *OpenAIOAuthClient) RefreshToken(refreshToken string) (*OpenAICredentials, error) {
|
||||
data := url.Values{
|
||||
"grant_type": {"refresh_token"},
|
||||
"refresh_token": {refreshToken},
|
||||
"client_id": {c.ClientID},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), "POST", c.TokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make refresh request: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("token refresh failed: %s", string(body))
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode refresh response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.AccessToken == "" || tokenResp.RefreshToken == "" {
|
||||
return nil, fmt.Errorf("refresh response missing required fields")
|
||||
}
|
||||
|
||||
// Extract account ID from JWT token
|
||||
accountID := extractOpenAIAccountID(tokenResp.AccessToken)
|
||||
if accountID == "" {
|
||||
return nil, fmt.Errorf("failed to extract account ID from refreshed token")
|
||||
}
|
||||
|
||||
return &OpenAICredentials{
|
||||
Type: "oauth",
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
||||
CreatedAt: time.Now(),
|
||||
AccountID: accountID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extractOpenAIAccountID extracts the ChatGPT account ID from a JWT access token.
|
||||
// The account ID is stored in the claim path https://api.openai.com/auth.chatgpt_account_id
|
||||
func extractOpenAIAccountID(token string) string {
|
||||
// JWT tokens are base64-encoded JSON payloads
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Decode payload (second part)
|
||||
payload := parts[1]
|
||||
// Add padding if needed
|
||||
if len(payload)%4 != 0 {
|
||||
payload += strings.Repeat("=", 4-len(payload)%4)
|
||||
}
|
||||
|
||||
decoded, err := base64.URLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var claims map[string]any
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Navigate to the claim path: https://api.openai.com/auth.chatgpt_account_id
|
||||
authPath, ok := claims["https://api.openai.com/auth"].(map[string]any)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
accountID, ok := authPath["chatgpt_account_id"].(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
return accountID
|
||||
}
|
||||
|
||||
// ParseOpenAIAuthorizationInput parses various forms of authorization input:
|
||||
// - Full callback URL: http://localhost:1455/auth/callback?code=xxx&state=yyy
|
||||
// - Code#State format: abc123#state456
|
||||
// - Query string: code=abc123&state=state456
|
||||
// - Just the code: abc123
|
||||
func ParseOpenAIAuthorizationInput(input string) (code, state string) {
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// Try parsing as URL
|
||||
if strings.HasPrefix(input, "http") {
|
||||
if u, err := url.Parse(input); err == nil {
|
||||
return u.Query().Get("code"), u.Query().Get("state")
|
||||
}
|
||||
}
|
||||
|
||||
// Try code#state format
|
||||
if strings.Contains(input, "#") {
|
||||
parts := strings.SplitN(input, "#", 2)
|
||||
return parts[0], parts[1]
|
||||
}
|
||||
|
||||
// Try query string format
|
||||
if strings.Contains(input, "code=") {
|
||||
if values, err := url.ParseQuery(input); err == nil {
|
||||
return values.Get("code"), values.Get("state")
|
||||
}
|
||||
}
|
||||
|
||||
// Assume it's just the code
|
||||
return input, ""
|
||||
}
|
||||
|
||||
// SetOAuthCredentials stores OAuth credentials in the credential manager's secure storage.
|
||||
// The credentials should include access token, refresh token, and expiration information.
|
||||
// Returns an error if the credentials cannot be saved.
|
||||
|
||||
+234
-44
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
@@ -13,19 +14,45 @@ import (
|
||||
udiff "github.com/aymanbagabas/go-udiff"
|
||||
)
|
||||
|
||||
type editArgs struct {
|
||||
Path string `json:"path"`
|
||||
// Edit represents a single replacement in a multi-edit operation.
|
||||
type Edit struct {
|
||||
OldText string `json:"old_text"`
|
||||
NewText string `json:"new_text"`
|
||||
}
|
||||
|
||||
// editArgs holds the arguments for the edit tool.
|
||||
// Supports both single-edit mode (old_text/new_text) and multi-edit mode (edits array).
|
||||
type editArgs struct {
|
||||
Path string `json:"path"`
|
||||
OldText string `json:"old_text"` // Single-edit mode
|
||||
NewText string `json:"new_text"` // Single-edit mode
|
||||
Edits []Edit `json:"edits"` // Multi-edit mode
|
||||
}
|
||||
|
||||
// replacement represents a normalized edit ready for processing.
|
||||
type replacement struct {
|
||||
oldText string // normalized old text for matching
|
||||
newText string // normalized new text
|
||||
originalOld string // original old text for metadata
|
||||
originalNew string // original new text for metadata
|
||||
index int // index in the original edits array (for error messages)
|
||||
}
|
||||
|
||||
// matchedReplacement represents a replacement with its match location.
|
||||
type matchedReplacement struct {
|
||||
replacement
|
||||
start int // start index in normalized content
|
||||
end int // end index in normalized content
|
||||
usedFuzzyMatch bool // true if fuzzy matching was used
|
||||
}
|
||||
|
||||
// NewEditTool creates the edit core tool.
|
||||
func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
cfg := ApplyOptions(opts)
|
||||
return &coreTool{
|
||||
info: fantasy.ToolInfo{
|
||||
Name: "edit",
|
||||
Description: "Edit a file by replacing exact text. The old_text must match exactly (including whitespace). Use this for precise, surgical edits. Fails if old_text is not found or matches multiple locations.",
|
||||
Description: "Edit a file by replacing exact text. Supports single edit via old_text/new_text, or multiple edits via the edits array. All edits in the array are matched against the original file content (non-incremental) and must be non-overlapping.",
|
||||
Parameters: map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
@@ -33,14 +60,32 @@ func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
},
|
||||
"old_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Exact text to find and replace (must match exactly)",
|
||||
"description": "Exact text to find and replace (single-edit mode). Must not be used with 'edits' array.",
|
||||
},
|
||||
"new_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "New text to replace the old text with",
|
||||
"description": "New text to replace the old text with (single-edit mode). Must not be used with 'edits' array.",
|
||||
},
|
||||
"edits": map[string]any{
|
||||
"type": "array",
|
||||
"description": "Array of edits for multi-region replacement. Each edit must have unique, non-overlapping old_text. All matches are against the original file content.",
|
||||
"items": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"old_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Exact text to find and replace for this edit",
|
||||
},
|
||||
"new_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "New text for this edit",
|
||||
},
|
||||
},
|
||||
"required": []string{"old_text", "new_text"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Required: []string{"path", "old_text", "new_text"},
|
||||
Required: []string{"path"},
|
||||
},
|
||||
handler: func(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return executeEdit(ctx, call, cfg.WorkDir)
|
||||
@@ -51,7 +96,7 @@ func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
func executeEdit(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
|
||||
var args editArgs
|
||||
if err := parseArgs(call.Input, &args); err != nil {
|
||||
return fantasy.NewTextErrorResponse("path, old_text, and new_text parameters are required"), nil
|
||||
return fantasy.NewTextErrorResponse("failed to parse arguments: " + err.Error()), nil
|
||||
}
|
||||
if args.Path == "" {
|
||||
return fantasy.NewTextErrorResponse("path parameter is required"), nil
|
||||
@@ -69,56 +114,201 @@ func executeEdit(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
|
||||
content := string(contentBytes)
|
||||
|
||||
// Normalize line endings for matching
|
||||
normalized := strings.ReplaceAll(content, "\r\n", "\n")
|
||||
normalizedOld := strings.ReplaceAll(args.OldText, "\r\n", "\n")
|
||||
|
||||
// Try exact match first
|
||||
count := strings.Count(normalized, normalizedOld)
|
||||
|
||||
// If no exact match, try fuzzy matching
|
||||
if count == 0 {
|
||||
if idx, matchLen := fuzzyMatch(normalized, normalizedOld); idx >= 0 {
|
||||
// Apply fuzzy match — the matched text is the original content slice
|
||||
matchedText := normalized[idx : idx+matchLen]
|
||||
newContent := normalized[:idx] + args.NewText + normalized[idx+matchLen:]
|
||||
if err := os.WriteFile(absPath, []byte(newContent), 0644); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to write file: %v", err)), nil
|
||||
}
|
||||
diff := generateDiff(absPath, normalized, newContent)
|
||||
resp := fantasy.NewTextResponse(fmt.Sprintf("Applied edit (fuzzy match) to %s\n%s", args.Path, diff))
|
||||
return fantasy.WithResponseMetadata(resp, editDiffMeta(absPath, matchedText, args.NewText)), nil
|
||||
}
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("old_text not found in %s", args.Path)), nil
|
||||
// Normalize and validate input
|
||||
replacements, err := normalizeEditInput(args)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
if count > 1 {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("found %d matches for old_text in %s. Provide more context to identify the correct match.", count, args.Path)), nil
|
||||
// Apply all edits
|
||||
newContent, applied, err := applyEdits(content, replacements)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
// Apply the edit
|
||||
newContent := strings.Replace(normalized, normalizedOld, args.NewText, 1)
|
||||
|
||||
// Write the file
|
||||
if err := os.WriteFile(absPath, []byte(newContent), 0644); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to write file: %v", err)), nil
|
||||
}
|
||||
|
||||
diff := generateDiff(absPath, normalized, newContent)
|
||||
resp := fantasy.NewTextResponse(fmt.Sprintf("Applied edit to %s\n%s", args.Path, diff))
|
||||
return fantasy.WithResponseMetadata(resp, editDiffMeta(absPath, normalizedOld, args.NewText)), nil
|
||||
// Generate diff
|
||||
normalizedContent := strings.ReplaceAll(content, "\r\n", "\n")
|
||||
diff := generateDiff(absPath, normalizedContent, newContent)
|
||||
|
||||
// Build response with fuzzy match indication
|
||||
fuzzyCount := 0
|
||||
for _, m := range applied {
|
||||
if m.usedFuzzyMatch {
|
||||
fuzzyCount++
|
||||
}
|
||||
}
|
||||
|
||||
var msg string
|
||||
if len(applied) == 1 {
|
||||
if fuzzyCount > 0 {
|
||||
msg = fmt.Sprintf("Applied edit (fuzzy match) to %s\n%s", args.Path, diff)
|
||||
} else {
|
||||
msg = fmt.Sprintf("Applied edit to %s\n%s", args.Path, diff)
|
||||
}
|
||||
} else {
|
||||
if fuzzyCount > 0 {
|
||||
msg = fmt.Sprintf("Applied %d edits (%d fuzzy) to %s\n%s", len(applied), fuzzyCount, args.Path, diff)
|
||||
} else {
|
||||
msg = fmt.Sprintf("Applied %d edits to %s\n%s", len(applied), args.Path, diff)
|
||||
}
|
||||
}
|
||||
|
||||
resp := fantasy.NewTextResponse(msg)
|
||||
return fantasy.WithResponseMetadata(resp, editDiffMeta(absPath, applied)), nil
|
||||
}
|
||||
|
||||
// normalizeEditInput validates and normalizes the edit input.
|
||||
// Returns error if both single-edit and multi-edit modes are used.
|
||||
func normalizeEditInput(args editArgs) ([]replacement, error) {
|
||||
singleMode := args.OldText != "" || args.NewText != ""
|
||||
multiMode := len(args.Edits) > 0
|
||||
|
||||
if singleMode && multiMode {
|
||||
return nil, fmt.Errorf("cannot use old_text/new_text together with edits array")
|
||||
}
|
||||
|
||||
if !singleMode && !multiMode {
|
||||
return nil, fmt.Errorf("must provide either old_text/new_text or edits array")
|
||||
}
|
||||
|
||||
if singleMode {
|
||||
if args.OldText == "" {
|
||||
return nil, fmt.Errorf("old_text is required when using single-edit mode")
|
||||
}
|
||||
if args.NewText == "" {
|
||||
return nil, fmt.Errorf("new_text is required when using single-edit mode")
|
||||
}
|
||||
return []replacement{{
|
||||
oldText: strings.ReplaceAll(args.OldText, "\r\n", "\n"),
|
||||
newText: strings.ReplaceAll(args.NewText, "\r\n", "\n"),
|
||||
originalOld: args.OldText,
|
||||
originalNew: args.NewText,
|
||||
index: 0,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// Multi-edit mode
|
||||
var reps []replacement
|
||||
for i, edit := range args.Edits {
|
||||
if edit.OldText == "" {
|
||||
return nil, fmt.Errorf("edits[%d].old_text is required", i)
|
||||
}
|
||||
reps = append(reps, replacement{
|
||||
oldText: strings.ReplaceAll(edit.OldText, "\r\n", "\n"),
|
||||
newText: strings.ReplaceAll(edit.NewText, "\r\n", "\n"),
|
||||
originalOld: edit.OldText,
|
||||
originalNew: edit.NewText,
|
||||
index: i,
|
||||
})
|
||||
}
|
||||
return reps, nil
|
||||
}
|
||||
|
||||
// applyEdits applies multiple replacements to the content.
|
||||
// All matches are against the original content (non-incremental).
|
||||
// Returns the new content, the applied matches, and any error.
|
||||
func applyEdits(content string, edits []replacement) (string, []matchedReplacement, error) {
|
||||
normalizedContent := strings.ReplaceAll(content, "\r\n", "\n")
|
||||
|
||||
// Find all matches
|
||||
var matched []matchedReplacement
|
||||
for _, edit := range edits {
|
||||
m, err := findMatch(normalizedContent, edit)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
matched = append(matched, *m)
|
||||
}
|
||||
|
||||
// Sort by position
|
||||
sort.Slice(matched, func(i, j int) bool {
|
||||
return matched[i].start < matched[j].start
|
||||
})
|
||||
|
||||
// Check for overlaps
|
||||
for i := 1; i < len(matched); i++ {
|
||||
if matched[i-1].end > matched[i].start {
|
||||
return "", nil, fmt.Errorf("edits[%d] and edits[%d] overlap; merge them into a single edit",
|
||||
matched[i-1].index, matched[i].index)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply edits in reverse order (end to start) to maintain stable offsets
|
||||
result := normalizedContent
|
||||
for i := len(matched) - 1; i >= 0; i-- {
|
||||
m := matched[i]
|
||||
result = result[:m.start] + m.newText + result[m.end:]
|
||||
}
|
||||
|
||||
return result, matched, nil
|
||||
}
|
||||
|
||||
// findMatch finds a unique match for the edit in the content.
|
||||
// Returns error if not found or ambiguous.
|
||||
func findMatch(content string, edit replacement) (*matchedReplacement, error) {
|
||||
// Try exact match first
|
||||
count := strings.Count(content, edit.oldText)
|
||||
|
||||
if count == 0 {
|
||||
// Try fuzzy match
|
||||
idx, matchLen := fuzzyMatch(content, edit.oldText)
|
||||
if idx < 0 {
|
||||
return nil, fmt.Errorf("edits[%d]: could not find old_text in file. The text must match exactly (including whitespace)", edit.index)
|
||||
}
|
||||
// Use the matched text from content for the replacement
|
||||
matchedText := content[idx : idx+matchLen]
|
||||
return &matchedReplacement{
|
||||
replacement: replacement{
|
||||
oldText: matchedText,
|
||||
newText: edit.newText,
|
||||
originalOld: edit.originalOld,
|
||||
originalNew: edit.originalNew,
|
||||
index: edit.index,
|
||||
},
|
||||
start: idx,
|
||||
end: idx + matchLen,
|
||||
usedFuzzyMatch: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if count > 1 {
|
||||
return nil, fmt.Errorf("found %d matches for edits[%d].old_text; each old_text must be unique, provide more context to identify the correct match", count, edit.index)
|
||||
}
|
||||
|
||||
// Single exact match
|
||||
idx := strings.Index(content, edit.oldText)
|
||||
return &matchedReplacement{
|
||||
replacement: edit,
|
||||
start: idx,
|
||||
end: idx + len(edit.oldText),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// editDiffMeta builds the structured metadata attached to edit tool responses.
|
||||
func editDiffMeta(path, oldText, newText string) map[string]any {
|
||||
func editDiffMeta(path string, applied []matchedReplacement) map[string]any {
|
||||
var diffBlocks []map[string]any
|
||||
totalAdditions, totalDeletions := 0, 0
|
||||
|
||||
for _, m := range applied {
|
||||
diffBlocks = append(diffBlocks, map[string]any{
|
||||
"old_text": m.originalOld,
|
||||
"new_text": m.originalNew,
|
||||
})
|
||||
totalAdditions += strings.Count(m.originalNew, "\n") + 1
|
||||
totalDeletions += strings.Count(m.originalOld, "\n") + 1
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"file_diffs": []map[string]any{{
|
||||
"path": path,
|
||||
"additions": strings.Count(newText, "\n") + 1,
|
||||
"deletions": strings.Count(oldText, "\n") + 1,
|
||||
"diff_blocks": []map[string]any{{
|
||||
"old_text": oldText,
|
||||
"new_text": newText,
|
||||
}},
|
||||
"path": path,
|
||||
"additions": totalAdditions,
|
||||
"deletions": totalDeletions,
|
||||
"diff_blocks": diffBlocks,
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -715,3 +715,315 @@ func TestExecuteEdit_MetadataContainsFileDiffs(t *testing.T) {
|
||||
t.Fatal("file_diffs should be a non-empty array")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Multi-edit tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExecuteEdit_MultiEdit_Basic(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "multi.txt")
|
||||
writeFileOrFail(t, path, "line1\nline2\nline3\nline4\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "line1", NewText: "LINE1"},
|
||||
{OldText: "line3", NewText: "LINE3"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("tool returned error: %s", resp.Content)
|
||||
}
|
||||
|
||||
got, _ := os.ReadFile(path)
|
||||
gotStr := string(got)
|
||||
|
||||
if !strings.Contains(gotStr, "LINE1") {
|
||||
t.Error("first edit not applied: missing LINE1")
|
||||
}
|
||||
if !strings.Contains(gotStr, "LINE3") {
|
||||
t.Error("second edit not applied: missing LINE3")
|
||||
}
|
||||
if !strings.Contains(gotStr, "line2") {
|
||||
t.Error("line2 was modified but should be untouched")
|
||||
}
|
||||
if !strings.Contains(gotStr, "line4") {
|
||||
t.Error("line4 was modified but should be untouched")
|
||||
}
|
||||
|
||||
// Check response mentions multiple edits
|
||||
if !strings.Contains(resp.Content, "2 edits") {
|
||||
t.Errorf("response should mention '2 edits', got: %s", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_NonIncrementalMatching(t *testing.T) {
|
||||
// All edits are matched against the original content, not incrementally
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "noninc.txt")
|
||||
writeFileOrFail(t, path, "aaa\nbbb\nccc\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "aaa", NewText: "AAA"},
|
||||
{OldText: "bbb", NewText: "BBB"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("tool returned error: %s", resp.Content)
|
||||
}
|
||||
|
||||
got, _ := os.ReadFile(path)
|
||||
gotStr := string(got)
|
||||
|
||||
want := "AAA\nBBB\nccc\n"
|
||||
if gotStr != want {
|
||||
t.Errorf("got %q, want %q", gotStr, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_OverlapDetection(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "overlap.txt")
|
||||
writeFileOrFail(t, path, "hello world\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "hello", NewText: "HELLO"},
|
||||
{OldText: "hello world", NewText: "GOODBYE"}, // Overlaps with first edit
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error for overlapping edits")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "overlap") {
|
||||
t.Errorf("expected 'overlap' in error, got: %s", resp.Content)
|
||||
}
|
||||
|
||||
// File should be untouched
|
||||
got, _ := os.ReadFile(path)
|
||||
if string(got) != "hello world\n" {
|
||||
t.Error("file was modified despite error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_DuplicateDetection(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "dup.txt")
|
||||
writeFileOrFail(t, path, "hello\nworld\nhello\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "hello", NewText: "HELLO"},
|
||||
{OldText: "world", NewText: "WORLD"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error for ambiguous old_text (duplicate matches)")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "unique") {
|
||||
t.Errorf("expected 'unique' in error, got: %s", resp.Content)
|
||||
}
|
||||
|
||||
// File should be untouched
|
||||
got, _ := os.ReadFile(path)
|
||||
if string(got) != "hello\nworld\nhello\n" {
|
||||
t.Error("file was modified despite error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_NotFound(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "notfound.txt")
|
||||
writeFileOrFail(t, path, "hello world\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "nonexistent", NewText: "REPLACEMENT"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error for not found")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "edits[0]") {
|
||||
t.Errorf("expected 'edits[0]' in error, got: %s", resp.Content)
|
||||
}
|
||||
|
||||
// File should be untouched
|
||||
got, _ := os.ReadFile(path)
|
||||
if string(got) != "hello world\n" {
|
||||
t.Error("file was modified despite error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_EmptyArray(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "empty.txt")
|
||||
writeFileOrFail(t, path, "hello\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error for empty edits array")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_MixedWithSingleMode(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "mixed.txt")
|
||||
writeFileOrFail(t, path, "hello\n")
|
||||
|
||||
input, _ := json.Marshal(map[string]any{
|
||||
"path": path,
|
||||
"old_text": "hello",
|
||||
"new_text": "HELLO",
|
||||
"edits": []Edit{
|
||||
{OldText: "hello", NewText: "HI"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error when mixing single and multi-edit modes")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "cannot use") {
|
||||
t.Errorf("expected 'cannot use' in error, got: %s", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_FuzzyMatch(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "fuzzy_multi.txt")
|
||||
// File has trailing whitespace
|
||||
original := "func foo() { \n\treturn 1 \n}\nfunc bar() { \n\treturn 2 \n}\n"
|
||||
writeFileOrFail(t, path, original)
|
||||
|
||||
// Search without trailing whitespace (common LLM behavior)
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "func foo() {\n\treturn 1\n}", NewText: "func foo() {\n\treturn 10\n}"},
|
||||
{OldText: "func bar() {\n\treturn 2\n}", NewText: "func bar() {\n\treturn 20\n}"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("tool returned error: %s", resp.Content)
|
||||
}
|
||||
|
||||
got, _ := os.ReadFile(path)
|
||||
gotStr := string(got)
|
||||
|
||||
if !strings.Contains(gotStr, "return 10") {
|
||||
t.Error("first edit not applied")
|
||||
}
|
||||
if !strings.Contains(gotStr, "return 20") {
|
||||
t.Error("second edit not applied")
|
||||
}
|
||||
|
||||
// Response should mention fuzzy match
|
||||
if !strings.Contains(resp.Content, "fuzzy") {
|
||||
t.Errorf("response should mention 'fuzzy', got: %s", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_Metadata(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "meta_multi.txt")
|
||||
writeFileOrFail(t, path, "aaa\nbbb\nccc\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "aaa", NewText: "AAA"},
|
||||
{OldText: "bbb", NewText: "BBB"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("tool returned error: %s", resp.Content)
|
||||
}
|
||||
|
||||
var meta map[string]any
|
||||
if err := json.Unmarshal([]byte(resp.Metadata), &meta); err != nil {
|
||||
t.Fatalf("metadata is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
diffs, ok := meta["file_diffs"].([]any)
|
||||
if !ok || len(diffs) == 0 {
|
||||
t.Fatal("metadata missing file_diffs")
|
||||
}
|
||||
|
||||
firstDiff, ok := diffs[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("first diff is not an object")
|
||||
}
|
||||
|
||||
// Check that diff_blocks contains both edits
|
||||
diffBlocks, ok := firstDiff["diff_blocks"].([]any)
|
||||
if !ok || len(diffBlocks) != 2 {
|
||||
t.Fatalf("expected 2 diff_blocks, got %d", len(diffBlocks))
|
||||
}
|
||||
|
||||
// Verify each block has old_text and new_text
|
||||
for i, block := range diffBlocks {
|
||||
b, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("diff_block[%d] is not an object", i)
|
||||
}
|
||||
if _, ok := b["old_text"]; !ok {
|
||||
t.Fatalf("diff_block[%d] missing old_text", i)
|
||||
}
|
||||
if _, ok := b["new_text"]; !ok {
|
||||
t.Fatalf("diff_block[%d] missing new_text", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -750,6 +750,9 @@ type API struct {
|
||||
registerOption func(OptionDef)
|
||||
registerShortcutFn func(ShortcutDef, func(Context))
|
||||
registerMessageRendererFn func(MessageRendererConfig)
|
||||
onSubagentStart func(func(SubagentStartEvent, Context))
|
||||
onSubagentChunk func(func(SubagentChunkEvent, Context))
|
||||
onSubagentEnd func(func(SubagentEndEvent, Context))
|
||||
}
|
||||
|
||||
// OnToolCall registers a handler that fires before a tool executes.
|
||||
@@ -781,6 +784,27 @@ func (a *API) OnToolResult(handler func(ToolResultEvent, Context) *ToolResultRes
|
||||
a.onToolResult(handler)
|
||||
}
|
||||
|
||||
// OnSubagentStart registers a handler that fires when a spawn_subagent tool
|
||||
// call begins executing. Use the ToolCallID to correlate with subsequent
|
||||
// OnSubagentChunk and OnSubagentEnd events for the same subagent.
|
||||
func (a *API) OnSubagentStart(handler func(SubagentStartEvent, Context)) {
|
||||
a.onSubagentStart(handler)
|
||||
}
|
||||
|
||||
// OnSubagentChunk registers a handler for real-time events from a running
|
||||
// subagent. ChunkType identifies the kind of event ("text", "tool_call",
|
||||
// "tool_result", "tool_execution_start", "tool_execution_end", etc.).
|
||||
// Correlate with OnSubagentStart via the ToolCallID field.
|
||||
func (a *API) OnSubagentChunk(handler func(SubagentChunkEvent, Context)) {
|
||||
a.onSubagentChunk(handler)
|
||||
}
|
||||
|
||||
// OnSubagentEnd registers a handler that fires when a spawn_subagent call
|
||||
// completes. ErrorMsg is non-empty when the subagent failed.
|
||||
func (a *API) OnSubagentEnd(handler func(SubagentEndEvent, Context)) {
|
||||
a.onSubagentEnd(handler)
|
||||
}
|
||||
|
||||
// OnInput registers a handler that fires when user input is received.
|
||||
// Return a non-nil InputResult to transform or handle the input.
|
||||
func (a *API) OnInput(handler func(InputEvent, Context) *InputResult) {
|
||||
@@ -1781,9 +1805,65 @@ type BeforeCompactResult struct {
|
||||
func (BeforeCompactResult) isResult() {}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Theme types (exposed to Yaegi — concrete structs, string hex colors)
|
||||
// Subagent lifecycle events (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SubagentStartEvent fires when a spawn_subagent tool call begins executing.
|
||||
type SubagentStartEvent struct {
|
||||
// ToolCallID is the LLM-assigned ID of the spawn_subagent tool call.
|
||||
// Use this to correlate SubagentChunkEvent and SubagentEndEvent.
|
||||
ToolCallID string
|
||||
// Task is the task description passed to the subagent.
|
||||
Task string
|
||||
}
|
||||
|
||||
func (e SubagentStartEvent) Type() EventType { return SubagentStart }
|
||||
|
||||
// SubagentChunkEvent fires for each real-time event from a running subagent.
|
||||
// Type field indicates the kind of event; read the relevant fields accordingly.
|
||||
type SubagentChunkEvent struct {
|
||||
// ToolCallID matches the SubagentStartEvent.ToolCallID for this subagent.
|
||||
ToolCallID string
|
||||
// Task is the task description (repeated for convenience).
|
||||
Task string
|
||||
// ChunkType identifies the event kind:
|
||||
// "text" — LLM text chunk (read Content)
|
||||
// "reasoning" — reasoning/thinking delta (read Content)
|
||||
// "tool_call" — subagent called a tool (read ToolName, ToolArgs)
|
||||
// "tool_result" — tool returned a result (read ToolName, ToolResult, IsError)
|
||||
// "tool_execution_start" — tool began executing (read ToolName)
|
||||
// "tool_execution_end" — tool finished executing (read ToolName)
|
||||
// "turn_start" — subagent turn began
|
||||
// "turn_end" — subagent turn ended
|
||||
ChunkType string
|
||||
// Content carries text for "text" and "reasoning" chunk types.
|
||||
Content string
|
||||
// ToolName is set on tool-related chunk types.
|
||||
ToolName string
|
||||
// ToolArgs is the JSON-encoded tool arguments for "tool_call" chunks.
|
||||
ToolArgs string
|
||||
// ToolResult is the tool output for "tool_result" chunks.
|
||||
ToolResult string
|
||||
// IsError is true when a "tool_result" chunk represents an error.
|
||||
IsError bool
|
||||
}
|
||||
|
||||
func (e SubagentChunkEvent) Type() EventType { return SubagentChunk }
|
||||
|
||||
// SubagentEndEvent fires when a spawn_subagent tool call completes.
|
||||
type SubagentEndEvent struct {
|
||||
// ToolCallID matches the SubagentStartEvent.ToolCallID for this subagent.
|
||||
ToolCallID string
|
||||
// Task is the task description.
|
||||
Task string
|
||||
// Response is the subagent's final text response (empty on error).
|
||||
Response string
|
||||
// ErrorMsg is non-empty when the subagent failed.
|
||||
ErrorMsg string
|
||||
}
|
||||
|
||||
func (e SubagentEndEvent) Type() EventType { return SubagentEnd }
|
||||
|
||||
// ThemeColor is an adaptive color pair with light and dark hex values.
|
||||
// Either field may be empty to inherit from the default theme.
|
||||
type ThemeColor struct {
|
||||
|
||||
@@ -71,6 +71,18 @@ const (
|
||||
// BeforeCompact fires before context compaction runs. Handlers can
|
||||
// cancel compaction by returning Cancel=true.
|
||||
BeforeCompact EventType = "before_compact"
|
||||
|
||||
// SubagentStart fires when a spawn_subagent tool call begins executing.
|
||||
// Carries the tool call ID and the task description.
|
||||
SubagentStart EventType = "subagent_start"
|
||||
|
||||
// SubagentChunk fires for each real-time event emitted by a running
|
||||
// subagent: text chunks, tool calls, tool results, etc.
|
||||
SubagentChunk EventType = "subagent_chunk"
|
||||
|
||||
// SubagentEnd fires when a spawn_subagent tool call completes (success
|
||||
// or error). Carries the final response and any error message.
|
||||
SubagentEnd EventType = "subagent_end"
|
||||
)
|
||||
|
||||
// AllEventTypes returns every supported event type.
|
||||
@@ -82,6 +94,7 @@ func AllEventTypes() []EventType {
|
||||
SessionStart, SessionShutdown,
|
||||
ModelChange, ContextPrepare,
|
||||
BeforeFork, BeforeSessionSwitch, BeforeCompact,
|
||||
SubagentStart, SubagentChunk, SubagentEnd,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ import "testing"
|
||||
|
||||
func TestAllEventTypes_Count(t *testing.T) {
|
||||
all := AllEventTypes()
|
||||
if len(all) != 18 {
|
||||
t.Fatalf("expected 18 event types, got %d", len(all))
|
||||
if len(all) != 21 {
|
||||
t.Fatalf("expected 21 event types, got %d", len(all))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,6 +55,9 @@ func TestEventType_TypeMethod(t *testing.T) {
|
||||
{BeforeForkEvent{TargetID: "abc"}, BeforeFork},
|
||||
{BeforeSessionSwitchEvent{Reason: "new"}, BeforeSessionSwitch},
|
||||
{BeforeCompactEvent{EstimatedTokens: 1000}, BeforeCompact},
|
||||
{SubagentStartEvent{ToolCallID: "x", Task: "t"}, SubagentStart},
|
||||
{SubagentChunkEvent{ToolCallID: "x", ChunkType: "text"}, SubagentChunk},
|
||||
{SubagentEndEvent{ToolCallID: "x"}, SubagentEnd},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -580,6 +580,24 @@ func loadSingleExtension(path string) (*LoadedExtension, error) {
|
||||
registerShortcutFn: func(def ShortcutDef, handler func(Context)) {
|
||||
ext.Shortcuts = append(ext.Shortcuts, ShortcutEntry{Def: def, Handler: handler})
|
||||
},
|
||||
onSubagentStart: func(h func(SubagentStartEvent, Context)) {
|
||||
reg(SubagentStart, func(e Event, c Context) Result {
|
||||
h(e.(SubagentStartEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onSubagentChunk: func(h func(SubagentChunkEvent, Context)) {
|
||||
reg(SubagentChunk, func(e Event, c Context) Result {
|
||||
h(e.(SubagentChunkEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onSubagentEnd: func(h func(SubagentEndEvent, Context)) {
|
||||
reg(SubagentEnd, func(e Event, c Context) Result {
|
||||
h(e.(SubagentEndEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
// Call Init — the extension registers its handlers, tools, commands.
|
||||
|
||||
@@ -56,11 +56,165 @@ func NewRunner(exts []LoadedExtension) *Runner {
|
||||
}
|
||||
|
||||
// SetContext updates the runtime context (session ID, model, etc.) that is
|
||||
// passed to every handler invocation. Thread-safe.
|
||||
// passed to every handler invocation. Nil function fields are replaced with
|
||||
// safe no-ops so extension handlers never panic on a missing callback.
|
||||
// Thread-safe.
|
||||
func (r *Runner) SetContext(ctx Context) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.ctx = ctx
|
||||
r.ctx = normalizeContext(ctx)
|
||||
}
|
||||
|
||||
// normalizeContext replaces nil function fields in ctx with no-op stubs so
|
||||
// that extension handlers can call any ctx method without a nil-function panic.
|
||||
func normalizeContext(ctx Context) Context {
|
||||
if ctx.Print == nil {
|
||||
ctx.Print = func(string) {}
|
||||
}
|
||||
if ctx.PrintInfo == nil {
|
||||
ctx.PrintInfo = func(string) {}
|
||||
}
|
||||
if ctx.PrintError == nil {
|
||||
ctx.PrintError = func(string) {}
|
||||
}
|
||||
if ctx.PrintBlock == nil {
|
||||
ctx.PrintBlock = func(PrintBlockOpts) {}
|
||||
}
|
||||
if ctx.SendMessage == nil {
|
||||
ctx.SendMessage = func(string) {}
|
||||
}
|
||||
if ctx.CancelAndSend == nil {
|
||||
ctx.CancelAndSend = func(string) {}
|
||||
}
|
||||
if ctx.SetWidget == nil {
|
||||
ctx.SetWidget = func(WidgetConfig) {}
|
||||
}
|
||||
if ctx.RemoveWidget == nil {
|
||||
ctx.RemoveWidget = func(string) {}
|
||||
}
|
||||
if ctx.SetHeader == nil {
|
||||
ctx.SetHeader = func(HeaderFooterConfig) {}
|
||||
}
|
||||
if ctx.RemoveHeader == nil {
|
||||
ctx.RemoveHeader = func() {}
|
||||
}
|
||||
if ctx.SetFooter == nil {
|
||||
ctx.SetFooter = func(HeaderFooterConfig) {}
|
||||
}
|
||||
if ctx.RemoveFooter == nil {
|
||||
ctx.RemoveFooter = func() {}
|
||||
}
|
||||
if ctx.PromptSelect == nil {
|
||||
ctx.PromptSelect = func(PromptSelectConfig) PromptSelectResult {
|
||||
return PromptSelectResult{Cancelled: true}
|
||||
}
|
||||
}
|
||||
if ctx.PromptConfirm == nil {
|
||||
ctx.PromptConfirm = func(PromptConfirmConfig) PromptConfirmResult {
|
||||
return PromptConfirmResult{Cancelled: true}
|
||||
}
|
||||
}
|
||||
if ctx.PromptInput == nil {
|
||||
ctx.PromptInput = func(PromptInputConfig) PromptInputResult {
|
||||
return PromptInputResult{Cancelled: true}
|
||||
}
|
||||
}
|
||||
if ctx.PromptMultiSelect == nil {
|
||||
ctx.PromptMultiSelect = func(PromptMultiSelectConfig) PromptMultiSelectResult {
|
||||
return PromptMultiSelectResult{Cancelled: true}
|
||||
}
|
||||
}
|
||||
if ctx.ShowOverlay == nil {
|
||||
ctx.ShowOverlay = func(OverlayConfig) OverlayResult {
|
||||
return OverlayResult{Cancelled: true, Index: -1}
|
||||
}
|
||||
}
|
||||
if ctx.SetEditor == nil {
|
||||
ctx.SetEditor = func(EditorConfig) {}
|
||||
}
|
||||
if ctx.ResetEditor == nil {
|
||||
ctx.ResetEditor = func() {}
|
||||
}
|
||||
if ctx.SetEditorText == nil {
|
||||
ctx.SetEditorText = func(string) {}
|
||||
}
|
||||
if ctx.SetUIVisibility == nil {
|
||||
ctx.SetUIVisibility = func(UIVisibility) {}
|
||||
}
|
||||
if ctx.SetStatus == nil {
|
||||
ctx.SetStatus = func(string, string, int) {}
|
||||
}
|
||||
if ctx.RemoveStatus == nil {
|
||||
ctx.RemoveStatus = func(string) {}
|
||||
}
|
||||
if ctx.GetContextStats == nil {
|
||||
ctx.GetContextStats = func() ContextStats { return ContextStats{} }
|
||||
}
|
||||
if ctx.GetMessages == nil {
|
||||
ctx.GetMessages = func() []SessionMessage { return nil }
|
||||
}
|
||||
if ctx.GetSessionPath == nil {
|
||||
ctx.GetSessionPath = func() string { return "" }
|
||||
}
|
||||
if ctx.AppendEntry == nil {
|
||||
ctx.AppendEntry = func(string, string) (string, error) { return "", nil }
|
||||
}
|
||||
if ctx.GetEntries == nil {
|
||||
ctx.GetEntries = func(string) []ExtensionEntry { return nil }
|
||||
}
|
||||
if ctx.GetOption == nil {
|
||||
ctx.GetOption = func(string) string { return "" }
|
||||
}
|
||||
if ctx.SetOption == nil {
|
||||
ctx.SetOption = func(string, string) {}
|
||||
}
|
||||
if ctx.SetModel == nil {
|
||||
ctx.SetModel = func(string) error { return nil }
|
||||
}
|
||||
if ctx.GetAvailableModels == nil {
|
||||
ctx.GetAvailableModels = func() []ModelInfoEntry { return nil }
|
||||
}
|
||||
if ctx.EmitCustomEvent == nil {
|
||||
ctx.EmitCustomEvent = func(string, string) {}
|
||||
}
|
||||
if ctx.GetAllTools == nil {
|
||||
ctx.GetAllTools = func() []ToolInfo { return nil }
|
||||
}
|
||||
if ctx.SetActiveTools == nil {
|
||||
ctx.SetActiveTools = func([]string) {}
|
||||
}
|
||||
if ctx.Exit == nil {
|
||||
ctx.Exit = func() {}
|
||||
}
|
||||
if ctx.Complete == nil {
|
||||
ctx.Complete = func(CompleteRequest) (CompleteResponse, error) {
|
||||
return CompleteResponse{}, nil
|
||||
}
|
||||
}
|
||||
if ctx.SuspendTUI == nil {
|
||||
ctx.SuspendTUI = func(callback func()) error { callback(); return nil }
|
||||
}
|
||||
if ctx.RenderMessage == nil {
|
||||
ctx.RenderMessage = func(string, string) {}
|
||||
}
|
||||
if ctx.RegisterTheme == nil {
|
||||
ctx.RegisterTheme = func(string, ThemeColorConfig) {}
|
||||
}
|
||||
if ctx.SetTheme == nil {
|
||||
ctx.SetTheme = func(string) error { return nil }
|
||||
}
|
||||
if ctx.ListThemes == nil {
|
||||
ctx.ListThemes = func() []string { return nil }
|
||||
}
|
||||
if ctx.ReloadExtensions == nil {
|
||||
ctx.ReloadExtensions = func() error { return nil }
|
||||
}
|
||||
if ctx.SpawnSubagent == nil {
|
||||
ctx.SpawnSubagent = func(SubagentConfig) (*SubagentHandle, *SubagentResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
// GetContext returns a snapshot of the current runtime context. Thread-safe.
|
||||
|
||||
@@ -119,6 +119,11 @@ func Symbols() interp.Exports {
|
||||
"SubagentHandle": reflect.ValueOf((*SubagentHandle)(nil)),
|
||||
"SubagentEvent": reflect.ValueOf((*SubagentEvent)(nil)),
|
||||
|
||||
// Subagent lifecycle events
|
||||
"SubagentStartEvent": reflect.ValueOf((*SubagentStartEvent)(nil)),
|
||||
"SubagentChunkEvent": reflect.ValueOf((*SubagentChunkEvent)(nil)),
|
||||
"SubagentEndEvent": reflect.ValueOf((*SubagentEndEvent)(nil)),
|
||||
|
||||
// Theme types
|
||||
"ThemeColor": reflect.ValueOf((*ThemeColor)(nil)),
|
||||
"ThemeColorConfig": reflect.ValueOf((*ThemeColorConfig)(nil)),
|
||||
|
||||
@@ -171,5 +171,23 @@ func NewTestAPI(ext *LoadedExtension) API {
|
||||
registerMessageRendererFn: func(config MessageRendererConfig) {
|
||||
ext.MessageRenderers = append(ext.MessageRenderers, config)
|
||||
},
|
||||
onSubagentStart: func(h func(SubagentStartEvent, Context)) {
|
||||
reg(SubagentStart, func(e Event, c Context) Result {
|
||||
h(e.(SubagentStartEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onSubagentChunk: func(h func(SubagentChunkEvent, Context)) {
|
||||
reg(SubagentChunk, func(e Event, c Context) Result {
|
||||
h(e.(SubagentChunkEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onSubagentEnd: func(h func(SubagentEndEvent, Context)) {
|
||||
reg(SubagentEnd, func(e Event, c Context) Result {
|
||||
h(e.(SubagentEndEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,11 +4,38 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// sanitizeToolCallID ensures the ID matches Anthropic's required pattern:
|
||||
// ^[a-zA-Z0-9_-]+$ (alphanumeric, underscores, and hyphens only).
|
||||
// Invalid characters are replaced with underscores.
|
||||
func sanitizeToolCallID(id string) string {
|
||||
var sb strings.Builder
|
||||
for _, r := range id {
|
||||
switch {
|
||||
case (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z'):
|
||||
sb.WriteRune(r)
|
||||
case r >= '0' && r <= '9':
|
||||
sb.WriteRune(r)
|
||||
case r == '_' || r == '-':
|
||||
sb.WriteRune(r)
|
||||
default:
|
||||
// Replace invalid characters with underscore
|
||||
sb.WriteByte('_')
|
||||
}
|
||||
}
|
||||
result := sb.String()
|
||||
// Ensure non-empty (Anthropic requires at least one character)
|
||||
if result == "" {
|
||||
return "tool_0"
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ContentPart is the marker interface for all message content block types.
|
||||
// A message contains a heterogeneous slice of ContentPart values, enabling
|
||||
// rich structured messages that carry text, reasoning, tool calls, tool
|
||||
@@ -312,7 +339,7 @@ func (m *Message) ToFantasyMessages() []fantasy.Message {
|
||||
// Add tool calls
|
||||
for _, tc := range m.ToolCalls() {
|
||||
parts = append(parts, fantasy.ToolCallPart{
|
||||
ToolCallID: tc.ID,
|
||||
ToolCallID: sanitizeToolCallID(tc.ID),
|
||||
ToolName: tc.Name,
|
||||
Input: tc.Input,
|
||||
})
|
||||
@@ -340,7 +367,7 @@ func (m *Message) ToFantasyMessages() []fantasy.Message {
|
||||
}
|
||||
}
|
||||
parts = append(parts, fantasy.ToolResultPart{
|
||||
ToolCallID: result.ToolCallID,
|
||||
ToolCallID: sanitizeToolCallID(result.ToolCallID),
|
||||
Output: output,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSanitizeToolCallID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "valid alphanumeric ID",
|
||||
input: "call_123abc",
|
||||
expected: "call_123abc",
|
||||
},
|
||||
{
|
||||
name: "ID with dots (OpenCode/Kimi style)",
|
||||
input: "call.123.abc",
|
||||
expected: "call_123_abc",
|
||||
},
|
||||
{
|
||||
name: "ID with colons",
|
||||
input: "tool:123:abc",
|
||||
expected: "tool_123_abc",
|
||||
},
|
||||
{
|
||||
name: "ID with special characters",
|
||||
input: "tool@#$%^&*()",
|
||||
expected: "tool_________",
|
||||
},
|
||||
{
|
||||
name: "Anthropic style ID (already valid)",
|
||||
input: "toolu_0123456789ABCDEF",
|
||||
expected: "toolu_0123456789ABCDEF",
|
||||
},
|
||||
{
|
||||
name: "OpenAI style ID (already valid)",
|
||||
input: "call_O17Uplv4lJvD6DVdIvFFeRMw",
|
||||
expected: "call_O17Uplv4lJvD6DVdIvFFeRMw",
|
||||
},
|
||||
{
|
||||
name: "ID with hyphens",
|
||||
input: "my-tool-call-123",
|
||||
expected: "my-tool-call-123",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "tool_0",
|
||||
},
|
||||
{
|
||||
name: "only special characters",
|
||||
input: "@#$%",
|
||||
expected: "____",
|
||||
},
|
||||
{
|
||||
name: "mixed valid and invalid",
|
||||
input: "call_123.abc-def@ghi",
|
||||
expected: "call_123_abc-def_ghi",
|
||||
},
|
||||
{
|
||||
name: "Unicode characters",
|
||||
input: "tool_日本語",
|
||||
expected: "tool____",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := sanitizeToolCallID(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("sanitizeToolCallID(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeToolCallID_MatchesAnthropicPattern(t *testing.T) {
|
||||
// Test that sanitized IDs match Anthropic's required pattern: ^[a-zA-Z0-9_-]+$
|
||||
// This is a simplified check - in reality the pattern allows alphanumeric, underscore, hyphen
|
||||
testIDs := []string{
|
||||
"call.123.abc",
|
||||
"tool:123:def",
|
||||
"id@#$%^&*()",
|
||||
"mixed.valid-id_test",
|
||||
"",
|
||||
}
|
||||
|
||||
for _, id := range testIDs {
|
||||
sanitized := sanitizeToolCallID(id)
|
||||
|
||||
// Verify each character is valid
|
||||
for i, r := range sanitized {
|
||||
valid := (r >= 'a' && r <= 'z') ||
|
||||
(r >= 'A' && r <= 'Z') ||
|
||||
(r >= '0' && r <= '9') ||
|
||||
r == '_' ||
|
||||
r == '-'
|
||||
|
||||
if !valid {
|
||||
t.Errorf("sanitizeToolCallID(%q) = %q, contains invalid character at position %d: %q",
|
||||
id, sanitized, i, string(r))
|
||||
}
|
||||
}
|
||||
|
||||
// Verify non-empty
|
||||
if sanitized == "" {
|
||||
t.Errorf("sanitizeToolCallID(%q) returned empty string", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -17,15 +17,21 @@ type modelsDBProvider struct {
|
||||
|
||||
// modelsDBModel represents a model entry from models.dev/api.json.
|
||||
type modelsDBModel struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Family string `json:"family,omitempty"`
|
||||
Attachment bool `json:"attachment"`
|
||||
Reasoning bool `json:"reasoning"`
|
||||
ToolCall bool `json:"tool_call"`
|
||||
Temperature bool `json:"temperature"`
|
||||
Cost modelsDBCost `json:"cost"`
|
||||
Limit modelsDBLimit `json:"limit"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Family string `json:"family,omitempty"`
|
||||
Attachment bool `json:"attachment"`
|
||||
Reasoning bool `json:"reasoning"`
|
||||
ToolCall bool `json:"tool_call"`
|
||||
Temperature bool `json:"temperature"`
|
||||
Cost modelsDBCost `json:"cost"`
|
||||
Limit modelsDBLimit `json:"limit"`
|
||||
Provider *modelsDBModelProvider `json:"provider,omitempty"` // Model-specific provider override
|
||||
}
|
||||
|
||||
// modelsDBModelProvider represents a provider reference within a model.
|
||||
type modelsDBModelProvider struct {
|
||||
NPM string `json:"npm"`
|
||||
}
|
||||
|
||||
// modelsDBCost represents model pricing from models.dev.
|
||||
|
||||
@@ -169,6 +169,9 @@ type ProviderResult struct {
|
||||
// ProviderOptions contains provider-specific options to be passed to the
|
||||
// fantasy agent (e.g. OpenAI Responses API reasoning options).
|
||||
ProviderOptions fantasy.ProviderOptions
|
||||
// SkipMaxOutputTokens indicates that this provider doesn't support the
|
||||
// max_output_tokens parameter (e.g., OpenAI Codex OAuth API).
|
||||
SkipMaxOutputTokens bool
|
||||
}
|
||||
|
||||
// ParseModelString parses a model string in "provider/model" format (e.g. "anthropic/claude-sonnet-4-5").
|
||||
@@ -263,14 +266,22 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
|
||||
// autoRouteProvider attempts to create a provider by looking up its npm package
|
||||
// in the models.dev database and routing through the appropriate fantasy provider.
|
||||
// For openai-compatible providers, it uses the api URL from models.dev.
|
||||
// Models may have a provider override that specifies a different npm package than
|
||||
// the provider's default (e.g., opencode's claude-opus-4-6 uses @ai-sdk/anthropic).
|
||||
func autoRouteProvider(ctx context.Context, config *ProviderConfig, provider, modelName string, registry *ModelsRegistry) (*ProviderResult, error) {
|
||||
providerInfo := registry.GetProviderInfo(provider)
|
||||
if providerInfo == nil {
|
||||
return nil, fmt.Errorf("unsupported provider: %s (not found in model database)", provider)
|
||||
}
|
||||
|
||||
// Check for model-specific provider override
|
||||
npmPackage := providerInfo.NPM
|
||||
if modelInfo := registry.LookupModel(provider, modelName); modelInfo != nil && modelInfo.ProviderNPM != "" {
|
||||
npmPackage = modelInfo.ProviderNPM
|
||||
}
|
||||
|
||||
// Determine the fantasy provider for this npm package
|
||||
fantasyProvider := npmToFantasyProvider[providerInfo.NPM]
|
||||
fantasyProvider := npmToFantasyProvider[npmPackage]
|
||||
if fantasyProvider == "" && providerInfo.API != "" {
|
||||
// Unknown npm but has API URL → route through openaicompat
|
||||
fantasyProvider = "openaicompat"
|
||||
@@ -290,7 +301,7 @@ func autoRouteProvider(ctx context.Context, config *ProviderConfig, provider, mo
|
||||
}
|
||||
return createAutoRoutedOpenAIProvider(ctx, config, modelName, providerInfo)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported provider: %s (npm: %s has no fantasy mapping)", provider, providerInfo.NPM)
|
||||
return nil, fmt.Errorf("unsupported provider: %s (npm: %s has no fantasy mapping)", provider, npmPackage)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -348,7 +359,10 @@ func createAutoRoutedAnthropicProvider(ctx context.Context, config *ProviderConf
|
||||
opts = append(opts, anthropic.WithAPIKey(apiKey))
|
||||
|
||||
if config.ProviderURL != "" {
|
||||
opts = append(opts, anthropic.WithBaseURL(config.ProviderURL))
|
||||
// The anthropic client appends "/v1/messages" to the base URL.
|
||||
// If the provider URL ends with "/v1", strip it to avoid double "/v1/v1" paths.
|
||||
baseURL := strings.TrimSuffix(config.ProviderURL, "/v1")
|
||||
opts = append(opts, anthropic.WithBaseURL(baseURL))
|
||||
}
|
||||
|
||||
if config.TLSSkipVerify {
|
||||
@@ -610,13 +624,52 @@ func createVertexAnthropicProvider(ctx context.Context, config *ProviderConfig,
|
||||
|
||||
func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
||||
apiKey := config.ProviderAPIKey
|
||||
source := "command-line flag"
|
||||
var accountID string
|
||||
var isCodexOAuth bool
|
||||
|
||||
if apiKey == "" {
|
||||
apiKey = os.Getenv("OPENAI_API_KEY")
|
||||
}
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("OpenAI API key not provided. Use --provider-api-key flag or OPENAI_API_KEY environment variable")
|
||||
// Check stored credentials first
|
||||
cm, err := auth.NewCredentialManager()
|
||||
if err == nil {
|
||||
if creds, err := cm.GetOpenAICredentials(); err == nil && creds != nil {
|
||||
if creds.Type == "oauth" && creds.AccessToken != "" {
|
||||
// For OAuth, get a valid access token (may refresh if needed)
|
||||
token, err := cm.GetValidOpenAIAccessToken()
|
||||
if err == nil && token != "" {
|
||||
apiKey = token
|
||||
accountID = creds.AccountID
|
||||
isCodexOAuth = true
|
||||
source = "stored Codex OAuth credentials"
|
||||
}
|
||||
} else if creds.Type == "api_key" && creds.APIKey != "" {
|
||||
apiKey = creds.APIKey
|
||||
source = "stored API key"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to environment variable
|
||||
if apiKey == "" {
|
||||
apiKey = os.Getenv("OPENAI_API_KEY")
|
||||
source = "OPENAI_API_KEY environment variable"
|
||||
}
|
||||
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("OpenAI API key not provided. Use 'kit auth login openai', --provider-api-key flag, or OPENAI_API_KEY environment variable")
|
||||
}
|
||||
|
||||
if os.Getenv("DEBUG") != "" || os.Getenv("KIT_DEBUG") != "" {
|
||||
fmt.Fprintf(os.Stderr, "Using OpenAI API key from: %s\n", source)
|
||||
}
|
||||
|
||||
// For Codex OAuth, use the ChatGPT backend API with custom headers
|
||||
if isCodexOAuth {
|
||||
return createOpenAICodexProvider(ctx, config, modelName, apiKey, accountID)
|
||||
}
|
||||
|
||||
// Regular OpenAI API key flow
|
||||
var opts []openai.Option
|
||||
opts = append(opts, openai.WithAPIKey(apiKey))
|
||||
opts = append(opts, openai.WithUseResponsesAPI())
|
||||
@@ -645,6 +698,135 @@ func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
return &ProviderResult{Model: model, ProviderOptions: providerOpts}, nil
|
||||
}
|
||||
|
||||
// createOpenAICodexProvider creates a provider for ChatGPT/Codex OAuth tokens.
|
||||
// Uses the chatgpt.com/backend-api/codex endpoint with special headers.
|
||||
func createOpenAICodexProvider(ctx context.Context, config *ProviderConfig, modelName, token, accountID string) (*ProviderResult, error) {
|
||||
// Check for spark models which are not accessible via OAuth
|
||||
if detectCodexModelFamily(modelName) == "gpt-codex-spark" {
|
||||
return nil, fmt.Errorf("gpt-codex-spark models are not accessible via ChatGPT OAuth. " +
|
||||
"These models require special access or a different authentication method. " +
|
||||
"Please use regular Codex models like 'openai/gpt-5.3-codex' instead")
|
||||
}
|
||||
|
||||
// Use the ChatGPT backend API with /codex path
|
||||
baseURL := "https://chatgpt.com/backend-api/codex"
|
||||
if config.ProviderURL != "" {
|
||||
baseURL = config.ProviderURL
|
||||
}
|
||||
|
||||
// Build custom HTTP client with required headers
|
||||
httpClient := createCodexHTTPClient(token, accountID, config.TLSSkipVerify)
|
||||
|
||||
var opts []openai.Option
|
||||
opts = append(opts, openai.WithAPIKey(token))
|
||||
opts = append(opts, openai.WithBaseURL(baseURL))
|
||||
opts = append(opts, openai.WithUseResponsesAPI())
|
||||
opts = append(opts, openai.WithHTTPClient(httpClient))
|
||||
|
||||
provider, err := openai.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OpenAI Codex provider: %w", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OpenAI Codex model: %w", err)
|
||||
}
|
||||
|
||||
providerOpts := buildCodexProviderOptions(config, modelName)
|
||||
|
||||
return &ProviderResult{
|
||||
Model: model,
|
||||
ProviderOptions: providerOpts,
|
||||
SkipMaxOutputTokens: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildCodexProviderOptions returns fantasy.ProviderOptions configured for
|
||||
// OpenAI Codex API. The Codex API requires the system prompt to be passed
|
||||
// as 'instructions' rather than as a system message.
|
||||
func buildCodexProviderOptions(config *ProviderConfig, modelName string) fantasy.ProviderOptions {
|
||||
store := false
|
||||
opts := &openai.ResponsesProviderOptions{
|
||||
Store: &store,
|
||||
}
|
||||
|
||||
if config.SystemPrompt != "" {
|
||||
opts.Instructions = &config.SystemPrompt
|
||||
}
|
||||
|
||||
if openai.IsResponsesReasoningModel(modelName) {
|
||||
opts.ReasoningEffort = thinkingLevelToReasoningEffort(config.ThinkingLevel)
|
||||
}
|
||||
|
||||
return fantasy.ProviderOptions{openai.Name: opts}
|
||||
}
|
||||
|
||||
// detectCodexModelFamily determines the model family from the model name
|
||||
func detectCodexModelFamily(modelName string) string {
|
||||
modelName = strings.ToLower(modelName)
|
||||
if strings.Contains(modelName, "spark") {
|
||||
return "gpt-codex-spark"
|
||||
}
|
||||
if strings.Contains(modelName, "codex-mini") || strings.Contains(modelName, "mini-latest") {
|
||||
return "gpt-codex-mini"
|
||||
}
|
||||
if strings.Contains(modelName, "codex") {
|
||||
return "gpt-codex"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// createCodexHTTPClient creates an HTTP client with headers required for ChatGPT/Codex API
|
||||
func createCodexHTTPClient(token, accountID string, skipVerify bool) *http.Client {
|
||||
var base http.RoundTripper
|
||||
if skipVerify {
|
||||
base = &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
base = http.DefaultTransport
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: &codexTransport{
|
||||
base: base,
|
||||
token: token,
|
||||
accountID: accountID,
|
||||
},
|
||||
Timeout: 120 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// codexTransport is a custom RoundTripper that adds ChatGPT/Codex specific headers
|
||||
type codexTransport struct {
|
||||
base http.RoundTripper
|
||||
token string
|
||||
accountID string
|
||||
}
|
||||
|
||||
func (t *codexTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
newReq := req.Clone(req.Context())
|
||||
|
||||
// Add required headers for ChatGPT/Codex API
|
||||
// These headers mimic the official pi client to avoid Cloudflare blocking
|
||||
newReq.Header.Set("Authorization", "Bearer "+t.token)
|
||||
if t.accountID != "" {
|
||||
newReq.Header.Set("chatgpt-account-id", t.accountID)
|
||||
}
|
||||
newReq.Header.Set("originator", "kit")
|
||||
newReq.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36")
|
||||
newReq.Header.Set("OpenAI-Beta", "responses=experimental")
|
||||
newReq.Header.Set("Accept", "text/event-stream")
|
||||
newReq.Header.Set("Accept-Language", "en-US,en;q=0.9")
|
||||
newReq.Header.Set("Cache-Control", "no-cache")
|
||||
newReq.Header.Set("Pragma", "no-cache")
|
||||
|
||||
return t.base.RoundTrip(newReq)
|
||||
}
|
||||
|
||||
func createGoogleProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
||||
apiKey := firstNonEmpty(
|
||||
config.ProviderAPIKey,
|
||||
|
||||
@@ -22,6 +22,7 @@ type ModelInfo struct {
|
||||
Temperature bool
|
||||
Cost Cost
|
||||
Limit Limit
|
||||
ProviderNPM string // Model-specific provider npm override (e.g. "@ai-sdk/anthropic")
|
||||
}
|
||||
|
||||
// Cost represents the pricing information for a model.
|
||||
@@ -78,6 +79,10 @@ func buildFromModelsDB() map[string]ProviderInfo {
|
||||
for providerID, dp := range dbProviders {
|
||||
modelsMap := make(map[string]ModelInfo, len(dp.Models))
|
||||
for modelID, dm := range dp.Models {
|
||||
providerNPM := ""
|
||||
if dm.Provider != nil {
|
||||
providerNPM = dm.Provider.NPM
|
||||
}
|
||||
modelsMap[modelID] = ModelInfo{
|
||||
ID: dm.ID,
|
||||
Name: dm.Name,
|
||||
@@ -94,6 +99,7 @@ func buildFromModelsDB() map[string]ProviderInfo {
|
||||
Context: dm.Limit.Context,
|
||||
Output: dm.Limit.Output,
|
||||
},
|
||||
ProviderNPM: providerNPM,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -219,6 +225,15 @@ func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) err
|
||||
}
|
||||
}
|
||||
|
||||
// For openai, check stored credentials (OAuth / API key)
|
||||
if provider == "openai" {
|
||||
if cm, err := auth.NewCredentialManager(); err == nil {
|
||||
if has, _ := cm.HasOpenAICredentials(); has {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
envVars, err := r.getRequiredEnvVars(provider)
|
||||
if err != nil {
|
||||
// Unknown provider — nothing to validate
|
||||
|
||||
@@ -96,6 +96,7 @@ func ListAllSessions() ([]SessionInfo, error) {
|
||||
}
|
||||
|
||||
// listSessionsInDir reads all .jsonl files in a directory and extracts session info.
|
||||
// Empty sessions (no messages) are automatically cleaned up and not returned.
|
||||
func listSessionsInDir(dir string) ([]SessionInfo, error) {
|
||||
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
@@ -117,6 +118,11 @@ func listSessionsInDir(dir string) ([]SessionInfo, error) {
|
||||
if err != nil {
|
||||
continue // skip malformed session files
|
||||
}
|
||||
// Clean up and skip empty sessions (no messages)
|
||||
if info.MessageCount == 0 {
|
||||
_ = os.Remove(path)
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, *info)
|
||||
}
|
||||
|
||||
|
||||
@@ -628,6 +628,11 @@ func (tm *TreeManager) MessageCount() int {
|
||||
return count
|
||||
}
|
||||
|
||||
// IsEmpty returns true if the session has no messages (only header).
|
||||
func (tm *TreeManager) IsEmpty() bool {
|
||||
return tm.MessageCount() == 0
|
||||
}
|
||||
|
||||
// Close closes the underlying file handle.
|
||||
func (tm *TreeManager) Close() error {
|
||||
tm.mu.Lock()
|
||||
|
||||
@@ -349,7 +349,7 @@ func TestStreamComponent_SpinnerKeepsRunningDuringStreaming(t *testing.T) {
|
||||
c = sendStreamMsg(c, app.StreamChunkEvent{Content: "hello"})
|
||||
|
||||
// Flush pending chunks (simulates the 16ms tick firing).
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{})
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{generation: c.flushGeneration})
|
||||
|
||||
if !c.spinning {
|
||||
t.Fatal("expected spinning=true after first chunk")
|
||||
@@ -376,7 +376,7 @@ func TestStreamComponent_ChunkAccumulation(t *testing.T) {
|
||||
}
|
||||
|
||||
// Flush pending chunks (simulates the 16ms tick firing).
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{})
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{generation: c.flushGeneration})
|
||||
|
||||
got := c.streamContent.String()
|
||||
want := "Hello, world!"
|
||||
@@ -396,6 +396,7 @@ func TestStreamComponent_ToolExecution_IsStarting_ShowsSpinner(t *testing.T) {
|
||||
c := newTestStream()
|
||||
|
||||
_, cmd := c.Update(app.ToolExecutionEvent{
|
||||
ToolCallID: "call-exec-1",
|
||||
ToolName: "exec_tool",
|
||||
IsStarting: true,
|
||||
})
|
||||
@@ -403,8 +404,9 @@ func TestStreamComponent_ToolExecution_IsStarting_ShowsSpinner(t *testing.T) {
|
||||
if !c.spinning {
|
||||
t.Fatal("expected spinning=true during tool execution")
|
||||
}
|
||||
if len(c.activeTools) != 1 || !strings.Contains(c.activeTools[0], "exec_tool") {
|
||||
t.Fatalf("expected activeTools to contain tool name, got %v", c.activeTools)
|
||||
tools := c.activeToolDisplays()
|
||||
if len(tools) != 1 || !strings.Contains(tools[0], "exec_tool") {
|
||||
t.Fatalf("expected activeTools to contain tool name, got %v", tools)
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Fatal("expected tick cmd from ToolExecutionEvent{IsStarting:true}")
|
||||
@@ -418,11 +420,13 @@ func TestStreamComponent_ToolExecution_NotStarting_KeepsSpinning(t *testing.T) {
|
||||
c = sendStreamMsg(c, app.SpinnerEvent{Show: true})
|
||||
// Simulate a tool starting
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{
|
||||
ToolCallID: "call-some-1",
|
||||
ToolName: "some_tool",
|
||||
IsStarting: true,
|
||||
})
|
||||
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{
|
||||
ToolCallID: "call-some-1",
|
||||
ToolName: "some_tool",
|
||||
IsStarting: false,
|
||||
})
|
||||
@@ -440,9 +444,9 @@ func TestStreamComponent_ParallelToolExecution(t *testing.T) {
|
||||
c := newTestStream()
|
||||
|
||||
// Start three tools in parallel
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "read", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "grep", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "find", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-read", ToolName: "read", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-grep", ToolName: "grep", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-find", ToolName: "find", IsStarting: true})
|
||||
|
||||
if len(c.activeTools) != 3 {
|
||||
t.Fatalf("expected 3 active tools, got %d: %v", len(c.activeTools), c.activeTools)
|
||||
@@ -455,19 +459,44 @@ func TestStreamComponent_ParallelToolExecution(t *testing.T) {
|
||||
}
|
||||
|
||||
// Finish one tool
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "grep", IsStarting: false})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-grep", ToolName: "grep", IsStarting: false})
|
||||
if len(c.activeTools) != 2 {
|
||||
t.Fatalf("expected 2 active tools after one finished, got %d: %v", len(c.activeTools), c.activeTools)
|
||||
}
|
||||
|
||||
// Finish remaining tools
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "read", IsStarting: false})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolName: "find", IsStarting: false})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-read", ToolName: "read", IsStarting: false})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-find", ToolName: "find", IsStarting: false})
|
||||
if len(c.activeTools) != 0 {
|
||||
t.Fatalf("expected 0 active tools after all finished, got %d: %v", len(c.activeTools), c.activeTools)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamComponent_ParallelSameToolName_UsesToolCallID verifies finishing one
|
||||
// tool call does not remove another concurrent call with the same tool name.
|
||||
func TestStreamComponent_ParallelSameToolName_UsesToolCallID(t *testing.T) {
|
||||
c := newTestStream()
|
||||
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-read-1", ToolName: "read", IsStarting: true})
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-read-2", ToolName: "read", IsStarting: true})
|
||||
|
||||
tools := c.activeToolDisplays()
|
||||
if len(tools) != 2 {
|
||||
t.Fatalf("expected 2 active read calls, got %d (%v)", len(tools), tools)
|
||||
}
|
||||
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-read-1", ToolName: "read", IsStarting: false})
|
||||
tools = c.activeToolDisplays()
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("expected 1 active read call after finishing one ID, got %d (%v)", len(tools), tools)
|
||||
}
|
||||
|
||||
c = sendStreamMsg(c, app.ToolExecutionEvent{ToolCallID: "call-read-2", ToolName: "read", IsStarting: false})
|
||||
if len(c.activeToolDisplays()) != 0 {
|
||||
t.Fatalf("expected no active tools after finishing both IDs, got %v", c.activeToolDisplays())
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// TestStreamComponent_GetRenderedContent verifies the method returns rendered
|
||||
// text when content is accumulated, and empty string when not.
|
||||
@@ -621,3 +650,43 @@ func TestStreamComponent_StaleTick_Discarded(t *testing.T) {
|
||||
t.Fatal("current-gen tick should reschedule")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamComponent_StaleFlushTick_Discarded verifies that flush ticks from a
|
||||
// previous generation (e.g. pre-Reset) are ignored.
|
||||
func TestStreamComponent_StaleFlushTick_Discarded(t *testing.T) {
|
||||
c := newTestStream()
|
||||
|
||||
// Start a pending flush and capture its generation.
|
||||
c = sendStreamMsg(c, app.StreamChunkEvent{Content: "old"})
|
||||
staleGen := c.flushGeneration
|
||||
if !c.flushPending {
|
||||
t.Fatal("precondition: expected flushPending=true after first chunk")
|
||||
}
|
||||
|
||||
// Reset should invalidate in-flight flush ticks.
|
||||
c.Reset()
|
||||
if c.flushGeneration == staleGen {
|
||||
t.Fatal("expected flushGeneration to change after Reset")
|
||||
}
|
||||
|
||||
// New content in a new generation.
|
||||
c = sendStreamMsg(c, app.StreamChunkEvent{Content: "new"})
|
||||
if got := c.pendingStream.String(); got != "new" {
|
||||
t.Fatalf("expected pendingStream='new', got %q", got)
|
||||
}
|
||||
|
||||
// Stale flush tick should be ignored.
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{generation: staleGen})
|
||||
if got := c.pendingStream.String(); got != "new" {
|
||||
t.Fatalf("stale flush tick should not commit pending stream, got %q", got)
|
||||
}
|
||||
|
||||
// Current generation flush should commit.
|
||||
c = sendStreamMsg(c, streamFlushTickMsg{generation: c.flushGeneration})
|
||||
if got := c.pendingStream.String(); got != "" {
|
||||
t.Fatalf("expected pendingStream empty after current flush, got %q", got)
|
||||
}
|
||||
if got := c.streamContent.String(); got != "new" {
|
||||
t.Fatalf("expected streamContent='new' after current flush, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
+2
-1
@@ -192,7 +192,8 @@ func (c *CLI) UpdateUsageFromResponse(response *fantasy.Response, inputText stri
|
||||
outputTokens := int(usage.OutputTokens)
|
||||
|
||||
// Validate that the metadata seems reasonable
|
||||
if inputTokens > 0 && outputTokens > 0 {
|
||||
// Use API-reported tokens if input tokens are available (output may be 0 in some cases)
|
||||
if inputTokens > 0 {
|
||||
cacheReadTokens := int(usage.CacheReadTokens)
|
||||
cacheWriteTokens := int(usage.CacheCreationTokens)
|
||||
c.usageTracker.UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens)
|
||||
|
||||
+14
-1
@@ -65,6 +65,10 @@ type InputComponent struct {
|
||||
// hideHint suppresses the "enter submit · ctrl+j..." hint text.
|
||||
hideHint bool
|
||||
|
||||
// agentBusy indicates the agent is currently working. When true, the
|
||||
// hint text shows steering shortcut (Ctrl+S) instead of submit.
|
||||
agentBusy bool
|
||||
|
||||
// pendingImages holds clipboard images attached to the next submission.
|
||||
// Images are added via Ctrl+V and cleared on submit or Ctrl+U.
|
||||
pendingImages []ImageAttachment
|
||||
@@ -514,7 +518,16 @@ func (s *InputComponent) View() tea.View {
|
||||
// Adapt hint text to available width (accounting for left padding of 3).
|
||||
var hint string
|
||||
availableHintWidth := s.width - 3
|
||||
if availableHintWidth >= 67 {
|
||||
if s.agentBusy {
|
||||
// When the agent is working, show steering shortcut.
|
||||
if availableHintWidth >= 55 {
|
||||
hint = "enter queue • ctrl+s steer • esc esc cancel"
|
||||
} else if availableHintWidth >= 35 {
|
||||
hint = "↵ queue • ^S steer • esc×2 cancel"
|
||||
} else {
|
||||
hint = "^S steer"
|
||||
}
|
||||
} else if availableHintWidth >= 67 {
|
||||
hint = "enter submit • ctrl+j / shift+enter new line • ctrl+v paste image"
|
||||
} else if availableHintWidth >= 40 {
|
||||
hint = "↵ submit • ctrl+j newline • ctrl+v image"
|
||||
|
||||
+14
-2
@@ -111,14 +111,26 @@ func formatToolParams(toolArgs string, maxWidth int) string {
|
||||
result.WriteString(primaryVal)
|
||||
}
|
||||
|
||||
// Collect remaining parameters (skip large values like file content)
|
||||
// Collect remaining parameters, skipping body-content keys (already
|
||||
// rendered in the tool body) and any values that are too large.
|
||||
bodyKeys := map[string]bool{
|
||||
"content": true,
|
||||
"old_text": true,
|
||||
"new_text": true,
|
||||
"oldText": true,
|
||||
"newText": true,
|
||||
"edits": true,
|
||||
"todos": true,
|
||||
}
|
||||
var remaining []string
|
||||
for key, val := range params {
|
||||
if key == primaryKey {
|
||||
continue
|
||||
}
|
||||
if bodyKeys[key] {
|
||||
continue
|
||||
}
|
||||
valStr := fmt.Sprintf("%v", val)
|
||||
// Skip very large values (e.g., oldString, newString, content, todos)
|
||||
if len(valStr) > 100 {
|
||||
continue
|
||||
}
|
||||
|
||||
+219
-38
@@ -3,11 +3,11 @@ package ui
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
@@ -83,6 +83,9 @@ type AppController interface {
|
||||
// GetTreeSession returns the tree session manager, or nil if tree sessions
|
||||
// are not enabled. Used by slash commands like /tree, /fork, /session.
|
||||
GetTreeSession() *session.TreeManager
|
||||
// SwitchTreeSession replaces the active tree session with a new one,
|
||||
// closing the old session. Used by /new to create a completely fresh session.
|
||||
SwitchTreeSession(ts *session.TreeManager)
|
||||
// SendEvent sends a tea.Msg to the program asynchronously. Safe to call
|
||||
// from any goroutine. Used by extension command goroutines to deliver
|
||||
// results back to the TUI without going through tea.Cmd (which can stall
|
||||
@@ -98,6 +101,12 @@ type AppController interface {
|
||||
// alongside the text. Returns the current queue depth (0 = started
|
||||
// immediately, >0 = queued).
|
||||
RunWithFiles(prompt string, files []fantasy.FilePart) int
|
||||
// Steer injects a steering message into the currently running agent
|
||||
// turn. If the agent is busy, the message is delivered between steps
|
||||
// (after current tool finishes, before next LLM call). If idle, the
|
||||
// message starts executing immediately. Returns 0 if started
|
||||
// immediately, >0 if injected/pending.
|
||||
Steer(prompt string) int
|
||||
}
|
||||
|
||||
// SkillItem holds display metadata about a loaded skill for the startup
|
||||
@@ -415,6 +424,11 @@ type AppModel struct {
|
||||
// the input and move to scrollback when the agent picks them up.
|
||||
queuedMessages []string
|
||||
|
||||
// steeringMessages stores the text of prompts that were sent as steer
|
||||
// messages (injected mid-turn via Ctrl+S). Rendered with a "STEERING"
|
||||
// badge above the input. Cleared when the steer is consumed.
|
||||
steeringMessages []string
|
||||
|
||||
// pendingUserPrints holds user messages that have been consumed from the
|
||||
// queue but not yet printed to scrollback. They are deferred until
|
||||
// SpinnerEvent{Show: true} so the previous assistant response can be
|
||||
@@ -569,8 +583,10 @@ type AppModel struct {
|
||||
streamingBashStderr []string
|
||||
// streamingBashMaxLines caps how many lines to accumulate to prevent memory issues.
|
||||
streamingBashMaxLines int
|
||||
// streamingMu protects the streaming bash output fields from concurrent access.
|
||||
streamingMu sync.RWMutex
|
||||
// streaming bash fields are only mutated/read from the Bubble Tea event loop
|
||||
// (Update/View), so no mutex is required here.
|
||||
// streamingBashCommand holds the command being executed for display as a header.
|
||||
streamingBashCommand string
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -1070,6 +1086,45 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
// In other states pass ESC through to children below.
|
||||
|
||||
case "ctrl+s":
|
||||
// Steer: inject the current input as a steering message into the
|
||||
// running agent turn. Only active during stateWorking — in input
|
||||
// state, Ctrl+S is passed through to children (no-op by default).
|
||||
if m.state == stateWorking && m.appCtrl != nil {
|
||||
var text string
|
||||
if ic, ok := m.input.(*InputComponent); ok {
|
||||
text = strings.TrimSpace(ic.textarea.Value())
|
||||
}
|
||||
if text != "" {
|
||||
// Clear the input and push to history.
|
||||
if ic, ok := m.input.(*InputComponent); ok {
|
||||
ic.pushHistory(text)
|
||||
ic.textarea.SetValue("")
|
||||
}
|
||||
|
||||
// Preprocess @file references.
|
||||
processedText := text
|
||||
if m.cwd != "" {
|
||||
processedText = ProcessFileAttachments(text, m.cwd)
|
||||
}
|
||||
|
||||
// Inject the steer message.
|
||||
sLen := m.appCtrl.Steer(processedText)
|
||||
if sLen > 0 {
|
||||
m.steeringMessages = append(m.steeringMessages, text)
|
||||
m.distributeHeight()
|
||||
} else {
|
||||
// Started immediately (agent was idle).
|
||||
m.pendingUserPrints = append(m.pendingUserPrints, text)
|
||||
m.flushStreamAndPendingUserMessages()
|
||||
if m.state != stateWorking {
|
||||
m.state = stateWorking
|
||||
}
|
||||
}
|
||||
}
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
}
|
||||
|
||||
// Route key events to the focused child. Check for editor
|
||||
@@ -1316,6 +1371,16 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// rendered when the ToolResultEvent arrives.
|
||||
m.flushStreamContent()
|
||||
|
||||
// For bash commands, extract and store the command for the streaming output header.
|
||||
if msg.ToolName == "bash" {
|
||||
var args struct {
|
||||
Command string `json:"command"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(msg.ToolArgs), &args); err == nil && args.Command != "" {
|
||||
m.streamingBashCommand = args.Command
|
||||
}
|
||||
}
|
||||
|
||||
case app.ToolExecutionEvent:
|
||||
// Pass to stream component for execution spinner display.
|
||||
if m.stream != nil {
|
||||
@@ -1327,10 +1392,9 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// Buffer tool result for scrollback.
|
||||
m.printToolResult(msg)
|
||||
// Clear streaming bash output since tool completed.
|
||||
m.streamingMu.Lock()
|
||||
m.streamingBashOutput = nil
|
||||
m.streamingBashStderr = nil
|
||||
m.streamingMu.Unlock()
|
||||
m.streamingBashCommand = ""
|
||||
// Start spinner again while waiting for the next LLM response.
|
||||
if m.stream != nil {
|
||||
_, cmd := m.stream.Update(app.SpinnerEvent{Show: true})
|
||||
@@ -1339,7 +1403,6 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
case app.ToolOutputEvent:
|
||||
// Accumulate streaming bash output for display.
|
||||
m.streamingMu.Lock()
|
||||
if msg.IsStderr {
|
||||
m.streamingBashStderr = append(m.streamingBashStderr, msg.Chunk)
|
||||
// Cap stderr lines to prevent memory issues.
|
||||
@@ -1353,7 +1416,6 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.streamingBashOutput = m.streamingBashOutput[len(m.streamingBashOutput)-m.streamingBashMaxLines:]
|
||||
}
|
||||
}
|
||||
m.streamingMu.Unlock()
|
||||
|
||||
case app.ToolCallContentEvent:
|
||||
// In streaming mode this text was already delivered via StreamChunkEvents
|
||||
@@ -1389,6 +1451,38 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
m.distributeHeight()
|
||||
|
||||
case app.SteerConsumedEvent:
|
||||
// Steering messages were consumed — either injected mid-turn via
|
||||
// PrepareStep, or drained into the queue after a text-only turn.
|
||||
//
|
||||
// Two cases:
|
||||
//
|
||||
// 1. Mid-turn (stateWorking, PrepareStep fired): no SpinnerEvent{Show:
|
||||
// true} will follow within this turn, so we cannot rely on
|
||||
// flushStreamAndPendingUserMessages() being called. Flush any live
|
||||
// stream content first (assistant text up to the steer point), then
|
||||
// render the steering user messages immediately to scrollback.
|
||||
//
|
||||
// 2. Post-turn (text-only response, drained after StepComplete): a
|
||||
// SpinnerEvent{Show: true} for the next turn is already in flight.
|
||||
// Defer to pendingUserPrints so the previous assistant response is
|
||||
// flushed first, preserving chronological order.
|
||||
if m.state == stateWorking {
|
||||
// Case 1: mid-turn — flush + print immediately.
|
||||
m.flushStreamContent()
|
||||
for _, text := range m.steeringMessages {
|
||||
m.printUserMessage(text)
|
||||
}
|
||||
m.steeringMessages = m.steeringMessages[:0]
|
||||
m.distributeHeight()
|
||||
cmds = append(cmds, m.drainScrollback())
|
||||
} else {
|
||||
// Case 2: post-turn — defer so SpinnerEvent orders correctly.
|
||||
m.pendingUserPrints = append(m.pendingUserPrints, m.steeringMessages...)
|
||||
m.steeringMessages = m.steeringMessages[:0]
|
||||
m.distributeHeight()
|
||||
}
|
||||
|
||||
case app.StepCompleteEvent:
|
||||
// Keep stream content visible in the view — don't flush to scrollback
|
||||
// yet. Flushing + resetting in the same frame would shrink the view
|
||||
@@ -1641,6 +1735,7 @@ func (m *AppModel) View() tea.View {
|
||||
// Propagate hint visibility to the input component before rendering.
|
||||
if ic, ok := m.input.(*InputComponent); ok {
|
||||
ic.hideHint = vis.HideInputHint
|
||||
ic.agentBusy = m.state == stateWorking
|
||||
}
|
||||
|
||||
// When a prompt is active, it replaces the input area for consistency
|
||||
@@ -1742,20 +1837,23 @@ func (m *AppModel) renderStream() string {
|
||||
|
||||
// renderStreamingBashOutput renders accumulated streaming bash output (stdout + stderr)
|
||||
// below the LLM streaming text. Returns empty string if no bash output is present.
|
||||
// Lines are truncated to the terminal width and capped to maxBashLines to prevent
|
||||
// long-running commands from blowing up the TUI layout.
|
||||
func (m *AppModel) renderStreamingBashOutput(theme Theme) string {
|
||||
m.streamingMu.RLock()
|
||||
stdoutLines := make([]string, len(m.streamingBashOutput))
|
||||
copy(stdoutLines, m.streamingBashOutput)
|
||||
stderrLines := make([]string, len(m.streamingBashStderr))
|
||||
copy(stderrLines, m.streamingBashStderr)
|
||||
m.streamingMu.RUnlock()
|
||||
command := m.streamingBashCommand
|
||||
|
||||
if len(stdoutLines) == 0 && len(stderrLines) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
const lineIndent = " "
|
||||
width := m.width - 2 // Account for indent and padding
|
||||
lineWidth := max(m.width-2-len(lineIndent), 20)
|
||||
// Account for PaddingLeft(1) on the output/stderr styles.
|
||||
maxLineChars := lineWidth - 1
|
||||
|
||||
outputStyle := lipgloss.NewStyle().
|
||||
Background(theme.CodeBg).
|
||||
@@ -1766,17 +1864,59 @@ func (m *AppModel) renderStreamingBashOutput(theme Theme) string {
|
||||
Background(theme.CodeBg).
|
||||
PaddingLeft(1)
|
||||
|
||||
// Header style for the command - muted text with a subtle indicator.
|
||||
headerStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
PaddingLeft(1)
|
||||
|
||||
// Cap displayed lines to maxBashLines (show the tail, since streaming
|
||||
// output is most useful at the end). The buffer itself is larger to
|
||||
// preserve context, but we only render the last N lines.
|
||||
totalLines := len(stdoutLines) + len(stderrLines)
|
||||
var hiddenCount int
|
||||
if totalLines > maxBashLines {
|
||||
hiddenCount = totalLines - maxBashLines
|
||||
// Trim from stdout first (older output), then stderr.
|
||||
remaining := maxBashLines
|
||||
if len(stderrLines) >= remaining {
|
||||
stdoutLines = nil
|
||||
stderrLines = stderrLines[len(stderrLines)-remaining:]
|
||||
} else {
|
||||
remaining -= len(stderrLines)
|
||||
if len(stdoutLines) > remaining {
|
||||
stdoutLines = stdoutLines[len(stdoutLines)-remaining:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var lines []string
|
||||
|
||||
// Command header - show the bash command being executed.
|
||||
if command != "" {
|
||||
headerText := fmt.Sprintf("$ %s", command)
|
||||
headerContent := headerStyle.Width(lineWidth).Render(truncateLine(headerText, maxLineChars))
|
||||
lines = append(lines, lineIndent+headerContent)
|
||||
}
|
||||
|
||||
// Truncation hint at the top.
|
||||
if hiddenCount > 0 {
|
||||
hint := fmt.Sprintf("...(%d more lines above)", hiddenCount)
|
||||
hintContent := outputStyle.Width(lineWidth).
|
||||
Foreground(theme.Muted).Italic(true).Render(hint)
|
||||
lines = append(lines, lineIndent+hintContent)
|
||||
}
|
||||
|
||||
// Render stdout lines.
|
||||
for _, line := range stdoutLines {
|
||||
styled := outputStyle.Width(width - len(lineIndent)).Render(line)
|
||||
line = truncateLine(strings.TrimRight(line, "\n"), maxLineChars)
|
||||
styled := outputStyle.Width(lineWidth).Render(line)
|
||||
lines = append(lines, lineIndent+styled)
|
||||
}
|
||||
|
||||
// Render stderr lines with error styling.
|
||||
for _, line := range stderrLines {
|
||||
styled := stderrStyle.Width(width - len(lineIndent)).Render(line)
|
||||
line = truncateLine(strings.TrimRight(line, "\n"), maxLineChars)
|
||||
styled := stderrStyle.Width(lineWidth).Render(line)
|
||||
lines = append(lines, lineIndent+styled)
|
||||
}
|
||||
|
||||
@@ -1901,16 +2041,26 @@ func (m *AppModel) cycleThinkingLevel() {
|
||||
go func() { _ = SaveThinkingLevelPreference(next) }()
|
||||
}
|
||||
|
||||
// renderSeparator renders the separator line with an optional queue count badge.
|
||||
// renderSeparator renders the separator line with an optional queue/steer count badge.
|
||||
func (m *AppModel) renderSeparator() string {
|
||||
theme := GetTheme()
|
||||
lineStyle := lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
queueLen := len(m.queuedMessages)
|
||||
steerLen := len(m.steeringMessages)
|
||||
|
||||
if queueLen > 0 {
|
||||
badge := lipgloss.NewStyle().
|
||||
Foreground(theme.Secondary).
|
||||
Render(fmt.Sprintf("%d queued", queueLen))
|
||||
if steerLen > 0 || queueLen > 0 {
|
||||
var parts []string
|
||||
if steerLen > 0 {
|
||||
parts = append(parts, lipgloss.NewStyle().
|
||||
Foreground(theme.Warning).
|
||||
Render(fmt.Sprintf("%d steering", steerLen)))
|
||||
}
|
||||
if queueLen > 0 {
|
||||
parts = append(parts, lipgloss.NewStyle().
|
||||
Foreground(theme.Secondary).
|
||||
Render(fmt.Sprintf("%d queued", queueLen)))
|
||||
}
|
||||
badge := strings.Join(parts, " ")
|
||||
|
||||
// Fill the separator with dashes up to the badge.
|
||||
dashWidth := max(m.width-lipgloss.Width(badge)-1, 0)
|
||||
@@ -2009,27 +2159,47 @@ func (m *AppModel) renderHeaderFooter(getter func() *WidgetData) string {
|
||||
return renderContentBlock(data.Text, m.width, opts...)
|
||||
}
|
||||
|
||||
// renderQueuedMessages renders queued prompts as styled content blocks with a
|
||||
// "QUEUED" badge, anchored between the separator and input. Each message is
|
||||
// displayed in a bordered block matching the overall message styling.
|
||||
// renderQueuedMessages renders queued and steering prompts as styled content
|
||||
// blocks with badges, anchored between the separator and input. Steering
|
||||
// messages use a distinct "STEERING" badge to differentiate from queued ones.
|
||||
func (m *AppModel) renderQueuedMessages() string {
|
||||
if len(m.queuedMessages) == 0 {
|
||||
if len(m.queuedMessages) == 0 && len(m.steeringMessages) == 0 {
|
||||
return ""
|
||||
}
|
||||
theme := GetTheme()
|
||||
badge := CreateBadge("QUEUED", theme.Accent)
|
||||
|
||||
var blocks []string
|
||||
for _, msg := range m.queuedMessages {
|
||||
content := msg + "\n" + badge
|
||||
rendered := renderContentBlock(
|
||||
content,
|
||||
m.width,
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(theme.Muted),
|
||||
)
|
||||
blocks = append(blocks, rendered)
|
||||
|
||||
// Render steering messages first (higher priority).
|
||||
if len(m.steeringMessages) > 0 {
|
||||
badge := CreateBadge("STEERING", theme.Warning)
|
||||
for _, msg := range m.steeringMessages {
|
||||
content := msg + "\n" + badge
|
||||
rendered := renderContentBlock(
|
||||
content,
|
||||
m.width,
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(theme.Warning),
|
||||
)
|
||||
blocks = append(blocks, rendered)
|
||||
}
|
||||
}
|
||||
|
||||
// Render queued messages.
|
||||
if len(m.queuedMessages) > 0 {
|
||||
badge := CreateBadge("QUEUED", theme.Accent)
|
||||
for _, msg := range m.queuedMessages {
|
||||
content := msg + "\n" + badge
|
||||
rendered := renderContentBlock(
|
||||
content,
|
||||
m.width,
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(theme.Muted),
|
||||
)
|
||||
blocks = append(blocks, rendered)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(blocks, "\n")
|
||||
}
|
||||
|
||||
@@ -2100,6 +2270,7 @@ func (m *AppModel) handleSlashCommand(sc *SlashCommand) tea.Cmd {
|
||||
m.appCtrl.ClearQueue()
|
||||
}
|
||||
m.queuedMessages = m.queuedMessages[:0]
|
||||
m.steeringMessages = m.steeringMessages[:0]
|
||||
m.distributeHeight()
|
||||
|
||||
case "/tree":
|
||||
@@ -2246,7 +2417,7 @@ func (m *AppModel) printHelpMessage() {
|
||||
"**Navigation:**\n" +
|
||||
"- `/tree`: Navigate session tree (switch branches)\n" +
|
||||
"- `/fork`: Branch from an earlier message\n" +
|
||||
"- `/new`: Start a new branch (preserves history)\n" +
|
||||
"- `/new`: Start a new session (discards context, saves old session)\n" +
|
||||
"- `/resume`: Open session picker to switch sessions\n" +
|
||||
"- `/name <name>`: Set a display name for this session\n\n" +
|
||||
"**System:**\n" +
|
||||
@@ -2287,7 +2458,9 @@ func (m *AppModel) printHelpMessage() {
|
||||
"- `!!command`: Run shell command, output excluded from LLM context\n\n" +
|
||||
"**Keys:**\n" +
|
||||
"- `Ctrl+C`: Exit at any time\n" +
|
||||
"- `ESC` (x2): Cancel ongoing LLM generation\n\n" +
|
||||
"- `ESC` (x2): Cancel ongoing LLM generation\n" +
|
||||
"- `Ctrl+S`: Steer — redirect the agent mid-turn (injected between tool calls)\n" +
|
||||
"- `Enter` (while working): Queue message for after the agent finishes\n\n" +
|
||||
"You can also just type your message to chat with the AI assistant."
|
||||
m.printSystemMessage(help)
|
||||
}
|
||||
@@ -2812,7 +2985,8 @@ func (m *AppModel) handleForkCommand() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleNewCommand starts a fresh session by resetting the tree leaf.
|
||||
// handleNewCommand starts a completely new session (Pi-style /new behavior).
|
||||
// Creates a new session file, discarding all context from the previous conversation.
|
||||
func (m *AppModel) handleNewCommand() tea.Cmd {
|
||||
// Emit before-session-switch event in a goroutine so that extension
|
||||
// handlers can call blocking operations (e.g. ctx.PromptConfirm) without
|
||||
@@ -2835,6 +3009,8 @@ func (m *AppModel) handleNewCommand() tea.Cmd {
|
||||
|
||||
// performNewSession performs the actual session reset. Called either directly
|
||||
// (when no before-hook exists) or after the async hook completes.
|
||||
// Matches Pi behavior: creates a completely new session file, discarding all
|
||||
// context from the previous conversation.
|
||||
func (m *AppModel) performNewSession() tea.Cmd {
|
||||
ts := m.appCtrl.GetTreeSession()
|
||||
if ts == nil {
|
||||
@@ -2846,11 +3022,16 @@ func (m *AppModel) performNewSession() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
ts.ResetLeaf()
|
||||
if m.appCtrl != nil {
|
||||
m.appCtrl.ClearMessages()
|
||||
// Create a brand new session file (Pi-style /new behavior)
|
||||
newTs, err := session.CreateTreeSession(m.cwd)
|
||||
if err != nil {
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to create new session: %v", err))
|
||||
return nil
|
||||
}
|
||||
m.printSystemMessage("New branch started. Previous conversation is preserved in the tree.")
|
||||
|
||||
// Switch to the new session, closing the old one
|
||||
m.appCtrl.SwitchTreeSession(newTs)
|
||||
m.printSystemMessage("New session started. Previous conversation saved.")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -54,6 +54,10 @@ func (s *stubAppController) GetTreeSession() *session.TreeManager {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAppController) SwitchTreeSession(_ *session.TreeManager) {
|
||||
// no-op in tests
|
||||
}
|
||||
|
||||
func (s *stubAppController) SendEvent(_ tea.Msg) {
|
||||
// no-op in tests
|
||||
}
|
||||
@@ -67,6 +71,11 @@ func (s *stubAppController) RunWithFiles(prompt string, _ []fantasy.FilePart) in
|
||||
return s.queueLen
|
||||
}
|
||||
|
||||
func (s *stubAppController) Steer(prompt string) int {
|
||||
s.runCalls = append(s.runCalls, prompt)
|
||||
return s.queueLen
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Stub child components
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -679,6 +688,57 @@ func TestToolResult_clearsStreamingBashOutput(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolCallStarted_extractsBashCommand verifies that ToolCallStartedEvent
|
||||
// extracts the bash command from ToolArgs and stores it for the streaming output header.
|
||||
func TestToolCallStarted_extractsBashCommand(t *testing.T) {
|
||||
ctrl := &stubAppController{}
|
||||
m, _, _ := newTestAppModel(ctrl)
|
||||
m.state = stateWorking
|
||||
|
||||
// Send ToolCallStartedEvent with bash command.
|
||||
m = sendMsg(m, app.ToolCallStartedEvent{
|
||||
ToolCallID: "call-1",
|
||||
ToolName: "bash",
|
||||
ToolArgs: `{"command":"ls -la /home"}`,
|
||||
})
|
||||
|
||||
if m.streamingBashCommand != "ls -la /home" {
|
||||
t.Fatalf("expected streamingBashCommand='ls -la /home', got %q", m.streamingBashCommand)
|
||||
}
|
||||
|
||||
// ToolResultEvent should clear the command.
|
||||
m = sendMsg(m, app.ToolResultEvent{
|
||||
ToolCallID: "call-1",
|
||||
ToolName: "bash",
|
||||
ToolArgs: `{"command":"ls -la /home"}`,
|
||||
Result: "output",
|
||||
IsError: false,
|
||||
})
|
||||
|
||||
if m.streamingBashCommand != "" {
|
||||
t.Fatalf("expected streamingBashCommand cleared, got %q", m.streamingBashCommand)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolCallStarted_nonBashTool_doesNotSetCommand verifies that non-bash tools
|
||||
// do not set the streamingBashCommand field.
|
||||
func TestToolCallStarted_nonBashTool_doesNotSetCommand(t *testing.T) {
|
||||
ctrl := &stubAppController{}
|
||||
m, _, _ := newTestAppModel(ctrl)
|
||||
m.state = stateWorking
|
||||
|
||||
// Send ToolCallStartedEvent with a non-bash tool.
|
||||
m = sendMsg(m, app.ToolCallStartedEvent{
|
||||
ToolCallID: "call-1",
|
||||
ToolName: "read",
|
||||
ToolArgs: `{"file":"/etc/passwd"}`,
|
||||
})
|
||||
|
||||
if m.streamingBashCommand != "" {
|
||||
t.Fatalf("expected streamingBashCommand to remain empty for non-bash tools, got %q", m.streamingBashCommand)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStepError_printCmd verifies that StepErrorEvent with a non-nil error
|
||||
// produces a non-nil cmd (the tea.Println call for the error message).
|
||||
func TestStepError_printCmd(t *testing.T) {
|
||||
|
||||
+109
-59
@@ -79,7 +79,12 @@ func streamSpinnerTickCmd(generation uint64) tea.Cmd {
|
||||
// streamFlushTickMsg fires when it's time to commit pending chunks to the
|
||||
// main content builders and trigger a re-render. This coalesces rapid
|
||||
// streaming chunks into fewer expensive markdown re-renders.
|
||||
type streamFlushTickMsg struct{}
|
||||
//
|
||||
// generation ties the tick to the pending flush session that created it so
|
||||
// stale ticks from a prior Reset() are discarded.
|
||||
type streamFlushTickMsg struct {
|
||||
generation uint64
|
||||
}
|
||||
|
||||
// streamFlushInterval is the coalescing window for stream chunks. Chunks
|
||||
// arriving within this window are batched into a single render pass.
|
||||
@@ -89,9 +94,9 @@ const streamFlushInterval = 16 * time.Millisecond
|
||||
|
||||
// streamFlushTickCmd returns a tea.Cmd that fires streamFlushTickMsg after
|
||||
// the coalescing interval.
|
||||
func streamFlushTickCmd() tea.Cmd {
|
||||
func streamFlushTickCmd(generation uint64) tea.Cmd {
|
||||
return tea.Tick(streamFlushInterval, func(_ time.Time) tea.Msg {
|
||||
return streamFlushTickMsg{}
|
||||
return streamFlushTickMsg{generation: generation}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -149,9 +154,11 @@ type StreamComponent struct {
|
||||
// spinnerFrame is the current frame index.
|
||||
spinnerFrame int
|
||||
|
||||
// activeTools tracks the names of tools currently executing in parallel.
|
||||
// When multiple tools run concurrently, all are displayed in the spinner.
|
||||
activeTools []string
|
||||
// activeTools maps ToolCallID -> display label for currently running tools.
|
||||
activeTools map[string]string
|
||||
|
||||
// activeToolOrder preserves deterministic display order for active tools.
|
||||
activeToolOrder []string
|
||||
|
||||
// streamContent holds committed streaming text (flushed from pending).
|
||||
streamContent strings.Builder
|
||||
@@ -172,6 +179,10 @@ type StreamComponent struct {
|
||||
// the same coalescing window.
|
||||
flushPending bool
|
||||
|
||||
// flushGeneration is incremented when stream state resets so stale flush
|
||||
// ticks from a previous step can be discarded.
|
||||
flushGeneration uint64
|
||||
|
||||
// renderCache holds the last rendered output string. Reused by View()
|
||||
// between flush ticks to avoid redundant markdown re-parsing.
|
||||
renderCache string
|
||||
@@ -190,14 +201,8 @@ type StreamComponent struct {
|
||||
// reasoningDuration holds the total reasoning time, frozen when streaming text begins.
|
||||
reasoningDuration time.Duration
|
||||
|
||||
// messageRenderer renders assistant messages in standard mode.
|
||||
messageRenderer *MessageRenderer
|
||||
|
||||
// compactRenderer renders assistant messages in compact mode.
|
||||
compactRenderer *CompactRenderer
|
||||
|
||||
// compactMode selects which renderer to use.
|
||||
compactMode bool
|
||||
// renderer renders streaming assistant text in either compact or standard mode.
|
||||
renderer Renderer
|
||||
|
||||
// modelName is displayed in the streaming text header.
|
||||
modelName string
|
||||
@@ -218,13 +223,19 @@ func NewStreamComponent(compactMode bool, width int, modelName string) *StreamCo
|
||||
if width == 0 {
|
||||
width = 80
|
||||
}
|
||||
|
||||
var renderer Renderer
|
||||
if compactMode {
|
||||
renderer = NewCompactRenderer(width, false)
|
||||
} else {
|
||||
renderer = newMessageRenderer(width, false)
|
||||
}
|
||||
|
||||
return &StreamComponent{
|
||||
spinnerFrames: knightRiderFrames(),
|
||||
compactMode: compactMode,
|
||||
modelName: modelName,
|
||||
messageRenderer: newMessageRenderer(width, false),
|
||||
compactRenderer: NewCompactRenderer(width, false),
|
||||
width: width,
|
||||
spinnerFrames: knightRiderFrames(),
|
||||
modelName: modelName,
|
||||
renderer: renderer,
|
||||
width: width,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -251,11 +262,13 @@ func (s *StreamComponent) Reset() {
|
||||
s.spinnerGeneration++ // invalidate any in-flight tick commands
|
||||
s.spinnerFrame = 0
|
||||
s.activeTools = nil
|
||||
s.activeToolOrder = nil
|
||||
s.streamContent.Reset()
|
||||
s.reasoningContent.Reset()
|
||||
s.pendingStream.Reset()
|
||||
s.pendingReasoning.Reset()
|
||||
s.flushPending = false
|
||||
s.flushGeneration++
|
||||
s.renderCache = ""
|
||||
s.renderDirty = false
|
||||
s.timestamp = time.Time{}
|
||||
@@ -282,7 +295,8 @@ func (s *StreamComponent) GetRenderedContent() string {
|
||||
|
||||
text := s.streamContent.String()
|
||||
if text != "" {
|
||||
sections = append(sections, s.renderStreamingText(text))
|
||||
rendered := s.renderStreamingText(text)
|
||||
sections = append(sections, rendered)
|
||||
}
|
||||
|
||||
if len(sections) == 0 {
|
||||
@@ -322,8 +336,9 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
case tea.WindowSizeMsg:
|
||||
s.width = msg.Width
|
||||
s.messageRenderer.SetWidth(s.width)
|
||||
s.compactRenderer.SetWidth(s.width)
|
||||
if s.renderer != nil {
|
||||
s.renderer.SetWidth(s.width)
|
||||
}
|
||||
// Invalidate render cache — width change affects wrapping/styling.
|
||||
s.renderCache = ""
|
||||
s.renderDirty = true
|
||||
@@ -359,6 +374,9 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
|
||||
case streamFlushTickMsg:
|
||||
if msg.generation != s.flushGeneration {
|
||||
break
|
||||
}
|
||||
s.flushPending = false
|
||||
s.commitPending()
|
||||
|
||||
@@ -373,7 +391,7 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
s.pendingReasoning.WriteString(msg.Delta)
|
||||
if !s.flushPending {
|
||||
s.flushPending = true
|
||||
return s, streamFlushTickCmd()
|
||||
return s, streamFlushTickCmd(s.flushGeneration)
|
||||
}
|
||||
|
||||
case app.StreamChunkEvent:
|
||||
@@ -388,14 +406,25 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
s.pendingStream.WriteString(msg.Content)
|
||||
if !s.flushPending {
|
||||
s.flushPending = true
|
||||
return s, streamFlushTickCmd()
|
||||
return s, streamFlushTickCmd(s.flushGeneration)
|
||||
}
|
||||
|
||||
case app.ToolExecutionEvent:
|
||||
toolID := msg.ToolCallID
|
||||
if toolID == "" {
|
||||
// Defensive fallback for older/third-party emitters that may omit
|
||||
// ToolCallID. Best-effort only: same-name+args concurrent calls can
|
||||
// still collide without a stable ID.
|
||||
toolID = fmt.Sprintf("%s|%s", msg.ToolName, msg.ToolArgs)
|
||||
}
|
||||
if msg.IsStarting {
|
||||
// Add tool to active list for parallel execution display.
|
||||
toolDisplay := formatToolExecutionMessage(msg.ToolName, msg.ToolArgs)
|
||||
s.activeTools = append(s.activeTools, toolDisplay)
|
||||
if s.activeTools == nil {
|
||||
s.activeTools = make(map[string]string)
|
||||
}
|
||||
if _, exists := s.activeTools[toolID]; !exists {
|
||||
s.activeToolOrder = append(s.activeToolOrder, toolID)
|
||||
}
|
||||
s.activeTools[toolID] = formatToolExecutionMessage(msg.ToolName)
|
||||
s.spinnerFrame = 0
|
||||
if !s.spinning {
|
||||
s.phase = streamPhaseActive
|
||||
@@ -404,9 +433,10 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return s, streamSpinnerTickCmd(s.spinnerGeneration)
|
||||
}
|
||||
} else {
|
||||
// Tool finished — remove from active list but keep spinning if others remain.
|
||||
toolDisplay := formatToolExecutionMessage(msg.ToolName, msg.ToolArgs)
|
||||
s.activeTools = removeFromSlice(s.activeTools, toolDisplay)
|
||||
if s.activeTools != nil {
|
||||
delete(s.activeTools, toolID)
|
||||
}
|
||||
s.activeToolOrder = removeToolID(s.activeToolOrder, toolID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -415,7 +445,9 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
// View implements tea.Model. Renders the current stream region content.
|
||||
func (s *StreamComponent) View() tea.View {
|
||||
return tea.NewView(s.render())
|
||||
fullContent := s.render()
|
||||
visibleContent := s.viewContent(fullContent)
|
||||
return tea.NewView(visibleContent)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -458,21 +490,27 @@ func (s *StreamComponent) render() string {
|
||||
|
||||
content := strings.Join(sections, "\n")
|
||||
|
||||
// Clamp to height if constrained: keep the last h lines so the most
|
||||
// recent output is always visible.
|
||||
if s.height > 0 && content != "" {
|
||||
lines := strings.Split(content, "\n")
|
||||
if len(lines) > s.height {
|
||||
lines = lines[len(lines)-s.height:]
|
||||
content = strings.Join(lines, "\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Cache FULL content without height clamping.
|
||||
// Height clamping is applied in View() for display only.
|
||||
s.renderCache = content
|
||||
s.renderDirty = false
|
||||
return content
|
||||
}
|
||||
|
||||
// viewContent returns the visible portion of content based on height constraint.
|
||||
// This is called by View() to get the slice that fits in the terminal.
|
||||
func (s *StreamComponent) viewContent(fullContent string) string {
|
||||
if s.height > 0 && fullContent != "" {
|
||||
lines := strings.Split(fullContent, "\n")
|
||||
if len(lines) > s.height {
|
||||
// Keep only the last h lines so the most recent output is visible.
|
||||
lines = lines[len(lines)-s.height:]
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
}
|
||||
return fullContent
|
||||
}
|
||||
|
||||
// renderReasoningBlock renders the reasoning/thinking content in a surface-tinted
|
||||
// box. When collapsed, shows the last 10 lines with a truncation hint. When
|
||||
// expanded, shows all lines. Includes a "Thought for Xs" duration footer.
|
||||
@@ -559,7 +597,8 @@ func (s *StreamComponent) SpinnerView() string {
|
||||
return ""
|
||||
}
|
||||
frame := s.spinnerFrames[s.spinnerFrame%len(s.spinnerFrames)]
|
||||
if len(s.activeTools) == 0 {
|
||||
tools := s.activeToolDisplays()
|
||||
if len(tools) == 0 {
|
||||
return " " + frame
|
||||
}
|
||||
theme := GetTheme()
|
||||
@@ -569,10 +608,10 @@ func (s *StreamComponent) SpinnerView() string {
|
||||
|
||||
// Format active tools list
|
||||
var toolsMsg string
|
||||
if len(s.activeTools) == 1 {
|
||||
toolsMsg = s.activeTools[0]
|
||||
if len(tools) == 1 {
|
||||
toolsMsg = tools[0]
|
||||
} else {
|
||||
toolsMsg = "Running: " + strings.Join(s.activeTools, ", ")
|
||||
toolsMsg = "Running: " + strings.Join(tools, ", ")
|
||||
}
|
||||
return " " + frame + " " + msgStyle.Render(toolsMsg)
|
||||
}
|
||||
@@ -584,28 +623,39 @@ func (s *StreamComponent) renderStreamingText(text string) string {
|
||||
if ts.IsZero() {
|
||||
ts = time.Now()
|
||||
}
|
||||
|
||||
if s.compactMode {
|
||||
msg := s.compactRenderer.RenderAssistantMessage(text, ts, s.modelName)
|
||||
return msg.Content
|
||||
if s.renderer == nil {
|
||||
return text
|
||||
}
|
||||
msg := s.messageRenderer.RenderAssistantMessage(text, ts, s.modelName)
|
||||
msg := s.renderer.RenderAssistantMessage(text, ts, s.modelName)
|
||||
return msg.Content
|
||||
}
|
||||
|
||||
// removeFromSlice removes the first occurrence of a string from a slice.
|
||||
func removeFromSlice(slice []string, s string) []string {
|
||||
for i, v := range slice {
|
||||
if v == s {
|
||||
return append(slice[:i], slice[i+1:]...)
|
||||
func (s *StreamComponent) activeToolDisplays() []string {
|
||||
if len(s.activeTools) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(s.activeToolOrder))
|
||||
for _, id := range s.activeToolOrder {
|
||||
if display, ok := s.activeTools[id]; ok {
|
||||
out = append(out, display)
|
||||
}
|
||||
}
|
||||
return slice
|
||||
return out
|
||||
}
|
||||
|
||||
// removeToolID removes the first occurrence of a tool ID from a slice.
|
||||
func removeToolID(ids []string, id string) []string {
|
||||
for i, v := range ids {
|
||||
if v == id {
|
||||
return append(ids[:i], ids[i+1:]...)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// formatToolExecutionMessage creates a descriptive spinner message for tool execution.
|
||||
// For spawn_subagent, it shows simply as "Subagent" with optional task preview.
|
||||
func formatToolExecutionMessage(toolName, toolArgs string) string {
|
||||
// For spawn_subagent, it shows simply as "Subagent".
|
||||
func formatToolExecutionMessage(toolName string) string {
|
||||
if toolName == "spawn_subagent" {
|
||||
return "Subagent"
|
||||
}
|
||||
|
||||
+110
-15
@@ -7,11 +7,86 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Color derivation helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// parseHexColor parses a "#RRGGBB" hex string into r, g, b components (0-255).
|
||||
func parseHexColor(hex string) (r, g, b int) {
|
||||
hex = strings.TrimPrefix(hex, "#")
|
||||
if len(hex) == 6 {
|
||||
if v, err := strconv.ParseUint(hex[0:2], 16, 8); err == nil {
|
||||
r = int(v)
|
||||
}
|
||||
if v, err := strconv.ParseUint(hex[2:4], 16, 8); err == nil {
|
||||
g = int(v)
|
||||
}
|
||||
if v, err := strconv.ParseUint(hex[4:6], 16, 8); err == nil {
|
||||
b = int(v)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// blendHex linearly interpolates between two hex colors by amount (0.0–1.0).
|
||||
func blendHex(base, tint string, amount float64) string {
|
||||
br, bg, bb := parseHexColor(base)
|
||||
tr, tg, tb := parseHexColor(tint)
|
||||
clamp := func(v int) int {
|
||||
if v < 0 {
|
||||
return 0
|
||||
}
|
||||
if v > 255 {
|
||||
return 255
|
||||
}
|
||||
return v
|
||||
}
|
||||
r := clamp(int(float64(br)*(1-amount) + float64(tr)*amount))
|
||||
g := clamp(int(float64(bg)*(1-amount) + float64(tg)*amount))
|
||||
b := clamp(int(float64(bb)*(1-amount) + float64(tb)*amount))
|
||||
return fmt.Sprintf("#%02x%02x%02x", r, g, b)
|
||||
}
|
||||
|
||||
// deriveDiffBg computes diff / code background colors from the theme's
|
||||
// background, success, and error hex pairs. Returns an adaptive color for each
|
||||
// diff element. The tint amounts are tuned for subtle differentiation.
|
||||
func deriveDiffBg(bgPair, successPair, errorPair [2]string) (diffInsert, diffDelete, diffEqual, diffMissing, codeBg, gutterBg, writeBg color.Color) {
|
||||
derive := func(idx int) (color.Color, color.Color, color.Color, color.Color) {
|
||||
bg := bgPair[idx]
|
||||
// Contrast target: darken for light mode (idx 0), lighten for dark (idx 1).
|
||||
contrast := "#000000"
|
||||
if idx == 1 {
|
||||
contrast = "#ffffff"
|
||||
}
|
||||
ins := blendHex(bg, successPair[idx], 0.13)
|
||||
del := blendHex(bg, errorPair[idx], 0.13)
|
||||
eq := blendHex(bg, contrast, 0.05)
|
||||
miss := blendHex(bg, contrast, 0.03)
|
||||
return AdaptiveColor(ins, ins), AdaptiveColor(del, del), AdaptiveColor(eq, eq), AdaptiveColor(miss, miss)
|
||||
}
|
||||
|
||||
// Pick the correct index based on detected background.
|
||||
idx := 0
|
||||
if isDarkBg {
|
||||
idx = 1
|
||||
}
|
||||
insL, delL, eqL, missL := derive(idx)
|
||||
diffInsert = insL
|
||||
diffDelete = delL
|
||||
diffEqual = eqL
|
||||
diffMissing = missL
|
||||
codeBg = eqL
|
||||
gutterBg = missL
|
||||
writeBg = insL
|
||||
return
|
||||
}
|
||||
|
||||
// ThemeEntry is a named, loadable theme — either built-in or discovered from disk.
|
||||
type ThemeEntry struct {
|
||||
Name string // Display name (filename stem or preset name)
|
||||
@@ -80,14 +155,9 @@ func makeTheme(p presetColors) Theme {
|
||||
Accent: acOr(p.accent, ac(p.primary)),
|
||||
Highlight: acOr(p.highlight, def.Highlight),
|
||||
}
|
||||
// Derive diff/code backgrounds from the base background.
|
||||
t.DiffInsertBg = def.DiffInsertBg
|
||||
t.DiffDeleteBg = def.DiffDeleteBg
|
||||
t.DiffEqualBg = def.DiffEqualBg
|
||||
t.DiffMissingBg = def.DiffMissingBg
|
||||
t.CodeBg = def.CodeBg
|
||||
t.GutterBg = def.GutterBg
|
||||
t.WriteBg = def.WriteBg
|
||||
// Derive diff/code backgrounds from the theme's own palette.
|
||||
t.DiffInsertBg, t.DiffDeleteBg, t.DiffEqualBg, t.DiffMissingBg,
|
||||
t.CodeBg, t.GutterBg, t.WriteBg = deriveDiffBg(p.background, p.success, p.error_)
|
||||
// Markdown colors.
|
||||
t.Markdown = MarkdownThemeColors{
|
||||
Text: t.Text,
|
||||
@@ -609,6 +679,17 @@ func loadThemeFile(path string) (Theme, error) {
|
||||
|
||||
func fileConfigToTheme(cfg themeFileConfig) Theme {
|
||||
def := DefaultTheme()
|
||||
|
||||
// Resolve the base background/success/error hex pairs for diff derivation.
|
||||
// We need the raw hex strings to feed deriveDiffBg.
|
||||
bgPair := resolveHexPair(cfg.Background, [2]string{"#F0F0F0", "#0D0D0D"})
|
||||
successPair := resolveHexPair(cfg.Success, [2]string{"#998800", "#CCAA00"})
|
||||
errorPair := resolveHexPair(cfg.Error, [2]string{"#CC0000", "#FF3333"})
|
||||
|
||||
// Derive diff backgrounds from the theme's own palette.
|
||||
derivedInsert, derivedDelete, derivedEqual, derivedMissing,
|
||||
derivedCodeBg, derivedGutterBg, derivedWriteBg := deriveDiffBg(bgPair, successPair, errorPair)
|
||||
|
||||
return Theme{
|
||||
Primary: cfg.Primary.resolve(def.Primary),
|
||||
Secondary: cfg.Secondary.resolve(def.Secondary),
|
||||
@@ -627,13 +708,13 @@ func fileConfigToTheme(cfg themeFileConfig) Theme {
|
||||
Accent: cfg.Accent.resolve(def.Accent),
|
||||
Highlight: cfg.Highlight.resolve(def.Highlight),
|
||||
|
||||
DiffInsertBg: cfg.DiffInsertBg.resolve(def.DiffInsertBg),
|
||||
DiffDeleteBg: cfg.DiffDeleteBg.resolve(def.DiffDeleteBg),
|
||||
DiffEqualBg: cfg.DiffEqualBg.resolve(def.DiffEqualBg),
|
||||
DiffMissingBg: cfg.DiffMissingBg.resolve(def.DiffMissingBg),
|
||||
CodeBg: cfg.CodeBg.resolve(def.CodeBg),
|
||||
GutterBg: cfg.GutterBg.resolve(def.GutterBg),
|
||||
WriteBg: cfg.WriteBg.resolve(def.WriteBg),
|
||||
DiffInsertBg: cfg.DiffInsertBg.resolve(derivedInsert),
|
||||
DiffDeleteBg: cfg.DiffDeleteBg.resolve(derivedDelete),
|
||||
DiffEqualBg: cfg.DiffEqualBg.resolve(derivedEqual),
|
||||
DiffMissingBg: cfg.DiffMissingBg.resolve(derivedMissing),
|
||||
CodeBg: cfg.CodeBg.resolve(derivedCodeBg),
|
||||
GutterBg: cfg.GutterBg.resolve(derivedGutterBg),
|
||||
WriteBg: cfg.WriteBg.resolve(derivedWriteBg),
|
||||
|
||||
Markdown: MarkdownThemeColors{
|
||||
Text: cfg.Markdown.Text.resolve(def.Markdown.Text),
|
||||
@@ -651,3 +732,17 @@ func fileConfigToTheme(cfg themeFileConfig) Theme {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// resolveHexPair returns the hex pair from an adaptiveColorPair, falling back
|
||||
// to defaults when the pair is empty.
|
||||
func resolveHexPair(a adaptiveColorPair, fallback [2]string) [2]string {
|
||||
light := a.Light
|
||||
if light == "" {
|
||||
light = fallback[0]
|
||||
}
|
||||
dark := a.Dark
|
||||
if dark == "" {
|
||||
dark = fallback[1]
|
||||
}
|
||||
return [2]string{light, dark}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseHexColor(t *testing.T) {
|
||||
tests := []struct {
|
||||
hex string
|
||||
r, g, b int
|
||||
}{
|
||||
{"#000000", 0, 0, 0},
|
||||
{"#ffffff", 255, 255, 255},
|
||||
{"#1e1e2e", 0x1e, 0x1e, 0x2e},
|
||||
{"#a6e3a1", 0xa6, 0xe3, 0xa1},
|
||||
{"#f38ba8", 0xf3, 0x8b, 0xa8},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
r, g, b := parseHexColor(tt.hex)
|
||||
if r != tt.r || g != tt.g || b != tt.b {
|
||||
t.Errorf("parseHexColor(%q) = (%d,%d,%d), want (%d,%d,%d)",
|
||||
tt.hex, r, g, b, tt.r, tt.g, tt.b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBlendHex(t *testing.T) {
|
||||
// Blending with 0 amount should return the base color.
|
||||
got := blendHex("#1e1e2e", "#a6e3a1", 0.0)
|
||||
if got != "#1e1e2e" {
|
||||
t.Errorf("blendHex with 0.0 = %q, want #1e1e2e", got)
|
||||
}
|
||||
|
||||
// Blending with 1.0 amount should return the tint color.
|
||||
got = blendHex("#1e1e2e", "#a6e3a1", 1.0)
|
||||
if got != "#a6e3a1" {
|
||||
t.Errorf("blendHex with 1.0 = %q, want #a6e3a1", got)
|
||||
}
|
||||
|
||||
// Blending black and white at 0.5 should give mid gray.
|
||||
got = blendHex("#000000", "#ffffff", 0.5)
|
||||
// 127 = int(0 + 255*0.5) — truncated, so #7f7f7f
|
||||
if got != "#7f7f7f" {
|
||||
t.Errorf("blendHex black/white at 0.5 = %q, want #7f7f7f", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveDiffBgProducesDifferentColorsPerTheme(t *testing.T) {
|
||||
// Catppuccin palette
|
||||
catBg := [2]string{"#eff1f5", "#1e1e2e"}
|
||||
catSuccess := [2]string{"#40a02b", "#a6e3a1"}
|
||||
catError := [2]string{"#d20f39", "#f38ba8"}
|
||||
|
||||
// KITT palette
|
||||
kittBg := [2]string{"#F0F0F0", "#0D0D0D"}
|
||||
kittSuccess := [2]string{"#998800", "#CCAA00"}
|
||||
kittError := [2]string{"#CC0000", "#FF3333"}
|
||||
|
||||
catInsert, catDelete, _, _, _, _, _ := deriveDiffBg(catBg, catSuccess, catError)
|
||||
kittInsert, kittDelete, _, _, _, _, _ := deriveDiffBg(kittBg, kittSuccess, kittError)
|
||||
|
||||
if catInsert == kittInsert {
|
||||
t.Error("catppuccin DiffInsertBg should differ from kitt DiffInsertBg")
|
||||
}
|
||||
if catDelete == kittDelete {
|
||||
t.Error("catppuccin DiffDeleteBg should differ from kitt DiffDeleteBg")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeThemeDerivesUniqueDiffColors(t *testing.T) {
|
||||
themes := builtinThemes()
|
||||
kitt := themes["kitt"]
|
||||
cat := themes["catppuccin"]
|
||||
|
||||
// The catppuccin diff backgrounds should NOT equal the kitt defaults.
|
||||
if cat.DiffInsertBg == kitt.DiffInsertBg {
|
||||
t.Error("catppuccin DiffInsertBg should differ from kitt default")
|
||||
}
|
||||
if cat.DiffDeleteBg == kitt.DiffDeleteBg {
|
||||
t.Error("catppuccin DiffDeleteBg should differ from kitt default")
|
||||
}
|
||||
if cat.DiffEqualBg == kitt.DiffEqualBg {
|
||||
t.Error("catppuccin DiffEqualBg should differ from kitt default")
|
||||
}
|
||||
}
|
||||
@@ -23,6 +23,7 @@ const (
|
||||
maxCodeLines = 20 // lines for Read / code blocks
|
||||
maxWriteLines = 10 // lines for Write blocks
|
||||
maxBashLines = 20 // lines for Bash output (matches Read)
|
||||
maxLsLines = 20 // lines for Ls directory listings
|
||||
)
|
||||
|
||||
// renderToolBody dispatches to tool-specific body renderers based on tool name.
|
||||
@@ -63,21 +64,44 @@ func renderToolBody(toolName, toolArgs, toolResult string, width int) string {
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// renderEditBody renders a side-by-side diff from old_text/new_text in toolArgs.
|
||||
// Supports both single-edit mode and multi-edit mode (edits array).
|
||||
func renderEditBody(toolArgs, toolResult string, width int) string {
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(toolArgs), &args); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Try to extract the starting line number from the unified diff in the result
|
||||
startLine := extractDiffStartLine(toolResult)
|
||||
|
||||
// Check for multi-edit mode (edits array)
|
||||
if editsArr, ok := args["edits"].([]any); ok && len(editsArr) > 0 {
|
||||
var results []string
|
||||
for _, edit := range editsArr {
|
||||
if e, ok := edit.(map[string]any); ok {
|
||||
oldText, _ := e["old_text"].(string)
|
||||
newText, _ := e["new_text"].(string)
|
||||
if oldText != "" || newText != "" {
|
||||
diff := renderDiffBlock(oldText, newText, startLine, width)
|
||||
if diff != "" {
|
||||
results = append(results, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(results) > 0 {
|
||||
return strings.Join(results, "\n")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Single-edit mode (legacy)
|
||||
oldText, _ := args["old_text"].(string)
|
||||
newText, _ := args["new_text"].(string)
|
||||
if oldText == "" && newText == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Try to extract the starting line number from the unified diff in the result
|
||||
startLine := extractDiffStartLine(toolResult)
|
||||
|
||||
return renderDiffBlock(oldText, newText, startLine, width)
|
||||
}
|
||||
|
||||
@@ -315,6 +339,13 @@ func renderLsBody(toolResult string, width int) string {
|
||||
|
||||
lines := strings.Split(content, "\n")
|
||||
|
||||
// Truncate to maxLsLines for display
|
||||
var hiddenCount int
|
||||
if len(lines) > maxLsLines {
|
||||
hiddenCount = len(lines) - maxLsLines
|
||||
lines = lines[:maxLsLines]
|
||||
}
|
||||
|
||||
const indent = " "
|
||||
codeWidth := max(width-len(indent), 20)
|
||||
|
||||
@@ -329,6 +360,13 @@ func renderLsBody(toolResult string, width int) string {
|
||||
result = append(result, indent+styled)
|
||||
}
|
||||
|
||||
if hiddenCount > 0 {
|
||||
hint := fmt.Sprintf("...(%d more entries)", hiddenCount)
|
||||
hintContent := codeStyle.Width(codeWidth).
|
||||
Foreground(theme.Muted).Italic(true).Render(hint)
|
||||
result = append(result, indent+hintContent)
|
||||
}
|
||||
|
||||
return strings.Join(result, "\n")
|
||||
}
|
||||
|
||||
|
||||
@@ -151,10 +151,6 @@ func (ut *UsageTracker) RenderUsageInfo() string {
|
||||
ut.mu.RLock()
|
||||
defer ut.mu.RUnlock()
|
||||
|
||||
if ut.sessionStats.RequestCount == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
baseStyle := lipgloss.NewStyle()
|
||||
|
||||
// Display the current context window token count (from the last API call),
|
||||
@@ -266,3 +262,14 @@ func (ut *UsageTracker) SetWidth(width int) {
|
||||
defer ut.mu.Unlock()
|
||||
ut.width = width
|
||||
}
|
||||
|
||||
// UpdateModelInfo updates the model information and OAuth status when the model
|
||||
// is switched mid-session. This ensures token costs and context limits are
|
||||
// calculated correctly for the new model.
|
||||
func (ut *UsageTracker) UpdateModelInfo(modelInfo *models.ModelInfo, provider string, isOAuth bool) {
|
||||
ut.mu.Lock()
|
||||
defer ut.mu.Unlock()
|
||||
ut.modelInfo = modelInfo
|
||||
ut.provider = provider
|
||||
ut.isOAuth = isOAuth
|
||||
}
|
||||
|
||||
@@ -67,3 +67,62 @@ func TestUsageTracker_RenderUsageInfo_OAuth(t *testing.T) {
|
||||
t.Errorf("Expected regular rendered output to show actual cost, got: %s", regularRendered)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageTracker_RenderUsageInfo_StartupState(t *testing.T) {
|
||||
// Create a mock model info with costs and context limit
|
||||
modelInfo := &models.ModelInfo{
|
||||
ID: "claude-3-5-sonnet-20241022",
|
||||
Name: "Claude 3.5 Sonnet v2",
|
||||
Cost: models.Cost{
|
||||
Input: 3.0,
|
||||
Output: 15.0,
|
||||
},
|
||||
Limit: models.Limit{
|
||||
Context: 200000,
|
||||
Output: 8192,
|
||||
},
|
||||
}
|
||||
|
||||
// Test startup state (no requests made yet) - Regular API key
|
||||
regularTracker := NewUsageTracker(modelInfo, "anthropic", 80, false)
|
||||
rendered := stripAnsi(regularTracker.RenderUsageInfo())
|
||||
|
||||
// Should NOT return empty string on startup
|
||||
if rendered == "" {
|
||||
t.Errorf("Expected non-empty output on startup, got empty string")
|
||||
}
|
||||
|
||||
// Should show 0 tokens
|
||||
if !strings.Contains(rendered, "Tokens: 0") {
|
||||
t.Errorf("Expected 'Tokens: 0' on startup, got: %s", rendered)
|
||||
}
|
||||
|
||||
// Should NOT show percentage when tokens are 0
|
||||
if strings.Contains(rendered, "(%") {
|
||||
t.Errorf("Expected no percentage on startup with 0 tokens, got: %s", rendered)
|
||||
}
|
||||
|
||||
// Should show $0.0000 cost for regular API key
|
||||
if !strings.Contains(rendered, "Cost: $0.0000") {
|
||||
t.Errorf("Expected 'Cost: $0.0000' on startup, got: %s", rendered)
|
||||
}
|
||||
|
||||
// Test startup state (no requests made yet) - OAuth
|
||||
oauthTracker := NewUsageTracker(modelInfo, "anthropic", 80, true)
|
||||
oauthRendered := stripAnsi(oauthTracker.RenderUsageInfo())
|
||||
|
||||
// Should NOT return empty string on startup
|
||||
if oauthRendered == "" {
|
||||
t.Errorf("Expected non-empty output on startup for OAuth, got empty string")
|
||||
}
|
||||
|
||||
// Should show 0 tokens for OAuth
|
||||
if !strings.Contains(oauthRendered, "Tokens: 0") {
|
||||
t.Errorf("Expected 'Tokens: 0' on startup for OAuth, got: %s", oauthRendered)
|
||||
}
|
||||
|
||||
// Should show $0.00 cost for OAuth
|
||||
if !strings.Contains(oauthRendered, "Cost: $0.00") {
|
||||
t.Errorf("Expected 'Cost: $0.00' on startup for OAuth, got: %s", oauthRendered)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,10 @@ type CredentialManager = auth.CredentialManager
|
||||
// and API key authentication methods.
|
||||
type AnthropicCredentials = auth.AnthropicCredentials
|
||||
|
||||
// OpenAICredentials holds OpenAI API credentials supporting both OAuth
|
||||
// and API key authentication methods.
|
||||
type OpenAICredentials = auth.OpenAICredentials
|
||||
|
||||
// CredentialStore holds all stored credentials for various providers.
|
||||
type CredentialStore = auth.CredentialStore
|
||||
|
||||
@@ -42,3 +46,34 @@ func GetAnthropicAPIKey() string {
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
// HasOpenAICredentials checks if valid OpenAI credentials are stored
|
||||
// (either OAuth token or API key).
|
||||
func HasOpenAICredentials() bool {
|
||||
cm, err := auth.NewCredentialManager()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
has, err := cm.HasOpenAICredentials()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return has
|
||||
}
|
||||
|
||||
// GetOpenAIAPIKey resolves the OpenAI API key using the standard
|
||||
// resolution order: stored credentials -> OPENAI_API_KEY env var.
|
||||
// Returns an empty string if no key is found.
|
||||
func GetOpenAIAPIKey() string {
|
||||
cm, err := auth.NewCredentialManager()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
// Try to get valid access token (handles OAuth refresh)
|
||||
token, err := cm.GetValidOpenAIAccessToken()
|
||||
if err == nil && token != "" {
|
||||
return token
|
||||
}
|
||||
// Fall back to environment variable
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -41,6 +41,10 @@ const (
|
||||
EventReasoningDelta EventType = "reasoning_delta"
|
||||
// EventToolOutput fires when a tool produces streaming output chunks.
|
||||
EventToolOutput EventType = "tool_output"
|
||||
EventStepUsage EventType = "step_usage"
|
||||
// EventSteerConsumed fires when one or more steering messages have been
|
||||
// injected into the agent turn via PrepareStep.
|
||||
EventSteerConsumed EventType = "steer_consumed"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -249,6 +253,19 @@ type ResponseEvent struct {
|
||||
// EventType implements Event.
|
||||
func (e ResponseEvent) EventType() EventType { return EventResponse }
|
||||
|
||||
// StepUsageEvent fires after each complete step in a multi-step agent turn,
|
||||
// carrying the token usage for that specific step. This enables real-time
|
||||
// cost tracking during long-running tool-calling conversations.
|
||||
type StepUsageEvent struct {
|
||||
InputTokens uint64
|
||||
OutputTokens uint64
|
||||
CacheReadTokens uint64
|
||||
CacheWriteTokens uint64
|
||||
}
|
||||
|
||||
// EventType implements Event.
|
||||
func (e StepUsageEvent) EventType() EventType { return EventStepUsage }
|
||||
|
||||
// CompactionEvent fires after a successful compaction.
|
||||
type CompactionEvent struct {
|
||||
Summary string
|
||||
@@ -262,6 +279,16 @@ type CompactionEvent struct {
|
||||
// EventType implements Event.
|
||||
func (e CompactionEvent) EventType() EventType { return EventCompaction }
|
||||
|
||||
// SteerConsumedEvent fires when one or more steering messages have been
|
||||
// injected into the agent turn via PrepareStep. The Count indicates how
|
||||
// many messages were consumed in this batch.
|
||||
type SteerConsumedEvent struct {
|
||||
Count int
|
||||
}
|
||||
|
||||
// EventType implements Event.
|
||||
func (e SteerConsumedEvent) EventType() EventType { return EventSteerConsumed }
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// EventBus
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -2,6 +2,7 @@ package kit
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
@@ -119,6 +120,125 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
})
|
||||
}
|
||||
|
||||
// --- Subagent lifecycle events ---
|
||||
// When an extension registers OnSubagentStart/Chunk/End handlers, bridge
|
||||
// the SDK's per-subagent event stream (SubscribeSubagent) into the
|
||||
// extension runner.
|
||||
//
|
||||
// Flow:
|
||||
// ToolExecutionStartEvent(spawn_subagent) → emit SubagentStartEvent
|
||||
// → SubscribeSubagent → emit SubagentChunkEvents
|
||||
// ToolResultEvent(spawn_subagent) → emit SubagentEndEvent
|
||||
//
|
||||
// We use ToolExecutionStart (not ToolCall) for SubagentStart because that
|
||||
// is when the subagent actually begins running. We use ToolResult for
|
||||
// SubagentEnd because that carries the final response text.
|
||||
wantsSubagent := runner.HasHandlers(extensions.SubagentStart) ||
|
||||
runner.HasHandlers(extensions.SubagentChunk) ||
|
||||
runner.HasHandlers(extensions.SubagentEnd)
|
||||
|
||||
if wantsSubagent {
|
||||
// taskByCallID tracks the task description extracted from ToolCall input,
|
||||
// keyed by toolCallID. Populated on ToolCall, consumed on ToolResult.
|
||||
taskByCallID := make(map[string]string)
|
||||
var taskMu = &taskMutex{}
|
||||
|
||||
// Intercept ToolCall to capture the task and subscribe to child events.
|
||||
m.Subscribe(func(e Event) {
|
||||
ev, ok := e.(ToolCallEvent)
|
||||
if !ok || ev.ToolName != "spawn_subagent" {
|
||||
return
|
||||
}
|
||||
|
||||
// Extract task from parsed args.
|
||||
task := ""
|
||||
if ev.ParsedArgs != nil {
|
||||
if t, ok := ev.ParsedArgs["task"].(string); ok {
|
||||
task = t
|
||||
}
|
||||
}
|
||||
taskMu.set(taskByCallID, ev.ToolCallID, task)
|
||||
|
||||
// Subscribe to child events so we can forward them as SubagentChunkEvents.
|
||||
if runner.HasHandlers(extensions.SubagentChunk) {
|
||||
m.SubscribeSubagent(ev.ToolCallID, func(childEvent Event) {
|
||||
chunk := extensions.SubagentChunkEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
Task: task,
|
||||
}
|
||||
switch ce := childEvent.(type) {
|
||||
case MessageUpdateEvent:
|
||||
chunk.ChunkType = "text"
|
||||
chunk.Content = ce.Chunk
|
||||
case TurnStartEvent:
|
||||
chunk.ChunkType = "turn_start"
|
||||
case TurnEndEvent:
|
||||
chunk.ChunkType = "turn_end"
|
||||
case ToolCallEvent:
|
||||
chunk.ChunkType = "tool_call"
|
||||
chunk.ToolName = ce.ToolName
|
||||
chunk.ToolArgs = ce.ToolArgs
|
||||
case ToolExecutionStartEvent:
|
||||
chunk.ChunkType = "tool_execution_start"
|
||||
chunk.ToolName = ce.ToolName
|
||||
case ToolExecutionEndEvent:
|
||||
chunk.ChunkType = "tool_execution_end"
|
||||
chunk.ToolName = ce.ToolName
|
||||
case ToolResultEvent:
|
||||
chunk.ChunkType = "tool_result"
|
||||
chunk.ToolName = ce.ToolName
|
||||
chunk.ToolResult = ce.Result
|
||||
chunk.IsError = ce.IsError
|
||||
default:
|
||||
return // skip unknown event types
|
||||
}
|
||||
_, _ = runner.Emit(chunk)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// Emit SubagentStartEvent when execution begins.
|
||||
if runner.HasHandlers(extensions.SubagentStart) {
|
||||
m.Subscribe(func(e Event) {
|
||||
ev, ok := e.(ToolExecutionStartEvent)
|
||||
if !ok || ev.ToolName != "spawn_subagent" {
|
||||
return
|
||||
}
|
||||
task := taskMu.get(taskByCallID, ev.ToolCallID)
|
||||
_, _ = runner.Emit(extensions.SubagentStartEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
Task: task,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Emit SubagentEndEvent when the tool result arrives.
|
||||
if runner.HasHandlers(extensions.SubagentEnd) {
|
||||
m.Subscribe(func(e Event) {
|
||||
ev, ok := e.(ToolResultEvent)
|
||||
if !ok || ev.ToolName != "spawn_subagent" {
|
||||
return
|
||||
}
|
||||
task := taskMu.get(taskByCallID, ev.ToolCallID)
|
||||
taskMu.del(taskByCallID, ev.ToolCallID)
|
||||
errMsg := ""
|
||||
if ev.IsError {
|
||||
errMsg = ev.Result
|
||||
}
|
||||
response := ""
|
||||
if !ev.IsError {
|
||||
response = ev.Result
|
||||
}
|
||||
_, _ = runner.Emit(extensions.SubagentEndEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
Task: task,
|
||||
Response: response,
|
||||
ErrorMsg: errMsg,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- Context filtering hook ---
|
||||
// Extension ContextPrepare → SDK ContextPrepare hook.
|
||||
if runner.HasHandlers(extensions.ContextPrepare) {
|
||||
@@ -204,3 +324,27 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// taskMutex is a simple mutex-protected map helper used by bridgeExtensions.
|
||||
// It lives in this file to avoid polluting the kit package with unexported types.
|
||||
type taskMutex struct {
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (t *taskMutex) set(m map[string]string, key, val string) {
|
||||
t.mu.Lock()
|
||||
m[key] = val
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
func (t *taskMutex) get(m map[string]string, key string) string {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return m[key]
|
||||
}
|
||||
|
||||
func (t *taskMutex) del(m map[string]string, key string) {
|
||||
t.mu.Lock()
|
||||
delete(m, key)
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
+122
@@ -66,6 +66,13 @@ type Kit struct {
|
||||
// subagentListeners holds per-tool-call event listeners registered via
|
||||
// SubscribeSubagent(). Keyed by toolCallID → *subagentListenerSet.
|
||||
subagentListeners sync.Map
|
||||
|
||||
// steerCh is a buffered channel used to inject steering messages into
|
||||
// the running agent turn via Fantasy's PrepareStep. Created fresh for
|
||||
// each generate() call and set to nil when idle. Protected by steerMu.
|
||||
steerMu sync.Mutex
|
||||
steerCh chan string
|
||||
leftoverSteer []string // unconsumed steer messages from the last turn
|
||||
}
|
||||
|
||||
// Subscribe registers an EventListener that will be called for every lifecycle
|
||||
@@ -529,8 +536,11 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
|
||||
}
|
||||
|
||||
// Build a provider config from current settings, overriding the model.
|
||||
// Load system prompt properly (handles both file paths and inline content).
|
||||
systemPrompt, _ := config.LoadSystemPrompt(viper.GetString("system-prompt"))
|
||||
config := &models.ProviderConfig{
|
||||
ModelString: modelString,
|
||||
SystemPrompt: systemPrompt,
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
ProviderURL: viper.GetString("provider-url"),
|
||||
MaxTokens: viper.GetInt("max-tokens"),
|
||||
@@ -1053,6 +1063,15 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
// Bridge extension events to SDK hooks.
|
||||
if agentResult.ExtRunner != nil {
|
||||
k.bridgeExtensions(agentResult.ExtRunner)
|
||||
|
||||
// Initialize extension context with minimal defaults. SDK users can call
|
||||
// SetExtensionContext to override with richer implementations (TUI callbacks,
|
||||
// prompts, etc.). This ensures extensions never crash on nil function fields.
|
||||
k.SetExtensionContext(extensions.Context{
|
||||
CWD: cwd,
|
||||
Model: k.modelString,
|
||||
Interactive: false, // SDK mode defaults to non-interactive
|
||||
})
|
||||
}
|
||||
|
||||
return k, nil
|
||||
@@ -1405,6 +1424,35 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult
|
||||
// All prompt modes (Prompt, Steer, FollowUp, PromptWithOptions) share this
|
||||
// single code path so callback wiring is never duplicated.
|
||||
func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.GenerateWithLoopResult, error) {
|
||||
// Create a per-turn steer channel and attach it to the context so the
|
||||
// agent's PrepareStep can inject steering messages between steps.
|
||||
steerCh := make(chan string, 16)
|
||||
m.steerMu.Lock()
|
||||
m.steerCh = steerCh
|
||||
m.steerMu.Unlock()
|
||||
defer func() {
|
||||
// Drain any unconsumed steer messages before nilling the channel.
|
||||
// These are stored in leftoverSteer so DrainSteer() can return them.
|
||||
var leftover []string
|
||||
for {
|
||||
select {
|
||||
case msg := <-steerCh:
|
||||
leftover = append(leftover, msg)
|
||||
default:
|
||||
goto drained
|
||||
}
|
||||
}
|
||||
drained:
|
||||
m.steerMu.Lock()
|
||||
m.steerCh = nil
|
||||
m.leftoverSteer = leftover
|
||||
m.steerMu.Unlock()
|
||||
}()
|
||||
ctx = agent.ContextWithSteerCh(ctx, steerCh)
|
||||
ctx = agent.ContextWithSteerConsumed(ctx, func(count int) {
|
||||
m.events.emit(SteerConsumedEvent{Count: count})
|
||||
})
|
||||
|
||||
// Inject the in-process subagent spawner into the context so the
|
||||
// spawn_subagent core tool can create child Kit instances without
|
||||
// importing pkg/kit (which would create an import cycle).
|
||||
@@ -1491,6 +1539,15 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
|
||||
IsStderr: isStderr,
|
||||
})
|
||||
},
|
||||
func(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64) {
|
||||
// Emit step usage event for real-time cost tracking
|
||||
m.events.emit(StepUsageEvent{
|
||||
InputTokens: uint64(inputTokens),
|
||||
OutputTokens: uint64(outputTokens),
|
||||
CacheReadTokens: uint64(cacheReadTokens),
|
||||
CacheWriteTokens: uint64(cacheCreationTokens),
|
||||
})
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1705,6 +1762,71 @@ func (m *Kit) FollowUp(ctx context.Context, text string) (string, error) {
|
||||
return result.Response, nil
|
||||
}
|
||||
|
||||
// InjectSteer sends a steering message into the currently active agent turn.
|
||||
// The message will be injected as a user message between steps (after the
|
||||
// current tool execution finishes, before the next LLM call). If no turn is
|
||||
// active the message is silently dropped — callers should check IsGenerating()
|
||||
// or use Prompt()/Steer() for idle-state messaging.
|
||||
//
|
||||
// InjectSteer is safe to call from any goroutine. Multiple calls queue
|
||||
// messages in order; all pending steer messages are drained and injected
|
||||
// together at the next step boundary.
|
||||
//
|
||||
// This is the preferred way to redirect an agent mid-turn without cancelling
|
||||
// in-progress tool execution.
|
||||
func (m *Kit) InjectSteer(message string) {
|
||||
m.steerMu.Lock()
|
||||
ch := m.steerCh
|
||||
m.steerMu.Unlock()
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case ch <- message:
|
||||
default:
|
||||
// Channel full — extremely unlikely with buffer of 16, but don't block.
|
||||
}
|
||||
}
|
||||
|
||||
// IsGenerating returns true if an agent turn is currently in progress.
|
||||
// Use this to decide between InjectSteer (mid-turn) and Prompt (new turn).
|
||||
func (m *Kit) IsGenerating() bool {
|
||||
m.steerMu.Lock()
|
||||
defer m.steerMu.Unlock()
|
||||
return m.steerCh != nil
|
||||
}
|
||||
|
||||
// DrainSteer removes and returns all unconsumed steer messages. Called after
|
||||
// a turn completes so the app layer can process any steer messages that
|
||||
// arrived after the last PrepareStep fired (e.g. during a text-only response
|
||||
// with no tool calls, or after the agent finished its last step).
|
||||
func (m *Kit) DrainSteer() []string {
|
||||
m.steerMu.Lock()
|
||||
defer m.steerMu.Unlock()
|
||||
|
||||
// First check leftover messages saved when generate() returned.
|
||||
if len(m.leftoverSteer) > 0 {
|
||||
msgs := m.leftoverSteer
|
||||
m.leftoverSteer = nil
|
||||
return msgs
|
||||
}
|
||||
|
||||
// If a turn is still active, drain from the live channel.
|
||||
if m.steerCh != nil {
|
||||
var msgs []string
|
||||
for {
|
||||
select {
|
||||
case msg := <-m.steerCh:
|
||||
msgs = append(msgs, msg)
|
||||
default:
|
||||
return msgs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PromptOptions configures a single PromptWithOptions call.
|
||||
type PromptOptions struct {
|
||||
// SystemMessage is prepended as a system message before the user prompt.
|
||||
|
||||
@@ -59,6 +59,79 @@ result := ctx.SpawnSubagent(ext.SubagentConfig{
|
||||
})
|
||||
```
|
||||
|
||||
### Monitoring subagents from extensions
|
||||
|
||||
When the LLM (not the extension itself) spawns a subagent using the `spawn_subagent` tool, extensions can monitor its activity in real-time using three lifecycle event handlers:
|
||||
|
||||
```go
|
||||
// Track active subagents and display their output
|
||||
var subagentWidgets map[string]*SubagentWidget
|
||||
|
||||
func Init(api ext.API) {
|
||||
// Subagent started by the main agent
|
||||
api.OnSubagentStart(func(e ext.SubagentStartEvent, ctx ext.Context) {
|
||||
// e.ToolCallID — unique ID for this subagent invocation
|
||||
// e.Task — the task/prompt sent to the subagent
|
||||
widget := NewWidget(e.ToolCallID, e.Task)
|
||||
subagentWidgets[e.ToolCallID] = widget
|
||||
ctx.SetWidget(widget.Config())
|
||||
})
|
||||
|
||||
// Real-time streaming from subagent
|
||||
api.OnSubagentChunk(func(e ext.SubagentChunkEvent, ctx ext.Context) {
|
||||
// e.ToolCallID — matches the start event
|
||||
// e.ChunkType — "text", "tool_call", "tool_execution_start", "tool_result"
|
||||
// e.Content — text content
|
||||
// e.ToolName — tool name (for tool chunks)
|
||||
// e.IsError — true if tool result failed
|
||||
widget := subagentWidgets[e.ToolCallID]
|
||||
if widget != nil {
|
||||
widget.AddOutput(e)
|
||||
ctx.SetWidget(widget.Config())
|
||||
}
|
||||
})
|
||||
|
||||
// Subagent completed
|
||||
api.OnSubagentEnd(func(e ext.SubagentEndEvent, ctx ext.Context) {
|
||||
// e.Response — final response from subagent
|
||||
// e.ErrorMsg — error message if subagent failed
|
||||
widget := subagentWidgets[e.ToolCallID]
|
||||
if widget != nil {
|
||||
widget.MarkComplete(e.Response, e.ErrorMsg)
|
||||
ctx.SetWidget(widget.Config())
|
||||
delete(subagentWidgets, e.ToolCallID)
|
||||
}
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
**Event structs:**
|
||||
|
||||
```go
|
||||
type SubagentStartEvent struct {
|
||||
ToolCallID string // Unique ID for this subagent invocation
|
||||
Task string // The task/prompt sent to subagent
|
||||
}
|
||||
|
||||
type SubagentChunkEvent struct {
|
||||
ToolCallID string // Matches SubagentStartEvent.ToolCallID
|
||||
Task string // Task description
|
||||
ChunkType string // "text", "tool_call", "tool_execution_start", "tool_result"
|
||||
Content string // For text chunks
|
||||
ToolName string // For tool-related chunks
|
||||
IsError bool // For tool_result chunks
|
||||
}
|
||||
|
||||
type SubagentEndEvent struct {
|
||||
ToolCallID string // Matches start event
|
||||
Task string // Task description
|
||||
Response string // Final response from subagent
|
||||
ErrorMsg string // Error message if failed
|
||||
}
|
||||
```
|
||||
|
||||
This enables building monitoring widgets that display real-time activity from all subagents spawned by the main agent.
|
||||
|
||||
## Go SDK subagents
|
||||
|
||||
The SDK provides in-process subagent spawning:
|
||||
|
||||
@@ -74,7 +74,7 @@ These commands are available inside the Kit TUI during an interactive session:
|
||||
| `/reset-usage` | Reset usage statistics |
|
||||
| `/tree` | Navigate session tree |
|
||||
| `/fork` | Branch from an earlier message |
|
||||
| `/new` | Start a new session |
|
||||
| `/new` | Start a new session (creates new session file) |
|
||||
| `/name [name]` | Set or show session display name |
|
||||
| `/resume` | Open session picker to switch sessions (alias: `/r`) |
|
||||
| `/session` | Show session info |
|
||||
@@ -95,9 +95,17 @@ Press **ESC twice** to cancel the current operation:
|
||||
|
||||
This ensures that `tool_use` and `tool_result` messages are always sent to the API as matched pairs, avoiding errors from orphaned tool calls.
|
||||
|
||||
## Prompt templates
|
||||
### Mid-turn steering
|
||||
|
||||
Create reusable prompt templates with shell-style argument substitution. Templates are loaded from `~/.kit/prompts/*.md` and `.kit/prompts/*.md`.
|
||||
Press **Ctrl+S** during streaming to inject a system-level instruction mid-turn. This allows you to steer the conversation direction without waiting for the model to finish:
|
||||
|
||||
- Works during streaming output
|
||||
- Sends a steering instruction as a system message
|
||||
- Model continues from the interruption point with the new guidance
|
||||
|
||||
Example: While the model is writing code, press Ctrl+S and type "Use async/await instead" to change the implementation approach.
|
||||
|
||||
## Prompt templates
|
||||
|
||||
### Creating templates
|
||||
|
||||
|
||||
@@ -96,9 +96,45 @@ mcpServers:
|
||||
|
||||
A legacy format with `transport`, `args`, `env`, and `headers` fields is also supported.
|
||||
|
||||
## Theme configuration
|
||||
## Custom models
|
||||
|
||||
Set theme colors inline or reference an external file:
|
||||
Define custom models in your `.kit.yml` for use with the `custom` provider. This is useful for self-hosted models or API endpoints not in the built-in database:
|
||||
|
||||
```yaml
|
||||
customModels:
|
||||
my-model:
|
||||
name: "My Custom Model"
|
||||
reasoning: true
|
||||
temperature: true
|
||||
cost:
|
||||
input: 0.002
|
||||
output: 0.004
|
||||
limit:
|
||||
context: 128000
|
||||
output: 32000
|
||||
```
|
||||
|
||||
### Custom model fields
|
||||
|
||||
| Field | Type | Required | Description |
|
||||
|-------|------|----------|-------------|
|
||||
| `name` | string | Yes | Display name for the model |
|
||||
| `reasoning` | bool | No | Whether the model supports reasoning/thinking |
|
||||
| `temperature` | bool | No | Whether the model supports temperature adjustment |
|
||||
| `cost.input` | float | No | Cost per 1K input tokens |
|
||||
| `cost.output` | float | No | Cost per 1K output tokens |
|
||||
| `limit.context` | int | Yes | Maximum context window in tokens |
|
||||
| `limit.output` | int | No | Maximum output tokens |
|
||||
|
||||
Use with a custom provider URL:
|
||||
|
||||
```bash
|
||||
kit --provider-url "http://localhost:8080/v1" --model custom/my-model "Hello"
|
||||
```
|
||||
|
||||
When `--provider-url` is specified without `--model`, Kit defaults to `custom/custom` which has zero cost tracking and a 262K context window.
|
||||
|
||||
## Theme configuration
|
||||
|
||||
```yaml
|
||||
# Inline partial overrides (unspecified fields inherit from default)
|
||||
|
||||
@@ -7,7 +7,7 @@ description: All extension capabilities — lifecycle events, tools, commands, w
|
||||
|
||||
## Lifecycle events
|
||||
|
||||
Extensions can hook into 20 lifecycle events:
|
||||
Extensions can hook into 23 lifecycle events:
|
||||
|
||||
| Event | Description |
|
||||
|-------|-------------|
|
||||
@@ -31,6 +31,9 @@ Extensions can hook into 20 lifecycle events:
|
||||
| `OnBeforeSessionSwitch` | Before switching sessions |
|
||||
| `OnBeforeCompact` | Before conversation compaction |
|
||||
| `OnCustomEvent` | Custom inter-extension event received |
|
||||
| `OnSubagentStart` | Subagent spawned by the main agent |
|
||||
| `OnSubagentChunk` | Real-time output from subagent (text, tool calls, results) |
|
||||
| `OnSubagentEnd` | Subagent completed with final response/error |
|
||||
|
||||
### Example
|
||||
|
||||
@@ -234,6 +237,54 @@ result := ctx.SpawnSubagent(ext.SubagentConfig{
|
||||
})
|
||||
```
|
||||
|
||||
### Monitoring subagents spawned by the main agent
|
||||
|
||||
When the LLM uses the built-in `spawn_subagent` tool, extensions can monitor the subagent's activity in real-time using three lifecycle events:
|
||||
|
||||
```go
|
||||
// Subagent started
|
||||
api.OnSubagentStart(func(e ext.SubagentStartEvent, ctx ext.Context) {
|
||||
// e.ToolCallID — unique ID for this subagent invocation
|
||||
// e.Task — the task/prompt sent to the subagent
|
||||
ctx.PrintInfo(fmt.Sprintf("Subagent started: %s", e.Task))
|
||||
})
|
||||
|
||||
// Real-time streaming output from subagent
|
||||
api.OnSubagentChunk(func(e ext.SubagentChunkEvent, ctx ext.Context) {
|
||||
// e.ToolCallID — matches the start event
|
||||
// e.Task — task description
|
||||
// e.ChunkType — "text", "tool_call", "tool_execution_start", "tool_result"
|
||||
// e.Content — text content (for text chunks)
|
||||
// e.ToolName — tool name (for tool-related chunks)
|
||||
// e.IsError — true if tool result is an error
|
||||
switch e.ChunkType {
|
||||
case "text":
|
||||
// Streaming text output
|
||||
case "tool_call":
|
||||
// Subagent is calling a tool
|
||||
case "tool_execution_start":
|
||||
// Tool execution started
|
||||
case "tool_result":
|
||||
// Tool execution completed (check e.IsError)
|
||||
}
|
||||
})
|
||||
|
||||
// Subagent completed
|
||||
api.OnSubagentEnd(func(e ext.SubagentEndEvent, ctx ext.Context) {
|
||||
// e.ToolCallID — matches start event
|
||||
// e.Task — task description
|
||||
// e.Response — final response from subagent
|
||||
// e.ErrorMsg — error message if subagent failed
|
||||
if e.ErrorMsg != "" {
|
||||
ctx.PrintError(fmt.Sprintf("Subagent failed: %s", e.ErrorMsg))
|
||||
} else {
|
||||
ctx.PrintInfo(fmt.Sprintf("Subagent completed: %s", e.Response))
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
This enables building widgets that display real-time subagent activity.
|
||||
|
||||
## LLM completion
|
||||
|
||||
Make direct model calls without going through the agent loop:
|
||||
|
||||
+11
-3
@@ -30,12 +30,20 @@ When conversations grow long, Kit can compact them to free up context window spa
|
||||
|
||||
Use `/compact [focus]` to manually compact, or enable `--auto-compact` to compact automatically near the context limit.
|
||||
|
||||
## Auto-cleanup
|
||||
|
||||
Kit automatically cleans up empty sessions on shutdown and when using `/resume`. A session is considered empty if it has no messages beyond the initial system prompt. This prevents cluttering your sessions directory with unused files.
|
||||
|
||||
To start fresh without creating a session file at all, use ephemeral mode:
|
||||
|
||||
```bash
|
||||
kit --no-session
|
||||
```
|
||||
|
||||
## Resuming sessions
|
||||
|
||||
### Continue most recent
|
||||
|
||||
Resume the most recent session for the current directory:
|
||||
|
||||
```bash
|
||||
kit --continue
|
||||
kit -c
|
||||
@@ -73,7 +81,7 @@ These slash commands are available during an interactive session:
|
||||
| `/share` | Upload session to GitHub Gist and get a shareable viewer URL |
|
||||
| `/tree` | Navigate the session tree |
|
||||
| `/fork` | Branch from an earlier message |
|
||||
| `/new` | Start a fresh session |
|
||||
| `/new` | Start a new session (creates new session file) |
|
||||
|
||||
## Ephemeral mode
|
||||
|
||||
|
||||
Reference in New Issue
Block a user