mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 11:40:13 +00:00
Compare commits
29 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7f366eab84 | |||
| e8e99b19a8 | |||
| ef072f6e59 | |||
| 49f8b485be | |||
| febdc530e1 | |||
| e610bdd2d0 | |||
| 6100e8b3a8 | |||
| 9f125f3400 | |||
| 00eab47218 | |||
| 06bf6d087a | |||
| fd960921ca | |||
| 0b651a8df9 | |||
| 7315c1dea7 | |||
| 0313fa03ad | |||
| d27022bcfb | |||
| ae722d520f | |||
| 7a04bdfeba | |||
| 7e4708f511 | |||
| 1e12102b92 | |||
| ab2a77c95e | |||
| 1e78153b50 | |||
| a613361969 | |||
| 67722b0c24 | |||
| 1a2f6da40f | |||
| 747f5be099 | |||
| d7c4565999 | |||
| bd24f3315c | |||
| 592f8dc84f | |||
| 66c4a1eb15 |
@@ -127,6 +127,13 @@ max-tokens: 4096
|
||||
temperature: 0.7
|
||||
stream: true
|
||||
thinking-level: off # off, none, minimal, low, medium, high
|
||||
no-core-tools: false # set to true to disable all built-in core tools
|
||||
|
||||
# Skills — all three keys are optional
|
||||
no-skills: false # set to true to disable all skill loading
|
||||
skill: # explicit skill files/dirs (disables auto-discovery)
|
||||
- /path/to/skill.md
|
||||
skills-dir: "" # override project-local directory for auto-discovery
|
||||
```
|
||||
|
||||
All of the above keys can also be set programmatically via the SDK
|
||||
@@ -195,12 +202,18 @@ mcpServers:
|
||||
--compact Enable compact output mode
|
||||
--auto-compact Auto-compact conversation near context limit
|
||||
|
||||
# Extensions
|
||||
# Extensions and tools
|
||||
--extension, -e Load additional extension file(s) (repeatable)
|
||||
--no-extensions Disable all extensions
|
||||
--no-core-tools Disable all built-in core tools (bash, read, write, edit, grep, find, ls, subagent)
|
||||
--prompt-template Load a specific prompt template by name
|
||||
--no-prompt-templates Disable prompt template loading
|
||||
|
||||
# Skills
|
||||
--skill Load skill file or directory (repeatable)
|
||||
--skills-dir Override the project-local skills directory for auto-discovery
|
||||
--no-skills Disable skill loading (auto-discovery and explicit)
|
||||
|
||||
# Generation parameters
|
||||
--max-tokens Maximum tokens in response (default: 8192, auto-raised up to 32768 for models with larger known output limits)
|
||||
--temperature Randomness 0.0-1.0 (default: 0.7)
|
||||
@@ -226,6 +239,10 @@ kit auth login [provider] --set-default # Set provider's default model as syste
|
||||
kit auth logout [provider] # Remove credentials for provider
|
||||
kit auth status # Check authentication status
|
||||
|
||||
# GitHub Copilot login (experimental; requires active Copilot subscription)
|
||||
kit auth login copilot
|
||||
kit --model copilot/gpt-5.5 "Hello"
|
||||
|
||||
# Model database
|
||||
kit models [provider] # List available models (optionally filter by provider)
|
||||
kit models --all # Show all providers (not just LLM-compatible)
|
||||
@@ -306,12 +323,15 @@ kit -e examples/extensions/minimal.go
|
||||
|
||||
### Extension Capabilities
|
||||
|
||||
**Lifecycle Events**: OnSessionStart, OnSessionShutdown, OnBeforeAgentStart, OnAgentStart, OnAgentEnd, OnToolCall, OnToolCallInputStart, OnToolCallInputDelta, OnToolCallInputEnd, OnToolExecutionStart, OnToolOutput, OnToolExecutionEnd, OnToolResult, OnInput, OnMessageStart, OnMessageUpdate, OnMessageEnd, OnModelChange, OnContextPrepare, OnBeforeFork, OnBeforeSessionSwitch, OnBeforeCompact, OnCustomEvent, OnSubagentStart, OnSubagentChunk, OnSubagentEnd
|
||||
**Lifecycle Events**: OnSessionStart, OnSessionShutdown, OnBeforeAgentStart, OnAgentStart, OnAgentEnd, OnLLMUsage, OnToolCall, OnToolCallInputStart, OnToolCallInputDelta, OnToolCallInputEnd, OnToolExecutionStart, OnToolOutput, OnToolExecutionEnd, OnToolResult, OnInput, OnMessageStart, OnMessageUpdate, OnMessageEnd, OnModelChange, OnContextPrepare, OnBeforeFork, OnBeforeSessionSwitch, OnBeforeCompact, OnCustomEvent, OnSubagentStart, OnSubagentChunk, OnSubagentEnd
|
||||
|
||||
`OnAgentEnd` carries per-turn aggregates (`ToolCallCount`, `ToolNames`, `LLMCallCount`, `InputTokensDelta`, `OutputTokensDelta`, `CostDelta`, `DurationMs`) so observers don't need to maintain parallel bookkeeping. `OnLLMUsage` fires after each LLM provider call with token + cost deltas attributed to that specific call/model — use it for accurate budget enforcement *between* calls instead of waiting for the turn to finish.
|
||||
|
||||
**Custom Components**:
|
||||
- **Tools**: Add new tools the LLM can invoke
|
||||
- **Commands**: Register slash commands (e.g., `/mycommand`)
|
||||
- **Options**: Register configurable extension options
|
||||
- **Session State**: Last-write-wins key-value store via `ctx.SetState` / `GetState` / `DeleteState` / `ListState`, persisted to a per-session sidecar file outside the conversation tree
|
||||
- **Widgets**: Persistent status displays above/below input
|
||||
- **Headers/Footers**: Persistent content above/below the conversation
|
||||
- **Status Bar**: Custom status bar entries
|
||||
@@ -367,6 +387,7 @@ See the `examples/extensions/` directory:
|
||||
- [`tool-logger.go`](examples/extensions/tool-logger.go) - Log all tool calls
|
||||
- [`neon-theme.go`](examples/extensions/neon-theme.go) - Custom theme registration and switching
|
||||
- [`tool-renderer-demo.go`](examples/extensions/tool-renderer-demo.go) - Custom tool call rendering
|
||||
- [`usage-budget.go`](examples/extensions/usage-budget.go) - Per-call usage callback (`OnLLMUsage`), session state, and enriched `OnAgentEnd` per-turn report
|
||||
- [`widget-status.go`](examples/extensions/widget-status.go) - Persistent status widgets
|
||||
|
||||
Also see [`.kit/extensions/go-edit-lint.go`](.kit/extensions/go-edit-lint.go) (in this repo) for a project-local extension example that runs gopls and golangci-lint on Go file edits.
|
||||
@@ -507,6 +528,8 @@ During an interactive session, use these slash commands:
|
||||
|
||||
| Shortcut | Description |
|
||||
|----------|-------------|
|
||||
| `Ctrl+V` | Paste an image from the clipboard — shows an inline low-res thumbnail preview (tmux/zellij-safe) |
|
||||
| `Ctrl+U` | Clear all pending image attachments |
|
||||
| `Ctrl+X e` | Open `$VISUAL`/`$EDITOR` to compose or edit your prompt |
|
||||
| `Ctrl+X s` | Steer — inject a system-level instruction mid-turn |
|
||||
| `ESC ESC` | Cancel the current operation (tool call or streaming) |
|
||||
@@ -554,7 +577,7 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
SystemPrompt: "You are a helpful bot",
|
||||
ConfigFile: "/path/to/config.yml",
|
||||
MaxSteps: 10,
|
||||
Streaming: true,
|
||||
Streaming: ptr(true), // *bool: nil = unset (default true), &false = off
|
||||
Quiet: true,
|
||||
|
||||
// Generation parameters (override env/config/per-model defaults)
|
||||
@@ -579,7 +602,9 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
// Tool options
|
||||
Tools: []kit.Tool{...}, // Replace default tool set entirely
|
||||
ExtraTools: []kit.Tool{...}, // Add tools alongside defaults
|
||||
DisableCoreTools: true, // Use no core tools (0 tools, for chat-only)
|
||||
DisableCoreTools: true, // Disable all built-in core tools; also controllable via
|
||||
// --no-core-tools flag, KIT_NO_CORE_TOOLS env var,
|
||||
// or no-core-tools: true in .kit.yml
|
||||
|
||||
// Configuration
|
||||
SkipConfig: true, // Skip .kit.yml files (viper defaults + env vars still apply)
|
||||
@@ -599,6 +624,38 @@ are pointer types so explicit `0.0` is distinguishable from "leave alone"; a
|
||||
non-zero `MaxTokens` suppresses automatic right-sizing the same way `--max-tokens`
|
||||
does on the CLI.
|
||||
|
||||
### Functional options (`NewAgent`)
|
||||
|
||||
For simple programmatic setups, `kit.NewAgent` offers an ergonomic
|
||||
functional-options front door over `kit.New`. Streaming is **enabled by
|
||||
default**; pass `kit.WithStreaming(false)` to opt out.
|
||||
|
||||
```go
|
||||
host, err := kit.NewAgent(ctx,
|
||||
kit.WithModel("anthropic/claude-sonnet-4-5-20250929"),
|
||||
kit.WithSystemPrompt("You are a helpful assistant."),
|
||||
kit.WithMaxTokens(8192),
|
||||
kit.WithThinkingLevel("medium"),
|
||||
kit.Ephemeral(), // in-memory session, no persistence
|
||||
)
|
||||
```
|
||||
|
||||
Available options: `WithModel`, `WithSystemPrompt`, `WithStreaming`,
|
||||
`WithMaxTokens`, `WithThinkingLevel`, `WithTools`, `WithExtraTools`,
|
||||
`WithProviderAPIKey`, `WithProviderURL`, `WithConfigFile`, `WithDebug`, and
|
||||
`Ephemeral`. For advanced configuration not covered by the helpers (custom MCP
|
||||
config, in-process MCP servers, session backends, MCP task tuning) construct an
|
||||
`Options` value explicitly and call `kit.New`.
|
||||
|
||||
### Per-instance config isolation
|
||||
|
||||
Each `kit.New` / `kit.NewAgent` call owns an **isolated configuration store**,
|
||||
so constructing multiple Kit instances in the same process is safe: setting the
|
||||
model, thinking level, or generation parameters on one never affects another,
|
||||
and runtime mutators (`SetModel`, `SetThinkingLevel`) only touch the owning
|
||||
instance. This makes subagent spawning and multi-Kit embedding race-free with
|
||||
no external synchronization required.
|
||||
|
||||
### MCP OAuth (remote MCP servers)
|
||||
|
||||
When a remote MCP server returns 401, Kit runs the full OAuth flow (dynamic
|
||||
@@ -756,6 +813,45 @@ host, _ := kit.New(ctx, &kit.Options{
|
||||
})
|
||||
```
|
||||
|
||||
### Runtime Skills & Context Files
|
||||
|
||||
For multi-tenant hosts (chatbots, per-user agents, web services), the SDK
|
||||
lets you swap skills and `AGENTS.md`-style context files **after** Kit
|
||||
construction. Every mutation recomposes the system prompt and applies it to
|
||||
the agent so the next turn picks up the new instructions — no restart needed.
|
||||
|
||||
```go
|
||||
// Programmatic skill (no file on disk required).
|
||||
host.AddSkill(&kit.Skill{
|
||||
Name: "polite-french",
|
||||
Description: "Respond in French and always greet the user.",
|
||||
Content: "Always reply in French. Open every response with 'Bonjour'.",
|
||||
})
|
||||
|
||||
// Or load one from disk.
|
||||
host.LoadAndAddSkill("/var/skills/refund-policy.md")
|
||||
|
||||
// Per-user AGENTS.md content pulled from a database.
|
||||
host.AddContextFileContent(
|
||||
fmt.Sprintf("session://%s/AGENTS.md", userID),
|
||||
rulesFromDB,
|
||||
)
|
||||
|
||||
// Tear down session-specific state on logout.
|
||||
host.RemoveSkill("polite-french")
|
||||
host.RemoveContextFile(fmt.Sprintf("session://%s/AGENTS.md", userID))
|
||||
|
||||
// Or replace the whole set atomically.
|
||||
host.SetSkills(activeSkillsForUser)
|
||||
host.SetContextFiles(activeContextForUser)
|
||||
```
|
||||
|
||||
Skills dedupe by `Name`, context files dedupe by `Path` (which can be any
|
||||
opaque identifier — it doesn't have to be a real filesystem path). All
|
||||
mutators and readers (`GetSkills`, `GetContextFiles`) are safe to call
|
||||
concurrently from multiple goroutines. See the [SDK overview docs](/sdk/overview#runtime-skills-and-context-files)
|
||||
for the full reference.
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Subagent Pattern
|
||||
@@ -872,6 +968,7 @@ npm/ - NPM package wrapper for distribution
|
||||
|
||||
- **Anthropic** - Claude models (native, prompt caching, OAuth)
|
||||
- **OpenAI** - GPT models
|
||||
- **Copilot** - GitHub Copilot models (`copilot`, requires active Copilot subscription)
|
||||
- **Google** - Gemini models
|
||||
- **Ollama** - Local models
|
||||
- **Azure OpenAI** - Azure-hosted OpenAI
|
||||
@@ -897,6 +994,31 @@ This automatically defaults to `custom/custom` without needing to specify a mode
|
||||
- Reasoning and temperature support
|
||||
- Optional `CUSTOM_API_KEY` environment variable or `--provider-api-key` flag
|
||||
|
||||
### Auto-routed Providers
|
||||
|
||||
Any provider in the [models.dev](https://models.dev) database can be used as
|
||||
`provider/model` without a dedicated native integration. Kit auto-routes the
|
||||
request through the matching **wire protocol** based on the provider's npm package
|
||||
(or per-model override), using its `api` URL as the base:
|
||||
|
||||
| npm package | Wire protocol |
|
||||
|-------------|---------------|
|
||||
| `@ai-sdk/openai` | OpenAI (Responses API) |
|
||||
| `@ai-sdk/openai-compatible` | OpenAI (chat completions) |
|
||||
| `@ai-sdk/anthropic` | Anthropic |
|
||||
| `@ai-sdk/google` | Google Gemini |
|
||||
|
||||
Providers with an `api` URL but an unrecognized npm package fall back to the
|
||||
OpenAI-compatible wire. Because routing follows the wire protocol, aggregator/proxy
|
||||
providers work across all of their models — including Claude, GPT, *and* Gemini
|
||||
routes:
|
||||
|
||||
```bash
|
||||
kit --model opencode/claude-haiku-4-5 "Hello" # → Anthropic wire
|
||||
kit --model opencode/gpt-5 "Hello" # → OpenAI wire
|
||||
kit --model opencode/gemini-3.5-flash "Hello" # → Google wire
|
||||
```
|
||||
|
||||
### Model String Format
|
||||
|
||||
```bash
|
||||
|
||||
+157
-4
@@ -31,10 +31,12 @@ using OAuth flows. Stored credentials take precedence over environment variables
|
||||
Available providers:
|
||||
- anthropic: Anthropic Claude API (OAuth)
|
||||
- openai: OpenAI API (OAuth and API key)
|
||||
- copilot: GitHub Copilot (GitHub device login)
|
||||
|
||||
Examples:
|
||||
kit auth login anthropic
|
||||
kit auth login openai
|
||||
kit auth login copilot
|
||||
kit auth logout anthropic
|
||||
kit auth status`,
|
||||
}
|
||||
@@ -54,6 +56,7 @@ environment variables when making API calls.
|
||||
Available providers:
|
||||
- anthropic: Anthropic Claude API (OAuth)
|
||||
- openai: OpenAI ChatGPT Plus/Pro (Codex OAuth)
|
||||
- copilot: GitHub Copilot (GitHub device login, experimental)
|
||||
|
||||
Flags:
|
||||
--set-default Set this provider's default model as the system default
|
||||
@@ -61,7 +64,8 @@ Flags:
|
||||
Examples:
|
||||
kit auth login anthropic
|
||||
kit auth login openai
|
||||
kit auth login openai --set-default`,
|
||||
kit auth login copilot
|
||||
kit auth login copilot --set-default`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runAuthLogin,
|
||||
}
|
||||
@@ -80,10 +84,12 @@ You will need to use environment variables or command-line flags for authenticat
|
||||
Available providers:
|
||||
- anthropic: Anthropic Claude API
|
||||
- openai: OpenAI API
|
||||
- copilot: GitHub Copilot
|
||||
|
||||
Example:
|
||||
kit auth logout anthropic
|
||||
kit auth logout openai`,
|
||||
kit auth logout openai
|
||||
kit auth logout copilot`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runAuthLogout,
|
||||
}
|
||||
@@ -113,6 +119,7 @@ var (
|
||||
var defaultModels = map[string]string{
|
||||
"anthropic": "anthropic/claude-sonnet-4-5-20250929",
|
||||
"openai": "openai/gpt-5.4",
|
||||
"copilot": "copilot/gpt-5.5",
|
||||
}
|
||||
|
||||
// setDefaultModelIfRequested sets the default model for the given provider
|
||||
@@ -143,6 +150,7 @@ func init() {
|
||||
authLoginCmd.Flags().BoolVar(&loginSetDefault, "set-default", false, "Set this provider's default model as the system default after login")
|
||||
}
|
||||
|
||||
// runAuthLogin dispatches OAuth login to the selected provider.
|
||||
func runAuthLogin(cmd *cobra.Command, args []string) error {
|
||||
provider := strings.ToLower(args[0])
|
||||
|
||||
@@ -151,8 +159,10 @@ func runAuthLogin(cmd *cobra.Command, args []string) error {
|
||||
return loginAnthropic()
|
||||
case "openai":
|
||||
return loginOpenAI()
|
||||
case "copilot":
|
||||
return loginCopilot(cmd.Context())
|
||||
default:
|
||||
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic, openai", provider)
|
||||
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic, openai, copilot", provider)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -164,8 +174,10 @@ func runAuthLogout(cmd *cobra.Command, args []string) error {
|
||||
return logoutAnthropic()
|
||||
case "openai":
|
||||
return logoutOpenAI()
|
||||
case "copilot":
|
||||
return logoutCopilot()
|
||||
default:
|
||||
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic, openai", provider)
|
||||
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic, openai, copilot", provider)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -244,9 +256,31 @@ func runAuthStatus(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Check GitHub Copilot credentials
|
||||
fmt.Print("\nGitHub Copilot: ")
|
||||
if hasCopilotCreds, err := cm.HasCopilotCredentials(); err != nil {
|
||||
fmt.Printf("Error checking credentials: %v\n", err)
|
||||
} else if hasCopilotCreds {
|
||||
if creds, err := cm.GetCopilotCredentials(); err != nil {
|
||||
fmt.Printf("Error reading credentials: %v\n", err)
|
||||
} else {
|
||||
status := "✓ Authenticated"
|
||||
if creds.IsExpired() {
|
||||
status = "⚠️ Token expired (will refresh automatically)"
|
||||
} else if creds.NeedsRefresh() {
|
||||
status = "⚠️ Token expires soon (will refresh automatically)"
|
||||
}
|
||||
|
||||
fmt.Printf("%s (GitHub OAuth, stored %s)\n", status, creds.CreatedAt.Format("2006-01-02 15:04:05"))
|
||||
}
|
||||
} else {
|
||||
fmt.Println("✗ Not authenticated")
|
||||
}
|
||||
|
||||
fmt.Println("\nTo authenticate with a provider:")
|
||||
fmt.Println(" kit auth login anthropic")
|
||||
fmt.Println(" kit auth login openai")
|
||||
fmt.Println(" kit auth login copilot")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -517,6 +551,85 @@ func loginOpenAI() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// loginCopilot authenticates GitHub Copilot using GitHub device flow.
|
||||
func loginCopilot(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
cm, err := kit.NewCredentialManager()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize credential manager: %w", err)
|
||||
}
|
||||
|
||||
if hasAuth, err := cm.HasCopilotCredentials(); err == nil && hasAuth {
|
||||
var reauth bool
|
||||
err := huh.NewConfirm().
|
||||
Title("You are already authenticated with GitHub Copilot").
|
||||
Description("Do you want to re-authenticate?").
|
||||
Affirmative("Yes").
|
||||
Negative("No").
|
||||
Value(&reauth).
|
||||
Run()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prompt for re-authentication: %w", err)
|
||||
}
|
||||
if !reauth {
|
||||
fmt.Println("Authentication cancelled.")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
client := auth.NewCopilotOAuthClient()
|
||||
|
||||
fmt.Println("🔐 Starting GitHub Copilot authentication...")
|
||||
fmt.Println("This uses GitHub device login and requires an active GitHub Copilot subscription.")
|
||||
fmt.Println("Experimental: this uses VS Code Copilot Chat client identifiers.")
|
||||
fmt.Println()
|
||||
|
||||
deviceCode, err := client.StartDeviceFlow(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start GitHub device login: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("📱 Open this page and enter the code:")
|
||||
fmt.Printf("\n%s\n\n", deviceCode.VerificationURI)
|
||||
fmt.Printf("Code: %s\n\n", deviceCode.UserCode)
|
||||
auth.TryOpenBrowser(deviceCode.VerificationURI)
|
||||
|
||||
fmt.Println("Waiting for GitHub authorization...")
|
||||
githubToken, err := client.PollDeviceToken(ctx, deviceCode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to complete GitHub device login: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("\n🔄 Exchanging GitHub token for Copilot access token...")
|
||||
creds, err := client.ExchangeGitHubToken(ctx, githubToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get GitHub Copilot token: %w", err)
|
||||
}
|
||||
|
||||
if err := cm.SetCopilotOAuthCredentials(creds); err != nil {
|
||||
return fmt.Errorf("failed to store credentials: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("✅ Successfully authenticated with GitHub Copilot!")
|
||||
fmt.Printf("📁 Credentials stored in: %s\n", cm.GetCredentialsPath())
|
||||
fmt.Println("\n🎉 Your GitHub Copilot credentials will now be used for copilot/* models.")
|
||||
fmt.Println("💡 You can check your authentication status with: kit auth status")
|
||||
|
||||
if err := setDefaultModelIfRequested("copilot"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !loginSetDefault {
|
||||
fmt.Println("\n💡 To set Copilot as your default model, run:")
|
||||
fmt.Println(" kit auth login copilot --set-default")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// callbackServer holds the HTTP server and channel for receiving the OAuth callback
|
||||
type callbackServer struct {
|
||||
Server *http.Server
|
||||
@@ -635,3 +748,43 @@ func logoutOpenAI() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func logoutCopilot() error {
|
||||
cm, err := kit.NewCredentialManager()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize credential manager: %w", err)
|
||||
}
|
||||
|
||||
hasAuth, err := cm.HasCopilotCredentials()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check authentication status: %w", err)
|
||||
}
|
||||
|
||||
if !hasAuth {
|
||||
fmt.Println("You are not currently authenticated with GitHub Copilot.")
|
||||
return nil
|
||||
}
|
||||
|
||||
var confirm bool
|
||||
err = huh.NewConfirm().
|
||||
Title("Remove GitHub Copilot credentials").
|
||||
Description("Are you sure you want to remove your stored credentials?").
|
||||
Affirmative("Yes").
|
||||
Negative("No").
|
||||
Value(&confirm).
|
||||
Run()
|
||||
if err != nil || !confirm {
|
||||
fmt.Println("Logout cancelled.")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := cm.RemoveCopilotCredentials(); err != nil {
|
||||
return fmt.Errorf("failed to remove credentials: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("✓ Successfully logged out from GitHub Copilot!")
|
||||
fmt.Println("You will need to authenticate again with 'kit auth login copilot'.")
|
||||
fmt.Println("Tip: this removes local credentials only. Revoke the GitHub OAuth grant at https://github.com/settings/applications")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
+264
-429
@@ -4,13 +4,11 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
"github.com/mark3labs/kit/internal/extbridge"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
@@ -35,439 +33,276 @@ type extensionContextDeps struct {
|
||||
// the three print routes appropriately for their phase (startup buffering
|
||||
// vs. live runtime routing).
|
||||
//
|
||||
// This consolidates two near-identical 400-line literal expressions that
|
||||
// previously appeared inline in runNormalMode.
|
||||
// The headless half (data access, state, options, tree navigation, skills,
|
||||
// templates, model resolution, subagents) comes from extbridge.BaseContext;
|
||||
// this function overlays the TUI-specific fields and overrides SetModel /
|
||||
// ReloadExtensions with TUI-aware versions.
|
||||
func buildInteractiveExtensionContext(deps extensionContextDeps) extensions.Context {
|
||||
kitInstance := deps.kitInstance
|
||||
appInstance := deps.appInstance
|
||||
usageTracker := deps.usageTracker
|
||||
ctx := deps.ctx
|
||||
|
||||
return extensions.Context{
|
||||
CWD: deps.cwd,
|
||||
Model: deps.modelName,
|
||||
Interactive: deps.interactive,
|
||||
PrintBlock: func(opts extensions.PrintBlockOpts) {
|
||||
appInstance.PrintBlockFromExtension(opts)
|
||||
},
|
||||
SendMessage: func(text string) { appInstance.Run(text) },
|
||||
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
|
||||
Abort: func() { appInstance.Abort() },
|
||||
IsIdle: func() bool { return !appInstance.IsBusy() },
|
||||
Compact: func(cfg extensions.CompactConfig) error {
|
||||
return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
|
||||
},
|
||||
SendMultimodalMessage: func(text string, files []extensions.FilePart) {
|
||||
parts := make([]kit.LLMFilePart, len(files))
|
||||
for i, f := range files {
|
||||
parts[i] = kit.LLMFilePart{
|
||||
Filename: f.Filename,
|
||||
Data: f.Data,
|
||||
MediaType: f.MediaType,
|
||||
}
|
||||
}
|
||||
appInstance.RunWithFiles(text, parts)
|
||||
},
|
||||
GetSessionUsage: func() extensions.SessionUsage {
|
||||
if usageTracker == nil {
|
||||
return extensions.SessionUsage{}
|
||||
}
|
||||
stats := usageTracker.GetSessionStats()
|
||||
return extensions.SessionUsage{
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheReadTokens: stats.TotalCacheReadTokens,
|
||||
TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
RequestCount: stats.RequestCount,
|
||||
}
|
||||
},
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
SetWidget: func(config extensions.WidgetConfig) {
|
||||
kitInstance.Extensions().SetWidget(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveWidget: func(id string) {
|
||||
kitInstance.Extensions().RemoveWidget(id)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetHeader: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetHeader(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveHeader: func() {
|
||||
kitInstance.Extensions().RemoveHeader()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetFooter: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetFooter(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveFooter: func() {
|
||||
kitInstance.Extensions().RemoveFooter()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
PromptSelect: func(config extensions.PromptSelectConfig) extensions.PromptSelectResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "select",
|
||||
Message: config.Message,
|
||||
Options: config.Options,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptSelectResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptSelectResult{Value: resp.Value, Index: resp.Index}
|
||||
},
|
||||
PromptConfirm: func(config extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
def := "false"
|
||||
if config.DefaultValue {
|
||||
def = "true"
|
||||
}
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "confirm",
|
||||
Message: config.Message,
|
||||
Default: def,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptConfirmResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptConfirmResult{Value: resp.Confirmed}
|
||||
},
|
||||
PromptInput: func(config extensions.PromptInputConfig) extensions.PromptInputResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "input",
|
||||
Message: config.Message,
|
||||
Placeholder: config.Placeholder,
|
||||
Default: config.Default,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptInputResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptInputResult{Value: resp.Value}
|
||||
},
|
||||
SetUIVisibility: func(v extensions.UIVisibility) {
|
||||
kitInstance.Extensions().SetUIVisibility(v)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetContextStats: func() extensions.ContextStats {
|
||||
s := kitInstance.GetContextStats()
|
||||
return extensions.ContextStats{
|
||||
EstimatedTokens: s.EstimatedTokens,
|
||||
ContextLimit: s.ContextLimit,
|
||||
UsagePercent: s.UsagePercent,
|
||||
MessageCount: s.MessageCount,
|
||||
}
|
||||
},
|
||||
SetEditor: func(config extensions.EditorConfig) {
|
||||
kitInstance.Extensions().SetEditor(config)
|
||||
// Always use a goroutine for NotifyWidgetUpdate: prog.Send()
|
||||
// deadlocks if called synchronously from inside BubbleTea's
|
||||
// Update() handler. All call sites use go-routines uniformly.
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
ResetEditor: func() {
|
||||
kitInstance.Extensions().ResetEditor()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetMessages: func() []extensions.SessionMessage {
|
||||
return kitInstance.Extensions().GetSessionMessages()
|
||||
},
|
||||
GetSessionPath: func() string {
|
||||
return kitInstance.GetSessionPath()
|
||||
},
|
||||
AppendEntry: func(entryType string, data string) (string, error) {
|
||||
return kitInstance.Extensions().AppendEntry(entryType, data)
|
||||
},
|
||||
GetEntries: func(entryType string) []extensions.ExtensionEntry {
|
||||
return kitInstance.Extensions().GetEntries(entryType)
|
||||
},
|
||||
SetEditorText: func(text string) {
|
||||
appInstance.SetEditorTextFromExtension(text)
|
||||
},
|
||||
SetStatus: func(key string, text string, priority int) {
|
||||
kitInstance.Extensions().SetStatus(extensions.StatusBarEntry{
|
||||
Key: key,
|
||||
Text: text,
|
||||
Priority: priority,
|
||||
})
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveStatus: func(key string) {
|
||||
kitInstance.Extensions().RemoveStatus(key)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetOption: func(name string) string {
|
||||
return kitInstance.Extensions().GetOption(name)
|
||||
},
|
||||
SetOption: func(name string, value string) {
|
||||
kitInstance.Extensions().SetOption(name, value)
|
||||
},
|
||||
SetModel: func(modelString string) error {
|
||||
// Capture previous model for the ModelChange event.
|
||||
previousModel := kitInstance.Extensions().GetContext().Model
|
||||
err := kitInstance.SetModel(context.Background(), modelString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Notify TUI so it updates model in status bar.
|
||||
p, m, _ := models.ParseModelString(modelString)
|
||||
appInstance.NotifyModelChanged(p, m)
|
||||
// Update the context's Model field so handlers see it.
|
||||
kitInstance.Extensions().UpdateContextModel(modelString)
|
||||
// Fire OnModelChange event to extensions.
|
||||
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
|
||||
// Update usage tracker with new model info for correct token counting.
|
||||
if usageTracker != nil {
|
||||
newProvider, newModel, _ := models.ParseModelString(modelString)
|
||||
if newProvider != "unknown" && newModel != "unknown" && newProvider != "ollama" {
|
||||
registry := models.GetGlobalRegistry()
|
||||
if modelInfo := registry.LookupModel(newProvider, newModel); modelInfo != nil {
|
||||
// Check OAuth status for Anthropic models
|
||||
isOAuth := false
|
||||
if newProvider == "anthropic" {
|
||||
_, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key"))
|
||||
if err == nil && strings.HasPrefix(source, "stored OAuth") {
|
||||
isOAuth = true
|
||||
}
|
||||
}
|
||||
usageTracker.UpdateModelInfo(modelInfo, newProvider, isOAuth)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetAvailableModels: func() []extensions.ModelInfoEntry {
|
||||
return kitInstance.GetAvailableModels()
|
||||
},
|
||||
EmitCustomEvent: func(name string, data string) {
|
||||
kitInstance.Extensions().EmitCustomEvent(name, data)
|
||||
},
|
||||
Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
|
||||
return kitInstance.ExecuteCompletion(context.Background(), req)
|
||||
},
|
||||
SuspendTUI: func(callback func()) error {
|
||||
return appInstance.SuspendTUI(callback)
|
||||
},
|
||||
RenderMessage: func(rendererName, content string) {
|
||||
renderer := kitInstance.Extensions().GetMessageRenderer(rendererName)
|
||||
if renderer == nil || renderer.Render == nil {
|
||||
appInstance.PrintFromExtension("", content)
|
||||
return
|
||||
}
|
||||
w, _, _ := term.GetSize(int(os.Stdout.Fd()))
|
||||
if w == 0 {
|
||||
w = 80
|
||||
}
|
||||
rendered := renderer.Render(content, w)
|
||||
appInstance.PrintFromExtension("", rendered)
|
||||
},
|
||||
ReloadExtensions: func() error {
|
||||
err := kitInstance.Extensions().Reload()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Notify TUI that widgets/status/commands may have changed.
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
return nil
|
||||
},
|
||||
GetAllTools: func() []extensions.ToolInfo {
|
||||
return kitInstance.Extensions().GetToolInfos()
|
||||
},
|
||||
SetActiveTools: func(names []string) {
|
||||
kitInstance.Extensions().SetActiveTools(names)
|
||||
},
|
||||
RegisterTheme: func(name string, config extensions.ThemeColorConfig) {
|
||||
tc := func(c extensions.ThemeColor) [2]string { return [2]string{c.Light, c.Dark} }
|
||||
ui.RegisterThemeFromConfig(name,
|
||||
tc(config.Primary), tc(config.Secondary),
|
||||
tc(config.Success), tc(config.Warning),
|
||||
tc(config.Error), tc(config.Info),
|
||||
tc(config.Text), tc(config.Muted),
|
||||
tc(config.VeryMuted), tc(config.Background),
|
||||
tc(config.Border), tc(config.MutedBorder),
|
||||
tc(config.System), tc(config.Tool),
|
||||
tc(config.Accent), tc(config.Highlight),
|
||||
tc(config.MdHeading), tc(config.MdLink),
|
||||
tc(config.MdKeyword), tc(config.MdString),
|
||||
tc(config.MdNumber), tc(config.MdComment),
|
||||
)
|
||||
},
|
||||
SetTheme: func(name string) error {
|
||||
return ui.ApplyTheme(name)
|
||||
},
|
||||
ListThemes: func() []string {
|
||||
return ui.ListThemes()
|
||||
},
|
||||
ShowOverlay: func(config extensions.OverlayConfig) extensions.OverlayResult {
|
||||
ch := make(chan app.OverlayResponse, 1)
|
||||
appInstance.SendOverlayRequest(app.OverlayRequestEvent{
|
||||
Title: config.Title,
|
||||
Content: config.Content.Text,
|
||||
Markdown: config.Content.Markdown,
|
||||
BorderColor: config.Style.BorderColor,
|
||||
Background: config.Style.Background,
|
||||
Width: config.Width,
|
||||
MaxHeight: config.MaxHeight,
|
||||
Anchor: string(config.Anchor),
|
||||
Actions: config.Actions,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.OverlayResult{Cancelled: true, Index: -1}
|
||||
}
|
||||
return extensions.OverlayResult{
|
||||
Action: resp.Action,
|
||||
Index: resp.Index,
|
||||
}
|
||||
},
|
||||
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
return extbridge.SpawnSubagent(ctx, kitInstance, config)
|
||||
},
|
||||
// -------------------------------------------------------------------
|
||||
// Tree Navigation API
|
||||
// -------------------------------------------------------------------
|
||||
GetTreeNode: func(entryID string) *extensions.TreeNode {
|
||||
node := kitInstance.GetTreeNode(entryID)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
return &extensions.TreeNode{
|
||||
ID: node.ID,
|
||||
ParentID: node.ParentID,
|
||||
Type: node.Type,
|
||||
Role: node.Role,
|
||||
Content: node.Content,
|
||||
Model: node.Model,
|
||||
Provider: node.Provider,
|
||||
Timestamp: node.Timestamp,
|
||||
Children: node.Children,
|
||||
}
|
||||
},
|
||||
GetCurrentBranch: func() []extensions.TreeNode {
|
||||
nodes := kitInstance.GetCurrentBranch()
|
||||
result := make([]extensions.TreeNode, len(nodes))
|
||||
for i, n := range nodes {
|
||||
result[i] = extensions.TreeNode{
|
||||
ID: n.ID,
|
||||
ParentID: n.ParentID,
|
||||
Type: n.Type,
|
||||
Role: n.Role,
|
||||
Content: n.Content,
|
||||
Model: n.Model,
|
||||
Provider: n.Provider,
|
||||
Timestamp: n.Timestamp,
|
||||
Children: n.Children,
|
||||
}
|
||||
}
|
||||
return result
|
||||
},
|
||||
GetChildren: func(parentID string) []string {
|
||||
return kitInstance.GetChildren(parentID)
|
||||
},
|
||||
NavigateTo: func(entryID string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.NavigateTo(entryID)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
SummarizeBranch: func(fromID, toID string) string {
|
||||
summary, _ := kitInstance.SummarizeBranch(fromID, toID)
|
||||
return summary
|
||||
},
|
||||
CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.CollapseBranch(fromID, toID, summary)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
ec := extbridge.BaseContext(deps.ctx, kitInstance)
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Skill Loading API
|
||||
// -------------------------------------------------------------------
|
||||
LoadSkill: func(path string) (*extensions.Skill, string) {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
return s, err
|
||||
},
|
||||
LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult {
|
||||
return kitInstance.LoadSkillsFromDirForExtension(dir)
|
||||
},
|
||||
DiscoverSkills: func() extensions.SkillLoadResult {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
return extensions.SkillLoadResult{Skills: skills}
|
||||
},
|
||||
InjectSkillAsContext: func(skillName string) string {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
for _, s := range skills {
|
||||
if s.Name == skillName {
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("skill not found: %s", skillName)
|
||||
},
|
||||
InjectRawSkillAsContext: func(path string) string {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
if err != "" {
|
||||
return err
|
||||
}
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
},
|
||||
GetAvailableSkills: func() []extensions.Skill {
|
||||
return kitInstance.DiscoverSkillsForExtension()
|
||||
},
|
||||
ec.CWD = deps.cwd
|
||||
ec.Model = deps.modelName
|
||||
ec.Interactive = deps.interactive
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Template Parsing API
|
||||
// -------------------------------------------------------------------
|
||||
ParseTemplate: func(name, content string) extensions.PromptTemplate {
|
||||
return kit.ParseTemplate(name, content)
|
||||
},
|
||||
RenderTemplate: func(tpl extensions.PromptTemplate, vars map[string]string) string {
|
||||
return kit.RenderTemplate(tpl, vars)
|
||||
},
|
||||
ParseArguments: func(input string, pattern extensions.ArgumentPattern) extensions.ParseResult {
|
||||
return kit.ParseArguments(input, pattern)
|
||||
},
|
||||
SimpleParseArguments: func(input string, count int) []string {
|
||||
return kit.SimpleParseArguments(input, count)
|
||||
},
|
||||
EvaluateModelConditional: func(condition string) bool {
|
||||
return kit.EvaluateModelConditional(kitInstance.Extensions().GetContext().Model, condition)
|
||||
},
|
||||
RenderWithModelConditionals: func(content string) string {
|
||||
return kit.RenderWithModelConditionals(content, kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Model Resolution API
|
||||
// -------------------------------------------------------------------
|
||||
ResolveModelChain: func(preferences []string) extensions.ModelResolutionResult {
|
||||
return kit.ResolveModelChain(preferences)
|
||||
},
|
||||
GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) {
|
||||
return kit.GetModelCapabilities(model)
|
||||
},
|
||||
CheckModelAvailable: func(model string) bool {
|
||||
return kit.CheckModelAvailable(model)
|
||||
},
|
||||
GetCurrentProvider: func() string {
|
||||
return kit.GetCurrentProvider(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
GetCurrentModelID: func() string {
|
||||
return kit.GetCurrentModelID(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
ec.PrintBlock = func(opts extensions.PrintBlockOpts) {
|
||||
appInstance.PrintBlockFromExtension(opts)
|
||||
}
|
||||
ec.SendMessage = func(text string) { appInstance.Run(text) }
|
||||
ec.CancelAndSend = func(text string) { appInstance.InterruptAndSend(text) }
|
||||
ec.Abort = func() { appInstance.Abort() }
|
||||
ec.IsIdle = func() bool { return !appInstance.IsBusy() }
|
||||
ec.Compact = func(cfg extensions.CompactConfig) error {
|
||||
return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
|
||||
}
|
||||
ec.SendMultimodalMessage = func(text string, files []extensions.FilePart) {
|
||||
parts := make([]kit.LLMFilePart, len(files))
|
||||
for i, f := range files {
|
||||
parts[i] = kit.LLMFilePart{
|
||||
Filename: f.Filename,
|
||||
Data: f.Data,
|
||||
MediaType: f.MediaType,
|
||||
}
|
||||
}
|
||||
appInstance.RunWithFiles(text, parts)
|
||||
}
|
||||
ec.GetSessionUsage = func() extensions.SessionUsage {
|
||||
if usageTracker == nil {
|
||||
return extensions.SessionUsage{}
|
||||
}
|
||||
stats := usageTracker.GetSessionStats()
|
||||
return extensions.SessionUsage{
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheReadTokens: stats.TotalCacheReadTokens,
|
||||
TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
RequestCount: stats.RequestCount,
|
||||
}
|
||||
}
|
||||
ec.Exit = func() { appInstance.QuitFromExtension() }
|
||||
|
||||
// TUI widgets/chrome — mutate runner state, then notify the TUI.
|
||||
// Always use a goroutine for NotifyWidgetUpdate: prog.Send() deadlocks
|
||||
// if called synchronously from inside BubbleTea's Update() handler.
|
||||
// All call sites use go-routines uniformly.
|
||||
ec.SetWidget = func(config extensions.WidgetConfig) {
|
||||
kitInstance.Extensions().SetWidget(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.RemoveWidget = func(id string) {
|
||||
kitInstance.Extensions().RemoveWidget(id)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.SetHeader = func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetHeader(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.RemoveHeader = func() {
|
||||
kitInstance.Extensions().RemoveHeader()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.SetFooter = func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetFooter(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.RemoveFooter = func() {
|
||||
kitInstance.Extensions().RemoveFooter()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.SetUIVisibility = func(v extensions.UIVisibility) {
|
||||
kitInstance.Extensions().SetUIVisibility(v)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.SetEditor = func(config extensions.EditorConfig) {
|
||||
kitInstance.Extensions().SetEditor(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.ResetEditor = func() {
|
||||
kitInstance.Extensions().ResetEditor()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.SetEditorText = func(text string) {
|
||||
appInstance.SetEditorTextFromExtension(text)
|
||||
}
|
||||
ec.SetStatus = func(key string, text string, priority int) {
|
||||
kitInstance.Extensions().SetStatus(extensions.StatusBarEntry{
|
||||
Key: key,
|
||||
Text: text,
|
||||
Priority: priority,
|
||||
})
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.RemoveStatus = func(key string) {
|
||||
kitInstance.Extensions().RemoveStatus(key)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
|
||||
// Interactive prompts — channel-based round trips through the TUI.
|
||||
ec.PromptSelect = func(config extensions.PromptSelectConfig) extensions.PromptSelectResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "select",
|
||||
Message: config.Message,
|
||||
Options: config.Options,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptSelectResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptSelectResult{Value: resp.Value, Index: resp.Index}
|
||||
}
|
||||
ec.PromptConfirm = func(config extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
def := "false"
|
||||
if config.DefaultValue {
|
||||
def = "true"
|
||||
}
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "confirm",
|
||||
Message: config.Message,
|
||||
Default: def,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptConfirmResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptConfirmResult{Value: resp.Confirmed}
|
||||
}
|
||||
ec.PromptInput = func(config extensions.PromptInputConfig) extensions.PromptInputResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "input",
|
||||
Message: config.Message,
|
||||
Placeholder: config.Placeholder,
|
||||
Default: config.Default,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptInputResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptInputResult{Value: resp.Value}
|
||||
}
|
||||
ec.ShowOverlay = func(config extensions.OverlayConfig) extensions.OverlayResult {
|
||||
ch := make(chan app.OverlayResponse, 1)
|
||||
appInstance.SendOverlayRequest(app.OverlayRequestEvent{
|
||||
Title: config.Title,
|
||||
Content: config.Content.Text,
|
||||
Markdown: config.Content.Markdown,
|
||||
BorderColor: config.Style.BorderColor,
|
||||
Background: config.Style.Background,
|
||||
Width: config.Width,
|
||||
MaxHeight: config.MaxHeight,
|
||||
Anchor: string(config.Anchor),
|
||||
Actions: config.Actions,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.OverlayResult{Cancelled: true, Index: -1}
|
||||
}
|
||||
return extensions.OverlayResult{
|
||||
Action: resp.Action,
|
||||
Index: resp.Index,
|
||||
}
|
||||
}
|
||||
ec.SuspendTUI = func(callback func()) error {
|
||||
return appInstance.SuspendTUI(callback)
|
||||
}
|
||||
|
||||
// TUI-aware model switch: also notifies the TUI status bar and
|
||||
// refreshes the usage tracker for correct token counting.
|
||||
ec.SetModel = func(modelString string) error {
|
||||
// Capture previous model for the ModelChange event.
|
||||
previousModel := kitInstance.Extensions().GetContext().Model
|
||||
err := kitInstance.SetModel(context.Background(), modelString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Notify TUI so it updates model in status bar.
|
||||
p, m, _ := models.ParseModelString(modelString)
|
||||
appInstance.NotifyModelChanged(p, m)
|
||||
// Update the context's Model field so handlers see it.
|
||||
kitInstance.Extensions().UpdateContextModel(modelString)
|
||||
// Fire OnModelChange event to extensions.
|
||||
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
|
||||
// Update usage tracker with new model info for correct token counting.
|
||||
ui.UpdateUsageTrackerForModel(usageTracker, modelString, viper.GetString("provider-api-key"))
|
||||
return nil
|
||||
}
|
||||
|
||||
ec.RenderMessage = func(rendererName, content string) {
|
||||
renderer := kitInstance.Extensions().GetMessageRenderer(rendererName)
|
||||
if renderer == nil || renderer.Render == nil {
|
||||
appInstance.PrintFromExtension("", content)
|
||||
return
|
||||
}
|
||||
w, _, _ := term.GetSize(int(os.Stdout.Fd()))
|
||||
if w == 0 {
|
||||
w = 80
|
||||
}
|
||||
rendered := renderer.Render(content, w)
|
||||
appInstance.PrintFromExtension("", rendered)
|
||||
}
|
||||
ec.ReloadExtensions = func() error {
|
||||
err := kitInstance.Extensions().Reload()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Notify TUI that widgets/status/commands may have changed.
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Theme management (TUI only).
|
||||
ec.RegisterTheme = func(name string, config extensions.ThemeColorConfig) {
|
||||
tc := func(c extensions.ThemeColor) [2]string { return [2]string{c.Light, c.Dark} }
|
||||
ui.RegisterThemeFromConfig(name,
|
||||
tc(config.Primary), tc(config.Secondary),
|
||||
tc(config.Success), tc(config.Warning),
|
||||
tc(config.Error), tc(config.Info),
|
||||
tc(config.Text), tc(config.Muted),
|
||||
tc(config.VeryMuted), tc(config.Background),
|
||||
tc(config.Border), tc(config.MutedBorder),
|
||||
tc(config.System), tc(config.Tool),
|
||||
tc(config.Accent), tc(config.Highlight),
|
||||
tc(config.MdHeading), tc(config.MdLink),
|
||||
tc(config.MdKeyword), tc(config.MdString),
|
||||
tc(config.MdNumber), tc(config.MdComment),
|
||||
)
|
||||
}
|
||||
ec.SetTheme = func(name string) error {
|
||||
return ui.ApplyTheme(name)
|
||||
}
|
||||
ec.ListThemes = func() []string {
|
||||
return ui.ListThemes()
|
||||
}
|
||||
|
||||
// Skill context-injection (drives a new agent turn through the TUI).
|
||||
ec.InjectSkillAsContext = func(skillName string) string {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
for _, s := range skills {
|
||||
if s.Name == skillName {
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("skill not found: %s", skillName)
|
||||
}
|
||||
ec.InjectRawSkillAsContext = func(path string) string {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
if err != "" {
|
||||
return err
|
||||
}
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
}
|
||||
|
||||
return ec
|
||||
}
|
||||
|
||||
+248
-102
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
@@ -71,8 +70,14 @@ var (
|
||||
|
||||
// Extensions control
|
||||
noExtensionsFlag bool
|
||||
noCoreToolsFlag bool
|
||||
extensionPaths []string
|
||||
|
||||
// Skills control
|
||||
noSkillsFlag bool
|
||||
skillsPaths []string
|
||||
skillsDir string
|
||||
|
||||
// TLS configuration
|
||||
tlsSkipVerify bool
|
||||
|
||||
@@ -278,9 +283,19 @@ func init() {
|
||||
BoolVar(&noSessionFlag, "no-session", false, "ephemeral mode — no session persistence")
|
||||
rootCmd.PersistentFlags().
|
||||
BoolVar(&noExtensionsFlag, "no-extensions", false, "disable all extensions")
|
||||
rootCmd.PersistentFlags().
|
||||
BoolVar(&noCoreToolsFlag, "no-core-tools", false, "disable all built-in core tools (bash, read, write, edit, grep, find, ls, subagent)")
|
||||
rootCmd.PersistentFlags().
|
||||
StringSliceVarP(&extensionPaths, "extension", "e", nil, "load additional extension file(s)")
|
||||
|
||||
// Skills flags
|
||||
rootCmd.PersistentFlags().
|
||||
BoolVar(&noSkillsFlag, "no-skills", false, "disable skill loading (auto-discovery and explicit)")
|
||||
rootCmd.PersistentFlags().
|
||||
StringSliceVar(&skillsPaths, "skill", nil, "load skill file or directory (repeatable)")
|
||||
rootCmd.PersistentFlags().
|
||||
StringVar(&skillsDir, "skills-dir", "", "override the project-local skills directory for auto-discovery")
|
||||
|
||||
flags := rootCmd.PersistentFlags()
|
||||
flags.StringVar(&providerURL, "provider-url", "", "base URL for the provider API (applies to OpenAI, Anthropic, Ollama, and Google)")
|
||||
flags.StringVar(&providerAPIKey, "provider-api-key", "", "API key for the provider (applies to OpenAI, Anthropic, and Google)")
|
||||
@@ -327,9 +342,13 @@ func init() {
|
||||
_ = viper.BindPFlag("main-gpu", rootCmd.PersistentFlags().Lookup("main-gpu"))
|
||||
_ = viper.BindPFlag("tls-skip-verify", rootCmd.PersistentFlags().Lookup("tls-skip-verify"))
|
||||
_ = viper.BindPFlag("no-extensions", rootCmd.PersistentFlags().Lookup("no-extensions"))
|
||||
_ = viper.BindPFlag("no-core-tools", rootCmd.PersistentFlags().Lookup("no-core-tools"))
|
||||
_ = viper.BindPFlag("extension", rootCmd.PersistentFlags().Lookup("extension"))
|
||||
_ = viper.BindPFlag("prompt-template", rootCmd.PersistentFlags().Lookup("prompt-template"))
|
||||
_ = viper.BindPFlag("no-prompt-templates", rootCmd.PersistentFlags().Lookup("no-prompt-templates"))
|
||||
_ = viper.BindPFlag("no-skills", rootCmd.PersistentFlags().Lookup("no-skills"))
|
||||
_ = viper.BindPFlag("skill", rootCmd.PersistentFlags().Lookup("skill"))
|
||||
_ = viper.BindPFlag("skills-dir", rootCmd.PersistentFlags().Lookup("skills-dir"))
|
||||
|
||||
// Defaults are already set in flag definitions, no need to duplicate in viper
|
||||
|
||||
@@ -673,8 +692,8 @@ func globalShortcutsProviderForUI(k *kit.Kit) func() map[string]func() {
|
||||
}
|
||||
}
|
||||
|
||||
func runNormalMode(ctx context.Context) error {
|
||||
// Validate flag combinations
|
||||
// validateModeFlags rejects invalid flag combinations for the root command.
|
||||
func validateModeFlags() error {
|
||||
if quietFlag && positionalPrompt == "" {
|
||||
return fmt.Errorf("--quiet requires a prompt (e.g. kit \"your question\" --quiet)")
|
||||
}
|
||||
@@ -687,21 +706,14 @@ func runNormalMode(ctx context.Context) error {
|
||||
if noExitFlag && positionalPrompt == "" {
|
||||
return fmt.Errorf("--no-exit requires a prompt (e.g. kit \"your question\" --no-exit)")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set up logging
|
||||
if debugMode {
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
}
|
||||
|
||||
// Update debug mode from viper
|
||||
if viper.GetBool("debug") && !debugMode {
|
||||
debugMode = viper.GetBool("debug")
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
}
|
||||
|
||||
// Restore persisted model preference when no explicit --model flag or
|
||||
// config file model is set. Precedence: CLI flag > config file > saved
|
||||
// preference > built-in default. This mirrors how themes are persisted.
|
||||
// restorePersistedPreferences applies saved model / thinking-level
|
||||
// preferences into viper when neither a CLI flag nor a config-file value
|
||||
// takes precedence. Precedence: CLI flag > config file > saved preference >
|
||||
// built-in default. This mirrors how themes are persisted.
|
||||
func restorePersistedPreferences() {
|
||||
// Skip custom/* models unless --provider-url is also provided, since the
|
||||
// custom provider requires a URL that was only valid for the previous session.
|
||||
if !modelFlagChanged && !viper.InConfig("model") {
|
||||
@@ -720,6 +732,15 @@ func runNormalMode(ctx context.Context) error {
|
||||
viper.Set("thinking-level", pref)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// applyProviderURLRouting rewrites the model in viper when --provider-url
|
||||
// is set, routing requests through the "custom" (OpenAI-compatible)
|
||||
// provider. Must run after restorePersistedPreferences.
|
||||
func applyProviderURLRouting() {
|
||||
if viper.GetString("provider-url") == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// When --provider-url is set but no explicit --model was provided,
|
||||
// default to "custom/custom" so the user doesn't need to remember a
|
||||
@@ -727,18 +748,53 @@ func runNormalMode(ctx context.Context) error {
|
||||
// This intentionally overrides saved preferences but respects config-file
|
||||
// models — if you specify a model in ~/.kit.yml, it will be used with
|
||||
// custom/custom's provider routing.
|
||||
if viper.GetString("provider-url") != "" && !modelFlagChanged && !viper.InConfig("model") {
|
||||
if !modelFlagChanged && !viper.InConfig("model") {
|
||||
viper.Set("model", "custom/custom")
|
||||
}
|
||||
|
||||
// When --provider-url is set with an explicit --model that lacks a provider
|
||||
// prefix (no "/"), auto-prefix with "custom/" for OpenAI-compatible endpoints.
|
||||
if viper.GetString("provider-url") != "" && modelFlagChanged {
|
||||
// When --provider-url is set with an explicit --model, route through the
|
||||
// "custom" provider (OpenAI-compatible wire). This honors the user's
|
||||
// intent: passing a custom URL means "use THIS endpoint", not "speak
|
||||
// the Google/Anthropic/etc. wire protocol against this endpoint".
|
||||
//
|
||||
// Any provider prefix on the model is stripped so a model name that
|
||||
// happens to collide with a known provider (e.g. `google/gemma-4-12b`
|
||||
// served by LM Studio) still resolves correctly. If you genuinely need
|
||||
// to point a non-OpenAI wire (Anthropic, Google, ...) at a proxy URL,
|
||||
// use the explicit `custom/<name>` form to opt out of the rewrite by
|
||||
// configuring the proxy as that provider in your config file instead.
|
||||
if modelFlagChanged {
|
||||
model := viper.GetString("model")
|
||||
if model != "" && !strings.Contains(model, "/") {
|
||||
viper.Set("model", "custom/"+model)
|
||||
if model != "" {
|
||||
name := model
|
||||
if _, after, ok := strings.Cut(model, "/"); ok {
|
||||
name = after
|
||||
}
|
||||
if !strings.HasPrefix(model, "custom/") {
|
||||
viper.Set("model", "custom/"+name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func runNormalMode(ctx context.Context) error {
|
||||
if err := validateModeFlags(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set up logging
|
||||
if debugMode {
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
}
|
||||
|
||||
// Update debug mode from viper
|
||||
if viper.GetBool("debug") && !debugMode {
|
||||
debugMode = viper.GetBool("debug")
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
}
|
||||
|
||||
restorePersistedPreferences()
|
||||
applyProviderURLRouting()
|
||||
|
||||
// Load MCP configuration.
|
||||
mcpConfig, err := config.LoadAndValidateConfig()
|
||||
@@ -772,13 +828,17 @@ func runNormalMode(ctx context.Context) error {
|
||||
var appInstancePtr *app.App
|
||||
|
||||
kitOpts := &kit.Options{
|
||||
Quiet: quietFlag,
|
||||
Debug: debugMode,
|
||||
NoSession: noSessionFlag,
|
||||
Continue: continueFlag,
|
||||
SessionPath: sessionPath,
|
||||
AutoCompact: autoCompactFlag,
|
||||
MCPAuthHandler: authHandler,
|
||||
Quiet: quietFlag,
|
||||
Debug: debugMode,
|
||||
NoSession: noSessionFlag,
|
||||
Continue: continueFlag,
|
||||
SessionPath: sessionPath,
|
||||
AutoCompact: autoCompactFlag,
|
||||
MCPAuthHandler: authHandler,
|
||||
DisableCoreTools: viper.GetBool("no-core-tools"),
|
||||
NoSkills: noSkillsFlag,
|
||||
Skills: skillsPaths,
|
||||
SkillsDir: skillsDir,
|
||||
// This callback is called when each MCP server finishes loading.
|
||||
// We use a closure that captures appInstancePtr which is set after
|
||||
// app.New() is called below.
|
||||
@@ -899,8 +959,9 @@ func runNormalMode(ctx context.Context) error {
|
||||
appInstance: appInstance,
|
||||
usageTracker: usageTracker,
|
||||
})
|
||||
|
||||
// During startup, buffer extension messages so they appear after the banner.
|
||||
extCtx.Print = func(text string) {
|
||||
// Capture messages during startup, print after startup banner.
|
||||
startupExtensionMessages = append(startupExtensionMessages, text)
|
||||
}
|
||||
extCtx.PrintInfo = func(text string) {
|
||||
@@ -910,18 +971,12 @@ func runNormalMode(ctx context.Context) error {
|
||||
startupExtensionMessages = append(startupExtensionMessages, text)
|
||||
}
|
||||
kitInstance.Extensions().SetContext(extCtx)
|
||||
if err := kitInstance.Extensions().InitStatePersistence(); err != nil {
|
||||
log.Printf("WARN extension state init failed: %v", err)
|
||||
}
|
||||
kitInstance.Extensions().EmitSessionStart()
|
||||
|
||||
// Restore normal print functions for runtime use.
|
||||
extCtx = buildInteractiveExtensionContext(extensionContextDeps{
|
||||
ctx: ctx,
|
||||
cwd: cwd,
|
||||
modelName: modelName,
|
||||
interactive: positionalPrompt == "",
|
||||
kitInstance: kitInstance,
|
||||
appInstance: appInstance,
|
||||
usageTracker: usageTracker,
|
||||
})
|
||||
extCtx.Print = func(text string) { appInstance.PrintFromExtension("", text) }
|
||||
extCtx.PrintInfo = func(text string) { appInstance.PrintFromExtension("info", text) }
|
||||
extCtx.PrintError = func(text string) { appInstance.PrintFromExtension("error", text) }
|
||||
@@ -1149,23 +1204,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
// NotifyModelChanged calls prog.Send() which deadlocks. The UI layer
|
||||
// updates m.providerName and m.modelName directly after setModel returns.
|
||||
// Update usage tracker with new model info for correct token counting.
|
||||
if usageTracker != nil {
|
||||
newProvider, newModel, _ := models.ParseModelString(modelString)
|
||||
if newProvider != "unknown" && newModel != "unknown" && newProvider != "ollama" {
|
||||
registry := models.GetGlobalRegistry()
|
||||
if modelInfo := registry.LookupModel(newProvider, newModel); modelInfo != nil {
|
||||
// Check OAuth status for Anthropic models
|
||||
isOAuth := false
|
||||
if newProvider == "anthropic" {
|
||||
_, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key"))
|
||||
if err == nil && strings.HasPrefix(source, "stored OAuth") {
|
||||
isOAuth = true
|
||||
}
|
||||
}
|
||||
usageTracker.UpdateModelInfo(modelInfo, newProvider, isOAuth)
|
||||
}
|
||||
}
|
||||
}
|
||||
ui.UpdateUsageTrackerForModel(usageTracker, modelString, viper.GetString("provider-api-key"))
|
||||
return nil
|
||||
}
|
||||
emitModelChangeForUI := func(newModel, previousModel, source string) {
|
||||
@@ -1265,9 +1304,57 @@ func runNormalMode(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Bundle all the shared dependencies into a single struct that both
|
||||
// run-mode entry points consume. This keeps the dispatch site and the
|
||||
// function signatures readable.
|
||||
deps := runModeDeps{
|
||||
appInstance: appInstance,
|
||||
cli: cli,
|
||||
modelName: modelName,
|
||||
providerName: parsedProvider,
|
||||
loadingMessage: kitInstance.GetLoadingMessage(),
|
||||
serverNames: serverNames,
|
||||
toolNames: toolNames,
|
||||
mcpToolCount: mcpToolCount,
|
||||
extensionToolCount: extensionToolCount,
|
||||
usageTracker: usageTracker,
|
||||
extCommands: extCommands,
|
||||
promptTemplates: promptTemplates,
|
||||
contextPaths: contextPaths,
|
||||
skillItems: skillItems,
|
||||
extensionItems: extensionItems,
|
||||
getPromptTemplates: getPromptTemplates,
|
||||
getSkillItems: getSkillItems,
|
||||
getExtensionItems: getExtensionItems,
|
||||
getToolNames: getToolNames,
|
||||
getMCPToolCount: getMCPToolCount,
|
||||
mcpPrompts: mcpPrompts,
|
||||
getMCPPrompts: getMCPPrompts,
|
||||
expandMCPPrompt: expandMCPPrompt,
|
||||
getWidgets: getWidgets,
|
||||
getHeader: getHeader,
|
||||
getFooter: getFooter,
|
||||
getToolRenderer: getToolRenderer,
|
||||
getEditorInterceptor: getEditorInterceptor,
|
||||
getUIVisibility: getUIVisibility,
|
||||
getStatusBarEntries: getStatusBarEntries,
|
||||
emitBeforeFork: emitBeforeFork,
|
||||
emitBeforeSessionSwitch: emitBeforeSessionSwitch,
|
||||
getGlobalShortcuts: getGlobalShortcuts,
|
||||
getExtensionCommands: getExtensionCommands,
|
||||
setModel: setModelForUI,
|
||||
emitModelChange: emitModelChangeForUI,
|
||||
isReasoningModel: kitInstance.IsReasoningModel(),
|
||||
thinkingLevel: kitInstance.GetThinkingLevel(),
|
||||
setThinkingLevel: setThinkingLevelForUI,
|
||||
switchSession: switchSessionForUI,
|
||||
reloadExtensions: reloadExtensionsForUI,
|
||||
startupExtensionMessages: startupExtensionMessages,
|
||||
}
|
||||
|
||||
// Check if running in non-interactive mode
|
||||
if positionalPrompt != "" {
|
||||
return runNonInteractiveModeApp(ctx, appInstance, cli, positionalPrompt, quietFlag, jsonFlag, noExitFlag, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, extensionItems, getPromptTemplates, getSkillItems, getExtensionItems, getToolNames, getMCPToolCount, mcpPrompts, getMCPPrompts, expandMCPPrompt, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI)
|
||||
return runNonInteractiveModeApp(ctx, deps, positionalPrompt, quietFlag, jsonFlag, noExitFlag)
|
||||
}
|
||||
|
||||
// Quiet mode is not allowed in interactive mode
|
||||
@@ -1275,7 +1362,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
return fmt.Errorf("--quiet requires a prompt")
|
||||
}
|
||||
|
||||
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, extensionItems, getPromptTemplates, getSkillItems, getExtensionItems, getToolNames, getMCPToolCount, mcpPrompts, getMCPPrompts, expandMCPPrompt, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI, startupExtensionMessages)
|
||||
return runInteractiveModeBubbleTea(ctx, deps)
|
||||
}
|
||||
|
||||
// runNonInteractiveModeApp executes a single prompt via the app layer and exits,
|
||||
@@ -1288,7 +1375,10 @@ func runNormalMode(ctx context.Context) error {
|
||||
//
|
||||
// When --no-exit is set, after the prompt completes the interactive BubbleTea
|
||||
// TUI is started so the user can continue the conversation.
|
||||
func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui.CLI, prompt string, quiet, jsonOutput, noExit bool, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, extensionItems []ui.ExtensionItem, getPromptTemplates func() []*prompts.PromptTemplate, getSkillItems func() []ui.SkillItem, getExtensionItems func() []ui.ExtensionItem, getToolNames func() []string, getMCPToolCount func() int, mcpPrompts []ui.MCPPromptInfo, getMCPPrompts func() []ui.MCPPromptInfo, expandMCPPrompt func(string, string, map[string]string) (*ui.MCPPromptExpandResult, error), getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error) error {
|
||||
func runNonInteractiveModeApp(ctx context.Context, deps runModeDeps, prompt string, quiet, jsonOutput, noExit bool) error {
|
||||
appInstance := deps.appInstance
|
||||
cli := deps.cli
|
||||
modelName := deps.modelName
|
||||
// Expand @file references in the prompt before sending to the agent.
|
||||
// Text files are XML-inlined; binary files are extracted as multimodal parts.
|
||||
var fileParts []kit.LLMFilePart
|
||||
@@ -1349,12 +1439,67 @@ func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui
|
||||
|
||||
// If --no-exit was requested, hand off to the interactive TUI.
|
||||
if noExit {
|
||||
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, providerName, loadingMessage, serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, extensionItems, getPromptTemplates, getSkillItems, getExtensionItems, getToolNames, getMCPToolCount, mcpPrompts, getMCPPrompts, expandMCPPrompt, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModel, emitModelChange, isReasoningModel, thinkingLevel, setThinkingLevel, switchSession, reloadExtensions, nil)
|
||||
// Drop the cli (interactive mode doesn't use it) and clear the
|
||||
// interactive-only fields explicitly; deps carries everything else.
|
||||
interactive := deps
|
||||
interactive.cli = nil
|
||||
interactive.startupExtensionMessages = nil
|
||||
return runInteractiveModeBubbleTea(ctx, interactive)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// runModeDeps bundles the shared dependencies that runNormalMode wires up
|
||||
// once and threads to both runNonInteractiveModeApp and
|
||||
// runInteractiveModeBubbleTea. Grouping them into a single struct keeps the
|
||||
// call sites and signatures readable and makes it trivial to add a new
|
||||
// provider callback without touching every call chain.
|
||||
type runModeDeps struct {
|
||||
appInstance *app.App
|
||||
cli *ui.CLI // non-interactive only
|
||||
modelName string
|
||||
providerName string
|
||||
loadingMessage string
|
||||
serverNames []string
|
||||
toolNames []string
|
||||
mcpToolCount int
|
||||
extensionToolCount int
|
||||
usageTracker *ui.UsageTracker
|
||||
extCommands []commands.ExtensionCommand
|
||||
promptTemplates []*prompts.PromptTemplate
|
||||
contextPaths []string
|
||||
skillItems []ui.SkillItem
|
||||
extensionItems []ui.ExtensionItem
|
||||
getPromptTemplates func() []*prompts.PromptTemplate
|
||||
getSkillItems func() []ui.SkillItem
|
||||
getExtensionItems func() []ui.ExtensionItem
|
||||
getToolNames func() []string
|
||||
getMCPToolCount func() int
|
||||
mcpPrompts []ui.MCPPromptInfo
|
||||
getMCPPrompts func() []ui.MCPPromptInfo
|
||||
expandMCPPrompt func(string, string, map[string]string) (*ui.MCPPromptExpandResult, error)
|
||||
getWidgets func(string) []ui.WidgetData
|
||||
getHeader func() *ui.WidgetData
|
||||
getFooter func() *ui.WidgetData
|
||||
getToolRenderer func(string) *ui.ToolRendererData
|
||||
getEditorInterceptor func() *ui.EditorInterceptor
|
||||
getUIVisibility func() *ui.UIVisibility
|
||||
getStatusBarEntries func() []ui.StatusBarEntryData
|
||||
emitBeforeFork func(string, bool, string) (bool, string)
|
||||
emitBeforeSessionSwitch func(string) (bool, string)
|
||||
getGlobalShortcuts func() map[string]func()
|
||||
getExtensionCommands func() []commands.ExtensionCommand
|
||||
setModel func(string) error
|
||||
emitModelChange func(string, string, string)
|
||||
isReasoningModel bool
|
||||
thinkingLevel string
|
||||
setThinkingLevel func(string) error
|
||||
switchSession func(string) error
|
||||
reloadExtensions func() error
|
||||
startupExtensionMessages []string // interactive only
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// JSON output helpers (--json mode)
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -1447,7 +1592,8 @@ func writeJSONError(err error) {
|
||||
// 4. Calls program.Run() which blocks until the user quits (Ctrl+C or /quit).
|
||||
//
|
||||
// SetupCLI is not used for interactive mode; the TUI (AppModel) handles its own rendering.
|
||||
func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, extensionItems []ui.ExtensionItem, getPromptTemplates func() []*prompts.PromptTemplate, getSkillItems func() []ui.SkillItem, getExtensionItems func() []ui.ExtensionItem, getToolNames func() []string, getMCPToolCount func() int, mcpPrompts []ui.MCPPromptInfo, getMCPPrompts func() []ui.MCPPromptInfo, expandMCPPrompt func(string, string, map[string]string) (*ui.MCPPromptExpandResult, error), getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error, startupExtensionMessages []string) error {
|
||||
func runInteractiveModeBubbleTea(_ context.Context, deps runModeDeps) error {
|
||||
appInstance := deps.appInstance
|
||||
// Redirect all log output (stdlib and charm) to a file so that log
|
||||
// messages don't write to stderr and corrupt the TUI. Bubble Tea
|
||||
// captures stdout for rendering; any stray stderr output from
|
||||
@@ -1470,49 +1616,49 @@ func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelN
|
||||
cwd, _ := os.Getwd()
|
||||
|
||||
appModel := ui.NewAppModel(appInstance, ui.AppModelOptions{
|
||||
ModelName: modelName,
|
||||
ProviderName: providerName,
|
||||
LoadingMessage: loadingMessage,
|
||||
ModelName: deps.modelName,
|
||||
ProviderName: deps.providerName,
|
||||
LoadingMessage: deps.loadingMessage,
|
||||
Cwd: cwd,
|
||||
Width: termWidth,
|
||||
Height: termHeight,
|
||||
ServerNames: serverNames,
|
||||
ToolNames: toolNames,
|
||||
GetToolNames: getToolNames,
|
||||
GetMCPToolCount: getMCPToolCount,
|
||||
MCPToolCount: mcpToolCount,
|
||||
ExtensionToolCount: extensionToolCount,
|
||||
UsageTracker: usageTracker,
|
||||
ExtensionCommands: extCommands,
|
||||
PromptTemplates: promptTemplates,
|
||||
GetPromptTemplates: getPromptTemplates,
|
||||
MCPPrompts: mcpPrompts,
|
||||
GetMCPPrompts: getMCPPrompts,
|
||||
ExpandMCPPrompt: expandMCPPrompt,
|
||||
ContextPaths: contextPaths,
|
||||
SkillItems: skillItems,
|
||||
GetSkillItems: getSkillItems,
|
||||
ExtensionItems: extensionItems,
|
||||
GetExtensionItems: getExtensionItems,
|
||||
StartupExtensionMessages: startupExtensionMessages,
|
||||
GetWidgets: getWidgets,
|
||||
GetHeader: getHeader,
|
||||
GetFooter: getFooter,
|
||||
GetToolRenderer: getToolRenderer,
|
||||
GetEditorInterceptor: getEditorInterceptor,
|
||||
GetUIVisibility: getUIVisibility,
|
||||
GetStatusBarEntries: getStatusBarEntries,
|
||||
EmitBeforeFork: emitBeforeFork,
|
||||
EmitBeforeSessionSwitch: emitBeforeSessionSwitch,
|
||||
GetGlobalShortcuts: getGlobalShortcuts,
|
||||
GetExtensionCommands: getExtensionCommands,
|
||||
SetModel: setModel,
|
||||
EmitModelChange: emitModelChange,
|
||||
ThinkingLevel: thinkingLevel,
|
||||
IsReasoningModel: isReasoningModel,
|
||||
SetThinkingLevel: setThinkingLevel,
|
||||
SwitchSession: switchSession,
|
||||
ReloadExtensions: reloadExtensions,
|
||||
ServerNames: deps.serverNames,
|
||||
ToolNames: deps.toolNames,
|
||||
GetToolNames: deps.getToolNames,
|
||||
GetMCPToolCount: deps.getMCPToolCount,
|
||||
MCPToolCount: deps.mcpToolCount,
|
||||
ExtensionToolCount: deps.extensionToolCount,
|
||||
UsageTracker: deps.usageTracker,
|
||||
ExtensionCommands: deps.extCommands,
|
||||
PromptTemplates: deps.promptTemplates,
|
||||
GetPromptTemplates: deps.getPromptTemplates,
|
||||
MCPPrompts: deps.mcpPrompts,
|
||||
GetMCPPrompts: deps.getMCPPrompts,
|
||||
ExpandMCPPrompt: deps.expandMCPPrompt,
|
||||
ContextPaths: deps.contextPaths,
|
||||
SkillItems: deps.skillItems,
|
||||
GetSkillItems: deps.getSkillItems,
|
||||
ExtensionItems: deps.extensionItems,
|
||||
GetExtensionItems: deps.getExtensionItems,
|
||||
StartupExtensionMessages: deps.startupExtensionMessages,
|
||||
GetWidgets: deps.getWidgets,
|
||||
GetHeader: deps.getHeader,
|
||||
GetFooter: deps.getFooter,
|
||||
GetToolRenderer: deps.getToolRenderer,
|
||||
GetEditorInterceptor: deps.getEditorInterceptor,
|
||||
GetUIVisibility: deps.getUIVisibility,
|
||||
GetStatusBarEntries: deps.getStatusBarEntries,
|
||||
EmitBeforeFork: deps.emitBeforeFork,
|
||||
EmitBeforeSessionSwitch: deps.emitBeforeSessionSwitch,
|
||||
GetGlobalShortcuts: deps.getGlobalShortcuts,
|
||||
GetExtensionCommands: deps.getExtensionCommands,
|
||||
SetModel: deps.setModel,
|
||||
EmitModelChange: deps.emitModelChange,
|
||||
ThinkingLevel: deps.thinkingLevel,
|
||||
IsReasoningModel: deps.isReasoningModel,
|
||||
SetThinkingLevel: deps.setThinkingLevel,
|
||||
SwitchSession: deps.switchSession,
|
||||
ReloadExtensions: deps.reloadExtensions,
|
||||
ShowSessionPicker: resumeFlag,
|
||||
GetMCPResources: mcpGetResources,
|
||||
MCPResourceReader: mcpResourceReader,
|
||||
|
||||
@@ -58,6 +58,7 @@ kit install github.com/mark3labs/kit/examples/extensions --local
|
||||
| `project-rules.go` | Project-specific rules | Session data, file reading |
|
||||
| `protected-paths.go` | Block dangerous operations | `OnToolCall` with blocking |
|
||||
| `permission-gate.go` | Confirm destructive actions | `OnToolCall` with confirmation |
|
||||
| `usage-budget.go` | Soft cost cap + per-turn report | `OnLLMUsage`, `SetState`/`GetState`, enriched `AgentEndEvent` |
|
||||
|
||||
### Tools & Commands
|
||||
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
//go:build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
// Init demonstrates the three primitives added in issue #53:
|
||||
//
|
||||
// 1. api.OnLLMUsage(...) — per-LLM-call usage callback with token + cost
|
||||
// deltas. Use this for budget enforcement that reacts between calls
|
||||
// within a single agent turn, rather than only at turn boundaries.
|
||||
//
|
||||
// 2. ctx.SetState / ctx.GetState / ctx.DeleteState / ctx.ListState —
|
||||
// last-write-wins, session-scoped key-value store backed by a sidecar
|
||||
// file. Use this for snapshot state (current value of X) instead of
|
||||
// ctx.AppendEntry, which is append-only and bloats branch reads.
|
||||
//
|
||||
// 3. ext.AgentEndEvent.ToolCallCount / .ToolNames / .LLMCallCount /
|
||||
// .InputTokensDelta / .OutputTokensDelta / .CostDelta / .DurationMs —
|
||||
// per-turn aggregates so observer extensions don't need to maintain
|
||||
// parallel bookkeeping.
|
||||
//
|
||||
// Together these support a simple soft-budget cap: warn when the
|
||||
// cumulative cost in this session exceeds a threshold, and print a
|
||||
// per-turn report on AgentEnd.
|
||||
//
|
||||
// Usage: kit -e examples/extensions/usage-budget.go
|
||||
func Init(api ext.API) {
|
||||
const warnAtKey = "usage-budget:warn-at-usd"
|
||||
|
||||
// 1. Print per-LLM-call usage with provider, model, and cost.
|
||||
api.OnLLMUsage(func(e ext.LLMUsageEvent, ctx ext.Context) {
|
||||
ctx.Print(fmt.Sprintf(
|
||||
"[usage] step=%d %s/%s tokens=↑%d ↓%d cache=↑%d/↓%d cost=$%.4f (%s)",
|
||||
e.StepNumber, e.Provider, e.Model,
|
||||
e.InputTokens, e.OutputTokens,
|
||||
e.CacheWriteTokens, e.CacheReadTokens,
|
||||
e.Cost, e.FinishReason,
|
||||
))
|
||||
|
||||
// 2. Persist running total in last-write-wins state.
|
||||
current := 0.0
|
||||
if raw, ok := ctx.GetState("usage-budget:total-cost"); ok {
|
||||
current, _ = strconv.ParseFloat(raw, 64)
|
||||
}
|
||||
current += e.Cost
|
||||
ctx.SetState("usage-budget:total-cost", strconv.FormatFloat(current, 'f', 6, 64))
|
||||
|
||||
// Soft warn-at threshold (configurable via state).
|
||||
warnAt := 0.50
|
||||
if raw, ok := ctx.GetState(warnAtKey); ok {
|
||||
if v, err := strconv.ParseFloat(raw, 64); err == nil {
|
||||
warnAt = v
|
||||
}
|
||||
}
|
||||
if current > warnAt {
|
||||
ctx.PrintError(fmt.Sprintf(
|
||||
"[usage] session cost $%.4f exceeds soft cap $%.2f",
|
||||
current, warnAt,
|
||||
))
|
||||
}
|
||||
})
|
||||
|
||||
// 3. Print a per-turn summary using the enriched AgentEndEvent.
|
||||
api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) {
|
||||
ctx.Print(fmt.Sprintf(
|
||||
"[turn] stop=%s tools=%d llm-calls=%d tokens=↑%d ↓%d cost=$%.4f duration=%dms",
|
||||
e.StopReason, e.ToolCallCount, e.LLMCallCount,
|
||||
e.InputTokensDelta, e.OutputTokensDelta, e.CostDelta, e.DurationMs,
|
||||
))
|
||||
if len(e.ToolNames) > 0 {
|
||||
ctx.Print(fmt.Sprintf("[turn] tool order: %v", e.ToolNames))
|
||||
}
|
||||
})
|
||||
|
||||
// Bootstrap default soft cap once per session.
|
||||
api.OnSessionStart(func(e ext.SessionStartEvent, ctx ext.Context) {
|
||||
if _, ok := ctx.GetState(warnAtKey); !ok {
|
||||
ctx.SetState(warnAtKey, "0.50")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -42,4 +42,14 @@ defer host.Close()
|
||||
response, err := host.Prompt(ctx, "Hello!")
|
||||
```
|
||||
|
||||
Or use the functional-options constructor for quick setups (streaming defaults on):
|
||||
|
||||
```go
|
||||
host, err := kit.NewAgent(ctx,
|
||||
kit.WithModel("anthropic/claude-sonnet-4-5-20250929"),
|
||||
kit.WithSystemPrompt("You are a helpful assistant."),
|
||||
kit.Ephemeral(),
|
||||
)
|
||||
```
|
||||
|
||||
See the [SDK README](../../pkg/kit/README.md) for the full API reference.
|
||||
|
||||
@@ -1,32 +1,34 @@
|
||||
module github.com/mark3labs/kit
|
||||
|
||||
go 1.26.2
|
||||
go 1.26.3
|
||||
|
||||
require (
|
||||
charm.land/bubbles/v2 v2.1.0
|
||||
charm.land/bubbletea/v2 v2.0.6
|
||||
charm.land/fantasy v0.23.0
|
||||
charm.land/bubbletea/v2 v2.0.7
|
||||
charm.land/fantasy v0.25.0
|
||||
charm.land/huh/v2 v2.0.3
|
||||
charm.land/lipgloss/v2 v2.0.3
|
||||
github.com/alecthomas/chroma/v2 v2.24.1
|
||||
github.com/alecthomas/chroma/v2 v2.26.1
|
||||
github.com/atotto/clipboard v0.1.4
|
||||
github.com/aymanbagabas/go-udiff v0.4.1
|
||||
github.com/charmbracelet/colorprofile v0.4.3
|
||||
github.com/charmbracelet/fang v1.0.0
|
||||
github.com/charmbracelet/log v1.0.0
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260428153724-66037269d7be
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260601155805-6cf7526a1b3f
|
||||
github.com/charmbracelet/x/editor v0.2.0
|
||||
github.com/clipperhouse/displaywidth v0.11.0
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0
|
||||
github.com/coder/acp-go-sdk v0.12.2
|
||||
github.com/coder/acp-go-sdk v0.13.5
|
||||
github.com/fsnotify/fsnotify v1.10.1
|
||||
github.com/indaco/herald v0.13.0
|
||||
github.com/indaco/herald-md v0.3.0
|
||||
github.com/mark3labs/mcp-go v0.51.0
|
||||
github.com/mark3labs/mcp-go v0.54.1
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/traefik/yaegi v0.16.1
|
||||
golang.org/x/term v0.42.0
|
||||
golang.org/x/image v0.41.0
|
||||
golang.org/x/term v0.43.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
@@ -37,39 +39,39 @@ require (
|
||||
cloud.google.com/go/compute/metadata v0.9.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.7 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.8 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.17 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.16 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.19 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.18 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.24 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.24 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.24 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.25 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 // indirect
|
||||
github.com/aws/smithy-go v1.25.1 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.24 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.1.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.18 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.1 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.2 // indirect
|
||||
github.com/aws/smithy-go v1.26.0 // indirect
|
||||
github.com/catppuccin/go v0.3.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab // indirect
|
||||
github.com/charmbracelet/colorprofile v0.4.3 // indirect
|
||||
github.com/charmbracelet/harmonica v0.2.0 // indirect
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260503005035-c113ba3d2310 // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260602025833-85a30b5e440a // indirect
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260503005035-c113ba3d2310 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260527151214-009e6338d40d // indirect
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/json v0.2.0 // indirect
|
||||
github.com/charmbracelet/x/termios v0.1.1 // indirect
|
||||
github.com/charmbracelet/x/windows v0.2.2 // indirect
|
||||
github.com/dlclark/regexp2 v1.12.0 // indirect
|
||||
github.com/dlclark/regexp2/v2 v2.1.1 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/go-json-experiment/json v0.0.0-20260430182902-b6187a392ed4 // indirect
|
||||
github.com/go-json-experiment/json v0.0.0-20260520185125-572e7c383686 // indirect
|
||||
github.com/go-logfmt/logfmt v0.6.1 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
@@ -79,13 +81,13 @@ require (
|
||||
github.com/google/jsonschema-go v0.4.3 // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.15 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.16 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.22.0 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.4.7 // indirect
|
||||
github.com/kaptinlin/jsonpointer v0.4.21 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.4.5 // indirect
|
||||
github.com/kaptinlin/jsonpointer v0.4.25 // indirect
|
||||
github.com/kaptinlin/jsonschema v0.7.13 // indirect
|
||||
github.com/kaptinlin/messageformat-go v0.6.3 // indirect
|
||||
github.com/kaptinlin/messageformat-go v0.6.0 // indirect
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect
|
||||
github.com/muesli/mango v0.2.0 // indirect
|
||||
github.com/muesli/mango-cobra v1.3.0 // indirect
|
||||
@@ -97,7 +99,7 @@ require (
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/tidwall/gjson v1.18.0 // indirect
|
||||
github.com/tidwall/gjson v1.19.0 // indirect
|
||||
github.com/tidwall/match v1.2.0 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
@@ -105,21 +107,21 @@ require (
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
github.com/yuin/goldmark v1.8.2 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 // indirect
|
||||
go.opentelemetry.io/otel v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.43.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.69.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.69.0 // indirect
|
||||
go.opentelemetry.io/otel v1.44.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.44.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.44.0 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/crypto v0.50.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f // indirect
|
||||
golang.org/x/net v0.53.0 // indirect
|
||||
golang.org/x/crypto v0.52.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20260603202125-055de637280b // indirect
|
||||
golang.org/x/net v0.55.0 // indirect
|
||||
golang.org/x/oauth2 v0.36.0 // indirect
|
||||
golang.org/x/time v0.15.0 // indirect
|
||||
google.golang.org/api v0.277.0 // indirect
|
||||
google.golang.org/genai v1.55.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260427160629-7cedc36a6bc4 // indirect
|
||||
google.golang.org/grpc v1.81.0 // indirect
|
||||
google.golang.org/api v0.282.0 // indirect
|
||||
google.golang.org/genai v1.58.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260526163538-3dc84a4a5aaa // indirect
|
||||
google.golang.org/grpc v1.81.1 // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
)
|
||||
@@ -131,12 +133,12 @@ require (
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.22 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.23 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.24 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/spf13/pflag v1.0.10
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/sys v0.43.0 // indirect
|
||||
golang.org/x/text v0.36.0
|
||||
golang.org/x/sys v0.45.0 // indirect
|
||||
golang.org/x/text v0.37.0
|
||||
)
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
cel.dev/expr v0.25.2/go.mod h1:hrXvqGP6G6gyx8UAHSHJ5RGk//1Oj5nXQ2NI02Nrsg4=
|
||||
charm.land/bubbles/v2 v2.1.0 h1:YSnNh5cPYlYjPxRrzs5VEn3vwhtEn3jVGRBT3M7/I0g=
|
||||
charm.land/bubbles/v2 v2.1.0/go.mod h1:l97h4hym2hvWBVfmJDtrEHHCtkIKeTEb3TTJ4ZOB3wY=
|
||||
charm.land/bubbletea/v2 v2.0.6 h1:UHN/91OyuhaOFGSrBXQ/hMZD8IO1Uc4BvHlgHXL2WJo=
|
||||
charm.land/bubbletea/v2 v2.0.6/go.mod h1:MH/D8ZLlN3op37vQvijKuU29g3rqTp+aQapURFonF9g=
|
||||
charm.land/fantasy v0.23.0 h1:pocjwC5CxfEg1Bpwb0raML2d5ijo3op33Mmd6hYJyo4=
|
||||
charm.land/fantasy v0.23.0/go.mod h1:4yzSsd9XmFEVjRnF1P0LTEbLTmQX6OLnPkrHaf7iruo=
|
||||
charm.land/bubbletea/v2 v2.0.7 h1:7qw2tTAVar7m7klOPBYfTB0mniv/RuexsYwMRNxSeL0=
|
||||
charm.land/bubbletea/v2 v2.0.7/go.mod h1:DGW2q8gvzHnOpMpZTORs0aySVHCox5C+2Svk0fci1qs=
|
||||
charm.land/fantasy v0.25.0 h1:oXOWY1ivmTSnhYGzAolscF8zKtavWZyBWv0LHRSwN5Q=
|
||||
charm.land/fantasy v0.25.0/go.mod h1:8QrWUzIcKwZQP+aAnC9vLu3iID6hu9/Jt+rPMiieBkc=
|
||||
charm.land/huh/v2 v2.0.3 h1:2cJsMqEPwSywGHvdlKsJyQKPtSJLVnFKyFbsYZTlLkU=
|
||||
charm.land/huh/v2 v2.0.3/go.mod h1:93eEveeeqn47MwiC3tf+2atZ2l7Is88rAtmZNZ8x9Wc=
|
||||
charm.land/lipgloss/v2 v2.0.3 h1:yM2zJ4Cf5Y51b7RHIwioil4ApI/aypFXXVHSwlM6RzU=
|
||||
charm.land/lipgloss/v2 v2.0.3/go.mod h1:7myLU9iG/3xluAWzpY/fSxYYHCgoKTie7laxk6ATwXA=
|
||||
charm.land/x/vcr v0.1.1/go.mod h1:eByq2gqzWvcct/8XE2XO5KznoWEBiXH56+y2gphbltM=
|
||||
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
|
||||
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
|
||||
cloud.google.com/go/auth v0.20.0 h1:kXTssoVb4azsVDoUiF8KvxAqrsQcQtB53DcSgta74CA=
|
||||
@@ -16,6 +18,11 @@ cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIi
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
|
||||
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
|
||||
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
|
||||
cloud.google.com/go/iam v1.11.0/go.mod h1:KP+nKGugNJW4LcLx1uEZcq1ok5sQHFaQehQNl4QDgV4=
|
||||
cloud.google.com/go/longrunning v0.5.6/go.mod h1:vUaDrWYOMKRuhiv6JBnn49YxCPz2Ayn9GqyjaBT8/mA=
|
||||
cloud.google.com/go/monitoring v1.29.0/go.mod h1:72NOVjJXHY/HBfoLT0+qlCZBT059+9VXLeAnL2PeeVM=
|
||||
cloud.google.com/go/storage v1.62.1/go.mod h1:cpYz/kRVZ+UQAF1uHeea10/9ewcRbxGoGNKsS9daSXA=
|
||||
cloud.google.com/go/translate v1.10.3/go.mod h1:GW0vC1qvPtd3pgtypCv4k4U8B7EdgK9/QEF2aJEUovs=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1 h1:jHb/wfvRikGdxMXYV3QG/SzUOPYN9KEUUuC0Yd0/vC0=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1/go.mod h1:pzBXCYn05zvYIrwLgtK8Ap8QcjRg+0i76tMQdWN6wOk=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4=
|
||||
@@ -24,52 +31,68 @@ github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 h1:fhqpLE3UEXi9lPaBRpQ6Xu
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0/go.mod h1:7dCRMLwisfRH3dBupKeNCioWYUZ4SS09Z14H+7i8ZoY=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk=
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.32.0/go.mod h1:RD2SsorTmYhF6HkTmDw7KmPYQk8OBYwTkuasChwv7R4=
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.56.0/go.mod h1:hEpiGU18xf70qb3jbTcIggWAiEfX/cOIVc2OTe4OegA=
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.56.0/go.mod h1:6ZZMQhZKDvUvkJw2rc+oDP90tMMzuU/J+5HG1ZmPOmE=
|
||||
github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ=
|
||||
github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
|
||||
github.com/Rhymond/go-money v1.0.15/go.mod h1:iHvCuIvitxu2JIlAlhF0g9jHqjRSr+rpdOs7Omqlupg=
|
||||
github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0=
|
||||
github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
|
||||
github.com/alecthomas/chroma/v2 v2.24.1 h1:m5ffpfZbIb++k8AqFEKy9uVgY12xIQtBsQlc6DfZJQM=
|
||||
github.com/alecthomas/chroma/v2 v2.24.1/go.mod h1:l+ohZ9xRXIbGe7cIW+YZgOGbvuVLjMps/FYN/CwuabI=
|
||||
github.com/alecthomas/chroma/v2 v2.26.1 h1:2X21EdxGZNv5GF9mG5u+uzc02GCFyGxbcBm3Grd9A78=
|
||||
github.com/alecthomas/chroma/v2 v2.26.1/go.mod h1:lxhRRa9H4hPmRLOOdYga4zkQIQjq3dtrrdwQeCfu78Y=
|
||||
github.com/alecthomas/repr v0.5.2 h1:SU73FTI9D1P5UNtvseffFSGmdNci/O6RsqzeXJtP0Qs=
|
||||
github.com/alecthomas/repr v0.5.2/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
|
||||
github.com/ardanlabs/jinja v1.2.0/go.mod h1:aXXzlJfjA+T3XNKA/YT5ZtDq2VJxt5a5siZ8cl9B35Q=
|
||||
github.com/ardanlabs/kronk v1.25.2/go.mod h1:b5Gg4jDqvHDklkeHNB8+7treZRxUiCFsV65zphrTloY=
|
||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.7 h1:DWpAJt66FmnnaRIOT/8ASTucrvuDPZASqhhLey6tLY8=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.7/go.mod h1:4LAfZOPHNVNQEckOACQx60Y8pSRjIkNZQz1w92xpMJc=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.8 h1:sRs7nG6/RiEBZ/K5UO2sNw0w40U02Nmz1VtARloTZXk=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.8/go.mod h1:4LAfZOPHNVNQEckOACQx60Y8pSRjIkNZQz1w92xpMJc=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10 h1:gx1AwW1Iyk9Z9dD9F4akX5gnN3QZwUB20GGKH/I+Rho=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10/go.mod h1:qqY157uZoqm5OXq/amuaBJyC9hgBCBQnsaWnPe905GY=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.17 h1:FpL4/758/diKwqbytU0prpuiu60fgXKUWCpDJtApclU=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.17/go.mod h1:OXqUMzgXytfoF9JaKkhrOYsyh72t9G+MJH8mMRaexOE=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.16 h1:r3RJBuU7X9ibt8RHbMjWE6y60QbKBiII6wSrXnapxSU=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.16/go.mod h1:6cx7zqDENJDbBIIWX6P8s0h6hqHC8Avbjh9Dseo27ug=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 h1:UuSfcORqNSz/ey3VPRS8TcVH2Ikf0/sC+Hdj400QI6U=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23/go.mod h1:+G/OSGiOFnSOkYloKj/9M35s74LgVAdJBSD5lsFfqKg=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 h1:GpT/TrnBYuE5gan2cZbTtvP+JlHsutdmlV2YfEyNde0=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23/go.mod h1:xYWD6BS9ywC5bS3sz9Xh04whO/hzK2plt2Zkyrp4JuA=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 h1:bpd8vxhlQi2r1hiueOw02f/duEPTMK59Q4QMAoTTtTo=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23/go.mod h1:15DfR2nw+CRHIk0tqNyifu3G1YdAOy68RftkhMDDwYk=
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 h1:OQqn11BtaYv1WLUowvcA30MpzIu8Ti4pcLPIIyoKZrA=
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24/go.mod h1:X5ZJyfwVrWA96GzPmUCWFQaEARPR7gCrpq2E92PJwAE=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.19 h1:qRhIJMbevHUvIE7X4TK8N8zye5+5AhapcslPrvB+qKE=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.19/go.mod h1:RbJ24nfoya63+Mf5VI+CGCGk9vEdv28xPeii+gojRYs=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.18 h1:GcXQz2M/0ZvMo0v5DakUqbDBeBM1ZNaivkolEF4Esgw=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.18/go.mod h1:sHJ06tMGcD3ZpmMyJqV+VBsGilhSIZPIN+ZFy5Dg0C4=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.24 h1:FQm5ApnyzkuJdXLGskPce83CK1CQKC4RUnIHKVe4BU4=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.24/go.mod h1:JsC7dqQc55MlZ5mvNsDMMge71u8pVcSzU3RNz2h/5yQ=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.24 h1:u6kJU2i0va1AgtJsH3RdWKWqHULlTh7zHwb35Womf74=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.24/go.mod h1:7GY+xLcXOFUpCkNwDReft9qOAVg54A4/AnjHIU7sSAY=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.24 h1:Xhbcf3KugX6vX7SDyUK205Oicyfg7EGuvoVNyP5L6DM=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.24/go.mod h1:rwDgb2HNOGZsnTHylOUedM7Vnl+bCfnXDqUNPsFWYfk=
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY=
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.25 h1:54CTMmlJ71Rk2dYvM9qZOob+39wjlVja2zDLxCu69Ew=
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.25/go.mod h1:BZaHqxsS9vN1fvV5EfEl0OBLOk5+AajWsMu6MjqnZB4=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 h1:FLudkZLt5ci0ozzgkVo8BJGwvqNaZbTWb3UcucAateA=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9/go.mod h1:w7wZ/s9qK7c8g4al+UyoF1Sp/Z45UwMGcqIzLWVQHWk=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23 h1:pbrxO/kuIwgEsOPLkaHu0O+m4fNgLU8B3vxQ+72jTPw=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23/go.mod h1:/CMNUqoj46HpS3MNRDEDIwcgEnrtZlKRaHNaHxIFpNA=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 h1:TdJ+HdzOBhU8+iVAOGUTU63VXopcumCOF1paFulHWZc=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11/go.mod h1:R82ZRExE/nheo0N+T8zHPcLRTcH8MGsnR3BiVGX0TwI=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 h1:7byT8HUWrgoRp6sXjxtZwgOKfhss5fW6SkLBtqzgRoE=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17/go.mod h1:xNWknVi4Ezm1vg1QsB/5EWpAJURq22uqd38U8qKvOJc=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21 h1:+1Kl1zx6bWi4X7cKi3VYh29h8BvsCoHQEQ6ST9X8w7w=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21/go.mod h1:4vIRDq+CJB2xFAXZ+YgGUTiEft7oAQlhIs71xcSeuVg=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 h1:F/M5Y9I3nwr2IEpshZgh1GeHpOItExNM9L1euNuh/fk=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1/go.mod h1:mTNxImtovCOEEuD65mKW7DCsL+2gjEH+RPEAexAzAio=
|
||||
github.com/aws/smithy-go v1.25.1 h1:J8ERsGSU7d+aCmdQur5Txg6bVoYelvQJgtZehD12GkI=
|
||||
github.com/aws/smithy-go v1.25.1/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.15/go.mod h1:e3IzZvQ3kAWNykvE0Tr0RDZCMFInMvhku3qNpcIQXhM=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.24 h1:CQW2FTrflfoslYWLf3fv7vG28Q219+v8YJS5QTQb2+Y=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.24/go.mod h1:Xfx13T+u3nH6EEzgl9fBSO6nDRmze1FvnZNYkctQ2zw=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.23/go.mod h1:M8l3mwgx5ToK7wot2sBBce/ojzgnPzZXUV445gTSyE8=
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.101.0/go.mod h1:L2dcoOgS2VSgbPLvpak2NyUPsO1TBN7M45Z4H7DlRc4=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.1.0 h1:yQo3eZ5qFaL1sJWqs1nL6j3yPHA2/R7c6tQ4T+0IO10=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.1.0/go.mod h1:3Zzou41Qt/ueXfIzHvTEjDNuR5IjCUBVF01SNhrt1e8=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.18 h1:ApLTFdAZfDhZSiY5uskwECKHkSNNF83y2Ru2r7SezWA=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.18/go.mod h1:A9K9qx2l6nK89hp+a350FdGfRkrkH5HdiEjHbiy/Q/c=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.1 h1:4VD7TIZOGzehrgQ8vDE+1c6BQW4ErZPGY8ohZT5LXEE=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.1/go.mod h1:er0SFJfdV89Rit5hIJu/EXtv+qC2XMnxoksLmcUFkqM=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.2 h1:XKnxlM4KZH1gktcsh3zSWc7GW4KivEv/OkifmHOhCUY=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.2/go.mod h1:KJYmkQaFB3SUW2j3aBkPsxNmAb4ZsSOvbvCpuxzHJA0=
|
||||
github.com/aws/smithy-go v1.26.0 h1:9ouqbi+NyKP7fV3Te7UElCwdAb6Y8uk7LGwPE5tVe/s=
|
||||
github.com/aws/smithy-go v1.26.0/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/aymanbagabas/go-udiff v0.4.1 h1:OEIrQ8maEeDBXQDoGCbbTTXYJMYRCRO1fnodZ12Gv5o=
|
||||
github.com/aymanbagabas/go-udiff v0.4.1/go.mod h1:0L9PGwj20lrtmEMeyw4WKJ/TMyDtvAoK9bf2u/mNo3w=
|
||||
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/bgentry/go-netrc v0.0.0-20140422174119-9fd32a8b3d3d/go.mod h1:6QX/PXZ00z/TKoufEY6K/a0k6AhaJrQKdFe6OfVXsa4=
|
||||
github.com/bits-and-blooms/bitset v1.24.4/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=
|
||||
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
|
||||
github.com/catppuccin/go v0.3.0 h1:d+0/YicIq+hSTo5oPuRi5kOpqkVA5tAsU6dNhvRu+aY=
|
||||
github.com/catppuccin/go v0.3.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc=
|
||||
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab h1:J7XQLgl9sefgTnTGrmX3xqvp5o6MCiBzEjGv5igAlc4=
|
||||
@@ -86,8 +109,8 @@ github.com/charmbracelet/log v1.0.0 h1:HVVVMmfOorfj3BA9i8X8UL69Hoz9lI0PYwXfJvOdR
|
||||
github.com/charmbracelet/log v1.0.0/go.mod h1:uYgY3SmLpwJWxmlrPwXvzVYujxis1vAKRV/0VQB7yWA=
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266 h1:BW/sZtyd1JyYy0h5adMm3tzpNyL857LWjuTRET6OhpY=
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266/go.mod h1:1DahUaExbUZx/jD+FNT2PKP4L9rLE5+ZBRuI8mZjd/E=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260428153724-66037269d7be h1:j7w8VP/D4lu5+/4GamMmFy8nrtadcl82/fjvDgSHwLo=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260428153724-66037269d7be/go.mod h1:3YdTxlnV/L0bQ3VN8WOSw8doF7LZV/xawUQ4MuAPDvo=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260601155805-6cf7526a1b3f h1:vKsPSlO4g4jKfJ9enESgNZ45BkbHngTIq3UxNOzic74=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260601155805-6cf7526a1b3f/go.mod h1:hFpumms29Smx3LStRfku8vcCTBe1Kq8aCXtHUJa3mjY=
|
||||
github.com/charmbracelet/x/ansi v0.11.7 h1:kzv1kJvjg2S3r9KHo8hDdHFQLEqn4RBCb39dAYC84jI=
|
||||
github.com/charmbracelet/x/ansi v0.11.7/go.mod h1:9qGpnAVYz+8ACONkZBUWPtL7lulP9No6p1epAihUZwQ=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
|
||||
@@ -98,14 +121,14 @@ github.com/charmbracelet/x/editor v0.2.0 h1:7XLUKtaRaB8jN7bWU2p2UChiySyaAuIfYiIR
|
||||
github.com/charmbracelet/x/editor v0.2.0/go.mod h1:p3oQ28TSL3YPd+GKJ1fHWcp+7bVGpedHpXmo0D6t1dY=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 h1:JSt3B+U9iqk37QUU2Rvb6DSBYRLtWqFqfxf8l5hOZUA=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86/go.mod h1:2P0UgXMEa6TsToMSuFqKFQR+fZTO9CNGUNokkPatT/0=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260503005035-c113ba3d2310 h1:rByFKh9JgQScu7oy0+TlUbC2e93woW/QNZmNXbbbw/E=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260503005035-c113ba3d2310/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260602025833-85a30b5e440a h1:aVvnksCVgxB2igk7jERL9ARIkbDXccp1gXCFqhGlamQ=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260602025833-85a30b5e440a/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f h1:pk6gmGpCE7F3FcjaOEKYriCvpmIN4+6OS/RD0vm4uIA=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f/go.mod h1:IfZAMTHB6XkZSeXUqriemErjAWCCzT0LwjKFYCZyw0I=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 h1:55/qLwjIh0gL0Vni+QAWk7T/qRVP6sBf+2agPBgnOFE=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0/go.mod h1:5UHwmG+is5THxMyCJHNPCn2/ecI07aKNrW+LcResjJ8=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260503005035-c113ba3d2310 h1:PMjHdSo8Vpq9psUw9BoHo9JLPMkm9Hqb+Whk64n3AQQ=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260503005035-c113ba3d2310/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260527151214-009e6338d40d h1:RxcAR+vJCoD8QqT1cqLtkQKw+1cqvjqnu5IpPqYzPco=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260527151214-009e6338d40d/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 h1:i69S2XI7uG1u4NLGeJPSYU++Nmjvpo9nwd6aoEm7gkA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0/go.mod h1:/ehtMPNh9K4odGFkqYJKpIYyePhdp1hLBRvyY4bWkH8=
|
||||
github.com/charmbracelet/x/json v0.2.0 h1:DqB+ZGx2h+Z+1s98HOuOyli+i97wsFQIxP2ZQANTPrQ=
|
||||
@@ -120,12 +143,13 @@ github.com/charmbracelet/x/xpty v0.1.3 h1:eGSitii4suhzrISYH50ZfufV3v085BXQwIytcO
|
||||
github.com/charmbracelet/x/xpty v0.1.3/go.mod h1:poPYpWuLDBFCKmKLDnhBp51ATa0ooD8FhypRwEFtH3Y=
|
||||
github.com/clipperhouse/displaywidth v0.11.0 h1:lBc6kY44VFw+TDx4I8opi/EtL9m20WSEFgwIwO+UVM8=
|
||||
github.com/clipperhouse/displaywidth v0.11.0/go.mod h1:bkrFNkf81G8HyVqmKGxsPufD3JhNl3dSqnGhOoSD/o0=
|
||||
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJv2v7Vk=
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM=
|
||||
github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2 h1:aBangftG7EVZoUb69Os8IaYg++6uMOdKK83QtkkvJik=
|
||||
github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2/go.mod h1:qwXFYgsP6T7XnJtbKlf1HP8AjxZZyzxMmc+Lq5GjlU4=
|
||||
github.com/coder/acp-go-sdk v0.12.2 h1:fpRJ8Z5HMSr5cZ5IywzFlFZcIxZOsto+laNVu7XelFA=
|
||||
github.com/coder/acp-go-sdk v0.12.2/go.mod h1:yKzM/3R9uELp4+nBAwwtkS0aN1FOFjo11CNPy37yFko=
|
||||
github.com/coder/acp-go-sdk v0.13.5 h1:LI9jq5xon7xslaYlnoktvTVyDlE37yIk2daT7N9ASYk=
|
||||
github.com/coder/acp-go-sdk v0.13.5/go.mod h1:yKzM/3R9uELp4+nBAwwtkS0aN1FOFjo11CNPy37yFko=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
|
||||
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
|
||||
@@ -133,13 +157,20 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dlclark/regexp2 v1.12.0 h1:0j4c5qQmnC6XOWNjP3PIXURXN2gWx76rd3KvgdPkCz8=
|
||||
github.com/dlclark/regexp2 v1.12.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/dlclark/regexp2/v2 v2.1.1 h1:LCUGyd9Wf+r+VVOl8Ny38JTpWJcAsdVnCIuhhtthmKw=
|
||||
github.com/dlclark/regexp2/v2 v2.1.1/go.mod h1:avUrQvPaLz2DrFNHJF0taWAFFX2C1GMSSoeiqFjcBmU=
|
||||
github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI=
|
||||
github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
|
||||
github.com/dromara/carbon/v2 v2.6.16/go.mod h1:NGo3reeV5vhWCYWcSqbJRZm46MEwyfYI5EJRdVFoLJo=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/ebitengine/purego v0.10.0/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||
github.com/eliben/go-sentencepiece v0.6.0/go.mod h1:nNYk4aMzgBoI6QFp4LUG8Eu1uO9fHD9L5ZEre93o9+c=
|
||||
github.com/envoyproxy/go-control-plane v0.14.0 h1:hbG2kr4RuFj222B6+7T83thSPqLjwBIfQawTkC++2HA=
|
||||
github.com/envoyproxy/go-control-plane v0.14.0/go.mod h1:NcS5X47pLl/hfqxU70yPwL9ZMkUlwlKxtAohpi2wBEU=
|
||||
github.com/envoyproxy/go-control-plane/envoy v1.37.0 h1:u3riX6BoYRfF4Dr7dwSOroNfdSbEPe9Yyl09/B6wBrQ=
|
||||
github.com/envoyproxy/go-control-plane/envoy v1.37.0/go.mod h1:DReE9MMrmecPy+YvQOAOHNYMALuowAnbjjEMkkWOi6A=
|
||||
github.com/envoyproxy/go-control-plane/ratelimit v0.1.0/go.mod h1:Wk+tMFAFbCXaJPzVVHnPgRKdUdwW/KdbRt94AzgRee4=
|
||||
github.com/envoyproxy/protoc-gen-validate v1.3.3 h1:MVQghNeW+LZcmXe7SY1V36Z+WFMDjpqGAGacLe2T0ds=
|
||||
github.com/envoyproxy/protoc-gen-validate v1.3.3/go.mod h1:TsndJ/ngyIdQRhMcVVGDDHINPLWB7C82oDArY51KfB0=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
@@ -148,8 +179,9 @@ github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHk
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.10.1 h1:b0/UzAf9yR5rhf3RPm9gf3ehBPpf0oZKIjtpKrx59Ho=
|
||||
github.com/fsnotify/fsnotify v1.10.1/go.mod h1:TLheqan6HD6GBK6PrDWyDPBaEV8LspOxvPSjC+bVfgo=
|
||||
github.com/go-json-experiment/json v0.0.0-20260430182902-b6187a392ed4 h1:2WmHkJINIjgXXYDGik8d3oJvFA3DAwPy00csDJ3vo+o=
|
||||
github.com/go-json-experiment/json v0.0.0-20260430182902-b6187a392ed4/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg=
|
||||
github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
|
||||
github.com/go-json-experiment/json v0.0.0-20260520185125-572e7c383686 h1:NZBJxCpbHS1gzS6xAmyxbJznosZIIPk9IB42v62UvKA=
|
||||
github.com/go-json-experiment/json v0.0.0-20260520185125-572e7c383686/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg=
|
||||
github.com/go-logfmt/logfmt v0.6.1 h1:4hvbpePJKnIzH1B+8OR/JPbTx37NktoI9LE2QZBBkvE=
|
||||
github.com/go-logfmt/logfmt v0.6.1/go.mod h1:EV2pOAQoZaT1ZXZbqDl5hrymndi4SY9ED9/z6CO0XAk=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
@@ -163,38 +195,53 @@ github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
|
||||
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/golang/glog v1.2.5/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w=
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/go-pkcs11 v0.3.0/go.mod h1:6eQoGcuNJpa7jnd5pMGdkSaQpNDYvPlXWMcjXXThLlY=
|
||||
github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+DQPd0=
|
||||
github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/martian/v3 v3.3.3/go.mod h1:iEPrYcgCF7jA9OtScMFQyAlZZ4YXTKEtJ1E6RWzmBA0=
|
||||
github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
|
||||
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.15 h1:xolVQTEXusUcAA5UgtyRLjelpFFHWlPQ4XfWGc7MBas=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.15/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.16 h1:F/VPrx0YPBdksZJQdCAp0WUsqnNmZpUZszzfYt0M5Dw=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.16/go.mod h1:9Yb0eAkH/Xqhvv3zbeKf/+wMJqCeocWc6KIhDvEAuYE=
|
||||
github.com/googleapis/gax-go/v2 v2.22.0 h1:PjIWBpgGIVKGoCXuiCoP64altEJCj3/Ei+kSU5vlZD4=
|
||||
github.com/googleapis/gax-go/v2 v2.22.0/go.mod h1:irWBbALSr0Sk3qlqb9SyJ1h68WjgeFuiOzI4Rqw5+aY=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.29.0/go.mod h1:Hyl3n6Twe1hvtd9XUXDec4pTvgMSEixRuQKPTMH2bNs=
|
||||
github.com/hashicorp/aws-sdk-go-base/v2 v2.0.0-beta.72/go.mod h1:Vn+BBgKQHVQYdVQ4NZDICE1Brb+JfaONyDHr3q07oQc=
|
||||
github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48=
|
||||
github.com/hashicorp/go-getter v1.8.6/go.mod h1:nVH12eOV2P58dIiL3rsU6Fh3wLeJEKBOJzhMmzlSWoo=
|
||||
github.com/hashicorp/go-version v1.9.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
|
||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||
github.com/hybridgroup/yzma v1.13.0/go.mod h1:zrzMgv/KVQz23+s6l16b+vJ+9uJVBdWtGcGkwRTMeiQ=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/indaco/herald v0.13.0 h1:+xVG9Fx5NpuWhwku/9IlRL6I009NnX4VUGKvlZHTRxU=
|
||||
github.com/indaco/herald v0.13.0/go.mod h1:T5g1+XLYvpjouhzAGHnAHDCKizhESkoV6+QPZ3DhgWA=
|
||||
github.com/indaco/herald-md v0.3.0 h1:hN1cKyrexPPM9PeHBsKuaWvIizSi/iYvM9yzRgtdb8M=
|
||||
github.com/indaco/herald-md v0.3.0/go.mod h1:RUHVaDSG45ymJjKyxpDwBocLXrZo93FB4OeYMsw9B9s=
|
||||
github.com/kaptinlin/go-i18n v0.4.7 h1:apjIIZHnGRyrkiX3vHj07F1BF6D0JLmV+VGSr1781Jc=
|
||||
github.com/kaptinlin/go-i18n v0.4.7/go.mod h1:+i1J0pFq/9i9ESC5qRMVkKwC+mdQTABhhBExpYOlbeM=
|
||||
github.com/kaptinlin/jsonpointer v0.4.21 h1:WVkwQbeerbHFcoXG7Yo/mlQhhZjWiTnagECEfwDXXa0=
|
||||
github.com/kaptinlin/jsonpointer v0.4.21/go.mod h1:Mo7+DX8RlQTFqS4dnYJl0izSP4ob+Rl5xO/mGDETgaU=
|
||||
github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/jupiterrider/ffi v0.7.0/go.mod h1:9dauhpOfNqrqk28fxuu0kkdeFtT9Qr4vbfigiuIXN7c=
|
||||
github.com/kaptinlin/go-i18n v0.4.5 h1:9tIlo5A0RXth+yZJO2MG7Bhpu/X9PlzQnGz/qyYWNoY=
|
||||
github.com/kaptinlin/go-i18n v0.4.5/go.mod h1:mU/7BH4molY5lGZYBwBRKAaiJ70dWRHuqmQ0/pFLGno=
|
||||
github.com/kaptinlin/jsonpointer v0.4.25 h1:iJ197e8n+WwqaqBsa53FqG3rPJCg5oijyFXEXNWWC3E=
|
||||
github.com/kaptinlin/jsonpointer v0.4.25/go.mod h1:wVOBaXGGnP42YsMb6zev/3W5POTvspdNfh8DXzf8XS8=
|
||||
github.com/kaptinlin/jsonschema v0.7.13 h1:kahVXTy/rURL0XJjyQ9WELm59wEmXi6IY0TWswQEFvU=
|
||||
github.com/kaptinlin/jsonschema v0.7.13/go.mod h1:Uh0aUBusnhXDCEXJ2oimL/hx7YTo7F+sKniE+tM0ERc=
|
||||
github.com/kaptinlin/messageformat-go v0.6.3 h1:m9ZE/fCjnsk8bdkv7Qs56L/ZoHbmQqhz9mRZSAQLU5g=
|
||||
github.com/kaptinlin/messageformat-go v0.6.3/go.mod h1:2KOZ/hgo/SveZ+uyi7vPUpUXieX65Mppzbc3VpGyqKs=
|
||||
github.com/kaptinlin/messageformat-go v0.6.0 h1:D6jiXFsKW4/JG2CMddv/F6Rev9KVbCRKEzzV5QOAcpc=
|
||||
github.com/kaptinlin/messageformat-go v0.6.0/go.mod h1:NKjwS6e9u7DRhAK+vydjDDwJ7UbdHhYjk/yk2WPuZPs=
|
||||
github.com/klauspost/compress v1.18.6/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
@@ -203,12 +250,14 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0 h1:UtrWVfLdarDgc44HcS7pYloGHJUjHV/4FwW4TvVgFr4=
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mark3labs/mcp-go v0.51.0 h1:e8AhEfxzcYt7XqYzwT7uzWNhnqpu3H1Tn7dEJB9Ygj8=
|
||||
github.com/mark3labs/mcp-go v0.51.0/go.mod h1:Zg9cB2HdwdMMVgY0xtTzq3KvYIOJQDsaut+jWjwDaQY=
|
||||
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
|
||||
github.com/mark3labs/mcp-go v0.54.1 h1:Ap/ptEB9FtWzFKM8NDsTA7QDxerQOC06eZigrTldVj0=
|
||||
github.com/mark3labs/mcp-go v0.54.1/go.mod h1:+8WclSK1ZUweCP3hvktSji8n8ABG/95QaEkeVE/Uwas=
|
||||
github.com/mattn/go-isatty v0.0.22 h1:j8l17JJ9i6VGPUFUYoTUKPSgKe/83EYU2zBC7YNKMw4=
|
||||
github.com/mattn/go-isatty v0.0.22/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
|
||||
github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw=
|
||||
github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/mattn/go-runewidth v0.0.24 h1:cpokDiIn0MGnhdHwuWnJBITySJ20QyNGnY2kR/ay2DU=
|
||||
github.com/mattn/go-runewidth v0.0.24/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
@@ -223,6 +272,7 @@ github.com/muesli/roff v0.1.0 h1:YD0lalCotmYuF5HhZliKWlIx7IEhiXeSfq7hNjFqGF8=
|
||||
github.com/muesli/roff v0.1.0/go.mod h1:pjAHQM9hdUUwm/krAfrLGgJkXJ+YuhtsfZ42kieB2Ig=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/pelletier/go-toml/v2 v2.3.1 h1:MYEvvGnQjeNkRF1qUuGolNtNExTDwct51yp7olPtrEc=
|
||||
github.com/pelletier/go-toml/v2 v2.3.1/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
@@ -231,6 +281,10 @@ github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgm
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
||||
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
||||
github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw=
|
||||
github.com/prometheus/procfs v0.20.1/go.mod h1:o9EMBZGRyvDrSPH1RqdxhojkuXstoe4UlK79eF5TGGo=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
@@ -238,8 +292,10 @@ github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4=
|
||||
github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI=
|
||||
github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y=
|
||||
github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 h1:KRzFb2m7YtdldCEkzs6KqmJw4nqEVZGK7IN2kJkjTuQ=
|
||||
github.com/santhosh-tekuri/jsonschema/v6 v6.0.2/go.mod h1:JXeL+ps8p7/KNMjDQk3TCwPpBy0wYklyWTfbkIzdIFU=
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
|
||||
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
|
||||
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
|
||||
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
||||
@@ -251,13 +307,14 @@ github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
|
||||
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
|
||||
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
|
||||
github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xIx7lEzqblHEs=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.19.0 h1:xwxm7n691Uf3u5OFjzngavjGTh55KX5q/9w9xHW88JU=
|
||||
github.com/tidwall/gjson v1.19.0/go.mod h1:V37/opeE/JbLUOfH0QTXiNez2l0RUjYUhpT4szFQAfc=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM=
|
||||
github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
@@ -268,67 +325,87 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/traefik/yaegi v0.16.1 h1:f1De3DVJqIDKmnasUF6MwmWv1dSEEat0wcpXhD2On3E=
|
||||
github.com/traefik/yaegi v0.16.1/go.mod h1:4eVhbPb3LnD2VigQjhYbEJ69vDRFdT2HQNrXx8eEwUY=
|
||||
github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE=
|
||||
github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
|
||||
github.com/zeebo/errs v1.4.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4=
|
||||
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0 h1:0Qx7VGBacMm9ZENQ7TnNObTYI4ShC+lHI16seduaxZo=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0/go.mod h1:Sje3i3MjSPKTSPvVWCaL8ugBzJwik3u4smCjUeuupqg=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 h1:CqXxU8VOmDefoh0+ztfGaymYbhdB/tT3zs79QaZTNGY=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0/go.mod h1:BuhAPThV8PBHBvg8ZzZ/Ok3idOdhWIodywz2xEcRbJo=
|
||||
go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I=
|
||||
go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0=
|
||||
go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM=
|
||||
go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
|
||||
go.opentelemetry.io/contrib/detectors/gcp v1.43.0/go.mod h1:RyaZMFY7yi1kAs45S6mbFGz8O8rqB0dTY14uzvG4LCs=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.69.0 h1:2yEATaop1/a1I4psnSLgWVPLWwCzkqWakgJy7xTDVy0=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.69.0/go.mod h1:D7J12YRapIekYyPWgGPlA/23pRmpSEZC5xJC/TTLI9U=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.69.0 h1:8tvICD4vSTOOsNrsI4Ljf6C+6UKvpTEH5XY3JMoyPoo=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.69.0/go.mod h1:z9+yiacE0IHRqM4qFfkbt/JYlmYXgss8GY/jXoNuPJI=
|
||||
go.opentelemetry.io/otel v1.44.0 h1:JjwHmHpA4iZ3wBxluu2fbbE7j4kqlE8jXyAyPXH7HqU=
|
||||
go.opentelemetry.io/otel v1.44.0/go.mod h1:BMgjTHL9WPRlRjL2oZCBTL4whCGtXch2H4BhOPIAyYc=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0/go.mod h1:Vl1/iaggsuRlrHf/hfPJPvVag77kKyvrLeD10kpMl+A=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0/go.mod h1:AGmbycVGEsRx9mXMZ75CsOyhSP6MFIcj/6dnG+vhVjk=
|
||||
go.opentelemetry.io/otel/metric v1.44.0 h1:1w0gILTcHdr3YI+ixLyjemwrVnsMURbTZFrSYCdDdmc=
|
||||
go.opentelemetry.io/otel/metric v1.44.0/go.mod h1:8O7hanEPBNgEMmybD3s2VBKcgWOCsA6tzHBPODAiquo=
|
||||
go.opentelemetry.io/otel/sdk v1.44.0 h1:nHYwb9lK+fJPU/dnT6s7W7Z8itMWyqrnVfbheVYrZ58=
|
||||
go.opentelemetry.io/otel/sdk v1.44.0/go.mod h1:Osuydd3Se74nqjAKxid74N5eC+jfEqfTegHRnq58oK0=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.44.0 h1:3LlKgI+VjbVsjNRFZJZAJ30WjXC5VkNRks6si09iEfI=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.44.0/go.mod h1:5B5pMARnXxKhltooO4xUuCBorl65a4EpnTalObqOigA=
|
||||
go.opentelemetry.io/otel/trace v1.44.0 h1:jxF5CsGYCe74MCRx2X4g7WsY/VBKRqqpNvXlX/6gtIk=
|
||||
go.opentelemetry.io/otel/trace v1.44.0/go.mod h1:oLl1jrMQAVo6v3GAggN+1VH9VIz9iUSvW53sW1Q8PIE=
|
||||
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
|
||||
go.yaml.in/yaml/v2 v2.4.4/go.mod h1:gMZqIpDtDqOfM0uNfy0SkpRhvUryYH0Z6wdMYcacYXQ=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||
golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f h1:W3F4c+6OLc6H2lb//N1q4WpJkhzJCK5J6kUi1NTVXfM=
|
||||
golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f/go.mod h1:J1xhfL/vlindoeF/aINzNzt2Bket5bjo9sdOYzOsU80=
|
||||
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
|
||||
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
|
||||
go.yaml.in/yaml/v4 v4.0.0-rc.3/go.mod h1:aZqd9kCMsGL7AuUv/m/PvWLdg5sjJsZ4oHDEnfPPfY0=
|
||||
golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988=
|
||||
golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc=
|
||||
golang.org/x/exp v0.0.0-20260603202125-055de637280b h1:v1uXiEBHo8QA0LiGCo7UgHMzHT4Kdfpl2zmtH5vaP1Q=
|
||||
golang.org/x/exp v0.0.0-20260603202125-055de637280b/go.mod h1:d2fgXJLVs4dYDHUk5lwMIfzRzSrWCfGZb0ZqeLa/Vcw=
|
||||
golang.org/x/image v0.41.0 h1:8wS72eGJMJaBxK6okTzd4WaXumUlTVlb753MlsSvTCo=
|
||||
golang.org/x/image v0.41.0/go.mod h1:uIc348UZMSvS5Z65CVZ7iDPaNobNFEPeJ4kbqTOszmA=
|
||||
golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ=
|
||||
golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8=
|
||||
golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww=
|
||||
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
|
||||
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
|
||||
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
||||
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||
golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY=
|
||||
golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
|
||||
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
|
||||
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
|
||||
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
|
||||
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||
golang.org/x/tools v0.45.0/go.mod h1:LuUGqqaXcXMEFEruIVJVm5mgDD8vww/z/SR1gQ4uE/0=
|
||||
golang.org/x/tools/go/expect v0.1.1-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY=
|
||||
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated/go.mod h1:RVAQXBGNv1ib0J382/DPCRS/BPnsGebyM1Gj5VSDpG8=
|
||||
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
|
||||
google.golang.org/api v0.277.0 h1:HJfyJUiNeBBUMai7ez8u14wkp/gH/I4wpGbbO9o+cSk=
|
||||
google.golang.org/api v0.277.0/go.mod h1:B9TqLBwJqVjp1mtt7WeoQwWRwvu/400y5lETOql+giQ=
|
||||
google.golang.org/genai v1.55.0 h1:iLHGk4Bj/IZ/GNNZb7hYqwSJMRBvqLeu2Hb6YQ+rYGw=
|
||||
google.golang.org/genai v1.55.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/genproto v0.0.0-20260427160629-7cedc36a6bc4 h1:2iMJZntwvmfgtse+s744JY7v7PgEdSBuFYXucvpOHNM=
|
||||
google.golang.org/genproto v0.0.0-20260427160629-7cedc36a6bc4/go.mod h1:v14kaaboYyXQ1Gsu489Q+Hg/oN4B33mWtuOhF1HCeXA=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260427160629-7cedc36a6bc4 h1:yOzSCGPx+cp5VO7IxvZ9SBFF7j1tZVcNtlHR2iYKtVo=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260427160629-7cedc36a6bc4/go.mod h1:Q9HWtNeE7tM9npdIsEvqXj1QJIvVoeAV3rtXtS715Cw=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260427160629-7cedc36a6bc4 h1:tEkOQcXgF6dH1G+MVKZrfpYvozGrzb91k6ha7jireSM=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260427160629-7cedc36a6bc4/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.81.0 h1:W3G9N3KQf3BU+YuCtGKJk0CmxQNbAISICD/9AORxLIw=
|
||||
google.golang.org/grpc v1.81.0/go.mod h1:xGH9GfzOyMTGIOXBJmXt+BX/V0kcdQbdcuwQ/zNw42I=
|
||||
google.golang.org/api v0.282.0 h1:WmJiSVqUnKqJCpJOx7YADbXaC+9DDsnGSfllFSj7R2I=
|
||||
google.golang.org/api v0.282.0/go.mod h1:6Wssta4c5n9qHq5CBhmlai5h/PUa1djdDAIhYEHyvcM=
|
||||
google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds=
|
||||
google.golang.org/genai v1.58.0 h1:MNA3ZkRyr7MnRwZ9RNZ60p4+UMKV3yYRw6pyHq4pp0U=
|
||||
google.golang.org/genai v1.58.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/genproto v0.0.0-20260504160031-60b97b32f348 h1:JjVGDZYWkJWZcxveJGzfkXC5myDVWAd4dZdgbzrDUv8=
|
||||
google.golang.org/genproto v0.0.0-20260504160031-60b97b32f348/go.mod h1:95PqD4xM+AdOcBGsmgfaofXsiA37uXDtDufVbntT3TU=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260504160031-60b97b32f348 h1:U8orV30l6KpDsi9dxU0CoJZGbjS8EEpw+6ba+XwGPQA=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260504160031-60b97b32f348/go.mod h1:Yzdzr5OOZFgSsEV2D/Xi9NL3bszpXFAg0hFJiRohcD8=
|
||||
google.golang.org/genproto/googleapis/bytestream v0.0.0-20260523011958-0a33c5d7ca68/go.mod h1:6TABGosqSqU2l1+fJ3jdvOYPPVryeKybxYF0cCZkTBE=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260526163538-3dc84a4a5aaa h1:mZHHdPZl0dbGHCflZgAq/Q468DWVFcU2whhB2KAo8fk=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260526163538-3dc84a4a5aaa/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.81.1 h1:VnnIIZ88UzOOKLukQi+ImGz8O1Wdp8nAGGnvOfEIWQQ=
|
||||
google.golang.org/grpc v1.81.1/go.mod h1:xGH9GfzOyMTGIOXBJmXt+BX/V0kcdQbdcuwQ/zNw42I=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/dnaeon/go-vcr.v4 v4.0.6-0.20251110073552-01de4eb40290/go.mod h1:sbq5oMEcM4PXngbcNbHhzfCP9OdZodLhrbRYoyg09HY=
|
||||
gopkg.in/ini.v1 v1.67.1/go.mod h1:x/cyOwCgZqOkJoDIJ3c1KNHMo10+nLGAhh+kn3Zizss=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
|
||||
@@ -61,6 +61,12 @@ func (a *Agent) Authenticate(_ context.Context, _ acp.AuthenticateRequest) (acp.
|
||||
return acp.AuthenticateResponse{}, nil
|
||||
}
|
||||
|
||||
// Logout handles logout requests. Kit doesn't require auth for local stdio
|
||||
// usage, so this is a no-op.
|
||||
func (a *Agent) Logout(_ context.Context, _ acp.LogoutRequest) (acp.LogoutResponse, error) {
|
||||
return acp.LogoutResponse{}, nil
|
||||
}
|
||||
|
||||
// Initialize negotiates capabilities with the ACP client.
|
||||
func (a *Agent) Initialize(_ context.Context, params acp.InitializeRequest) (acp.InitializeResponse, error) {
|
||||
log.Debug("acp: initialize", "protocol_version", params.ProtocolVersion)
|
||||
|
||||
+71
-100
@@ -7,6 +7,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/spf13/viper"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extbridge"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
@@ -38,10 +39,21 @@ func newSessionRegistry() *sessionRegistry {
|
||||
// given working directory. The Kit-generated session ID is used as the ACP
|
||||
// session ID so the mapping is 1:1.
|
||||
func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession, error) {
|
||||
// Each ACP session gets its own isolated config store (CLI is left nil) so
|
||||
// per-session SetModel / SetThinkingLevel calls cannot race or bleed across
|
||||
// the sessionRegistry. We seed the relevant root-command flag values from
|
||||
// the process-global store (which cobra populated from flags) so launching
|
||||
// `kit acp -m <model> [--thinking-level ...] [--provider-url ...]` is still
|
||||
// honored; .kit.yml and KIT_* env vars are loaded per session by kit.New.
|
||||
streamOn := true
|
||||
kitInstance, err := kit.New(ctx, &kit.Options{
|
||||
SessionDir: cwd,
|
||||
Quiet: true,
|
||||
Streaming: true,
|
||||
SessionDir: cwd,
|
||||
Quiet: true,
|
||||
Streaming: &streamOn,
|
||||
Model: viper.GetString("model"),
|
||||
ThinkingLevel: viper.GetString("thinking-level"),
|
||||
ProviderURL: viper.GetString("provider-url"),
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
})
|
||||
if err != nil {
|
||||
// Provide actionable guidance for provider auth errors, which are
|
||||
@@ -61,111 +73,70 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
|
||||
|
||||
// Wire extension context with headless implementations so extensions
|
||||
// work in ACP mode. TUI-dependent features (widgets, prompts, editor)
|
||||
// become no-ops or return cancelled; all data/model/tool APIs work
|
||||
// identically to interactive mode.
|
||||
// become no-ops or return cancelled; all data/model/tool APIs come from
|
||||
// extbridge.BaseContext and work identically to interactive mode.
|
||||
if kitInstance.Extensions().HasExtensions() {
|
||||
kitInstance.Extensions().SetContext(extensions.Context{
|
||||
SessionID: sessionID,
|
||||
CWD: cwd,
|
||||
Model: kitInstance.GetModelString(),
|
||||
Interactive: false,
|
||||
// Use a background context for subagent spawns: the create() ctx is
|
||||
// request-scoped and may be cancelled before extensions spawn anything.
|
||||
ec := extbridge.BaseContext(context.Background(), kitInstance)
|
||||
|
||||
// Output — route through structured logger.
|
||||
Print: func(text string) { log.Debug("extension: print", "text", text) },
|
||||
PrintInfo: func(text string) { log.Info("extension: info", "text", text) },
|
||||
PrintError: func(text string) { log.Error("extension: error", "text", text) },
|
||||
PrintBlock: func(opts extensions.PrintBlockOpts) {
|
||||
log.Info("extension: block", "subtitle", opts.Subtitle, "text", opts.Text)
|
||||
},
|
||||
ec.SessionID = sessionID
|
||||
ec.CWD = cwd
|
||||
ec.Model = kitInstance.GetModelString()
|
||||
ec.Interactive = false
|
||||
|
||||
// Message injection — no-ops for now; ACP clients drive prompts.
|
||||
SendMessage: func(string) {},
|
||||
CancelAndSend: func(string) {},
|
||||
Exit: func() {},
|
||||
// Output — route through structured logger.
|
||||
ec.Print = func(text string) { log.Debug("extension: print", "text", text) }
|
||||
ec.PrintInfo = func(text string) { log.Info("extension: info", "text", text) }
|
||||
ec.PrintError = func(text string) { log.Error("extension: error", "text", text) }
|
||||
ec.PrintBlock = func(opts extensions.PrintBlockOpts) {
|
||||
log.Info("extension: block", "subtitle", opts.Subtitle, "text", opts.Text)
|
||||
}
|
||||
|
||||
// TUI widgets/chrome — silent no-ops (no TUI in ACP).
|
||||
SetWidget: func(extensions.WidgetConfig) {},
|
||||
RemoveWidget: func(string) {},
|
||||
SetHeader: func(extensions.HeaderFooterConfig) {},
|
||||
RemoveHeader: func() {},
|
||||
SetFooter: func(extensions.HeaderFooterConfig) {},
|
||||
RemoveFooter: func() {},
|
||||
SetEditor: func(extensions.EditorConfig) {},
|
||||
ResetEditor: func() {},
|
||||
SetEditorText: func(string) {},
|
||||
SetUIVisibility: func(extensions.UIVisibility) {},
|
||||
SetStatus: func(string, string, int) {},
|
||||
RemoveStatus: func(string) {},
|
||||
// Message injection — no-ops for now; ACP clients drive prompts.
|
||||
ec.SendMessage = func(string) {}
|
||||
ec.CancelAndSend = func(string) {}
|
||||
ec.Exit = func() {}
|
||||
|
||||
// Interactive prompts — return cancelled (no user to prompt).
|
||||
PromptSelect: func(extensions.PromptSelectConfig) extensions.PromptSelectResult {
|
||||
return extensions.PromptSelectResult{Cancelled: true}
|
||||
},
|
||||
PromptConfirm: func(extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
|
||||
return extensions.PromptConfirmResult{Cancelled: true}
|
||||
},
|
||||
PromptInput: func(extensions.PromptInputConfig) extensions.PromptInputResult {
|
||||
return extensions.PromptInputResult{Cancelled: true}
|
||||
},
|
||||
ShowOverlay: func(extensions.OverlayConfig) extensions.OverlayResult {
|
||||
return extensions.OverlayResult{Cancelled: true, Index: -1}
|
||||
},
|
||||
SuspendTUI: func(callback func()) error { callback(); return nil },
|
||||
// TUI widgets/chrome — silent no-ops (no TUI in ACP).
|
||||
ec.SetWidget = func(extensions.WidgetConfig) {}
|
||||
ec.RemoveWidget = func(string) {}
|
||||
ec.SetHeader = func(extensions.HeaderFooterConfig) {}
|
||||
ec.RemoveHeader = func() {}
|
||||
ec.SetFooter = func(extensions.HeaderFooterConfig) {}
|
||||
ec.RemoveFooter = func() {}
|
||||
ec.SetEditor = func(extensions.EditorConfig) {}
|
||||
ec.ResetEditor = func() {}
|
||||
ec.SetEditorText = func(string) {}
|
||||
ec.SetUIVisibility = func(extensions.UIVisibility) {}
|
||||
ec.SetStatus = func(string, string, int) {}
|
||||
ec.RemoveStatus = func(string) {}
|
||||
|
||||
// Data access — delegate to Kit instance.
|
||||
GetContextStats: func() extensions.ContextStats {
|
||||
s := kitInstance.GetContextStats()
|
||||
return extensions.ContextStats{
|
||||
EstimatedTokens: s.EstimatedTokens,
|
||||
ContextLimit: s.ContextLimit,
|
||||
UsagePercent: s.UsagePercent,
|
||||
MessageCount: s.MessageCount,
|
||||
}
|
||||
},
|
||||
GetMessages: func() []extensions.SessionMessage { return kitInstance.Extensions().GetSessionMessages() },
|
||||
GetSessionPath: func() string { return kitInstance.GetSessionPath() },
|
||||
AppendEntry: func(entryType, data string) (string, error) {
|
||||
return kitInstance.Extensions().AppendEntry(entryType, data)
|
||||
},
|
||||
GetEntries: func(entryType string) []extensions.ExtensionEntry {
|
||||
return kitInstance.Extensions().GetEntries(entryType)
|
||||
},
|
||||
// Interactive prompts — return cancelled (no user to prompt).
|
||||
ec.PromptSelect = func(extensions.PromptSelectConfig) extensions.PromptSelectResult {
|
||||
return extensions.PromptSelectResult{Cancelled: true}
|
||||
}
|
||||
ec.PromptConfirm = func(extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
|
||||
return extensions.PromptConfirmResult{Cancelled: true}
|
||||
}
|
||||
ec.PromptInput = func(extensions.PromptInputConfig) extensions.PromptInputResult {
|
||||
return extensions.PromptInputResult{Cancelled: true}
|
||||
}
|
||||
ec.ShowOverlay = func(extensions.OverlayConfig) extensions.OverlayResult {
|
||||
return extensions.OverlayResult{Cancelled: true, Index: -1}
|
||||
}
|
||||
ec.SuspendTUI = func(callback func()) error { callback(); return nil }
|
||||
|
||||
// Options, model, and tool management.
|
||||
GetOption: func(name string) string { return kitInstance.Extensions().GetOption(name) },
|
||||
SetOption: func(name, value string) { kitInstance.Extensions().SetOption(name, value) },
|
||||
SetModel: func(modelString string) error {
|
||||
previousModel := kitInstance.Extensions().GetContext().Model
|
||||
if err := kitInstance.SetModel(context.Background(), modelString); err != nil {
|
||||
return err
|
||||
}
|
||||
kitInstance.Extensions().UpdateContextModel(modelString)
|
||||
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
|
||||
return nil
|
||||
},
|
||||
GetAvailableModels: func() []extensions.ModelInfoEntry { return kitInstance.GetAvailableModels() },
|
||||
EmitCustomEvent: func(name, data string) { kitInstance.Extensions().EmitCustomEvent(name, data) },
|
||||
GetAllTools: func() []extensions.ToolInfo { return kitInstance.Extensions().GetToolInfos() },
|
||||
SetActiveTools: func(names []string) { kitInstance.Extensions().SetActiveTools(names) },
|
||||
// Render — fall back to logging.
|
||||
ec.RenderMessage = func(name, content string) {
|
||||
renderer := kitInstance.Extensions().GetMessageRenderer(name)
|
||||
if renderer != nil && renderer.Render != nil {
|
||||
content = renderer.Render(content, 80)
|
||||
}
|
||||
log.Info("extension: message", "renderer", name, "content", content)
|
||||
}
|
||||
|
||||
// LLM completions and subagents.
|
||||
Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
|
||||
return kitInstance.ExecuteCompletion(context.Background(), req)
|
||||
},
|
||||
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
return extbridge.SpawnSubagent(context.Background(), kitInstance, config)
|
||||
},
|
||||
|
||||
// Render — fall back to logging.
|
||||
RenderMessage: func(name, content string) {
|
||||
renderer := kitInstance.Extensions().GetMessageRenderer(name)
|
||||
if renderer != nil && renderer.Render != nil {
|
||||
content = renderer.Render(content, 80)
|
||||
}
|
||||
log.Info("extension: message", "renderer", name, "content", content)
|
||||
},
|
||||
ReloadExtensions: func() error { return kitInstance.Extensions().Reload() },
|
||||
})
|
||||
kitInstance.Extensions().SetContext(ec)
|
||||
kitInstance.Extensions().EmitSessionStart()
|
||||
}
|
||||
|
||||
|
||||
+47
-46
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
@@ -168,9 +169,9 @@ type RetryHandler func(attempt int, err error)
|
||||
type PrepareStepHandler func(stepNumber int, messages []fantasy.Message) []fantasy.Message
|
||||
|
||||
// GenerateCallbacks consolidates all callback functions for
|
||||
// GenerateWithLoopAndStreaming into a single struct. This replaces the previous
|
||||
// 16+ positional callback parameters, making it easier to add new callbacks
|
||||
// without breaking existing callers (new fields default to nil).
|
||||
// GenerateWithCallbacks into a single struct, replacing what was previously
|
||||
// 16+ positional callback parameters. New fields default to nil, so adding
|
||||
// new callbacks does not break existing callers.
|
||||
type GenerateCallbacks struct {
|
||||
OnToolCall ToolCallHandler
|
||||
OnToolExecution ToolExecutionHandler
|
||||
@@ -245,6 +246,12 @@ type Agent struct {
|
||||
mcpReady chan struct{}
|
||||
// mcpErr holds any error from background MCP loading.
|
||||
mcpErr error
|
||||
|
||||
// promptMu serializes runtime updates to systemPrompt and the
|
||||
// accompanying fantasy agent rebuild so concurrent SetSystemPrompt
|
||||
// callers (e.g. Kit.applyComposedSystemPrompt invoked from multiple
|
||||
// goroutines) don't race on a.systemPrompt / a.fantasyAgent.
|
||||
promptMu sync.Mutex
|
||||
}
|
||||
|
||||
// GenerateWithLoopResult contains the result and conversation history from an agent interaction.
|
||||
@@ -515,44 +522,6 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []fantasy.Message
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithLoopAndStreaming processes messages using the agent with streaming and callbacks.
|
||||
// The agent handles the tool call loop internally.
|
||||
//
|
||||
// Deprecated: Use GenerateWithCallbacks instead, which takes a GenerateCallbacks
|
||||
// struct and is easier to extend with new callbacks.
|
||||
func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fantasy.Message,
|
||||
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler,
|
||||
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
|
||||
onStreamingResponse StreamingResponseHandler,
|
||||
onReasoningDelta ReasoningDeltaHandler,
|
||||
onReasoningComplete ReasoningCompleteHandler,
|
||||
onToolOutput ToolOutputHandler,
|
||||
onStepMessages StepMessagesHandler,
|
||||
onStepUsage StepUsageHandler,
|
||||
onPasswordPrompt PasswordPromptHandler,
|
||||
onToolCallStart ToolCallStartHandler,
|
||||
onToolCallDelta ToolCallDeltaHandler,
|
||||
onToolCallEnd ToolCallEndHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
return a.GenerateWithCallbacks(ctx, messages, GenerateCallbacks{
|
||||
OnToolCall: onToolCall,
|
||||
OnToolExecution: onToolExecution,
|
||||
OnToolResult: onToolResult,
|
||||
OnResponse: onResponse,
|
||||
OnToolCallContent: onToolCallContent,
|
||||
OnStreamingResponse: onStreamingResponse,
|
||||
OnReasoningDelta: onReasoningDelta,
|
||||
OnReasoningComplete: onReasoningComplete,
|
||||
OnToolOutput: onToolOutput,
|
||||
OnStepMessages: onStepMessages,
|
||||
OnStepUsage: onStepUsage,
|
||||
OnPasswordPrompt: onPasswordPrompt,
|
||||
OnToolCallStart: onToolCallStart,
|
||||
OnToolCallDelta: onToolCallDelta,
|
||||
OnToolCallEnd: onToolCallEnd,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithCallbacks processes messages using the agent with streaming and callbacks.
|
||||
// The agent handles the tool call loop internally. We map the rich callback system
|
||||
// to kit's existing callback interface for UI integration.
|
||||
@@ -585,8 +554,13 @@ func (a *Agent) GenerateWithCallbacks(ctx context.Context, messages []fantasy.Me
|
||||
// This avoids type conflicts with provider-level options.
|
||||
history = applyCacheControlToMessages(history)
|
||||
|
||||
// Track current tool call args for callbacks
|
||||
var currentToolArgs string
|
||||
// Track tool call args per-ToolCallID so parallel tool calls in a single
|
||||
// step don't clobber each other. Without this, OnToolResult callbacks would
|
||||
// all see the args of the last OnToolCall in the step. The mutex guards
|
||||
// against the possibility that the underlying streaming layer dispatches
|
||||
// callbacks from multiple goroutines.
|
||||
toolCallArgs := make(map[string]string)
|
||||
var toolCallArgsMu sync.Mutex
|
||||
|
||||
// Use the streaming path when streaming is enabled OR when any callbacks are
|
||||
// provided. The agent only exposes tool/step callbacks on AgentStreamCall, so
|
||||
@@ -773,7 +747,9 @@ func (a *Agent) GenerateWithCallbacks(ctx context.Context, messages []fantasy.Me
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
currentToolArgs = tc.Input
|
||||
toolCallArgsMu.Lock()
|
||||
toolCallArgs[tc.ToolCallID] = tc.Input
|
||||
toolCallArgsMu.Unlock()
|
||||
|
||||
// Notify about the tool call
|
||||
if cb.OnToolCall != nil {
|
||||
@@ -793,15 +769,22 @@ func (a *Agent) GenerateWithCallbacks(ctx context.Context, messages []fantasy.Me
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
// Look up the args recorded for this specific tool call. Delete
|
||||
// the entry so the map doesn't accumulate across steps.
|
||||
toolCallArgsMu.Lock()
|
||||
args := toolCallArgs[tr.ToolCallID]
|
||||
delete(toolCallArgs, tr.ToolCallID)
|
||||
toolCallArgsMu.Unlock()
|
||||
|
||||
// Notify tool execution finished
|
||||
if cb.OnToolExecution != nil {
|
||||
cb.OnToolExecution(tr.ToolCallID, tr.ToolName, currentToolArgs, false)
|
||||
cb.OnToolExecution(tr.ToolCallID, tr.ToolName, args, false)
|
||||
}
|
||||
|
||||
if cb.OnToolResult != nil {
|
||||
// Extract result text and error status
|
||||
resultText, isError := extractToolResultText(tr)
|
||||
cb.OnToolResult(tr.ToolCallID, tr.ToolName, currentToolArgs, resultText, tr.ClientMetadata, isError)
|
||||
cb.OnToolResult(tr.ToolCallID, tr.ToolName, args, resultText, tr.ClientMetadata, isError)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1303,6 +1286,24 @@ func (a *Agent) GetModel() fantasy.LanguageModel {
|
||||
return a.model
|
||||
}
|
||||
|
||||
// SetSystemPrompt updates the agent's system prompt and rebuilds the underlying
|
||||
// fantasy agent so subsequent turns use the new prompt. Safe to call while the
|
||||
// agent is idle; if invoked during an in-flight turn the new prompt takes
|
||||
// effect on the next LLM call.
|
||||
func (a *Agent) SetSystemPrompt(prompt string) {
|
||||
a.promptMu.Lock()
|
||||
defer a.promptMu.Unlock()
|
||||
a.systemPrompt = prompt
|
||||
a.rebuildFantasyAgent()
|
||||
}
|
||||
|
||||
// GetSystemPrompt returns the agent's current system prompt.
|
||||
func (a *Agent) GetSystemPrompt() string {
|
||||
a.promptMu.Lock()
|
||||
defer a.promptMu.Unlock()
|
||||
return a.systemPrompt
|
||||
}
|
||||
|
||||
// GetMaxTokens returns the effective max output tokens the agent currently
|
||||
// sends to the LLM provider, after per-model defaults, right-sizing, and any
|
||||
// Anthropic thinking-budget adjustments. Returns 0 when no ModelConfig is
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// fakeParallelAgent simulates a provider that emits two parallel tool_use
|
||||
// blocks in a single step. It invokes the streaming callbacks in the order:
|
||||
//
|
||||
// OnToolCall(A) -> OnToolCall(B) -> OnToolResult(A) -> OnToolResult(B)
|
||||
//
|
||||
// Before the fix in #33 the agent-layer wrapper recorded a single
|
||||
// `currentToolArgs` variable that was clobbered by the second OnToolCall, so
|
||||
// both OnToolResult callbacks received B's args instead of their own.
|
||||
type fakeParallelAgent struct {
|
||||
calls []fantasy.ToolCallContent
|
||||
results []fantasy.ToolResultContent
|
||||
}
|
||||
|
||||
func (f *fakeParallelAgent) Generate(_ context.Context, _ fantasy.AgentCall) (*fantasy.AgentResult, error) {
|
||||
return &fantasy.AgentResult{}, nil
|
||||
}
|
||||
|
||||
func (f *fakeParallelAgent) Stream(_ context.Context, opts fantasy.AgentStreamCall) (*fantasy.AgentResult, error) {
|
||||
for _, tc := range f.calls {
|
||||
if opts.OnToolCall != nil {
|
||||
if err := opts.OnToolCall(tc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tr := range f.results {
|
||||
if opts.OnToolResult != nil {
|
||||
if err := opts.OnToolResult(tr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return &fantasy.AgentResult{}, nil
|
||||
}
|
||||
|
||||
// TestGenerateWithCallbacks_ParallelToolArgs is the regression test for #33.
|
||||
// It drives the streaming-callback wiring inside GenerateWithCallbacks with a
|
||||
// fake fantasy.Agent that emits two parallel tool calls before either result.
|
||||
// Each OnToolResult must receive the args of its own tool call (matched by
|
||||
// ToolCallID), not the args of the last OnToolCall in the step.
|
||||
func TestGenerateWithCallbacks_ParallelToolArgs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
argsA := `{"name":"scheduled_jobs"}`
|
||||
argsB := `{"name":"gmail_trigger"}`
|
||||
|
||||
fake := &fakeParallelAgent{
|
||||
calls: []fantasy.ToolCallContent{
|
||||
{ToolCallID: "kit-A", ToolName: "load_skill", Input: argsA},
|
||||
{ToolCallID: "kit-B", ToolName: "load_skill", Input: argsB},
|
||||
},
|
||||
results: []fantasy.ToolResultContent{
|
||||
{ToolCallID: "kit-A", ToolName: "load_skill", Result: fantasy.ToolResultOutputContentText{Text: "ok-A"}},
|
||||
{ToolCallID: "kit-B", ToolName: "load_skill", Result: fantasy.ToolResultOutputContentText{Text: "ok-B"}},
|
||||
},
|
||||
}
|
||||
|
||||
a := &Agent{
|
||||
fantasyAgent: fake,
|
||||
streamingEnabled: false, // exercise the "hasCallbacks" branch
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
resultArgs := map[string]string{}
|
||||
executionArgs := map[string]string{} // captured when running == false
|
||||
|
||||
cb := GenerateCallbacks{
|
||||
OnToolExecution: func(id, _, args string, running bool) {
|
||||
if running {
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
executionArgs[id] = args
|
||||
},
|
||||
OnToolResult: func(id, _, args, _, _ string, _ bool) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
resultArgs[id] = args
|
||||
},
|
||||
}
|
||||
|
||||
if _, err := a.GenerateWithCallbacks(context.Background(), nil, cb); err != nil {
|
||||
t.Fatalf("GenerateWithCallbacks returned error: %v", err)
|
||||
}
|
||||
|
||||
if got, want := resultArgs["kit-A"], argsA; got != want {
|
||||
t.Errorf("OnToolResult for kit-A: args = %q, want %q", got, want)
|
||||
}
|
||||
if got, want := resultArgs["kit-B"], argsB; got != want {
|
||||
t.Errorf("OnToolResult for kit-B: args = %q, want %q", got, want)
|
||||
}
|
||||
if got, want := executionArgs["kit-A"], argsA; got != want {
|
||||
t.Errorf("OnToolExecution(finish) for kit-A: args = %q, want %q", got, want)
|
||||
}
|
||||
if got, want := executionArgs["kit-B"], argsB; got != want {
|
||||
t.Errorf("OnToolExecution(finish) for kit-B: args = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/message"
|
||||
"github.com/mark3labs/kit/internal/session"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
@@ -343,6 +344,90 @@ func (a *App) SwitchTreeSession(ts *session.TreeManager) {
|
||||
}
|
||||
}
|
||||
|
||||
// PopLastUserMessage truncates the tree session back to the parent of the
|
||||
// most recent user message on the current branch, syncs the in-memory
|
||||
// message store, and returns the user prompt text plus any image file
|
||||
// parts so the caller can resubmit via Run/RunWithFiles.
|
||||
//
|
||||
// This is the building block for /retry: the user message and any orphaned
|
||||
// assistant/tool entries produced by a failed turn become unreachable on
|
||||
// the current branch (they remain in the session file under a different
|
||||
// leaf) and are excluded from the next LLM context.
|
||||
//
|
||||
// Returns an error when:
|
||||
// - the agent is currently working (busy)
|
||||
// - the app has been closed
|
||||
// - no tree session is active (sessions disabled via --no-session)
|
||||
// - no user message exists on the current branch
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) PopLastUserMessage() (string, []kit.LLMFilePart, error) {
|
||||
a.mu.Lock()
|
||||
if a.closed {
|
||||
a.mu.Unlock()
|
||||
return "", nil, fmt.Errorf("app is closed")
|
||||
}
|
||||
if a.busy {
|
||||
a.mu.Unlock()
|
||||
return "", nil, fmt.Errorf("cannot retry while the agent is working")
|
||||
}
|
||||
a.mu.Unlock()
|
||||
|
||||
ts := a.opts.TreeSession
|
||||
if ts == nil {
|
||||
return "", nil, fmt.Errorf("no tree session active; /retry requires a session")
|
||||
}
|
||||
|
||||
// Walk the current branch backwards to find the most recent user message.
|
||||
branch := ts.GetBranch("")
|
||||
var target *session.MessageEntry
|
||||
for i := len(branch) - 1; i >= 0; i-- {
|
||||
me, ok := branch[i].(*session.MessageEntry)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if me.Role == string(message.RoleUser) {
|
||||
target = me
|
||||
break
|
||||
}
|
||||
}
|
||||
if target == nil {
|
||||
return "", nil, fmt.Errorf("no user message to retry")
|
||||
}
|
||||
|
||||
// Extract the prompt text and any image parts from the target entry.
|
||||
msg, err := target.ToMessage()
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("decode user message: %w", err)
|
||||
}
|
||||
prompt := msg.Content()
|
||||
var files []kit.LLMFilePart
|
||||
for _, part := range msg.Parts {
|
||||
if ic, ok := part.(message.ImageContent); ok {
|
||||
files = append(files, kit.LLMFilePart{
|
||||
Data: ic.Data,
|
||||
MediaType: ic.MediaType,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Move the leaf to the parent of the user message. The failed turn's
|
||||
// entries (user message + any partial assistant/tool entries) are still
|
||||
// in the tree file but no longer on the active branch, so they will not
|
||||
// be re-sent to the LLM. runTurn() will append a fresh user message on
|
||||
// the next call.
|
||||
if err := ts.Branch(target.ParentID); err != nil {
|
||||
return "", nil, fmt.Errorf("branch to parent: %w", err)
|
||||
}
|
||||
|
||||
// Sync the in-memory store with the new branch position so subsequent
|
||||
// reads (and ReloadMessagesFromTree() consumers) see the truncated view.
|
||||
a.store.Clear()
|
||||
a.store.Replace(ts.GetLLMMessages())
|
||||
|
||||
return prompt, files, nil
|
||||
}
|
||||
|
||||
// AddContextMessage adds a user-role message to the conversation history
|
||||
// without triggering an LLM response. Used by the ! shell command prefix
|
||||
// to inject command output into context so the LLM can reference it in
|
||||
|
||||
@@ -9,7 +9,10 @@ import (
|
||||
"time"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/fantasy"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
|
||||
"github.com/mark3labs/kit/internal/session"
|
||||
)
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -969,3 +972,146 @@ func TestReleaseBusyAfterCompact_dropsQueueWhenClosed(t *testing.T) {
|
||||
t.Fatalf("expected 0 PromptFunc calls on closed app, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// PopLastUserMessage (/retry building block)
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// TestPopLastUserMessage_NoTreeSession verifies that PopLastUserMessage
|
||||
// returns an error when no tree session is active.
|
||||
func TestPopLastUserMessage_NoTreeSession(t *testing.T) {
|
||||
app := newTestApp(newStub())
|
||||
defer app.Close()
|
||||
|
||||
prompt, files, err := app.PopLastUserMessage()
|
||||
if err == nil {
|
||||
t.Fatal("expected error when no tree session is active")
|
||||
}
|
||||
if prompt != "" || files != nil {
|
||||
t.Fatalf("expected zero values on error, got prompt=%q files=%v", prompt, files)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopLastUserMessage_WhileBusy verifies that PopLastUserMessage
|
||||
// refuses to truncate while the agent is busy (would race with executeBatch).
|
||||
func TestPopLastUserMessage_WhileBusy(t *testing.T) {
|
||||
app := newTestApp(newStub())
|
||||
defer app.Close()
|
||||
|
||||
app.mu.Lock()
|
||||
app.busy = true
|
||||
app.mu.Unlock()
|
||||
|
||||
_, _, err := app.PopLastUserMessage()
|
||||
if err == nil {
|
||||
t.Fatal("expected error when agent is busy")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "working") {
|
||||
t.Fatalf("expected error mentioning busy/working, got %q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopLastUserMessage_WhenClosed verifies that PopLastUserMessage
|
||||
// returns an error after Close().
|
||||
func TestPopLastUserMessage_WhenClosed(t *testing.T) {
|
||||
app := newTestApp(newStub())
|
||||
app.Close()
|
||||
|
||||
_, _, err := app.PopLastUserMessage()
|
||||
if err == nil {
|
||||
t.Fatal("expected error on closed app")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopLastUserMessage_TruncatesAndReturnsPrompt verifies the happy path:
|
||||
// a real tree session with user→assistant→user→assistant entries is
|
||||
// truncated back to before the most recent user message, and that user's
|
||||
// text is returned.
|
||||
func TestPopLastUserMessage_TruncatesAndReturnsPrompt(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
ts, err := session.CreateTreeSession(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("create tree session: %v", err)
|
||||
}
|
||||
defer func() { _ = ts.Close() }()
|
||||
|
||||
// Build history: user "first" → assistant "ack 1" → user "second" → assistant "ack 2".
|
||||
if _, err := ts.AppendLLMMessage(fantasy.NewUserMessage("first")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := ts.AppendLLMMessage(fantasy.Message{
|
||||
Role: fantasy.MessageRoleAssistant,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "ack 1"}},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := ts.AppendLLMMessage(fantasy.NewUserMessage("second")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := ts.AppendLLMMessage(fantasy.Message{
|
||||
Role: fantasy.MessageRoleAssistant,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "ack 2"}},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
app := New(Options{TreeSession: ts, PromptFunc: newStub().fn}, nil)
|
||||
defer app.Close()
|
||||
|
||||
prompt, files, err := app.PopLastUserMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("PopLastUserMessage: %v", err)
|
||||
}
|
||||
if prompt != "second" {
|
||||
t.Fatalf("expected prompt=%q, got %q", "second", prompt)
|
||||
}
|
||||
if files != nil {
|
||||
t.Fatalf("expected no files, got %v", files)
|
||||
}
|
||||
|
||||
// After truncation the branch should only contain the first user
|
||||
// message and its assistant response (the "second" turn is orphaned).
|
||||
msgs := ts.GetLLMMessages()
|
||||
if len(msgs) != 2 {
|
||||
t.Fatalf("expected 2 messages on truncated branch, got %d", len(msgs))
|
||||
}
|
||||
if got := messageText(msgs[0]); got != "first" {
|
||||
t.Fatalf("expected first message %q, got %q", "first", got)
|
||||
}
|
||||
if got := messageText(msgs[1]); got != "ack 1" {
|
||||
t.Fatalf("expected second message %q, got %q", "ack 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
// messageText extracts concatenated TextPart content from a fantasy.Message.
|
||||
func messageText(m fantasy.Message) string {
|
||||
var out strings.Builder
|
||||
for _, p := range m.Content {
|
||||
if tp, ok := p.(fantasy.TextPart); ok {
|
||||
out.WriteString(tp.Text)
|
||||
}
|
||||
}
|
||||
return out.String()
|
||||
}
|
||||
|
||||
// TestPopLastUserMessage_NoUserOnBranch verifies that an empty tree (no
|
||||
// user messages at all) returns a friendly error rather than panicking.
|
||||
func TestPopLastUserMessage_NoUserOnBranch(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
ts, err := session.CreateTreeSession(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("create tree session: %v", err)
|
||||
}
|
||||
defer func() { _ = ts.Close() }()
|
||||
|
||||
app := New(Options{TreeSession: ts, PromptFunc: newStub().fn}, nil)
|
||||
defer app.Close()
|
||||
|
||||
_, _, err = app.PopLastUserMessage()
|
||||
if err == nil {
|
||||
t.Fatal("expected error when no user message exists on branch")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no user message") {
|
||||
t.Fatalf("expected error mentioning missing user message, got %q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,11 +13,6 @@ type MessageStore struct {
|
||||
messages []kit.LLMMessage
|
||||
}
|
||||
|
||||
// NewMessageStore creates an empty MessageStore.
|
||||
func NewMessageStore() *MessageStore {
|
||||
return &MessageStore{}
|
||||
}
|
||||
|
||||
// NewMessageStoreWithMessages creates a MessageStore pre-populated with the
|
||||
// given messages. This is used when loading an existing session at startup.
|
||||
func NewMessageStoreWithMessages(msgs []kit.LLMMessage) *MessageStore {
|
||||
|
||||
@@ -29,7 +29,7 @@ func textOf(msg kit.LLMMessage) string {
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func TestNewMessageStore_empty(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
if s == nil {
|
||||
t.Fatal("expected non-nil store")
|
||||
}
|
||||
@@ -72,7 +72,7 @@ func TestNewMessageStoreWithMessages_isolatesInput(t *testing.T) {
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func TestAdd_appendsMessage(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
s.Add(makeTextMsg("user", "first"))
|
||||
s.Add(makeTextMsg("assistant", "second"))
|
||||
|
||||
@@ -82,7 +82,7 @@ func TestAdd_appendsMessage(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAdd_preservesOrder(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
texts := []string{"a", "b", "c"}
|
||||
for _, t2 := range texts {
|
||||
s.Add(makeTextMsg("user", t2))
|
||||
@@ -100,7 +100,7 @@ func TestAdd_preservesOrder(t *testing.T) {
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func TestReplace_swapsHistory(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
s.Add(makeTextMsg("user", "old"))
|
||||
|
||||
replacement := []kit.LLMMessage{
|
||||
@@ -120,7 +120,7 @@ func TestReplace_swapsHistory(t *testing.T) {
|
||||
|
||||
// Replace must deep-copy the incoming slice.
|
||||
func TestReplace_isolatesInput(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
replacement := []kit.LLMMessage{makeTextMsg("user", "original")}
|
||||
s.Replace(replacement)
|
||||
|
||||
@@ -137,7 +137,7 @@ func TestReplace_isolatesInput(t *testing.T) {
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func TestGetAll_returnsCopy(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
s.Add(makeTextMsg("user", "hello"))
|
||||
|
||||
got := s.GetAll()
|
||||
@@ -151,7 +151,7 @@ func TestGetAll_returnsCopy(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetAll_emptyStore(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
got := s.GetAll()
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("expected empty slice, got %d elements", len(got))
|
||||
@@ -163,7 +163,7 @@ func TestGetAll_emptyStore(t *testing.T) {
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func TestClear_removesAllMessages(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
s.Add(makeTextMsg("user", "a"))
|
||||
s.Add(makeTextMsg("user", "b"))
|
||||
s.Clear()
|
||||
@@ -174,7 +174,7 @@ func TestClear_removesAllMessages(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClear_allowsSubsequentAdds(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
s.Add(makeTextMsg("user", "before"))
|
||||
s.Clear()
|
||||
s.Add(makeTextMsg("user", "after"))
|
||||
@@ -193,7 +193,7 @@ func TestClear_allowsSubsequentAdds(t *testing.T) {
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
done := make(chan struct{})
|
||||
|
||||
// Writer goroutine.
|
||||
|
||||
+135
-45
@@ -1,6 +1,7 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -9,11 +10,11 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// CredentialStore holds all stored credentials for various providers.
|
||||
// Currently supports Anthropic and OpenAI credentials with both OAuth and API key authentication methods.
|
||||
// CredentialStore holds stored credentials for Anthropic, OpenAI, and GitHub Copilot.
|
||||
type CredentialStore struct {
|
||||
Anthropic *AnthropicCredentials `json:"anthropic,omitempty"`
|
||||
OpenAI *OpenAICredentials `json:"openai,omitempty"`
|
||||
Copilot *CopilotCredentials `json:"copilot,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicCredentials holds Anthropic API credentials supporting both OAuth
|
||||
@@ -43,6 +44,16 @@ type OpenAICredentials struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// CopilotCredentials holds GitHub OAuth credentials and the short-lived
|
||||
// GitHub Copilot API token derived from them.
|
||||
type CopilotCredentials struct {
|
||||
Type string `json:"type"` // "oauth"
|
||||
GitHubToken string `json:"github_token,omitempty"` // GitHub device-flow OAuth token
|
||||
CopilotAccessToken string `json:"copilot_access_token,omitempty"` // Short-lived Copilot API token
|
||||
ExpiresAt int64 `json:"expires_at,omitempty"` // Copilot token expiry
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// oauthTokenExpired reports whether an OAuth token with the given type and
|
||||
// expiry unix timestamp is past its expiry. Returns false for API key
|
||||
// credentials or when no expiry is set.
|
||||
@@ -91,6 +102,16 @@ func (c *OpenAICredentials) NeedsRefresh() bool {
|
||||
return oauthTokenNeedsRefresh(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// IsExpired checks if the Copilot API token is expired.
|
||||
func (c *CopilotCredentials) IsExpired() bool {
|
||||
return oauthTokenExpired(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// NeedsRefresh reports whether the Copilot API token should be renewed.
|
||||
func (c *CopilotCredentials) NeedsRefresh() bool {
|
||||
return oauthTokenNeedsRefresh(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// CredentialManager handles secure storage and retrieval of authentication credentials.
|
||||
// It manages a JSON file stored in the user's config directory with appropriate
|
||||
// file permissions for security.
|
||||
@@ -222,7 +243,7 @@ func (cm *CredentialManager) RemoveAnthropicCredentials() error {
|
||||
store.Anthropic = nil
|
||||
|
||||
// If store is empty, remove the file entirely
|
||||
if store.Anthropic == nil {
|
||||
if store.Anthropic == nil && store.OpenAI == nil && store.Copilot == nil {
|
||||
if err := os.Remove(cm.credentialsPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to remove credentials file: %w", err)
|
||||
}
|
||||
@@ -255,29 +276,6 @@ func (cm *CredentialManager) HasAnthropicCredentials() (bool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// SetOpenAICredentials stores OpenAI API key credentials. It validates the
|
||||
// API key format before storing. The API key must start with "sk-" and be
|
||||
// at least 20 characters long. Returns an error if the API key is invalid or
|
||||
// if storage fails.
|
||||
func (cm *CredentialManager) SetOpenAICredentials(apiKey string) error {
|
||||
if err := validateOpenAIAPIKey(apiKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store.OpenAI = &OpenAICredentials{
|
||||
Type: "api_key",
|
||||
APIKey: apiKey,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
return cm.SaveCredentials(store)
|
||||
}
|
||||
|
||||
// GetOpenAICredentials retrieves stored OpenAI credentials. Returns nil if
|
||||
// no credentials are stored. The returned credentials may be either OAuth or API
|
||||
// key type, check the Type field to determine which.
|
||||
@@ -302,7 +300,7 @@ func (cm *CredentialManager) RemoveOpenAICredentials() error {
|
||||
store.OpenAI = nil
|
||||
|
||||
// If store is empty, remove the file entirely
|
||||
if store.Anthropic == nil && store.OpenAI == nil {
|
||||
if store.Anthropic == nil && store.OpenAI == nil && store.Copilot == nil {
|
||||
if err := os.Remove(cm.credentialsPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to remove credentials file: %w", err)
|
||||
}
|
||||
@@ -312,6 +310,104 @@ func (cm *CredentialManager) RemoveOpenAICredentials() error {
|
||||
return cm.SaveCredentials(store)
|
||||
}
|
||||
|
||||
// GetCopilotCredentials retrieves stored GitHub Copilot credentials.
|
||||
func (cm *CredentialManager) GetCopilotCredentials() (*CopilotCredentials, error) {
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return store.Copilot, nil
|
||||
}
|
||||
|
||||
// RemoveCopilotCredentials removes stored GitHub Copilot credentials.
|
||||
func (cm *CredentialManager) RemoveCopilotCredentials() error {
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store.Copilot = nil
|
||||
|
||||
if store.Anthropic == nil && store.OpenAI == nil && store.Copilot == nil {
|
||||
if err := os.Remove(cm.credentialsPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to remove credentials file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return cm.SaveCredentials(store)
|
||||
}
|
||||
|
||||
// HasCopilotCredentials checks if valid GitHub Copilot credentials are stored.
|
||||
func (cm *CredentialManager) HasCopilotCredentials() (bool, error) {
|
||||
creds, err := cm.GetCopilotCredentials()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if creds == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return creds.Type == "oauth" && creds.GitHubToken != "", nil
|
||||
}
|
||||
|
||||
// SetCopilotOAuthCredentials stores GitHub Copilot OAuth credentials.
|
||||
func (cm *CredentialManager) SetCopilotOAuthCredentials(creds *CopilotCredentials) error {
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store.Copilot = creds
|
||||
return cm.SaveCredentials(store)
|
||||
}
|
||||
|
||||
// GetValidCopilotAccessToken returns a fresh Copilot API token, renewing it
|
||||
// with the stored GitHub OAuth token when needed.
|
||||
func (cm *CredentialManager) GetValidCopilotAccessToken() (string, error) {
|
||||
return cm.GetValidCopilotAccessTokenContext(context.Background())
|
||||
}
|
||||
|
||||
// GetValidCopilotAccessTokenContext returns a fresh Copilot API token, renewing
|
||||
// it with the stored GitHub OAuth token when needed.
|
||||
func (cm *CredentialManager) GetValidCopilotAccessTokenContext(ctx context.Context) (string, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
creds, err := cm.GetCopilotCredentials()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if creds == nil {
|
||||
return "", fmt.Errorf("no Copilot credentials found")
|
||||
}
|
||||
if creds.Type != "oauth" {
|
||||
return "", fmt.Errorf("unknown credential type: %s", creds.Type)
|
||||
}
|
||||
if creds.GitHubToken == "" {
|
||||
return "", fmt.Errorf("GitHub OAuth token missing from Copilot credentials")
|
||||
}
|
||||
|
||||
if creds.CopilotAccessToken == "" || creds.NeedsRefresh() {
|
||||
client := NewCopilotOAuthClient()
|
||||
newCreds, err := client.RefreshCopilotToken(ctx, creds.GitHubToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to refresh Copilot token: %w", err)
|
||||
}
|
||||
newCreds.CreatedAt = creds.CreatedAt
|
||||
|
||||
if err := cm.SetCopilotOAuthCredentials(newCreds); err != nil {
|
||||
return "", fmt.Errorf("failed to save refreshed Copilot token: %w", err)
|
||||
}
|
||||
|
||||
return newCreds.CopilotAccessToken, nil
|
||||
}
|
||||
|
||||
return creds.CopilotAccessToken, nil
|
||||
}
|
||||
|
||||
// HasOpenAICredentials checks if valid OpenAI credentials are stored.
|
||||
// Returns true if either a non-empty OAuth access token or API key is present,
|
||||
// false otherwise. Returns an error if credentials cannot be loaded.
|
||||
@@ -417,24 +513,18 @@ func validateAnthropicAPIKey(apiKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateOpenAIAPIKey validates the format of an OpenAI API key
|
||||
func validateOpenAIAPIKey(apiKey string) error {
|
||||
apiKey = strings.TrimSpace(apiKey)
|
||||
// CredentialSourceOAuth is the source description returned by
|
||||
// GetAnthropicAPIKey when the key resolves to stored OAuth credentials.
|
||||
// Consumers should compare against this constant (or use IsAnthropicOAuth)
|
||||
// rather than matching the string literal.
|
||||
const CredentialSourceOAuth = "stored OAuth credentials"
|
||||
|
||||
if apiKey == "" {
|
||||
return fmt.Errorf("API key cannot be empty")
|
||||
}
|
||||
|
||||
// OpenAI API keys typically start with "sk-" and are quite long
|
||||
if !strings.HasPrefix(apiKey, "sk-") {
|
||||
return fmt.Errorf("invalid OpenAI API key format (should start with 'sk-')")
|
||||
}
|
||||
|
||||
if len(apiKey) < 20 {
|
||||
return fmt.Errorf("API key appears to be too short")
|
||||
}
|
||||
|
||||
return nil
|
||||
// IsAnthropicOAuth reports whether the active Anthropic credential resolves
|
||||
// to a stored OAuth token (in which case the user is not billed per-token).
|
||||
// flagValue is the --provider-api-key flag value (may be empty).
|
||||
func IsAnthropicOAuth(flagValue string) bool {
|
||||
_, source, err := GetAnthropicAPIKey(flagValue)
|
||||
return err == nil && source == CredentialSourceOAuth
|
||||
}
|
||||
|
||||
// GetAnthropicAPIKey retrieves an Anthropic API key from multiple sources in priority order:
|
||||
@@ -459,7 +549,7 @@ func GetAnthropicAPIKey(flagValue string) (string, string, error) {
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to get valid OAuth token: %w", err)
|
||||
}
|
||||
return token, "stored OAuth credentials", nil
|
||||
return token, CredentialSourceOAuth, nil
|
||||
} else if creds.Type == "api_key" && creds.APIKey != "" {
|
||||
return creds.APIKey, "stored API key", nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCredentialManager(t *testing.T) {
|
||||
@@ -215,6 +216,7 @@ func TestCredentialStorePersistence(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
|
||||
defer func() { _ = os.RemoveAll(tempDir) }()
|
||||
|
||||
credentialsPath := filepath.Join(tempDir, "credentials.json")
|
||||
@@ -252,3 +254,98 @@ func TestCredentialStorePersistence(t *testing.T) {
|
||||
t.Errorf("Expected file permissions 0600, got %v", info.Mode().Perm())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopilotCredentials(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "kit-auth-test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(tempDir) }()
|
||||
|
||||
cm := &CredentialManager{
|
||||
credentialsPath: filepath.Join(tempDir, "credentials.json"),
|
||||
}
|
||||
|
||||
creds := &CopilotCredentials{
|
||||
Type: "oauth",
|
||||
GitHubToken: "github-token",
|
||||
CopilotAccessToken: "copilot-token",
|
||||
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := cm.SetCopilotOAuthCredentials(creds); err != nil {
|
||||
t.Fatalf("SetCopilotOAuthCredentials failed: %v", err)
|
||||
}
|
||||
|
||||
hasAuth, err := cm.HasCopilotCredentials()
|
||||
if err != nil {
|
||||
t.Fatalf("HasCopilotCredentials failed: %v", err)
|
||||
}
|
||||
if !hasAuth {
|
||||
t.Fatal("Expected Copilot credentials")
|
||||
}
|
||||
|
||||
token, err := cm.GetValidCopilotAccessToken()
|
||||
if err != nil {
|
||||
t.Fatalf("GetValidCopilotAccessToken failed: %v", err)
|
||||
}
|
||||
if token != creds.CopilotAccessToken {
|
||||
t.Fatalf("Expected Copilot token %q, got %q", creds.CopilotAccessToken, token)
|
||||
}
|
||||
|
||||
if err := cm.RemoveCopilotCredentials(); err != nil {
|
||||
t.Fatalf("RemoveCopilotCredentials failed: %v", err)
|
||||
}
|
||||
hasAuth, err = cm.HasCopilotCredentials()
|
||||
if err != nil {
|
||||
t.Fatalf("HasCopilotCredentials after removal failed: %v", err)
|
||||
}
|
||||
if hasAuth {
|
||||
t.Fatal("Expected no Copilot credentials after removal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveCredentialsPreservesOtherProviders(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "kit-auth-test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(tempDir) }()
|
||||
|
||||
cm := &CredentialManager{
|
||||
credentialsPath: filepath.Join(tempDir, "credentials.json"),
|
||||
}
|
||||
|
||||
if err := cm.SetOpenAIOAuthCredentials(&OpenAICredentials{
|
||||
Type: "oauth",
|
||||
AccessToken: "openai-token",
|
||||
RefreshToken: "refresh-token",
|
||||
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||
AccountID: "account",
|
||||
CreatedAt: time.Now(),
|
||||
}); err != nil {
|
||||
t.Fatalf("SetOpenAIOAuthCredentials failed: %v", err)
|
||||
}
|
||||
if err := cm.SetCopilotOAuthCredentials(&CopilotCredentials{
|
||||
Type: "oauth",
|
||||
GitHubToken: "github-token",
|
||||
CopilotAccessToken: "copilot-token",
|
||||
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||
CreatedAt: time.Now(),
|
||||
}); err != nil {
|
||||
t.Fatalf("SetCopilotOAuthCredentials failed: %v", err)
|
||||
}
|
||||
|
||||
if err := cm.RemoveCopilotCredentials(); err != nil {
|
||||
t.Fatalf("RemoveCopilotCredentials failed: %v", err)
|
||||
}
|
||||
|
||||
hasOpenAI, err := cm.HasOpenAICredentials()
|
||||
if err != nil {
|
||||
t.Fatalf("HasOpenAICredentials failed: %v", err)
|
||||
}
|
||||
if !hasOpenAI {
|
||||
t.Fatal("Expected OpenAI credentials to remain after removing Copilot credentials")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -211,6 +212,262 @@ type OpenAIOAuthClient struct {
|
||||
Scopes string
|
||||
}
|
||||
|
||||
// CopilotOAuthClient handles GitHub device-flow OAuth and exchanges the
|
||||
// GitHub token for a short-lived GitHub Copilot API token.
|
||||
//
|
||||
// The GitHub token comes from GitHub's OAuth device flow. It is then presented
|
||||
// to GitHub's internal Copilot token endpoint, which returns the bearer token
|
||||
// used by api.githubcopilot.com.
|
||||
type CopilotOAuthClient struct {
|
||||
ClientID string
|
||||
DeviceURL string
|
||||
TokenURL string
|
||||
CopilotURL string
|
||||
Scopes string
|
||||
PollTimeout time.Duration
|
||||
ClientTimeout time.Duration
|
||||
}
|
||||
|
||||
// CopilotDeviceCode contains data returned by GitHub's device-code endpoint.
|
||||
type CopilotDeviceCode struct {
|
||||
DeviceCode string `json:"device_code"`
|
||||
UserCode string `json:"user_code"`
|
||||
VerificationURI string `json:"verification_uri"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Interval int `json:"interval"`
|
||||
}
|
||||
|
||||
// NewCopilotOAuthClient creates a GitHub Copilot OAuth client.
|
||||
func NewCopilotOAuthClient() *CopilotOAuthClient {
|
||||
return &CopilotOAuthClient{
|
||||
ClientID: "Iv1.b507a08c87ecfe98",
|
||||
DeviceURL: "https://github.com/login/device/code",
|
||||
TokenURL: "https://github.com/login/oauth/access_token",
|
||||
CopilotURL: "https://api.github.com/copilot_internal/v2/token",
|
||||
Scopes: "read:user",
|
||||
PollTimeout: 15 * time.Minute,
|
||||
ClientTimeout: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// StartDeviceFlow requests a GitHub device code for browser login.
|
||||
//
|
||||
// The returned user code and verification URI are displayed by loginCopilot.
|
||||
// GitHub's response may omit interval, so this method normalizes it to the
|
||||
// documented five-second default.
|
||||
func (c *CopilotOAuthClient) StartDeviceFlow(ctx context.Context) (*CopilotDeviceCode, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
data := url.Values{
|
||||
"client_id": {c.ClientID},
|
||||
"scope": {c.Scopes},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", c.DeviceURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create device-code request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := (&http.Client{Timeout: c.ClientTimeout}).Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to request device code: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("device-code request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var code CopilotDeviceCode
|
||||
if err := json.NewDecoder(resp.Body).Decode(&code); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode device-code response: %w", err)
|
||||
}
|
||||
if code.DeviceCode == "" || code.UserCode == "" || code.VerificationURI == "" {
|
||||
return nil, fmt.Errorf("device-code response missing required fields")
|
||||
}
|
||||
if code.Interval <= 0 {
|
||||
code.Interval = 5
|
||||
}
|
||||
return &code, nil
|
||||
}
|
||||
|
||||
// PollDeviceToken waits until the user authorizes the device code and returns
|
||||
// the resulting GitHub OAuth token.
|
||||
//
|
||||
// It follows GitHub's device-flow polling contract: authorization_pending keeps
|
||||
// polling, slow_down increases the interval, and polling stops at the earlier of
|
||||
// the client timeout or the device-code expiry.
|
||||
func (c *CopilotOAuthClient) PollDeviceToken(ctx context.Context, deviceCode *CopilotDeviceCode) (string, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
if deviceCode == nil || deviceCode.DeviceCode == "" {
|
||||
return "", fmt.Errorf("device code missing")
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(c.PollTimeout)
|
||||
if deviceCode.ExpiresIn > 0 {
|
||||
expiresAt := time.Now().Add(time.Duration(deviceCode.ExpiresIn) * time.Second)
|
||||
if expiresAt.Before(deadline) {
|
||||
deadline = expiresAt
|
||||
}
|
||||
}
|
||||
|
||||
interval := time.Duration(deviceCode.Interval) * time.Second
|
||||
if interval <= 0 {
|
||||
interval = 5 * time.Second
|
||||
}
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
wait := interval
|
||||
if remaining := time.Until(deadline); remaining < wait {
|
||||
wait = remaining
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
case <-time.After(wait):
|
||||
}
|
||||
|
||||
data := url.Values{
|
||||
"client_id": {c.ClientID},
|
||||
"device_code": {deviceCode.DeviceCode},
|
||||
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", c.TokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create device-token request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := (&http.Client{Timeout: c.ClientTimeout}).Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to poll device token: %w", err)
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Error string `json:"error"`
|
||||
Description string `json:"error_description"`
|
||||
}
|
||||
decodeErr := json.NewDecoder(resp.Body).Decode(&tokenResp)
|
||||
_ = resp.Body.Close()
|
||||
if decodeErr != nil {
|
||||
return "", fmt.Errorf("failed to decode device-token response: %w", decodeErr)
|
||||
}
|
||||
|
||||
if tokenResp.AccessToken != "" {
|
||||
return tokenResp.AccessToken, nil
|
||||
}
|
||||
|
||||
switch tokenResp.Error {
|
||||
case "authorization_pending":
|
||||
continue
|
||||
case "slow_down":
|
||||
interval += 5 * time.Second
|
||||
continue
|
||||
case "expired_token":
|
||||
return "", fmt.Errorf("device code expired; restart login")
|
||||
case "access_denied":
|
||||
return "", fmt.Errorf("github login denied")
|
||||
case "":
|
||||
return "", fmt.Errorf("device-token request failed with status %d", resp.StatusCode)
|
||||
default:
|
||||
if tokenResp.Description != "" {
|
||||
return "", fmt.Errorf("device-token request failed: %s: %s", tokenResp.Error, tokenResp.Description)
|
||||
}
|
||||
return "", fmt.Errorf("device-token request failed: %s", tokenResp.Error)
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("timed out waiting for github device authorization")
|
||||
}
|
||||
|
||||
// ExchangeGitHubToken converts a GitHub OAuth token into a Copilot API token.
|
||||
// It is a semantic wrapper over RefreshCopilotToken used by the login flow.
|
||||
func (c *CopilotOAuthClient) ExchangeGitHubToken(ctx context.Context, githubToken string) (*CopilotCredentials, error) {
|
||||
return c.RefreshCopilotToken(ctx, githubToken)
|
||||
}
|
||||
|
||||
// RefreshCopilotToken obtains a fresh short-lived Copilot token from GitHub.
|
||||
//
|
||||
// GitHub may return expires_at as either a Unix timestamp or RFC3339 string.
|
||||
// parseCopilotExpiry handles both forms and falls back to a conservative
|
||||
// 20-minute lifetime when the field is absent or unrecognized.
|
||||
func (c *CopilotOAuthClient) RefreshCopilotToken(ctx context.Context, githubToken string) (*CopilotCredentials, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", c.CopilotURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create copilot token request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "token "+githubToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("User-Agent", "kit")
|
||||
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
|
||||
|
||||
resp, err := (&http.Client{Timeout: c.ClientTimeout}).Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to request copilot token: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("copilot token request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
Token string `json:"token"`
|
||||
ExpiresAt any `json:"expires_at"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode copilot token response: %w", err)
|
||||
}
|
||||
if tokenResp.Token == "" {
|
||||
return nil, fmt.Errorf("copilot token response missing token")
|
||||
}
|
||||
|
||||
expiresAt := parseCopilotExpiry(tokenResp.ExpiresAt)
|
||||
if expiresAt == 0 {
|
||||
expiresAt = time.Now().Add(20 * time.Minute).Unix()
|
||||
}
|
||||
|
||||
return &CopilotCredentials{
|
||||
Type: "oauth",
|
||||
GitHubToken: githubToken,
|
||||
CopilotAccessToken: tokenResp.Token,
|
||||
ExpiresAt: expiresAt,
|
||||
CreatedAt: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseCopilotExpiry normalizes GitHub's expires_at variants to a Unix second.
|
||||
func parseCopilotExpiry(value any) int64 {
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
return int64(v)
|
||||
case string:
|
||||
if parsed, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
return parsed
|
||||
}
|
||||
if parsed, err := time.Parse(time.RFC3339, v); err == nil {
|
||||
return parsed.Unix()
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// NewOpenAIOAuthClient creates a new OAuth client configured for OpenAI Codex OAuth.
|
||||
// This uses the public client ID for CLI applications with PKCE for security.
|
||||
func NewOpenAIOAuthClient() *OpenAIOAuthClient {
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCopilotStartDeviceFlow(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Fatalf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Fatalf("ParseForm failed: %v", err)
|
||||
}
|
||||
if r.Form.Get("client_id") != "client-id" {
|
||||
t.Fatalf("expected client id, got %q", r.Form.Get("client_id"))
|
||||
}
|
||||
if r.Form.Get("scope") != "read:user" {
|
||||
t.Fatalf("expected scope, got %q", r.Form.Get("scope"))
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"device_code": "device-code",
|
||||
"user_code": "USER-CODE",
|
||||
"verification_uri": "https://github.com/login/device",
|
||||
"expires_in": 600,
|
||||
"interval": 1,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewCopilotOAuthClient()
|
||||
client.ClientID = "client-id"
|
||||
client.DeviceURL = server.URL
|
||||
|
||||
code, err := client.StartDeviceFlow(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("StartDeviceFlow failed: %v", err)
|
||||
}
|
||||
if code.DeviceCode != "device-code" || code.UserCode != "USER-CODE" || code.Interval != 1 {
|
||||
t.Fatalf("unexpected device code: %#v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopilotPollDeviceToken(t *testing.T) {
|
||||
polls := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
polls++
|
||||
if r.Method != http.MethodPost {
|
||||
t.Fatalf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Fatalf("ParseForm failed: %v", err)
|
||||
}
|
||||
if r.Form.Get("grant_type") != "urn:ietf:params:oauth:grant-type:device_code" {
|
||||
t.Fatalf("unexpected grant type: %q", r.Form.Get("grant_type"))
|
||||
}
|
||||
if polls == 1 {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"error": "authorization_pending"})
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"access_token": "github-token"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewCopilotOAuthClient()
|
||||
client.ClientID = "client-id"
|
||||
client.TokenURL = server.URL
|
||||
client.PollTimeout = 5 * time.Second
|
||||
client.ClientTimeout = time.Second
|
||||
|
||||
token, err := client.PollDeviceToken(context.Background(), &CopilotDeviceCode{
|
||||
DeviceCode: "device-code",
|
||||
ExpiresIn: 10,
|
||||
Interval: 1,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PollDeviceToken failed: %v", err)
|
||||
}
|
||||
if token != "github-token" {
|
||||
t.Fatalf("expected github-token, got %q", token)
|
||||
}
|
||||
if polls != 2 {
|
||||
t.Fatalf("expected 2 polls, got %d", polls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopilotRefreshToken(t *testing.T) {
|
||||
expiresAt := time.Now().Add(time.Hour).Unix()
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
t.Fatalf("expected GET, got %s", r.Method)
|
||||
}
|
||||
if r.Header.Get("Authorization") != "token github-token" {
|
||||
t.Fatalf("unexpected authorization header: %q", r.Header.Get("Authorization"))
|
||||
}
|
||||
if r.Header.Get("User-Agent") != "kit" {
|
||||
t.Fatalf("unexpected user agent: %q", r.Header.Get("User-Agent"))
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"token": "copilot-token",
|
||||
"expires_at": expiresAt,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewCopilotOAuthClient()
|
||||
client.CopilotURL = server.URL
|
||||
|
||||
creds, err := client.RefreshCopilotToken(context.Background(), "github-token")
|
||||
if err != nil {
|
||||
t.Fatalf("RefreshCopilotToken failed: %v", err)
|
||||
}
|
||||
if creds.GitHubToken != "github-token" || creds.CopilotAccessToken != "copilot-token" {
|
||||
t.Fatalf("unexpected credentials: %#v", creds)
|
||||
}
|
||||
if creds.ExpiresAt != expiresAt {
|
||||
t.Fatalf("expected expires_at %d, got %d", expiresAt, creds.ExpiresAt)
|
||||
}
|
||||
}
|
||||
@@ -493,6 +493,12 @@ mcpServers:
|
||||
# maxTokens: 16384
|
||||
# systemPrompt: "You are a deep reasoning assistant." # or a file path
|
||||
|
||||
# Skills configuration (all optional)
|
||||
# no-skills: false # Set to true to disable all skill loading
|
||||
# skill: # Explicit skill files/dirs (disables auto-discovery)
|
||||
# - "/path/to/skill.md"
|
||||
# skills-dir: "/path/to/skills" # Override project-local directory for auto-discovery
|
||||
|
||||
# API Configuration (can also use environment variables)
|
||||
# provider-api-key: "your-api-key" # API key for OpenAI, Anthropic, or Google
|
||||
# provider-url: "https://api.openai.com/v1" # Base URL for OpenAI, Anthropic, or Ollama
|
||||
|
||||
@@ -205,6 +205,9 @@ func TestEnsureConfigExists(t *testing.T) {
|
||||
"type: \"local\"",
|
||||
"type: \"remote\"",
|
||||
"Core tools",
|
||||
"# Skills configuration",
|
||||
"no-skills:",
|
||||
"skills-dir:",
|
||||
}
|
||||
|
||||
for _, expected := range expectedSections {
|
||||
|
||||
@@ -7,32 +7,48 @@ import (
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// LoadAndValidateConfig loads configuration from viper, fixes environment variable
|
||||
// casing issues, and validates the configuration. Returns an error if loading or
|
||||
// validation fails.
|
||||
// LoadAndValidateConfig loads configuration from the process-global viper
|
||||
// store, fixes environment variable casing issues, and validates the
|
||||
// configuration. Returns an error if loading or validation fails.
|
||||
//
|
||||
// This is a convenience wrapper around [LoadAndValidateConfigFrom] using the
|
||||
// shared global store; it is retained for the CLI and other callers that rely
|
||||
// on viper's process-global state.
|
||||
func LoadAndValidateConfig() (*Config, error) {
|
||||
return LoadAndValidateConfigFrom(viper.GetViper())
|
||||
}
|
||||
|
||||
// LoadAndValidateConfigFrom loads configuration from the supplied per-instance
|
||||
// store, fixes environment variable casing issues, and validates the
|
||||
// configuration. When v is nil, the process-global store is used. Threading an
|
||||
// explicit store lets each Kit instance own an isolated configuration without
|
||||
// clobbering other instances in the same process.
|
||||
func LoadAndValidateConfigFrom(v *viper.Viper) (*Config, error) {
|
||||
if v == nil {
|
||||
v = viper.GetViper()
|
||||
}
|
||||
config := &Config{
|
||||
MCPServers: make(map[string]MCPServerConfig),
|
||||
}
|
||||
if err := viper.Unmarshal(config); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal config: %v", err)
|
||||
if err := v.Unmarshal(config); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
|
||||
// Fix environment variable case sensitivity issue
|
||||
// Viper lowercases all keys, but we need to preserve the original case for environment variables
|
||||
fixEnvironmentCase(config)
|
||||
fixEnvironmentCase(v, config)
|
||||
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid config: %v", err)
|
||||
return nil, fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// fixEnvironmentCase fixes the case of environment variable keys that were lowercased by Viper
|
||||
func fixEnvironmentCase(config *Config) {
|
||||
func fixEnvironmentCase(v *viper.Viper, config *Config) {
|
||||
// Get the raw config data from viper
|
||||
rawConfig := viper.AllSettings()
|
||||
rawConfig := v.AllSettings()
|
||||
|
||||
// Check if we have mcpServers in the raw config
|
||||
if mcpServersRaw, ok := rawConfig["mcpservers"]; ok {
|
||||
|
||||
@@ -56,9 +56,3 @@ func (e *EnvSubstituter) SubstituteEnvVars(content string) (string, error) {
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// HasEnvVars checks if content contains environment variable patterns (${env://...}).
|
||||
// This is useful for determining if substitution is needed before processing.
|
||||
func HasEnvVars(content string) bool {
|
||||
return envVarPattern.MatchString(content)
|
||||
}
|
||||
|
||||
@@ -187,41 +187,3 @@ func TestEnvSubstituter_SubstituteEnvVars(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasEnvVars(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "has env vars",
|
||||
content: `{"token": "${env://GITHUB_TOKEN}"}`,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "has env vars with default",
|
||||
content: `{"debug": "${env://DEBUG:-false}"}`,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "no env vars",
|
||||
content: `{"name": "${username}", "normal": "value"}`,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty content",
|
||||
content: "",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := HasEnvVars(tt.content)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+57
-78
@@ -59,12 +59,6 @@ func passwordPromptFromContext(ctx context.Context) PasswordPromptCallback {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ContextWithSudoPassword returns a new context with the sudo password set.
|
||||
// When present, the bash tool will use sudo -S to pipe this password to sudo commands.
|
||||
func ContextWithSudoPassword(ctx context.Context, password string) context.Context {
|
||||
return context.WithValue(ctx, sudoPasswordKey, password)
|
||||
}
|
||||
|
||||
// sudoPasswordFromContext retrieves the sudo password from context.
|
||||
func sudoPasswordFromContext(ctx context.Context) string {
|
||||
if pw, ok := ctx.Value(sudoPasswordKey).(string); ok {
|
||||
@@ -160,15 +154,6 @@ func rewriteSudoForStdin(command string) string {
|
||||
return result
|
||||
}
|
||||
|
||||
// SudoPasswordRequiredResult is a special marker that indicates sudo needs a password.
|
||||
// This is stored in tool response metadata to signal the TUI to prompt for password.
|
||||
const SudoPasswordRequiredMetadata = `{"sudo_password_required":true}`
|
||||
|
||||
// IsSudoPasswordRequiredResult checks if a tool response indicates sudo password is needed.
|
||||
func IsSudoPasswordRequiredResult(resp fantasy.ToolResponse) bool {
|
||||
return resp.Metadata == SudoPasswordRequiredMetadata
|
||||
}
|
||||
|
||||
func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
|
||||
var args bashArgs
|
||||
if err := parseArgs(call.Input, &args); err != nil {
|
||||
@@ -258,34 +243,37 @@ func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
return executeBashBuffered(cmdCtx, call, cmd, sudoPassword)
|
||||
}
|
||||
|
||||
// executeBashBuffered collects all output before returning (original behavior).
|
||||
// It uses explicit pipes (not cmd.Stdout) so that cmd.WaitDelay can forcibly
|
||||
// close them when grandchild processes hold pipe handles open after the
|
||||
// direct child exits.
|
||||
func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd, sudoPassword string) (fantasy.ToolResponse, error) {
|
||||
// setupBashPipes opens stdout/stderr pipes (plus an optional sudo stdin),
|
||||
// starts the command, and asynchronously writes the sudo password if any.
|
||||
// Returns the readers ready for the caller to consume. If setup fails,
|
||||
// errResp is non-nil and the readers must not be used; the caller should
|
||||
// return the response directly.
|
||||
func setupBashPipes(cmd *exec.Cmd, sudoPassword string) (stdout, stderr io.Reader, errResp *fantasy.ToolResponse) {
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stdout pipe"), nil
|
||||
r := fantasy.NewTextErrorResponse("failed to create stdout pipe")
|
||||
return nil, nil, &r
|
||||
}
|
||||
stderrPipe, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stderr pipe"), nil
|
||||
r := fantasy.NewTextErrorResponse("failed to create stderr pipe")
|
||||
return nil, nil, &r
|
||||
}
|
||||
|
||||
// If we have a sudo password, create a stdin pipe and write the password
|
||||
var stdinPipe io.WriteCloser
|
||||
if sudoPassword != "" {
|
||||
stdinPipe, err = cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stdin pipe"), nil
|
||||
r := fantasy.NewTextErrorResponse("failed to create stdin pipe")
|
||||
return nil, nil, &r
|
||||
}
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to start command: %v", err)), nil
|
||||
r := fantasy.NewTextErrorResponse(fmt.Sprintf("failed to start command: %v", err))
|
||||
return nil, nil, &r
|
||||
}
|
||||
|
||||
// Write password to stdin if needed, then close stdin
|
||||
if sudoPassword != "" && stdinPipe != nil {
|
||||
go func() {
|
||||
defer func() { _ = stdinPipe.Close() }()
|
||||
@@ -293,19 +281,49 @@ func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exe
|
||||
}()
|
||||
}
|
||||
|
||||
return stdoutPipe, stderrPipe, nil
|
||||
}
|
||||
|
||||
// interpretBashExit decodes cmd.Wait()'s error into an exit code, mapping
|
||||
// context-deadline-exceeded to a friendly "command timed out" response.
|
||||
// errResp is non-nil only when the caller should short-circuit and return
|
||||
// it directly (e.g. timeout).
|
||||
func interpretBashExit(waitErr error, cmdCtx context.Context) (exitCode int, errResp *fantasy.ToolResponse) {
|
||||
if waitErr == nil {
|
||||
return 0, nil
|
||||
}
|
||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||
return exitErr.ExitCode(), nil
|
||||
}
|
||||
if cmdCtx.Err() == context.DeadlineExceeded {
|
||||
r := fantasy.NewTextErrorResponse("command timed out")
|
||||
return 0, &r
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// executeBashBuffered collects all output before returning (original behavior).
|
||||
// It uses explicit pipes (not cmd.Stdout) so that cmd.WaitDelay can forcibly
|
||||
// close them when grandchild processes hold pipe handles open after the
|
||||
// direct child exits.
|
||||
func executeBashBuffered(cmdCtx context.Context, _ fantasy.ToolCall, cmd *exec.Cmd, sudoPassword string) (fantasy.ToolResponse, error) {
|
||||
stdoutPipe, stderrPipe, errResp := setupBashPipes(cmd, sudoPassword)
|
||||
if errResp != nil {
|
||||
return *errResp, nil
|
||||
}
|
||||
|
||||
// Read pipes concurrently
|
||||
var wg sync.WaitGroup
|
||||
var stdout, stderr strings.Builder
|
||||
var stdoutErr, stderrErr error
|
||||
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, stdoutErr = io.Copy(&stdout, stdoutPipe)
|
||||
_, _ = io.Copy(&stdout, stdoutPipe)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, stderrErr = io.Copy(&stderr, stderrPipe)
|
||||
_, _ = io.Copy(&stderr, stderrPipe)
|
||||
}()
|
||||
|
||||
// Wait for the process to exit first. cmd.WaitDelay ensures that if
|
||||
@@ -316,18 +334,9 @@ func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exe
|
||||
// Wait for pipe readers to finish draining.
|
||||
wg.Wait()
|
||||
|
||||
// Ignore pipe read errors caused by WaitDelay force-closing —
|
||||
// we still have whatever was read before the close.
|
||||
_ = stdoutErr
|
||||
_ = stderrErr
|
||||
|
||||
exitCode := 0
|
||||
if waitErr != nil {
|
||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||
exitCode = exitErr.ExitCode()
|
||||
} else if cmdCtx.Err() == context.DeadlineExceeded {
|
||||
return fantasy.NewTextErrorResponse("command timed out"), nil
|
||||
}
|
||||
exitCode, errResp := interpretBashExit(waitErr, cmdCtx)
|
||||
if errResp != nil {
|
||||
return *errResp, nil
|
||||
}
|
||||
|
||||
return buildBashResponse(stdout.String(), stderr.String(), exitCode)
|
||||
@@ -335,35 +344,9 @@ func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exe
|
||||
|
||||
// executeBashStreaming streams output as it arrives via the callback.
|
||||
func executeBashStreaming(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd, outputCallback ToolOutputCallback, sudoPassword string) (fantasy.ToolResponse, error) {
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stdout pipe"), nil
|
||||
}
|
||||
stderrPipe, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stderr pipe"), nil
|
||||
}
|
||||
|
||||
// If we have a sudo password, create a stdin pipe
|
||||
var stdinPipe io.WriteCloser
|
||||
if sudoPassword != "" {
|
||||
stdinPipe, err = cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stdin pipe"), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Start command execution
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to start command: %v", err)), nil
|
||||
}
|
||||
|
||||
// Write password to stdin if needed, then close stdin
|
||||
if sudoPassword != "" && stdinPipe != nil {
|
||||
go func() {
|
||||
defer func() { _ = stdinPipe.Close() }()
|
||||
_, _ = io.WriteString(stdinPipe, sudoPassword+"\n")
|
||||
}()
|
||||
stdoutPipe, stderrPipe, errResp := setupBashPipes(cmd, sudoPassword)
|
||||
if errResp != nil {
|
||||
return *errResp, nil
|
||||
}
|
||||
|
||||
// Stream stdout and stderr concurrently
|
||||
@@ -400,20 +383,16 @@ func executeBashStreaming(cmdCtx context.Context, call fantasy.ToolCall, cmd *ex
|
||||
// Wait for the process to exit. cmd.WaitDelay ensures that if pipes
|
||||
// remain open (held by grandchild processes), they'll be forcibly closed
|
||||
// after the grace period, which unblocks the scanners above.
|
||||
err = cmd.Wait()
|
||||
waitErr := cmd.Wait()
|
||||
|
||||
// Wait for the pipe readers to finish draining. This will complete
|
||||
// quickly since cmd.Wait() (with WaitDelay) has already ensured
|
||||
// the pipes are closed.
|
||||
wg.Wait()
|
||||
|
||||
exitCode := 0
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
exitCode = exitErr.ExitCode()
|
||||
} else if cmdCtx.Err() == context.DeadlineExceeded {
|
||||
return fantasy.NewTextErrorResponse("command timed out"), nil
|
||||
}
|
||||
exitCode, errResp := interpretBashExit(waitErr, cmdCtx)
|
||||
if errResp != nil {
|
||||
return *errResp, nil
|
||||
}
|
||||
|
||||
return buildBashResponse(strings.Join(stdoutChunks, "\n"), strings.Join(stderrChunks, "\n"), exitCode)
|
||||
|
||||
@@ -183,7 +183,7 @@ func TestRewriteSudoForStdin(t *testing.T) {
|
||||
|
||||
func TestSudoPasswordFromContext(t *testing.T) {
|
||||
// Test with password in context
|
||||
ctx := ContextWithSudoPassword(context.Background(), "secret123")
|
||||
ctx := context.WithValue(context.Background(), sudoPasswordKey, "secret123")
|
||||
pw := sudoPasswordFromContext(ctx)
|
||||
if pw != "secret123" {
|
||||
t.Errorf("expected password 'secret123', got %q", pw)
|
||||
|
||||
@@ -83,6 +83,9 @@ func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
}
|
||||
|
||||
func executeEdit(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return fantasy.ToolResponse{}, err
|
||||
}
|
||||
var args editArgs
|
||||
if err := parseArgs(call.Input, &args); err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to parse arguments: " + err.Error()), nil
|
||||
|
||||
@@ -42,6 +42,9 @@ func NewLsTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
}
|
||||
|
||||
func executeLs(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return fantasy.ToolResponse{}, err
|
||||
}
|
||||
var args lsArgs
|
||||
_ = parseArgs(call.Input, &args) // optional args
|
||||
|
||||
|
||||
@@ -47,6 +47,9 @@ func NewReadTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
}
|
||||
|
||||
func executeRead(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return fantasy.ToolResponse{}, err
|
||||
}
|
||||
var args readArgs
|
||||
if err := parseArgs(call.Input, &args); err != nil {
|
||||
return fantasy.NewTextErrorResponse("path parameter is required"), nil
|
||||
|
||||
@@ -41,6 +41,9 @@ func NewWriteTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
}
|
||||
|
||||
func executeWrite(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return fantasy.ToolResponse{}, err
|
||||
}
|
||||
var args writeArgs
|
||||
if err := parseArgs(call.Input, &args); err != nil {
|
||||
return fantasy.NewTextErrorResponse("path and content parameters are required"), nil
|
||||
|
||||
@@ -0,0 +1,234 @@
|
||||
package extbridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// BaseContext returns an extensions.Context populated with the headless,
|
||||
// TUI-independent delegation fields: data access, state, options,
|
||||
// model/tool management, completions, subagents, tree navigation, skills,
|
||||
// template parsing, and model resolution.
|
||||
//
|
||||
// Callers overlay their UI-specific fields (print routes, widgets, prompts,
|
||||
// editor, TUI-aware SetModel/ReloadExtensions, etc.) on the returned value:
|
||||
// cmd/extension_context.go for the interactive TUI and
|
||||
// internal/acpserver/session.go for headless ACP mode. Keeping the shared
|
||||
// half here means a new data-access Context field only has to be wired once.
|
||||
//
|
||||
// ctx is used for subagent spawns; pass a long-lived context (not a
|
||||
// per-request one) so later spawns aren't cancelled prematurely.
|
||||
func BaseContext(ctx context.Context, kitInstance *kit.Kit) extensions.Context {
|
||||
return extensions.Context{
|
||||
// -------------------------------------------------------------------
|
||||
// Data access
|
||||
// -------------------------------------------------------------------
|
||||
GetContextStats: func() extensions.ContextStats {
|
||||
s := kitInstance.GetContextStats()
|
||||
return extensions.ContextStats{
|
||||
EstimatedTokens: s.EstimatedTokens,
|
||||
ContextLimit: s.ContextLimit,
|
||||
UsagePercent: s.UsagePercent,
|
||||
MessageCount: s.MessageCount,
|
||||
}
|
||||
},
|
||||
GetMessages: func() []extensions.SessionMessage {
|
||||
return kitInstance.Extensions().GetSessionMessages()
|
||||
},
|
||||
GetSessionPath: func() string {
|
||||
return kitInstance.GetSessionPath()
|
||||
},
|
||||
AppendEntry: func(entryType string, data string) (string, error) {
|
||||
return kitInstance.Extensions().AppendEntry(entryType, data)
|
||||
},
|
||||
GetEntries: func(entryType string) []extensions.ExtensionEntry {
|
||||
return kitInstance.Extensions().GetEntries(entryType)
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Extension state
|
||||
// -------------------------------------------------------------------
|
||||
SetState: func(key string, value string) {
|
||||
kitInstance.Extensions().SetState(key, value)
|
||||
},
|
||||
GetState: func(key string) (string, bool) {
|
||||
return kitInstance.Extensions().GetState(key)
|
||||
},
|
||||
DeleteState: func(key string) {
|
||||
kitInstance.Extensions().DeleteState(key)
|
||||
},
|
||||
ListState: func() []string {
|
||||
return kitInstance.Extensions().ListState()
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Options, model, and tool management
|
||||
// -------------------------------------------------------------------
|
||||
GetOption: func(name string) string {
|
||||
return kitInstance.Extensions().GetOption(name)
|
||||
},
|
||||
SetOption: func(name string, value string) {
|
||||
kitInstance.Extensions().SetOption(name, value)
|
||||
},
|
||||
// Headless model switch. The interactive TUI overrides this with a
|
||||
// version that also notifies the TUI and refreshes the usage tracker.
|
||||
SetModel: func(modelString string) error {
|
||||
previousModel := kitInstance.Extensions().GetContext().Model
|
||||
if err := kitInstance.SetModel(context.Background(), modelString); err != nil {
|
||||
return err
|
||||
}
|
||||
kitInstance.Extensions().UpdateContextModel(modelString)
|
||||
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
|
||||
return nil
|
||||
},
|
||||
GetAvailableModels: func() []extensions.ModelInfoEntry {
|
||||
return kitInstance.GetAvailableModels()
|
||||
},
|
||||
EmitCustomEvent: func(name string, data string) {
|
||||
kitInstance.Extensions().EmitCustomEvent(name, data)
|
||||
},
|
||||
GetAllTools: func() []extensions.ToolInfo {
|
||||
return kitInstance.Extensions().GetToolInfos()
|
||||
},
|
||||
SetActiveTools: func(names []string) {
|
||||
kitInstance.Extensions().SetActiveTools(names)
|
||||
},
|
||||
// Headless reload. The interactive TUI overrides this to also
|
||||
// refresh widgets/status/commands.
|
||||
ReloadExtensions: func() error {
|
||||
return kitInstance.Extensions().Reload()
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// LLM completions and subagents
|
||||
// -------------------------------------------------------------------
|
||||
Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
|
||||
return kitInstance.ExecuteCompletion(context.Background(), req)
|
||||
},
|
||||
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
return SpawnSubagent(ctx, kitInstance, config)
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Tree Navigation API
|
||||
// -------------------------------------------------------------------
|
||||
GetTreeNode: func(entryID string) *extensions.TreeNode {
|
||||
node := kitInstance.GetTreeNode(entryID)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
return &extensions.TreeNode{
|
||||
ID: node.ID,
|
||||
ParentID: node.ParentID,
|
||||
Type: node.Type,
|
||||
Role: node.Role,
|
||||
Content: node.Content,
|
||||
Model: node.Model,
|
||||
Provider: node.Provider,
|
||||
Timestamp: node.Timestamp,
|
||||
Children: node.Children,
|
||||
}
|
||||
},
|
||||
GetCurrentBranch: func() []extensions.TreeNode {
|
||||
nodes := kitInstance.GetCurrentBranch()
|
||||
result := make([]extensions.TreeNode, len(nodes))
|
||||
for i, n := range nodes {
|
||||
result[i] = extensions.TreeNode{
|
||||
ID: n.ID,
|
||||
ParentID: n.ParentID,
|
||||
Type: n.Type,
|
||||
Role: n.Role,
|
||||
Content: n.Content,
|
||||
Model: n.Model,
|
||||
Provider: n.Provider,
|
||||
Timestamp: n.Timestamp,
|
||||
Children: n.Children,
|
||||
}
|
||||
}
|
||||
return result
|
||||
},
|
||||
GetChildren: func(parentID string) []string {
|
||||
return kitInstance.GetChildren(parentID)
|
||||
},
|
||||
NavigateTo: func(entryID string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.NavigateTo(entryID)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
SummarizeBranch: func(fromID, toID string) string {
|
||||
summary, _ := kitInstance.SummarizeBranch(fromID, toID)
|
||||
return summary
|
||||
},
|
||||
CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.CollapseBranch(fromID, toID, summary)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Skill Loading API (context-injection variants are TUI-specific and
|
||||
// wired by the interactive overlay)
|
||||
// -------------------------------------------------------------------
|
||||
LoadSkill: func(path string) (*extensions.Skill, string) {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
return s, err
|
||||
},
|
||||
LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult {
|
||||
return kitInstance.LoadSkillsFromDirForExtension(dir)
|
||||
},
|
||||
DiscoverSkills: func() extensions.SkillLoadResult {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
return extensions.SkillLoadResult{Skills: skills}
|
||||
},
|
||||
GetAvailableSkills: func() []extensions.Skill {
|
||||
return kitInstance.DiscoverSkillsForExtension()
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Template Parsing API
|
||||
// -------------------------------------------------------------------
|
||||
ParseTemplate: func(name, content string) extensions.PromptTemplate {
|
||||
return kit.ParseTemplate(name, content)
|
||||
},
|
||||
RenderTemplate: func(tpl extensions.PromptTemplate, vars map[string]string) string {
|
||||
return kit.RenderTemplate(tpl, vars)
|
||||
},
|
||||
ParseArguments: func(input string, pattern extensions.ArgumentPattern) extensions.ParseResult {
|
||||
return kit.ParseArguments(input, pattern)
|
||||
},
|
||||
SimpleParseArguments: func(input string, count int) []string {
|
||||
return kit.SimpleParseArguments(input, count)
|
||||
},
|
||||
EvaluateModelConditional: func(condition string) bool {
|
||||
return kit.EvaluateModelConditional(kitInstance.Extensions().GetContext().Model, condition)
|
||||
},
|
||||
RenderWithModelConditionals: func(content string) string {
|
||||
return kit.RenderWithModelConditionals(content, kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Model Resolution API
|
||||
// -------------------------------------------------------------------
|
||||
ResolveModelChain: func(preferences []string) extensions.ModelResolutionResult {
|
||||
return kit.ResolveModelChain(preferences)
|
||||
},
|
||||
GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) {
|
||||
return kit.GetModelCapabilities(model)
|
||||
},
|
||||
CheckModelAvailable: func(model string) bool {
|
||||
return kit.CheckModelAvailable(model)
|
||||
},
|
||||
GetCurrentProvider: func() string {
|
||||
return kit.GetCurrentProvider(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
GetCurrentModelID: func() string {
|
||||
return kit.GetCurrentModelID(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -66,6 +66,7 @@ func SpawnSubagent(ctx context.Context, k *kit.Kit, cfg extensions.SubagentConfi
|
||||
SystemPrompt: cfg.SystemPrompt,
|
||||
Timeout: cfg.Timeout,
|
||||
NoSession: cfg.NoSession,
|
||||
Tools: k.GetToolsForSubagent(),
|
||||
}
|
||||
if cfg.OnEvent != nil {
|
||||
sdkCfg.OnEvent = func(e kit.Event) {
|
||||
|
||||
+135
-1
@@ -341,6 +341,13 @@ type Context struct {
|
||||
// The data survives across session restarts and can be retrieved via
|
||||
// GetEntries. Use entryType to namespace your data (e.g. "myext:state").
|
||||
//
|
||||
// AppendEntry is append-only and lives in the conversation tree, which
|
||||
// makes it the right tool for audit logs and event histories. For
|
||||
// last-write-wins snapshot state — "what's the current value of X?" —
|
||||
// prefer SetState / GetState instead. Those primitives store data in a
|
||||
// sidecar file outside the conversation tree, are O(1) to read/write,
|
||||
// and do not bloat branch reads or duplicate on fork.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// data, _ := json.Marshal(myState)
|
||||
@@ -360,6 +367,45 @@ type Context struct {
|
||||
// }
|
||||
GetEntries func(entryType string) []ExtensionEntry
|
||||
|
||||
// SetState stores a key-value pair in session-scoped, last-write-wins
|
||||
// extension state. Unlike AppendEntry the value is kept in a sidecar
|
||||
// file outside the conversation tree, so:
|
||||
// - reads are O(1) (no branch walk)
|
||||
// - writes don't bloat the session JSONL
|
||||
// - state is not duplicated on fork (branches share the sidecar)
|
||||
// - state is invisible to the LLM
|
||||
//
|
||||
// Use SetState for snapshot state ("current value of X"); use
|
||||
// AppendEntry for audit logs and event histories. Namespace keys with
|
||||
// your extension name to avoid collisions (e.g. "myext:budget-cap").
|
||||
//
|
||||
// State persists for the lifetime of the session. For ephemeral or
|
||||
// in-memory sessions the state lives only in memory.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ctx.SetState("myext:budget-cap", "10.00")
|
||||
SetState func(key string, value string)
|
||||
|
||||
// GetState returns the value previously stored via SetState. The bool
|
||||
// is false when the key was never written. Returns ("", false) when
|
||||
// state is unavailable.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// if cap, ok := ctx.GetState("myext:budget-cap"); ok {
|
||||
// fmt.Println("current cap:", cap)
|
||||
// }
|
||||
GetState func(key string) (string, bool)
|
||||
|
||||
// DeleteState removes a key from session-scoped extension state.
|
||||
// No-op when the key is missing.
|
||||
DeleteState func(key string)
|
||||
|
||||
// ListState returns all keys currently stored in session-scoped
|
||||
// extension state, in unspecified order.
|
||||
ListState func() []string
|
||||
|
||||
// SetEditorText sets the text content of the input editor. This can
|
||||
// be used to pre-fill the editor with suggested text (e.g. extracted
|
||||
// questions, handoff prompts). The cursor is moved to the end.
|
||||
@@ -1102,6 +1148,7 @@ type API struct {
|
||||
onError func(func(ErrorEvent, Context))
|
||||
onRetry func(func(RetryEvent, Context))
|
||||
onPrepareStep func(func(PrepareStepEvent, Context) *PrepareStepResult)
|
||||
onLLMUsage func(func(LLMUsageEvent, Context))
|
||||
}
|
||||
|
||||
// OnToolCall registers a handler that fires before a tool executes.
|
||||
@@ -1359,6 +1406,19 @@ func (a *API) OnPrepareStep(handler func(PrepareStepEvent, Context) *PrepareStep
|
||||
a.onPrepareStep(handler)
|
||||
}
|
||||
|
||||
// OnLLMUsage registers a handler that fires after each LLM provider call
|
||||
// with the token and cost deltas for that single call. Use this for
|
||||
// per-call usage attribution, real-time budget enforcement, and cost
|
||||
// dashboards that need to react between calls within a single agent turn.
|
||||
//
|
||||
// Handlers receive an LLMUsageEvent describing the call's input/output
|
||||
// tokens, cache tokens, computed cost, model, and provider. A single agent
|
||||
// turn typically fires multiple LLMUsageEvents (one per tool-loop
|
||||
// iteration).
|
||||
func (a *API) OnLLMUsage(handler func(LLMUsageEvent, Context)) {
|
||||
a.onLLMUsage(handler)
|
||||
}
|
||||
|
||||
// RegisterToolRenderer registers a custom renderer for a specific tool's
|
||||
// display in the TUI. The renderer controls the header (parameter summary)
|
||||
// and/or body (result display) of the tool's output block. If multiple
|
||||
@@ -2091,10 +2151,47 @@ type AgentStartEvent struct {
|
||||
|
||||
func (e AgentStartEvent) Type() EventType { return AgentStart }
|
||||
|
||||
// AgentEndEvent fires when the agent finishes responding.
|
||||
// AgentEndEvent fires when the agent finishes responding. In addition to the
|
||||
// final response and stop reason, the event carries per-turn aggregates so
|
||||
// observer-style extensions don't have to maintain parallel bookkeeping in
|
||||
// OnToolResult / OnStepFinish handlers.
|
||||
type AgentEndEvent struct {
|
||||
Response string
|
||||
StopReason string // "completed", "cancelled", "error"
|
||||
|
||||
// ToolCallCount is the total number of tool invocations observed during
|
||||
// this turn (sum across all steps).
|
||||
ToolCallCount int
|
||||
|
||||
// ToolNames lists the tool names invoked during this turn, in call order.
|
||||
// Duplicates are preserved (e.g. two bash calls produce ["bash", "bash"]).
|
||||
ToolNames []string
|
||||
|
||||
// LLMCallCount is the number of LLM round-trips (tool-loop iterations)
|
||||
// performed during this turn. Always >= 1 for a successful turn.
|
||||
LLMCallCount int
|
||||
|
||||
// InputTokensDelta is the sum of input tokens consumed during this turn
|
||||
// across every LLM call (including cache-hit input tokens).
|
||||
InputTokensDelta int
|
||||
|
||||
// OutputTokensDelta is the sum of output tokens generated during this turn.
|
||||
OutputTokensDelta int
|
||||
|
||||
// CacheReadTokensDelta is the sum of cache-read tokens during this turn.
|
||||
CacheReadTokensDelta int
|
||||
|
||||
// CacheWriteTokensDelta is the sum of cache-write tokens during this turn.
|
||||
CacheWriteTokensDelta int
|
||||
|
||||
// CostDelta is the total cost in USD attributable to this turn. Computed
|
||||
// from per-step usage and current model pricing. Zero when pricing is
|
||||
// unknown or OAuth credentials are in use.
|
||||
CostDelta float64
|
||||
|
||||
// DurationMs is the elapsed wall-clock time from AgentStart to AgentEnd,
|
||||
// in milliseconds.
|
||||
DurationMs int64
|
||||
}
|
||||
|
||||
func (e AgentEndEvent) Type() EventType { return AgentEnd }
|
||||
@@ -2403,6 +2500,43 @@ type PrepareStepResult struct {
|
||||
|
||||
func (PrepareStepResult) isResult() {}
|
||||
|
||||
// LLMUsageEvent fires after each LLM provider call with the per-call token
|
||||
// and cost deltas. Use this for accurate budget tracking, cost dashboards,
|
||||
// and any logic that needs to react between LLM calls within a single agent
|
||||
// turn (rather than only at turn boundaries).
|
||||
//
|
||||
// A single agent turn typically produces multiple LLMUsageEvents (one per
|
||||
// tool-loop iteration). The Model and Provider fields reflect the model used
|
||||
// for that specific call, which may differ from earlier calls if the
|
||||
// extension switched models mid-turn via ctx.SetModel().
|
||||
type LLMUsageEvent struct {
|
||||
// InputTokens is the number of input tokens for this call.
|
||||
InputTokens int
|
||||
// OutputTokens is the number of output tokens generated by this call.
|
||||
OutputTokens int
|
||||
// CacheReadTokens is the number of cache-hit input tokens (provider-specific).
|
||||
CacheReadTokens int
|
||||
// CacheWriteTokens is the number of cache-write tokens.
|
||||
CacheWriteTokens int
|
||||
// Cost is the USD cost of this call computed from the model's per-token
|
||||
// pricing. Zero when pricing is unknown or OAuth credentials are in use.
|
||||
Cost float64
|
||||
// Model is the model identifier used for this call (e.g. "claude-sonnet-4-5-20250929").
|
||||
Model string
|
||||
// Provider is the provider identifier (e.g. "anthropic", "openai").
|
||||
Provider string
|
||||
// RequestID is an optional correlation id for the underlying provider
|
||||
// call. May be empty when the provider does not surface one.
|
||||
RequestID string
|
||||
// StepNumber is the zero-based step index within the current agent turn.
|
||||
StepNumber int
|
||||
// FinishReason mirrors the provider's finish reason for this call
|
||||
// (e.g. "stop", "tool_calls", "length"). May be empty.
|
||||
FinishReason string
|
||||
}
|
||||
|
||||
func (e LLMUsageEvent) Type() EventType { return LLMUsage }
|
||||
|
||||
// ThemeColor is an adaptive color pair with light and dark hex values.
|
||||
// Either field may be empty to inherit from the default theme.
|
||||
type ThemeColor struct {
|
||||
|
||||
@@ -125,6 +125,11 @@ const (
|
||||
// after steering messages are injected and before messages are sent
|
||||
// to the LLM. Handlers can replace the context window for this step.
|
||||
PrepareStep EventType = "prepare_step"
|
||||
|
||||
// LLMUsage fires after each LLM provider call with the token and cost
|
||||
// deltas for that single call. Extensions use it to attribute usage to
|
||||
// specific calls/models and to drive budget enforcement between calls.
|
||||
LLMUsage EventType = "llm_usage"
|
||||
)
|
||||
|
||||
// AllEventTypes returns every supported event type.
|
||||
@@ -139,7 +144,7 @@ func AllEventTypes() []EventType {
|
||||
BeforeFork, BeforeSessionSwitch, BeforeCompact,
|
||||
SubagentStart, SubagentChunk, SubagentEnd,
|
||||
StepStart, StepFinish, ReasoningStart, Warnings, Source, Error, Retry,
|
||||
PrepareStep,
|
||||
PrepareStep, LLMUsage,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ import "testing"
|
||||
|
||||
func TestAllEventTypes_Count(t *testing.T) {
|
||||
all := AllEventTypes()
|
||||
if len(all) != 32 {
|
||||
t.Fatalf("expected 32 event types, got %d", len(all))
|
||||
if len(all) != 33 {
|
||||
t.Fatalf("expected 33 event types, got %d", len(all))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
package extensions
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestRunner_EmitLLMUsage(t *testing.T) {
|
||||
var got LLMUsageEvent
|
||||
var called bool
|
||||
ext := makeHandlerExt("llmusage.go", map[EventType][]HandlerFunc{
|
||||
LLMUsage: {
|
||||
func(e Event, c Context) Result {
|
||||
got = e.(LLMUsageEvent)
|
||||
called = true
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext)
|
||||
_, err := r.Emit(LLMUsageEvent{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
Cost: 0.0012,
|
||||
Model: "claude-sonnet-4-5-20250929",
|
||||
Provider: "anthropic",
|
||||
StepNumber: 2,
|
||||
FinishReason: "tool_calls",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("emit: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Fatal("expected LLMUsage handler to be called")
|
||||
}
|
||||
if got.InputTokens != 100 || got.OutputTokens != 50 {
|
||||
t.Errorf("token fields not propagated: %+v", got)
|
||||
}
|
||||
if got.Cost != 0.0012 {
|
||||
t.Errorf("cost not propagated, got %v", got.Cost)
|
||||
}
|
||||
if got.Model != "claude-sonnet-4-5-20250929" || got.Provider != "anthropic" {
|
||||
t.Errorf("model/provider not propagated: %+v", got)
|
||||
}
|
||||
if got.StepNumber != 2 || got.FinishReason != "tool_calls" {
|
||||
t.Errorf("step/finish reason not propagated: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_LLMUsageRegisteredViaTestAPI(t *testing.T) {
|
||||
// Verify NewTestAPI wires up onLLMUsage so the extension can call
|
||||
// api.OnLLMUsage during Init.
|
||||
ext := &LoadedExtension{Handlers: make(map[EventType][]HandlerFunc)}
|
||||
api := NewTestAPI(ext)
|
||||
|
||||
var calls int
|
||||
api.OnLLMUsage(func(e LLMUsageEvent, c Context) {
|
||||
calls++
|
||||
})
|
||||
|
||||
if len(ext.Handlers[LLMUsage]) != 1 {
|
||||
t.Fatalf("expected 1 LLMUsage handler registered, got %d", len(ext.Handlers[LLMUsage]))
|
||||
}
|
||||
|
||||
r := makeRunner(*ext)
|
||||
_, _ = r.Emit(LLMUsageEvent{InputTokens: 1})
|
||||
if calls != 1 {
|
||||
t.Errorf("expected handler called once, got %d", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentEndEvent_EnrichedFields(t *testing.T) {
|
||||
// Verify the enriched event carries through Emit without mangling.
|
||||
var got AgentEndEvent
|
||||
ext := makeHandlerExt("end.go", map[EventType][]HandlerFunc{
|
||||
AgentEnd: {
|
||||
func(e Event, c Context) Result {
|
||||
got = e.(AgentEndEvent)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
r := makeRunner(ext)
|
||||
_, err := r.Emit(AgentEndEvent{
|
||||
Response: "done",
|
||||
StopReason: "completed",
|
||||
ToolCallCount: 3,
|
||||
ToolNames: []string{"bash", "read", "bash"},
|
||||
LLMCallCount: 4,
|
||||
InputTokensDelta: 1500,
|
||||
OutputTokensDelta: 400,
|
||||
CacheReadTokensDelta: 200,
|
||||
CacheWriteTokensDelta: 100,
|
||||
CostDelta: 0.0123,
|
||||
DurationMs: 2500,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("emit: %v", err)
|
||||
}
|
||||
if got.ToolCallCount != 3 {
|
||||
t.Errorf("ToolCallCount: got %d want 3", got.ToolCallCount)
|
||||
}
|
||||
if len(got.ToolNames) != 3 || got.ToolNames[0] != "bash" || got.ToolNames[2] != "bash" {
|
||||
t.Errorf("ToolNames: %v", got.ToolNames)
|
||||
}
|
||||
if got.LLMCallCount != 4 {
|
||||
t.Errorf("LLMCallCount: got %d want 4", got.LLMCallCount)
|
||||
}
|
||||
if got.InputTokensDelta != 1500 || got.OutputTokensDelta != 400 {
|
||||
t.Errorf("token deltas: %+v", got)
|
||||
}
|
||||
if got.CacheReadTokensDelta != 200 || got.CacheWriteTokensDelta != 100 {
|
||||
t.Errorf("cache deltas: %+v", got)
|
||||
}
|
||||
if got.CostDelta != 0.0123 {
|
||||
t.Errorf("CostDelta: got %v", got.CostDelta)
|
||||
}
|
||||
if got.DurationMs != 2500 {
|
||||
t.Errorf("DurationMs: got %d", got.DurationMs)
|
||||
}
|
||||
}
|
||||
@@ -669,6 +669,12 @@ func loadSingleExtension(path string) (*LoadedExtension, error) {
|
||||
return *r
|
||||
})
|
||||
},
|
||||
onLLMUsage: func(h func(LLMUsageEvent, Context)) {
|
||||
reg(LLMUsage, func(e Event, c Context) Result {
|
||||
h(e.(LLMUsageEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
// Call Init — the extension registers its handlers, tools, commands.
|
||||
|
||||
@@ -2,9 +2,12 @@ package extensions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"maps"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
@@ -98,9 +101,24 @@ type Runner struct {
|
||||
disabledTools map[string]bool // nil = all tools enabled
|
||||
customEventSubs map[string][]func(string) // inter-extension event bus
|
||||
optionOverrides map[string]string // runtime option overrides
|
||||
configStore *viper.Viper // per-instance config store (nil = global)
|
||||
state map[string]string // session-scoped extension state (last-write-wins)
|
||||
stateMu sync.RWMutex // guards state independently of mu
|
||||
saverMu sync.Mutex // serializes stateSaver invocations so atomic-rename writes don't interleave
|
||||
stateSaver func() // optional persistence hook invoked after each state mutation
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// SetConfigStore sets the per-instance configuration store used by GetOption
|
||||
// to resolve "options.<name>" config values. When unset (nil), GetOption falls
|
||||
// back to the process-global viper store. Threading a per-Kit store keeps
|
||||
// extension option resolution isolated between Kit instances.
|
||||
func (r *Runner) SetConfigStore(v *viper.Viper) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.configStore = v
|
||||
}
|
||||
|
||||
// ShortcutEntry pairs a shortcut definition with its handler.
|
||||
type ShortcutEntry struct {
|
||||
Def ShortcutDef
|
||||
@@ -253,6 +271,18 @@ func normalizeContext(ctx Context) Context {
|
||||
if ctx.GetEntries == nil {
|
||||
ctx.GetEntries = func(string) []ExtensionEntry { return nil }
|
||||
}
|
||||
if ctx.SetState == nil {
|
||||
ctx.SetState = func(string, string) {}
|
||||
}
|
||||
if ctx.GetState == nil {
|
||||
ctx.GetState = func(string) (string, bool) { return "", false }
|
||||
}
|
||||
if ctx.DeleteState == nil {
|
||||
ctx.DeleteState = func(string) {}
|
||||
}
|
||||
if ctx.ListState == nil {
|
||||
ctx.ListState = func() []string { return nil }
|
||||
}
|
||||
if ctx.GetOption == nil {
|
||||
ctx.GetOption = func(string) string { return "" }
|
||||
}
|
||||
@@ -734,6 +764,168 @@ func (r *Runner) GetMessageRenderer(name string) *MessageRendererConfig {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Extension state store (session-scoped, last-write-wins)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SetState records a key-value pair in the runner's session-scoped extension
|
||||
// state store. The store is in-memory; callers wire SetStateSaver to persist
|
||||
// changes to a sidecar file. Thread-safe.
|
||||
//
|
||||
// When a saver is installed, concurrent SetState/DeleteState invocations are
|
||||
// serialized through saverMu so that overlapping snapshot-and-rename writes
|
||||
// cannot interleave (which would otherwise race on the shared tmp file and
|
||||
// risk persisting an older snapshot after a newer one).
|
||||
func (r *Runner) SetState(key, value string) {
|
||||
r.stateMu.Lock()
|
||||
if r.state == nil {
|
||||
r.state = make(map[string]string)
|
||||
}
|
||||
r.state[key] = value
|
||||
saver := r.stateSaver
|
||||
r.stateMu.Unlock()
|
||||
r.runSaver(saver)
|
||||
}
|
||||
|
||||
// GetState returns the value previously stored via SetState, plus a bool
|
||||
// indicating whether the key was present. Thread-safe.
|
||||
func (r *Runner) GetState(key string) (string, bool) {
|
||||
r.stateMu.RLock()
|
||||
defer r.stateMu.RUnlock()
|
||||
v, ok := r.state[key]
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// DeleteState removes a key from the state store. No-op if the key is
|
||||
// missing. Thread-safe. Saver invocations are serialized via saverMu — see
|
||||
// SetState for the rationale.
|
||||
func (r *Runner) DeleteState(key string) {
|
||||
r.stateMu.Lock()
|
||||
_, existed := r.state[key]
|
||||
if existed {
|
||||
delete(r.state, key)
|
||||
}
|
||||
saver := r.stateSaver
|
||||
r.stateMu.Unlock()
|
||||
if !existed {
|
||||
return
|
||||
}
|
||||
r.runSaver(saver)
|
||||
}
|
||||
|
||||
// runSaver invokes the optional persistence callback under saverMu so
|
||||
// concurrent SetState/DeleteState writers cannot race on the shared tmp
|
||||
// file used by SaveStateToFile's atomic rename. The deferred Unlock
|
||||
// guarantees saverMu is released even if the saver panics.
|
||||
func (r *Runner) runSaver(saver func()) {
|
||||
if saver == nil {
|
||||
return
|
||||
}
|
||||
r.saverMu.Lock()
|
||||
defer r.saverMu.Unlock()
|
||||
saver()
|
||||
}
|
||||
|
||||
// ListState returns all keys currently in the state store, in unspecified
|
||||
// order. Thread-safe.
|
||||
func (r *Runner) ListState() []string {
|
||||
r.stateMu.RLock()
|
||||
defer r.stateMu.RUnlock()
|
||||
if len(r.state) == 0 {
|
||||
return nil
|
||||
}
|
||||
keys := make([]string, 0, len(r.state))
|
||||
for k := range r.state {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// SetStateSaver installs an optional persistence hook invoked after each
|
||||
// mutation to the state store (SetState / DeleteState / LoadStateFromFile).
|
||||
// Pass nil to disable persistence. Thread-safe.
|
||||
func (r *Runner) SetStateSaver(saver func()) {
|
||||
r.stateMu.Lock()
|
||||
defer r.stateMu.Unlock()
|
||||
r.stateSaver = saver
|
||||
}
|
||||
|
||||
// SnapshotState returns a copy of the current state store as a
|
||||
// fresh map. Useful for persisting to disk without holding the lock.
|
||||
// Thread-safe.
|
||||
func (r *Runner) SnapshotState() map[string]string {
|
||||
r.stateMu.RLock()
|
||||
defer r.stateMu.RUnlock()
|
||||
if len(r.state) == 0 {
|
||||
return nil
|
||||
}
|
||||
copyMap := make(map[string]string, len(r.state))
|
||||
maps.Copy(copyMap, r.state)
|
||||
return copyMap
|
||||
}
|
||||
|
||||
// LoadStateFromFile reads a JSON map from path and replaces the in-memory
|
||||
// state store with its contents. Missing or empty files are treated as
|
||||
// "no prior state": the in-memory store is replaced with an empty map so
|
||||
// callers can safely switch sessions without leaking keys from a prior
|
||||
// session into a new one. Malformed JSON returns the parse error without
|
||||
// touching the existing store. Thread-safe.
|
||||
func (r *Runner) LoadStateFromFile(path string) error {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
r.stateMu.Lock()
|
||||
r.state = map[string]string{}
|
||||
r.stateMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("reading extension state: %w", err)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
r.stateMu.Lock()
|
||||
r.state = map[string]string{}
|
||||
r.stateMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
var loaded map[string]string
|
||||
if err := json.Unmarshal(data, &loaded); err != nil {
|
||||
return fmt.Errorf("parsing extension state: %w", err)
|
||||
}
|
||||
r.stateMu.Lock()
|
||||
r.state = loaded
|
||||
r.stateMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveStateToFile writes the current state store to path as JSON, creating
|
||||
// parent directories as needed. An empty store writes an empty object so
|
||||
// that consumers can distinguish "loaded but empty" from "never saved".
|
||||
// Writes are atomic via a tmp-file-and-rename sequence. Thread-safe.
|
||||
func (r *Runner) SaveStateToFile(path string) error {
|
||||
snap := r.SnapshotState()
|
||||
if snap == nil {
|
||||
snap = map[string]string{}
|
||||
}
|
||||
data, err := json.MarshalIndent(snap, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshalling extension state: %w", err)
|
||||
}
|
||||
if dir := filepath.Dir(path); dir != "." && dir != "" {
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return fmt.Errorf("creating state directory: %w", err)
|
||||
}
|
||||
}
|
||||
tmp := path + ".tmp"
|
||||
if err := os.WriteFile(tmp, data, 0o644); err != nil {
|
||||
return fmt.Errorf("writing extension state: %w", err)
|
||||
}
|
||||
if err := os.Rename(tmp, path); err != nil {
|
||||
_ = os.Remove(tmp)
|
||||
return fmt.Errorf("renaming extension state: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Hot-reload
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -757,7 +949,9 @@ func (r *Runner) Reload(exts []LoadedExtension) {
|
||||
r.uiVisibility = nil
|
||||
r.disabledTools = nil
|
||||
r.customEventSubs = nil
|
||||
// optionOverrides are intentionally preserved.
|
||||
// optionOverrides and state are intentionally preserved across reloads:
|
||||
// they represent user/session intent (not extension code) and would be
|
||||
// surprising to lose on a hot-reload.
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -872,7 +1066,13 @@ func (r *Runner) GetOption(name string) string {
|
||||
|
||||
// 3. Viper config: options.<name>
|
||||
configKey := "options." + name
|
||||
if v := viper.GetString(configKey); v != "" {
|
||||
r.mu.RLock()
|
||||
store := r.configStore
|
||||
r.mu.RUnlock()
|
||||
if store == nil {
|
||||
store = viper.GetViper()
|
||||
}
|
||||
if v := store.GetString(configKey); v != "" {
|
||||
return v
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,262 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRunner_State_BasicSetGetDelete(t *testing.T) {
|
||||
r := NewRunner(nil)
|
||||
|
||||
if _, ok := r.GetState("missing"); ok {
|
||||
t.Fatal("expected GetState to return ok=false for missing key")
|
||||
}
|
||||
|
||||
r.SetState("a", "1")
|
||||
r.SetState("b", "2")
|
||||
r.SetState("a", "3") // last-write-wins
|
||||
|
||||
if v, ok := r.GetState("a"); !ok || v != "3" {
|
||||
t.Errorf("expected GetState(a)=(3,true), got (%q,%v)", v, ok)
|
||||
}
|
||||
if v, ok := r.GetState("b"); !ok || v != "2" {
|
||||
t.Errorf("expected GetState(b)=(2,true), got (%q,%v)", v, ok)
|
||||
}
|
||||
|
||||
keys := r.ListState()
|
||||
if len(keys) != 2 {
|
||||
t.Errorf("expected 2 keys, got %d (%v)", len(keys), keys)
|
||||
}
|
||||
|
||||
r.DeleteState("a")
|
||||
if _, ok := r.GetState("a"); ok {
|
||||
t.Error("expected key a to be gone after DeleteState")
|
||||
}
|
||||
if len(r.ListState()) != 1 {
|
||||
t.Errorf("expected 1 key after delete, got %v", r.ListState())
|
||||
}
|
||||
|
||||
// Deleting missing key is a no-op.
|
||||
r.DeleteState("never-there")
|
||||
}
|
||||
|
||||
func TestRunner_State_SaverFires(t *testing.T) {
|
||||
r := NewRunner(nil)
|
||||
var calls int
|
||||
var mu sync.Mutex
|
||||
r.SetStateSaver(func() {
|
||||
mu.Lock()
|
||||
calls++
|
||||
mu.Unlock()
|
||||
})
|
||||
|
||||
r.SetState("a", "1")
|
||||
r.SetState("a", "2")
|
||||
r.DeleteState("a")
|
||||
r.DeleteState("a") // missing → no save
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if calls != 3 {
|
||||
t.Errorf("expected saver to fire 3 times (2 sets + 1 delete), got %d", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_SaveAndLoadRoundTrip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "ext-state.json")
|
||||
|
||||
r1 := NewRunner(nil)
|
||||
r1.SetState("k1", "v1")
|
||||
r1.SetState("k2", `{"json":"value"}`)
|
||||
if err := r1.SaveStateToFile(path); err != nil {
|
||||
t.Fatalf("SaveStateToFile: %v", err)
|
||||
}
|
||||
|
||||
// Verify file contains JSON map.
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("reading saved file: %v", err)
|
||||
}
|
||||
var parsed map[string]string
|
||||
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||
t.Fatalf("unmarshalling: %v", err)
|
||||
}
|
||||
if parsed["k1"] != "v1" || parsed["k2"] != `{"json":"value"}` {
|
||||
t.Errorf("unexpected file contents: %v", parsed)
|
||||
}
|
||||
|
||||
r2 := NewRunner(nil)
|
||||
if err := r2.LoadStateFromFile(path); err != nil {
|
||||
t.Fatalf("LoadStateFromFile: %v", err)
|
||||
}
|
||||
if v, ok := r2.GetState("k1"); !ok || v != "v1" {
|
||||
t.Errorf("expected k1=v1 after load, got (%q,%v)", v, ok)
|
||||
}
|
||||
if v, ok := r2.GetState("k2"); !ok || v != `{"json":"value"}` {
|
||||
t.Errorf("expected k2 to round-trip, got %q", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_LoadMissingFileClearsState(t *testing.T) {
|
||||
// LoadStateFromFile is documented to "replace the in-memory state store
|
||||
// with its contents"; for a missing file that means clearing the store.
|
||||
// This is what makes session-switching safe: a new session that has not
|
||||
// yet written a sidecar must not inherit keys from a prior session.
|
||||
r := NewRunner(nil)
|
||||
r.SetState("a", "1")
|
||||
if err := r.LoadStateFromFile(filepath.Join(t.TempDir(), "does-not-exist.json")); err != nil {
|
||||
t.Errorf("expected nil error for missing file, got %v", err)
|
||||
}
|
||||
if _, ok := r.GetState("a"); ok {
|
||||
t.Error("expected pre-existing state to be cleared when target file is missing")
|
||||
}
|
||||
if keys := r.ListState(); keys != nil {
|
||||
t.Errorf("expected ListState() to be nil after clearing, got %v", keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_LoadEmptyFileClearsState(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "empty.json")
|
||||
if err := os.WriteFile(path, nil, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r := NewRunner(nil)
|
||||
r.SetState("a", "1")
|
||||
if err := r.LoadStateFromFile(path); err != nil {
|
||||
t.Errorf("expected nil error for empty file, got %v", err)
|
||||
}
|
||||
if _, ok := r.GetState("a"); ok {
|
||||
t.Error("expected pre-existing state to be cleared when target file is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_LoadMalformedFileError(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "bad.json")
|
||||
if err := os.WriteFile(path, []byte("{not json"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r := NewRunner(nil)
|
||||
if err := r.LoadStateFromFile(path); err == nil {
|
||||
t.Error("expected error loading malformed JSON, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_PersistenceViaSaver(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "ext-state.json")
|
||||
|
||||
r := NewRunner(nil)
|
||||
r.SetStateSaver(func() {
|
||||
_ = r.SaveStateToFile(path)
|
||||
})
|
||||
r.SetState("hello", "world")
|
||||
|
||||
// File should exist with the value already.
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("reading saved file: %v", err)
|
||||
}
|
||||
var parsed map[string]string
|
||||
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||
t.Fatalf("unmarshalling: %v", err)
|
||||
}
|
||||
if parsed["hello"] != "world" {
|
||||
t.Errorf("expected file to contain hello=world, got %v", parsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_ConcurrentSet(t *testing.T) {
|
||||
r := NewRunner(nil)
|
||||
var wg sync.WaitGroup
|
||||
const goroutines = 16
|
||||
const iterations = 100
|
||||
wg.Add(goroutines)
|
||||
for range goroutines {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range iterations {
|
||||
r.SetState("k", "v")
|
||||
_, _ = r.GetState("k")
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
if v, ok := r.GetState("k"); !ok || v != "v" {
|
||||
t.Errorf("expected k=v after concurrent writes, got (%q,%v)", v, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_ContextNoOpsWhenUnset(t *testing.T) {
|
||||
// Verify normalizeContext installs safe no-ops for SetState/GetState/etc.
|
||||
// when not provided by the caller.
|
||||
ext := makeHandlerExt("state.go", map[EventType][]HandlerFunc{
|
||||
SessionStart: {
|
||||
func(e Event, c Context) Result {
|
||||
// All four state functions should be non-nil and safe to call.
|
||||
c.SetState("a", "b")
|
||||
if v, ok := c.GetState("a"); ok || v != "" {
|
||||
t.Errorf("no-op GetState should return (\"\", false); got (%q,%v)", v, ok)
|
||||
}
|
||||
c.DeleteState("a")
|
||||
if keys := c.ListState(); keys != nil {
|
||||
t.Errorf("no-op ListState should return nil; got %v", keys)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
r := makeRunner(ext)
|
||||
// SetContext with empty Context to exercise normalizeContext defaults.
|
||||
r.SetContext(Context{})
|
||||
_, err := r.Emit(SessionStartEvent{})
|
||||
if err != nil {
|
||||
t.Fatalf("emit: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_SaverPanicReleasesSaverMu(t *testing.T) {
|
||||
// If the saver callback panics (e.g. disk full mid-write), runSaver
|
||||
// must still release saverMu so subsequent SetState/DeleteState calls
|
||||
// can make progress. Without `defer Unlock()` the lock would be
|
||||
// permanently held and the next write would deadlock.
|
||||
r := NewRunner(nil)
|
||||
var calls int
|
||||
r.SetStateSaver(func() {
|
||||
calls++
|
||||
if calls == 1 {
|
||||
panic("simulated disk-write failure")
|
||||
}
|
||||
})
|
||||
|
||||
// First call panics. Recover, then verify a follow-up call still works
|
||||
// without blocking (proving saverMu was released).
|
||||
func() {
|
||||
defer func() {
|
||||
if rec := recover(); rec == nil {
|
||||
t.Fatal("expected panic from first saver invocation")
|
||||
}
|
||||
}()
|
||||
r.SetState("a", "1")
|
||||
}()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
r.SetState("b", "2") // would deadlock if saverMu were still held
|
||||
close(done)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("SetState after saver panic blocked — saverMu was not released")
|
||||
}
|
||||
if calls != 2 {
|
||||
t.Errorf("expected saver to fire twice (panic + recovery write), got %d", calls)
|
||||
}
|
||||
}
|
||||
@@ -183,6 +183,7 @@ func Symbols() interp.Exports {
|
||||
"RetryEvent": reflect.ValueOf((*RetryEvent)(nil)),
|
||||
"PrepareStepEvent": reflect.ValueOf((*PrepareStepEvent)(nil)),
|
||||
"PrepareStepResult": reflect.ValueOf((*PrepareStepResult)(nil)),
|
||||
"LLMUsageEvent": reflect.ValueOf((*LLMUsageEvent)(nil)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,5 +189,11 @@ func NewTestAPI(ext *LoadedExtension) API {
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onLLMUsage: func(h func(LLMUsageEvent, Context)) {
|
||||
reg(LLMUsage, func(e Event, c Context) Result {
|
||||
h(e.(LLMUsageEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
package extensions
|
||||
|
||||
// ToolKind constants classify what a tool does, enabling UIs to render
|
||||
// appropriate visualizations (e.g. diff view for edit tools, command+output
|
||||
// for execute tools) and file trackers to identify which results contain
|
||||
// modifications.
|
||||
//
|
||||
// This is the single source of truth for tool-kind classification; the
|
||||
// pkg/kit SDK re-exports these constants.
|
||||
const (
|
||||
ToolKindExecute = "execute" // Shell execution (bash)
|
||||
ToolKindEdit = "edit" // File modification (edit, write)
|
||||
ToolKindRead = "read" // File reading (read, ls)
|
||||
ToolKindSearch = "search" // Content/file search (grep, find)
|
||||
ToolKindSubagent = "agent" // Subagent spawning (subagent)
|
||||
)
|
||||
|
||||
// coreToolKinds maps built-in tool names to their kind classification.
|
||||
// MCP and extension tools without an entry default to ToolKindExecute.
|
||||
var coreToolKinds = map[string]string{
|
||||
"bash": ToolKindExecute,
|
||||
"edit": ToolKindEdit,
|
||||
"write": ToolKindEdit,
|
||||
"read": ToolKindRead,
|
||||
"ls": ToolKindRead,
|
||||
"grep": ToolKindSearch,
|
||||
"find": ToolKindSearch,
|
||||
"subagent": ToolKindSubagent,
|
||||
}
|
||||
|
||||
// ToolKindFor returns the ToolKind for a given tool name, defaulting to
|
||||
// ToolKindExecute for unknown tools (including MCP tools).
|
||||
func ToolKindFor(toolName string) string {
|
||||
if kind, ok := coreToolKinds[toolName]; ok {
|
||||
return kind
|
||||
}
|
||||
return ToolKindExecute
|
||||
}
|
||||
+24
-157
@@ -1,143 +1,32 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/mark3labs/kit/internal/watcher"
|
||||
)
|
||||
|
||||
// Watcher monitors extension directories for file changes and triggers
|
||||
// a reload callback when .go files are created, modified, or removed.
|
||||
// It uses fsnotify for kernel-level file notifications (inotify on Linux,
|
||||
// kqueue on macOS) with debouncing to coalesce rapid editor writes.
|
||||
type Watcher struct {
|
||||
watcher *fsnotify.Watcher
|
||||
onReload func()
|
||||
debounce time.Duration
|
||||
cancel context.CancelFunc
|
||||
done chan struct{}
|
||||
mu sync.Mutex
|
||||
}
|
||||
// Watcher monitors extension directories for .go file changes and triggers
|
||||
// a reload callback when changes are detected. It is implemented in terms
|
||||
// of the general-purpose internal/watcher.ContentWatcher.
|
||||
//
|
||||
// Type-aliasing here lets existing call sites (cmd/root.go and the
|
||||
// watcher_test.go suite) keep using `extensions.NewWatcher` / `*Watcher`
|
||||
// without knowing about the underlying implementation.
|
||||
type Watcher = watcher.ContentWatcher
|
||||
|
||||
// NewWatcher creates a file watcher that monitors the given directories
|
||||
// for .go file changes. When a change is detected (after debouncing),
|
||||
// onReload is called. The watcher must be started with Start() and
|
||||
// stopped with Close().
|
||||
func NewWatcher(dirs []string, onReload func()) (*Watcher, error) {
|
||||
fsw, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating file watcher: %w", err)
|
||||
}
|
||||
|
||||
for _, dir := range dirs {
|
||||
// Watch the directory itself.
|
||||
if err := fsw.Add(dir); err != nil {
|
||||
log.Printf("DEBUG watcher: skipping directory: dir=%s err=%v", dir, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Also watch immediate subdirectories (for */main.go pattern).
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
subdir := filepath.Join(dir, entry.Name())
|
||||
if err := fsw.Add(subdir); err != nil {
|
||||
log.Printf("DEBUG watcher: skipping subdirectory: dir=%s err=%v", subdir, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &Watcher{
|
||||
watcher: fsw,
|
||||
onReload: onReload,
|
||||
debounce: 300 * time.Millisecond,
|
||||
done: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start begins watching for file changes. It blocks until the context
|
||||
// is cancelled or Close() is called. Typically called in a goroutine.
|
||||
func (w *Watcher) Start(ctx context.Context) {
|
||||
w.mu.Lock()
|
||||
ctx, w.cancel = context.WithCancel(ctx)
|
||||
w.mu.Unlock()
|
||||
|
||||
defer close(w.done)
|
||||
|
||||
var timer *time.Timer
|
||||
var timerC <-chan time.Time
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
return
|
||||
|
||||
case event, ok := <-w.watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Only care about .go files.
|
||||
if !strings.HasSuffix(event.Name, ".go") {
|
||||
continue
|
||||
}
|
||||
|
||||
// React to write, create, remove, rename events.
|
||||
if event.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Remove|fsnotify.Rename) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("DEBUG watcher: file changed: file=%s op=%s", event.Name, event.Op)
|
||||
|
||||
// Debounce: reset timer on each event.
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
timer = time.NewTimer(w.debounce)
|
||||
timerC = timer.C
|
||||
|
||||
case <-timerC:
|
||||
timerC = nil
|
||||
timer = nil
|
||||
log.Printf("DEBUG watcher: reloading extensions")
|
||||
w.onReload()
|
||||
|
||||
case err, ok := <-w.watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Printf("WARN watcher: error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the watcher and releases resources.
|
||||
func (w *Watcher) Close() error {
|
||||
w.mu.Lock()
|
||||
cancel := w.cancel
|
||||
w.mu.Unlock()
|
||||
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
|
||||
// Wait for the event loop to finish.
|
||||
<-w.done
|
||||
return w.watcher.Close()
|
||||
return watcher.New(watcher.Options{
|
||||
Dirs: dirs,
|
||||
Extensions: []string{".go"},
|
||||
OnReload: onReload,
|
||||
Label: "extensions",
|
||||
})
|
||||
}
|
||||
|
||||
// WatchedDirs returns the directories to watch for extension changes.
|
||||
@@ -146,47 +35,25 @@ func (w *Watcher) Close() error {
|
||||
// point to directories are also included; explicit file paths cause
|
||||
// their parent directory to be watched instead.
|
||||
func WatchedDirs(extraPaths []string) []string {
|
||||
var dirs []string
|
||||
seen := make(map[string]bool)
|
||||
|
||||
add := func(dir string) {
|
||||
abs, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if seen[abs] {
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the directory exists.
|
||||
info, err := os.Stat(abs)
|
||||
if err != nil || !info.IsDir() {
|
||||
return
|
||||
}
|
||||
|
||||
seen[abs] = true
|
||||
dirs = append(dirs, abs)
|
||||
standard := []string{
|
||||
globalExtensionsDir(),
|
||||
filepath.Join(".kit", "extensions"),
|
||||
}
|
||||
|
||||
// Global extensions dir.
|
||||
add(globalExtensionsDir())
|
||||
|
||||
// Project-local extensions dir.
|
||||
add(filepath.Join(".kit", "extensions"))
|
||||
|
||||
// Explicit paths that are directories.
|
||||
// Filter explicit paths into directories (passed through) and files
|
||||
// (parent dir watched) for CollectDirs to dedupe.
|
||||
var extras []string
|
||||
for _, p := range extraPaths {
|
||||
info, err := os.Stat(p)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if info.IsDir() {
|
||||
add(p)
|
||||
extras = append(extras, p)
|
||||
} else {
|
||||
// For explicit files, watch the parent directory.
|
||||
add(filepath.Dir(p))
|
||||
extras = append(extras, filepath.Dir(p))
|
||||
}
|
||||
}
|
||||
|
||||
return dirs
|
||||
return watcher.CollectDirs(standard, extras)
|
||||
}
|
||||
|
||||
@@ -40,27 +40,6 @@ func ExtensionToolsAsLLMTools(defs []ToolDef, runner *Runner) []fantasy.AgentToo
|
||||
return tools
|
||||
}
|
||||
|
||||
// coreToolKinds maps built-in tool names to their kind classification.
|
||||
var coreToolKinds = map[string]string{
|
||||
"bash": "execute",
|
||||
"edit": "edit",
|
||||
"write": "edit",
|
||||
"read": "read",
|
||||
"ls": "read",
|
||||
"grep": "search",
|
||||
"find": "search",
|
||||
"subagent": "agent",
|
||||
}
|
||||
|
||||
// toolKindFor returns the ToolKind for a given tool name, defaulting to
|
||||
// "execute" for unknown tools (including MCP tools).
|
||||
func toolKindFor(toolName string) string {
|
||||
if kind, ok := coreToolKinds[toolName]; ok {
|
||||
return kind
|
||||
}
|
||||
return "execute"
|
||||
}
|
||||
|
||||
// parseToolArgsJSON attempts to parse JSON-encoded tool args into a map.
|
||||
// Returns nil on failure (non-fatal convenience parsing).
|
||||
func parseToolArgsJSON(input string) map[string]any {
|
||||
@@ -93,7 +72,7 @@ func (w *wrappedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.T
|
||||
fmt.Sprintf("Error: tool %q is currently disabled", toolName)), nil
|
||||
}
|
||||
|
||||
kind := toolKindFor(toolName)
|
||||
kind := ToolKindFor(toolName)
|
||||
|
||||
// 1. Emit ToolCall — extensions can block execution.
|
||||
if w.runner.HasHandlers(ToolCall) {
|
||||
|
||||
+65
-42
@@ -46,9 +46,9 @@ type AgentSetupOptions struct {
|
||||
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
|
||||
|
||||
// ProviderConfig, when non-nil, is used directly instead of calling
|
||||
// BuildProviderConfig(). Callers that already hold viperInitMu can
|
||||
// pre-build this and release the lock before calling SetupAgent, so the
|
||||
// slow agent/MCP initialisation runs concurrently with other New() calls.
|
||||
// BuildProviderConfig(). Callers (e.g. Kit.New) pre-build this from their
|
||||
// per-instance config store and pass it here, so the slow agent/MCP
|
||||
// initialisation can run without further config reads.
|
||||
ProviderConfig *models.ProviderConfig
|
||||
// Debug enables debug logging. When zero-value, viper is consulted.
|
||||
// Only meaningful when ProviderConfig is also set.
|
||||
@@ -75,6 +75,11 @@ type AgentSetupOptions struct {
|
||||
// MCPTaskConfig configures task-augmented tools/call execution. The
|
||||
// zero value preserves historical synchronous-only behaviour.
|
||||
MCPTaskConfig tools.MCPTaskConfig
|
||||
// Viper is the per-instance configuration store. When set, it is used for
|
||||
// any fallback config reads (debug, no-extensions, max-steps, stream,
|
||||
// extension paths) and is attached to the extension runner. When nil, the
|
||||
// process-global viper store is used.
|
||||
Viper *viper.Viper
|
||||
}
|
||||
|
||||
// AgentSetupResult bundles the created agent and any debug logger so the caller
|
||||
@@ -87,57 +92,62 @@ type AgentSetupResult struct {
|
||||
ExtRunner *extensions.Runner
|
||||
}
|
||||
|
||||
// BuildProviderConfig creates a *models.ProviderConfig from the current viper
|
||||
// state. All entry points (root, script, SDK) converge through this function.
|
||||
// BuildProviderConfig creates a *models.ProviderConfig from the supplied viper
|
||||
// store (or the process-global store when v is nil). All entry points (root,
|
||||
// script, SDK) converge through this function.
|
||||
//
|
||||
// Generation parameter pointers (Temperature, TopP, etc.) are only set when
|
||||
// the user has explicitly configured them via CLI flag, environment variable,
|
||||
// or global config file. This allows per-model defaults from modelSettings
|
||||
// and customModels to fill in unset parameters downstream.
|
||||
func BuildProviderConfig() (*models.ProviderConfig, string, error) {
|
||||
systemPrompt, err := config.LoadSystemPrompt(viper.GetString("system-prompt"))
|
||||
func BuildProviderConfig(v *viper.Viper) (*models.ProviderConfig, string, error) {
|
||||
if v == nil {
|
||||
v = viper.GetViper()
|
||||
}
|
||||
systemPrompt, err := config.LoadSystemPrompt(v.GetString("system-prompt"))
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to load system prompt: %w", err)
|
||||
}
|
||||
|
||||
numGPU := int32(viper.GetInt("num-gpu-layers"))
|
||||
mainGPU := int32(viper.GetInt("main-gpu"))
|
||||
numGPU := int32(v.GetInt("num-gpu-layers"))
|
||||
mainGPU := int32(v.GetInt("main-gpu"))
|
||||
|
||||
cfg := &models.ProviderConfig{
|
||||
ModelString: viper.GetString("model"),
|
||||
ModelString: v.GetString("model"),
|
||||
SystemPrompt: systemPrompt,
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
ProviderURL: viper.GetString("provider-url"),
|
||||
MaxTokens: viper.GetInt("max-tokens"),
|
||||
StopSequences: viper.GetStringSlice("stop-sequences"),
|
||||
ProviderAPIKey: v.GetString("provider-api-key"),
|
||||
ProviderURL: v.GetString("provider-url"),
|
||||
MaxTokens: v.GetInt("max-tokens"),
|
||||
StopSequences: v.GetStringSlice("stop-sequences"),
|
||||
NumGPU: &numGPU,
|
||||
MainGPU: &mainGPU,
|
||||
TLSSkipVerify: viper.GetBool("tls-skip-verify"),
|
||||
ThinkingLevel: models.ParseThinkingLevel(viper.GetString("thinking-level")),
|
||||
TLSSkipVerify: v.GetBool("tls-skip-verify"),
|
||||
ThinkingLevel: models.ParseThinkingLevel(v.GetString("thinking-level")),
|
||||
ConfigStore: v,
|
||||
}
|
||||
|
||||
// Only set generation parameter pointers when the user has explicitly
|
||||
// provided a value. This leaves nil pointers for unset params, allowing
|
||||
// per-model defaults (modelSettings / customModels params) to apply.
|
||||
if viper.IsSet("temperature") {
|
||||
v := float32(viper.GetFloat64("temperature"))
|
||||
cfg.Temperature = &v
|
||||
if v.IsSet("temperature") {
|
||||
val := float32(v.GetFloat64("temperature"))
|
||||
cfg.Temperature = &val
|
||||
}
|
||||
if viper.IsSet("top-p") {
|
||||
v := float32(viper.GetFloat64("top-p"))
|
||||
cfg.TopP = &v
|
||||
if v.IsSet("top-p") {
|
||||
val := float32(v.GetFloat64("top-p"))
|
||||
cfg.TopP = &val
|
||||
}
|
||||
if viper.IsSet("top-k") {
|
||||
v := int32(viper.GetInt("top-k"))
|
||||
cfg.TopK = &v
|
||||
if v.IsSet("top-k") {
|
||||
val := int32(v.GetInt("top-k"))
|
||||
cfg.TopK = &val
|
||||
}
|
||||
if viper.IsSet("frequency-penalty") {
|
||||
v := float32(viper.GetFloat64("frequency-penalty"))
|
||||
cfg.FrequencyPenalty = &v
|
||||
if v.IsSet("frequency-penalty") {
|
||||
val := float32(v.GetFloat64("frequency-penalty"))
|
||||
cfg.FrequencyPenalty = &val
|
||||
}
|
||||
if viper.IsSet("presence-penalty") {
|
||||
v := float32(viper.GetFloat64("presence-penalty"))
|
||||
cfg.PresencePenalty = &v
|
||||
if v.IsSet("presence-penalty") {
|
||||
val := float32(v.GetFloat64("presence-penalty"))
|
||||
cfg.PresencePenalty = &val
|
||||
}
|
||||
|
||||
return cfg, systemPrompt, nil
|
||||
@@ -149,14 +159,21 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
var modelConfig *models.ProviderConfig
|
||||
var systemPrompt string
|
||||
|
||||
// Resolve the config store: prefer the per-instance store, falling back to
|
||||
// the process-global store.
|
||||
v := opts.Viper
|
||||
if v == nil {
|
||||
v = viper.GetViper()
|
||||
}
|
||||
|
||||
if opts.ProviderConfig != nil {
|
||||
// Pre-built config supplied by caller (e.g. Kit.New after releasing
|
||||
// viperInitMu). Use it directly — no viper reads needed here.
|
||||
// Pre-built config supplied by caller (e.g. Kit.New after building the
|
||||
// per-instance store). Use it directly — no viper reads needed here.
|
||||
modelConfig = opts.ProviderConfig
|
||||
systemPrompt = modelConfig.SystemPrompt
|
||||
} else {
|
||||
var err error
|
||||
modelConfig, systemPrompt, err = BuildProviderConfig()
|
||||
modelConfig, systemPrompt, err = BuildProviderConfig(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -164,13 +181,13 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
|
||||
// Resolve debug / no-extensions / max-steps / streaming: prefer explicit
|
||||
// fields (set when ProviderConfig was pre-built) over viper fallback.
|
||||
debugEnabled := opts.Debug || viper.GetBool("debug")
|
||||
noExtensions := opts.NoExtensions || viper.GetBool("no-extensions")
|
||||
debugEnabled := opts.Debug || v.GetBool("debug")
|
||||
noExtensions := opts.NoExtensions || v.GetBool("no-extensions")
|
||||
maxSteps := opts.MaxSteps
|
||||
if maxSteps == 0 {
|
||||
maxSteps = viper.GetInt("max-steps")
|
||||
maxSteps = v.GetInt("max-steps")
|
||||
}
|
||||
streamingEnabled := opts.StreamingEnabled || viper.GetBool("stream")
|
||||
streamingEnabled := opts.StreamingEnabled || v.GetBool("stream")
|
||||
|
||||
// Create the appropriate debug logger.
|
||||
var debugLogger tools.DebugLogger
|
||||
@@ -189,7 +206,7 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
var extCreationOpts extensionCreationOpts
|
||||
if !noExtensions {
|
||||
var extErr error
|
||||
extRunner, extCreationOpts, extErr = loadExtensions()
|
||||
extRunner, extCreationOpts, extErr = loadExtensions(v)
|
||||
if extErr != nil {
|
||||
fmt.Printf("Warning: Failed to load extensions: %v\n", extErr)
|
||||
}
|
||||
@@ -253,9 +270,14 @@ type extensionCreationOpts struct {
|
||||
}
|
||||
|
||||
// loadExtensions discovers and loads Yaegi extensions, builds the runner,
|
||||
// and returns the tool wrapper/extra tools.
|
||||
func loadExtensions() (*extensions.Runner, extensionCreationOpts, error) {
|
||||
extraPaths := viper.GetStringSlice("extension")
|
||||
// and returns the tool wrapper/extra tools. The supplied store is used to
|
||||
// resolve the "extension" config key and is attached to the runner so
|
||||
// extension option lookups stay isolated to this Kit instance.
|
||||
func loadExtensions(v *viper.Viper) (*extensions.Runner, extensionCreationOpts, error) {
|
||||
if v == nil {
|
||||
v = viper.GetViper()
|
||||
}
|
||||
extraPaths := v.GetStringSlice("extension")
|
||||
loaded, err := extensions.LoadExtensions(extraPaths)
|
||||
if err != nil {
|
||||
return nil, extensionCreationOpts{}, err
|
||||
@@ -266,6 +288,7 @@ func loadExtensions() (*extensions.Runner, extensionCreationOpts, error) {
|
||||
}
|
||||
|
||||
runner := extensions.NewRunner(loaded)
|
||||
runner.SetConfigStore(v)
|
||||
|
||||
wrapper := func(tools []fantasy.AgentTool) []fantasy.AgentTool {
|
||||
return extensions.WrapToolsWithExtensions(tools, runner)
|
||||
|
||||
@@ -0,0 +1,282 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestNpmToWireProtocol documents the wire protocols that the auto-router
|
||||
// understands. Provider-specific bundles that need bespoke auth or URL
|
||||
// templating (azure, bedrock, openrouter, google-vertex*, @ai-sdk/gateway)
|
||||
// are intentionally absent — they have native top-level cases in
|
||||
// CreateProvider and never reach the auto-router.
|
||||
func TestNpmToWireProtocol(t *testing.T) {
|
||||
want := map[string]wireProtocol{
|
||||
"@ai-sdk/openai": wireOpenAI,
|
||||
"@ai-sdk/openai-compatible": wireOpenAI,
|
||||
"@ai-sdk/anthropic": wireAnthropic,
|
||||
"@ai-sdk/google": wireGoogle,
|
||||
|
||||
// Thin OpenAI-compatible wrappers — routed via openaicompat using
|
||||
// the SDK's hard-coded default base URL (sdkDefaultBaseURL).
|
||||
"@ai-sdk/groq": wireOpenAI,
|
||||
"@ai-sdk/cerebras": wireOpenAI,
|
||||
"@ai-sdk/perplexity": wireOpenAI,
|
||||
"@ai-sdk/togetherai": wireOpenAI,
|
||||
"@ai-sdk/xai": wireOpenAI,
|
||||
"@ai-sdk/deepinfra": wireOpenAI,
|
||||
"@ai-sdk/mistral": wireOpenAI,
|
||||
"@ai-sdk/cohere": wireOpenAI,
|
||||
"@ai-sdk/vercel": wireOpenAI,
|
||||
"@aihubmix/ai-sdk-provider": wireOpenAI,
|
||||
"venice-ai-sdk-provider": wireOpenAI,
|
||||
"merge-gateway-ai-sdk-provider": wireOpenAI,
|
||||
}
|
||||
for npm, wire := range want {
|
||||
if got := npmToWireProtocol[npm]; got != wire {
|
||||
t.Errorf("npmToWireProtocol[%q] = %d, want %d", npm, got, wire)
|
||||
}
|
||||
}
|
||||
|
||||
// Bundle packages must NOT be in the table — they need bespoke auth or
|
||||
// URL templating that the auto-router cannot satisfy.
|
||||
for _, npm := range []string{
|
||||
"@ai-sdk/google-vertex",
|
||||
"@ai-sdk/google-vertex/anthropic",
|
||||
"@ai-sdk/amazon-bedrock",
|
||||
"@ai-sdk/azure",
|
||||
"@openrouter/ai-sdk-provider",
|
||||
"@ai-sdk/gateway",
|
||||
} {
|
||||
if _, ok := npmToWireProtocol[npm]; ok {
|
||||
t.Errorf("npmToWireProtocol unexpectedly contains bundle package %q", npm)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newTestRegistry builds a registry containing a single proxy-style provider
|
||||
// ("testproxy") with the given default npm, plus one model that carries the
|
||||
// given per-model npm override.
|
||||
func newTestRegistry(api, defaultNPM, modelID, modelNPMOverride string) *ModelsRegistry {
|
||||
return &ModelsRegistry{
|
||||
providers: map[string]ProviderInfo{
|
||||
"testproxy": {
|
||||
ID: "testproxy",
|
||||
Name: "Test Proxy",
|
||||
Env: []string{"TESTPROXY_API_KEY"},
|
||||
NPM: defaultNPM,
|
||||
API: api,
|
||||
Models: map[string]ModelInfo{
|
||||
modelID: {
|
||||
ID: modelID,
|
||||
Name: modelID,
|
||||
ProviderNPM: modelNPMOverride,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// TestAutoRouteProvider_WireRouting verifies that autoRouteProvider routes each
|
||||
// npm package to the correct fantasy provider implementation. This is the core
|
||||
// regression test for issue #41: previously any npm that resolved to a
|
||||
// non-openai/anthropic/openaicompat LLM provider (notably @ai-sdk/google) hit a
|
||||
// dead `default` branch and failed with "has no LLM provider mapping".
|
||||
func TestAutoRouteProvider_WireRouting(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelID string
|
||||
defaultNPM string
|
||||
overrideNPM string
|
||||
// wantType is the concrete fantasy LanguageModel type the model should
|
||||
// be routed to, identified by reflect type string.
|
||||
wantType string
|
||||
}{
|
||||
{
|
||||
name: "openai-compatible default",
|
||||
modelID: "test-model",
|
||||
defaultNPM: "@ai-sdk/openai-compatible",
|
||||
wantType: "openai.languageModel",
|
||||
},
|
||||
{
|
||||
name: "anthropic override",
|
||||
modelID: "test-model",
|
||||
defaultNPM: "@ai-sdk/openai-compatible",
|
||||
overrideNPM: "@ai-sdk/anthropic",
|
||||
wantType: "anthropic.languageModel",
|
||||
},
|
||||
{
|
||||
name: "openai (responses) override",
|
||||
modelID: "gpt-4o",
|
||||
defaultNPM: "@ai-sdk/openai-compatible",
|
||||
overrideNPM: "@ai-sdk/openai",
|
||||
wantType: "openai.responsesLanguageModel",
|
||||
},
|
||||
{
|
||||
// The bug: opencode's gemini-* models override the default
|
||||
// openai-compatible npm with @ai-sdk/google.
|
||||
name: "google override (issue #41)",
|
||||
modelID: "gemini-3.5-flash",
|
||||
defaultNPM: "@ai-sdk/openai-compatible",
|
||||
overrideNPM: "@ai-sdk/google",
|
||||
wantType: "*google.languageModel",
|
||||
},
|
||||
{
|
||||
// Unknown npm but provider has an API URL → openai-compatible fallback.
|
||||
name: "unknown npm with API URL falls back to openai-compat",
|
||||
modelID: "test-model",
|
||||
defaultNPM: "@ai-sdk/some-future-thing",
|
||||
wantType: "openai.languageModel",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reg := newTestRegistry("https://proxy.example/v1", tt.defaultNPM, tt.modelID, tt.overrideNPM)
|
||||
config := &ProviderConfig{ProviderAPIKey: "test-key"}
|
||||
|
||||
result, err := autoRouteProvider(context.Background(), config, "testproxy", tt.modelID, reg)
|
||||
if err != nil {
|
||||
t.Fatalf("autoRouteProvider returned error: %v", err)
|
||||
}
|
||||
if result == nil || result.Model == nil {
|
||||
t.Fatalf("autoRouteProvider returned nil model")
|
||||
}
|
||||
|
||||
gotType := reflect.TypeOf(result.Model).String()
|
||||
if gotType != tt.wantType {
|
||||
t.Errorf("routed to %s, want %s", gotType, tt.wantType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAutoRouteProvider_UnknownNpmNoAPI verifies the improved error message for
|
||||
// a provider whose npm has no known wire protocol and that has no API URL to
|
||||
// fall back on.
|
||||
func TestAutoRouteProvider_UnknownNpmNoAPI(t *testing.T) {
|
||||
reg := newTestRegistry("", "@ai-sdk/unmapped", "test-model", "")
|
||||
config := &ProviderConfig{ProviderAPIKey: "test-key"}
|
||||
|
||||
_, err := autoRouteProvider(context.Background(), config, "testproxy", "test-model", reg)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown npm with no API URL, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "cannot auto-route provider testproxy") {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "--provider-url") {
|
||||
t.Errorf("error should suggest --provider-url, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAutoRouteProvider_UnknownProvider verifies the not-in-database error.
|
||||
func TestAutoRouteProvider_UnknownProvider(t *testing.T) {
|
||||
reg := newTestRegistry("https://proxy.example/v1", "@ai-sdk/openai-compatible", "test-model", "")
|
||||
config := &ProviderConfig{ProviderAPIKey: "test-key"}
|
||||
|
||||
_, err := autoRouteProvider(context.Background(), config, "does-not-exist", "test-model", reg)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown provider, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not found in model database") {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsProviderLLMSupported_Google verifies that a provider whose npm is
|
||||
// @ai-sdk/google is reported as supported (it now maps to a wire protocol).
|
||||
func TestIsProviderLLMSupported_Google(t *testing.T) {
|
||||
info := &ProviderInfo{ID: "testproxy", NPM: "@ai-sdk/google"}
|
||||
if !isProviderLLMSupported("testproxy", info) {
|
||||
t.Error("expected @ai-sdk/google provider to be LLM-supported")
|
||||
}
|
||||
}
|
||||
|
||||
// TestVersionedBasePath verifies detection of proxy base URLs that already
|
||||
// carry an API version segment (which collides with the genai SDK's injected
|
||||
// version).
|
||||
func TestVersionedBasePath(t *testing.T) {
|
||||
tests := []struct {
|
||||
rawURL string
|
||||
want string
|
||||
}{
|
||||
{"https://opencode.ai/zen/v1", "/zen/v1"},
|
||||
{"https://opencode.ai/zen/v1/", "/zen/v1"},
|
||||
{"https://example.com/api/v1beta", "/api/v1beta"},
|
||||
{"https://example.com/api/v2alpha", "/api/v2alpha"},
|
||||
{"https://generativelanguage.googleapis.com", ""},
|
||||
{"https://proxy.example/openai", ""},
|
||||
{"", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := versionedBasePath(tt.rawURL); got != tt.want {
|
||||
t.Errorf("versionedBasePath(%q) = %q, want %q", tt.rawURL, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// recordingRoundTripper captures the path of the request it receives.
|
||||
type recordingRoundTripper struct{ gotPath string }
|
||||
|
||||
func (r *recordingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
r.gotPath = req.URL.Path
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(strings.NewReader("{}")),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestGeminiProxyTransport_StripsInjectedVersion verifies that the transport
|
||||
// collapses the genai-injected "/v1beta" segment that follows a proxy base
|
||||
// URL which already carries its own version segment. This is the second-order
|
||||
// fix that makes opencode/gemini-* actually reach the proxy (issue #41).
|
||||
func TestGeminiProxyTransport_StripsInjectedVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
basePath string
|
||||
reqPath string
|
||||
wantPath string
|
||||
}{
|
||||
{
|
||||
name: "strips doubled v1beta after /zen/v1",
|
||||
basePath: "/zen/v1",
|
||||
reqPath: "/zen/v1/v1beta/models/gemini-3.5-flash:generateContent",
|
||||
wantPath: "/zen/v1/models/gemini-3.5-flash:generateContent",
|
||||
},
|
||||
{
|
||||
name: "strips doubled v1beta1 after /zen/v1",
|
||||
basePath: "/zen/v1",
|
||||
reqPath: "/zen/v1/v1beta1/models/gemini-3.5-flash:generateContent",
|
||||
wantPath: "/zen/v1/models/gemini-3.5-flash:generateContent",
|
||||
},
|
||||
{
|
||||
name: "leaves non-matching path untouched",
|
||||
basePath: "/zen/v1",
|
||||
reqPath: "/other/v1beta/models/x:generateContent",
|
||||
wantPath: "/other/v1beta/models/x:generateContent",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rec := &recordingRoundTripper{}
|
||||
tr := &geminiProxyTransport{base: rec, basePath: tt.basePath}
|
||||
req, err := http.NewRequest(http.MethodPost, "https://host"+tt.reqPath, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest: %v", err)
|
||||
}
|
||||
if _, err := tr.RoundTrip(req); err != nil {
|
||||
t.Fatalf("RoundTrip: %v", err)
|
||||
}
|
||||
if rec.gotPath != tt.wantPath {
|
||||
t.Errorf("forwarded path = %q, want %q", rec.gotPath, tt.wantPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCopilotProviderAliasUsesCatalog(t *testing.T) {
|
||||
registry := NewModelsRegistry()
|
||||
|
||||
models, err := registry.GetModelsForProvider("copilot")
|
||||
if err != nil {
|
||||
t.Fatalf("GetModelsForProvider(copilot) failed: %v", err)
|
||||
}
|
||||
if len(models) == 0 {
|
||||
t.Fatal("expected copilot alias to return github-copilot catalog models")
|
||||
}
|
||||
if registry.LookupModel("copilot", "gpt-5.5") == nil {
|
||||
t.Fatal("expected copilot/gpt-5.5 to resolve through github-copilot catalog")
|
||||
}
|
||||
if registry.GetProviderInfo("copilot") == nil {
|
||||
t.Fatal("expected copilot alias to return github-copilot provider info")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopilotRejectsNonGPTModels(t *testing.T) {
|
||||
_, err := CreateProvider(t.Context(), &ProviderConfig{ModelString: "copilot/claude-sonnet-4.6"})
|
||||
if err == nil {
|
||||
t.Fatal("expected non-GPT Copilot model to be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopilotHTTPClientCachesToken(t *testing.T) {
|
||||
client := createCopilotHTTPClient("cached-token", time.Now().Add(time.Hour).Unix(), false)
|
||||
transport, ok := client.Transport.(*copilotTransport)
|
||||
if !ok {
|
||||
t.Fatal("expected *copilotTransport")
|
||||
}
|
||||
|
||||
token := transport.cachedToken(t.Context())
|
||||
if token != "cached-token" {
|
||||
t.Fatalf("expected cached token, got %q", token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopilotTransportHeaders(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest failed: %v", err)
|
||||
}
|
||||
|
||||
transport := &copilotTransport{
|
||||
base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
if req.Header.Get("Authorization") != "Bearer cached-token" {
|
||||
t.Fatalf("unexpected Authorization header: %q", req.Header.Get("Authorization"))
|
||||
}
|
||||
if req.Header.Get("Copilot-Integration-Id") != copilotIntegrationID {
|
||||
t.Fatalf("unexpected Copilot-Integration-Id header: %q", req.Header.Get("Copilot-Integration-Id"))
|
||||
}
|
||||
if req.Header.Get("Editor-Version") != copilotEditorVersion {
|
||||
t.Fatalf("unexpected Editor-Version header: %q", req.Header.Get("Editor-Version"))
|
||||
}
|
||||
if req.Header.Get("User-Agent") != copilotUserAgent {
|
||||
t.Fatalf("unexpected User-Agent header: %q", req.Header.Get("User-Agent"))
|
||||
}
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil
|
||||
}),
|
||||
token: "cached-token",
|
||||
expiresAt: time.Now().Add(time.Hour).Unix(),
|
||||
}
|
||||
|
||||
resp, err := transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("RoundTrip failed: %v", err)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
+46
-20
@@ -10,14 +10,24 @@ import (
|
||||
|
||||
// loadCustomModelsFromConfig loads custom model definitions from the config file
|
||||
// and returns them as a map of model ID -> ModelInfo. Returns nil if no custom
|
||||
// models are configured.
|
||||
// models are configured. Reads from the process-global viper store (the model
|
||||
// registry is a process-global singleton).
|
||||
func loadCustomModelsFromConfig() map[string]ModelInfo {
|
||||
if !viper.IsSet("customModels") {
|
||||
return loadCustomModelsFrom(viper.GetViper())
|
||||
}
|
||||
|
||||
// loadCustomModelsFrom loads custom model definitions from the supplied store.
|
||||
// When v is nil the process-global store is used.
|
||||
func loadCustomModelsFrom(v *viper.Viper) map[string]ModelInfo {
|
||||
if v == nil {
|
||||
v = viper.GetViper()
|
||||
}
|
||||
if !v.IsSet("customModels") {
|
||||
return nil
|
||||
}
|
||||
|
||||
var customModels map[string]CustomModelConfig
|
||||
if err := viper.UnmarshalKey("customModels", &customModels); err != nil {
|
||||
if err := v.UnmarshalKey("customModels", &customModels); err != nil {
|
||||
log.Printf("Warning: Failed to parse customModels: %v", err)
|
||||
return nil
|
||||
}
|
||||
@@ -59,16 +69,20 @@ func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
|
||||
return info
|
||||
}
|
||||
|
||||
// LoadModelSettingsFromConfig loads per-model generation parameter overrides
|
||||
// from the config file. Keys are "provider/model" strings. Returns nil if
|
||||
// no model settings are configured.
|
||||
func LoadModelSettingsFromConfig() map[string]*GenerationParams {
|
||||
if !viper.IsSet("modelSettings") {
|
||||
// LoadModelSettingsFrom loads per-model generation parameter overrides from the
|
||||
// supplied per-instance store. When v is nil the process-global store is used.
|
||||
// Keys are "provider/model" strings. Returns nil if no model settings are
|
||||
// configured.
|
||||
func LoadModelSettingsFrom(v *viper.Viper) map[string]*GenerationParams {
|
||||
if v == nil {
|
||||
v = viper.GetViper()
|
||||
}
|
||||
if !v.IsSet("modelSettings") {
|
||||
return nil
|
||||
}
|
||||
|
||||
var settings map[string]GenerationParamsConfig
|
||||
if err := viper.UnmarshalKey("modelSettings", &settings); err != nil {
|
||||
if err := v.UnmarshalKey("modelSettings", &settings); err != nil {
|
||||
log.Printf("Warning: Failed to parse modelSettings: %v", err)
|
||||
return nil
|
||||
}
|
||||
@@ -148,12 +162,17 @@ func ApplyModelSettings(config *ProviderConfig, modelInfo *ModelInfo) {
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve the config store: prefer the per-instance store carried on the
|
||||
// ProviderConfig (set by BuildProviderConfig / Kit.New), falling back to
|
||||
// the process-global store for callers that don't thread one through.
|
||||
store := config.ConfigStore
|
||||
|
||||
// Collect model-level params: modelSettings override > custom model params.
|
||||
// modelSettings takes priority because it's the more specific/intentional config.
|
||||
var params *GenerationParams
|
||||
|
||||
// First check modelSettings from config.
|
||||
if settings := LoadModelSettingsFromConfig(); settings != nil {
|
||||
if settings := LoadModelSettingsFrom(store); settings != nil {
|
||||
modelKey := provider + "/" + modelName
|
||||
if p, ok := settings[modelKey]; ok {
|
||||
params = p
|
||||
@@ -173,28 +192,28 @@ func ApplyModelSettings(config *ProviderConfig, modelInfo *ModelInfo) {
|
||||
// We check viper.IsSet() which returns true only when the key was
|
||||
// set via CLI flag, environment variable, or config file global section.
|
||||
|
||||
if params.MaxTokens != nil && !isExplicitlySet("max-tokens") {
|
||||
if params.MaxTokens != nil && !isExplicitlySet(store, "max-tokens") {
|
||||
config.MaxTokens = *params.MaxTokens
|
||||
}
|
||||
if params.Temperature != nil && !isExplicitlySet("temperature") {
|
||||
if params.Temperature != nil && !isExplicitlySet(store, "temperature") {
|
||||
config.Temperature = params.Temperature
|
||||
}
|
||||
if params.TopP != nil && !isExplicitlySet("top-p") {
|
||||
if params.TopP != nil && !isExplicitlySet(store, "top-p") {
|
||||
config.TopP = params.TopP
|
||||
}
|
||||
if params.TopK != nil && !isExplicitlySet("top-k") {
|
||||
if params.TopK != nil && !isExplicitlySet(store, "top-k") {
|
||||
config.TopK = params.TopK
|
||||
}
|
||||
if params.FrequencyPenalty != nil && !isExplicitlySet("frequency-penalty") {
|
||||
if params.FrequencyPenalty != nil && !isExplicitlySet(store, "frequency-penalty") {
|
||||
config.FrequencyPenalty = params.FrequencyPenalty
|
||||
}
|
||||
if params.PresencePenalty != nil && !isExplicitlySet("presence-penalty") {
|
||||
if params.PresencePenalty != nil && !isExplicitlySet(store, "presence-penalty") {
|
||||
config.PresencePenalty = params.PresencePenalty
|
||||
}
|
||||
if len(params.StopSequences) > 0 && !isExplicitlySet("stop-sequences") {
|
||||
if len(params.StopSequences) > 0 && !isExplicitlySet(store, "stop-sequences") {
|
||||
config.StopSequences = params.StopSequences
|
||||
}
|
||||
if params.ThinkingLevel != "" && !isExplicitlySet("thinking-level") {
|
||||
if params.ThinkingLevel != "" && !isExplicitlySet(store, "thinking-level") {
|
||||
config.ThinkingLevel = params.ThinkingLevel
|
||||
}
|
||||
if params.SystemPrompt != "" && config.SystemPrompt == "" {
|
||||
@@ -228,7 +247,14 @@ func LoadSystemPromptValue(input string) string {
|
||||
// isExplicitlySet returns true when the user has explicitly set a config key
|
||||
// via CLI flag, environment variable, or the global section of the config file.
|
||||
// Model-level defaults should not override explicitly set values.
|
||||
func isExplicitlySet(key string) bool {
|
||||
//
|
||||
// The check runs against the supplied per-instance store when non-nil,
|
||||
// otherwise the process-global store. This keeps the "explicit vs unset"
|
||||
// precedence contract per-Kit-instance once a store is threaded through.
|
||||
func isExplicitlySet(v *viper.Viper, key string) bool {
|
||||
if v == nil {
|
||||
v = viper.GetViper()
|
||||
}
|
||||
// viper.IsSet returns true if the key has been set in any of the
|
||||
// data stores (flag, env, config file, default). We need to check
|
||||
// whether the value was set at the global config level (not just
|
||||
@@ -239,7 +265,7 @@ func isExplicitlySet(key string) bool {
|
||||
// file values. This means global config file values (e.g.
|
||||
// temperature: 0.7 at the top level) will correctly take precedence
|
||||
// over model-level defaults, which is the desired behavior.
|
||||
return viper.IsSet(key)
|
||||
return v.IsSet(key)
|
||||
}
|
||||
|
||||
// GenerationParams holds per-model generation parameter defaults.
|
||||
|
||||
File diff suppressed because one or more lines are too long
+83
-14
@@ -48,18 +48,87 @@ type modelsDBLimit struct {
|
||||
Output int `json:"output"`
|
||||
}
|
||||
|
||||
// npmToLLMProvider maps npm package names from models.dev to LLM
|
||||
// provider identifiers. Providers not in this map but with an api URL
|
||||
// can be auto-routed through openaicompat.
|
||||
var npmToLLMProvider = map[string]string{
|
||||
"@ai-sdk/anthropic": "anthropic",
|
||||
"@ai-sdk/openai": "openai",
|
||||
"@ai-sdk/google": "google",
|
||||
"@ai-sdk/google-vertex": "google-vertex",
|
||||
"@ai-sdk/google-vertex/anthropic": "google-vertex-anthropic",
|
||||
"@ai-sdk/amazon-bedrock": "bedrock",
|
||||
"@ai-sdk/azure": "azure",
|
||||
"@openrouter/ai-sdk-provider": "openrouter",
|
||||
"@ai-sdk/vercel": "vercel",
|
||||
"@ai-sdk/openai-compatible": "openaicompat",
|
||||
// wireProtocol identifies which LLM API protocol an npm package speaks.
|
||||
// Fantasy implements three native protocols (openai, anthropic, google);
|
||||
// everything else in its providers/ tree is a thin wrapper around one of
|
||||
// them with a pre-baked default URL or auth scheme.
|
||||
type wireProtocol int
|
||||
|
||||
const (
|
||||
wireUnknown wireProtocol = iota
|
||||
wireOpenAI
|
||||
wireAnthropic
|
||||
wireGoogle
|
||||
)
|
||||
|
||||
// npmToWireProtocol maps npm package names from models.dev to the wire
|
||||
// protocol they speak. Provider-specific bundles that need bespoke auth or
|
||||
// URL templating (azure, bedrock, openrouter, google-vertex, google-vertex-
|
||||
// anthropic, and @ai-sdk/gateway which is the Vercel AI Gateway) are
|
||||
// intentionally absent — they have native top-level cases in CreateProvider
|
||||
// and never reach the auto-router. Providers not in this map but with an
|
||||
// api URL are auto-routed through the OpenAI-compatible wire.
|
||||
//
|
||||
// The thin OpenAI-compatible npm wrappers (groq, cerebras, mistral, …) are
|
||||
// listed explicitly so that auto-routing can recover their hard-coded base
|
||||
// URL from sdkDefaultBaseURL when the registry entry has no api field.
|
||||
var npmToWireProtocol = map[string]wireProtocol{
|
||||
// Native wires.
|
||||
"@ai-sdk/openai": wireOpenAI,
|
||||
"@ai-sdk/openai-compatible": wireOpenAI,
|
||||
"@ai-sdk/anthropic": wireAnthropic,
|
||||
"@ai-sdk/google": wireGoogle,
|
||||
|
||||
// Thin OpenAI-compatible wrappers. Each ships with a hard-coded base URL
|
||||
// in its JS SDK (see sdkDefaultBaseURL) but speaks the plain OpenAI chat
|
||||
// completions wire — so we can route them all through fantasy's
|
||||
// openaicompat provider once we supply the URL.
|
||||
"@ai-sdk/groq": wireOpenAI,
|
||||
"@ai-sdk/cerebras": wireOpenAI,
|
||||
"@ai-sdk/perplexity": wireOpenAI,
|
||||
"@ai-sdk/togetherai": wireOpenAI,
|
||||
"@ai-sdk/xai": wireOpenAI,
|
||||
"@ai-sdk/deepinfra": wireOpenAI,
|
||||
"@ai-sdk/mistral": wireOpenAI,
|
||||
"@ai-sdk/cohere": wireOpenAI,
|
||||
"@ai-sdk/vercel": wireOpenAI, // v0 API (api.v0.dev), distinct from @ai-sdk/gateway
|
||||
"@aihubmix/ai-sdk-provider": wireOpenAI,
|
||||
"venice-ai-sdk-provider": wireOpenAI,
|
||||
"merge-gateway-ai-sdk-provider": wireOpenAI,
|
||||
}
|
||||
|
||||
// sdkDefaultBaseURL maps an npm package name to the base URL its JavaScript
|
||||
// SDK uses by default. This lets us recover a working endpoint for providers
|
||||
// whose models.dev entry omits the `api` field because the JS SDK hard-codes
|
||||
// the URL (e.g. groq, cerebras, mistral, x.ai…).
|
||||
//
|
||||
// Only OpenAI-compatible and native-wire SDKs are listed; providers needing
|
||||
// bespoke auth or URL templating (bedrock SigV4, azure resource URLs,
|
||||
// google-vertex project/location, cloudflare gateway account IDs, gitlab,
|
||||
// sap-ai-core) are handled by native CreateProvider cases or surface a
|
||||
// targeted error that asks the user to supply --provider-url.
|
||||
var sdkDefaultBaseURL = map[string]string{
|
||||
// Native wires.
|
||||
"@ai-sdk/openai": "https://api.openai.com/v1",
|
||||
"@ai-sdk/anthropic": "https://api.anthropic.com/v1",
|
||||
"@ai-sdk/google": "https://generativelanguage.googleapis.com/v1beta",
|
||||
|
||||
// Thin OpenAI-compatible wrappers.
|
||||
"@ai-sdk/groq": "https://api.groq.com/openai/v1",
|
||||
"@ai-sdk/cerebras": "https://api.cerebras.ai/v1",
|
||||
"@ai-sdk/perplexity": "https://api.perplexity.ai",
|
||||
"@ai-sdk/togetherai": "https://api.together.xyz/v1",
|
||||
"@ai-sdk/xai": "https://api.x.ai/v1",
|
||||
"@ai-sdk/deepinfra": "https://api.deepinfra.com/v1/openai",
|
||||
"@ai-sdk/mistral": "https://api.mistral.ai/v1",
|
||||
"@ai-sdk/cohere": "https://api.cohere.com/compatibility/v1",
|
||||
"@ai-sdk/vercel": "https://api.v0.dev/v1",
|
||||
"@aihubmix/ai-sdk-provider": "https://aihubmix.com/v1",
|
||||
"venice-ai-sdk-provider": "https://api.venice.ai/api/v1",
|
||||
"merge-gateway-ai-sdk-provider": "https://api-gateway.merge.dev/v1/ai-sdk",
|
||||
|
||||
// Native handlers — included for ResolveProviderBaseURL introspection
|
||||
// even though CreateProvider routes these via dedicated cases.
|
||||
"@ai-sdk/gateway": "https://ai-gateway.vercel.sh/v1",
|
||||
"@openrouter/ai-sdk-provider": "https://openrouter.ai/api/v1",
|
||||
}
|
||||
|
||||
+444
-74
@@ -9,8 +9,11 @@ import (
|
||||
"io"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
@@ -25,11 +28,30 @@ import (
|
||||
openaisdk "github.com/charmbracelet/openai-go"
|
||||
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
const (
|
||||
// ClaudeCodePrompt is the required system prompt for OAuth authentication.
|
||||
ClaudeCodePrompt = "You are Claude Code, Anthropic's official CLI for Claude."
|
||||
|
||||
// copilotProviderID is the canonical models.dev provider key. The CLI also
|
||||
// accepts the shorter "copilot" alias for user-facing model strings.
|
||||
copilotProviderID = "github-copilot"
|
||||
// copilotAliasProviderID is the short provider prefix accepted by kit.
|
||||
copilotAliasProviderID = "copilot"
|
||||
// copilotBaseURL is the fallback API URL if the model catalog has no API URL.
|
||||
copilotBaseURL = "https://api.githubcopilot.com"
|
||||
|
||||
// GitHub Copilot currently expects VS Code Copilot Chat client identifiers.
|
||||
// Keep these centralized so they are easy to audit and update when GitHub
|
||||
// changes accepted client metadata.
|
||||
copilotIntegrationID = "vscode-chat"
|
||||
copilotEditorVersion = "vscode/1.104.1"
|
||||
copilotEditorPluginVersion = "copilot-chat/0.31.0"
|
||||
copilotUserAgent = "GitHubCopilotChat/0.31.0"
|
||||
copilotOpenAIIntent = "conversation-agent"
|
||||
copilotGitHubAPIVersion = "2026-01-09"
|
||||
)
|
||||
|
||||
// resolveModelAlias resolves model aliases to their full names using the registry
|
||||
@@ -164,6 +186,13 @@ type ProviderConfig struct {
|
||||
ThinkingLevel ThinkingLevel
|
||||
DisableCaching bool // Opt-out: set to true to disable automatic prompt caching
|
||||
|
||||
// ConfigStore is the per-instance configuration store used to resolve
|
||||
// "explicitly set" precedence checks (isExplicitlySet), per-model
|
||||
// settings, and right-sizing. When nil, the process-global viper store is
|
||||
// used. Threading a per-Kit store here keeps generation-parameter
|
||||
// precedence isolated between Kit instances in the same process.
|
||||
ConfigStore *viper.Viper
|
||||
|
||||
// ProgressReaderFunc, when set, wraps an io.Reader with progress display
|
||||
// for long operations like Ollama model pulls. The returned io.ReadCloser
|
||||
// must be closed when done. When nil, the raw reader is consumed directly
|
||||
@@ -205,6 +234,20 @@ func ParseModelString(modelString string) (provider, model string, err error) {
|
||||
return "", "", fmt.Errorf("invalid model format %q: expected provider/model (e.g. anthropic/claude-sonnet-4-5)", modelString)
|
||||
}
|
||||
|
||||
// isCopilotProvider reports whether provider is the canonical catalog key or
|
||||
// the user-facing shorthand alias.
|
||||
func isCopilotProvider(provider string) bool {
|
||||
return provider == copilotAliasProviderID || provider == copilotProviderID
|
||||
}
|
||||
|
||||
// catalogProviderID maps supported provider aliases to their models.dev keys.
|
||||
func catalogProviderID(provider string) string {
|
||||
if isCopilotProvider(provider) {
|
||||
return copilotProviderID
|
||||
}
|
||||
return provider
|
||||
}
|
||||
|
||||
// CreateProvider creates a fantasy LanguageModel based on the provider configuration.
|
||||
// Model metadata is looked up from the models.dev database for cost tracking and
|
||||
// capability detection, but unknown models are passed through to the provider
|
||||
@@ -212,8 +255,10 @@ func ParseModelString(modelString string) (provider, model string, err error) {
|
||||
//
|
||||
// Native providers: anthropic, openai, google, ollama, azure, google-vertex-anthropic,
|
||||
// openrouter, bedrock, vercel.
|
||||
// Any provider in models.dev with an api URL or openai-compatible npm package
|
||||
// is auto-routed through fantasy's openaicompat provider.
|
||||
// Any other provider in models.dev is auto-routed by wire protocol: its npm
|
||||
// package (or per-model override) selects the OpenAI, Anthropic, or Google
|
||||
// transport, using the provider's api URL as the base. Providers with an api
|
||||
// URL but an unrecognized npm package fall back to the OpenAI-compatible wire.
|
||||
func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResult, error) {
|
||||
provider, modelName, err := ParseModelString(config.ModelString)
|
||||
if err != nil {
|
||||
@@ -226,17 +271,30 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
|
||||
}
|
||||
|
||||
registry := GetGlobalRegistry()
|
||||
lookupProvider := catalogProviderID(provider)
|
||||
|
||||
// Look up model metadata (advisory, not blocking).
|
||||
// Look up model metadata (advisory for most providers, strict for Copilot).
|
||||
// When the model is known we validate config limits and print
|
||||
// suggestions on likely typos; when unknown we let the provider
|
||||
// API be the authority.
|
||||
modelInfo := registry.LookupModel(provider, modelName)
|
||||
if modelInfo == nil && provider != "ollama" && config.ProviderURL == "" {
|
||||
// API be the authority except for Copilot, whose non-GPT catalog entries
|
||||
// require unsupported wire protocols.
|
||||
modelInfo := registry.LookupModel(lookupProvider, modelName)
|
||||
if isCopilotProvider(provider) {
|
||||
providerInfo := registry.GetProviderInfo(copilotProviderID)
|
||||
if providerInfo == nil {
|
||||
return nil, fmt.Errorf("unsupported provider: %s (not found in model database)", copilotProviderID)
|
||||
}
|
||||
if modelInfo == nil {
|
||||
if suggestions := registry.SuggestModels(copilotProviderID, modelName); len(suggestions) > 0 {
|
||||
return nil, fmt.Errorf("model %q not found for provider %s. Did you mean one of: %s", modelName, copilotProviderID, strings.Join(suggestions, ", "))
|
||||
}
|
||||
return nil, fmt.Errorf("model %q not found for provider %s", modelName, copilotProviderID)
|
||||
}
|
||||
} else if modelInfo == nil && provider != "ollama" && config.ProviderURL == "" {
|
||||
// Model not in database — warn with suggestions but don't block.
|
||||
if suggestions := registry.SuggestModels(provider, modelName); len(suggestions) > 0 {
|
||||
if suggestions := registry.SuggestModels(lookupProvider, modelName); len(suggestions) > 0 {
|
||||
fmt.Fprintf(os.Stderr, "Warning: model %q not found in model database for provider %s. Similar models: %s\n",
|
||||
modelName, provider, strings.Join(suggestions, ", "))
|
||||
modelName, lookupProvider, strings.Join(suggestions, ", "))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -270,17 +328,21 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
|
||||
result, createErr = createAnthropicProvider(ctx, config, modelName)
|
||||
case "openai":
|
||||
result, createErr = createOpenAIProvider(ctx, config, modelName)
|
||||
case "copilot", "github-copilot":
|
||||
result, createErr = createCopilotProvider(ctx, config, modelName)
|
||||
case "google", "gemini":
|
||||
result, createErr = createGoogleProvider(ctx, config, modelName)
|
||||
case "ollama":
|
||||
result, createErr = createOllamaProvider(ctx, config, modelName)
|
||||
case "azure":
|
||||
case "azure", "azure-cognitive-services":
|
||||
result, createErr = createAzureProvider(ctx, config, modelName)
|
||||
case "google-vertex-anthropic":
|
||||
result, createErr = createVertexAnthropicProvider(ctx, config, modelName)
|
||||
case "google-vertex":
|
||||
result, createErr = createGoogleVertexProvider(ctx, config, modelName)
|
||||
case "openrouter":
|
||||
result, createErr = createOpenRouterProvider(ctx, config, modelName)
|
||||
case "bedrock":
|
||||
case "bedrock", "amazon-bedrock":
|
||||
result, createErr = createBedrockProvider(ctx, config, modelName)
|
||||
case "vercel":
|
||||
result, createErr = createVercelProvider(ctx, config, modelName)
|
||||
@@ -327,44 +389,100 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
|
||||
|
||||
// autoRouteProvider attempts to create a provider by looking up its npm package
|
||||
// in the models.dev database and routing through the appropriate fantasy provider.
|
||||
// For openai-compatible providers, it uses the api URL from models.dev.
|
||||
// Models may have a provider override that specifies a different npm package than
|
||||
// the provider's default (e.g., opencode's claude-opus-4-6 uses @ai-sdk/anthropic).
|
||||
// It routes on wire protocol (openai, anthropic, google) rather than per-npm
|
||||
// provider name: fantasy implements three native wire protocols, and every other
|
||||
// entry in its providers/ tree is a thin wrapper around one of them. Using the
|
||||
// provider's api URL from models.dev as the base URL, any proxy that re-flavors
|
||||
// one of these protocols (e.g. opencode's Gemini routes) Just Works.
|
||||
//
|
||||
// Models may carry a provider override that specifies a different npm package
|
||||
// than the provider's default (e.g. opencode's claude-* uses @ai-sdk/anthropic
|
||||
// and its gemini-* uses @ai-sdk/google), which is resolved first.
|
||||
func autoRouteProvider(ctx context.Context, config *ProviderConfig, provider, modelName string, registry *ModelsRegistry) (*ProviderResult, error) {
|
||||
providerInfo := registry.GetProviderInfo(provider)
|
||||
if providerInfo == nil {
|
||||
return nil, fmt.Errorf("unsupported provider: %s (not found in model database)", provider)
|
||||
}
|
||||
|
||||
// Check for model-specific provider override
|
||||
// Resolve npm: per-model override > provider default.
|
||||
npmPackage := providerInfo.NPM
|
||||
if modelInfo := registry.LookupModel(provider, modelName); modelInfo != nil && modelInfo.ProviderNPM != "" {
|
||||
npmPackage = modelInfo.ProviderNPM
|
||||
}
|
||||
|
||||
// Determine the LLM provider for this npm package
|
||||
llmProvider := npmToLLMProvider[npmPackage]
|
||||
if llmProvider == "" && providerInfo.API != "" {
|
||||
// Unknown npm but has API URL → route through openaicompat
|
||||
llmProvider = "openaicompat"
|
||||
wire, known := npmToWireProtocol[npmPackage]
|
||||
if !known {
|
||||
// Unknown npm but the provider has an API URL → assume OpenAI-compatible.
|
||||
// (Preserves the long-standing "any provider in models.dev with an api URL
|
||||
// is auto-routed through openaicompat" behaviour.)
|
||||
if providerInfo.API == "" {
|
||||
return nil, fmt.Errorf(
|
||||
"cannot auto-route provider %s: npm package %q has no known wire protocol "+
|
||||
"and the registry has no API URL (use --provider-url to override)",
|
||||
provider, npmPackage,
|
||||
)
|
||||
}
|
||||
wire = wireOpenAI
|
||||
}
|
||||
|
||||
switch llmProvider {
|
||||
case "openaicompat":
|
||||
return createAutoRoutedOpenAICompatProvider(ctx, config, modelName, providerInfo)
|
||||
case "anthropic":
|
||||
if config.ProviderURL == "" && providerInfo.API != "" {
|
||||
// All three wires use the provider's API URL from models.dev as the base.
|
||||
// When the registry has none, fall back to the SDK's hard-coded default for
|
||||
// this npm package (covers groq, cerebras, mistral, x.ai, etc. — providers
|
||||
// whose JS SDK ships a built-in baseURL that models.dev doesn't restate).
|
||||
if config.ProviderURL == "" {
|
||||
if providerInfo.API != "" {
|
||||
config.ProviderURL = providerInfo.API
|
||||
} else if defaultURL, ok := sdkDefaultBaseURL[npmPackage]; ok {
|
||||
config.ProviderURL = defaultURL
|
||||
providerInfo.API = defaultURL // for downstream helpers that read info.API
|
||||
}
|
||||
return createAutoRoutedAnthropicProvider(ctx, config, modelName, providerInfo)
|
||||
case "openai":
|
||||
if config.ProviderURL == "" && providerInfo.API != "" {
|
||||
config.ProviderURL = providerInfo.API
|
||||
}
|
||||
return createAutoRoutedOpenAIProvider(ctx, config, modelName, providerInfo)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported provider: %s (npm: %s has no LLM provider mapping)", provider, npmPackage)
|
||||
}
|
||||
|
||||
// Provider templates a runtime account/region/deployment segment into the
|
||||
// URL (cloudflare-ai-gateway, databricks, snowflake-cortex, gitlab,
|
||||
// sap-ai-core). Resolve via environment variables, or surface a targeted
|
||||
// error pointing the user at the right knobs.
|
||||
if resolved, err := resolveTemplatedAPIURL(config.ProviderURL, providerInfo); err != nil {
|
||||
return nil, err
|
||||
} else if resolved != "" {
|
||||
config.ProviderURL = resolved
|
||||
providerInfo.API = resolved
|
||||
}
|
||||
|
||||
switch wire {
|
||||
case wireOpenAI:
|
||||
// The native OpenAI SDK package (@ai-sdk/openai) speaks the Responses
|
||||
// API; openai-compatible proxies (and unknown-npm fallbacks) use the
|
||||
// chat-completions wire via fantasy's openaicompat provider.
|
||||
if npmPackage == "@ai-sdk/openai" {
|
||||
return createAutoRoutedOpenAIProvider(ctx, config, modelName, providerInfo)
|
||||
}
|
||||
return createAutoRoutedOpenAICompatProvider(ctx, config, modelName, providerInfo)
|
||||
case wireAnthropic:
|
||||
return createAutoRoutedAnthropicProvider(ctx, config, modelName, providerInfo)
|
||||
case wireGoogle:
|
||||
return createAutoRoutedGoogleProvider(ctx, config, modelName, providerInfo)
|
||||
default:
|
||||
return nil, fmt.Errorf("internal error: unknown wire protocol for provider %s (npm: %s)", provider, npmPackage)
|
||||
}
|
||||
}
|
||||
|
||||
// resolveAutoRouteAPIKey looks up the API key for an auto-routed provider,
|
||||
// returning a uniform error message when none can be resolved.
|
||||
func resolveAutoRouteAPIKey(config *ProviderConfig, info *ProviderInfo) (string, error) {
|
||||
apiKey := resolveAPIKey(config.ProviderAPIKey, info.Env)
|
||||
if apiKey == "" {
|
||||
return "", fmt.Errorf("%s API key not provided. Use --provider-api-key or set %s",
|
||||
info.Name, strings.Join(info.Env, " / "))
|
||||
}
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// wrapProviderErr produces the uniform "failed to create X provider/model: %w"
|
||||
// error wrap used by every createXxxProvider path. kind is typically
|
||||
// "provider" or "model".
|
||||
func wrapProviderErr(name, kind string, err error) error {
|
||||
return fmt.Errorf("failed to create %s %s: %w", name, kind, err)
|
||||
}
|
||||
|
||||
// createAutoRoutedOpenAICompatProvider creates an openaicompat provider using
|
||||
@@ -378,10 +496,9 @@ func createAutoRoutedOpenAICompatProvider(ctx context.Context, config *ProviderC
|
||||
return nil, fmt.Errorf("provider %s requires --provider-url (no API URL in database)", info.ID)
|
||||
}
|
||||
|
||||
apiKey := resolveAPIKey(config.ProviderAPIKey, info.Env)
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("%s API key not provided. Use --provider-api-key or set %s",
|
||||
info.Name, strings.Join(info.Env, " / "))
|
||||
apiKey, err := resolveAutoRouteAPIKey(config, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var opts []openaicompat.Option
|
||||
@@ -395,12 +512,12 @@ func createAutoRoutedOpenAICompatProvider(ctx context.Context, config *ProviderC
|
||||
|
||||
p, err := openaicompat.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create %s provider: %w", info.Name, err)
|
||||
return nil, wrapProviderErr(info.Name, "provider", err)
|
||||
}
|
||||
|
||||
model, err := p.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create %s model: %w", info.Name, err)
|
||||
return nil, wrapProviderErr(info.Name, "model", err)
|
||||
}
|
||||
|
||||
return &ProviderResult{Model: model}, nil
|
||||
@@ -411,10 +528,9 @@ func createAutoRoutedOpenAICompatProvider(ctx context.Context, config *ProviderC
|
||||
func createAutoRoutedAnthropicProvider(ctx context.Context, config *ProviderConfig, modelName string, info *ProviderInfo) (*ProviderResult, error) {
|
||||
clearConflictingAnthropicSamplingParams(config)
|
||||
|
||||
apiKey := resolveAPIKey(config.ProviderAPIKey, info.Env)
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("%s API key not provided. Use --provider-api-key or set %s",
|
||||
info.Name, strings.Join(info.Env, " / "))
|
||||
apiKey, err := resolveAutoRouteAPIKey(config, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var opts []anthropic.Option
|
||||
@@ -433,12 +549,12 @@ func createAutoRoutedAnthropicProvider(ctx context.Context, config *ProviderConf
|
||||
|
||||
p, err := anthropic.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create %s provider: %w", info.Name, err)
|
||||
return nil, wrapProviderErr(info.Name, "provider", err)
|
||||
}
|
||||
|
||||
model, err := p.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create %s model: %w", info.Name, err)
|
||||
return nil, wrapProviderErr(info.Name, "model", err)
|
||||
}
|
||||
|
||||
return &ProviderResult{Model: model}, nil
|
||||
@@ -447,10 +563,9 @@ func createAutoRoutedAnthropicProvider(ctx context.Context, config *ProviderConf
|
||||
// createAutoRoutedOpenAIProvider creates an openai provider for
|
||||
// third-party providers with openai-compatible APIs.
|
||||
func createAutoRoutedOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName string, info *ProviderInfo) (*ProviderResult, error) {
|
||||
apiKey := resolveAPIKey(config.ProviderAPIKey, info.Env)
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("%s API key not provided. Use --provider-api-key or set %s",
|
||||
info.Name, strings.Join(info.Env, " / "))
|
||||
apiKey, err := resolveAutoRouteAPIKey(config, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var opts []openai.Option
|
||||
@@ -467,12 +582,12 @@ func createAutoRoutedOpenAIProvider(ctx context.Context, config *ProviderConfig,
|
||||
|
||||
p, err := openai.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create %s provider: %w", info.Name, err)
|
||||
return nil, wrapProviderErr(info.Name, "provider", err)
|
||||
}
|
||||
|
||||
model, err := p.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create %s model: %w", info.Name, err)
|
||||
return nil, wrapProviderErr(info.Name, "model", err)
|
||||
}
|
||||
|
||||
providerOpts := buildOpenAIProviderOptions(config, modelName)
|
||||
@@ -480,6 +595,114 @@ func createAutoRoutedOpenAIProvider(ctx context.Context, config *ProviderConfig,
|
||||
return &ProviderResult{Model: model, ProviderOptions: providerOpts}, nil
|
||||
}
|
||||
|
||||
// createAutoRoutedGoogleProvider creates a Google (Gemini) provider for
|
||||
// third-party providers that expose a Gemini-compatible API (e.g. opencode's
|
||||
// Gemini routes, which carry an @ai-sdk/google per-model override).
|
||||
//
|
||||
// The underlying genai SDK always injects its own API version segment
|
||||
// ("v1beta") between the base URL and the resource path. When the proxy's
|
||||
// base URL from models.dev already carries a version segment (e.g. opencode's
|
||||
// https://opencode.ai/zen/v1), that produces a doubled ".../v1/v1beta/..."
|
||||
// path that the proxy rejects. In that case we install a transport that
|
||||
// strips the injected segment so the proxy's own version is used.
|
||||
func createAutoRoutedGoogleProvider(ctx context.Context, config *ProviderConfig, modelName string, info *ProviderInfo) (*ProviderResult, error) {
|
||||
apiKey, err := resolveAutoRouteAPIKey(config, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts := []google.Option{
|
||||
google.WithGeminiAPIKey(apiKey),
|
||||
google.WithName(info.ID),
|
||||
}
|
||||
|
||||
if config.ProviderURL != "" {
|
||||
opts = append(opts, google.WithBaseURL(config.ProviderURL))
|
||||
}
|
||||
|
||||
// Decide whether the genai-injected version segment needs stripping.
|
||||
var httpClient *http.Client
|
||||
if basePath := versionedBasePath(config.ProviderURL); basePath != "" {
|
||||
httpClient = newGeminiProxyHTTPClient(basePath, config.TLSSkipVerify)
|
||||
} else if config.TLSSkipVerify {
|
||||
httpClient = createHTTPClientWithTLSConfig(true)
|
||||
}
|
||||
if httpClient != nil {
|
||||
opts = append(opts, google.WithHTTPClient(httpClient))
|
||||
}
|
||||
|
||||
p, err := google.New(opts...)
|
||||
if err != nil {
|
||||
return nil, wrapProviderErr(info.Name, "provider", err)
|
||||
}
|
||||
|
||||
model, err := p.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, wrapProviderErr(info.Name, "model", err)
|
||||
}
|
||||
|
||||
return &ProviderResult{Model: model}, nil
|
||||
}
|
||||
|
||||
// versionSegmentRe matches a trailing API version segment in a URL path,
|
||||
// e.g. "/v1", "/v1beta", "/v1beta1", "/v2alpha".
|
||||
var versionSegmentRe = regexp.MustCompile(`/v\d+(?:beta\d*|alpha\d*)?$`)
|
||||
|
||||
// versionedBasePath returns the path component of rawURL when that path ends
|
||||
// with an API version segment (e.g. opencode's ".../zen/v1" → "/zen/v1").
|
||||
// It returns "" when rawURL is empty, unparseable, or has no version suffix
|
||||
// — in which case the genai SDK's default version injection is correct and
|
||||
// no rewriting is needed.
|
||||
func versionedBasePath(rawURL string) string {
|
||||
if rawURL == "" {
|
||||
return ""
|
||||
}
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
path := strings.TrimSuffix(u.Path, "/")
|
||||
if versionSegmentRe.MatchString(path) {
|
||||
return path
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// newGeminiProxyHTTPClient builds an HTTP client whose transport strips the
|
||||
// genai-injected version segment ("v1beta"/"v1beta1") that directly follows
|
||||
// basePath, collapsing "{basePath}/v1beta/..." back to "{basePath}/...".
|
||||
func newGeminiProxyHTTPClient(basePath string, skipVerify bool) *http.Client {
|
||||
var base http.RoundTripper
|
||||
if skipVerify {
|
||||
base = &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||
} else {
|
||||
base = http.DefaultTransport
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: &geminiProxyTransport{base: base, basePath: basePath},
|
||||
}
|
||||
}
|
||||
|
||||
// geminiProxyTransport removes the redundant API version segment that the
|
||||
// genai SDK injects after a proxy base URL that already carries its own
|
||||
// version segment.
|
||||
type geminiProxyTransport struct {
|
||||
base http.RoundTripper
|
||||
basePath string
|
||||
}
|
||||
|
||||
func (t *geminiProxyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
for _, injected := range []string{"/v1beta1", "/v1beta"} {
|
||||
prefix := t.basePath + injected + "/"
|
||||
if strings.HasPrefix(req.URL.Path, prefix) {
|
||||
newReq := req.Clone(req.Context())
|
||||
newReq.URL.Path = t.basePath + strings.TrimPrefix(req.URL.Path, t.basePath+injected)
|
||||
return t.base.RoundTrip(newReq)
|
||||
}
|
||||
}
|
||||
return t.base.RoundTrip(req)
|
||||
}
|
||||
|
||||
// resolveAPIKey returns the first non-empty API key from the explicit key
|
||||
// or the environment variables.
|
||||
func resolveAPIKey(explicitKey string, envVars []string) string {
|
||||
@@ -530,7 +753,7 @@ func rightSizeMaxTokens(config *ProviderConfig, modelInfo *ModelInfo) {
|
||||
if modelInfo == nil || modelInfo.Limit.Output <= 0 {
|
||||
return
|
||||
}
|
||||
if isExplicitlySet("max-tokens") {
|
||||
if isExplicitlySet(config.ConfigStore, "max-tokens") {
|
||||
return
|
||||
}
|
||||
target := min(modelInfo.Limit.Output, defaultRightSizeCap)
|
||||
@@ -709,7 +932,7 @@ func createAnthropicProvider(ctx context.Context, config *ProviderConfig, modelN
|
||||
}
|
||||
|
||||
// Handle OAuth vs API key authentication
|
||||
if strings.HasPrefix(source, "stored OAuth") {
|
||||
if source == auth.CredentialSourceOAuth {
|
||||
httpClient := createOAuthHTTPClient(apiKey, config.TLSSkipVerify)
|
||||
opts = append(opts, anthropic.WithHTTPClient(httpClient))
|
||||
// Note: For OAuth, the API key is set as a placeholder; the transport handles auth
|
||||
@@ -719,12 +942,12 @@ func createAnthropicProvider(ctx context.Context, config *ProviderConfig, modelN
|
||||
|
||||
provider, err := anthropic.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Anthropic provider: %w", err)
|
||||
return nil, wrapProviderErr("Anthropic", "provider", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Anthropic model: %w", err)
|
||||
return nil, wrapProviderErr("Anthropic", "model", err)
|
||||
}
|
||||
|
||||
// Build provider options for extended thinking (reasoning budget).
|
||||
@@ -761,12 +984,12 @@ func createVertexAnthropicProvider(ctx context.Context, config *ProviderConfig,
|
||||
|
||||
provider, err := anthropic.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Vertex Anthropic provider: %w", err)
|
||||
return nil, wrapProviderErr("Vertex Anthropic", "provider", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Vertex Anthropic model: %w", err)
|
||||
return nil, wrapProviderErr("Vertex Anthropic", "model", err)
|
||||
}
|
||||
|
||||
return &ProviderResult{Model: model}, nil
|
||||
@@ -834,12 +1057,12 @@ func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
|
||||
provider, err := openai.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OpenAI provider: %w", err)
|
||||
return nil, wrapProviderErr("OpenAI", "provider", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OpenAI model: %w", err)
|
||||
return nil, wrapProviderErr("OpenAI", "model", err)
|
||||
}
|
||||
|
||||
// Build provider options for OpenAI Responses API reasoning models.
|
||||
@@ -848,6 +1071,72 @@ func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
return &ProviderResult{Model: model, ProviderOptions: providerOpts}, nil
|
||||
}
|
||||
|
||||
// createCopilotProvider builds a GitHub Copilot provider through fantasy's
|
||||
// OpenAI-compatible provider. The catalog key is github-copilot, but the public
|
||||
// model prefix may be either copilot/ or github-copilot/.
|
||||
//
|
||||
// Only gpt-* Copilot models are enabled here. The catalog also lists Claude and
|
||||
// Gemini Copilot models, but those require different wire protocols and must be
|
||||
// routed explicitly before they can be safely accepted.
|
||||
func createCopilotProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
||||
if !strings.HasPrefix(modelName, "gpt-") {
|
||||
return nil, fmt.Errorf("GitHub Copilot model %q is not supported yet: only gpt-* models use the OpenAI-compatible protocol", modelName)
|
||||
}
|
||||
|
||||
cm, err := auth.NewCredentialManager()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize credential manager: %w", err)
|
||||
}
|
||||
|
||||
token, err := cm.GetValidCopilotAccessTokenContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GitHub Copilot credentials not available. Use 'kit auth login copilot': %w", err)
|
||||
}
|
||||
|
||||
expiresAt := int64(0)
|
||||
if creds, err := cm.GetCopilotCredentials(); err == nil && creds != nil && creds.CopilotAccessToken == token {
|
||||
expiresAt = creds.ExpiresAt
|
||||
}
|
||||
|
||||
baseURL := copilotBaseURL
|
||||
if providerInfo := GetGlobalRegistry().GetProviderInfo(copilotProviderID); providerInfo != nil && providerInfo.API != "" {
|
||||
baseURL = providerInfo.API
|
||||
}
|
||||
if config.ProviderURL != "" {
|
||||
baseURL = config.ProviderURL
|
||||
}
|
||||
|
||||
opts := []openai.Option{
|
||||
openai.WithName(copilotAliasProviderID),
|
||||
openai.WithBaseURL(baseURL),
|
||||
openai.WithAPIKey(token),
|
||||
openai.WithHTTPClient(createCopilotHTTPClient(token, expiresAt, config.TLSSkipVerify)),
|
||||
openai.WithUseResponsesAPI(),
|
||||
openai.WithResponsesAPIFunc(copilotUsesResponsesAPI),
|
||||
openai.WithObjectMode(fantasy.ObjectModeTool),
|
||||
}
|
||||
|
||||
provider, err := openai.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GitHub Copilot provider: %w", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GitHub Copilot model: %w", err)
|
||||
}
|
||||
|
||||
providerOpts := buildOpenAIProviderOptions(config, modelName)
|
||||
|
||||
return &ProviderResult{Model: model, ProviderOptions: providerOpts}, nil
|
||||
}
|
||||
|
||||
// copilotUsesResponsesAPI selects the OpenAI Responses API for Copilot models
|
||||
// known to support it. Non-gpt models are rejected before provider creation.
|
||||
func copilotUsesResponsesAPI(modelID string) bool {
|
||||
return strings.HasPrefix(modelID, "gpt-5")
|
||||
}
|
||||
|
||||
// createOpenAICodexProvider creates a provider for ChatGPT/Codex OAuth tokens.
|
||||
// Uses the chatgpt.com/backend-api/codex endpoint with special headers.
|
||||
func createOpenAICodexProvider(ctx context.Context, config *ProviderConfig, modelName, token, accountID string) (*ProviderResult, error) {
|
||||
@@ -875,12 +1164,12 @@ func createOpenAICodexProvider(ctx context.Context, config *ProviderConfig, mode
|
||||
|
||||
provider, err := openai.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OpenAI Codex provider: %w", err)
|
||||
return nil, wrapProviderErr("OpenAI Codex", "provider", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OpenAI Codex model: %w", err)
|
||||
return nil, wrapProviderErr("OpenAI Codex", "model", err)
|
||||
}
|
||||
|
||||
providerOpts := buildCodexProviderOptions(config, modelName)
|
||||
@@ -977,6 +1266,87 @@ func (t *codexTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return t.base.RoundTrip(newReq)
|
||||
}
|
||||
|
||||
// createCopilotHTTPClient returns an HTTP client that injects Copilot-specific
|
||||
// authorization and client metadata headers. The token and expiry are cached in
|
||||
// the transport so streaming requests do not hit credentials.json on every
|
||||
// RoundTrip; the credential manager is consulted only near expiry.
|
||||
func createCopilotHTTPClient(token string, expiresAt int64, skipVerify bool) *http.Client {
|
||||
var base http.RoundTripper
|
||||
if skipVerify {
|
||||
base = &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
base = http.DefaultTransport
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: &copilotTransport{
|
||||
base: base,
|
||||
token: token,
|
||||
expiresAt: expiresAt,
|
||||
},
|
||||
Timeout: 120 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// copilotTransport decorates requests for api.githubcopilot.com.
|
||||
//
|
||||
// It owns a cached Copilot access token. When the token is still valid, the hot
|
||||
// path is in-memory only. Near expiry it refreshes through CredentialManager,
|
||||
// which updates both the cache here and credentials.json.
|
||||
type copilotTransport struct {
|
||||
base http.RoundTripper
|
||||
token string
|
||||
expiresAt int64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (t *copilotTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
token := t.cachedToken(req.Context())
|
||||
|
||||
newReq := req.Clone(req.Context())
|
||||
newReq.Header.Set("Authorization", "Bearer "+token)
|
||||
newReq.Header.Set("Copilot-Integration-Id", copilotIntegrationID)
|
||||
newReq.Header.Set("Editor-Version", copilotEditorVersion)
|
||||
newReq.Header.Set("Editor-Plugin-Version", copilotEditorPluginVersion)
|
||||
newReq.Header.Set("Openai-Intent", copilotOpenAIIntent)
|
||||
newReq.Header.Set("User-Agent", copilotUserAgent)
|
||||
newReq.Header.Set("X-GitHub-Api-Version", copilotGitHubAPIVersion)
|
||||
|
||||
return t.base.RoundTrip(newReq)
|
||||
}
|
||||
|
||||
// cachedToken returns the cached token unless it is within the five-minute
|
||||
// refresh window. Refresh errors fall back to the last token so the request can
|
||||
// surface any authoritative auth failure from the Copilot API.
|
||||
func (t *copilotTransport) cachedToken(ctx context.Context) string {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if t.expiresAt == 0 || time.Now().Unix() < t.expiresAt-300 {
|
||||
return t.token
|
||||
}
|
||||
|
||||
cm, err := auth.NewCredentialManager()
|
||||
if err != nil {
|
||||
return t.token
|
||||
}
|
||||
|
||||
fresh, err := cm.GetValidCopilotAccessTokenContext(ctx)
|
||||
if err != nil || fresh == "" {
|
||||
return t.token
|
||||
}
|
||||
|
||||
t.token = fresh
|
||||
if creds, err := cm.GetCopilotCredentials(); err == nil && creds != nil && creds.CopilotAccessToken == fresh {
|
||||
t.expiresAt = creds.ExpiresAt
|
||||
}
|
||||
return t.token
|
||||
}
|
||||
|
||||
func createGoogleProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
||||
apiKey := firstNonEmpty(
|
||||
config.ProviderAPIKey,
|
||||
@@ -993,12 +1363,12 @@ func createGoogleProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
|
||||
provider, err := google.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Google provider: %w", err)
|
||||
return nil, wrapProviderErr("Google", "provider", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Google model: %w", err)
|
||||
return nil, wrapProviderErr("Google", "model", err)
|
||||
}
|
||||
|
||||
return &ProviderResult{Model: model}, nil
|
||||
@@ -1031,12 +1401,12 @@ func createAzureProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
|
||||
provider, err := azure.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Azure OpenAI provider: %w", err)
|
||||
return nil, wrapProviderErr("Azure OpenAI", "provider", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Azure OpenAI model: %w", err)
|
||||
return nil, wrapProviderErr("Azure OpenAI", "model", err)
|
||||
}
|
||||
|
||||
return &ProviderResult{Model: model}, nil
|
||||
@@ -1056,12 +1426,12 @@ func createOpenRouterProvider(ctx context.Context, config *ProviderConfig, model
|
||||
|
||||
provider, err := openrouter.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OpenRouter provider: %w", err)
|
||||
return nil, wrapProviderErr("OpenRouter", "provider", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OpenRouter model: %w", err)
|
||||
return nil, wrapProviderErr("OpenRouter", "model", err)
|
||||
}
|
||||
|
||||
return &ProviderResult{Model: model}, nil
|
||||
@@ -1073,12 +1443,12 @@ func createBedrockProvider(ctx context.Context, config *ProviderConfig, modelNam
|
||||
// Bedrock uses AWS SDK default credential chain (env vars, shared config, etc.)
|
||||
provider, err := bedrock.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Bedrock provider: %w", err)
|
||||
return nil, wrapProviderErr("Bedrock", "provider", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Bedrock model: %w", err)
|
||||
return nil, wrapProviderErr("Bedrock", "model", err)
|
||||
}
|
||||
|
||||
return &ProviderResult{Model: model}, nil
|
||||
@@ -1102,12 +1472,12 @@ func createVercelProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
|
||||
provider, err := vercel.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Vercel provider: %w", err)
|
||||
return nil, wrapProviderErr("Vercel", "provider", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Vercel model: %w", err)
|
||||
return nil, wrapProviderErr("Vercel", "model", err)
|
||||
}
|
||||
|
||||
return &ProviderResult{Model: model}, nil
|
||||
@@ -1160,12 +1530,12 @@ func createCustomProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
|
||||
p, err := openai.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create custom provider: %w", err)
|
||||
return nil, wrapProviderErr("custom", "provider", err)
|
||||
}
|
||||
|
||||
model, err := p.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create custom model: %w", err)
|
||||
return nil, wrapProviderErr("custom", "model", err)
|
||||
}
|
||||
|
||||
return &ProviderResult{Model: model}, nil
|
||||
@@ -1209,12 +1579,12 @@ func createOllamaProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
|
||||
provider, err := openaicompat.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Ollama provider: %w", err)
|
||||
return nil, wrapProviderErr("Ollama", "provider", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Ollama model: %w", err)
|
||||
return nil, wrapProviderErr("Ollama", "model", err)
|
||||
}
|
||||
|
||||
return &ProviderResult{
|
||||
|
||||
@@ -246,6 +246,7 @@ func loadEmbeddedProviders() map[string]modelsDBProvider {
|
||||
// doesn't track yet. Callers should treat a nil return as "unknown model"
|
||||
// and continue with sensible defaults.
|
||||
func (r *ModelsRegistry) LookupModel(provider, modelID string) *ModelInfo {
|
||||
provider = catalogProviderID(provider)
|
||||
providerInfo, exists := r.providers[provider]
|
||||
if !exists {
|
||||
return nil
|
||||
@@ -273,6 +274,7 @@ func LookupModelForSettings(modelString string) *ModelInfo {
|
||||
|
||||
// getRequiredEnvVars returns the required environment variables for a provider.
|
||||
func (r *ModelsRegistry) getRequiredEnvVars(provider string) ([]string, error) {
|
||||
provider = catalogProviderID(provider)
|
||||
providerInfo, exists := r.providers[provider]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("unsupported provider: %s", provider)
|
||||
@@ -287,6 +289,7 @@ func (r *ModelsRegistry) getRequiredEnvVars(provider string) ([]string, error) {
|
||||
// variables. Returns nil for providers not in the registry (unknown
|
||||
// providers are assumed to handle auth themselves or via --provider-api-key).
|
||||
func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) error {
|
||||
provider = catalogProviderID(provider)
|
||||
if apiKey != "" {
|
||||
return nil
|
||||
}
|
||||
@@ -311,6 +314,15 @@ func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) err
|
||||
}
|
||||
}
|
||||
|
||||
// For GitHub Copilot, check stored GitHub OAuth credentials.
|
||||
if provider == copilotProviderID {
|
||||
if cm, err := auth.NewCredentialManager(); err == nil {
|
||||
if has, _ := cm.HasCopilotCredentials(); has {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
envVars, err := r.getRequiredEnvVars(provider)
|
||||
if err != nil {
|
||||
// Unknown provider — nothing to validate
|
||||
@@ -350,6 +362,7 @@ func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) err
|
||||
|
||||
// SuggestModels returns similar model names when an invalid model is provided.
|
||||
func (r *ModelsRegistry) SuggestModels(provider, invalidModel string) []string {
|
||||
provider = catalogProviderID(provider)
|
||||
providerInfo, exists := r.providers[provider]
|
||||
if !exists {
|
||||
return nil
|
||||
@@ -404,8 +417,8 @@ func isProviderLLMSupported(providerID string, info *ProviderInfo) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if npm maps to an LLM provider
|
||||
if _, ok := npmToLLMProvider[info.NPM]; ok {
|
||||
// Check if npm maps to a known wire protocol
|
||||
if _, ok := npmToWireProtocol[info.NPM]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -415,6 +428,7 @@ func isProviderLLMSupported(providerID string, info *ProviderInfo) bool {
|
||||
|
||||
// GetModelsForProvider returns all models for a specific provider.
|
||||
func (r *ModelsRegistry) GetModelsForProvider(provider string) (map[string]ModelInfo, error) {
|
||||
provider = catalogProviderID(provider)
|
||||
providerInfo, exists := r.providers[provider]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("unsupported provider: %s", provider)
|
||||
@@ -425,6 +439,7 @@ func (r *ModelsRegistry) GetModelsForProvider(provider string) (map[string]Model
|
||||
|
||||
// GetProviderInfo returns the full provider info, or nil if not found.
|
||||
func (r *ModelsRegistry) GetProviderInfo(provider string) *ProviderInfo {
|
||||
provider = catalogProviderID(provider)
|
||||
info, exists := r.providers[provider]
|
||||
if !exists {
|
||||
return nil
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"charm.land/fantasy/providers/google"
|
||||
)
|
||||
|
||||
// templatePlaceholderRe matches "${NAME}" placeholders in URL templates from
|
||||
// models.dev (e.g. "https://${DATABRICKS_HOST}/ai-gateway/mlflow/v1").
|
||||
var templatePlaceholderRe = regexp.MustCompile(`\$\{([A-Z0-9_]+)\}`)
|
||||
|
||||
// templateEnvVarOverrides supplies fallback environment variable names for
|
||||
// placeholders that providers commonly use under non-obvious env names.
|
||||
// The placeholder name itself is always tried first; this map adds extra
|
||||
// names to try when the placeholder doesn't match the canonical env var.
|
||||
var templateEnvVarOverrides = map[string][]string{
|
||||
"CLOUDFLARE_ACCOUNT_ID": {"CF_ACCOUNT_ID"},
|
||||
"CLOUDFLARE_GATEWAY_NAME": {"CF_GATEWAY", "CLOUDFLARE_GATEWAY"},
|
||||
"DATABRICKS_HOST": {"DATABRICKS_WORKSPACE_URL"},
|
||||
"SNOWFLAKE_ACCOUNT": {"SNOWFLAKE_ACCOUNT_ID"},
|
||||
}
|
||||
|
||||
// resolveTemplatedAPIURL substitutes "${VAR}" placeholders in apiURL with the
|
||||
// values of the named environment variables. Returns:
|
||||
// - ("", nil) when apiURL contains no placeholders (caller keeps current URL),
|
||||
// - (resolved, nil) when every placeholder was resolved,
|
||||
// - ("", error) when one or more placeholders are unset, with a message that
|
||||
// names the missing env vars and points at the relevant provider.
|
||||
//
|
||||
// The info parameter is used purely for error messaging (provider name).
|
||||
func resolveTemplatedAPIURL(apiURL string, info *ProviderInfo) (string, error) {
|
||||
if apiURL == "" || !strings.Contains(apiURL, "${") {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var missing []string
|
||||
resolved := templatePlaceholderRe.ReplaceAllStringFunc(apiURL, func(match string) string {
|
||||
// match is "${NAME}". Extract NAME.
|
||||
name := match[2 : len(match)-1]
|
||||
if v := os.Getenv(name); v != "" {
|
||||
return v
|
||||
}
|
||||
for _, alt := range templateEnvVarOverrides[name] {
|
||||
if v := os.Getenv(alt); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
missing = append(missing, name)
|
||||
return match
|
||||
})
|
||||
|
||||
if len(missing) > 0 {
|
||||
providerName := info.ID
|
||||
if info.Name != "" {
|
||||
providerName = info.Name
|
||||
}
|
||||
return "", fmt.Errorf(
|
||||
"provider %s requires environment variable(s) %s to construct its API URL (%s); "+
|
||||
"set them or pass --provider-url to override",
|
||||
providerName, strings.Join(missing, ", "), apiURL,
|
||||
)
|
||||
}
|
||||
return resolved, nil
|
||||
}
|
||||
|
||||
// ResolveProviderBaseURL returns the base API URL kit will use when talking to
|
||||
// the given provider, applying the same resolution order as CreateProvider:
|
||||
//
|
||||
// 1. The provider's `api` field from the models.dev registry.
|
||||
// 2. The hard-coded default base URL of its npm SDK package (e.g.
|
||||
// @ai-sdk/groq → https://api.groq.com/openai/v1).
|
||||
// 3. Template substitution against the current process environment when the
|
||||
// URL contains "${VAR}" placeholders (e.g. cloudflare-workers-ai needs
|
||||
// CLOUDFLARE_ACCOUNT_ID).
|
||||
//
|
||||
// It returns an error when the provider is unknown, when no URL can be derived,
|
||||
// or when a templated URL has unset placeholders. The error message is suitable
|
||||
// for direct display to end users.
|
||||
//
|
||||
// Note: providers handled by bespoke auth schemes (amazon-bedrock SigV4,
|
||||
// azure resource URLs, google-vertex project/location, sap-ai-core customer
|
||||
// deployments) may return either an empty URL or a regional/templated URL —
|
||||
// the actual endpoint is finalised inside their native handlers and depends on
|
||||
// runtime credentials.
|
||||
func ResolveProviderBaseURL(providerID string) (string, error) {
|
||||
registry := GetGlobalRegistry()
|
||||
info := registry.GetProviderInfo(providerID)
|
||||
if info == nil {
|
||||
return "", fmt.Errorf("unknown provider: %s", providerID)
|
||||
}
|
||||
|
||||
apiURL := info.API
|
||||
if apiURL == "" {
|
||||
if defaultURL, ok := sdkDefaultBaseURL[info.NPM]; ok {
|
||||
apiURL = defaultURL
|
||||
}
|
||||
}
|
||||
|
||||
if apiURL == "" {
|
||||
return "", fmt.Errorf(
|
||||
"provider %s has no default API URL: its npm package %q does not "+
|
||||
"ship a built-in baseURL (likely Bedrock SigV4, Azure deployment, "+
|
||||
"Vertex project/location, or a customer-hosted endpoint). "+
|
||||
"Pass --provider-url or set the provider's URL env var",
|
||||
providerID, info.NPM,
|
||||
)
|
||||
}
|
||||
|
||||
if strings.Contains(apiURL, "${") {
|
||||
resolved, err := resolveTemplatedAPIURL(apiURL, info)
|
||||
if err != nil {
|
||||
return apiURL, err
|
||||
}
|
||||
return resolved, nil
|
||||
}
|
||||
return apiURL, nil
|
||||
}
|
||||
|
||||
// createGoogleVertexProvider creates a Google Gemini provider that targets the
|
||||
// Vertex AI backend (rather than the public generativelanguage.googleapis.com
|
||||
// endpoint). It requires the same project/region environment variables as
|
||||
// google-vertex-anthropic.
|
||||
func createGoogleVertexProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
||||
projectID := firstNonEmpty(
|
||||
os.Getenv("GOOGLE_VERTEX_PROJECT"),
|
||||
os.Getenv("GOOGLE_CLOUD_PROJECT"),
|
||||
os.Getenv("GCLOUD_PROJECT"),
|
||||
os.Getenv("CLOUDSDK_CORE_PROJECT"),
|
||||
)
|
||||
if projectID == "" {
|
||||
return nil, fmt.Errorf(
|
||||
"google Vertex project ID not provided, set GOOGLE_VERTEX_PROJECT, " +
|
||||
"GOOGLE_CLOUD_PROJECT, or GCLOUD_PROJECT environment variable",
|
||||
)
|
||||
}
|
||||
|
||||
region := firstNonEmpty(
|
||||
os.Getenv("GOOGLE_VERTEX_LOCATION"),
|
||||
os.Getenv("CLOUD_ML_REGION"),
|
||||
)
|
||||
if region == "" {
|
||||
region = "global"
|
||||
}
|
||||
|
||||
opts := []google.Option{
|
||||
google.WithVertex(projectID, region),
|
||||
google.WithName("google-vertex"),
|
||||
}
|
||||
|
||||
if config.TLSSkipVerify {
|
||||
opts = append(opts, google.WithHTTPClient(createHTTPClientWithTLSConfig(true)))
|
||||
}
|
||||
|
||||
provider, err := google.New(opts...)
|
||||
if err != nil {
|
||||
return nil, wrapProviderErr("Google Vertex", "provider", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, wrapProviderErr("Google Vertex", "model", err)
|
||||
}
|
||||
|
||||
return &ProviderResult{Model: model}, nil
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestSDKDefaultBaseURL_CoversAllWireMappedPackages enforces the invariant
|
||||
// that every npm package recognised by the auto-router has a corresponding
|
||||
// default base URL — otherwise a provider that omits its `api` field in the
|
||||
// registry would silently fail to route at runtime.
|
||||
func TestSDKDefaultBaseURL_CoversAllWireMappedPackages(t *testing.T) {
|
||||
for npm := range npmToWireProtocol {
|
||||
// @ai-sdk/openai-compatible is a wire family, not a single SDK with
|
||||
// a default URL — providers using it always supply their own `api`.
|
||||
if npm == "@ai-sdk/openai-compatible" {
|
||||
continue
|
||||
}
|
||||
if _, ok := sdkDefaultBaseURL[npm]; !ok {
|
||||
t.Errorf("npm %q is in npmToWireProtocol but has no sdkDefaultBaseURL entry — "+
|
||||
"providers using this npm with no `api` field cannot be routed", npm)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSDKDefaultBaseURL_AllURLsAreAbsolute sanity-checks that every default
|
||||
// URL is a well-formed absolute https endpoint (catches typos in the table).
|
||||
func TestSDKDefaultBaseURL_AllURLsAreAbsolute(t *testing.T) {
|
||||
for npm, url := range sdkDefaultBaseURL {
|
||||
if !strings.HasPrefix(url, "https://") {
|
||||
t.Errorf("sdkDefaultBaseURL[%q] = %q is not an absolute https URL", npm, url)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveProviderBaseURL_RegistryFirst verifies that the registry's `api`
|
||||
// field wins over any SDK default.
|
||||
func TestResolveProviderBaseURL_RegistryFirst(t *testing.T) {
|
||||
// xai is in the registry with no `api` field — its URL comes from the
|
||||
// SDK default. Use a synthetic registry-backed provider to test the
|
||||
// priority via the public registry instead.
|
||||
url, err := ResolveProviderBaseURL("openai")
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveProviderBaseURL(openai): %v", err)
|
||||
}
|
||||
if url != "https://api.openai.com/v1" {
|
||||
t.Errorf("openai URL = %q, want https://api.openai.com/v1", url)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveProviderBaseURL_SDKDefaultFallback verifies that providers
|
||||
// without an `api` field (groq, cerebras, xai, …) resolve to their SDK
|
||||
// hard-coded default URL.
|
||||
func TestResolveProviderBaseURL_SDKDefaultFallback(t *testing.T) {
|
||||
tests := map[string]string{
|
||||
"groq": "https://api.groq.com/openai/v1",
|
||||
"cerebras": "https://api.cerebras.ai/v1",
|
||||
"xai": "https://api.x.ai/v1",
|
||||
"mistral": "https://api.mistral.ai/v1",
|
||||
"perplexity": "https://api.perplexity.ai",
|
||||
"togetherai": "https://api.together.xyz/v1",
|
||||
"deepinfra": "https://api.deepinfra.com/v1/openai",
|
||||
"cohere": "https://api.cohere.com/compatibility/v1",
|
||||
"v0": "https://api.v0.dev/v1",
|
||||
"aihubmix": "https://aihubmix.com/v1",
|
||||
"venice": "https://api.venice.ai/api/v1",
|
||||
"openrouter": "https://openrouter.ai/api/v1",
|
||||
}
|
||||
for providerID, wantURL := range tests {
|
||||
t.Run(providerID, func(t *testing.T) {
|
||||
got, err := ResolveProviderBaseURL(providerID)
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveProviderBaseURL(%s): %v", providerID, err)
|
||||
}
|
||||
if got != wantURL {
|
||||
t.Errorf("%s URL = %q, want %q", providerID, got, wantURL)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveProviderBaseURL_TemplatedURL_MissingEnv verifies that providers
|
||||
// whose URL contains "${VAR}" placeholders surface a targeted error when the
|
||||
// environment variables are unset.
|
||||
func TestResolveProviderBaseURL_TemplatedURL_MissingEnv(t *testing.T) {
|
||||
// cloudflare-workers-ai's api URL contains ${CLOUDFLARE_ACCOUNT_ID}.
|
||||
// Ensure the variable is unset for this test.
|
||||
t.Setenv("CLOUDFLARE_ACCOUNT_ID", "")
|
||||
t.Setenv("CF_ACCOUNT_ID", "")
|
||||
|
||||
_, err := ResolveProviderBaseURL("cloudflare-workers-ai")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unset CLOUDFLARE_ACCOUNT_ID, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "CLOUDFLARE_ACCOUNT_ID") {
|
||||
t.Errorf("error should name the missing env var, got: %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "--provider-url") {
|
||||
t.Errorf("error should suggest --provider-url override, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveProviderBaseURL_TemplatedURL_Resolved verifies env-var
|
||||
// substitution succeeds when the placeholder is set.
|
||||
func TestResolveProviderBaseURL_TemplatedURL_Resolved(t *testing.T) {
|
||||
t.Setenv("CLOUDFLARE_ACCOUNT_ID", "test-acct-123")
|
||||
got, err := ResolveProviderBaseURL("cloudflare-workers-ai")
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveProviderBaseURL: %v", err)
|
||||
}
|
||||
if !strings.Contains(got, "test-acct-123") {
|
||||
t.Errorf("resolved URL %q should contain test-acct-123", got)
|
||||
}
|
||||
if strings.Contains(got, "${") {
|
||||
t.Errorf("resolved URL %q still contains template placeholder", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveProviderBaseURL_UnknownProvider verifies the not-in-registry error.
|
||||
func TestResolveProviderBaseURL_UnknownProvider(t *testing.T) {
|
||||
_, err := ResolveProviderBaseURL("does-not-exist")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown provider, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unknown provider") {
|
||||
t.Errorf("error should say 'unknown provider', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAutoRouteProvider_SDKDefaultURLFallback verifies that providers whose
|
||||
// registry entry omits the `api` field (groq, mistral, xai, etc.) are still
|
||||
// auto-routed by falling back to the SDK's hard-coded default URL.
|
||||
func TestAutoRouteProvider_SDKDefaultURLFallback(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
npmPackage string
|
||||
wantInURL string
|
||||
}{
|
||||
{"groq", "@ai-sdk/groq", "groq.com"},
|
||||
{"cerebras", "@ai-sdk/cerebras", "cerebras.ai"},
|
||||
{"xai", "@ai-sdk/xai", "x.ai"},
|
||||
{"mistral", "@ai-sdk/mistral", "mistral.ai"},
|
||||
{"v0", "@ai-sdk/vercel", "v0.dev"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reg := &ModelsRegistry{
|
||||
providers: map[string]ProviderInfo{
|
||||
"testfallback": {
|
||||
ID: "testfallback",
|
||||
Name: "Test Fallback",
|
||||
Env: []string{"TESTFALLBACK_API_KEY"},
|
||||
NPM: tt.npmPackage,
|
||||
// API intentionally omitted — must fall back to SDK default.
|
||||
Models: map[string]ModelInfo{
|
||||
"any-model": {ID: "any-model", Name: "any-model"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
config := &ProviderConfig{ProviderAPIKey: "test-key"}
|
||||
|
||||
result, err := autoRouteProvider(context.Background(), config, "testfallback", "any-model", reg)
|
||||
if err != nil {
|
||||
t.Fatalf("autoRouteProvider returned error: %v", err)
|
||||
}
|
||||
if result == nil || result.Model == nil {
|
||||
t.Fatal("autoRouteProvider returned nil model")
|
||||
}
|
||||
// Verify the SDK default URL was picked up.
|
||||
if !strings.Contains(config.ProviderURL, tt.wantInURL) {
|
||||
t.Errorf("config.ProviderURL = %q, want substring %q (SDK default)",
|
||||
config.ProviderURL, tt.wantInURL)
|
||||
}
|
||||
// All these wrappers route through the openai-compat wire.
|
||||
gotType := reflect.TypeOf(result.Model).String()
|
||||
if gotType != "openai.languageModel" {
|
||||
t.Errorf("model type = %q, want openai.languageModel", gotType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveTemplatedAPIURL_NoPlaceholders verifies that URLs without
|
||||
// placeholders are returned as-is (the caller keeps using the original).
|
||||
func TestResolveTemplatedAPIURL_NoPlaceholders(t *testing.T) {
|
||||
got, err := resolveTemplatedAPIURL("https://api.example.com/v1", &ProviderInfo{ID: "x"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "" {
|
||||
t.Errorf("got %q, want empty string for URL with no placeholders", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveTemplatedAPIURL_AltEnvVar verifies that the alternative env-var
|
||||
// names (e.g. CF_ACCOUNT_ID for CLOUDFLARE_ACCOUNT_ID) are honoured.
|
||||
func TestResolveTemplatedAPIURL_AltEnvVar(t *testing.T) {
|
||||
t.Setenv("CLOUDFLARE_ACCOUNT_ID", "")
|
||||
t.Setenv("CF_ACCOUNT_ID", "alt-name-123")
|
||||
|
||||
got, err := resolveTemplatedAPIURL(
|
||||
"https://api.cloudflare.com/client/v4/accounts/${CLOUDFLARE_ACCOUNT_ID}/ai/v1",
|
||||
&ProviderInfo{ID: "cloudflare-workers-ai"},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !strings.Contains(got, "alt-name-123") {
|
||||
t.Errorf("resolved URL %q should have picked up CF_ACCOUNT_ID alternative", got)
|
||||
}
|
||||
}
|
||||
@@ -70,7 +70,8 @@ func ParseTemplate(path string) (*PromptTemplate, error) {
|
||||
}
|
||||
|
||||
// ParseCommandArgs splits a command line into arguments respecting quotes.
|
||||
// It handles single quotes, double quotes, and backslash escaping.
|
||||
// It handles single quotes, double quotes, backslash escaping, and splits on
|
||||
// spaces and tabs.
|
||||
func ParseCommandArgs(input string) []string {
|
||||
var args []string
|
||||
var current strings.Builder
|
||||
@@ -78,7 +79,7 @@ func ParseCommandArgs(input string) []string {
|
||||
inDoubleQuote := false
|
||||
escaped := false
|
||||
|
||||
for i, r := range input {
|
||||
for _, r := range input {
|
||||
if escaped {
|
||||
current.WriteRune(r)
|
||||
escaped = false
|
||||
@@ -101,7 +102,7 @@ func ParseCommandArgs(input string) []string {
|
||||
continue
|
||||
}
|
||||
|
||||
if r == ' ' && !inSingleQuote && !inDoubleQuote {
|
||||
if (r == ' ' || r == '\t') && !inSingleQuote && !inDoubleQuote {
|
||||
if current.Len() > 0 {
|
||||
args = append(args, current.String())
|
||||
current.Reset()
|
||||
@@ -110,7 +111,6 @@ func ParseCommandArgs(input string) []string {
|
||||
}
|
||||
|
||||
current.WriteRune(r)
|
||||
_ = i // silence unused warning when we need position later
|
||||
}
|
||||
|
||||
if current.Len() > 0 {
|
||||
@@ -325,8 +325,3 @@ func (t *PromptTemplate) Expand(argsInput string) string {
|
||||
args := ParseCommandArgs(argsInput)
|
||||
return SubstituteArgs(t.Content, args)
|
||||
}
|
||||
|
||||
// ExpandWithArgs substitutes the provided arguments into the template content.
|
||||
func (t *PromptTemplate) ExpandWithArgs(args []string) string {
|
||||
return SubstituteArgs(t.Content, args)
|
||||
}
|
||||
|
||||
@@ -458,11 +458,6 @@ func (tm *TreeManager) AppendLLMMessage(msg fantasy.Message) (string, error) {
|
||||
return tm.AppendMessage(message.FromLLMMessage(msg))
|
||||
}
|
||||
|
||||
// Deprecated: Use AppendLLMMessage instead.
|
||||
func (tm *TreeManager) AppendFantasyMessage(msg fantasy.Message) (string, error) {
|
||||
return tm.AppendLLMMessage(msg)
|
||||
}
|
||||
|
||||
// AppendModelChange records a model/provider change.
|
||||
func (tm *TreeManager) AppendModelChange(provider, modelID string) (string, error) {
|
||||
tm.mu.Lock()
|
||||
@@ -1170,11 +1165,6 @@ func (tm *TreeManager) AddLLMMessages(msgs []fantasy.Message) error {
|
||||
return tm.flushLocked()
|
||||
}
|
||||
|
||||
// Deprecated: Use AddLLMMessages instead.
|
||||
func (tm *TreeManager) AddFantasyMessages(msgs []fantasy.Message) error {
|
||||
return tm.AddLLMMessages(msgs)
|
||||
}
|
||||
|
||||
// GetLLMMessages builds the context and returns just the messages.
|
||||
// This satisfies the same conceptual role as the old Manager.GetMessages().
|
||||
func (tm *TreeManager) GetLLMMessages() []fantasy.Message {
|
||||
@@ -1182,11 +1172,6 @@ func (tm *TreeManager) GetLLMMessages() []fantasy.Message {
|
||||
return msgs
|
||||
}
|
||||
|
||||
// Deprecated: Use GetLLMMessages instead.
|
||||
func (tm *TreeManager) GetFantasyMessages() []fantasy.Message {
|
||||
return tm.GetLLMMessages()
|
||||
}
|
||||
|
||||
// --- Internal helpers ---
|
||||
|
||||
// addEntryToIndex adds an entry to the in-memory indices.
|
||||
|
||||
@@ -18,8 +18,11 @@ type PromptTemplate struct {
|
||||
Variables []string
|
||||
}
|
||||
|
||||
// variableRe matches {{variable_name}} placeholders.
|
||||
var variableRe = regexp.MustCompile(`\{\{(\w+)\}\}`)
|
||||
// variableRe matches {{variable_name}} placeholders, tolerating surrounding
|
||||
// whitespace inside the braces (e.g. {{ name }}). This is the canonical
|
||||
// template grammar shared by skill prompts and the extension template API
|
||||
// (pkg/kit ParseTemplate/RenderTemplate delegate here).
|
||||
var variableRe = regexp.MustCompile(`\{\{\s*(\w+)\s*\}\}`)
|
||||
|
||||
// NewPromptTemplate creates a PromptTemplate, automatically extracting
|
||||
// variable names from {{...}} placeholders in content.
|
||||
@@ -50,11 +53,13 @@ func LoadPromptTemplate(path string) (*PromptTemplate, error) {
|
||||
// Expand replaces all {{variable}} placeholders with values from the
|
||||
// provided map. Missing variables are left as-is (no error).
|
||||
func (t *PromptTemplate) Expand(values map[string]string) string {
|
||||
result := t.Content
|
||||
for k, v := range values {
|
||||
result = strings.ReplaceAll(result, "{{"+k+"}}", v)
|
||||
}
|
||||
return result
|
||||
return variableRe.ReplaceAllStringFunc(t.Content, func(m string) string {
|
||||
name := variableRe.FindStringSubmatch(m)[1]
|
||||
if v, ok := values[name]; ok {
|
||||
return v
|
||||
}
|
||||
return m
|
||||
})
|
||||
}
|
||||
|
||||
// ExpandStrict replaces all {{variable}} placeholders and returns an error
|
||||
|
||||
@@ -345,49 +345,70 @@ func (p *MCPConnectionPool) createStdioClient(ctx context.Context, serverConfig
|
||||
return stdioClient, nil
|
||||
}
|
||||
|
||||
// createSSEClient creates an SSE client
|
||||
// parseHeaders parses "Key: Value" header strings into a map.
|
||||
func parseHeaders(raw []string) map[string]string {
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
headers := make(map[string]string)
|
||||
for _, header := range raw {
|
||||
parts := strings.SplitN(header, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
headers[key] = value
|
||||
}
|
||||
}
|
||||
if len(headers) == 0 {
|
||||
return nil
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
// buildOAuthConfig constructs a transport.OAuthConfig from the server config
|
||||
// and the pool's OAuth flow. Returns nil if OAuth is not applicable.
|
||||
func (p *MCPConnectionPool) buildOAuthConfig(serverConfig config.MCPServerConfig) (*transport.OAuthConfig, error) {
|
||||
if p.oauthFlow == nil || serverConfig.NoOAuth {
|
||||
return nil, nil
|
||||
}
|
||||
tokenStore, err := p.createTokenStore(serverConfig.URL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", err)
|
||||
}
|
||||
cfg := &transport.OAuthConfig{
|
||||
RedirectURI: p.oauthFlow.handler.RedirectURI(),
|
||||
PKCEEnabled: true,
|
||||
TokenStore: tokenStore,
|
||||
}
|
||||
if serverConfig.OAuthClientID != "" {
|
||||
cfg.ClientID = serverConfig.OAuthClientID
|
||||
}
|
||||
if serverConfig.OAuthClientSecret != "" {
|
||||
cfg.ClientSecret = serverConfig.OAuthClientSecret
|
||||
}
|
||||
if len(serverConfig.OAuthScopes) > 0 {
|
||||
cfg.Scopes = serverConfig.OAuthScopes
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (p *MCPConnectionPool) createSSEClient(ctx context.Context, serverConfig config.MCPServerConfig) (client.MCPClient, error) {
|
||||
var options []transport.ClientOption
|
||||
|
||||
if len(serverConfig.Headers) > 0 {
|
||||
headers := make(map[string]string)
|
||||
for _, header := range serverConfig.Headers {
|
||||
parts := strings.SplitN(header, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
headers[key] = value
|
||||
}
|
||||
}
|
||||
if len(headers) > 0 {
|
||||
options = append(options, transport.WithHeaders(headers))
|
||||
}
|
||||
if headers := parseHeaders(serverConfig.Headers); headers != nil {
|
||||
options = append(options, transport.WithHeaders(headers))
|
||||
}
|
||||
|
||||
// Enable OAuth for remote transports when an auth handler is configured
|
||||
// and the server hasn't opted out via NoOAuth. Public MCP servers (e.g.
|
||||
// PubMed) set NoOAuth to skip dynamic client registration and token
|
||||
// exchange, which would otherwise fail with a 404.
|
||||
if p.oauthFlow != nil && !serverConfig.NoOAuth {
|
||||
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
|
||||
if tsErr != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
|
||||
}
|
||||
oauthCfg := transport.OAuthConfig{
|
||||
RedirectURI: p.oauthFlow.handler.RedirectURI(),
|
||||
PKCEEnabled: true,
|
||||
TokenStore: tokenStore,
|
||||
}
|
||||
if serverConfig.OAuthClientID != "" {
|
||||
oauthCfg.ClientID = serverConfig.OAuthClientID
|
||||
}
|
||||
if serverConfig.OAuthClientSecret != "" {
|
||||
oauthCfg.ClientSecret = serverConfig.OAuthClientSecret
|
||||
}
|
||||
if len(serverConfig.OAuthScopes) > 0 {
|
||||
oauthCfg.Scopes = serverConfig.OAuthScopes
|
||||
}
|
||||
options = append(options, transport.WithOAuth(oauthCfg))
|
||||
oauthCfg, err := p.buildOAuthConfig(serverConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if oauthCfg != nil {
|
||||
options = append(options, transport.WithOAuth(*oauthCfg))
|
||||
}
|
||||
|
||||
sseClient, err := client.NewSSEMCPClient(serverConfig.URL, options...)
|
||||
@@ -406,43 +427,18 @@ func (p *MCPConnectionPool) createSSEClient(ctx context.Context, serverConfig co
|
||||
func (p *MCPConnectionPool) createStreamableClient(ctx context.Context, serverConfig config.MCPServerConfig) (client.MCPClient, error) {
|
||||
var options []transport.StreamableHTTPCOption
|
||||
|
||||
if len(serverConfig.Headers) > 0 {
|
||||
headers := make(map[string]string)
|
||||
for _, header := range serverConfig.Headers {
|
||||
parts := strings.SplitN(header, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
headers[key] = value
|
||||
}
|
||||
}
|
||||
if len(headers) > 0 {
|
||||
options = append(options, transport.WithHTTPHeaders(headers))
|
||||
}
|
||||
if headers := parseHeaders(serverConfig.Headers); headers != nil {
|
||||
options = append(options, transport.WithHTTPHeaders(headers))
|
||||
}
|
||||
|
||||
// Enable OAuth for remote transports when an auth handler is configured
|
||||
// and the server hasn't opted out via NoOAuth.
|
||||
if p.oauthFlow != nil && !serverConfig.NoOAuth {
|
||||
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
|
||||
if tsErr != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
|
||||
}
|
||||
oauthCfg := transport.OAuthConfig{
|
||||
RedirectURI: p.oauthFlow.handler.RedirectURI(),
|
||||
PKCEEnabled: true,
|
||||
TokenStore: tokenStore,
|
||||
}
|
||||
if serverConfig.OAuthClientID != "" {
|
||||
oauthCfg.ClientID = serverConfig.OAuthClientID
|
||||
}
|
||||
if serverConfig.OAuthClientSecret != "" {
|
||||
oauthCfg.ClientSecret = serverConfig.OAuthClientSecret
|
||||
}
|
||||
if len(serverConfig.OAuthScopes) > 0 {
|
||||
oauthCfg.Scopes = serverConfig.OAuthScopes
|
||||
}
|
||||
options = append(options, transport.WithHTTPOAuth(oauthCfg))
|
||||
oauthCfg, err := p.buildOAuthConfig(serverConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if oauthCfg != nil {
|
||||
options = append(options, transport.WithHTTPOAuth(*oauthCfg))
|
||||
}
|
||||
|
||||
streamableClient, err := client.NewStreamableHttpClient(serverConfig.URL, options...)
|
||||
|
||||
+60
-57
@@ -641,30 +641,16 @@ func (m *MCPToolManager) ExecuteTool(ctx context.Context, prefixedName, inputJSO
|
||||
Request: mcp.Request{Method: "tools/call"},
|
||||
Params: callParams,
|
||||
}
|
||||
result, callErr := conn.client.CallTool(ctx, callRequest)
|
||||
if callErr != nil {
|
||||
if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) {
|
||||
if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, mapping.serverName, callErr); flowErr != nil {
|
||||
return nil, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", mapping.originalName, flowErr)
|
||||
}
|
||||
result, callErr = conn.client.CallTool(ctx, callRequest)
|
||||
if callErr != nil {
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
||||
return nil, fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr)
|
||||
}
|
||||
} else {
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
||||
return nil, fmt.Errorf("failed to call mcp tool: %w", callErr)
|
||||
}
|
||||
var result *mcp.CallToolResult
|
||||
err := m.withOAuthRetry(ctx, mapping.serverName, mapping.originalName, func() error {
|
||||
var callErr error
|
||||
result, callErr = conn.client.CallTool(ctx, callRequest)
|
||||
return callErr
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
marshaledResult, mErr := json.Marshal(result)
|
||||
if mErr != nil {
|
||||
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
|
||||
}
|
||||
return &MCPToolResult{
|
||||
Content: string(marshaledResult),
|
||||
IsError: result.IsError,
|
||||
}, nil
|
||||
return marshalToolResult(result)
|
||||
}
|
||||
|
||||
// Task-augmented path. Bypass the upstream CallTool helper because its
|
||||
@@ -683,40 +669,25 @@ func (m *MCPToolManager) ExecuteTool(ctx context.Context, prefixedName, inputJSO
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
||||
return nil, fmt.Errorf("failed to call mcp tool: %w", callErr)
|
||||
}
|
||||
marshaledResult, mErr := json.Marshal(result)
|
||||
if mErr != nil {
|
||||
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
|
||||
}
|
||||
return &MCPToolResult{Content: string(marshaledResult), IsError: result.IsError}, nil
|
||||
return marshalToolResult(result)
|
||||
}
|
||||
|
||||
callResult, taskResult, callErr := callToolWithTask(ctx, rawClient, callParams)
|
||||
if callErr != nil {
|
||||
if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) {
|
||||
if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, mapping.serverName, callErr); flowErr != nil {
|
||||
return nil, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", mapping.originalName, flowErr)
|
||||
}
|
||||
callResult, taskResult, callErr = callToolWithTask(ctx, rawClient, callParams)
|
||||
if callErr != nil {
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
||||
return nil, fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr)
|
||||
}
|
||||
} else {
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
||||
return nil, fmt.Errorf("failed to call mcp tool: %w", callErr)
|
||||
}
|
||||
var (
|
||||
callResult *mcp.CallToolResult
|
||||
taskResult *mcp.CreateTaskResult
|
||||
)
|
||||
err = m.withOAuthRetry(ctx, mapping.serverName, mapping.originalName, func() error {
|
||||
var callErr error
|
||||
callResult, taskResult, callErr = callToolWithTask(ctx, rawClient, callParams)
|
||||
return callErr
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Server chose to answer synchronously — same shape as the no-task path.
|
||||
if callResult != nil {
|
||||
marshaledResult, mErr := json.Marshal(callResult)
|
||||
if mErr != nil {
|
||||
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
|
||||
}
|
||||
return &MCPToolResult{
|
||||
Content: string(marshaledResult),
|
||||
IsError: callResult.IsError,
|
||||
}, nil
|
||||
return marshalToolResult(callResult)
|
||||
}
|
||||
|
||||
// Asynchronous task path: poll until terminal, then return the result.
|
||||
@@ -732,18 +703,50 @@ func (m *MCPToolManager) ExecuteTool(ctx context.Context, prefixedName, inputJSO
|
||||
}
|
||||
|
||||
// Adapt TaskResultResult → CallToolResult for downstream JSON shape parity.
|
||||
adapted := &mcp.CallToolResult{
|
||||
return marshalToolResult(&mcp.CallToolResult{
|
||||
Content: final.Content,
|
||||
StructuredContent: final.StructuredContent,
|
||||
IsError: final.IsError,
|
||||
})
|
||||
}
|
||||
|
||||
// withOAuthRetry runs call once; when it fails with an OAuth error and an
|
||||
// OAuth flow is configured, it re-authorizes the server and retries once.
|
||||
// Connection failures are reported to the pool and wrapped uniformly. This
|
||||
// consolidates the retry/error chain shared by the synchronous and
|
||||
// task-augmented tool-call paths.
|
||||
func (m *MCPToolManager) withOAuthRetry(ctx context.Context, serverName, toolName string, call func() error) error {
|
||||
callErr := call()
|
||||
if callErr == nil {
|
||||
return nil
|
||||
}
|
||||
marshaledResult, mErr := json.Marshal(adapted)
|
||||
if mErr != nil {
|
||||
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
|
||||
if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) {
|
||||
if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, serverName, callErr); flowErr != nil {
|
||||
return fmt.Errorf("OAuth re-authorization failed for tool %s: %w", toolName, flowErr)
|
||||
}
|
||||
if callErr = call(); callErr != nil {
|
||||
m.connectionPool.HandleConnectionError(serverName, callErr)
|
||||
return fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
m.connectionPool.HandleConnectionError(serverName, callErr)
|
||||
return fmt.Errorf("failed to call mcp tool: %w", callErr)
|
||||
}
|
||||
|
||||
// marshalToolResult converts an MCP CallToolResult into the JSON-encoded
|
||||
// MCPToolResult shape returned to the agent.
|
||||
func marshalToolResult(result *mcp.CallToolResult) (*MCPToolResult, error) {
|
||||
if result == nil {
|
||||
return nil, errors.New("mcp tool call returned nil result")
|
||||
}
|
||||
marshaled, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", err)
|
||||
}
|
||||
return &MCPToolResult{
|
||||
Content: string(marshaledResult),
|
||||
IsError: final.IsError,
|
||||
Content: string(marshaled),
|
||||
IsError: result.IsError,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -167,6 +167,21 @@ var SlashCommands = []SlashCommand{
|
||||
Category: "System",
|
||||
Aliases: []string{"/cp"},
|
||||
},
|
||||
{
|
||||
Name: "/retry",
|
||||
Description: "Resubmit the last user message (e.g. after a provider error)",
|
||||
Category: "System",
|
||||
Aliases: []string{"/rt"},
|
||||
},
|
||||
{
|
||||
Name: "/edit",
|
||||
Description: "Open a file in $EDITOR (fuzzy-find a path, then edit)",
|
||||
Category: "System",
|
||||
Aliases: []string{"/ed"},
|
||||
HasArgs: true,
|
||||
// Note: no Complete callback — file fuzzy-finding is driven directly
|
||||
// by InputComponent (mirroring the @file popup with directory drill).
|
||||
},
|
||||
{
|
||||
Name: "/export",
|
||||
Description: "Export session (JSONL by default, or /export path.jsonl)",
|
||||
|
||||
+29
-35
@@ -2,7 +2,6 @@ package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
@@ -44,28 +43,39 @@ func parseModelName(modelString string) (provider, model string) {
|
||||
// ollama or unrecognised models). This is used by the interactive TUI path
|
||||
// which doesn't go through SetupCLI.
|
||||
func CreateUsageTracker(modelString, providerAPIKey string) *UsageTracker {
|
||||
provider, model := parseModelName(modelString)
|
||||
if provider == "unknown" || model == "unknown" || provider == "ollama" {
|
||||
return nil
|
||||
}
|
||||
|
||||
registry := models.GetGlobalRegistry()
|
||||
modelInfo := registry.LookupModel(provider, model)
|
||||
modelInfo, provider := lookupTrackableModel(modelString)
|
||||
if modelInfo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
isOAuth := false
|
||||
if provider == "anthropic" {
|
||||
_, source, err := auth.GetAnthropicAPIKey(providerAPIKey)
|
||||
if err == nil && strings.HasPrefix(source, "stored OAuth") {
|
||||
isOAuth = true
|
||||
}
|
||||
}
|
||||
|
||||
isOAuth := provider == "anthropic" && auth.IsAnthropicOAuth(providerAPIKey)
|
||||
return NewUsageTracker(modelInfo, provider, 80, isOAuth)
|
||||
}
|
||||
|
||||
// UpdateUsageTrackerForModel refreshes an existing tracker after a model
|
||||
// switch so token counting and cost reporting use the new model's metadata.
|
||||
// No-op for a nil tracker or untrackable models (unknown/ollama).
|
||||
func UpdateUsageTrackerForModel(t *UsageTracker, modelString, providerAPIKey string) {
|
||||
if t == nil {
|
||||
return
|
||||
}
|
||||
modelInfo, provider := lookupTrackableModel(modelString)
|
||||
if modelInfo == nil {
|
||||
return
|
||||
}
|
||||
isOAuth := provider == "anthropic" && auth.IsAnthropicOAuth(providerAPIKey)
|
||||
t.UpdateModelInfo(modelInfo, provider, isOAuth)
|
||||
}
|
||||
|
||||
// lookupTrackableModel resolves a model string to registry metadata, returning
|
||||
// nil for models without usage tracking support (unknown or ollama models).
|
||||
func lookupTrackableModel(modelString string) (*models.ModelInfo, string) {
|
||||
provider, model := parseModelName(modelString)
|
||||
if provider == "unknown" || model == "unknown" || provider == "ollama" {
|
||||
return nil, provider
|
||||
}
|
||||
return models.GetGlobalRegistry().LookupModel(provider, model), provider
|
||||
}
|
||||
|
||||
// SetupCLI creates, configures, and initializes a CLI instance with the provided
|
||||
// options. It sets up model display, usage tracking for supported providers, and
|
||||
// shows initial loading information. Returns nil in quiet mode or an initialized
|
||||
@@ -89,24 +99,8 @@ func SetupCLI(opts *CLISetupOptions) (*CLI, error) {
|
||||
}
|
||||
|
||||
// Set up usage tracking for supported providers
|
||||
if provider != "unknown" && model != "unknown" {
|
||||
// Skip usage tracking for ollama as it's not in models.dev
|
||||
if provider != "ollama" {
|
||||
registry := models.GetGlobalRegistry()
|
||||
if modelInfo := registry.LookupModel(provider, model); modelInfo != nil {
|
||||
// Check if OAuth credentials are being used for Anthropic models
|
||||
isOAuth := false
|
||||
if provider == "anthropic" {
|
||||
_, source, err := auth.GetAnthropicAPIKey(opts.ProviderAPIKey)
|
||||
if err == nil && strings.HasPrefix(source, "stored OAuth") {
|
||||
isOAuth = true
|
||||
}
|
||||
}
|
||||
|
||||
usageTracker := NewUsageTracker(modelInfo, provider, 80, isOAuth) // Will be updated with actual width
|
||||
cli.SetUsageTracker(usageTracker)
|
||||
}
|
||||
}
|
||||
if usageTracker := CreateUsageTracker(opts.ModelString, opts.ProviderAPIKey); usageTracker != nil {
|
||||
cli.SetUsageTracker(usageTracker)
|
||||
}
|
||||
|
||||
// Display model info (the system message block provides its own spacing).
|
||||
|
||||
@@ -125,6 +125,33 @@ func ExtractAtPrefix(line string, cursorCol int) (hasAt bool, prefix string, sta
|
||||
return true, raw, atIdx
|
||||
}
|
||||
|
||||
// editTriggerPrefixes lists the command tokens (including trailing space)
|
||||
// that activate the /edit fuzzy-file picker. Aliases come first so the
|
||||
// longer alias "/edit " is matched before a hypothetical superset.
|
||||
var editTriggerPrefixes = []string{"/edit ", "/ed "}
|
||||
|
||||
// ExtractEditPrefix detects when the input value is a single-line /edit (or
|
||||
// alias) invocation and returns the path-portion the user has typed so far.
|
||||
//
|
||||
// Returns:
|
||||
// - cmdLen: byte offset where the path argument begins (i.e. length of
|
||||
// the matched command token, including its trailing space)
|
||||
// - pathPrefix: text the user has typed after the command token
|
||||
// - ok: true when the value matches one of the /edit triggers
|
||||
//
|
||||
// Multi-line values never match — /edit only makes sense as a single line.
|
||||
func ExtractEditPrefix(value string) (cmdLen int, pathPrefix string, ok bool) {
|
||||
if strings.Contains(value, "\n") {
|
||||
return 0, "", false
|
||||
}
|
||||
for _, p := range editTriggerPrefixes {
|
||||
if strings.HasPrefix(value, p) {
|
||||
return len(p), value[len(p):], true
|
||||
}
|
||||
}
|
||||
return 0, "", false
|
||||
}
|
||||
|
||||
// GetFileSuggestions returns file/directory suggestions matching the given
|
||||
// prefix. It tries `git ls-files` first (fast, respects .gitignore), then
|
||||
// falls back to a simple directory walk.
|
||||
|
||||
@@ -0,0 +1,233 @@
|
||||
// Package imagepreview renders low-resolution, in-terminal thumbnails of
|
||||
// images using Unicode upper half-block characters (U+2580, "▀") combined
|
||||
// with SGR foreground/background color codes.
|
||||
//
|
||||
// The technique stacks two vertical pixels into a single character cell: the
|
||||
// foreground color paints the top pixel and the background color paints the
|
||||
// bottom pixel. This produces pure styled text — no graphics escape sequences
|
||||
// — so the output survives terminal multiplexers (tmux, zellij) untouched.
|
||||
//
|
||||
// The Kitty graphics protocol, Sixel, and iTerm2 inline images are
|
||||
// deliberately NOT used: those are graphics escape-sequence protocols that
|
||||
// tmux and zellij strip or mangle by default.
|
||||
package imagepreview
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/color"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
// Register the standard image decoders so image.Decode can handle the
|
||||
// common clipboard / attachment formats.
|
||||
_ "image/gif"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
|
||||
"github.com/charmbracelet/colorprofile"
|
||||
"github.com/charmbracelet/x/ansi"
|
||||
xdraw "golang.org/x/image/draw"
|
||||
)
|
||||
|
||||
// upperHalfBlock is U+2580 ("▀"). The glyph fills the top half of a cell,
|
||||
// letting the foreground color render the top pixel and the cell's background
|
||||
// color render the bottom pixel.
|
||||
const upperHalfBlock = "▀"
|
||||
|
||||
// reset is the SGR reset sequence appended after each rendered row.
|
||||
const reset = "\x1b[0m"
|
||||
|
||||
// maxImageDimension is the largest width or height, in pixels, that Render will
|
||||
// fully decode. Images larger than this in either axis are rejected before the
|
||||
// expensive image.Decode call to guard against decompression bombs (small
|
||||
// encoded payloads that expand to enormous pixel buffers).
|
||||
const maxImageDimension = 20000
|
||||
|
||||
// Render returns a half-block ANSI thumbnail of the image, scaled to fit
|
||||
// within maxCols x maxRows terminal cells while preserving aspect ratio.
|
||||
//
|
||||
// Each terminal cell encodes two vertically-stacked pixels, so the effective
|
||||
// pixel resolution of the thumbnail is up to maxCols x (maxRows*2).
|
||||
//
|
||||
// Colors are emitted at the fidelity of the detected terminal color profile:
|
||||
// truecolor (24-bit) when available, degrading to 256-color. When the
|
||||
// terminal supports neither (no truecolor and no 256-color), Render returns
|
||||
// an empty string and a nil error so the caller can fall back to a text
|
||||
// indicator. A non-nil error is only returned when the image data cannot be
|
||||
// decoded.
|
||||
//
|
||||
// bg is the color used to composite transparent pixels (typically the
|
||||
// terminal background). A nil bg defaults to black.
|
||||
func Render(data []byte, mediaType string, maxCols, maxRows int, bg color.Color) (string, error) {
|
||||
profile := colorprofile.Env(os.Environ())
|
||||
return renderWithProfile(data, maxCols, maxRows, bg, profile)
|
||||
}
|
||||
|
||||
// renderWithProfile is the testable core of Render. It accepts an explicit
|
||||
// color profile instead of detecting one from the environment.
|
||||
func renderWithProfile(data []byte, maxCols, maxRows int, bg color.Color, profile colorprofile.Profile) (string, error) {
|
||||
// Half-block fidelity needs at least 256-color support. Anything less
|
||||
// degrades to the caller's text fallback.
|
||||
if profile < colorprofile.ANSI256 {
|
||||
return "", nil
|
||||
}
|
||||
if maxCols < 1 || maxRows < 1 {
|
||||
return "", nil
|
||||
}
|
||||
if bg == nil {
|
||||
bg = color.Black
|
||||
}
|
||||
|
||||
// Guard against decompression bombs: inspect the header dimensions before
|
||||
// fully decoding, so a small malicious payload cannot expand into an
|
||||
// enormous pixel buffer.
|
||||
cfg, _, err := image.DecodeConfig(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode image config: %w", err)
|
||||
}
|
||||
if cfg.Width > maxImageDimension || cfg.Height > maxImageDimension {
|
||||
return "", fmt.Errorf("decode image: dimensions %dx%d exceed limit %d", cfg.Width, cfg.Height, maxImageDimension)
|
||||
}
|
||||
|
||||
img, _, err := image.Decode(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode image: %w", err)
|
||||
}
|
||||
|
||||
// Target pixel dimensions: one pixel per column horizontally and two
|
||||
// pixels per row vertically (the half-block trick).
|
||||
cols, rows := fitDimensions(img.Bounds().Dx(), img.Bounds().Dy(), maxCols, maxRows)
|
||||
if cols < 1 || rows < 1 {
|
||||
return "", nil
|
||||
}
|
||||
pxW, pxH := cols, rows*2
|
||||
|
||||
scaled := image.NewRGBA(image.Rect(0, 0, pxW, pxH))
|
||||
xdraw.CatmullRom.Scale(scaled, scaled.Bounds(), img, img.Bounds(), xdraw.Over, nil)
|
||||
|
||||
var b strings.Builder
|
||||
for y := 0; y < pxH; y += 2 {
|
||||
for x := range pxW {
|
||||
top := composite(scaled.At(x, y), bg)
|
||||
bottom := composite(scaled.At(x, y+1), bg)
|
||||
b.WriteString(sgr(top, bottom, profile))
|
||||
b.WriteString(upperHalfBlock)
|
||||
}
|
||||
b.WriteString(reset)
|
||||
if y+2 < pxH {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
// fitDimensions returns the largest cell dimensions (cols, rows) that fit a
|
||||
// srcW x srcH image inside a maxCols x maxRows box while preserving aspect
|
||||
// ratio. Because each cell stacks two vertical pixels, a terminal cell is
|
||||
// treated as roughly twice as tall as it is wide, which keeps the thumbnail's
|
||||
// aspect ratio visually correct.
|
||||
func fitDimensions(srcW, srcH, maxCols, maxRows int) (cols, rows int) {
|
||||
if srcW <= 0 || srcH <= 0 {
|
||||
return 0, 0
|
||||
}
|
||||
// Work in pixel space: the box is maxCols wide and maxRows*2 tall.
|
||||
maxPxW := float64(maxCols)
|
||||
maxPxH := float64(maxRows * 2)
|
||||
scale := maxPxW / float64(srcW)
|
||||
if h := maxPxH / float64(srcH); h < scale {
|
||||
scale = h
|
||||
}
|
||||
if scale > 1 {
|
||||
scale = 1 // never upscale; keep the low-res look
|
||||
}
|
||||
pxW := int(float64(srcW) * scale)
|
||||
pxH := int(float64(srcH) * scale)
|
||||
if pxW < 1 {
|
||||
pxW = 1
|
||||
}
|
||||
if pxH < 2 {
|
||||
pxH = 2
|
||||
}
|
||||
// Convert back to cells; round the row count up to an even pixel height.
|
||||
cols = pxW
|
||||
rows = (pxH + 1) / 2
|
||||
if cols > maxCols {
|
||||
cols = maxCols
|
||||
}
|
||||
if rows > maxRows {
|
||||
rows = maxRows
|
||||
}
|
||||
return cols, rows
|
||||
}
|
||||
|
||||
// composite blends a (possibly translucent) pixel over the background color,
|
||||
// returning an opaque color. Fully opaque pixels are returned unchanged.
|
||||
func composite(c, bg color.Color) color.Color {
|
||||
r, g, b, a := c.RGBA()
|
||||
if a == 0xffff {
|
||||
return c
|
||||
}
|
||||
br, bgc, bb, _ := bg.RGBA()
|
||||
// Standard "over" alpha compositing in 16-bit space.
|
||||
inv := 0xffff - a
|
||||
out := color.RGBA64{
|
||||
R: uint16(r + br*inv/0xffff),
|
||||
G: uint16(g + bgc*inv/0xffff),
|
||||
B: uint16(b + bb*inv/0xffff),
|
||||
A: 0xffff,
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// sgr builds the SGR escape sequence that sets the foreground (top pixel) and
|
||||
// background (bottom pixel) colors at the fidelity of the given profile.
|
||||
func sgr(fg, bg color.Color, profile colorprofile.Profile) string {
|
||||
if profile >= colorprofile.TrueColor {
|
||||
fr, fgc, fb := rgb8(fg)
|
||||
br, bgc, bb := rgb8(bg)
|
||||
return fmt.Sprintf("\x1b[38;2;%d;%d;%d;48;2;%d;%d;%dm", fr, fgc, fb, br, bgc, bb)
|
||||
}
|
||||
return fmt.Sprintf("\x1b[38;5;%d;48;5;%dm", index256(fg, profile), index256(bg, profile))
|
||||
}
|
||||
|
||||
// rgb8 reduces a color to 8-bit RGB components.
|
||||
func rgb8(c color.Color) (r, g, b uint8) {
|
||||
cr, cg, cb, _ := c.RGBA()
|
||||
return uint8(cr >> 8), uint8(cg >> 8), uint8(cb >> 8)
|
||||
}
|
||||
|
||||
// index256 converts a color to its nearest 256-color palette index using the
|
||||
// supplied profile.
|
||||
func index256(c color.Color, profile colorprofile.Profile) uint8 {
|
||||
cc := profile.Convert(c)
|
||||
if idx, ok := cc.(ansi.IndexedColor); ok {
|
||||
return uint8(idx)
|
||||
}
|
||||
if idx, ok := cc.(ansi.BasicColor); ok {
|
||||
return uint8(idx)
|
||||
}
|
||||
// Fallback: derive an index directly if conversion produced an
|
||||
// unexpected type.
|
||||
r, g, b := rgb8(c)
|
||||
return ansi256FromRGB(r, g, b)
|
||||
}
|
||||
|
||||
// ansi256FromRGB maps an 8-bit RGB color to the xterm 256-color cube. It is a
|
||||
// best-effort fallback used only when profile.Convert does not yield a known
|
||||
// indexed color type.
|
||||
func ansi256FromRGB(r, g, b uint8) uint8 {
|
||||
q := func(v uint8) int {
|
||||
switch {
|
||||
case v < 48:
|
||||
return 0
|
||||
case v < 115:
|
||||
return 1
|
||||
default:
|
||||
return int((v - 35) / 40)
|
||||
}
|
||||
}
|
||||
ri, gi, bi := q(r), q(g), q(b)
|
||||
return uint8(16 + 36*ri + 6*gi + bi)
|
||||
}
|
||||
@@ -0,0 +1,193 @@
|
||||
package imagepreview
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/png"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/charmbracelet/colorprofile"
|
||||
)
|
||||
|
||||
// makePNG builds a simple w x h PNG filled with the given color and returns
|
||||
// its encoded bytes.
|
||||
func makePNG(t *testing.T, w, h int, c color.Color) []byte {
|
||||
t.Helper()
|
||||
img := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
for y := range h {
|
||||
for x := range w {
|
||||
img.Set(x, y, c)
|
||||
}
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
if err := png.Encode(&buf, img); err != nil {
|
||||
t.Fatalf("encode png: %v", err)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func TestRenderTrueColor(t *testing.T) {
|
||||
data := makePNG(t, 20, 20, color.RGBA{R: 255, A: 255})
|
||||
out, err := renderWithProfile(data, 10, 5, color.Black, colorprofile.TrueColor)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if out == "" {
|
||||
t.Fatal("expected non-empty thumbnail for truecolor profile")
|
||||
}
|
||||
if !strings.Contains(out, upperHalfBlock) {
|
||||
t.Error("output should contain upper half block glyphs")
|
||||
}
|
||||
if !strings.Contains(out, "\x1b[38;2;") || !strings.Contains(out, "48;2;") {
|
||||
t.Errorf("expected truecolor SGR sequences, got %q", out)
|
||||
}
|
||||
// Red fill should appear as 255;0;0 somewhere.
|
||||
if !strings.Contains(out, "255;0;0") {
|
||||
t.Errorf("expected red color in output, got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderANSI256(t *testing.T) {
|
||||
data := makePNG(t, 20, 20, color.RGBA{G: 255, A: 255})
|
||||
out, err := renderWithProfile(data, 8, 4, color.Black, colorprofile.ANSI256)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if out == "" {
|
||||
t.Fatal("expected non-empty thumbnail for ANSI256 profile")
|
||||
}
|
||||
if !strings.Contains(out, "\x1b[38;5;") || !strings.Contains(out, "48;5;") {
|
||||
t.Errorf("expected 256-color SGR sequences, got %q", out)
|
||||
}
|
||||
if strings.Contains(out, "38;2;") {
|
||||
t.Errorf("ANSI256 output should not contain truecolor sequences, got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderDegradesBelowANSI256(t *testing.T) {
|
||||
data := makePNG(t, 20, 20, color.RGBA{B: 255, A: 255})
|
||||
for _, p := range []colorprofile.Profile{colorprofile.ANSI, colorprofile.ASCII, colorprofile.NoTTY} {
|
||||
out, err := renderWithProfile(data, 10, 5, color.Black, p)
|
||||
if err != nil {
|
||||
t.Fatalf("profile %v: unexpected error: %v", p, err)
|
||||
}
|
||||
if out != "" {
|
||||
t.Errorf("profile %v: expected empty fallback, got %q", p, out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderInvalidImage(t *testing.T) {
|
||||
out, err := renderWithProfile([]byte("not an image"), 10, 5, color.Black, colorprofile.TrueColor)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid image data")
|
||||
}
|
||||
if out != "" {
|
||||
t.Errorf("expected empty output on decode error, got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderRejectsOversizedImage(t *testing.T) {
|
||||
// A header advertising dimensions beyond maxImageDimension must be
|
||||
// rejected before full decode (decompression-bomb guard). image.RGBA
|
||||
// allocation is avoided by only checking the config path here.
|
||||
w := maxImageDimension + 1
|
||||
data := makePNG(t, w, 1, color.White)
|
||||
out, err := renderWithProfile(data, 10, 5, color.Black, colorprofile.TrueColor)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for oversized image dimensions")
|
||||
}
|
||||
if out != "" {
|
||||
t.Errorf("expected empty output for oversized image, got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderZeroBox(t *testing.T) {
|
||||
data := makePNG(t, 20, 20, color.White)
|
||||
out, err := renderWithProfile(data, 0, 0, color.Black, colorprofile.TrueColor)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if out != "" {
|
||||
t.Errorf("expected empty output for zero-sized box, got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderNilBackgroundDefaults(t *testing.T) {
|
||||
data := makePNG(t, 10, 10, color.RGBA{R: 10, G: 20, B: 30, A: 255})
|
||||
out, err := renderWithProfile(data, 6, 3, nil, colorprofile.TrueColor)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if out == "" {
|
||||
t.Fatal("expected output with nil background (defaults to black)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRowCountWithinBounds(t *testing.T) {
|
||||
// A tall image should be capped at maxRows cells.
|
||||
data := makePNG(t, 10, 100, color.White)
|
||||
out, err := renderWithProfile(data, 20, 6, color.Black, colorprofile.TrueColor)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
rows := strings.Count(out, "\n") + 1
|
||||
if rows > 6 {
|
||||
t.Errorf("expected at most 6 rows, got %d", rows)
|
||||
}
|
||||
}
|
||||
|
||||
func TestColumnCountWithinBounds(t *testing.T) {
|
||||
// A wide image should be capped at maxCols cells per row.
|
||||
data := makePNG(t, 100, 10, color.White)
|
||||
out, err := renderWithProfile(data, 8, 20, color.Black, colorprofile.TrueColor)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
firstRow := strings.SplitN(out, "\n", 2)[0]
|
||||
cols := strings.Count(firstRow, upperHalfBlock)
|
||||
if cols > 8 {
|
||||
t.Errorf("expected at most 8 columns, got %d", cols)
|
||||
}
|
||||
if cols == 0 {
|
||||
t.Error("expected at least one column")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFitDimensionsPreservesAspect(t *testing.T) {
|
||||
// 2:1 (wide) image into a 40x20 box. Pixel box is 40x40; width-bound.
|
||||
cols, rows := fitDimensions(200, 100, 40, 20)
|
||||
if cols != 40 {
|
||||
t.Errorf("expected 40 cols, got %d", cols)
|
||||
}
|
||||
// pxH = 100 * (40/200) = 20 → 10 rows.
|
||||
if rows != 10 {
|
||||
t.Errorf("expected 10 rows, got %d", rows)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFitDimensionsNeverUpscales(t *testing.T) {
|
||||
cols, rows := fitDimensions(4, 4, 40, 20)
|
||||
if cols != 4 || rows != 2 {
|
||||
t.Errorf("expected 4x2 (no upscale), got %dx%d", cols, rows)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompositeOpaquePassthrough(t *testing.T) {
|
||||
c := color.RGBA{R: 1, G: 2, B: 3, A: 255}
|
||||
got := composite(c, color.White)
|
||||
if got != color.Color(c) {
|
||||
t.Errorf("opaque color should pass through unchanged, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompositeTransparentOverBackground(t *testing.T) {
|
||||
// Fully transparent pixel over red background should yield red.
|
||||
got := composite(color.RGBA{}, color.RGBA{R: 255, A: 255})
|
||||
r, g, b, a := got.RGBA()
|
||||
if r>>8 != 255 || g>>8 != 0 || b>>8 != 0 || a != 0xffff {
|
||||
t.Errorf("expected opaque red, got r=%d g=%d b=%d a=%d", r>>8, g>>8, b>>8, a)
|
||||
}
|
||||
}
|
||||
+224
-187
@@ -2,6 +2,7 @@ package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image/color"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
"github.com/mark3labs/kit/internal/clipboard"
|
||||
"github.com/mark3labs/kit/internal/ui/commands"
|
||||
"github.com/mark3labs/kit/internal/ui/core"
|
||||
"github.com/mark3labs/kit/internal/ui/imagepreview"
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
)
|
||||
|
||||
@@ -42,6 +44,12 @@ type InputComponent struct {
|
||||
popupHeight int
|
||||
submitNext bool // defer submit one tick so popup dismisses cleanly
|
||||
|
||||
// popup is the shared PopupList used to render the / and @ autocomplete
|
||||
// dropdowns. State (items, cursor, visible search-driven filter) is
|
||||
// driven externally by InputComponent — we only use PopupList for the
|
||||
// rendering chrome so all popups in the app look identical.
|
||||
popup *PopupList
|
||||
|
||||
// Argument completion state. When the user types "/cmd " followed by
|
||||
// a partial argument and the command has a Complete function, the popup
|
||||
// switches to argument-completion mode showing suggestions from Complete.
|
||||
@@ -53,10 +61,16 @@ type InputComponent struct {
|
||||
// file path, the popup shows file/directory suggestions from the cwd.
|
||||
fileMode bool // true when showing @file completions
|
||||
filePrefix string // current text after @ being matched
|
||||
fileAtStartIdx int // byte offset of @ in the textarea value
|
||||
fileAtStartIdx int // byte offset of @ (or path start in /edit mode) in the textarea value
|
||||
fileSuggestions []FileSuggestion // backing storage for file entries
|
||||
fileSynthCmds []commands.SlashCommand // synthetic commands.SlashCommands wrapping file entries
|
||||
|
||||
// fileEditMode is true when fileMode was activated by the /edit slash
|
||||
// command rather than an @ trigger. Selecting a file submits the line
|
||||
// (running $EDITOR on it); selecting a directory drills further like @
|
||||
// does. MCP resources are excluded in this mode.
|
||||
fileEditMode bool
|
||||
|
||||
// cwd is the working directory used for @file path resolution and
|
||||
// autocomplete suggestions. Set by the parent via SetCwd.
|
||||
cwd string
|
||||
@@ -80,6 +94,23 @@ type InputComponent struct {
|
||||
// Images are added via Ctrl+V and cleared on submit or Ctrl+U.
|
||||
pendingImages []core.ImageAttachment
|
||||
|
||||
// imageThumbs caches the rendered half-block thumbnail for each entry in
|
||||
// pendingImages (1:1 index correspondence). Thumbnails are rendered
|
||||
// asynchronously off the Bubble Tea event loop (decode + resample is too
|
||||
// slow to run inside Update), so an entry starts as the empty string
|
||||
// placeholder and is filled in when the matching thumbnailReadyMsg
|
||||
// arrives. An entry stays empty when the terminal cannot display a
|
||||
// half-block preview, in which case the text pill is shown alone.
|
||||
// See internal/ui/imagepreview.
|
||||
imageThumbs []string
|
||||
|
||||
// imageGen is a monotonic generation counter incremented whenever the
|
||||
// pending image set is cleared. Async thumbnail results carry the
|
||||
// generation they were enqueued under and are discarded if it no longer
|
||||
// matches, preventing a stale thumbnail from landing on the wrong slot
|
||||
// after a clear + re-attach.
|
||||
imageGen int
|
||||
|
||||
// history stores previously submitted prompts (most recent last).
|
||||
// Limited to maxHistory entries; duplicates of the previous entry are
|
||||
// skipped. Empty strings are never stored.
|
||||
@@ -105,6 +136,16 @@ type clipboardImageMsg struct {
|
||||
err error
|
||||
}
|
||||
|
||||
// thumbnailReadyMsg carries the result of an async thumbnail render back to
|
||||
// the Update loop. gen and index identify the pendingImages slot the
|
||||
// thumbnail belongs to; the result is dropped if the generation no longer
|
||||
// matches (the pending set was cleared) or the index is out of range.
|
||||
type thumbnailReadyMsg struct {
|
||||
gen int
|
||||
index int
|
||||
thumb string
|
||||
}
|
||||
|
||||
// NewInputComponent creates a new InputComponent with the given width and
|
||||
// optional AppController. If appCtrl is nil the component still works but
|
||||
// /clear and /clear-queue are no-ops.
|
||||
@@ -135,7 +176,7 @@ func NewInputComponent(width int, appCtrl AppController) *InputComponent {
|
||||
styles.Focused.CursorLine = lipgloss.NewStyle()
|
||||
ta.SetStyles(styles)
|
||||
|
||||
return &InputComponent{
|
||||
ic := &InputComponent{
|
||||
textarea: ta,
|
||||
commands: commands.SlashCommands,
|
||||
width: width,
|
||||
@@ -143,6 +184,12 @@ func NewInputComponent(width int, appCtrl AppController) *InputComponent {
|
||||
appCtrl: appCtrl,
|
||||
hideHint: true,
|
||||
}
|
||||
ic.popup = NewPopupList("", nil, width, 0)
|
||||
ic.popup.ShowSearch = false
|
||||
ic.popup.HideCount = true
|
||||
ic.popup.MaxVisible = ic.popupHeight
|
||||
ic.popup.FooterHint = "↑↓ navigate • tab complete • ↵ select • esc dismiss"
|
||||
return ic
|
||||
}
|
||||
|
||||
// SetCwd sets the working directory used for @file autocomplete suggestions
|
||||
@@ -193,7 +240,23 @@ func (s *InputComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return s, nil
|
||||
}
|
||||
if msg.image != nil {
|
||||
s.pendingImages = append(s.pendingImages, *msg.image)
|
||||
img := *msg.image
|
||||
index := len(s.pendingImages)
|
||||
s.pendingImages = append(s.pendingImages, img)
|
||||
// Reserve a placeholder; the async render fills it in via
|
||||
// thumbnailReadyMsg so Update never blocks on decode/resample.
|
||||
s.imageThumbs = append(s.imageThumbs, "")
|
||||
cols := s.thumbCols()
|
||||
if cols < 1 {
|
||||
return s, nil
|
||||
}
|
||||
return s, renderThumbnailCmd(img, cols, thumbMaxRows, style.GetTheme().Background, s.imageGen, index)
|
||||
}
|
||||
return s, nil
|
||||
|
||||
case thumbnailReadyMsg:
|
||||
if msg.gen == s.imageGen && msg.index >= 0 && msg.index < len(s.imageThumbs) {
|
||||
s.imageThumbs[msg.index] = msg.thumb
|
||||
}
|
||||
return s, nil
|
||||
|
||||
@@ -250,6 +313,8 @@ func (s *InputComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// Clear all pending image attachments.
|
||||
if len(s.pendingImages) > 0 {
|
||||
s.pendingImages = nil
|
||||
s.imageThumbs = nil
|
||||
s.imageGen++
|
||||
return s, nil
|
||||
}
|
||||
}
|
||||
@@ -405,10 +470,17 @@ func (s *InputComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
} else {
|
||||
s.showPopup = false
|
||||
s.fileMode = false
|
||||
s.fileEditMode = false
|
||||
}
|
||||
} else if len(lines) == 1 && strings.HasPrefix(lines[0], "/") {
|
||||
s.fileMode = false
|
||||
if !strings.Contains(lines[0], " ") {
|
||||
s.fileEditMode = false
|
||||
if cmdLen, pathPrefix, isEdit := ExtractEditPrefix(lines[0]); isEdit {
|
||||
// /edit fuzzy-file picker. Behaves like @ except
|
||||
// MCP resources are excluded and selecting a file
|
||||
// submits the line (running $EDITOR).
|
||||
s.updateEditFilePopup(cmdLen, pathPrefix)
|
||||
} else if !strings.Contains(lines[0], " ") {
|
||||
// Command name completion.
|
||||
s.showPopup = true
|
||||
s.argMode = false
|
||||
@@ -428,6 +500,7 @@ func (s *InputComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
s.showPopup = false
|
||||
s.argMode = false
|
||||
s.fileMode = false
|
||||
s.fileEditMode = false
|
||||
}
|
||||
}
|
||||
return s, cmd
|
||||
@@ -486,6 +559,8 @@ func (s *InputComponent) handleSubmit(value string) tea.Cmd {
|
||||
// images and clear them.
|
||||
images := s.pendingImages
|
||||
s.pendingImages = nil
|
||||
s.imageThumbs = nil
|
||||
s.imageGen++
|
||||
return func() tea.Msg {
|
||||
return core.SubmitMsg{Text: trimmed, Images: images}
|
||||
}
|
||||
@@ -519,6 +594,42 @@ func (s *InputComponent) resetHistoryBrowsing() {
|
||||
s.savedInput = ""
|
||||
}
|
||||
|
||||
// thumbMaxCols and thumbMaxRows cap the size, in terminal cells, of pending
|
||||
// image previews. Kept small for the low-res look and to keep scrollback
|
||||
// light.
|
||||
const (
|
||||
thumbMaxCols = 40
|
||||
thumbMaxRows = 12
|
||||
)
|
||||
|
||||
// thumbCols returns the thumbnail width in terminal cells given the current
|
||||
// input width, or 0 when there is no room to render a preview.
|
||||
func (s *InputComponent) thumbCols() int {
|
||||
if s.width <= 6 {
|
||||
return 0
|
||||
}
|
||||
cols := min(thumbMaxCols, s.width-6)
|
||||
if cols < 1 {
|
||||
return 0
|
||||
}
|
||||
return cols
|
||||
}
|
||||
|
||||
// renderThumbnailCmd returns a tea.Cmd that renders a half-block ANSI preview
|
||||
// off the Bubble Tea event loop. The decode + resample work runs in the Cmd
|
||||
// goroutine, and the result is delivered as a thumbnailReadyMsg tagged with
|
||||
// the generation and slot index it was enqueued for. An empty thumbnail
|
||||
// (terminal unsupported or render error) leaves the text pill in place.
|
||||
func renderThumbnailCmd(img core.ImageAttachment, cols, rows int, bg color.Color, gen, index int) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
thumb, err := imagepreview.Render(img.Data, img.MediaType, cols, rows, bg)
|
||||
if err != nil {
|
||||
thumb = ""
|
||||
}
|
||||
return thumbnailReadyMsg{gen: gen, index: index, thumb: thumb}
|
||||
}
|
||||
}
|
||||
|
||||
// View implements tea.Model. Renders the textarea, autocomplete popup
|
||||
// (if visible), and help text.
|
||||
func (s *InputComponent) View() tea.View {
|
||||
@@ -544,7 +655,9 @@ func (s *InputComponent) View() tea.View {
|
||||
// Popup is now rendered as a centered overlay in AppModel.View()
|
||||
// instead of inline here to prevent bottom overflow
|
||||
|
||||
// Show image attachment indicator when images are pending.
|
||||
// Show image attachment previews when images are pending. A cached
|
||||
// half-block thumbnail is rendered when the terminal supports it;
|
||||
// otherwise the text pill alone is shown.
|
||||
if len(s.pendingImages) > 0 {
|
||||
imgStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Secondary).
|
||||
@@ -553,6 +666,14 @@ func (s *InputComponent) View() tea.View {
|
||||
label := fmt.Sprintf("[%d image(s) attached] ctrl+u to clear", len(s.pendingImages))
|
||||
view.WriteString("\n")
|
||||
view.WriteString(imgStyle.Render(label))
|
||||
|
||||
thumbStyle := lipgloss.NewStyle().PaddingLeft(3)
|
||||
for i := range s.pendingImages {
|
||||
if i < len(s.imageThumbs) && s.imageThumbs[i] != "" {
|
||||
view.WriteString("\n")
|
||||
view.WriteString(thumbStyle.Render(s.imageThumbs[i]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !s.hideHint {
|
||||
@@ -591,191 +712,37 @@ func (s *InputComponent) View() tea.View {
|
||||
return tea.NewView(containerStyle.Render(view.String()))
|
||||
}
|
||||
|
||||
// renderPopup renders the autocomplete popup for slash command suggestions.
|
||||
// When rendered inline (not centered), returns the styled popup content.
|
||||
// RenderPopupCentered renders the popup as a centered overlay.
|
||||
// RenderPopupCentered renders the autocomplete popup for / or @ as a
|
||||
// centered overlay. Returns "" when the popup is not currently shown.
|
||||
// The actual filtering / selection state lives on InputComponent — this
|
||||
// method merely converts the filtered FuzzyMatch list into PopupItems
|
||||
// and asks the shared PopupList to draw it. As a result the / popup, the
|
||||
// @ popup, the model picker, the tree selector and the session selector
|
||||
// all share identical chrome.
|
||||
func (s *InputComponent) RenderPopupCentered(termWidth, termHeight int) string {
|
||||
if !s.showPopup || len(s.filtered) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
popupContent := s.renderPopupWithOptions(true)
|
||||
|
||||
// Center popup using lipgloss.Place
|
||||
positioned := lipgloss.Place(
|
||||
termWidth,
|
||||
termHeight,
|
||||
lipgloss.Center,
|
||||
lipgloss.Center,
|
||||
popupContent,
|
||||
)
|
||||
|
||||
return positioned
|
||||
}
|
||||
|
||||
// renderPopupWithOptions renders the popup content with optional center styling.
|
||||
func (s *InputComponent) renderPopupWithOptions(centered bool) string {
|
||||
theme := style.GetTheme()
|
||||
popupWidth := max(s.width-4, 20)
|
||||
|
||||
// Use the theme background for the popup - the full-width item backgrounds
|
||||
// and primary-colored selection will provide sufficient contrast
|
||||
popupBg := theme.Background
|
||||
|
||||
popupStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(theme.Primary).
|
||||
Background(popupBg).
|
||||
Padding(1, 2).
|
||||
Width(popupWidth).
|
||||
MarginLeft(0).
|
||||
MarginBottom(1) // Visual depth/shadow effect
|
||||
|
||||
// Inner content width: popup minus border (2) and horizontal padding (4).
|
||||
innerWidth := max(popupWidth-6, 10)
|
||||
|
||||
// Item background styles for high contrast
|
||||
normalItemBg := lipgloss.NewStyle().
|
||||
Background(popupBg).
|
||||
Foreground(theme.Text).
|
||||
Width(innerWidth).
|
||||
Padding(0, 1)
|
||||
|
||||
selectedItemBg := lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(theme.Background).
|
||||
Width(innerWidth).
|
||||
Padding(0, 1).
|
||||
Bold(true)
|
||||
|
||||
var items []string
|
||||
|
||||
visibleItems := min(len(s.filtered), s.popupHeight)
|
||||
startIdx := 0
|
||||
if s.selected >= s.popupHeight {
|
||||
startIdx = s.selected - s.popupHeight + 1
|
||||
}
|
||||
endIdx := min(startIdx+visibleItems, len(s.filtered))
|
||||
|
||||
for i := startIdx; i < endIdx; i++ {
|
||||
match := s.filtered[i]
|
||||
sc := match.Command
|
||||
|
||||
// Choose the appropriate background style
|
||||
itemStyle := normalItemBg
|
||||
if i == s.selected {
|
||||
itemStyle = selectedItemBg
|
||||
items := make([]PopupItem, len(s.filtered))
|
||||
for i, m := range s.filtered {
|
||||
desc := ""
|
||||
if m.Command != nil {
|
||||
desc = m.Command.Description
|
||||
}
|
||||
|
||||
// Build indicator with proper coloring
|
||||
var indicator string
|
||||
if i == s.selected {
|
||||
indicator = "> "
|
||||
} else {
|
||||
indicator = " "
|
||||
name := ""
|
||||
if m.Command != nil {
|
||||
name = m.Command.Name
|
||||
}
|
||||
|
||||
// Build content with name and description
|
||||
var content string
|
||||
if s.fileMode {
|
||||
// File mode: use full width for the path, show description inline
|
||||
maxNameLen := max(innerWidth-16, 8)
|
||||
displayName := sc.Name
|
||||
if len(displayName) > maxNameLen && maxNameLen > 3 {
|
||||
displayName = displayName[:maxNameLen-3] + "..."
|
||||
}
|
||||
|
||||
if sc.Description != "" && innerWidth > 30 {
|
||||
content = indicator + displayName + " " + sc.Description
|
||||
} else {
|
||||
content = indicator + displayName
|
||||
}
|
||||
} else {
|
||||
// Line layout: indicator(2) + name(nameWidth-2 visual) + desc
|
||||
if innerWidth < 20 {
|
||||
// Very narrow: show truncated name only
|
||||
displayName := sc.Name
|
||||
maxName := max(innerWidth-2, 3)
|
||||
if len(displayName) > maxName {
|
||||
displayName = displayName[:maxName-1] + "…"
|
||||
}
|
||||
content = indicator + displayName
|
||||
} else {
|
||||
// Compute nameWidth from the longest command name in the
|
||||
// visible slice so we never truncate unnecessarily.
|
||||
nameWidth := 0
|
||||
for _, fm := range s.filtered {
|
||||
if n := len([]rune(fm.Command.Name)); n > nameWidth {
|
||||
nameWidth = n
|
||||
}
|
||||
}
|
||||
nameWidth += 3 // account for indicator prefix (2) + gap before description (1)
|
||||
// Ensure descriptions still get at least 20 chars when possible.
|
||||
maxForName := innerWidth - 20
|
||||
if maxForName < 8 {
|
||||
maxForName = innerWidth * 2 / 3
|
||||
}
|
||||
if nameWidth > maxForName {
|
||||
nameWidth = maxForName
|
||||
}
|
||||
if nameWidth < 8 {
|
||||
nameWidth = 8
|
||||
}
|
||||
maxNameChars := nameWidth - 2
|
||||
displayName := sc.Name
|
||||
if len(displayName) > maxNameChars {
|
||||
displayName = displayName[:maxNameChars-1] + "…"
|
||||
}
|
||||
|
||||
// Description gets remaining space
|
||||
maxDescLen := max(innerWidth-nameWidth, 0)
|
||||
desc := sc.Description
|
||||
if maxDescLen >= 4 && desc != "" {
|
||||
if len(desc) > maxDescLen {
|
||||
desc = desc[:maxDescLen-3] + "..."
|
||||
}
|
||||
content = indicator + lipgloss.NewStyle().Width(maxNameChars).Render(displayName) + desc
|
||||
} else {
|
||||
content = indicator + displayName
|
||||
}
|
||||
}
|
||||
items[i] = PopupItem{
|
||||
Label: name,
|
||||
Description: desc,
|
||||
}
|
||||
|
||||
items = append(items, itemStyle.Render(content))
|
||||
}
|
||||
|
||||
// Add scroll indicators with background
|
||||
scrollStyle := lipgloss.NewStyle().
|
||||
Background(popupBg).
|
||||
Foreground(theme.VeryMuted).
|
||||
Width(innerWidth).
|
||||
Padding(0, 1)
|
||||
|
||||
if startIdx > 0 {
|
||||
items = append([]string{scrollStyle.Render(" ↑ more above")}, items...)
|
||||
}
|
||||
if endIdx < len(s.filtered) {
|
||||
items = append(items, scrollStyle.Render(" ↓ more below"))
|
||||
}
|
||||
|
||||
content := strings.Join(items, "\n")
|
||||
|
||||
// Adapt footer text to available width with background
|
||||
var footerText string
|
||||
if innerWidth >= 50 {
|
||||
footerText = "↑↓ navigate • tab complete • ↵ select • esc dismiss"
|
||||
} else if innerWidth >= 30 {
|
||||
footerText = "↑↓ nav • tab • ↵ select • esc"
|
||||
} else {
|
||||
footerText = "↑↓ tab ↵ esc"
|
||||
}
|
||||
footer := lipgloss.NewStyle().
|
||||
Background(popupBg).
|
||||
Foreground(theme.VeryMuted).
|
||||
Italic(true).
|
||||
Render(footerText)
|
||||
|
||||
return popupStyle.Render(content + "\n\n" + footer)
|
||||
s.popup.SetSize(termWidth, termHeight)
|
||||
s.popup.SetItems(items)
|
||||
s.popup.SetCursor(s.selected)
|
||||
return s.popup.RenderCentered(termWidth, termHeight)
|
||||
}
|
||||
|
||||
// completeArgs checks whether the input line matches a command with a Complete
|
||||
@@ -844,6 +811,8 @@ func readClipboardImageCmd() tea.Cmd {
|
||||
func (s *InputComponent) ClearPendingImages() []core.ImageAttachment {
|
||||
images := s.pendingImages
|
||||
s.pendingImages = nil
|
||||
s.imageThumbs = nil
|
||||
s.imageGen++
|
||||
return images
|
||||
}
|
||||
|
||||
@@ -862,6 +831,7 @@ func (s *InputComponent) Clear() bool {
|
||||
s.showPopup = false
|
||||
s.argMode = false
|
||||
s.fileMode = false
|
||||
s.fileEditMode = false
|
||||
s.browsingHistory = false
|
||||
s.savedInput = ""
|
||||
return hadContent
|
||||
@@ -871,6 +841,11 @@ func (s *InputComponent) Clear() bool {
|
||||
// file or MCP resource suggestion. For directories, it keeps the popup open
|
||||
// for further drilling. For files and resources, it closes the popup and adds
|
||||
// a trailing space.
|
||||
//
|
||||
// When fileEditMode is active the same path-replacement happens against the
|
||||
// /edit (or alias) command prefix instead of an @ trigger. Selecting a file
|
||||
// also arms submitNext so the next tick runs $EDITOR on it; selecting a
|
||||
// directory keeps the popup open for drill-down.
|
||||
func (s *InputComponent) applyFileCompletion(idx int) {
|
||||
if idx >= len(s.fileSuggestions) {
|
||||
return
|
||||
@@ -889,7 +864,17 @@ func (s *InputComponent) applyFileCompletion(idx int) {
|
||||
beforeAt := lastLine[:s.fileAtStartIdx]
|
||||
|
||||
var replacement string
|
||||
if suggestion.IsMCPResource {
|
||||
switch {
|
||||
case s.fileEditMode:
|
||||
// /edit path mode — no @ prefix; the path is the bare argument.
|
||||
// MCP resources are excluded upstream, so only file/dir entries reach here.
|
||||
needsQuote := strings.Contains(suggestion.RelPath, " ")
|
||||
if needsQuote {
|
||||
replacement = `"` + suggestion.RelPath + `"`
|
||||
} else {
|
||||
replacement = suggestion.RelPath
|
||||
}
|
||||
case suggestion.IsMCPResource:
|
||||
// MCP resources use @mcp:server:uri format.
|
||||
// Quote if the URI contains spaces.
|
||||
ref := "mcp:" + suggestion.MCPServerName + ":" + suggestion.MCPResourceURI
|
||||
@@ -899,7 +884,7 @@ func (s *InputComponent) applyFileCompletion(idx int) {
|
||||
replacement = "@" + ref
|
||||
}
|
||||
replacement += " "
|
||||
} else {
|
||||
default:
|
||||
needsQuote := strings.Contains(suggestion.RelPath, " ")
|
||||
if needsQuote {
|
||||
replacement = `@"` + suggestion.RelPath + `"`
|
||||
@@ -925,9 +910,61 @@ func (s *InputComponent) applyFileCompletion(idx int) {
|
||||
if suggestion.IsDir && !suggestion.IsMCPResource {
|
||||
// Keep popup open — trigger a refresh for the new directory.
|
||||
s.lastValue = "" // force re-evaluation on next update tick
|
||||
} else {
|
||||
s.showPopup = false
|
||||
s.fileMode = false
|
||||
s.selected = 0
|
||||
return
|
||||
}
|
||||
|
||||
s.showPopup = false
|
||||
s.fileMode = false
|
||||
s.selected = 0
|
||||
|
||||
if s.fileEditMode {
|
||||
// A file was selected via /edit — submit on the next tick so the
|
||||
// popup dismisses cleanly before $EDITOR takes the terminal.
|
||||
s.fileEditMode = false
|
||||
s.submitNext = true
|
||||
}
|
||||
}
|
||||
|
||||
// updateEditFilePopup queries the file-suggestion engine for the /edit path
|
||||
// prefix and populates the popup state. cmdLen is the byte offset of the path
|
||||
// argument within the current line (i.e. length of "/edit " or "/ed ").
|
||||
// Directories are kept so the user can drill down; MCP resources are skipped.
|
||||
func (s *InputComponent) updateEditFilePopup(cmdLen int, pathPrefix string) {
|
||||
var suggestions []FileSuggestion
|
||||
if s.cwd != "" {
|
||||
suggestions = GetFileSuggestions(pathPrefix, s.cwd)
|
||||
}
|
||||
if len(suggestions) == 0 {
|
||||
s.showPopup = false
|
||||
s.fileMode = false
|
||||
s.fileEditMode = false
|
||||
return
|
||||
}
|
||||
|
||||
sort.Slice(suggestions, func(i, j int) bool {
|
||||
return suggestions[i].Score > suggestions[j].Score
|
||||
})
|
||||
if len(suggestions) > maxFileSuggestions {
|
||||
suggestions = suggestions[:maxFileSuggestions]
|
||||
}
|
||||
|
||||
s.showPopup = true
|
||||
s.fileMode = true
|
||||
s.fileEditMode = true
|
||||
s.argMode = false
|
||||
s.filePrefix = pathPrefix
|
||||
s.fileAtStartIdx = cmdLen
|
||||
s.fileSuggestions = suggestions
|
||||
s.fileSynthCmds = make([]commands.SlashCommand, len(suggestions))
|
||||
s.filtered = make([]FuzzyMatch, len(suggestions))
|
||||
for i, fs := range suggestions {
|
||||
name := fs.RelPath
|
||||
desc := ""
|
||||
if fs.IsDir {
|
||||
desc = "directory"
|
||||
}
|
||||
s.fileSynthCmds[i] = commands.SlashCommand{Name: name, Description: desc}
|
||||
s.filtered[i] = FuzzyMatch{Command: &s.fileSynthCmds[i], Score: fs.Score}
|
||||
}
|
||||
s.selected = 0
|
||||
}
|
||||
|
||||
+358
-116
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -25,6 +26,7 @@ import (
|
||||
"github.com/mark3labs/kit/internal/ui/commands"
|
||||
uicore "github.com/mark3labs/kit/internal/ui/core"
|
||||
"github.com/mark3labs/kit/internal/ui/fileutil"
|
||||
"github.com/mark3labs/kit/internal/ui/imagepreview"
|
||||
"github.com/mark3labs/kit/internal/ui/prefs"
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
@@ -124,6 +126,14 @@ type AppController interface {
|
||||
// attachments (e.g. pasted images) into the currently running agent
|
||||
// turn. Behaves like Steer but includes file parts alongside the text.
|
||||
SteerWithFiles(prompt string, files []kit.LLMFilePart) int
|
||||
// PopLastUserMessage truncates the tree session at the parent of the
|
||||
// most recent user message on the current branch, syncs the in-memory
|
||||
// message store, and returns that user prompt (plus any image file
|
||||
// parts) so the caller can resubmit it. Used by /retry to recover from
|
||||
// provider errors (overloaded, timeout) without duplicating the user
|
||||
// message in context. Returns an error if the agent is busy, no tree
|
||||
// session is active, or no user message exists on the current branch.
|
||||
PopLastUserMessage() (string, []kit.LLMFilePart, error)
|
||||
}
|
||||
|
||||
// SkillItem holds display metadata about a loaded skill for the startup
|
||||
@@ -1198,53 +1208,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.modelSelector = nil
|
||||
m.state = stateInput
|
||||
if m.setModel != nil {
|
||||
previousModel := m.providerName + "/" + m.modelName
|
||||
|
||||
// Check if thinking level needs adjustment for the new model.
|
||||
// Some models (e.g., OpenAI gpt-5.4) don't support "minimal" and require "none".
|
||||
if m.thinkingLevel != "" && m.thinkingLevel != "off" {
|
||||
parts := strings.SplitN(msg.ModelString, "/", 2)
|
||||
if len(parts) == 2 {
|
||||
modelName := parts[1]
|
||||
currentLevel := models.ParseThinkingLevel(m.thinkingLevel)
|
||||
if !models.IsValidThinkingLevelForModel(currentLevel, modelName) {
|
||||
fallback := models.SuggestThinkingLevelFallback(currentLevel, modelName)
|
||||
if fallback != models.ThinkingOff {
|
||||
m.printSystemMessage(fmt.Sprintf(
|
||||
"Note: Model %s doesn't support '%s' thinking level. Adjusted to '%s'.",
|
||||
modelName, currentLevel, fallback,
|
||||
))
|
||||
m.thinkingLevel = string(fallback)
|
||||
if m.setThinkingLevel != nil {
|
||||
_ = m.setThinkingLevel(string(fallback))
|
||||
}
|
||||
go func() { _ = prefs.SaveThinkingLevelPreference(string(fallback)) }()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.setModel(msg.ModelString); err != nil {
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to switch model: %v", err))
|
||||
} else {
|
||||
// Update display state directly — we cannot use
|
||||
// NotifyModelChanged (prog.Send) from inside Update()
|
||||
// without deadlocking BubbleTea.
|
||||
parts := strings.SplitN(msg.ModelString, "/", 2)
|
||||
if len(parts) == 2 {
|
||||
m.providerName = parts[0]
|
||||
m.modelName = parts[1]
|
||||
}
|
||||
m.printSystemMessage(fmt.Sprintf("Switched to %s", msg.ModelString))
|
||||
// Persist model selection for next launch.
|
||||
go func() { _ = prefs.SaveModelPreference(msg.ModelString) }()
|
||||
if m.emitModelChange != nil {
|
||||
emit := m.emitModelChange
|
||||
newModel := msg.ModelString
|
||||
prev := previousModel
|
||||
go emit(newModel, prev, "user")
|
||||
}
|
||||
}
|
||||
m.switchModel(msg.ModelString)
|
||||
}
|
||||
return m, tea.Batch(cmds...)
|
||||
|
||||
@@ -1794,14 +1758,27 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// messages stay in chronological order.
|
||||
m.pendingUserPrints = append(m.pendingUserPrints, displayText)
|
||||
m.flushStreamAndPendingUserMessages()
|
||||
// Insert inline thumbnail previews after the user message.
|
||||
cmds = append(cmds, m.transcriptPreviewCmd(msg.Images, m.lastMessageID()))
|
||||
}
|
||||
} else {
|
||||
m.printUserMessage(displayText)
|
||||
// Insert inline thumbnail previews after the user message.
|
||||
cmds = append(cmds, m.transcriptPreviewCmd(msg.Images, m.lastMessageID()))
|
||||
}
|
||||
if m.state != stateWorking {
|
||||
m.state = stateWorking
|
||||
}
|
||||
|
||||
// ── Async transcript image preview ───────────────────────────────────────
|
||||
case imagePreviewReadyMsg:
|
||||
if msg.block != "" {
|
||||
item := NewStyledMessageItem(generateMessageID(), "user", "", msg.block)
|
||||
m.insertMessageAfter(msg.anchorID, item)
|
||||
m.refreshContent()
|
||||
m.layoutDirty = true
|
||||
}
|
||||
|
||||
// ── Shell command (! / !!) ───────────────────────────────────────────────
|
||||
case uicore.ShellCommandMsg:
|
||||
// Show spinner while the shell command runs.
|
||||
@@ -2397,6 +2374,16 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.layoutDirty = true
|
||||
}
|
||||
|
||||
case editFileMsg:
|
||||
// User returned from $EDITOR after `/edit <path>`. The file was
|
||||
// edited directly on disk — no textarea changes. Report the result.
|
||||
if msg.err != nil {
|
||||
m.printSystemMessage(fmt.Sprintf("Editor exited with error: %v", msg.err))
|
||||
} else {
|
||||
m.printSystemMessage(fmt.Sprintf("Edited `%s`", msg.path))
|
||||
}
|
||||
m.layoutDirty = true
|
||||
|
||||
case extReloadResultMsg:
|
||||
if msg.err != nil {
|
||||
m.printSystemMessage(fmt.Sprintf("Extension reload failed: %v", msg.err))
|
||||
@@ -2447,6 +2434,19 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.printSystemMessage(msg.Text)
|
||||
}
|
||||
|
||||
// ── Clipboard image attached / thumbnail rendered ────────────────────────
|
||||
// Both messages change the input region's rendered height (the pill and
|
||||
// the async half-block preview), so forward them to the input and mark the
|
||||
// layout dirty — otherwise distributeHeight keeps a stale, too-short input
|
||||
// height and the preview is clipped off the bottom of the screen.
|
||||
case clipboardImageMsg, thumbnailReadyMsg:
|
||||
if m.input != nil {
|
||||
updated, cmd := m.input.Update(msg)
|
||||
m.input, _ = updated.(inputComponentIface)
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
m.layoutDirty = true
|
||||
|
||||
default:
|
||||
// Pass unrecognised messages to all children.
|
||||
if m.input != nil {
|
||||
@@ -3046,6 +3046,85 @@ func truncateMessageForBlock(msg string, maxLines, width int) string {
|
||||
// Print helpers — add content to ScrollList
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// imagePreviewReadyMsg carries an asynchronously rendered transcript image
|
||||
// preview block back to the Update loop, where it is inserted into the
|
||||
// ScrollList directly after the originating user message (identified by
|
||||
// anchorID). Inserting by anchor — rather than appending — keeps the preview
|
||||
// next to its message even when the agent's streamed reply has already been
|
||||
// appended while the thumbnail was being decoded off the event loop.
|
||||
type imagePreviewReadyMsg struct {
|
||||
block string
|
||||
anchorID string
|
||||
}
|
||||
|
||||
// transcriptPreviewCmd returns a tea.Cmd that renders half-block thumbnail
|
||||
// previews for the given clipboard images off the Bubble Tea event loop
|
||||
// (decode + resample must not block Update). The rendered block is delivered
|
||||
// via imagePreviewReadyMsg, tagged with anchorID so the consumer can place it
|
||||
// directly after the originating user message. Returns nil when there is
|
||||
// nothing to render or no room for a preview; an empty result (terminal lacks
|
||||
// color support) yields a nil message that Bubble Tea ignores.
|
||||
func (m *AppModel) transcriptPreviewCmd(images []uicore.ImageAttachment, anchorID string) tea.Cmd {
|
||||
if len(images) == 0 {
|
||||
return nil
|
||||
}
|
||||
cols := thumbMaxCols
|
||||
if m.width > 6 && m.width-6 < cols {
|
||||
cols = m.width - 6
|
||||
}
|
||||
if cols < 1 {
|
||||
return nil
|
||||
}
|
||||
bg := style.GetTheme().Background
|
||||
imgs := images
|
||||
return func() tea.Msg {
|
||||
pad := lipgloss.NewStyle().PaddingLeft(2)
|
||||
var blocks []string
|
||||
for _, img := range imgs {
|
||||
thumb, err := imagepreview.Render(img.Data, img.MediaType, cols, thumbMaxRows, bg)
|
||||
if err != nil || thumb == "" {
|
||||
continue
|
||||
}
|
||||
blocks = append(blocks, pad.Render(thumb))
|
||||
}
|
||||
if len(blocks) == 0 {
|
||||
return nil
|
||||
}
|
||||
return imagePreviewReadyMsg{block: strings.Join(blocks, "\n"), anchorID: anchorID}
|
||||
}
|
||||
}
|
||||
|
||||
// lastMessageID returns the ID of the most recently added ScrollList message,
|
||||
// or "" when there are none. Used to anchor an async transcript preview to the
|
||||
// user message that was just printed.
|
||||
func (m *AppModel) lastMessageID() string {
|
||||
if len(m.messages) == 0 {
|
||||
return ""
|
||||
}
|
||||
return m.messages[len(m.messages)-1].ID()
|
||||
}
|
||||
|
||||
// insertMessageAfter inserts item immediately after the message whose ID
|
||||
// matches anchorID. If anchorID is empty or not found, item is appended.
|
||||
func (m *AppModel) insertMessageAfter(anchorID string, item MessageItem) {
|
||||
idx := -1
|
||||
if anchorID != "" {
|
||||
for i, msgItem := range m.messages {
|
||||
if msgItem.ID() == anchorID {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if idx < 0 {
|
||||
m.messages = append(m.messages, item)
|
||||
return
|
||||
}
|
||||
m.messages = append(m.messages, nil)
|
||||
copy(m.messages[idx+2:], m.messages[idx+1:])
|
||||
m.messages[idx+1] = item
|
||||
}
|
||||
|
||||
// printUserMessage renders a user message into the ScrollList.
|
||||
func (m *AppModel) printUserMessage(text string) {
|
||||
// Check if this exact message was just added (prevents duplicates)
|
||||
@@ -3171,6 +3250,10 @@ func (m *AppModel) handleSlashCommand(sc *commands.SlashCommand, args string) te
|
||||
return m.handleExportCommand(args)
|
||||
case "/copy":
|
||||
return m.handleCopyCommand()
|
||||
case "/retry":
|
||||
return m.handleRetryCommand()
|
||||
case "/edit":
|
||||
return m.handleEditCommand(args)
|
||||
case "/share":
|
||||
return m.handleShareCommand()
|
||||
case "/import":
|
||||
@@ -3596,6 +3679,8 @@ func (m *AppModel) printHelpMessage() {
|
||||
"- `/compact [instructions]`: Summarise older messages to free context space\n" +
|
||||
"- `/clear`: Clear message history\n" +
|
||||
"- `/copy`: Copy the last message to the system clipboard\n" +
|
||||
"- `/retry`: Resubmit the last user message (e.g. after a provider error)\n" +
|
||||
"- `/edit [path]`: Open a file in `$EDITOR` (fuzzy-find from cwd)\n" +
|
||||
"- `/export [path]`: Export session as JSONL\n" +
|
||||
"- `/import <path.jsonl>`: Import session from JSONL file\n" +
|
||||
"- `/reset-usage`: Reset usage statistics\n" +
|
||||
@@ -4080,11 +4165,31 @@ func (m *AppModel) handleModelCommand(args string) tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Direct model switch with the provided model string.
|
||||
m.switchModel(args)
|
||||
return nil
|
||||
}
|
||||
|
||||
// switchModel performs a direct model switch, shared by the model selector
|
||||
// overlay and the /model slash command: it adjusts the thinking level when
|
||||
// the new model doesn't support the current one, calls the setModel
|
||||
// callback, updates display state, persists preferences, and emits the
|
||||
// ModelChange extension event.
|
||||
//
|
||||
// Display state is updated directly — we cannot use NotifyModelChanged
|
||||
// (prog.Send) from inside Update() without deadlocking BubbleTea.
|
||||
func (m *AppModel) switchModel(modelString string) {
|
||||
if m.setModel == nil {
|
||||
m.printSystemMessage("Model switching is not available.")
|
||||
return
|
||||
}
|
||||
|
||||
previousModel := m.providerName + "/" + m.modelName
|
||||
|
||||
// Check if thinking level needs adjustment for the new model.
|
||||
// Some models (e.g., OpenAI gpt-5.4) don't support "minimal" and require "none".
|
||||
if m.thinkingLevel != "" && m.thinkingLevel != "off" {
|
||||
parts := strings.SplitN(args, "/", 2)
|
||||
if len(parts) == 2 {
|
||||
if parts := strings.SplitN(modelString, "/", 2); len(parts) == 2 {
|
||||
modelName := parts[1]
|
||||
currentLevel := models.ParseThinkingLevel(m.thinkingLevel)
|
||||
if !models.IsValidThinkingLevelForModel(currentLevel, modelName) {
|
||||
@@ -4104,32 +4209,26 @@ func (m *AppModel) handleModelCommand(args string) tea.Cmd {
|
||||
}
|
||||
}
|
||||
|
||||
// Direct model switch with the provided model string.
|
||||
previousModel := m.providerName + "/" + m.modelName
|
||||
if err := m.setModel(args); err != nil {
|
||||
if err := m.setModel(modelString); err != nil {
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to switch model: %v", err))
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
// Update display state directly (cannot use prog.Send from Update).
|
||||
parts := strings.SplitN(args, "/", 2)
|
||||
if len(parts) == 2 {
|
||||
if parts := strings.SplitN(modelString, "/", 2); len(parts) == 2 {
|
||||
m.providerName = parts[0]
|
||||
m.modelName = parts[1]
|
||||
}
|
||||
|
||||
if m.emitModelChange != nil {
|
||||
emit := m.emitModelChange
|
||||
prev := previousModel
|
||||
newModel := args
|
||||
go emit(newModel, prev, "user")
|
||||
}
|
||||
m.printSystemMessage(fmt.Sprintf("Switched to %s", modelString))
|
||||
|
||||
// Persist model selection for next launch.
|
||||
go func() { _ = prefs.SaveModelPreference(args) }()
|
||||
go func() { _ = prefs.SaveModelPreference(modelString) }()
|
||||
|
||||
m.printSystemMessage(fmt.Sprintf("Switched to %s", args))
|
||||
return nil
|
||||
if m.emitModelChange != nil {
|
||||
emit := m.emitModelChange
|
||||
go emit(modelString, previousModel, "user")
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -4446,6 +4545,141 @@ func (m *AppModel) handleCopyCommand() tea.Cmd {
|
||||
return clipboard.CopyToClipboard(text)
|
||||
}
|
||||
|
||||
// handleRetryCommand resubmits the most recent user message on the current
|
||||
// branch. Used to recover from transient provider errors (overloaded,
|
||||
// timeout) without users having to retype — and without the duplicate-user-
|
||||
// message bloat that retyping creates.
|
||||
//
|
||||
// Flow:
|
||||
// 1. App.PopLastUserMessage() truncates the tree at the parent of the last
|
||||
// user message and returns its text + any image parts. The failed turn's
|
||||
// entries become orphaned (still on disk, off-branch) so they will not
|
||||
// be re-sent to the LLM.
|
||||
// 2. The visible message list is rebuilt from the truncated branch so the
|
||||
// prior user message + any partial assistant + error rendering vanish.
|
||||
// 3. The prompt is resubmitted via Run/RunWithFiles, mirroring the normal
|
||||
// SubmitMsg display path (badge formatting, pending-prints flush,
|
||||
// stateWorking transition).
|
||||
func (m *AppModel) handleRetryCommand() tea.Cmd {
|
||||
if m.appCtrl == nil {
|
||||
m.printSystemMessage("App controller unavailable.")
|
||||
return nil
|
||||
}
|
||||
|
||||
prompt, files, err := m.appCtrl.PopLastUserMessage()
|
||||
if err != nil {
|
||||
m.printSystemMessage(fmt.Sprintf("Cannot retry: %v", err))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rebuild the visible ScrollList from the truncated branch so the failed
|
||||
// turn's user message and any partial assistant/error rendering disappear
|
||||
// before the resubmit prints a fresh user message.
|
||||
m.messages = []MessageItem{}
|
||||
m.renderSessionHistory()
|
||||
|
||||
// Mirror SubmitMsg's badge formatting for the display text.
|
||||
var imageCount, fileOnlyCount int
|
||||
for _, f := range files {
|
||||
if strings.HasPrefix(f.MediaType, "image/") {
|
||||
imageCount++
|
||||
} else {
|
||||
fileOnlyCount++
|
||||
}
|
||||
}
|
||||
displayText := prompt
|
||||
if imageCount > 0 || fileOnlyCount > 0 {
|
||||
var badges []string
|
||||
if imageCount > 0 {
|
||||
badges = append(badges, fmt.Sprintf("%d image(s) pasted", imageCount))
|
||||
}
|
||||
if fileOnlyCount > 0 {
|
||||
badges = append(badges, fmt.Sprintf("%d file(s) attached", fileOnlyCount))
|
||||
}
|
||||
displayText = fmt.Sprintf("%s\n[%s]", prompt, strings.Join(badges, ", "))
|
||||
}
|
||||
|
||||
var qLen int
|
||||
if len(files) > 0 {
|
||||
qLen = m.appCtrl.RunWithFiles(prompt, files)
|
||||
} else {
|
||||
qLen = m.appCtrl.Run(prompt)
|
||||
}
|
||||
if qLen > 0 {
|
||||
m.queuedMessages = append(m.queuedMessages, displayText)
|
||||
m.layoutDirty = true
|
||||
} else {
|
||||
m.pendingUserPrints = append(m.pendingUserPrints, displayText)
|
||||
m.flushStreamAndPendingUserMessages()
|
||||
}
|
||||
if m.state != stateWorking {
|
||||
m.state = stateWorking
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleEditCommand opens the supplied path in $EDITOR via tea.ExecProcess,
|
||||
// pausing the TUI for the duration of the editor session. The path is
|
||||
// resolved relative to cwd; ~/ and absolute paths are honoured. Non-existent
|
||||
// paths are allowed — most editors will create the file on save.
|
||||
//
|
||||
// On exit an editFileMsg is emitted with the resolved path (or error) so the
|
||||
// Update loop can report the result. The textarea is not touched — use
|
||||
// Ctrl+X e if you want to round-trip a prompt through $EDITOR instead.
|
||||
func (m *AppModel) handleEditCommand(args string) tea.Cmd {
|
||||
path := strings.TrimSpace(args)
|
||||
if path == "" {
|
||||
m.printSystemMessage("Usage: `/edit <path>` — or type `/edit ` and pick a file from the popup.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Strip optional surrounding double-quotes (the autocomplete inserts
|
||||
// these when a path contains spaces).
|
||||
if len(path) >= 2 && strings.HasPrefix(path, `"`) && strings.HasSuffix(path, `"`) {
|
||||
path = path[1 : len(path)-1]
|
||||
}
|
||||
|
||||
// Resolve ~/, relative, and absolute paths against cwd.
|
||||
resolved := path
|
||||
if strings.HasPrefix(resolved, "~/") {
|
||||
if home, err := os.UserHomeDir(); err == nil {
|
||||
resolved = filepath.Join(home, resolved[2:])
|
||||
}
|
||||
}
|
||||
if !filepath.IsAbs(resolved) {
|
||||
cwd, err := os.Getwd()
|
||||
if err == nil {
|
||||
resolved = filepath.Join(cwd, resolved)
|
||||
}
|
||||
}
|
||||
resolved = filepath.Clean(resolved)
|
||||
|
||||
// Reject paths that exist but are directories — $EDITOR semantics vary.
|
||||
if info, err := os.Stat(resolved); err == nil && info.IsDir() {
|
||||
m.printSystemMessage(fmt.Sprintf("`%s` is a directory, not a file.", resolved))
|
||||
return nil
|
||||
}
|
||||
|
||||
editorApp := os.Getenv("VISUAL")
|
||||
if editorApp == "" {
|
||||
editorApp = os.Getenv("EDITOR")
|
||||
}
|
||||
if editorApp == "" {
|
||||
m.printSystemMessage("Set `$EDITOR` or `$VISUAL` to use `/edit`")
|
||||
return nil
|
||||
}
|
||||
|
||||
editorCmd, cmdErr := editor.Command(editorApp, resolved)
|
||||
if cmdErr != nil {
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to open editor: %v", cmdErr))
|
||||
return nil
|
||||
}
|
||||
|
||||
return tea.ExecProcess(editorCmd, func(err error) tea.Msg {
|
||||
return editFileMsg{path: resolved, err: err}
|
||||
})
|
||||
}
|
||||
|
||||
// handleExportCommand exports the current session to a file.
|
||||
// Usage: /export — copies the JSONL file to cwd with a descriptive name.
|
||||
//
|
||||
@@ -4561,61 +4795,11 @@ func (m *AppModel) handleShareCommand() tea.Cmd {
|
||||
return r
|
||||
}, name)
|
||||
|
||||
tmpFile, err := os.CreateTemp("", fmt.Sprintf("kit-%s-*.jsonl", name))
|
||||
tmpPath, err := buildShareFile(name, data, sysPromptJSON)
|
||||
if err != nil {
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to create temp file: %v", err))
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to share session: %v", err))
|
||||
return nil
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
|
||||
// Write the session data with the system prompt entry inserted after the header.
|
||||
// The header is the first line, so we write:
|
||||
// 1. First line (header) from original data
|
||||
// 2. System prompt entry
|
||||
// 3. Remaining lines from original data
|
||||
lines := strings.Split(string(data), "\n")
|
||||
if len(lines) > 0 && lines[len(lines)-1] == "" {
|
||||
lines = lines[:len(lines)-1] // Remove trailing empty line
|
||||
}
|
||||
|
||||
if len(lines) > 0 {
|
||||
// Write header (first line)
|
||||
if _, err := tmpFile.WriteString(lines[0] + "\n"); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to write temp file: %v", err))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write system prompt entry
|
||||
if _, err := tmpFile.Write(sysPromptJSON); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to write system prompt: %v", err))
|
||||
return nil
|
||||
}
|
||||
if _, err := tmpFile.WriteString("\n"); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to write temp file: %v", err))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write remaining lines
|
||||
for i := 1; i < len(lines); i++ {
|
||||
if lines[i] == "" {
|
||||
continue // Skip empty lines
|
||||
}
|
||||
if _, err := tmpFile.WriteString(lines[i] + "\n"); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to write temp file: %v", err))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_ = tmpFile.Close()
|
||||
|
||||
m.printSystemMessage("Uploading session to GitHub Gist...")
|
||||
|
||||
@@ -4641,6 +4825,56 @@ func (m *AppModel) handleShareCommand() tea.Cmd {
|
||||
}
|
||||
}
|
||||
|
||||
// buildShareFile assembles a temp JSONL file containing the session data
|
||||
// with the system-prompt entry inserted after the header line. On success
|
||||
// the caller owns the returned file and must remove it when done; on error
|
||||
// any partially-written temp file has already been cleaned up.
|
||||
func buildShareFile(name string, data, sysPromptJSON []byte) (tmpPath string, err error) {
|
||||
tmpFile, err := os.CreateTemp("", fmt.Sprintf("kit-%s-*.jsonl", name))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
tmpPath = tmpFile.Name()
|
||||
defer func() {
|
||||
_ = tmpFile.Close()
|
||||
if err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
}
|
||||
}()
|
||||
|
||||
// Write the session data with the system prompt entry inserted after the
|
||||
// header. The header is the first line, so we write:
|
||||
// 1. First line (header) from original data
|
||||
// 2. System prompt entry
|
||||
// 3. Remaining lines from original data
|
||||
lines := strings.Split(string(data), "\n")
|
||||
if len(lines) > 0 && lines[len(lines)-1] == "" {
|
||||
lines = lines[:len(lines)-1] // Remove trailing empty line
|
||||
}
|
||||
if len(lines) == 0 {
|
||||
return tmpPath, nil
|
||||
}
|
||||
|
||||
if _, err = tmpFile.WriteString(lines[0] + "\n"); err != nil {
|
||||
return "", fmt.Errorf("write temp file: %w", err)
|
||||
}
|
||||
if _, err = tmpFile.Write(sysPromptJSON); err != nil {
|
||||
return "", fmt.Errorf("write system prompt: %w", err)
|
||||
}
|
||||
if _, err = tmpFile.WriteString("\n"); err != nil {
|
||||
return "", fmt.Errorf("write temp file: %w", err)
|
||||
}
|
||||
for i := 1; i < len(lines); i++ {
|
||||
if lines[i] == "" {
|
||||
continue // Skip empty lines
|
||||
}
|
||||
if _, err = tmpFile.WriteString(lines[i] + "\n"); err != nil {
|
||||
return "", fmt.Errorf("write temp file: %w", err)
|
||||
}
|
||||
}
|
||||
return tmpPath, nil
|
||||
}
|
||||
|
||||
// handleImportCommand imports a session from a JSONL file.
|
||||
// Usage: /import path.jsonl
|
||||
func (m *AppModel) handleImportCommand(args string) tea.Cmd {
|
||||
@@ -4856,6 +5090,14 @@ type externalEditorMsg struct {
|
||||
err error
|
||||
}
|
||||
|
||||
// editFileMsg is sent when the user returns from $EDITOR after invoking the
|
||||
// /edit slash command on a specific file. Unlike externalEditorMsg, no text
|
||||
// is read back — the user edited the file directly on disk.
|
||||
type editFileMsg struct {
|
||||
path string
|
||||
err error
|
||||
}
|
||||
|
||||
// shareResultMsg carries the result of an async gist upload.
|
||||
type shareResultMsg struct {
|
||||
err error
|
||||
|
||||
@@ -2,6 +2,7 @@ package ui
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -87,6 +88,10 @@ func (s *stubAppController) SteerWithFiles(prompt string, _ []kit.LLMFilePart) i
|
||||
return s.queueLen
|
||||
}
|
||||
|
||||
func (s *stubAppController) PopLastUserMessage() (string, []kit.LLMFilePart, error) {
|
||||
return "", nil, fmt.Errorf("no user message to retry")
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Stub child components
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
uicore "github.com/mark3labs/kit/internal/ui/core"
|
||||
)
|
||||
|
||||
// drainCmds runs a tea.Cmd chain back through m.Update like the BubbleTea
|
||||
// event loop, expanding batches, until no further messages are produced.
|
||||
func drainCmds(t *testing.T, m *AppModel, cmd tea.Cmd) *AppModel {
|
||||
t.Helper()
|
||||
queue := []tea.Cmd{cmd}
|
||||
for i := 0; i < 50 && len(queue) > 0; i++ {
|
||||
c := queue[0]
|
||||
queue = queue[1:]
|
||||
if c == nil {
|
||||
continue
|
||||
}
|
||||
msg := c()
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if batch, ok := msg.(tea.BatchMsg); ok {
|
||||
queue = append(queue, batch...)
|
||||
continue
|
||||
}
|
||||
updated, nc := m.Update(msg)
|
||||
m = updated.(*AppModel)
|
||||
_ = m.View()
|
||||
if nc != nil {
|
||||
queue = append(queue, nc)
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func measuredInputHeight(m *AppModel) int {
|
||||
rendered := m.renderInput()
|
||||
if rendered == "" {
|
||||
return 0
|
||||
}
|
||||
return strings.Count(rendered, "\n") + 1
|
||||
}
|
||||
|
||||
// TestPendingThumbnailTriggersLayoutRecompute is a regression test for the bug
|
||||
// where a pasted image's async half-block preview rendered but was clipped off
|
||||
// the bottom of the screen: the thumbnail arrives via thumbnailReadyMsg after
|
||||
// distributeHeight already measured the input region without it. The parent
|
||||
// must mark the layout dirty so the (now taller) input is re-measured.
|
||||
func TestPendingThumbnailTriggersLayoutRecompute(t *testing.T) {
|
||||
// Force a truecolor profile so imagepreview.Render deterministically
|
||||
// produces a thumbnail regardless of the CI terminal's color support.
|
||||
// Without this, a low-color test environment yields an empty preview and
|
||||
// the glyph / height assertions below would flake.
|
||||
t.Setenv("TERM", "xterm-256color")
|
||||
t.Setenv("COLORTERM", "truecolor")
|
||||
t.Setenv("NO_COLOR", "")
|
||||
|
||||
real := NewInputComponent(80, nil)
|
||||
m, _, _ := newTestAppModel(nil)
|
||||
m.input = real
|
||||
m = sendMsg(m, tea.WindowSizeMsg{Width: 80, Height: 24})
|
||||
|
||||
heightBefore := measuredInputHeight(m)
|
||||
|
||||
updated, cmd := m.Update(clipboardImageMsg{image: &uicore.ImageAttachment{
|
||||
Data: makeTestPNG(t, 16, 16),
|
||||
MediaType: "image/png",
|
||||
}})
|
||||
m = updated.(*AppModel)
|
||||
_ = m.View()
|
||||
m = drainCmds(t, m, cmd)
|
||||
|
||||
heightAfter := measuredInputHeight(m)
|
||||
if heightAfter <= heightBefore {
|
||||
t.Errorf("input region should grow to fit the thumbnail (before=%d after=%d)", heightBefore, heightAfter)
|
||||
}
|
||||
|
||||
if !strings.Contains(m.View().Content, "▀") {
|
||||
t.Error("parent View should contain the half-block thumbnail (was clipped or not rendered)")
|
||||
}
|
||||
}
|
||||
+182
-46
@@ -20,17 +20,23 @@ type PopupItem struct {
|
||||
Meta any // opaque data returned on selection
|
||||
}
|
||||
|
||||
// PopupList is a generic, themed, scrollable fuzzy-find popup list. It is
|
||||
// rendered as a centered overlay on top of the normal TUI layout and can be
|
||||
// reused by any feature that needs a selection popup (slash commands, model
|
||||
// selector, session picker, extension-provided lists, etc.).
|
||||
// PopupList is a generic, themed, scrollable popup list used by every
|
||||
// list-style popup in the TUI (slash commands, @file autocomplete, model
|
||||
// picker, session picker, tree navigation, etc.).
|
||||
//
|
||||
// The caller is responsible for:
|
||||
// - Building the initial item list
|
||||
// - Providing a fuzzy-filter callback (or nil for substring matching)
|
||||
// - Handling the result when the user selects or cancels
|
||||
// Two layout modes:
|
||||
// - Centered (default): bordered ~80-col box centered on the screen. Used
|
||||
// for the input-bar popups (/ and @) and the model picker.
|
||||
// - FullScreen: bordered panel filling almost the entire terminal. Used by
|
||||
// /tree, /fork, /sessions and other browse-many-items popups.
|
||||
//
|
||||
// Navigation: up/down to move, enter to select, esc to cancel, type to filter.
|
||||
// Two usage modes:
|
||||
// - Internal state: caller creates the list with items, calls HandleKey for
|
||||
// navigation/search, and PopupList owns the cursor and search string.
|
||||
// Used by selectors like ModelSelector, TreeSelector, SessionSelector.
|
||||
// - External state: caller drives the items / cursor / search themselves
|
||||
// (e.g. InputComponent, where typing in the textarea filters the list).
|
||||
// Caller uses SetItems / SetCursor / SetSearch and only calls Render.
|
||||
type PopupList struct {
|
||||
// Title shown at the top of the popup.
|
||||
Title string
|
||||
@@ -38,20 +44,45 @@ type PopupList struct {
|
||||
Subtitle string
|
||||
// FooterHint overrides the default keyboard-hint footer.
|
||||
FooterHint string
|
||||
// ExtraFooter is appended to the footer line (after the default hint).
|
||||
// Used by selectors to surface mode info like the active filter.
|
||||
ExtraFooter string
|
||||
|
||||
allItems []PopupItem // full unfiltered list
|
||||
filtered []PopupItem // subset matching the current search
|
||||
cursor int
|
||||
search string
|
||||
// FullScreen renders the popup at almost the full terminal size instead
|
||||
// of a centered ~80-col box. Used by tree/session/fork selectors.
|
||||
FullScreen bool
|
||||
|
||||
// ShowSearch toggles the "> <query>" search input line. Default true.
|
||||
ShowSearch bool
|
||||
|
||||
// HideCount suppresses the "(i/N)" count in the footer.
|
||||
HideCount bool
|
||||
|
||||
// MaxVisible caps the number of items visible at once. 0 = derive from
|
||||
// available height.
|
||||
MaxVisible int
|
||||
|
||||
// RenderItem optionally renders a single item row. When nil, the
|
||||
// built-in label + description + active-checkmark renderer is used.
|
||||
// innerWidth is the usable line width inside the popup (after border
|
||||
// and padding). The returned string must already be styled — the
|
||||
// shared selection-row background is applied by the popup only when
|
||||
// RenderItem is nil.
|
||||
RenderItem func(item PopupItem, innerWidth int, isCursor bool) string
|
||||
|
||||
// FilterFunc is called with (query, allItems) and should return the
|
||||
// filtered+scored subset. When nil, a default substring match is used.
|
||||
// filtered+scored subset. When nil, a default substring + fuzzy match
|
||||
// is used. Only consulted in internal-state mode (via HandleKey).
|
||||
FilterFunc func(query string, items []PopupItem) []PopupItem
|
||||
|
||||
width int
|
||||
height int
|
||||
maxVisible int // max items visible at once (0 = auto from height)
|
||||
showSearch bool
|
||||
allItems []PopupItem // full unfiltered list (internal-state mode)
|
||||
filtered []PopupItem // items currently rendered (driven by FilterFunc
|
||||
// in internal-state mode, or set directly via SetItems in external mode)
|
||||
cursor int
|
||||
search string
|
||||
|
||||
width int
|
||||
height int
|
||||
}
|
||||
|
||||
// PopupResult is returned by HandleKey to tell the caller what happened.
|
||||
@@ -72,7 +103,7 @@ func NewPopupList(title string, items []PopupItem, width, height int) *PopupList
|
||||
filtered: items,
|
||||
width: width,
|
||||
height: height,
|
||||
showSearch: true,
|
||||
ShowSearch: true,
|
||||
}
|
||||
// Position cursor on the active item if one exists.
|
||||
for i, item := range p.filtered {
|
||||
@@ -90,25 +121,102 @@ func (p *PopupList) SetSize(width, height int) {
|
||||
p.height = height
|
||||
}
|
||||
|
||||
// SetItems replaces the displayed item list and clamps the cursor. Used by
|
||||
// external-state callers (e.g. InputComponent) that filter items themselves.
|
||||
// In internal-state mode, this also replaces the unfiltered backing list.
|
||||
func (p *PopupList) SetItems(items []PopupItem) {
|
||||
p.allItems = items
|
||||
p.filtered = items
|
||||
if p.cursor >= len(p.filtered) {
|
||||
p.cursor = max(len(p.filtered)-1, 0)
|
||||
}
|
||||
if p.cursor < 0 {
|
||||
p.cursor = 0
|
||||
}
|
||||
}
|
||||
|
||||
// SetCursor moves the selection to the given index (clamped to range).
|
||||
func (p *PopupList) SetCursor(i int) {
|
||||
if len(p.filtered) == 0 {
|
||||
p.cursor = 0
|
||||
return
|
||||
}
|
||||
if i < 0 {
|
||||
i = 0
|
||||
}
|
||||
if i >= len(p.filtered) {
|
||||
i = len(p.filtered) - 1
|
||||
}
|
||||
p.cursor = i
|
||||
}
|
||||
|
||||
// Cursor returns the current selection index.
|
||||
func (p *PopupList) Cursor() int { return p.cursor }
|
||||
|
||||
// SetSearch replaces the search string without rebuilding the filtered list.
|
||||
// Used by external-state callers that filter items themselves.
|
||||
func (p *PopupList) SetSearch(s string) { p.search = s }
|
||||
|
||||
// Items returns the currently-visible (filtered) items.
|
||||
func (p *PopupList) Items() []PopupItem { return p.filtered }
|
||||
|
||||
// Search returns the current search string.
|
||||
func (p *PopupList) Search() string { return p.search }
|
||||
|
||||
// dimensions returns the (popupWidth, popupHeight, innerWidth, innerHeight)
|
||||
// the popup will render at, given its current size and FullScreen flag.
|
||||
func (p *PopupList) dimensions() (popupW, popupH, innerW, innerH int) {
|
||||
if p.FullScreen {
|
||||
// Leave a small margin so the border doesn't kiss the screen edge.
|
||||
popupW = max(p.width-2, 20)
|
||||
popupH = max(p.height-2, 10)
|
||||
} else {
|
||||
// Centered: cap at 80 cols, leave a 4-col margin.
|
||||
popupW = max(min(p.width-4, 80), 20)
|
||||
// Height is dynamic — let it grow with content within the screen.
|
||||
popupH = 0
|
||||
}
|
||||
// Border (2) + horizontal padding (4) = 6 chrome cols.
|
||||
innerW = max(popupW-6, 10)
|
||||
if popupH > 0 {
|
||||
// Border (2) + vertical padding (2) = 4 chrome rows.
|
||||
innerH = max(popupH-4, 6)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// visibleCount returns the number of items visible at once.
|
||||
func (p *PopupList) visibleCount() int {
|
||||
if p.maxVisible > 0 {
|
||||
return p.maxVisible
|
||||
if p.MaxVisible > 0 {
|
||||
return p.MaxVisible
|
||||
}
|
||||
// Reserve: title(1) + subtitle(1) + search(1) + separator(1) + footer(2) + border(2) + padding(2) = 10
|
||||
if p.FullScreen {
|
||||
_, _, _, innerH := p.dimensions()
|
||||
// Reserve: title(1) + subtitle(0|1) + search(0|2) + sep(1) + footer(2)
|
||||
overhead := 4
|
||||
if p.Subtitle != "" {
|
||||
overhead++
|
||||
}
|
||||
if p.ShowSearch {
|
||||
overhead += 2
|
||||
}
|
||||
return max(innerH-overhead, 3)
|
||||
}
|
||||
// Centered: derive from terminal height (legacy behaviour).
|
||||
overhead := 8
|
||||
if p.Subtitle != "" {
|
||||
overhead++
|
||||
}
|
||||
if p.showSearch {
|
||||
overhead += 2 // search line + separator
|
||||
if p.ShowSearch {
|
||||
overhead += 2
|
||||
}
|
||||
return max(p.height/2-overhead, 3)
|
||||
}
|
||||
|
||||
// HandleKey processes a single key event and returns the result. The caller
|
||||
// should inspect PopupResult to decide whether to re-render, close the popup,
|
||||
// or act on a selection.
|
||||
// or act on a selection. Internal-state mode only — external-state callers
|
||||
// drive cursor/search themselves and never call this.
|
||||
//
|
||||
// keyName is the Bubble Tea key string (e.g. "up", "down", "enter", "esc").
|
||||
// keyText is the printable text for character keys (e.g. "a", "1").
|
||||
@@ -191,7 +299,7 @@ func (p *PopupList) HandleKey(keyName, keyText string) PopupResult {
|
||||
// as a centered overlay via lipgloss.Place + overlayContent.
|
||||
func (p *PopupList) Render() string {
|
||||
theme := style.GetTheme()
|
||||
popupWidth := max(min(p.width-4, 80), 20)
|
||||
popupW, popupH, innerW, _ := p.dimensions()
|
||||
popupBg := theme.Background
|
||||
|
||||
popupStyle := lipgloss.NewStyle().
|
||||
@@ -199,11 +307,12 @@ func (p *PopupList) Render() string {
|
||||
BorderForeground(theme.Primary).
|
||||
Background(popupBg).
|
||||
Padding(1, 2).
|
||||
Width(popupWidth).
|
||||
MarginBottom(1)
|
||||
|
||||
// Inner content width: popup minus border (2) and horizontal padding (4).
|
||||
innerWidth := max(popupWidth-6, 10)
|
||||
Width(popupW)
|
||||
if popupH > 0 {
|
||||
popupStyle = popupStyle.Height(popupH)
|
||||
} else {
|
||||
popupStyle = popupStyle.MarginBottom(1)
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
|
||||
@@ -212,7 +321,7 @@ func (p *PopupList) Render() string {
|
||||
Bold(true).
|
||||
Foreground(theme.Accent).
|
||||
Background(popupBg).
|
||||
Width(innerWidth)
|
||||
Width(innerW)
|
||||
b.WriteString(titleStyle.Render(p.Title))
|
||||
b.WriteString("\n")
|
||||
|
||||
@@ -221,17 +330,17 @@ func (p *PopupList) Render() string {
|
||||
subtitleStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(popupBg).
|
||||
Width(innerWidth)
|
||||
Width(innerW)
|
||||
b.WriteString(subtitleStyle.Render(p.Subtitle))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Search input.
|
||||
if p.showSearch {
|
||||
if p.ShowSearch {
|
||||
searchStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Info).
|
||||
Background(popupBg).
|
||||
Width(innerWidth)
|
||||
Width(innerW)
|
||||
if p.search != "" {
|
||||
b.WriteString(searchStyle.Render(fmt.Sprintf("> %s", p.search)))
|
||||
} else {
|
||||
@@ -243,7 +352,7 @@ func (p *PopupList) Render() string {
|
||||
sepStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(popupBg)
|
||||
b.WriteString(sepStyle.Render(strings.Repeat("─", innerWidth)))
|
||||
b.WriteString(sepStyle.Render(strings.Repeat("─", innerW)))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
@@ -251,20 +360,20 @@ func (p *PopupList) Render() string {
|
||||
normalItemBg := lipgloss.NewStyle().
|
||||
Background(popupBg).
|
||||
Foreground(theme.Text).
|
||||
Width(innerWidth).
|
||||
Width(innerW).
|
||||
Padding(0, 1)
|
||||
|
||||
selectedItemBg := lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(theme.Background).
|
||||
Width(innerWidth).
|
||||
Width(innerW).
|
||||
Padding(0, 1).
|
||||
Bold(true)
|
||||
|
||||
scrollStyle := lipgloss.NewStyle().
|
||||
Background(popupBg).
|
||||
Foreground(theme.VeryMuted).
|
||||
Width(innerWidth).
|
||||
Width(innerW).
|
||||
Padding(0, 1)
|
||||
|
||||
vis := p.visibleCount()
|
||||
@@ -274,7 +383,7 @@ func (p *PopupList) Render() string {
|
||||
emptyStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(popupBg).
|
||||
Width(innerWidth).
|
||||
Width(innerW).
|
||||
Padding(0, 1)
|
||||
if p.search != "" {
|
||||
items = append(items, emptyStyle.Render("No matches for \""+p.search+"\""))
|
||||
@@ -282,9 +391,14 @@ func (p *PopupList) Render() string {
|
||||
items = append(items, emptyStyle.Render("No items"))
|
||||
}
|
||||
} else {
|
||||
// Center the cursor in the visible window so the user always sees
|
||||
// context above and below. Clamp to bounds.
|
||||
startIdx := 0
|
||||
if p.cursor >= vis {
|
||||
startIdx = p.cursor - vis + 1
|
||||
if len(p.filtered) > vis {
|
||||
startIdx = max(p.cursor-vis/2, 0)
|
||||
if startIdx+vis > len(p.filtered) {
|
||||
startIdx = len(p.filtered) - vis
|
||||
}
|
||||
}
|
||||
endIdx := min(startIdx+vis, len(p.filtered))
|
||||
|
||||
@@ -292,10 +406,27 @@ func (p *PopupList) Render() string {
|
||||
items = append(items, scrollStyle.Render(" ↑ more above"))
|
||||
}
|
||||
|
||||
// Account for the consumed padding (1 left + 1 right = 2 cols)
|
||||
// when rendering item content so RenderItem callbacks can match.
|
||||
itemContentWidth := max(innerW-2, 6)
|
||||
|
||||
for i := startIdx; i < endIdx; i++ {
|
||||
entry := p.filtered[i]
|
||||
isCursor := i == p.cursor
|
||||
|
||||
if p.RenderItem != nil {
|
||||
// Custom renderer: caller produces the inner text. We still
|
||||
// wrap it in a full-width row so the selection highlight
|
||||
// covers the line edge-to-edge.
|
||||
rowStyle := normalItemBg
|
||||
if isCursor {
|
||||
rowStyle = selectedItemBg
|
||||
}
|
||||
content := p.RenderItem(entry, itemContentWidth, isCursor)
|
||||
items = append(items, rowStyle.Render(content))
|
||||
continue
|
||||
}
|
||||
|
||||
itemStyle := normalItemBg
|
||||
if isCursor {
|
||||
itemStyle = selectedItemBg
|
||||
@@ -310,7 +441,7 @@ func (p *PopupList) Render() string {
|
||||
}
|
||||
|
||||
// Build content: indicator + label + description + active checkmark.
|
||||
content := p.renderItemContent(indicator, entry, innerWidth, isCursor)
|
||||
content := p.renderItemContent(indicator, entry, itemContentWidth, isCursor)
|
||||
items = append(items, itemStyle.Render(content))
|
||||
}
|
||||
|
||||
@@ -323,19 +454,24 @@ func (p *PopupList) Render() string {
|
||||
|
||||
// Footer with count and keyboard hints.
|
||||
var footerParts []string
|
||||
footerParts = append(footerParts, fmt.Sprintf("(%d/%d)", p.cursor+1, len(p.filtered)))
|
||||
if !p.HideCount {
|
||||
footerParts = append(footerParts, fmt.Sprintf("(%d/%d)", p.cursor+1, len(p.filtered)))
|
||||
}
|
||||
|
||||
footerHint := p.FooterHint
|
||||
if footerHint == "" {
|
||||
if innerWidth >= 50 {
|
||||
if innerW >= 50 {
|
||||
footerHint = "↑↓ navigate • enter select • esc cancel • type to filter"
|
||||
} else if innerWidth >= 30 {
|
||||
} else if innerW >= 30 {
|
||||
footerHint = "↑↓ nav • ↵ select • esc"
|
||||
} else {
|
||||
footerHint = "↑↓ ↵ esc"
|
||||
}
|
||||
}
|
||||
footerParts = append(footerParts, footerHint)
|
||||
if p.ExtraFooter != "" {
|
||||
footerParts = append(footerParts, p.ExtraFooter)
|
||||
}
|
||||
|
||||
footer := lipgloss.NewStyle().
|
||||
Background(popupBg).
|
||||
|
||||
+131
-304
@@ -5,7 +5,6 @@ import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"charm.land/bubbles/v2/key"
|
||||
tea "charm.land/bubbletea/v2"
|
||||
@@ -62,17 +61,14 @@ func (m SessionFilterMode) String() string {
|
||||
// controlCharsRe matches ASCII control characters for stripping from previews.
|
||||
var controlCharsRe = regexp.MustCompile(`[\x00-\x1f\x7f]`)
|
||||
|
||||
// SessionSelectorComponent is a full-screen Bubble Tea component that lets
|
||||
// the user browse and select from available sessions. Modeled after pi's
|
||||
// session picker: right-aligned metadata, background-highlighted selection,
|
||||
// scope/filter toggles, and inline search.
|
||||
// SessionSelectorComponent is a Bubble Tea component that lets the user browse
|
||||
// and select from available sessions. It wraps PopupList in FullScreen mode:
|
||||
// PopupList owns the cursor/search/scroll math/chrome; this component owns
|
||||
// the session list, scope/filter toggles, and delete-confirmation flow.
|
||||
type SessionSelectorComponent struct {
|
||||
allSessions []session.SessionInfo
|
||||
cwdSessions []session.SessionInfo
|
||||
filtered []session.SessionInfo
|
||||
|
||||
cursor int
|
||||
search string
|
||||
filtered []session.SessionInfo // matches popup.Items() 1:1
|
||||
|
||||
scope SessionScopeMode
|
||||
filter SessionFilterMode
|
||||
@@ -80,6 +76,7 @@ type SessionSelectorComponent struct {
|
||||
// currentPath is the active session file path for marking it in the list.
|
||||
currentPath string
|
||||
|
||||
popup *PopupList
|
||||
width int
|
||||
height int
|
||||
active bool
|
||||
@@ -110,7 +107,12 @@ func NewSessionSelector(cwd string, width, height int) *SessionSelectorComponent
|
||||
ss.scope = SessionScopeAll
|
||||
}
|
||||
|
||||
ss.rebuildFiltered()
|
||||
ss.popup = NewPopupList("Resume Session", nil, width, height)
|
||||
ss.popup.FullScreen = true
|
||||
ss.popup.FooterHint = "↑↓ nav • ↵ open • esc cancel • tab scope • ^N named • d delete • type to search"
|
||||
ss.popup.RenderItem = ss.renderEntry
|
||||
|
||||
ss.rebuild()
|
||||
return ss
|
||||
}
|
||||
|
||||
@@ -131,10 +133,11 @@ func (ss *SessionSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
case tea.WindowSizeMsg:
|
||||
ss.width = msg.Width
|
||||
ss.height = msg.Height
|
||||
ss.popup.SetSize(msg.Width, msg.Height)
|
||||
return ss, nil
|
||||
|
||||
case tea.KeyPressMsg:
|
||||
// Delete confirmation mode.
|
||||
// Delete confirmation mode swallows all keys until y/n.
|
||||
if ss.confirmDelete >= 0 {
|
||||
switch msg.String() {
|
||||
case "y", "Y":
|
||||
@@ -145,7 +148,7 @@ func (ss *SessionSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if err := session.DeleteSession(info.Path); err == nil {
|
||||
name := sessionDisplayName(info)
|
||||
ss.removeSession(info.Path)
|
||||
ss.rebuildFiltered()
|
||||
ss.rebuild()
|
||||
return ss, func() tea.Msg {
|
||||
return SessionDeletedMsg{Name: name}
|
||||
}
|
||||
@@ -159,64 +162,14 @@ func (ss *SessionSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
|
||||
switch {
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("up"))):
|
||||
if ss.cursor > 0 {
|
||||
ss.cursor--
|
||||
}
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("down"))):
|
||||
if ss.cursor < len(ss.filtered)-1 {
|
||||
ss.cursor++
|
||||
}
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("pgup"))):
|
||||
ss.cursor -= ss.visibleHeight()
|
||||
if ss.cursor < 0 {
|
||||
ss.cursor = 0
|
||||
}
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("pgdown"))):
|
||||
ss.cursor += ss.visibleHeight()
|
||||
if ss.cursor >= len(ss.filtered) {
|
||||
ss.cursor = len(ss.filtered) - 1
|
||||
}
|
||||
if ss.cursor < 0 {
|
||||
ss.cursor = 0
|
||||
}
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("home"))):
|
||||
ss.cursor = 0
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("end"))):
|
||||
ss.cursor = max(len(ss.filtered)-1, 0)
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("enter"))):
|
||||
if ss.cursor < len(ss.filtered) {
|
||||
info := ss.filtered[ss.cursor]
|
||||
ss.active = false
|
||||
return ss, func() tea.Msg {
|
||||
return SessionSelectedMsg{Path: info.Path}
|
||||
}
|
||||
}
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("esc"))):
|
||||
if ss.search != "" {
|
||||
ss.search = ""
|
||||
ss.rebuildFiltered()
|
||||
} else {
|
||||
ss.active = false
|
||||
return ss, func() tea.Msg {
|
||||
return SessionSelectorCancelledMsg{}
|
||||
}
|
||||
}
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("tab"))):
|
||||
if ss.scope == SessionScopeCwd {
|
||||
ss.scope = SessionScopeAll
|
||||
} else {
|
||||
ss.scope = SessionScopeCwd
|
||||
}
|
||||
ss.rebuildFiltered()
|
||||
ss.rebuild()
|
||||
return ss, nil
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+n"))):
|
||||
if ss.filter == SessionFilterAll {
|
||||
@@ -224,25 +177,48 @@ func (ss *SessionSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
} else {
|
||||
ss.filter = SessionFilterAll
|
||||
}
|
||||
ss.rebuildFiltered()
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("d"))):
|
||||
if ss.cursor < len(ss.filtered) {
|
||||
ss.confirmDelete = ss.cursor
|
||||
}
|
||||
ss.rebuild()
|
||||
return ss, nil
|
||||
|
||||
default:
|
||||
if msg.Text != "" && len(msg.Text) == 1 {
|
||||
ch := msg.Text[0]
|
||||
if ch >= 32 && ch < 127 {
|
||||
ss.search += string(ch)
|
||||
ss.rebuildFiltered()
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+d"))):
|
||||
// Ctrl+D as an explicit delete shortcut. Plain "d" still works
|
||||
// below when the search field is empty so it doesn't conflict
|
||||
// with typing the letter 'd' into a query.
|
||||
if c := ss.popup.Cursor(); c < len(ss.filtered) {
|
||||
ss.confirmDelete = c
|
||||
}
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
// Plain 'd' triggers delete only when there's no active search
|
||||
// query (otherwise the user would never be able to type 'd' into
|
||||
// a search like "doc").
|
||||
if msg.String() == "d" && !ss.popup.IsSearching() {
|
||||
if c := ss.popup.Cursor(); c < len(ss.filtered) {
|
||||
ss.confirmDelete = c
|
||||
return ss, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Delegate everything else to the popup.
|
||||
result := ss.popup.HandleKey(msg.String(), msg.Text)
|
||||
if result.Changed {
|
||||
ss.syncFiltered()
|
||||
}
|
||||
if result.Selected != nil {
|
||||
cursor := ss.popup.Cursor()
|
||||
if cursor < len(ss.filtered) {
|
||||
info := ss.filtered[cursor]
|
||||
ss.active = false
|
||||
return ss, func() tea.Msg {
|
||||
return SessionSelectedMsg{Path: info.Path}
|
||||
}
|
||||
}
|
||||
if key.Matches(msg, key.NewBinding(key.WithKeys("backspace"))) && len(ss.search) > 0 {
|
||||
ss.search = ss.search[:len(ss.search)-1]
|
||||
ss.rebuildFiltered()
|
||||
}
|
||||
if result.Cancelled {
|
||||
ss.active = false
|
||||
return ss, func() tea.Msg {
|
||||
return SessionSelectorCancelledMsg{}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -251,152 +227,17 @@ func (ss *SessionSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
// View implements tea.Model.
|
||||
func (ss *SessionSelectorComponent) View() tea.View {
|
||||
theme := style.GetTheme()
|
||||
|
||||
// Full-screen bordered container - uses entire terminal width and height
|
||||
maxWidth := ss.width - 2 // Small margin on each side
|
||||
if maxWidth < 20 {
|
||||
maxWidth = ss.width
|
||||
}
|
||||
maxHeight := ss.height - 2 // Small margin top/bottom to prevent overflow
|
||||
if maxHeight < 10 {
|
||||
maxHeight = ss.height
|
||||
}
|
||||
horizontalPadding := 1
|
||||
innerWidth := maxWidth - 4 // Account for border (2) + padding (2)
|
||||
innerHeight := maxHeight - 4 // Account for border (2) + padding (2)
|
||||
|
||||
// Container style with border - full width/height like a framed panel
|
||||
containerStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(theme.Primary).
|
||||
Background(theme.Background).
|
||||
Padding(1, horizontalPadding).
|
||||
Width(maxWidth).
|
||||
Height(maxHeight)
|
||||
|
||||
var contentBuilder strings.Builder
|
||||
|
||||
// ── Header: title + scope badges ─────────────────────────────
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(theme.Accent).
|
||||
Background(theme.Background)
|
||||
contentBuilder.WriteString(titleStyle.Render(fmt.Sprintf("Resume Session (%s)", ss.scope)))
|
||||
contentBuilder.WriteString("\n")
|
||||
|
||||
// ── Help / keybindings ───────────────────────────────────────
|
||||
helpStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background)
|
||||
if innerWidth >= 75 {
|
||||
contentBuilder.WriteString(helpStyle.Render("tab: scope N: named D: delete R: rename type to search esc: cancel"))
|
||||
} else if innerWidth >= 50 {
|
||||
contentBuilder.WriteString(helpStyle.Render("tab scope N named D del type to search esc"))
|
||||
} else {
|
||||
contentBuilder.WriteString(helpStyle.Render("tab N D esc"))
|
||||
}
|
||||
contentBuilder.WriteString("\n")
|
||||
|
||||
// ── Search (only shown when active) ──────────────────────────
|
||||
if ss.search != "" {
|
||||
searchStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Info).
|
||||
Background(theme.Background)
|
||||
contentBuilder.WriteString(searchStyle.Render(fmt.Sprintf("> %s", ss.search)))
|
||||
contentBuilder.WriteString("\n")
|
||||
}
|
||||
|
||||
// Separator line
|
||||
sepWidth := innerWidth
|
||||
contentBuilder.WriteString(
|
||||
lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background).
|
||||
Render(strings.Repeat("─", sepWidth)))
|
||||
contentBuilder.WriteString("\n")
|
||||
|
||||
// ── Delete confirmation ──────────────────────────────────────
|
||||
// Compose dynamic footer extras: scope + filter + (delete confirm).
|
||||
extra := fmt.Sprintf("scope: %s • filter: %s", ss.scope, ss.filter)
|
||||
if ss.confirmDelete >= 0 && ss.confirmDelete < len(ss.filtered) {
|
||||
warnStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Error).
|
||||
Bold(true).
|
||||
Background(theme.Background)
|
||||
name := sessionDisplayName(ss.filtered[ss.confirmDelete])
|
||||
contentBuilder.WriteString(warnStyle.Render(fmt.Sprintf("Delete %q? (y/N)", truncateRunes(name, 40))))
|
||||
contentBuilder.WriteString("\n")
|
||||
name := truncateRunes(sessionDisplayName(ss.filtered[ss.confirmDelete]), 30)
|
||||
extra = fmt.Sprintf("delete %q? y/N", name)
|
||||
}
|
||||
ss.popup.Title = fmt.Sprintf("Resume Session (%s)", ss.scope)
|
||||
ss.popup.ExtraFooter = extra
|
||||
|
||||
// ── Session list ─────────────────────────────────────────────
|
||||
if len(ss.filtered) == 0 {
|
||||
emptyStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background)
|
||||
if ss.search != "" {
|
||||
contentBuilder.WriteString(emptyStyle.Render(fmt.Sprintf("No sessions matching %q", ss.search)))
|
||||
} else if ss.filter == SessionFilterNamed {
|
||||
contentBuilder.WriteString(emptyStyle.Render("No named sessions. Press N to show all."))
|
||||
} else if ss.scope == SessionScopeCwd {
|
||||
contentBuilder.WriteString(emptyStyle.Render("No sessions in current folder. Press tab to view all."))
|
||||
} else {
|
||||
contentBuilder.WriteString(emptyStyle.Render("No sessions found"))
|
||||
}
|
||||
contentBuilder.WriteString("\n")
|
||||
} else {
|
||||
// Compute visible window based on inner container height
|
||||
// Chrome: header(2) + separator(1) + footer separator(1) + footer(1) = 5
|
||||
chromeLines := 5
|
||||
if ss.search != "" {
|
||||
chromeLines++
|
||||
}
|
||||
if ss.confirmDelete >= 0 {
|
||||
chromeLines++
|
||||
}
|
||||
visH := max(innerHeight-chromeLines, 3)
|
||||
|
||||
// Center the cursor in the visible window.
|
||||
startIdx := max(0, min(ss.cursor-visH/2, len(ss.filtered)-visH))
|
||||
endIdx := min(startIdx+visH, len(ss.filtered))
|
||||
|
||||
for i := startIdx; i < endIdx; i++ {
|
||||
info := ss.filtered[i]
|
||||
isCursor := i == ss.cursor
|
||||
isCurrent := info.Path == ss.currentPath
|
||||
isDeleting := i == ss.confirmDelete
|
||||
line := ss.renderEntry(info, isCursor, isCurrent, isDeleting, innerWidth)
|
||||
contentBuilder.WriteString(line)
|
||||
contentBuilder.WriteString("\n")
|
||||
}
|
||||
|
||||
// Scroll position indicator.
|
||||
if len(ss.filtered) > visH {
|
||||
posStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background)
|
||||
contentBuilder.WriteString(posStyle.Render(fmt.Sprintf("(%d/%d)", ss.cursor+1, len(ss.filtered))))
|
||||
contentBuilder.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Footer separator
|
||||
contentBuilder.WriteString(
|
||||
lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background).
|
||||
Render(strings.Repeat("─", sepWidth)))
|
||||
contentBuilder.WriteString("\n")
|
||||
|
||||
// Footer with filter info
|
||||
footerStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background)
|
||||
contentBuilder.WriteString(footerStyle.Render(fmt.Sprintf("Filter: %s", ss.filter)))
|
||||
|
||||
// Apply the bordered container
|
||||
content := contentBuilder.String()
|
||||
borderedContent := containerStyle.Render(content)
|
||||
|
||||
v := tea.NewView(borderedContent)
|
||||
rendered := ss.popup.RenderCentered(ss.width, ss.height)
|
||||
v := tea.NewView(rendered)
|
||||
v.AltScreen = true
|
||||
return v
|
||||
}
|
||||
@@ -408,20 +249,9 @@ func (ss *SessionSelectorComponent) IsActive() bool {
|
||||
|
||||
// --- Internal helpers ---
|
||||
|
||||
func (ss *SessionSelectorComponent) visibleHeight() int {
|
||||
// Reserve: title(1) + help(1) + blank(1) + scroll indicator(1) = 4.
|
||||
// Optional: search(1), delete confirm(1).
|
||||
chrome := 4
|
||||
if ss.search != "" {
|
||||
chrome++
|
||||
}
|
||||
if ss.confirmDelete >= 0 {
|
||||
chrome++
|
||||
}
|
||||
return max(ss.height-chrome, 3)
|
||||
}
|
||||
|
||||
func (ss *SessionSelectorComponent) rebuildFiltered() {
|
||||
// rebuild applies the scope and filter selections, then publishes the
|
||||
// resulting session list to the popup.
|
||||
func (ss *SessionSelectorComponent) rebuild() {
|
||||
var source []session.SessionInfo
|
||||
if ss.scope == SessionScopeCwd {
|
||||
source = ss.cwdSessions
|
||||
@@ -439,23 +269,33 @@ func (ss *SessionSelectorComponent) rebuildFiltered() {
|
||||
source = named
|
||||
}
|
||||
|
||||
if ss.search != "" {
|
||||
query := strings.ToLower(ss.search)
|
||||
var matches []session.SessionInfo
|
||||
for _, s := range source {
|
||||
haystack := strings.ToLower(s.Name + " " + s.FirstMessage + " " + s.Cwd)
|
||||
if strings.Contains(haystack, query) {
|
||||
matches = append(matches, s)
|
||||
}
|
||||
// Build PopupItems. The Label holds a haystack string (name + first
|
||||
// message + cwd) so PopupList's default filter can match against any
|
||||
// of those fields. We render each row with a custom RenderItem.
|
||||
items := make([]PopupItem, len(source))
|
||||
for i, s := range source {
|
||||
haystack := strings.TrimSpace(s.Name + " " + s.FirstMessage + " " + s.Cwd)
|
||||
items[i] = PopupItem{
|
||||
Label: haystack,
|
||||
Active: s.Path == ss.currentPath,
|
||||
Meta: s,
|
||||
}
|
||||
ss.filtered = matches
|
||||
} else {
|
||||
ss.filtered = source
|
||||
}
|
||||
ss.popup.SetItems(items)
|
||||
ss.syncFiltered()
|
||||
}
|
||||
|
||||
if ss.cursor >= len(ss.filtered) {
|
||||
ss.cursor = max(len(ss.filtered)-1, 0)
|
||||
// syncFiltered refreshes the filtered slice from popup.Items() so cursor
|
||||
// indices map back to session.SessionInfo for the parent.
|
||||
func (ss *SessionSelectorComponent) syncFiltered() {
|
||||
items := ss.popup.Items()
|
||||
out := make([]session.SessionInfo, 0, len(items))
|
||||
for _, it := range items {
|
||||
if s, ok := it.Meta.(session.SessionInfo); ok {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
ss.filtered = out
|
||||
}
|
||||
|
||||
func (ss *SessionSelectorComponent) removeSession(path string) {
|
||||
@@ -473,87 +313,74 @@ func removeByPath(sessions []session.SessionInfo, path string) []session.Session
|
||||
return result
|
||||
}
|
||||
|
||||
// renderEntry renders a single session line with right-aligned metadata.
|
||||
// Layout: [cursor 2] [message ...variable...] [padding] [count age] [cwd?]
|
||||
func (ss *SessionSelectorComponent) renderEntry(info session.SessionInfo, isCursor, isCurrent, isDeleting bool, width int) string {
|
||||
// renderEntry is the RenderItem callback handed to PopupList. It produces a
|
||||
// single-line entry with left-aligned message text and right-aligned
|
||||
// metadata (message count + relative time, plus optional cwd in "All" scope).
|
||||
//
|
||||
// When isCursor we return a plain (unstyled) string so PopupList's outer
|
||||
// row style can paint one continuous fg+bg span. Mixing inner lipgloss
|
||||
// Render calls with an outer Background() breaks the highlight into bars,
|
||||
// because each inner Render emits an ANSI reset that drops the background.
|
||||
func (ss *SessionSelectorComponent) renderEntry(item PopupItem, innerWidth int, isCursor bool) string {
|
||||
theme := style.GetTheme()
|
||||
info, ok := item.Meta.(session.SessionInfo)
|
||||
if !ok {
|
||||
return item.Label
|
||||
}
|
||||
isCurrent := info.Path == ss.currentPath
|
||||
isDeleting := ss.confirmDelete >= 0 && ss.confirmDelete < len(ss.filtered) &&
|
||||
ss.filtered[ss.confirmDelete].Path == info.Path
|
||||
|
||||
// ── Cursor indicator (2 chars) ───────────────────────────────
|
||||
cursorStr := " "
|
||||
// Cursor indicator (2 cells).
|
||||
indicator := " "
|
||||
if isCursor {
|
||||
cursorStr = lipgloss.NewStyle().Foreground(theme.Accent).Render("> ")
|
||||
indicator = "> "
|
||||
}
|
||||
const cursorW = 2
|
||||
|
||||
// ── Right part: message count + relative time (+ optional cwd) ──
|
||||
// Right-hand metadata.
|
||||
age := relativeTime(info.Modified)
|
||||
msgCount := fmt.Sprintf("%d", info.MessageCount)
|
||||
rightPart := msgCount + " " + age
|
||||
right := fmt.Sprintf("%d %s", info.MessageCount, age)
|
||||
if ss.scope == SessionScopeAll && info.Cwd != "" {
|
||||
shortCwd := shortenPath(info.Cwd)
|
||||
if len(shortCwd) > 25 {
|
||||
shortCwd = "..." + shortCwd[len(shortCwd)-22:]
|
||||
}
|
||||
rightPart = shortCwd + " " + rightPart
|
||||
shortCwd := truncateRunes(shortenPath(info.Cwd), 25)
|
||||
right = shortCwd + " " + right
|
||||
}
|
||||
rightW := utf8.RuneCountInString(rightPart)
|
||||
rightW := lipgloss.Width(right)
|
||||
|
||||
// Message text width: innerWidth minus indicator(2) minus right minus gap(2).
|
||||
availForMsg := max(innerWidth-2-rightW-2, 10)
|
||||
|
||||
// ── Message text ─────────────────────────────────────────────
|
||||
displayText := sessionDisplayName(info)
|
||||
// Strip control characters and collapse whitespace.
|
||||
displayText = controlCharsRe.ReplaceAllString(displayText, " ")
|
||||
displayText = strings.Join(strings.Fields(displayText), " ")
|
||||
displayText = truncateRunes(displayText, availForMsg)
|
||||
|
||||
availableForMsg := max(width-cursorW-rightW-2, 10) // 2 for min spacing
|
||||
displayText = truncateRunes(displayText, availableForMsg)
|
||||
msgW := utf8.RuneCountInString(displayText)
|
||||
msgW := lipgloss.Width(displayText)
|
||||
spacing := max(innerWidth-2-msgW-rightW, 1)
|
||||
|
||||
// ── Style the message ────────────────────────────────────────
|
||||
var msgStyle lipgloss.Style
|
||||
// Selected row: raw string, outer row style paints it.
|
||||
if isCursor {
|
||||
return indicator + displayText + strings.Repeat(" ", spacing) + right
|
||||
}
|
||||
|
||||
// Color the message text by state.
|
||||
var msgStyle, rightStyle lipgloss.Style
|
||||
switch {
|
||||
case isDeleting:
|
||||
msgStyle = lipgloss.NewStyle().Foreground(theme.Error)
|
||||
case isCurrent:
|
||||
msgStyle = lipgloss.NewStyle().Foreground(theme.Accent)
|
||||
msgStyle = lipgloss.NewStyle().Foreground(theme.Accent).Bold(true)
|
||||
case info.Name != "":
|
||||
msgStyle = lipgloss.NewStyle().Foreground(theme.Warning)
|
||||
default:
|
||||
msgStyle = lipgloss.NewStyle().Foreground(theme.Text)
|
||||
}
|
||||
|
||||
// ── Style the right part ─────────────────────────────────────
|
||||
rightColor := theme.Muted
|
||||
if isDeleting {
|
||||
rightColor = theme.Error
|
||||
}
|
||||
var styledRight string
|
||||
|
||||
// ── Assemble with spacing ────────────────────────────────────
|
||||
spacing := max(width-cursorW-msgW-rightW, 1)
|
||||
|
||||
// If selected, use inverted colors like PopupList
|
||||
if isCursor {
|
||||
// Inverted colors for selected item
|
||||
msgStyle = lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(theme.Background).
|
||||
Bold(true)
|
||||
styledRight = lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(rightColor).
|
||||
Render(rightPart)
|
||||
cursorStr = lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(theme.Accent).
|
||||
Render("> ")
|
||||
rightStyle = lipgloss.NewStyle().Foreground(theme.Error)
|
||||
} else {
|
||||
styledRight = lipgloss.NewStyle().Foreground(rightColor).Render(rightPart)
|
||||
rightStyle = lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
}
|
||||
|
||||
styledMsg := msgStyle.Render(displayText)
|
||||
line := cursorStr + styledMsg + strings.Repeat(" ", spacing) + styledRight
|
||||
|
||||
return line
|
||||
return indicator + msgStyle.Render(displayText) + strings.Repeat(" ", spacing) + rightStyle.Render(right)
|
||||
}
|
||||
|
||||
// --- Package helpers ---
|
||||
@@ -570,7 +397,7 @@ func sessionDisplayName(info session.SessionInfo) string {
|
||||
return "(empty session)"
|
||||
}
|
||||
|
||||
// truncateRunes truncates a string to at most maxRunes runes, appending "..."
|
||||
// truncateRunes truncates a string to at most maxRunes runes, appending "…"
|
||||
// if truncated.
|
||||
func truncateRunes(s string, maxRunes int) string {
|
||||
if maxRunes <= 0 {
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/png"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
uicore "github.com/mark3labs/kit/internal/ui/core"
|
||||
)
|
||||
|
||||
// makeTestPNG builds a small solid-color PNG for transcript preview tests.
|
||||
func makeTestPNG(t *testing.T, w, h int) []byte {
|
||||
t.Helper()
|
||||
img := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
for y := range h {
|
||||
for x := range w {
|
||||
img.Set(x, y, color.RGBA{R: 200, G: 40, B: 90, A: 255})
|
||||
}
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
if err := png.Encode(&buf, img); err != nil {
|
||||
t.Fatalf("encode png: %v", err)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func TestTranscriptPreviewCmdNoImages(t *testing.T) {
|
||||
m, _, _ := newTestAppModel(nil)
|
||||
if cmd := m.transcriptPreviewCmd(nil, ""); cmd != nil {
|
||||
t.Error("expected nil cmd when there are no images")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranscriptPreviewCmdRendersBlock(t *testing.T) {
|
||||
m, _, _ := newTestAppModel(nil)
|
||||
images := []uicore.ImageAttachment{
|
||||
{Data: makeTestPNG(t, 16, 16), MediaType: "image/png"},
|
||||
}
|
||||
cmd := m.transcriptPreviewCmd(images, "anchor-1")
|
||||
if cmd == nil {
|
||||
t.Fatal("expected a non-nil cmd for a valid image")
|
||||
}
|
||||
msg := cmd()
|
||||
// The result depends on the test process color profile. When the
|
||||
// terminal supports color the cmd yields a preview block; otherwise it
|
||||
// yields nil (caller keeps the text badge). Both are valid — assert the
|
||||
// shape only when a block is produced.
|
||||
if msg == nil {
|
||||
t.Skip("color profile below ANSI256 in test env; preview correctly skipped")
|
||||
}
|
||||
ready, ok := msg.(imagePreviewReadyMsg)
|
||||
if !ok {
|
||||
t.Fatalf("expected imagePreviewReadyMsg, got %T", msg)
|
||||
}
|
||||
if !strings.Contains(ready.block, "▀") {
|
||||
t.Errorf("preview block should contain half-block glyphs, got %q", ready.block)
|
||||
}
|
||||
if ready.anchorID != "anchor-1" {
|
||||
t.Errorf("preview should carry the originating anchorID, got %q", ready.anchorID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImagePreviewReadyMsgAppendsItem(t *testing.T) {
|
||||
m, _, _ := newTestAppModel(nil)
|
||||
before := len(m.messages)
|
||||
m = sendMsg(m, imagePreviewReadyMsg{block: "\x1b[38;2;1;2;3;48;2;4;5;6m▀\x1b[0m"})
|
||||
if len(m.messages) != before+1 {
|
||||
t.Fatalf("expected one appended message item, got %d (was %d)", len(m.messages), before)
|
||||
}
|
||||
last, ok := m.messages[len(m.messages)-1].(*TextMessageItem)
|
||||
if !ok {
|
||||
t.Fatalf("expected last item to be *TextMessageItem, got %T", m.messages[len(m.messages)-1])
|
||||
}
|
||||
if !strings.Contains(last.Render(0), "▀") {
|
||||
t.Error("appended preview item should render the half-block block verbatim")
|
||||
}
|
||||
}
|
||||
|
||||
// TestImagePreviewReadyMsgInsertsAfterAnchor verifies the preview is placed
|
||||
// directly after its originating user message even when a later message (e.g.
|
||||
// a streamed assistant reply) was already appended while the thumbnail was
|
||||
// being decoded asynchronously.
|
||||
func TestImagePreviewReadyMsgInsertsAfterAnchor(t *testing.T) {
|
||||
m, _, _ := newTestAppModel(nil)
|
||||
userItem := NewStyledMessageItem("user-anchor", "user", "hi", "hi")
|
||||
assistantItem := NewStyledMessageItem("assistant-1", "assistant", "reply", "reply")
|
||||
m.messages = append(m.messages, userItem, assistantItem)
|
||||
|
||||
m = sendMsg(m, imagePreviewReadyMsg{
|
||||
block: "\x1b[38;2;1;2;3;48;2;4;5;6m▀\x1b[0m",
|
||||
anchorID: "user-anchor",
|
||||
})
|
||||
|
||||
// Expect order: user, preview, assistant.
|
||||
if len(m.messages) != 3 {
|
||||
t.Fatalf("expected 3 messages, got %d", len(m.messages))
|
||||
}
|
||||
if m.messages[0].ID() != "user-anchor" {
|
||||
t.Errorf("messages[0] should be the user message, got %q", m.messages[0].ID())
|
||||
}
|
||||
if m.messages[2].ID() != "assistant-1" {
|
||||
t.Errorf("messages[2] should be the assistant message, got %q", m.messages[2].ID())
|
||||
}
|
||||
if !strings.Contains(m.messages[1].Render(0), "▀") {
|
||||
t.Errorf("messages[1] should be the inserted preview, got %q", m.messages[1].Render(0))
|
||||
}
|
||||
}
|
||||
|
||||
// TestImagePreviewReadyMsgUnknownAnchorAppends verifies that when the anchor
|
||||
// is missing (e.g. the message was cleared), the preview falls back to append.
|
||||
func TestImagePreviewReadyMsgUnknownAnchorAppends(t *testing.T) {
|
||||
m, _, _ := newTestAppModel(nil)
|
||||
m.messages = append(m.messages, NewStyledMessageItem("only", "user", "hi", "hi"))
|
||||
m = sendMsg(m, imagePreviewReadyMsg{
|
||||
block: "\x1b[38;2;1;2;3;48;2;4;5;6m▀\x1b[0m",
|
||||
anchorID: "does-not-exist",
|
||||
})
|
||||
if len(m.messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(m.messages))
|
||||
}
|
||||
if !strings.Contains(m.messages[1].Render(0), "▀") {
|
||||
t.Error("preview should be appended as the last item when anchor is unknown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImagePreviewReadyMsgEmptyBlockIgnored(t *testing.T) {
|
||||
m, _, _ := newTestAppModel(nil)
|
||||
before := len(m.messages)
|
||||
m = sendMsg(m, imagePreviewReadyMsg{block: ""})
|
||||
if len(m.messages) != before {
|
||||
t.Errorf("empty preview block should not append an item; got %d (was %d)", len(m.messages), before)
|
||||
}
|
||||
}
|
||||
+183
-315
@@ -53,16 +53,19 @@ type FlatNode struct {
|
||||
}
|
||||
|
||||
// TreeSelectorComponent is a Bubble Tea component that renders the session
|
||||
// tree as an ASCII art list with navigation and selection.
|
||||
// tree as an ASCII art list with navigation and selection. It is a thin
|
||||
// wrapper around PopupList (in FullScreen mode) — PopupList owns the cursor,
|
||||
// search, scroll math, and chrome; TreeSelectorComponent supplies the
|
||||
// filtered node list and a custom RenderItem that draws each tree node with
|
||||
// its indentation prefix and role colors.
|
||||
type TreeSelectorComponent struct {
|
||||
tm *session.TreeManager
|
||||
flatNodes []FlatNode
|
||||
cursor int
|
||||
flatNodes []FlatNode // visible nodes (matches popup.Items() 1:1)
|
||||
filter TreeFilterMode
|
||||
leafID string // real leaf for "active" marker
|
||||
popup *PopupList
|
||||
width int
|
||||
height int
|
||||
search string
|
||||
active bool
|
||||
selectedID string // set when user selects a node
|
||||
cancelled bool
|
||||
@@ -78,11 +81,12 @@ func NewTreeSelector(tm *session.TreeManager, width, height int) *TreeSelectorCo
|
||||
height: height,
|
||||
active: true,
|
||||
}
|
||||
ts.rebuildFlatList()
|
||||
ts.initPopup()
|
||||
ts.rebuild()
|
||||
// Position cursor at the active leaf.
|
||||
for i, node := range ts.flatNodes {
|
||||
if node.ID == ts.leafID {
|
||||
ts.cursor = i
|
||||
ts.popup.SetCursor(i)
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -100,17 +104,25 @@ func NewTreeSelectorForFork(tm *session.TreeManager, width, height int) *TreeSel
|
||||
height: height,
|
||||
active: true,
|
||||
}
|
||||
ts.rebuildFlatList()
|
||||
ts.initPopup()
|
||||
ts.rebuild()
|
||||
// Position cursor at the last user message before the leaf.
|
||||
for i := len(ts.flatNodes) - 1; i >= 0; i-- {
|
||||
if ts.isUserMessage(ts.flatNodes[i].Entry) {
|
||||
ts.cursor = i
|
||||
ts.popup.SetCursor(i)
|
||||
break
|
||||
}
|
||||
}
|
||||
return ts
|
||||
}
|
||||
|
||||
func (ts *TreeSelectorComponent) initPopup() {
|
||||
ts.popup = NewPopupList("Session Tree", nil, ts.width, ts.height)
|
||||
ts.popup.FullScreen = true
|
||||
ts.popup.FooterHint = "↑↓ nav • ←→ page • ↵ select • esc cancel • ^O filter • type to search"
|
||||
ts.popup.RenderItem = ts.renderNode
|
||||
}
|
||||
|
||||
// Init implements tea.Model.
|
||||
func (ts *TreeSelectorComponent) Init() tea.Cmd {
|
||||
return nil
|
||||
@@ -122,96 +134,75 @@ func (ts *TreeSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
case tea.WindowSizeMsg:
|
||||
ts.width = msg.Width
|
||||
ts.height = msg.Height
|
||||
ts.popup.SetSize(msg.Width, msg.Height)
|
||||
return ts, nil
|
||||
|
||||
case tea.KeyPressMsg:
|
||||
// Tree-specific keys we handle ourselves before delegating to popup.
|
||||
switch {
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("up"))):
|
||||
if ts.cursor > 0 {
|
||||
ts.cursor--
|
||||
}
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("down"))):
|
||||
if ts.cursor < len(ts.flatNodes)-1 {
|
||||
ts.cursor++
|
||||
}
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("left", "pgup"))):
|
||||
// Page up.
|
||||
ts.cursor -= ts.visibleHeight()
|
||||
if ts.cursor < 0 {
|
||||
ts.cursor = 0
|
||||
}
|
||||
result := ts.popup.HandleKey("pgup", "")
|
||||
_ = result
|
||||
return ts, nil
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("right", "pgdown"))):
|
||||
// Page down.
|
||||
ts.cursor += ts.visibleHeight()
|
||||
if ts.cursor >= len(ts.flatNodes) {
|
||||
ts.cursor = len(ts.flatNodes) - 1
|
||||
}
|
||||
result := ts.popup.HandleKey("pgdown", "")
|
||||
_ = result
|
||||
return ts, nil
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("home"))):
|
||||
ts.cursor = 0
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+o"))):
|
||||
ts.filter = (ts.filter + 1) % 5
|
||||
ts.rebuild()
|
||||
return ts, nil
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("end"))):
|
||||
ts.cursor = len(ts.flatNodes) - 1
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+d"))):
|
||||
ts.filter = TreeFilterDefault
|
||||
ts.rebuild()
|
||||
return ts, nil
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+t"))):
|
||||
ts.filter = TreeFilterNoTools
|
||||
ts.rebuild()
|
||||
return ts, nil
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+u"))):
|
||||
ts.filter = TreeFilterUserOnly
|
||||
ts.rebuild()
|
||||
return ts, nil
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+l"))):
|
||||
ts.filter = TreeFilterLabelOnly
|
||||
ts.rebuild()
|
||||
return ts, nil
|
||||
}
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("enter"))):
|
||||
if ts.cursor < len(ts.flatNodes) {
|
||||
ts.selectedID = ts.flatNodes[ts.cursor].ID
|
||||
// Delegate everything else (nav, search, enter, esc) to the popup.
|
||||
result := ts.popup.HandleKey(msg.String(), msg.Text)
|
||||
|
||||
// Update our flatNodes view if popup filtered/changed search.
|
||||
if result.Changed {
|
||||
ts.syncFlatNodes()
|
||||
}
|
||||
|
||||
if result.Selected != nil {
|
||||
cursor := ts.popup.Cursor()
|
||||
if cursor < len(ts.flatNodes) {
|
||||
node := ts.flatNodes[cursor]
|
||||
ts.selectedID = node.ID
|
||||
ts.active = false
|
||||
return ts, func() tea.Msg {
|
||||
return core.TreeNodeSelectedMsg{
|
||||
ID: ts.selectedID,
|
||||
Entry: ts.flatNodes[ts.cursor].Entry,
|
||||
IsUser: ts.isUserMessage(ts.flatNodes[ts.cursor].Entry),
|
||||
UserText: ts.extractUserText(ts.flatNodes[ts.cursor].Entry),
|
||||
ID: node.ID,
|
||||
Entry: node.Entry,
|
||||
IsUser: ts.isUserMessage(node.Entry),
|
||||
UserText: ts.extractUserText(node.Entry),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("esc"))):
|
||||
if ts.search != "" {
|
||||
ts.search = ""
|
||||
ts.rebuildFlatList()
|
||||
} else {
|
||||
ts.cancelled = true
|
||||
ts.active = false
|
||||
return ts, func() tea.Msg {
|
||||
return core.TreeCancelledMsg{}
|
||||
}
|
||||
}
|
||||
|
||||
// Filter cycle with ctrl+o.
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+o"))):
|
||||
ts.filter = (ts.filter + 1) % 5
|
||||
ts.rebuildFlatList()
|
||||
|
||||
// Direct filter shortcuts.
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+d"))):
|
||||
ts.filter = TreeFilterDefault
|
||||
ts.rebuildFlatList()
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+t"))):
|
||||
ts.filter = TreeFilterNoTools
|
||||
ts.rebuildFlatList()
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+u"))):
|
||||
ts.filter = TreeFilterUserOnly
|
||||
ts.rebuildFlatList()
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+l"))):
|
||||
ts.filter = TreeFilterLabelOnly
|
||||
ts.rebuildFlatList()
|
||||
default:
|
||||
// Typing search.
|
||||
if msg.Text != "" && len(msg.Text) == 1 {
|
||||
ch := msg.Text[0]
|
||||
if ch >= 32 && ch < 127 {
|
||||
ts.search += string(ch)
|
||||
ts.rebuildFlatList()
|
||||
}
|
||||
}
|
||||
if key.Matches(msg, key.NewBinding(key.WithKeys("backspace"))) && len(ts.search) > 0 {
|
||||
ts.search = ts.search[:len(ts.search)-1]
|
||||
ts.rebuildFlatList()
|
||||
}
|
||||
if result.Cancelled {
|
||||
ts.cancelled = true
|
||||
ts.active = false
|
||||
return ts, func() tea.Msg {
|
||||
return core.TreeCancelledMsg{}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -220,128 +211,10 @@ func (ts *TreeSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
// View implements tea.Model.
|
||||
func (ts *TreeSelectorComponent) View() tea.View {
|
||||
theme := GetTheme()
|
||||
|
||||
// Full-screen bordered container - uses entire terminal width and height
|
||||
maxWidth := ts.width - 2 // Small margin on each side
|
||||
if maxWidth < 20 {
|
||||
maxWidth = ts.width
|
||||
}
|
||||
maxHeight := ts.height - 2 // Small margin top/bottom to prevent overflow
|
||||
if maxHeight < 10 {
|
||||
maxHeight = ts.height
|
||||
}
|
||||
horizontalPadding := 1
|
||||
innerWidth := maxWidth - 4 // Account for border (2) + padding (2)
|
||||
innerHeight := maxHeight - 4 // Account for border (2) + padding (2)
|
||||
|
||||
// Container style with border - full width/height like a framed panel
|
||||
containerStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(theme.Primary).
|
||||
Background(theme.Background).
|
||||
Padding(1, horizontalPadding).
|
||||
Width(maxWidth).
|
||||
Height(maxHeight)
|
||||
|
||||
// Header style with background highlight (like PopupList title)
|
||||
headerStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(theme.Accent).
|
||||
Background(theme.Background)
|
||||
|
||||
// Help text style
|
||||
helpStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background)
|
||||
|
||||
var contentBuilder strings.Builder
|
||||
|
||||
// Header row with title and help
|
||||
headerRow := headerStyle.Render("Session Tree")
|
||||
contentBuilder.WriteString(headerRow)
|
||||
contentBuilder.WriteString("\n")
|
||||
|
||||
// Help text - adapt to terminal width
|
||||
var helpText string
|
||||
if ts.width >= 70 {
|
||||
helpText = "↑/↓: move ←/→: page enter: select esc: cancel ^O: cycle filter"
|
||||
} else if ts.width >= 45 {
|
||||
helpText = "↑↓ move ↵ select esc cancel ^O filter"
|
||||
} else {
|
||||
helpText = "↑↓ ↵ esc ^O"
|
||||
}
|
||||
contentBuilder.WriteString(helpStyle.Render(helpText))
|
||||
contentBuilder.WriteString("\n")
|
||||
|
||||
// Search display (if active)
|
||||
if ts.search != "" {
|
||||
searchStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Info).
|
||||
Background(theme.Background)
|
||||
contentBuilder.WriteString(searchStyle.Render(fmt.Sprintf("> %s", ts.search)))
|
||||
contentBuilder.WriteString("\n")
|
||||
}
|
||||
|
||||
// Separator line - full width
|
||||
sepWidth := innerWidth
|
||||
contentBuilder.WriteString(
|
||||
lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background).
|
||||
Render(strings.Repeat("─", sepWidth)))
|
||||
contentBuilder.WriteString("\n")
|
||||
|
||||
// Tree content
|
||||
if len(ts.flatNodes) == 0 {
|
||||
emptyStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background)
|
||||
contentBuilder.WriteString(emptyStyle.Render("No entries in session"))
|
||||
contentBuilder.WriteString("\n")
|
||||
} else {
|
||||
// Compute visible window based on inner container height
|
||||
// Chrome: header(2) + separator(1) + footer separator(1) + footer(1) = 5
|
||||
chromeLines := 5
|
||||
if ts.search != "" {
|
||||
chromeLines++
|
||||
}
|
||||
visH := max(innerHeight-chromeLines, 3)
|
||||
|
||||
startIdx := 0
|
||||
if ts.cursor >= visH {
|
||||
startIdx = ts.cursor - visH + 1
|
||||
}
|
||||
endIdx := min(startIdx+visH, len(ts.flatNodes))
|
||||
|
||||
for i := startIdx; i < endIdx; i++ {
|
||||
node := ts.flatNodes[i]
|
||||
line := ts.renderNode(node, i == ts.cursor, node.ID == ts.leafID, innerWidth)
|
||||
contentBuilder.WriteString(line)
|
||||
contentBuilder.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Footer separator
|
||||
contentBuilder.WriteString(
|
||||
lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background).
|
||||
Render(strings.Repeat("─", sepWidth)))
|
||||
contentBuilder.WriteString("\n")
|
||||
|
||||
// Footer with count and filter
|
||||
footerStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background)
|
||||
footer := fmt.Sprintf("(%d/%d) [%s]", ts.cursor+1, len(ts.flatNodes), ts.filter)
|
||||
contentBuilder.WriteString(footerStyle.Render(footer))
|
||||
|
||||
// Apply the bordered container - full width, no centering
|
||||
content := contentBuilder.String()
|
||||
borderedContent := containerStyle.Render(content)
|
||||
|
||||
v := tea.NewView(borderedContent)
|
||||
// Update extra footer with current filter mode.
|
||||
ts.popup.ExtraFooter = fmt.Sprintf("[%s]", ts.filter)
|
||||
rendered := ts.popup.RenderCentered(ts.width, ts.height)
|
||||
v := tea.NewView(rendered)
|
||||
v.AltScreen = true
|
||||
return v
|
||||
}
|
||||
@@ -353,38 +226,46 @@ func (ts *TreeSelectorComponent) IsActive() bool {
|
||||
|
||||
// --- Internal helpers ---
|
||||
|
||||
func (ts *TreeSelectorComponent) visibleHeight() int {
|
||||
// Chrome: header(1) + help(1) + separator(1) + entries + separator(1) + footer(1) = 5 fixed.
|
||||
// Optional search line adds 1 more. Use 7 as a safe estimate.
|
||||
const chromeLines = 7
|
||||
return max(ts.height-chromeLines, 3)
|
||||
}
|
||||
|
||||
func (ts *TreeSelectorComponent) rebuildFlatList() {
|
||||
tree := ts.tm.GetTree()
|
||||
// rebuild reflattens the tree under the current filter and reseeds the popup
|
||||
// with PopupItems. Called on initial load and whenever the filter changes.
|
||||
func (ts *TreeSelectorComponent) rebuild() {
|
||||
ts.flatNodes = ts.flatNodes[:0]
|
||||
tree := ts.tm.GetTree()
|
||||
for i, root := range tree {
|
||||
isLast := i == len(tree)-1
|
||||
ts.flattenNode(root, 0, isLast, "")
|
||||
}
|
||||
ts.publishItems()
|
||||
}
|
||||
|
||||
// Apply search filter.
|
||||
if ts.search != "" {
|
||||
query := strings.ToLower(ts.search)
|
||||
filtered := make([]FlatNode, 0)
|
||||
for _, node := range ts.flatNodes {
|
||||
text := ts.entryDisplayText(node.Entry)
|
||||
if strings.Contains(strings.ToLower(text), query) {
|
||||
filtered = append(filtered, node)
|
||||
}
|
||||
// syncFlatNodes refreshes flatNodes from the popup's current filtered view.
|
||||
// Called after a search-driven HandleKey result so the cursor index matches.
|
||||
func (ts *TreeSelectorComponent) syncFlatNodes() {
|
||||
items := ts.popup.Items()
|
||||
newFlat := make([]FlatNode, len(items))
|
||||
for i, it := range items {
|
||||
if fn, ok := it.Meta.(FlatNode); ok {
|
||||
newFlat[i] = fn
|
||||
}
|
||||
ts.flatNodes = filtered
|
||||
}
|
||||
ts.flatNodes = newFlat
|
||||
}
|
||||
|
||||
// Clamp cursor.
|
||||
if ts.cursor >= len(ts.flatNodes) {
|
||||
ts.cursor = max(len(ts.flatNodes)-1, 0)
|
||||
// publishItems converts flatNodes → PopupItems and seeds the popup. We rely
|
||||
// on PopupList's default substring filter against item.Label (which holds
|
||||
// the display text) for search.
|
||||
func (ts *TreeSelectorComponent) publishItems() {
|
||||
items := make([]PopupItem, len(ts.flatNodes))
|
||||
for i, n := range ts.flatNodes {
|
||||
items[i] = PopupItem{
|
||||
Label: ts.entryDisplayText(n.Entry),
|
||||
Active: n.ID == ts.leafID,
|
||||
Meta: n,
|
||||
}
|
||||
}
|
||||
ts.popup.SetItems(items)
|
||||
// Mirror the popup's current view in flatNodes so cursor lookups work.
|
||||
ts.syncFlatNodes()
|
||||
}
|
||||
|
||||
func (ts *TreeSelectorComponent) flattenNode(node *session.TreeNode, depth int, isLast bool, gutterPrefix string) {
|
||||
@@ -473,35 +354,73 @@ func (ts *TreeSelectorComponent) passesFilter(node *session.TreeNode) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *TreeSelectorComponent) renderNode(node FlatNode, isCursor, isLeaf bool, innerWidth int) string {
|
||||
// renderNode is the RenderItem callback handed to PopupList. PopupList wraps
|
||||
// the returned string with a full-width row style.
|
||||
//
|
||||
// When isCursor we return a plain (unstyled) string so the outer row style
|
||||
// can paint a single continuous fg+bg span across the line. Composing inner
|
||||
// lipgloss.Render calls emits ANSI resets mid-string which knock the
|
||||
// background back out, breaking the highlight into disjoint bars (issue
|
||||
// observed with deep tool-interaction branches).
|
||||
func (ts *TreeSelectorComponent) renderNode(item PopupItem, innerWidth int, isCursor bool) string {
|
||||
theme := GetTheme()
|
||||
node, ok := item.Meta.(FlatNode)
|
||||
if !ok {
|
||||
return item.Label
|
||||
}
|
||||
isLeaf := node.ID == ts.leafID
|
||||
|
||||
// Cursor indicator - use ">" for selected (like PopupList)
|
||||
var cursor string
|
||||
// Indicator (2 cells).
|
||||
indicator := " "
|
||||
if isCursor {
|
||||
cursor = lipgloss.NewStyle().Foreground(theme.Accent).Render("> ")
|
||||
} else {
|
||||
cursor = " "
|
||||
indicator = "> "
|
||||
}
|
||||
|
||||
// Role-colored content with background support for selection
|
||||
text := ts.entryDisplayText(node.Entry)
|
||||
// Prefix (tree art) — width measured in display cells via lipgloss.
|
||||
prefix := node.Prefix
|
||||
prefixW := lipgloss.Width(prefix)
|
||||
|
||||
// Calculate available width accounting for cursor, prefix, and markers
|
||||
prefixLen := len(node.Prefix)
|
||||
available := innerWidth - prefixLen - 4 // 4 for cursor and some padding
|
||||
if available > 3 && len(text) > available {
|
||||
trimLen := max(available-3, 1)
|
||||
if trimLen < len(text) {
|
||||
text = text[:trimLen] + "..."
|
||||
// Compute right-side fixed parts: label badge + active marker.
|
||||
var labelBadgeRaw, activeMarkerRaw string
|
||||
if node.Label != "" {
|
||||
labelBadgeRaw = " [" + node.Label + "]"
|
||||
}
|
||||
if isLeaf {
|
||||
activeMarkerRaw = " ← active"
|
||||
}
|
||||
rightW := lipgloss.Width(labelBadgeRaw) + lipgloss.Width(activeMarkerRaw)
|
||||
|
||||
// If the tree prefix is so deep it would push the text off the row,
|
||||
// truncate the prefix from the LEFT and prepend an ellipsis. Keeping
|
||||
// the right-most segment preserves the most recent depth indicator
|
||||
// (└─ / ├─) so the user can still see this row's connection to its
|
||||
// parent. We reserve at least 20 cells for the actual entry text.
|
||||
const minTextWidth = 20
|
||||
budget := innerWidth - 2 - rightW - minTextWidth
|
||||
if prefixW > budget && budget > 2 {
|
||||
runes := []rune(prefix)
|
||||
// Strip from the left until lipgloss.Width fits the budget.
|
||||
for len(runes) > 0 && lipgloss.Width(string(runes)) > budget-1 {
|
||||
runes = runes[1:]
|
||||
}
|
||||
prefix = "…" + string(runes)
|
||||
prefixW = lipgloss.Width(prefix)
|
||||
}
|
||||
|
||||
// Build the full line style
|
||||
var lineStyle lipgloss.Style
|
||||
var textStyle lipgloss.Style
|
||||
// Reserve space for indicator(2) + prefix + right parts.
|
||||
available := max(innerWidth-2-prefixW-rightW, 4)
|
||||
|
||||
// Base text color based on role
|
||||
text := ts.entryDisplayText(node.Entry)
|
||||
text = truncateRunes(text, available)
|
||||
|
||||
// Selected row: emit raw text. The outer row style applies fg+bg in one
|
||||
// uninterrupted span, keeping the highlight solid edge-to-edge.
|
||||
if isCursor {
|
||||
return indicator + prefix + text + labelBadgeRaw + activeMarkerRaw
|
||||
}
|
||||
|
||||
// Role-based text color.
|
||||
var textStyle lipgloss.Style
|
||||
switch e := node.Entry.(type) {
|
||||
case *session.MessageEntry:
|
||||
switch e.Role {
|
||||
@@ -520,77 +439,27 @@ func (ts *TreeSelectorComponent) renderNode(node FlatNode, isCursor, isLeaf bool
|
||||
textStyle = lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
}
|
||||
|
||||
// Apply selection highlighting (like PopupList)
|
||||
if isCursor {
|
||||
// Inverted colors for selected item - matches PopupList style
|
||||
lineStyle = lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(theme.Background).
|
||||
Bold(true)
|
||||
textStyle = lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(theme.Background).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// Render components
|
||||
content := textStyle.Render(text)
|
||||
|
||||
// Label badge.
|
||||
var labelBadge string
|
||||
if node.Label != "" {
|
||||
labelStyle := lipgloss.NewStyle().Foreground(theme.Warning)
|
||||
if isCursor {
|
||||
labelStyle = lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(theme.Warning)
|
||||
}
|
||||
labelBadge = " " + labelStyle.Render("["+node.Label+"]")
|
||||
}
|
||||
|
||||
// Active marker - use Success color for better visibility
|
||||
var activeMarker string
|
||||
if isLeaf {
|
||||
markerStyle := lipgloss.NewStyle().Foreground(theme.Success).Bold(true)
|
||||
if isCursor {
|
||||
markerStyle = lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(theme.Success).
|
||||
Bold(true)
|
||||
}
|
||||
activeMarker = markerStyle.Render(" ← active")
|
||||
}
|
||||
|
||||
// Prefix (tree lines) - use MutedBorder for subtler appearance
|
||||
prefixStyle := lipgloss.NewStyle().Foreground(theme.MutedBorder)
|
||||
if isCursor {
|
||||
prefixStyle = lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(theme.MutedBorder)
|
||||
labelStyle := lipgloss.NewStyle().Foreground(theme.Warning)
|
||||
markerStyle := lipgloss.NewStyle().Foreground(theme.Success).Bold(true)
|
||||
|
||||
parts := indicator + prefixStyle.Render(prefix) + textStyle.Render(text)
|
||||
if labelBadgeRaw != "" {
|
||||
parts += labelStyle.Render(labelBadgeRaw)
|
||||
}
|
||||
renderedPrefix := prefixStyle.Render(node.Prefix)
|
||||
|
||||
// Combine all parts
|
||||
line := cursor + renderedPrefix + content + labelBadge + activeMarker
|
||||
|
||||
// If selected, apply the background to the entire line
|
||||
if isCursor {
|
||||
return lineStyle.Render(line)
|
||||
if activeMarkerRaw != "" {
|
||||
parts += markerStyle.Render(activeMarkerRaw)
|
||||
}
|
||||
|
||||
return line
|
||||
return parts
|
||||
}
|
||||
|
||||
func (ts *TreeSelectorComponent) entryDisplayText(entry any) string {
|
||||
switch e := entry.(type) {
|
||||
case *session.MessageEntry:
|
||||
role := e.Role
|
||||
text := extractTextFromParts(e.Parts)
|
||||
if len(text) > 80 {
|
||||
text = text[:80] + "..."
|
||||
}
|
||||
text := collapseToLine(extractTextFromParts(e.Parts))
|
||||
text = truncateRunes(text, 200)
|
||||
if text == "" {
|
||||
// Tool call messages may not have text.
|
||||
text = "(tool interaction)"
|
||||
}
|
||||
return fmt.Sprintf("%s: %s", role, text)
|
||||
@@ -599,18 +468,10 @@ func (ts *TreeSelectorComponent) entryDisplayText(entry any) string {
|
||||
return fmt.Sprintf("model: %s/%s", e.Provider, e.ModelID)
|
||||
|
||||
case *session.BranchSummaryEntry:
|
||||
summary := e.Summary
|
||||
if len(summary) > 60 {
|
||||
summary = summary[:60] + "..."
|
||||
}
|
||||
return fmt.Sprintf("branch summary: %s", summary)
|
||||
return fmt.Sprintf("branch summary: %s", truncateRunes(collapseToLine(e.Summary), 200))
|
||||
|
||||
case *session.CompactionEntry:
|
||||
summary := e.Summary
|
||||
if len(summary) > 60 {
|
||||
summary = summary[:60] + "..."
|
||||
}
|
||||
return fmt.Sprintf("compaction: %s", summary)
|
||||
return fmt.Sprintf("compaction: %s", truncateRunes(collapseToLine(e.Summary), 200))
|
||||
|
||||
case *session.LabelEntry:
|
||||
return fmt.Sprintf("label: %s", e.Label)
|
||||
@@ -623,6 +484,13 @@ func (ts *TreeSelectorComponent) entryDisplayText(entry any) string {
|
||||
}
|
||||
}
|
||||
|
||||
// collapseToLine flattens any multi-line string into a single line by
|
||||
// replacing whitespace runs (including newlines and tabs) with single
|
||||
// spaces. Used so popup rows never wrap and break the layout.
|
||||
func collapseToLine(s string) string {
|
||||
return strings.Join(strings.Fields(s), " ")
|
||||
}
|
||||
|
||||
func (ts *TreeSelectorComponent) isUserMessage(entry any) bool {
|
||||
if me, ok := entry.(*session.MessageEntry); ok {
|
||||
return me.Role == "user"
|
||||
|
||||
+73
-1
@@ -49,6 +49,36 @@ The SDK behaves identically to the CLI:
|
||||
- Respects all environment variables (`KIT_*`)
|
||||
- Uses the same defaults as the CLI
|
||||
|
||||
Each `kit.New` / `kit.NewAgent` call owns an **isolated configuration store**,
|
||||
so constructing multiple Kit instances in the same process is safe — setting
|
||||
the model, thinking level, or generation parameters on one never affects
|
||||
another, and runtime mutators (`SetModel`, `SetThinkingLevel`) only touch the
|
||||
owning instance. This makes subagent spawning and multi-Kit embedding race-free
|
||||
without external synchronization.
|
||||
|
||||
### Functional options (`NewAgent`)
|
||||
|
||||
For simple programmatic setups, `kit.NewAgent` is an ergonomic
|
||||
functional-options front door over `kit.New`. Streaming is enabled by default;
|
||||
pass `kit.WithStreaming(false)` to opt out.
|
||||
|
||||
```go
|
||||
host, err := kit.NewAgent(ctx,
|
||||
kit.WithModel("anthropic/claude-sonnet-4-5-20250929"),
|
||||
kit.WithSystemPrompt("You are a helpful assistant."),
|
||||
kit.WithMaxTokens(8192),
|
||||
kit.WithThinkingLevel("medium"),
|
||||
kit.Ephemeral(), // in-memory session, no persistence
|
||||
)
|
||||
```
|
||||
|
||||
Helpers: `WithModel`, `WithSystemPrompt`, `WithStreaming`, `WithMaxTokens`,
|
||||
`WithThinkingLevel`, `WithTools`, `WithExtraTools`, `WithProviderAPIKey`,
|
||||
`WithProviderURL`, `WithConfigFile`, `WithDebug`, and `Ephemeral`. `Option` is
|
||||
a plain `func(*Options)`, so you can define your own. For fields without a
|
||||
`With*` helper (`MCPConfig`, `InProcessMCPServers`, `SessionManager`, MCP task
|
||||
tuning) construct an `Options` value and call `kit.New`.
|
||||
|
||||
### Options
|
||||
|
||||
You can override specific settings:
|
||||
@@ -59,7 +89,7 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
SystemPrompt: "You are a helpful bot", // Override system prompt
|
||||
ConfigFile: "/path/to/config.yml", // Use specific config file
|
||||
MaxSteps: 10, // Override max steps
|
||||
Streaming: true, // Enable streaming
|
||||
Streaming: ptrBool(true), // *bool: nil = unset (default true), &false = off
|
||||
Quiet: true, // Suppress debug output
|
||||
|
||||
// Session options
|
||||
@@ -241,6 +271,43 @@ response, _ := host.Prompt(ctx, "What's my name?")
|
||||
host.ClearSession()
|
||||
```
|
||||
|
||||
### Runtime Skills and Context Files
|
||||
|
||||
For multi-tenant chatbots, web services, or any host that needs per-user or
|
||||
per-session instructions, the SDK lets you add, remove, and replace skills and
|
||||
project context files (e.g. `AGENTS.md`) **after** Kit construction. Every
|
||||
mutation recomposes the system prompt and applies it to the agent so the next
|
||||
turn picks up the new instructions — no restart required.
|
||||
|
||||
```go
|
||||
// Add a programmatic skill (no file on disk required).
|
||||
host.AddSkill(&kit.Skill{
|
||||
Name: "polite-french",
|
||||
Description: "Respond in French and always greet the user.",
|
||||
Content: "Always reply in French. Open every response with 'Bonjour'.",
|
||||
})
|
||||
|
||||
// Or load one from disk.
|
||||
host.LoadAndAddSkill("/var/skills/refund-policy.md")
|
||||
|
||||
// Swap per-user AGENTS.md content fetched from your database.
|
||||
host.AddContextFileContent(
|
||||
fmt.Sprintf("session://%s/AGENTS.md", userID),
|
||||
rulesFromDB,
|
||||
)
|
||||
|
||||
// Tear down session-specific state when the user logs off.
|
||||
host.RemoveSkill("polite-french")
|
||||
host.RemoveContextFile(fmt.Sprintf("session://%s/AGENTS.md", userID))
|
||||
|
||||
// Or replace the whole set in one shot.
|
||||
host.SetSkills(activeSkillsForUser)
|
||||
host.SetContextFiles(activeContextForUser)
|
||||
```
|
||||
|
||||
Readers (`GetSkills`, `GetContextFiles`) return snapshots, and every mutator
|
||||
is safe to call concurrently from multiple goroutines.
|
||||
|
||||
## Re-exported Types
|
||||
|
||||
The SDK re-exports message/session/MCP types so you don't need direct internal imports. Agent-configuration types are Kit-owned (not aliases) and use only SDK types in their signatures, so consumers never need to import the underlying LLM-provider package.
|
||||
@@ -294,6 +361,7 @@ msg := kit.ConvertFromLLMMessage(lMsg) // LLMMessage → SDK Message
|
||||
|
||||
- `Kit` - Main SDK type
|
||||
- `Options` - Configuration options
|
||||
- `Option` - Functional option (`func(*Options)`) for `NewAgent`
|
||||
- `Message` - Conversation message with typed content parts
|
||||
- `Tool` - Agent tool interface
|
||||
- `TurnResult` - Full result from a prompt including usage stats
|
||||
@@ -301,6 +369,7 @@ msg := kit.ConvertFromLLMMessage(lMsg) // LLMMessage → SDK Message
|
||||
### Key Methods
|
||||
|
||||
- `New(ctx, opts)` - Create new Kit instance
|
||||
- `NewAgent(ctx, ...Option)` - Create a Kit via functional options (streaming on by default)
|
||||
- `Prompt(ctx, message)` - Send message and get response string
|
||||
- `PromptResult(ctx, message)` - Send message and get full TurnResult
|
||||
- `PromptWithOptions(ctx, message, opts)` - Prompt with per-call options
|
||||
@@ -312,6 +381,9 @@ msg := kit.ConvertFromLLMMessage(lMsg) // LLMMessage → SDK Message
|
||||
- `ClearSession()` - Clear conversation history
|
||||
- `GetSessionPath()` - Get session file path
|
||||
- `GetSessionID()` - Get session UUID
|
||||
- `AddSkill(*Skill)` / `LoadAndAddSkill(path)` / `RemoveSkill(name)` / `SetSkills([])` - Manage skills at runtime
|
||||
- `AddContextFile(*ContextFile)` / `AddContextFileContent(path, content)` / `LoadAndAddContextFile(path)` / `RemoveContextFile(path)` / `SetContextFiles([])` - Manage AGENTS.md-style context files at runtime
|
||||
- `RefreshSystemPrompt()` - Re-apply the composed system prompt to the agent
|
||||
- `Close()` - Clean up resources
|
||||
|
||||
### Options
|
||||
|
||||
+2
-2
@@ -11,12 +11,12 @@ import (
|
||||
// treeManagerAdapter adapts TreeManager to SessionManager interface.
|
||||
// This is unexported - users don't interact with it directly.
|
||||
type treeManagerAdapter struct {
|
||||
inner *session.TreeManager
|
||||
inner *TreeManager
|
||||
}
|
||||
|
||||
// NewTreeManagerAdapter creates an adapter (exported for use in New function).
|
||||
// This is used by the SDK when no custom SessionManager is provided.
|
||||
func NewTreeManagerAdapter(tm *session.TreeManager) SessionManager {
|
||||
func NewTreeManagerAdapter(tm *TreeManager) SessionManager {
|
||||
return &treeManagerAdapter{inner: tm}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,9 @@ type AnthropicCredentials = auth.AnthropicCredentials
|
||||
// and API key authentication methods.
|
||||
type OpenAICredentials = auth.OpenAICredentials
|
||||
|
||||
// CopilotCredentials holds GitHub OAuth and Copilot API credentials.
|
||||
type CopilotCredentials = auth.CopilotCredentials
|
||||
|
||||
// CredentialStore holds all stored credentials for various providers.
|
||||
type CredentialStore = auth.CredentialStore
|
||||
|
||||
@@ -65,6 +68,37 @@ func HasOpenAICredentials() bool {
|
||||
return has
|
||||
}
|
||||
|
||||
// HasCopilotCredentials checks if valid GitHub Copilot credentials are stored.
|
||||
func HasCopilotCredentials() bool {
|
||||
cm, err := auth.NewCredentialManager()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
has, err := cm.HasCopilotCredentials()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return has
|
||||
}
|
||||
|
||||
// GetCopilotCredentials retrieves stored GitHub Copilot credentials.
|
||||
func GetCopilotCredentials() (*CopilotCredentials, error) {
|
||||
cm, err := auth.NewCredentialManager()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cm.GetCopilotCredentials()
|
||||
}
|
||||
|
||||
// GetValidCopilotAccessToken returns a fresh GitHub Copilot access token.
|
||||
func GetValidCopilotAccessToken() (string, error) {
|
||||
cm, err := auth.NewCredentialManager()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return cm.GetValidCopilotAccessToken()
|
||||
}
|
||||
|
||||
// GetOpenAIAPIKey resolves the OpenAI API key using the standard
|
||||
// resolution order: stored credentials -> OPENAI_API_KEY env var.
|
||||
// Returns an empty string if no key is found.
|
||||
|
||||
+50
-23
@@ -65,23 +65,46 @@ const sdkDefaultMaxTokens = 8192
|
||||
// which returns models.ThinkingOff.
|
||||
// - sampling params (temperature, top-p, top-k, frequency/presence-penalty):
|
||||
// left as nil pointers so provider libraries apply their own defaults.
|
||||
func setSDKDefaults() {
|
||||
viper.SetDefault("model", "anthropic/claude-sonnet-4-5-20250929")
|
||||
viper.SetDefault("system-prompt", defaultSystemPrompt)
|
||||
viper.SetDefault("stream", true)
|
||||
viper.SetDefault("num-gpu-layers", -1)
|
||||
viper.SetDefault("main-gpu", 0)
|
||||
func setSDKDefaults(v *viper.Viper) {
|
||||
v.SetDefault("model", "anthropic/claude-sonnet-4-5-20250929")
|
||||
v.SetDefault("system-prompt", defaultSystemPrompt)
|
||||
v.SetDefault("stream", true)
|
||||
v.SetDefault("num-gpu-layers", -1)
|
||||
v.SetDefault("main-gpu", 0)
|
||||
}
|
||||
|
||||
// InitConfig initializes the viper configuration system.
|
||||
// InitConfig initializes the process-global viper configuration system.
|
||||
// It searches for config files in standard locations and loads them with
|
||||
// environment variable substitution.
|
||||
//
|
||||
// configFile: explicit config file path (empty = search defaults).
|
||||
// debug: if true, print warnings about missing configs to stderr.
|
||||
//
|
||||
// This wraps [initConfig] using the process-global store and is retained for
|
||||
// the CLI, which binds its flags to the global viper.
|
||||
func InitConfig(configFile string, debug bool) error {
|
||||
return initConfig(viper.GetViper(), configFile, debug)
|
||||
}
|
||||
|
||||
// initConfig loads configuration into the supplied per-instance store. When v
|
||||
// is nil the process-global store is used.
|
||||
func initConfig(v *viper.Viper, configFile string, debug bool) error {
|
||||
if v == nil {
|
||||
v = viper.GetViper()
|
||||
}
|
||||
|
||||
// Configure KIT_* environment overrides unconditionally, before any file
|
||||
// is loaded, so that an explicit config file does not disable env support.
|
||||
// Map hyphenated config keys (e.g. "max-tokens") to underscored env var
|
||||
// names (e.g. KIT_MAX_TOKENS); without this AutomaticEnv looks for
|
||||
// KIT_MAX-TOKENS and silently misses valid overrides. Precedence is
|
||||
// resolved at read time, so calling these before ReadConfig is fine.
|
||||
v.SetEnvPrefix("KIT")
|
||||
v.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
|
||||
v.AutomaticEnv()
|
||||
|
||||
if configFile != "" {
|
||||
return LoadConfigWithEnvSubstitution(configFile)
|
||||
return loadConfigWithEnvSubstitution(v, configFile)
|
||||
}
|
||||
|
||||
// Ensure a config file exists (create default if none found).
|
||||
@@ -97,15 +120,15 @@ func InitConfig(configFile string, debug bool) error {
|
||||
}
|
||||
|
||||
// Current directory has higher priority than home directory.
|
||||
viper.AddConfigPath(".")
|
||||
viper.AddConfigPath(home)
|
||||
v.AddConfigPath(".")
|
||||
v.AddConfigPath(home)
|
||||
|
||||
configLoaded := false
|
||||
|
||||
viper.SetConfigName(".kit")
|
||||
if err := viper.ReadInConfig(); err == nil {
|
||||
configPath := viper.ConfigFileUsed()
|
||||
if err := LoadConfigWithEnvSubstitution(configPath); err != nil {
|
||||
v.SetConfigName(".kit")
|
||||
if err := v.ReadInConfig(); err == nil {
|
||||
configPath := v.ConfigFileUsed()
|
||||
if err := loadConfigWithEnvSubstitution(v, configPath); err != nil {
|
||||
if strings.Contains(err.Error(), "environment variable substitution failed") {
|
||||
return fmt.Errorf("error reading config file '%s': %w", configPath, err)
|
||||
}
|
||||
@@ -118,17 +141,21 @@ func InitConfig(configFile string, debug bool) error {
|
||||
fmt.Fprintf(os.Stderr, "No config file found in current directory or home directory\n")
|
||||
}
|
||||
|
||||
viper.SetEnvPrefix("KIT")
|
||||
// Map hyphenated config keys (e.g. "max-tokens") to underscored env
|
||||
// var names (e.g. KIT_MAX_TOKENS). Without this, AutomaticEnv looks
|
||||
// for KIT_MAX-TOKENS and silently misses valid env overrides.
|
||||
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
|
||||
viper.AutomaticEnv()
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadConfigWithEnvSubstitution loads a config file with ${ENV_VAR} expansion.
|
||||
// LoadConfigWithEnvSubstitution loads a config file with ${ENV_VAR} expansion
|
||||
// into the process-global viper store.
|
||||
func LoadConfigWithEnvSubstitution(configPath string) error {
|
||||
return loadConfigWithEnvSubstitution(viper.GetViper(), configPath)
|
||||
}
|
||||
|
||||
// loadConfigWithEnvSubstitution loads a config file with ${ENV_VAR} expansion
|
||||
// into the supplied per-instance store (or the global store when v is nil).
|
||||
func loadConfigWithEnvSubstitution(v *viper.Viper, configPath string) error {
|
||||
if v == nil {
|
||||
v = viper.GetViper()
|
||||
}
|
||||
rawContent, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config file: %w", err)
|
||||
@@ -146,6 +173,6 @@ func LoadConfigWithEnvSubstitution(configPath string) error {
|
||||
}
|
||||
|
||||
config.SetConfigPath(configPath)
|
||||
viper.SetConfigType(configType)
|
||||
return viper.ReadConfig(strings.NewReader(processedContent))
|
||||
v.SetConfigType(configType)
|
||||
return v.ReadConfig(strings.NewReader(processedContent))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,150 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Runtime context-file management (Issue #36)
|
||||
// ---------------------------------------------------------------------------
|
||||
//
|
||||
// Project context files (AGENTS.md and friends) are normally auto-discovered
|
||||
// during Kit.New() and injected into the system prompt. SDK consumers building
|
||||
// multi-tenant chatbots often need to swap context per user/session at runtime
|
||||
// without restarting the agent. The methods below provide that surface.
|
||||
//
|
||||
// Every mutation recomposes the system prompt and applies it to the underlying
|
||||
// agent so the next turn sees the updated project context.
|
||||
|
||||
// AddContextFile registers a project context file (e.g. an AGENTS.md
|
||||
// equivalent) on this Kit instance. The file does not need to exist on
|
||||
// disk — Path is treated as an opaque identifier used both for de-duplication
|
||||
// and for the "Instructions from: <Path>" header injected into the system
|
||||
// prompt. If a context file with the same Path is already loaded the new
|
||||
// content replaces it.
|
||||
//
|
||||
// Returns an error when cf is nil or has an empty Path. AddContextFile is
|
||||
// safe to call from any goroutine.
|
||||
func (m *Kit) AddContextFile(cf *ContextFile) error {
|
||||
if cf == nil {
|
||||
return fmt.Errorf("AddContextFile: context file is nil")
|
||||
}
|
||||
if cf.Path == "" {
|
||||
return fmt.Errorf("AddContextFile: context file path is required")
|
||||
}
|
||||
|
||||
// Take a defensive copy so later mutations by the caller don't race with
|
||||
// the agent reading the composed prompt.
|
||||
stored := &ContextFile{
|
||||
Path: cf.Path,
|
||||
Content: strings.TrimSpace(cf.Content),
|
||||
}
|
||||
|
||||
m.runtimeMu.Lock()
|
||||
replaced := false
|
||||
for i, existing := range m.contextFiles {
|
||||
if existing.Path == stored.Path {
|
||||
m.contextFiles[i] = stored
|
||||
replaced = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !replaced {
|
||||
m.contextFiles = append(m.contextFiles, stored)
|
||||
}
|
||||
m.runtimeMu.Unlock()
|
||||
|
||||
m.applyComposedSystemPrompt()
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddContextFileContent is a convenience wrapper around [Kit.AddContextFile]
|
||||
// that builds the ContextFile from a path and inline content string. Use this
|
||||
// when the context originates from a database, API response, or any other
|
||||
// non-filesystem source.
|
||||
func (m *Kit) AddContextFileContent(path, content string) (*ContextFile, error) {
|
||||
cf := &ContextFile{Path: path, Content: content}
|
||||
if err := m.AddContextFile(cf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cf, nil
|
||||
}
|
||||
|
||||
// LoadAndAddContextFile reads a file from disk and registers it as a project
|
||||
// context file via [Kit.AddContextFile]. The absolute path is stored on the
|
||||
// resulting ContextFile.
|
||||
func (m *Kit) LoadAndAddContextFile(path string) (*ContextFile, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("LoadAndAddContextFile: %w", err)
|
||||
}
|
||||
abs, absErr := filepath.Abs(path)
|
||||
if absErr != nil {
|
||||
abs = path
|
||||
}
|
||||
cf := &ContextFile{
|
||||
Path: abs,
|
||||
Content: strings.TrimSpace(string(data)),
|
||||
}
|
||||
if err := m.AddContextFile(cf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cf, nil
|
||||
}
|
||||
|
||||
// RemoveContextFile removes the context file with the given path and
|
||||
// recomposes the system prompt. Returns true when a matching file was found
|
||||
// and removed, false otherwise.
|
||||
func (m *Kit) RemoveContextFile(path string) bool {
|
||||
m.runtimeMu.Lock()
|
||||
found := false
|
||||
for i, cf := range m.contextFiles {
|
||||
if cf.Path == path {
|
||||
m.contextFiles = append(m.contextFiles[:i], m.contextFiles[i+1:]...)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
m.runtimeMu.Unlock()
|
||||
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
m.applyComposedSystemPrompt()
|
||||
return true
|
||||
}
|
||||
|
||||
// SetContextFiles replaces the active context-file set with the provided
|
||||
// slice. Pass nil or an empty slice to clear all context. The system prompt
|
||||
// is recomposed and applied. ContextFiles with empty Paths are rejected and
|
||||
// no mutation is performed.
|
||||
func (m *Kit) SetContextFiles(files []*ContextFile) error {
|
||||
// Validate first so a bad input doesn't partially mutate state.
|
||||
for i, cf := range files {
|
||||
if cf == nil {
|
||||
return fmt.Errorf("SetContextFiles: context file at index %d is nil", i)
|
||||
}
|
||||
if cf.Path == "" {
|
||||
return fmt.Errorf("SetContextFiles: context file at index %d has empty path", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Defensive copies so caller-side mutation cannot race with composition.
|
||||
copied := make([]*ContextFile, len(files))
|
||||
for i, cf := range files {
|
||||
copied[i] = &ContextFile{
|
||||
Path: cf.Path,
|
||||
Content: strings.TrimSpace(cf.Content),
|
||||
}
|
||||
}
|
||||
|
||||
m.runtimeMu.Lock()
|
||||
m.contextFiles = copied
|
||||
m.runtimeMu.Unlock()
|
||||
|
||||
m.applyComposedSystemPrompt()
|
||||
return nil
|
||||
}
|
||||
+54
-172
@@ -3,6 +3,8 @@ package kit
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -103,34 +105,21 @@ type Event interface {
|
||||
// appropriate visualizations (e.g. diff view for edit tools, command+output
|
||||
// for execute tools) and file trackers to identify which results contain
|
||||
// modifications.
|
||||
//
|
||||
// These constants re-export the canonical classification used by extension
|
||||
// events, so SDK events and extension events always agree.
|
||||
const (
|
||||
ToolKindExecute = "execute" // Shell execution (bash)
|
||||
ToolKindEdit = "edit" // File modification (edit, write)
|
||||
ToolKindRead = "read" // File reading (read, ls)
|
||||
ToolKindSearch = "search" // Content/file search (grep, find)
|
||||
ToolKindSubagent = "agent" // Subagent spawning (subagent)
|
||||
ToolKindExecute = extensions.ToolKindExecute // Shell execution (bash)
|
||||
ToolKindEdit = extensions.ToolKindEdit // File modification (edit, write)
|
||||
ToolKindRead = extensions.ToolKindRead // File reading (read, ls)
|
||||
ToolKindSearch = extensions.ToolKindSearch // Content/file search (grep, find)
|
||||
ToolKindSubagent = extensions.ToolKindSubagent // Subagent spawning (subagent)
|
||||
)
|
||||
|
||||
// coreToolKinds maps built-in tool names to their kind. MCP and extension
|
||||
// tools without an entry default to ToolKindExecute.
|
||||
var coreToolKinds = map[string]string{
|
||||
"bash": ToolKindExecute,
|
||||
"edit": ToolKindEdit,
|
||||
"write": ToolKindEdit,
|
||||
"read": ToolKindRead,
|
||||
"ls": ToolKindRead,
|
||||
"grep": ToolKindSearch,
|
||||
"find": ToolKindSearch,
|
||||
"subagent": ToolKindSubagent,
|
||||
}
|
||||
|
||||
// toolKindFor returns the ToolKind for a given tool name, defaulting to
|
||||
// ToolKindExecute for unknown tools.
|
||||
func toolKindFor(toolName string) string {
|
||||
if kind, ok := coreToolKinds[toolName]; ok {
|
||||
return kind
|
||||
}
|
||||
return ToolKindExecute
|
||||
return extensions.ToolKindFor(toolName)
|
||||
}
|
||||
|
||||
// parseToolArgs attempts to parse a JSON-encoded tool args string into a map.
|
||||
@@ -571,67 +560,56 @@ func (eb *eventBus) emit(event Event) {
|
||||
// Typed convenience subscribers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// subscribeTyped is the generic backbone of all the typed `On<EventName>`
|
||||
// convenience methods on *Kit. It wraps Subscribe with a type assertion
|
||||
// against E so handlers receive a strongly-typed event without each
|
||||
// public method having to repeat the boilerplate. Returns an unsubscribe
|
||||
// function.
|
||||
func subscribeTyped[E Event](k *Kit, handler func(E)) func() {
|
||||
return k.Subscribe(func(e Event) {
|
||||
if tev, ok := e.(E); ok {
|
||||
handler(tev)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// OnToolCall registers a handler that fires only for ToolCallEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnToolCall(handler func(ToolCallEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if tc, ok := e.(ToolCallEvent); ok {
|
||||
handler(tc)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnToolCallStart registers a handler that fires only for ToolCallStartEvent.
|
||||
// This fires when the LLM begins generating tool call arguments — before the
|
||||
// full argument JSON is available. Returns an unsubscribe function.
|
||||
func (m *Kit) OnToolCallStart(handler func(ToolCallStartEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if tcs, ok := e.(ToolCallStartEvent); ok {
|
||||
handler(tcs)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnToolCallDelta registers a handler that fires only for ToolCallDeltaEvent.
|
||||
// Each delta contains a JSON fragment of tool call arguments as they stream in.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnToolCallDelta(handler func(ToolCallDeltaEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if tcd, ok := e.(ToolCallDeltaEvent); ok {
|
||||
handler(tcd)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnToolCallEnd registers a handler that fires only for ToolCallEndEvent.
|
||||
// This fires when tool argument streaming is complete, before the tool call
|
||||
// is parsed and execution begins. Returns an unsubscribe function.
|
||||
func (m *Kit) OnToolCallEnd(handler func(ToolCallEndEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if tce, ok := e.(ToolCallEndEvent); ok {
|
||||
handler(tce)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnToolResult registers a handler that fires only for ToolResultEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnToolResult(handler func(ToolResultEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if tr, ok := e.(ToolResultEvent); ok {
|
||||
handler(tr)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnToolOutput registers a handler that fires only for ToolOutputEvent
|
||||
// (streaming tool output chunks, e.g., from bash). Returns an unsubscribe function.
|
||||
func (m *Kit) OnToolOutput(handler func(ToolOutputEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if to, ok := e.(ToolOutputEvent); ok {
|
||||
handler(to)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnStreaming registers a handler that fires only for MessageUpdateEvent
|
||||
@@ -646,41 +624,25 @@ func (m *Kit) OnStreaming(handler func(MessageUpdateEvent)) func() {
|
||||
// OnMessageUpdate registers a handler that fires only for MessageUpdateEvent
|
||||
// (streaming text chunks). Returns an unsubscribe function.
|
||||
func (m *Kit) OnMessageUpdate(handler func(MessageUpdateEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if mu, ok := e.(MessageUpdateEvent); ok {
|
||||
handler(mu)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnResponse registers a handler that fires only for ResponseEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnResponse(handler func(ResponseEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if r, ok := e.(ResponseEvent); ok {
|
||||
handler(r)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnTurnStart registers a handler that fires only for TurnStartEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnTurnStart(handler func(TurnStartEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if ts, ok := e.(TurnStartEvent); ok {
|
||||
handler(ts)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnTurnEnd registers a handler that fires only for TurnEndEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnTurnEnd(handler func(TurnEndEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if te, ok := e.(TurnEndEvent); ok {
|
||||
handler(te)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -690,101 +652,61 @@ func (m *Kit) OnTurnEnd(handler func(TurnEndEvent)) func() {
|
||||
// OnMessageStart registers a handler that fires only for MessageStartEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnMessageStart(handler func(MessageStartEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if ms, ok := e.(MessageStartEvent); ok {
|
||||
handler(ms)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnMessageEnd registers a handler that fires only for MessageEndEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnMessageEnd(handler func(MessageEndEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if me, ok := e.(MessageEndEvent); ok {
|
||||
handler(me)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnReasoningDelta registers a handler that fires only for ReasoningDeltaEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnReasoningDelta(handler func(ReasoningDeltaEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if rd, ok := e.(ReasoningDeltaEvent); ok {
|
||||
handler(rd)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnReasoningComplete registers a handler that fires only for ReasoningCompleteEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnReasoningComplete(handler func(ReasoningCompleteEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if rc, ok := e.(ReasoningCompleteEvent); ok {
|
||||
handler(rc)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnToolExecutionStart registers a handler that fires only for ToolExecutionStartEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnToolExecutionStart(handler func(ToolExecutionStartEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if tes, ok := e.(ToolExecutionStartEvent); ok {
|
||||
handler(tes)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnToolExecutionEnd registers a handler that fires only for ToolExecutionEndEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnToolExecutionEnd(handler func(ToolExecutionEndEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if tee, ok := e.(ToolExecutionEndEvent); ok {
|
||||
handler(tee)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnToolCallContent registers a handler that fires only for ToolCallContentEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnToolCallContent(handler func(ToolCallContentEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if tcc, ok := e.(ToolCallContentEvent); ok {
|
||||
handler(tcc)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnStepUsage registers a handler that fires only for StepUsageEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnStepUsage(handler func(StepUsageEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if su, ok := e.(StepUsageEvent); ok {
|
||||
handler(su)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnCompaction registers a handler that fires only for CompactionEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnCompaction(handler func(CompactionEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if ce, ok := e.(CompactionEvent); ok {
|
||||
handler(ce)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnSteerConsumed registers a handler that fires only for SteerConsumedEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnSteerConsumed(handler func(SteerConsumedEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if sc, ok := e.(SteerConsumedEvent); ok {
|
||||
handler(sc)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -794,101 +716,61 @@ func (m *Kit) OnSteerConsumed(handler func(SteerConsumedEvent)) func() {
|
||||
// OnStepStart registers a handler that fires only for StepStartEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnStepStart(handler func(StepStartEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if ss, ok := e.(StepStartEvent); ok {
|
||||
handler(ss)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnStepFinish registers a handler that fires only for StepFinishEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnStepFinish(handler func(StepFinishEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if sf, ok := e.(StepFinishEvent); ok {
|
||||
handler(sf)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnTextStart registers a handler that fires only for TextStartEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnTextStart(handler func(TextStartEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if ts, ok := e.(TextStartEvent); ok {
|
||||
handler(ts)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnTextEnd registers a handler that fires only for TextEndEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnTextEnd(handler func(TextEndEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if te, ok := e.(TextEndEvent); ok {
|
||||
handler(te)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnReasoningStart registers a handler that fires only for ReasoningStartEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnReasoningStart(handler func(ReasoningStartEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if rs, ok := e.(ReasoningStartEvent); ok {
|
||||
handler(rs)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnWarnings registers a handler that fires only for WarningsEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnWarnings(handler func(WarningsEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if w, ok := e.(WarningsEvent); ok {
|
||||
handler(w)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnSource registers a handler that fires only for SourceEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnSource(handler func(SourceEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if s, ok := e.(SourceEvent); ok {
|
||||
handler(s)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnStreamFinish registers a handler that fires only for StreamFinishEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnStreamFinish(handler func(StreamFinishEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if sf, ok := e.(StreamFinishEvent); ok {
|
||||
handler(sf)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnError registers a handler that fires only for ErrorEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnError(handler func(ErrorEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if ee, ok := e.(ErrorEvent); ok {
|
||||
handler(ee)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// OnRetry registers a handler that fires only for RetryEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnRetry(handler func(RetryEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if r, ok := e.(RetryEvent); ok {
|
||||
handler(r)
|
||||
}
|
||||
})
|
||||
return subscribeTyped(m, handler)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
package kit
|
||||
|
||||
// This file exposes a handful of internal accessors to the external kit_test
|
||||
// package. Because it ends in _test.go it is only compiled during testing and
|
||||
// is therefore not part of the public SDK surface.
|
||||
|
||||
// ConfigValueIsSetForTest reports whether key is explicitly set in this Kit's
|
||||
// isolated configuration store. Used by tests to assert the tri-state
|
||||
// precedence contract per-instance.
|
||||
func (m *Kit) ConfigValueIsSetForTest(key string) bool { return m.v.IsSet(key) }
|
||||
|
||||
// ConfigStringForTest returns the string value of key from this Kit's isolated
|
||||
// configuration store.
|
||||
func (m *Kit) ConfigStringForTest(key string) string { return m.v.GetString(key) }
|
||||
|
||||
// ConfigFloatForTest returns the float64 value of key from this Kit's isolated
|
||||
// configuration store.
|
||||
func (m *Kit) ConfigFloatForTest(key string) float64 { return m.v.GetFloat64(key) }
|
||||
|
||||
// ConfigBoolForTest returns the bool value of key from this Kit's isolated
|
||||
// configuration store.
|
||||
func (m *Kit) ConfigBoolForTest(key string) bool { return m.v.GetBool(key) }
|
||||
|
||||
// ConfigStringSliceForTest returns the string slice value of key from this
|
||||
// Kit's isolated configuration store.
|
||||
func (m *Kit) ConfigStringSliceForTest(key string) []string {
|
||||
return m.v.GetStringSlice(key)
|
||||
}
|
||||
+178
-49
@@ -2,61 +2,129 @@ package kit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/message"
|
||||
"github.com/mark3labs/kit/internal/session"
|
||||
)
|
||||
|
||||
// ==== Extension Types ====
|
||||
//
|
||||
// Type aliases for internal extension types exposed through the public
|
||||
// ExtensionAPI interface. External SDK consumers can use these without
|
||||
// importing internal packages directly.
|
||||
|
||||
// ExtensionContext holds the runtime context passed to extensions, including
|
||||
// callbacks for printing, sending messages, and accessing session state.
|
||||
type ExtensionContext = extensions.Context
|
||||
|
||||
// ExtensionWidgetConfig describes a widget registered by an extension.
|
||||
type ExtensionWidgetConfig = extensions.WidgetConfig
|
||||
|
||||
// ExtensionWidgetPlacement indicates where a widget should be rendered
|
||||
// (e.g. above or below the conversation).
|
||||
type ExtensionWidgetPlacement = extensions.WidgetPlacement
|
||||
|
||||
// ExtensionHeaderFooterConfig describes a header or footer registered by an extension.
|
||||
type ExtensionHeaderFooterConfig = extensions.HeaderFooterConfig
|
||||
|
||||
// ExtensionEditorConfig configures editor behaviour overrides set by extensions.
|
||||
type ExtensionEditorConfig = extensions.EditorConfig
|
||||
|
||||
// ExtensionUIVisibility controls which UI elements are visible.
|
||||
type ExtensionUIVisibility = extensions.UIVisibility
|
||||
|
||||
// ExtensionToolRenderConfig describes custom tool output rendering registered by an extension.
|
||||
type ExtensionToolRenderConfig = extensions.ToolRenderConfig
|
||||
|
||||
// ExtensionMessageRendererConfig describes custom message rendering registered by an extension.
|
||||
type ExtensionMessageRendererConfig = extensions.MessageRendererConfig
|
||||
|
||||
// ExtensionSessionMessage represents a single message in the session history
|
||||
// as exposed to extensions.
|
||||
type ExtensionSessionMessage = extensions.SessionMessage
|
||||
|
||||
// ExtensionEntry represents a custom data entry stored by an extension
|
||||
// in the session tree.
|
||||
type ExtensionEntry = extensions.ExtensionEntry
|
||||
|
||||
// ExtensionStatusBarEntry describes a status bar entry registered by an extension.
|
||||
type ExtensionStatusBarEntry = extensions.StatusBarEntry
|
||||
|
||||
// ExtensionToolInfo describes a tool available to the agent, as seen by extensions.
|
||||
type ExtensionToolInfo = extensions.ToolInfo
|
||||
|
||||
// ExtensionCommandDef describes a slash command registered by an extension.
|
||||
type ExtensionCommandDef = extensions.CommandDef
|
||||
|
||||
// ExtensionAPI provides grouped access to all extension-related functionality.
|
||||
// This cleans up the main Kit API surface while keeping all extension capabilities available.
|
||||
type ExtensionAPI interface {
|
||||
// Context management
|
||||
SetContext(ctx extensions.Context)
|
||||
GetContext() extensions.Context
|
||||
SetContext(ctx ExtensionContext)
|
||||
GetContext() ExtensionContext
|
||||
UpdateContextModel(model string)
|
||||
|
||||
// Widgets
|
||||
SetWidget(config extensions.WidgetConfig)
|
||||
SetWidget(config ExtensionWidgetConfig)
|
||||
RemoveWidget(id string)
|
||||
GetWidgets(placement extensions.WidgetPlacement) []extensions.WidgetConfig
|
||||
GetWidgets(placement ExtensionWidgetPlacement) []ExtensionWidgetConfig
|
||||
|
||||
// Header/Footer
|
||||
SetHeader(config extensions.HeaderFooterConfig)
|
||||
SetHeader(config ExtensionHeaderFooterConfig)
|
||||
RemoveHeader()
|
||||
GetHeader() *extensions.HeaderFooterConfig
|
||||
SetFooter(config extensions.HeaderFooterConfig)
|
||||
GetHeader() *ExtensionHeaderFooterConfig
|
||||
SetFooter(config ExtensionHeaderFooterConfig)
|
||||
RemoveFooter()
|
||||
GetFooter() *extensions.HeaderFooterConfig
|
||||
GetFooter() *ExtensionHeaderFooterConfig
|
||||
|
||||
// Editor
|
||||
SetEditor(config extensions.EditorConfig)
|
||||
SetEditor(config ExtensionEditorConfig)
|
||||
ResetEditor()
|
||||
GetEditor() *extensions.EditorConfig
|
||||
GetEditor() *ExtensionEditorConfig
|
||||
|
||||
// UI Visibility
|
||||
SetUIVisibility(v extensions.UIVisibility)
|
||||
GetUIVisibility() *extensions.UIVisibility
|
||||
SetUIVisibility(v ExtensionUIVisibility)
|
||||
GetUIVisibility() *ExtensionUIVisibility
|
||||
|
||||
// Tool rendering
|
||||
GetToolRenderer(toolName string) *extensions.ToolRenderConfig
|
||||
GetMessageRenderer(name string) *extensions.MessageRendererConfig
|
||||
GetToolRenderer(toolName string) *ExtensionToolRenderConfig
|
||||
GetMessageRenderer(name string) *ExtensionMessageRendererConfig
|
||||
|
||||
// Session data
|
||||
GetSessionMessages() []extensions.SessionMessage
|
||||
GetSessionMessages() []ExtensionSessionMessage
|
||||
AppendEntry(extType, data string) (string, error)
|
||||
GetEntries(extType string) []extensions.ExtensionEntry
|
||||
GetEntries(extType string) []ExtensionEntry
|
||||
|
||||
// Session-scoped extension state (last-write-wins key-value store).
|
||||
// Backed by an in-memory map and (optionally) a sidecar file per session;
|
||||
// state lives outside the conversation tree and is not visible to the LLM.
|
||||
SetState(key, value string)
|
||||
GetState(key string) (string, bool)
|
||||
DeleteState(key string)
|
||||
ListState() []string
|
||||
|
||||
// InitStatePersistence loads any existing state from the per-session
|
||||
// sidecar file and installs a saver hook so that subsequent SetState /
|
||||
// DeleteState mutations are flushed to disk. Safe to call multiple times;
|
||||
// repeat calls simply reload and reinstall the saver.
|
||||
//
|
||||
// For ephemeral or in-memory sessions (no session file path), the call
|
||||
// is a no-op and state remains in memory for the lifetime of the runner.
|
||||
InitStatePersistence() error
|
||||
|
||||
// Status bar
|
||||
SetStatus(entry extensions.StatusBarEntry)
|
||||
SetStatus(entry ExtensionStatusBarEntry)
|
||||
RemoveStatus(key string)
|
||||
GetStatusEntries() []extensions.StatusBarEntry
|
||||
GetStatusEntries() []ExtensionStatusBarEntry
|
||||
|
||||
// Shortcuts
|
||||
GetShortcuts() map[string]func()
|
||||
|
||||
// Tools
|
||||
GetToolInfos() []extensions.ToolInfo
|
||||
GetToolInfos() []ExtensionToolInfo
|
||||
SetActiveTools(names []string)
|
||||
|
||||
// Options
|
||||
@@ -71,7 +139,7 @@ type ExtensionAPI interface {
|
||||
EmitBeforeSessionSwitch(switchReason string) (cancelled bool, reason string)
|
||||
|
||||
// Commands
|
||||
Commands() []extensions.CommandDef
|
||||
Commands() []ExtensionCommandDef
|
||||
|
||||
// Lifecycle
|
||||
Reload() error
|
||||
@@ -106,17 +174,17 @@ func (m *Kit) Extensions() ExtensionAPI {
|
||||
|
||||
// Context management
|
||||
|
||||
func (e *extensionAPI) SetContext(ctx extensions.Context) {
|
||||
func (e *extensionAPI) SetContext(ctx ExtensionContext) {
|
||||
if e.kit.extRunner != nil {
|
||||
e.kit.extRunner.SetContext(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) GetContext() extensions.Context {
|
||||
func (e *extensionAPI) GetContext() ExtensionContext {
|
||||
if e.kit.extRunner != nil {
|
||||
return e.kit.extRunner.GetContext()
|
||||
}
|
||||
return extensions.Context{}
|
||||
return ExtensionContext{}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) UpdateContextModel(model string) {
|
||||
@@ -129,7 +197,7 @@ func (e *extensionAPI) UpdateContextModel(model string) {
|
||||
|
||||
// Widgets
|
||||
|
||||
func (e *extensionAPI) SetWidget(config extensions.WidgetConfig) {
|
||||
func (e *extensionAPI) SetWidget(config ExtensionWidgetConfig) {
|
||||
if e.kit.extRunner != nil {
|
||||
e.kit.extRunner.SetWidget(config)
|
||||
}
|
||||
@@ -141,7 +209,7 @@ func (e *extensionAPI) RemoveWidget(id string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) GetWidgets(placement extensions.WidgetPlacement) []extensions.WidgetConfig {
|
||||
func (e *extensionAPI) GetWidgets(placement ExtensionWidgetPlacement) []ExtensionWidgetConfig {
|
||||
if e.kit.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -150,7 +218,7 @@ func (e *extensionAPI) GetWidgets(placement extensions.WidgetPlacement) []extens
|
||||
|
||||
// Header/Footer
|
||||
|
||||
func (e *extensionAPI) SetHeader(config extensions.HeaderFooterConfig) {
|
||||
func (e *extensionAPI) SetHeader(config ExtensionHeaderFooterConfig) {
|
||||
if e.kit.extRunner != nil {
|
||||
e.kit.extRunner.SetHeader(config)
|
||||
}
|
||||
@@ -162,14 +230,14 @@ func (e *extensionAPI) RemoveHeader() {
|
||||
}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) GetHeader() *extensions.HeaderFooterConfig {
|
||||
func (e *extensionAPI) GetHeader() *ExtensionHeaderFooterConfig {
|
||||
if e.kit.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
return e.kit.extRunner.GetHeader()
|
||||
}
|
||||
|
||||
func (e *extensionAPI) SetFooter(config extensions.HeaderFooterConfig) {
|
||||
func (e *extensionAPI) SetFooter(config ExtensionHeaderFooterConfig) {
|
||||
if e.kit.extRunner != nil {
|
||||
e.kit.extRunner.SetFooter(config)
|
||||
}
|
||||
@@ -181,7 +249,7 @@ func (e *extensionAPI) RemoveFooter() {
|
||||
}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) GetFooter() *extensions.HeaderFooterConfig {
|
||||
func (e *extensionAPI) GetFooter() *ExtensionHeaderFooterConfig {
|
||||
if e.kit.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -190,7 +258,7 @@ func (e *extensionAPI) GetFooter() *extensions.HeaderFooterConfig {
|
||||
|
||||
// Editor
|
||||
|
||||
func (e *extensionAPI) SetEditor(config extensions.EditorConfig) {
|
||||
func (e *extensionAPI) SetEditor(config ExtensionEditorConfig) {
|
||||
if e.kit.extRunner != nil {
|
||||
e.kit.extRunner.SetEditor(config)
|
||||
}
|
||||
@@ -202,7 +270,7 @@ func (e *extensionAPI) ResetEditor() {
|
||||
}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) GetEditor() *extensions.EditorConfig {
|
||||
func (e *extensionAPI) GetEditor() *ExtensionEditorConfig {
|
||||
if e.kit.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -211,13 +279,13 @@ func (e *extensionAPI) GetEditor() *extensions.EditorConfig {
|
||||
|
||||
// UI Visibility
|
||||
|
||||
func (e *extensionAPI) SetUIVisibility(v extensions.UIVisibility) {
|
||||
func (e *extensionAPI) SetUIVisibility(v ExtensionUIVisibility) {
|
||||
if e.kit.extRunner != nil {
|
||||
e.kit.extRunner.SetUIVisibility(v)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) GetUIVisibility() *extensions.UIVisibility {
|
||||
func (e *extensionAPI) GetUIVisibility() *ExtensionUIVisibility {
|
||||
if e.kit.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -226,14 +294,14 @@ func (e *extensionAPI) GetUIVisibility() *extensions.UIVisibility {
|
||||
|
||||
// Tool rendering
|
||||
|
||||
func (e *extensionAPI) GetToolRenderer(toolName string) *extensions.ToolRenderConfig {
|
||||
func (e *extensionAPI) GetToolRenderer(toolName string) *ExtensionToolRenderConfig {
|
||||
if e.kit.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
return e.kit.extRunner.GetToolRenderer(toolName)
|
||||
}
|
||||
|
||||
func (e *extensionAPI) GetMessageRenderer(name string) *extensions.MessageRendererConfig {
|
||||
func (e *extensionAPI) GetMessageRenderer(name string) *ExtensionMessageRendererConfig {
|
||||
if e.kit.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -242,7 +310,7 @@ func (e *extensionAPI) GetMessageRenderer(name string) *extensions.MessageRender
|
||||
|
||||
// Session data
|
||||
|
||||
func (e *extensionAPI) GetSessionMessages() []extensions.SessionMessage {
|
||||
func (e *extensionAPI) GetSessionMessages() []ExtensionSessionMessage {
|
||||
if e.kit.session == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -250,8 +318,8 @@ func (e *extensionAPI) GetSessionMessages() []extensions.SessionMessage {
|
||||
// Try to use the legacy iterBranchMessages for backward compatibility
|
||||
// with the default TreeManager adapter
|
||||
if adapter, ok := e.kit.session.(*treeManagerAdapter); ok {
|
||||
return iterBranchMessages(adapter.inner, func(me *session.MessageEntry, msg message.Message) extensions.SessionMessage {
|
||||
return extensions.SessionMessage{
|
||||
return iterBranchMessages(adapter.inner, func(me *session.MessageEntry, msg message.Message) ExtensionSessionMessage {
|
||||
return ExtensionSessionMessage{
|
||||
ID: me.ID,
|
||||
Role: string(msg.Role),
|
||||
Content: msg.Content(),
|
||||
@@ -262,10 +330,10 @@ func (e *extensionAPI) GetSessionMessages() []extensions.SessionMessage {
|
||||
|
||||
// For custom SessionManagers, use the public interface
|
||||
branch := e.kit.session.GetCurrentBranch()
|
||||
var result []extensions.SessionMessage
|
||||
var result []ExtensionSessionMessage
|
||||
for _, entry := range branch {
|
||||
if entry.Type == EntryTypeMessage {
|
||||
result = append(result, extensions.SessionMessage{
|
||||
result = append(result, ExtensionSessionMessage{
|
||||
ID: entry.ID,
|
||||
Role: entry.Role,
|
||||
Content: entry.Content,
|
||||
@@ -283,14 +351,75 @@ func (e *extensionAPI) AppendEntry(extType, data string) (string, error) {
|
||||
return e.kit.session.AppendExtensionData(extType, data)
|
||||
}
|
||||
|
||||
func (e *extensionAPI) GetEntries(extType string) []extensions.ExtensionEntry {
|
||||
func (e *extensionAPI) SetState(key, value string) {
|
||||
if e.kit.extRunner != nil {
|
||||
e.kit.extRunner.SetState(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) GetState(key string) (string, bool) {
|
||||
if e.kit.extRunner == nil {
|
||||
return "", false
|
||||
}
|
||||
return e.kit.extRunner.GetState(key)
|
||||
}
|
||||
|
||||
func (e *extensionAPI) DeleteState(key string) {
|
||||
if e.kit.extRunner != nil {
|
||||
e.kit.extRunner.DeleteState(key)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) ListState() []string {
|
||||
if e.kit.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
return e.kit.extRunner.ListState()
|
||||
}
|
||||
|
||||
func (e *extensionAPI) InitStatePersistence() error {
|
||||
if e.kit.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
path := extStateSidecarPath(e.kit.GetSessionPath())
|
||||
if path == "" {
|
||||
// Ephemeral or in-memory session; no on-disk state.
|
||||
e.kit.extRunner.SetStateSaver(nil)
|
||||
return nil
|
||||
}
|
||||
if err := e.kit.extRunner.LoadStateFromFile(path); err != nil {
|
||||
return err
|
||||
}
|
||||
runner := e.kit.extRunner
|
||||
runner.SetStateSaver(func() {
|
||||
if err := runner.SaveStateToFile(path); err != nil {
|
||||
log.Printf("WARN extension state save failed: path=%s err=%v", path, err)
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// extStateSidecarPath returns the path to the per-session extension state
|
||||
// sidecar file derived from the session's JSONL path. Returns empty for
|
||||
// ephemeral / in-memory sessions where no JSONL is being written.
|
||||
func extStateSidecarPath(sessionPath string) string {
|
||||
if sessionPath == "" {
|
||||
return ""
|
||||
}
|
||||
if trimmed, ok := strings.CutSuffix(sessionPath, ".jsonl"); ok {
|
||||
return trimmed + ".ext-state.json"
|
||||
}
|
||||
return sessionPath + ".ext-state.json"
|
||||
}
|
||||
|
||||
func (e *extensionAPI) GetEntries(extType string) []ExtensionEntry {
|
||||
if e.kit.session == nil {
|
||||
return nil
|
||||
}
|
||||
entries := e.kit.session.GetExtensionData(extType)
|
||||
result := make([]extensions.ExtensionEntry, 0, len(entries))
|
||||
result := make([]ExtensionEntry, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
result = append(result, extensions.ExtensionEntry{
|
||||
result = append(result, ExtensionEntry{
|
||||
ID: e.ID,
|
||||
EntryType: e.ExtType,
|
||||
Data: e.Data,
|
||||
@@ -302,7 +431,7 @@ func (e *extensionAPI) GetEntries(extType string) []extensions.ExtensionEntry {
|
||||
|
||||
// Status bar
|
||||
|
||||
func (e *extensionAPI) SetStatus(entry extensions.StatusBarEntry) {
|
||||
func (e *extensionAPI) SetStatus(entry ExtensionStatusBarEntry) {
|
||||
if e.kit.extRunner != nil {
|
||||
e.kit.extRunner.SetStatusEntry(entry)
|
||||
}
|
||||
@@ -314,7 +443,7 @@ func (e *extensionAPI) RemoveStatus(key string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (e *extensionAPI) GetStatusEntries() []extensions.StatusBarEntry {
|
||||
func (e *extensionAPI) GetStatusEntries() []ExtensionStatusBarEntry {
|
||||
if e.kit.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -345,12 +474,12 @@ func (e *extensionAPI) GetShortcuts() map[string]func() {
|
||||
|
||||
// Tools
|
||||
|
||||
func (e *extensionAPI) GetToolInfos() []extensions.ToolInfo {
|
||||
func (e *extensionAPI) GetToolInfos() []ExtensionToolInfo {
|
||||
agentTools := e.kit.agent.GetTools()
|
||||
coreCount := e.kit.agent.GetCoreToolCount()
|
||||
mcpCount := e.kit.agent.GetMCPToolCount()
|
||||
|
||||
result := make([]extensions.ToolInfo, 0, len(agentTools))
|
||||
result := make([]ExtensionToolInfo, 0, len(agentTools))
|
||||
for i, t := range agentTools {
|
||||
info := t.Info()
|
||||
source := "core"
|
||||
@@ -363,7 +492,7 @@ func (e *extensionAPI) GetToolInfos() []extensions.ToolInfo {
|
||||
if e.kit.extRunner != nil && e.kit.extRunner.IsToolDisabled(info.Name) {
|
||||
enabled = false
|
||||
}
|
||||
result = append(result, extensions.ToolInfo{
|
||||
result = append(result, ExtensionToolInfo{
|
||||
Name: info.Name,
|
||||
Description: info.Description,
|
||||
Source: source,
|
||||
@@ -456,7 +585,7 @@ func (e *extensionAPI) EmitBeforeSessionSwitch(switchReason string) (cancelled b
|
||||
|
||||
// Commands
|
||||
|
||||
func (e *extensionAPI) Commands() []extensions.CommandDef {
|
||||
func (e *extensionAPI) Commands() []ExtensionCommandDef {
|
||||
if e.kit.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
+366
-217
@@ -3,8 +3,11 @@ package kit
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
)
|
||||
|
||||
// bridgeExtensions registers extension event handlers as SDK hooks and
|
||||
@@ -19,6 +22,30 @@ import (
|
||||
// wrapper (internal/extensions/wrapper.go) which composes underneath the SDK
|
||||
// hook wrapper.
|
||||
func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
// Per-turn aggregator: collects tool/LLM/usage signals between AgentStart
|
||||
// and AgentEnd so the enriched AgentEndEvent can be populated without
|
||||
// requiring extensions to maintain parallel bookkeeping.
|
||||
//
|
||||
// NOTE: this aggregator assumes a single in-flight turn per *Kit instance,
|
||||
// which is the current contract — runTurn does not serialize callers and
|
||||
// the SDK's TurnStartEvent/TurnEndEvent do not carry a turn ID, so two
|
||||
// concurrent Prompt() calls on the same *Kit would clobber the counters.
|
||||
// All current callers (TUI app layer, CLI runner, SDK examples) serialize
|
||||
// turns above this layer. If concurrent turns become a supported use case,
|
||||
// extend TurnStartEvent/TurnEndEvent with a turn ID and key this map per
|
||||
// turn instead.
|
||||
turnAgg := &turnAggregator{kit: m}
|
||||
m.Subscribe(func(e Event) {
|
||||
switch ev := e.(type) {
|
||||
case TurnStartEvent:
|
||||
turnAgg.start()
|
||||
case ToolResultEvent:
|
||||
turnAgg.recordTool(ev.ToolName)
|
||||
case StepFinishEvent:
|
||||
turnAgg.recordStep(ev.Usage)
|
||||
}
|
||||
})
|
||||
|
||||
// --- Interception hooks ---
|
||||
|
||||
// Extension Input → BeforeTurn hook (high priority, runs first).
|
||||
@@ -54,83 +81,51 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
// Subscribe to SDK events and forward to extension runner so extensions
|
||||
// see lifecycle events from the SDK's runTurn()/generate() path.
|
||||
|
||||
if runner.HasHandlers(extensions.AgentStart) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(TurnStartEvent); ok {
|
||||
_, _ = runner.Emit(extensions.AgentStartEvent{Prompt: ev.Prompt})
|
||||
}
|
||||
})
|
||||
}
|
||||
bridgeObserve(m, runner, extensions.AgentStart, func(ev TurnStartEvent) extensions.Event {
|
||||
return extensions.AgentStartEvent{Prompt: ev.Prompt}
|
||||
})
|
||||
|
||||
if runner.HasHandlers(extensions.MessageStart) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if _, ok := e.(MessageStartEvent); ok {
|
||||
_, _ = runner.Emit(extensions.MessageStartEvent{})
|
||||
}
|
||||
})
|
||||
}
|
||||
bridgeObserve(m, runner, extensions.MessageStart, func(_ MessageStartEvent) extensions.Event {
|
||||
return extensions.MessageStartEvent{}
|
||||
})
|
||||
|
||||
if runner.HasHandlers(extensions.MessageUpdate) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(MessageUpdateEvent); ok {
|
||||
_, _ = runner.Emit(extensions.MessageUpdateEvent{Chunk: ev.Chunk})
|
||||
}
|
||||
})
|
||||
}
|
||||
bridgeObserve(m, runner, extensions.MessageUpdate, func(ev MessageUpdateEvent) extensions.Event {
|
||||
return extensions.MessageUpdateEvent{Chunk: ev.Chunk}
|
||||
})
|
||||
|
||||
if runner.HasHandlers(extensions.MessageEnd) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(MessageEndEvent); ok {
|
||||
_, _ = runner.Emit(extensions.MessageEndEvent{Content: ev.Content})
|
||||
}
|
||||
})
|
||||
}
|
||||
bridgeObserve(m, runner, extensions.MessageEnd, func(ev MessageEndEvent) extensions.Event {
|
||||
return extensions.MessageEndEvent{Content: ev.Content}
|
||||
})
|
||||
|
||||
// Tool output streaming events (observation only).
|
||||
if runner.HasHandlers(extensions.ToolOutput) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(ToolOutputEvent); ok {
|
||||
_, _ = runner.Emit(extensions.ToolOutputEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName,
|
||||
Chunk: ev.Chunk,
|
||||
IsStderr: ev.IsStderr,
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
bridgeObserve(m, runner, extensions.ToolOutput, func(ev ToolOutputEvent) extensions.Event {
|
||||
return extensions.ToolOutputEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName,
|
||||
Chunk: ev.Chunk,
|
||||
IsStderr: ev.IsStderr,
|
||||
}
|
||||
})
|
||||
|
||||
// Tool call input streaming events — fire as the LLM generates tool arguments.
|
||||
if runner.HasHandlers(extensions.ToolCallInputStart) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(ToolCallStartEvent); ok {
|
||||
_, _ = runner.Emit(extensions.ToolCallInputStartEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName,
|
||||
ToolKind: ev.ToolKind,
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
if runner.HasHandlers(extensions.ToolCallInputDelta) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(ToolCallDeltaEvent); ok {
|
||||
_, _ = runner.Emit(extensions.ToolCallInputDeltaEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
Delta: ev.Delta,
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
if runner.HasHandlers(extensions.ToolCallInputEnd) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(ToolCallEndEvent); ok {
|
||||
_, _ = runner.Emit(extensions.ToolCallInputEndEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
bridgeObserve(m, runner, extensions.ToolCallInputStart, func(ev ToolCallStartEvent) extensions.Event {
|
||||
return extensions.ToolCallInputStartEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName,
|
||||
ToolKind: ev.ToolKind,
|
||||
}
|
||||
})
|
||||
bridgeObserve(m, runner, extensions.ToolCallInputDelta, func(ev ToolCallDeltaEvent) extensions.Event {
|
||||
return extensions.ToolCallInputDeltaEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
Delta: ev.Delta,
|
||||
}
|
||||
})
|
||||
bridgeObserve(m, runner, extensions.ToolCallInputEnd, func(ev ToolCallEndEvent) extensions.Event {
|
||||
return extensions.ToolCallInputEndEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
}
|
||||
})
|
||||
|
||||
if runner.HasHandlers(extensions.AgentEnd) {
|
||||
m.Subscribe(func(e Event) {
|
||||
@@ -141,9 +136,19 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
} else if stopReason == "" {
|
||||
stopReason = "completed"
|
||||
}
|
||||
agg := turnAgg.consume()
|
||||
_, _ = runner.Emit(extensions.AgentEndEvent{
|
||||
Response: response,
|
||||
StopReason: stopReason,
|
||||
Response: response,
|
||||
StopReason: stopReason,
|
||||
ToolCallCount: agg.toolCallCount,
|
||||
ToolNames: agg.toolNames,
|
||||
LLMCallCount: agg.llmCallCount,
|
||||
InputTokensDelta: agg.inputTokens,
|
||||
OutputTokensDelta: agg.outputTokens,
|
||||
CacheReadTokensDelta: agg.cacheReadTokens,
|
||||
CacheWriteTokensDelta: agg.cacheWriteTokens,
|
||||
CostDelta: agg.cost,
|
||||
DurationMs: agg.durationMs(),
|
||||
})
|
||||
}
|
||||
})
|
||||
@@ -278,54 +283,13 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
// Extension ContextPrepare → SDK ContextPrepare hook.
|
||||
if runner.HasHandlers(extensions.ContextPrepare) {
|
||||
m.OnContextPrepare(HookPriorityNormal, func(h ContextPrepareHook) *ContextPrepareResult {
|
||||
// Convert LLM message slice to extension ContextMessage slice.
|
||||
// Extract plain text from each message for the extension API.
|
||||
extMsgs := make([]extensions.ContextMessage, len(h.Messages))
|
||||
for i, msg := range h.Messages {
|
||||
var sb strings.Builder
|
||||
for _, part := range msg.Content {
|
||||
if tp, ok := part.(LLMTextPart); ok {
|
||||
sb.WriteString(tp.Text)
|
||||
}
|
||||
}
|
||||
extMsgs[i] = extensions.ContextMessage{
|
||||
Index: i,
|
||||
Role: string(msg.Role),
|
||||
Content: sb.String(),
|
||||
}
|
||||
}
|
||||
|
||||
extMsgs := llmToContextMessages(h.Messages)
|
||||
result, _ := runner.Emit(extensions.ContextPrepareEvent{Messages: extMsgs})
|
||||
r, ok := result.(extensions.ContextPrepareResult)
|
||||
if !ok || r.Messages == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rebuild LLM message slice from extension result.
|
||||
rebuilt := make([]LLMMessage, 0, len(r.Messages))
|
||||
for _, cm := range r.Messages {
|
||||
if cm.Index >= 0 && cm.Index < len(h.Messages) {
|
||||
// Reuse original message (preserves original role and content).
|
||||
rebuilt = append(rebuilt, h.Messages[cm.Index])
|
||||
} else {
|
||||
// New message injected by extension — construct from role + text.
|
||||
role := LLMRoleUser
|
||||
switch cm.Role {
|
||||
case "assistant":
|
||||
role = LLMRoleAssistant
|
||||
case "system":
|
||||
role = LLMRoleSystem
|
||||
case "tool":
|
||||
role = LLMRoleTool
|
||||
}
|
||||
rebuilt = append(rebuilt, LLMMessage{
|
||||
Role: role,
|
||||
Content: []LLMMessagePart{LLMTextPart{Text: cm.Content}},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return &ContextPrepareResult{Messages: rebuilt}
|
||||
return &ContextPrepareResult{Messages: contextMessagesToLLM(r.Messages, h.Messages)}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -359,99 +323,82 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
|
||||
// --- Step lifecycle observation events ---
|
||||
|
||||
if runner.HasHandlers(extensions.StepStart) {
|
||||
bridgeObserve(m, runner, extensions.StepStart, func(ev StepStartEvent) extensions.Event {
|
||||
return extensions.StepStartEvent{StepNumber: ev.StepNumber}
|
||||
})
|
||||
|
||||
bridgeObserve(m, runner, extensions.StepFinish, func(ev StepFinishEvent) extensions.Event {
|
||||
return extensions.StepFinishEvent{
|
||||
StepNumber: ev.StepNumber,
|
||||
HasToolCalls: ev.HasToolCalls,
|
||||
FinishReason: ev.FinishReason,
|
||||
InputTokens: ev.Usage.InputTokens,
|
||||
OutputTokens: ev.Usage.OutputTokens,
|
||||
CacheReadTokens: ev.Usage.CacheReadTokens,
|
||||
CacheWriteTokens: ev.Usage.CacheCreationTokens,
|
||||
}
|
||||
})
|
||||
|
||||
// LLMUsage: derive per-call usage from StepFinish. Each step corresponds
|
||||
// to one LLM provider call, so the step's usage is the per-call delta.
|
||||
// Cost is computed from the current model's pricing (zero when unknown
|
||||
// or OAuth credentials are in use). RequestID is left empty until the
|
||||
// SDK surfaces a correlation id from the underlying provider.
|
||||
if runner.HasHandlers(extensions.LLMUsage) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(StepStartEvent); ok {
|
||||
_, _ = runner.Emit(extensions.StepStartEvent{StepNumber: ev.StepNumber})
|
||||
ev, ok := e.(StepFinishEvent)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
provider, modelID, cost := llmUsageMeta(m, ev.Usage)
|
||||
_, _ = runner.Emit(extensions.LLMUsageEvent{
|
||||
InputTokens: int(ev.Usage.InputTokens),
|
||||
OutputTokens: int(ev.Usage.OutputTokens),
|
||||
CacheReadTokens: int(ev.Usage.CacheReadTokens),
|
||||
CacheWriteTokens: int(ev.Usage.CacheCreationTokens),
|
||||
Cost: cost,
|
||||
Model: modelID,
|
||||
Provider: provider,
|
||||
StepNumber: ev.StepNumber,
|
||||
FinishReason: ev.FinishReason,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
if runner.HasHandlers(extensions.StepFinish) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(StepFinishEvent); ok {
|
||||
_, _ = runner.Emit(extensions.StepFinishEvent{
|
||||
StepNumber: ev.StepNumber,
|
||||
HasToolCalls: ev.HasToolCalls,
|
||||
FinishReason: ev.FinishReason,
|
||||
InputTokens: ev.Usage.InputTokens,
|
||||
OutputTokens: ev.Usage.OutputTokens,
|
||||
CacheReadTokens: ev.Usage.CacheReadTokens,
|
||||
CacheWriteTokens: ev.Usage.CacheCreationTokens,
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
bridgeObserve(m, runner, extensions.ReasoningStart, func(ev ReasoningStartEvent) extensions.Event {
|
||||
return extensions.ReasoningStartEvent{ID: ev.ID}
|
||||
})
|
||||
|
||||
if runner.HasHandlers(extensions.ReasoningStart) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(ReasoningStartEvent); ok {
|
||||
_, _ = runner.Emit(extensions.ReasoningStartEvent{ID: ev.ID})
|
||||
}
|
||||
})
|
||||
}
|
||||
bridgeObserve(m, runner, extensions.Warnings, func(ev WarningsEvent) extensions.Event {
|
||||
return extensions.WarningsEvent{Warnings: ev.Warnings}
|
||||
})
|
||||
|
||||
if runner.HasHandlers(extensions.Warnings) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(WarningsEvent); ok {
|
||||
_, _ = runner.Emit(extensions.WarningsEvent{Warnings: ev.Warnings})
|
||||
}
|
||||
})
|
||||
}
|
||||
bridgeObserve(m, runner, extensions.Source, func(ev SourceEvent) extensions.Event {
|
||||
return extensions.SourceEvent{
|
||||
SourceType: ev.SourceType,
|
||||
ID: ev.ID,
|
||||
URL: ev.URL,
|
||||
Title: ev.Title,
|
||||
}
|
||||
})
|
||||
|
||||
if runner.HasHandlers(extensions.Source) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(SourceEvent); ok {
|
||||
_, _ = runner.Emit(extensions.SourceEvent{
|
||||
SourceType: ev.SourceType,
|
||||
ID: ev.ID,
|
||||
URL: ev.URL,
|
||||
Title: ev.Title,
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
bridgeObserve(m, runner, extensions.Error, func(ev ErrorEvent) extensions.Event {
|
||||
return extensions.ErrorEvent{Error: ev.Error.Error()}
|
||||
})
|
||||
|
||||
if runner.HasHandlers(extensions.Error) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(ErrorEvent); ok {
|
||||
_, _ = runner.Emit(extensions.ErrorEvent{Error: ev.Error.Error()})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if runner.HasHandlers(extensions.Retry) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(RetryEvent); ok {
|
||||
_, _ = runner.Emit(extensions.RetryEvent{
|
||||
Attempt: ev.Attempt,
|
||||
Error: ev.Error.Error(),
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
bridgeObserve(m, runner, extensions.Retry, func(ev RetryEvent) extensions.Event {
|
||||
return extensions.RetryEvent{
|
||||
Attempt: ev.Attempt,
|
||||
Error: ev.Error.Error(),
|
||||
}
|
||||
})
|
||||
|
||||
// --- PrepareStep hook ---
|
||||
// Extension PrepareStep → SDK PrepareStep hook.
|
||||
// Same pattern as ContextPrepare: convert LLMMessage ↔ ContextMessage.
|
||||
if runner.HasHandlers(extensions.PrepareStep) {
|
||||
m.OnPrepareStep(HookPriorityNormal, func(h PrepareStepHook) *PrepareStepResult {
|
||||
// Convert LLM message slice to extension ContextMessage slice.
|
||||
extMsgs := make([]extensions.ContextMessage, len(h.Messages))
|
||||
for i, msg := range h.Messages {
|
||||
var sb strings.Builder
|
||||
for _, part := range msg.Content {
|
||||
if tp, ok := part.(LLMTextPart); ok {
|
||||
sb.WriteString(tp.Text)
|
||||
}
|
||||
}
|
||||
extMsgs[i] = extensions.ContextMessage{
|
||||
Index: i,
|
||||
Role: string(msg.Role),
|
||||
Content: sb.String(),
|
||||
}
|
||||
}
|
||||
|
||||
extMsgs := llmToContextMessages(h.Messages)
|
||||
result, _ := runner.Emit(extensions.PrepareStepEvent{
|
||||
StepNumber: h.StepNumber,
|
||||
Messages: extMsgs,
|
||||
@@ -460,30 +407,232 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
if !ok || r.Messages == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rebuild LLM message slice from extension result.
|
||||
rebuilt := make([]LLMMessage, 0, len(r.Messages))
|
||||
for _, cm := range r.Messages {
|
||||
if cm.Index >= 0 && cm.Index < len(h.Messages) {
|
||||
rebuilt = append(rebuilt, h.Messages[cm.Index])
|
||||
} else {
|
||||
role := LLMRoleUser
|
||||
switch cm.Role {
|
||||
case "assistant":
|
||||
role = LLMRoleAssistant
|
||||
case "system":
|
||||
role = LLMRoleSystem
|
||||
case "tool":
|
||||
role = LLMRoleTool
|
||||
}
|
||||
rebuilt = append(rebuilt, LLMMessage{
|
||||
Role: role,
|
||||
Content: []LLMMessagePart{LLMTextPart{Text: cm.Content}},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return &PrepareStepResult{Messages: rebuilt}
|
||||
return &PrepareStepResult{Messages: contextMessagesToLLM(r.Messages, h.Messages)}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// bridgeObserve subscribes to SDK events of type In and forwards them to the
|
||||
// extension runner as the event returned by conv. The subscription is only
|
||||
// registered when the runner has handlers for the given event kind.
|
||||
func bridgeObserve[In Event](m *Kit, runner *extensions.Runner, kind extensions.EventType, conv func(In) extensions.Event) {
|
||||
if !runner.HasHandlers(kind) {
|
||||
return
|
||||
}
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(In); ok {
|
||||
_, _ = runner.Emit(conv(ev))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// turnAggregator collects per-turn signals (tool calls, LLM round-trips, token
|
||||
// usage, wall-clock duration) so that the enriched AgentEndEvent can be
|
||||
// populated without requiring extensions to maintain parallel bookkeeping.
|
||||
//
|
||||
// The aggregator resets on each TurnStartEvent and is consumed (snapshotted +
|
||||
// reset) on TurnEndEvent. All access is serialized via a mutex because the
|
||||
// underlying event bus may fan handlers across goroutines in the future.
|
||||
type turnAggregator struct {
|
||||
mu sync.Mutex
|
||||
started time.Time
|
||||
ended time.Time
|
||||
toolCallCount int
|
||||
toolNames []string
|
||||
llmCallCount int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
cacheReadTokens int
|
||||
cacheWriteTokens int
|
||||
cost float64
|
||||
kit *Kit
|
||||
}
|
||||
|
||||
type turnSnapshot struct {
|
||||
started time.Time
|
||||
ended time.Time
|
||||
toolCallCount int
|
||||
toolNames []string
|
||||
llmCallCount int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
cacheReadTokens int
|
||||
cacheWriteTokens int
|
||||
cost float64
|
||||
}
|
||||
|
||||
func (s turnSnapshot) durationMs() int64 {
|
||||
if s.started.IsZero() {
|
||||
return 0
|
||||
}
|
||||
end := s.ended
|
||||
if end.IsZero() {
|
||||
end = time.Now()
|
||||
}
|
||||
return end.Sub(s.started).Milliseconds()
|
||||
}
|
||||
|
||||
// start resets all counters and records the turn's start time. Called from
|
||||
// the TurnStartEvent subscriber.
|
||||
func (a *turnAggregator) start() {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.started = time.Now()
|
||||
a.ended = time.Time{}
|
||||
a.toolCallCount = 0
|
||||
a.toolNames = nil
|
||||
a.llmCallCount = 0
|
||||
a.inputTokens = 0
|
||||
a.outputTokens = 0
|
||||
a.cacheReadTokens = 0
|
||||
a.cacheWriteTokens = 0
|
||||
a.cost = 0
|
||||
}
|
||||
|
||||
func (a *turnAggregator) recordTool(name string) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.toolCallCount++
|
||||
if name != "" {
|
||||
a.toolNames = append(a.toolNames, name)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *turnAggregator) recordStep(usage LLMUsage) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.llmCallCount++
|
||||
a.inputTokens += int(usage.InputTokens)
|
||||
a.outputTokens += int(usage.OutputTokens)
|
||||
a.cacheReadTokens += int(usage.CacheReadTokens)
|
||||
a.cacheWriteTokens += int(usage.CacheCreationTokens)
|
||||
if a.kit != nil {
|
||||
_, _, c := llmUsageMeta(a.kit, usage)
|
||||
a.cost += c
|
||||
}
|
||||
}
|
||||
|
||||
// consume returns a snapshot of the current turn and marks it ended.
|
||||
// Subsequent start() calls clear the snapshot.
|
||||
func (a *turnAggregator) consume() turnSnapshot {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.ended = time.Now()
|
||||
names := a.toolNames
|
||||
if len(names) > 0 {
|
||||
copied := make([]string, len(names))
|
||||
copy(copied, names)
|
||||
names = copied
|
||||
}
|
||||
return turnSnapshot{
|
||||
started: a.started,
|
||||
ended: a.ended,
|
||||
toolCallCount: a.toolCallCount,
|
||||
toolNames: names,
|
||||
llmCallCount: a.llmCallCount,
|
||||
inputTokens: a.inputTokens,
|
||||
outputTokens: a.outputTokens,
|
||||
cacheReadTokens: a.cacheReadTokens,
|
||||
cacheWriteTokens: a.cacheWriteTokens,
|
||||
cost: a.cost,
|
||||
}
|
||||
}
|
||||
|
||||
// llmUsageMeta returns the current provider, model id, and computed cost for
|
||||
// the given usage values using the Kit instance's active model. Cost is zero
|
||||
// in any of the following cases:
|
||||
// - the *Kit pointer is nil or has no active model;
|
||||
// - the model is not in the registry (custom fine-tunes, unknown providers);
|
||||
// - the model has no pricing fields set;
|
||||
// - the active credential is an Anthropic OAuth token (matches the
|
||||
// existing usage_tracker behavior of suppressing cost for OAuth users).
|
||||
func llmUsageMeta(m *Kit, usage LLMUsage) (provider, modelID string, cost float64) {
|
||||
if m == nil {
|
||||
return "", "", 0
|
||||
}
|
||||
modelString := m.GetModelString()
|
||||
if modelString == "" {
|
||||
return "", "", 0
|
||||
}
|
||||
p, id, err := models.ParseModelString(modelString)
|
||||
if err != nil {
|
||||
return "", "", 0
|
||||
}
|
||||
provider, modelID = p, id
|
||||
info := models.GetGlobalRegistry().LookupModel(provider, modelID)
|
||||
if info == nil {
|
||||
return provider, modelID, 0
|
||||
}
|
||||
if isAnthropicOAuth(m, provider) {
|
||||
return provider, modelID, 0
|
||||
}
|
||||
cost = float64(usage.InputTokens) * info.Cost.Input / 1_000_000
|
||||
cost += float64(usage.OutputTokens) * info.Cost.Output / 1_000_000
|
||||
if info.Cost.CacheRead != nil {
|
||||
cost += float64(usage.CacheReadTokens) * (*info.Cost.CacheRead) / 1_000_000
|
||||
}
|
||||
if info.Cost.CacheWrite != nil {
|
||||
cost += float64(usage.CacheCreationTokens) * (*info.Cost.CacheWrite) / 1_000_000
|
||||
}
|
||||
return provider, modelID, cost
|
||||
}
|
||||
|
||||
// isAnthropicOAuth reports whether the current Anthropic credential resolves
|
||||
// to a stored OAuth token (in which case the user is not billed per-token),
|
||||
// so OnLLMUsage cost reporting agrees with ctx.GetSessionUsage().
|
||||
func isAnthropicOAuth(m *Kit, provider string) bool {
|
||||
if m == nil || provider != "anthropic" {
|
||||
return false
|
||||
}
|
||||
return auth.IsAnthropicOAuth(m.v.GetString("provider-api-key"))
|
||||
}
|
||||
|
||||
// llmToContextMessages converts a slice of LLM messages to extension
|
||||
// ContextMessage values, extracting plain text from each message.
|
||||
func llmToContextMessages(msgs []LLMMessage) []extensions.ContextMessage {
|
||||
extMsgs := make([]extensions.ContextMessage, len(msgs))
|
||||
for i, msg := range msgs {
|
||||
var sb strings.Builder
|
||||
for _, part := range msg.Content {
|
||||
if tp, ok := part.(LLMTextPart); ok {
|
||||
sb.WriteString(tp.Text)
|
||||
}
|
||||
}
|
||||
extMsgs[i] = extensions.ContextMessage{
|
||||
Index: i,
|
||||
Role: string(msg.Role),
|
||||
Content: sb.String(),
|
||||
}
|
||||
}
|
||||
return extMsgs
|
||||
}
|
||||
|
||||
// contextMessagesToLLM rebuilds an LLM message slice from extension
|
||||
// ContextMessages. Messages with a valid index reuse the original from
|
||||
// originals; new messages injected by extensions are constructed from
|
||||
// role + text.
|
||||
func contextMessagesToLLM(cms []extensions.ContextMessage, originals []LLMMessage) []LLMMessage {
|
||||
rebuilt := make([]LLMMessage, 0, len(cms))
|
||||
for _, cm := range cms {
|
||||
if cm.Index >= 0 && cm.Index < len(originals) {
|
||||
// Reuse original message (preserves original role and content).
|
||||
rebuilt = append(rebuilt, originals[cm.Index])
|
||||
} else {
|
||||
// New message injected by extension — construct from role + text.
|
||||
role := LLMRoleUser
|
||||
switch cm.Role {
|
||||
case "assistant":
|
||||
role = LLMRoleAssistant
|
||||
case "system":
|
||||
role = LLMRoleSystem
|
||||
case "tool":
|
||||
role = LLMRoleTool
|
||||
}
|
||||
rebuilt = append(rebuilt, LLMMessage{
|
||||
Role: role,
|
||||
Content: []LLMMessagePart{LLMTextPart{Text: cm.Content}},
|
||||
})
|
||||
}
|
||||
}
|
||||
return rebuilt
|
||||
}
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestTurnAggregator_BasicLifecycle exercises the per-turn aggregator:
|
||||
// start → record several tools and steps → consume → snapshot should reflect
|
||||
// the accumulated counts and zero out for the next turn.
|
||||
func TestTurnAggregator_BasicLifecycle(t *testing.T) {
|
||||
agg := &turnAggregator{}
|
||||
|
||||
agg.start()
|
||||
agg.recordTool("bash")
|
||||
agg.recordTool("read")
|
||||
agg.recordTool("bash")
|
||||
agg.recordStep(LLMUsage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
CacheReadTokens: 10,
|
||||
CacheCreationTokens: 5,
|
||||
})
|
||||
agg.recordStep(LLMUsage{
|
||||
InputTokens: 200,
|
||||
OutputTokens: 75,
|
||||
})
|
||||
|
||||
snap := agg.consume()
|
||||
if snap.toolCallCount != 3 {
|
||||
t.Errorf("toolCallCount: got %d want 3", snap.toolCallCount)
|
||||
}
|
||||
wantNames := []string{"bash", "read", "bash"}
|
||||
if len(snap.toolNames) != len(wantNames) {
|
||||
t.Fatalf("toolNames length: got %d want %d", len(snap.toolNames), len(wantNames))
|
||||
}
|
||||
for i, n := range wantNames {
|
||||
if snap.toolNames[i] != n {
|
||||
t.Errorf("toolNames[%d]: got %q want %q", i, snap.toolNames[i], n)
|
||||
}
|
||||
}
|
||||
if snap.llmCallCount != 2 {
|
||||
t.Errorf("llmCallCount: got %d want 2", snap.llmCallCount)
|
||||
}
|
||||
if snap.inputTokens != 300 {
|
||||
t.Errorf("inputTokens: got %d want 300", snap.inputTokens)
|
||||
}
|
||||
if snap.outputTokens != 125 {
|
||||
t.Errorf("outputTokens: got %d want 125", snap.outputTokens)
|
||||
}
|
||||
if snap.cacheReadTokens != 10 {
|
||||
t.Errorf("cacheReadTokens: got %d want 10", snap.cacheReadTokens)
|
||||
}
|
||||
if snap.cacheWriteTokens != 5 {
|
||||
t.Errorf("cacheWriteTokens: got %d want 5", snap.cacheWriteTokens)
|
||||
}
|
||||
if snap.durationMs() < 0 {
|
||||
t.Errorf("durationMs should not be negative, got %d", snap.durationMs())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTurnAggregator_StartResetsCounters(t *testing.T) {
|
||||
agg := &turnAggregator{}
|
||||
agg.start()
|
||||
agg.recordTool("bash")
|
||||
agg.recordStep(LLMUsage{InputTokens: 50})
|
||||
|
||||
// Begin a new turn — previous counters should be cleared.
|
||||
agg.start()
|
||||
snap := agg.consume()
|
||||
|
||||
if snap.toolCallCount != 0 || snap.llmCallCount != 0 || snap.inputTokens != 0 {
|
||||
t.Errorf("expected counters zeroed after start(), got %+v", snap)
|
||||
}
|
||||
if snap.toolNames != nil {
|
||||
t.Errorf("expected toolNames=nil after start(), got %v", snap.toolNames)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTurnAggregator_DurationMs verifies the snapshot computes a positive
|
||||
// duration when consume() runs after start().
|
||||
func TestTurnAggregator_DurationMs(t *testing.T) {
|
||||
agg := &turnAggregator{}
|
||||
agg.start()
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
snap := agg.consume()
|
||||
if snap.durationMs() < 1 {
|
||||
t.Errorf("expected positive duration, got %d", snap.durationMs())
|
||||
}
|
||||
}
|
||||
|
||||
// TestTurnAggregator_ZeroStartSafe ensures a snapshot taken without a prior
|
||||
// start() doesn't crash and reports zero duration.
|
||||
func TestTurnAggregator_ZeroStartSafe(t *testing.T) {
|
||||
agg := &turnAggregator{}
|
||||
snap := agg.consume()
|
||||
if snap.durationMs() != 0 {
|
||||
t.Errorf("expected zero duration for unstarted aggregator, got %d", snap.durationMs())
|
||||
}
|
||||
}
|
||||
|
||||
// TestLLMUsageMeta_NilKit verifies the helper degrades gracefully when given
|
||||
// a nil Kit instance (zero values, no panic).
|
||||
func TestLLMUsageMeta_NilKit(t *testing.T) {
|
||||
provider, modelID, cost := llmUsageMeta(nil, LLMUsage{InputTokens: 100})
|
||||
if provider != "" || modelID != "" || cost != 0 {
|
||||
t.Errorf("expected zero values for nil kit, got (%q,%q,%v)", provider, modelID, cost)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsAnthropicOAuth_NonAnthropic verifies the helper short-circuits for any
|
||||
// provider other than "anthropic" without touching the credential store.
|
||||
func TestIsAnthropicOAuth_NonAnthropic(t *testing.T) {
|
||||
for _, provider := range []string{"openai", "google", "openrouter", ""} {
|
||||
if isAnthropicOAuth(nil, provider) {
|
||||
t.Errorf("isAnthropicOAuth(nil, %q) = true, want false", provider)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtStateSidecarPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"empty", "", ""},
|
||||
{"jsonl", "/tmp/sessions/abc.jsonl", "/tmp/sessions/abc.ext-state.json"},
|
||||
{"jsonl with subdir", "/a/b/c.jsonl", "/a/b/c.ext-state.json"},
|
||||
{"no extension", "/tmp/session-blob", "/tmp/session-blob.ext-state.json"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := extStateSidecarPath(tc.in)
|
||||
if got != tc.want {
|
||||
t.Errorf("extStateSidecarPath(%q): got %q want %q", tc.in, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+288
-146
@@ -53,6 +53,14 @@ type Kit struct {
|
||||
opts *Options // stored for reload operations (skills, etc.)
|
||||
mcpConfig *config.Config // loaded MCP/server config, shared with subagents
|
||||
|
||||
// v is this Kit instance's isolated configuration store. Each Kit owns its
|
||||
// own *viper.Viper (constructed via viper.New) so that runtime config
|
||||
// mutators (SetModel, SetThinkingLevel) and config reads do not clobber or
|
||||
// observe state from other Kit instances in the same process. When the CLI
|
||||
// constructs a Kit (Options.CLI != nil) this points at the process-global
|
||||
// store so cobra flag bindings remain in effect.
|
||||
v *viper.Viper
|
||||
|
||||
// hasCustomSystemPrompt is true when the user explicitly configured a
|
||||
// system prompt (via --system-prompt flag, config file, or SDK option).
|
||||
// When false, per-model system prompts from modelSettings/customModels
|
||||
@@ -61,6 +69,11 @@ type Kit struct {
|
||||
// systemPromptSource holds the raw configured value (file path or text)
|
||||
// when hasCustomSystemPrompt is true; empty when the built-in default is in use.
|
||||
systemPromptSource string
|
||||
// basePrompt holds the resolved base system prompt text (post file-load,
|
||||
// pre runtime-context composition) captured during New. Used by
|
||||
// RefreshSystemPrompt to recompose after skills/context-file mutations.
|
||||
// Protected by runtimeMu.
|
||||
basePrompt string
|
||||
|
||||
// Hook registries — interception layer (see hooks.go).
|
||||
beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult]
|
||||
@@ -90,6 +103,12 @@ type Kit struct {
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// runtimeMu protects contextFiles and skills against concurrent runtime
|
||||
// mutations via AddSkill / RemoveSkill / AddContextFile etc. The fields
|
||||
// are read by composeSystemPrompt and several other accessors, so all
|
||||
// reads and writes after Kit construction must take this lock.
|
||||
runtimeMu sync.RWMutex
|
||||
|
||||
// steerCh is a buffered channel used to inject steering messages into
|
||||
// the running agent turn via the LLM library's PrepareStep. Created fresh for
|
||||
// each generate() call and set to nil when idle. Protected by steerMu.
|
||||
@@ -119,6 +138,19 @@ func (m *Kit) GetToolNames() []string {
|
||||
return names
|
||||
}
|
||||
|
||||
// GetToolsForSubagent like GetTools but eliminates subagent tool
|
||||
// to avoid infinite recursion.
|
||||
func (m *Kit) GetToolsForSubagent() []Tool {
|
||||
var tools []Tool
|
||||
for _, t := range m.agent.GetTools() {
|
||||
if t.Info().Name == "subagent" {
|
||||
continue
|
||||
}
|
||||
tools = append(tools, t)
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
// GetLoadingMessage returns the agent's startup info message (e.g. GPU
|
||||
// fallback info), or empty string if none.
|
||||
func (m *Kit) GetLoadingMessage() string {
|
||||
@@ -544,8 +576,8 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
|
||||
|
||||
// Build a provider config from current settings, overriding the model.
|
||||
// Load system prompt properly (handles both file paths and inline content).
|
||||
systemPrompt, _ := config.LoadSystemPrompt(viper.GetString("system-prompt"))
|
||||
thinkingLevel := models.ParseThinkingLevel(viper.GetString("thinking-level"))
|
||||
systemPrompt, _ := config.LoadSystemPrompt(m.v.GetString("system-prompt"))
|
||||
thinkingLevel := models.ParseThinkingLevel(m.v.GetString("thinking-level"))
|
||||
|
||||
// Validate and adjust thinking level for the target model.
|
||||
// Some models (e.g., OpenAI gpt-5.4) don't support "minimal" and require "none".
|
||||
@@ -556,8 +588,8 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
|
||||
if !models.IsValidThinkingLevelForModel(thinkingLevel, modelName) {
|
||||
fallback := models.SuggestThinkingLevelFallback(thinkingLevel, modelName)
|
||||
if fallback != models.ThinkingOff {
|
||||
// Adjust the thinking level in viper so the change persists.
|
||||
viper.Set("thinking-level", string(fallback))
|
||||
// Adjust the thinking level in the instance store so the change persists.
|
||||
m.v.Set("thinking-level", string(fallback))
|
||||
thinkingLevel = fallback
|
||||
}
|
||||
}
|
||||
@@ -569,35 +601,36 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
|
||||
cfg := &models.ProviderConfig{
|
||||
ModelString: modelString,
|
||||
SystemPrompt: systemPrompt,
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
ProviderURL: viper.GetString("provider-url"),
|
||||
MaxTokens: viper.GetInt("max-tokens"),
|
||||
TLSSkipVerify: viper.GetBool("tls-skip-verify"),
|
||||
ProviderAPIKey: m.v.GetString("provider-api-key"),
|
||||
ProviderURL: m.v.GetString("provider-url"),
|
||||
MaxTokens: m.v.GetInt("max-tokens"),
|
||||
TLSSkipVerify: m.v.GetBool("tls-skip-verify"),
|
||||
ThinkingLevel: thinkingLevel,
|
||||
DisableCaching: false, // Caching enabled by default, works with thinking
|
||||
ConfigStore: m.v,
|
||||
}
|
||||
|
||||
// Only set generation parameter pointers when the user has explicitly
|
||||
// provided a value. This leaves nil pointers for unset params, allowing
|
||||
// per-model defaults (modelSettings / customModels params) to apply.
|
||||
if viper.IsSet("temperature") {
|
||||
v := float32(viper.GetFloat64("temperature"))
|
||||
if m.v.IsSet("temperature") {
|
||||
v := float32(m.v.GetFloat64("temperature"))
|
||||
cfg.Temperature = &v
|
||||
}
|
||||
if viper.IsSet("top-p") {
|
||||
v := float32(viper.GetFloat64("top-p"))
|
||||
if m.v.IsSet("top-p") {
|
||||
v := float32(m.v.GetFloat64("top-p"))
|
||||
cfg.TopP = &v
|
||||
}
|
||||
if viper.IsSet("top-k") {
|
||||
v := int32(viper.GetInt("top-k"))
|
||||
if m.v.IsSet("top-k") {
|
||||
v := int32(m.v.GetInt("top-k"))
|
||||
cfg.TopK = &v
|
||||
}
|
||||
if viper.IsSet("frequency-penalty") {
|
||||
v := float32(viper.GetFloat64("frequency-penalty"))
|
||||
if m.v.IsSet("frequency-penalty") {
|
||||
v := float32(m.v.GetFloat64("frequency-penalty"))
|
||||
cfg.FrequencyPenalty = &v
|
||||
}
|
||||
if viper.IsSet("presence-penalty") {
|
||||
v := float32(viper.GetFloat64("presence-penalty"))
|
||||
if m.v.IsSet("presence-penalty") {
|
||||
v := float32(m.v.GetFloat64("presence-penalty"))
|
||||
cfg.PresencePenalty = &v
|
||||
}
|
||||
|
||||
@@ -653,18 +686,25 @@ func (m *Kit) GetSystemPromptSource() string {
|
||||
// composeSystemPrompt takes a base system prompt and composes it with the
|
||||
// current runtime context: AGENTS.md content, skills metadata, and date/cwd.
|
||||
// This mirrors the composition done during Kit.New() initialization.
|
||||
// It acquires a read lock on runtimeMu while snapshotting contextFiles and
|
||||
// skills, so callers must not hold the write lock.
|
||||
func (m *Kit) composeSystemPrompt(basePrompt string) string {
|
||||
cwd, _ := os.Getwd()
|
||||
pb := skills.NewPromptBuilder(basePrompt)
|
||||
|
||||
m.runtimeMu.RLock()
|
||||
contextFiles := append([]*ContextFile(nil), m.contextFiles...)
|
||||
loadedSkills := append([]*skills.Skill(nil), m.skills...)
|
||||
m.runtimeMu.RUnlock()
|
||||
|
||||
// Inject AGENTS.md content as project context.
|
||||
for _, cf := range m.contextFiles {
|
||||
for _, cf := range contextFiles {
|
||||
pb.WithSection("", fmt.Sprintf("Instructions from: %s\n\n%s", cf.Path, cf.Content))
|
||||
}
|
||||
|
||||
// Inject skills metadata.
|
||||
if len(m.skills) > 0 {
|
||||
pb.WithSkills(m.skills)
|
||||
if len(loadedSkills) > 0 {
|
||||
pb.WithSkills(loadedSkills)
|
||||
}
|
||||
|
||||
// Append current date/time and working directory.
|
||||
@@ -716,7 +756,7 @@ func (m *Kit) ReloadExtensions() error {
|
||||
}
|
||||
|
||||
// Re-load from disk.
|
||||
extraPaths := viper.GetStringSlice("extension")
|
||||
extraPaths := m.v.GetStringSlice("extension")
|
||||
loaded, err := extensions.LoadExtensions(extraPaths)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reloading extensions: %w", err)
|
||||
@@ -724,6 +764,7 @@ func (m *Kit) ReloadExtensions() error {
|
||||
|
||||
// Swap extensions on the runner (clears dynamic state).
|
||||
m.extRunner.Reload(loaded)
|
||||
m.extRunner.SetConfigStore(m.v)
|
||||
|
||||
// Update extension tools on the agent so the LLM sees changes.
|
||||
if m.agent != nil {
|
||||
@@ -762,7 +803,8 @@ func (m *Kit) ExecuteCompletion(ctx context.Context, req extensions.CompleteRequ
|
||||
// Create a temporary provider for the requested model.
|
||||
config := &models.ProviderConfig{
|
||||
ModelString: req.Model,
|
||||
TLSSkipVerify: viper.GetBool("tls-skip-verify"),
|
||||
TLSSkipVerify: m.v.GetBool("tls-skip-verify"),
|
||||
ConfigStore: m.v,
|
||||
}
|
||||
if req.MaxTokens > 0 {
|
||||
config.MaxTokens = req.MaxTokens
|
||||
@@ -848,37 +890,30 @@ func (m *Kit) ExecuteCompletion(ctx context.Context, req extensions.CompleteRequ
|
||||
// prompts, configuration, and behavior settings. All fields are optional
|
||||
// and will use CLI defaults if not specified.
|
||||
//
|
||||
// Global viper state warning:
|
||||
// Options are applied by [New] via [viper.Set] calls against viper's
|
||||
// process-global store. This store is shared with every downstream reader
|
||||
// (e.g. [Kit.SetModel], [Kit.GetThinkingLevel], BuildProviderConfig, and
|
||||
// any other code path that calls viper.Get*). Two consequences:
|
||||
//
|
||||
// 1. Kit instances are NOT isolated from each other within a single
|
||||
// process. Values set by the second New() call overwrite the first,
|
||||
// and any code that later reads viper will see the most recent Set.
|
||||
// 2. Fields left at the zero value do NOT clear prior viper state; they
|
||||
// simply skip the viper.Set. Callers that need a clean slate between
|
||||
// constructions should invoke viper.Reset() (the test suite uses a
|
||||
// private resetViper() helper that wraps it) before the next New().
|
||||
//
|
||||
// Recommended usage: create one Kit per process, or reset viper between
|
||||
// constructions. Concurrent calls to New are serialized internally by
|
||||
// [viperInitMu], but that mutex does not prevent later viper reads (from
|
||||
// a different Kit) from observing mutated keys.
|
||||
//
|
||||
// TODO: refactor New to use a per-instance *viper.Viper (constructed via
|
||||
// viper.New()) so each Kit owns its own isolated config store and Options
|
||||
// no longer leak through the global singleton.
|
||||
// Config isolation: each [New] / [NewAgent] call constructs its own isolated
|
||||
// configuration store (via viper.New internally). Options are applied to that
|
||||
// per-instance store, so two Kits constructed in the same process do NOT share
|
||||
// or clobber each other's configuration. Runtime mutators ([Kit.SetModel],
|
||||
// [Kit.SetThinkingLevel]) and config readers ([Kit.GetThinkingLevel]) operate
|
||||
// only on the owning instance. Fields left at their zero value are simply not
|
||||
// applied; they fall through to the precedence chain (env → .kit.yml →
|
||||
// per-model defaults) resolved within the instance's own store.
|
||||
type Options struct {
|
||||
Model string // Override model (e.g., "anthropic/claude-sonnet-4-5-20250929")
|
||||
SystemPrompt string // Override system prompt
|
||||
ConfigFile string // Override config file path
|
||||
MaxSteps int // Override max steps (0 = use default)
|
||||
Streaming bool // Enable streaming (default from config)
|
||||
Quiet bool // Suppress debug output
|
||||
Tools []Tool // Custom tool set. If empty, AllTools() is used.
|
||||
ExtraTools []Tool // Additional tools added alongside core/MCP/extension tools.
|
||||
|
||||
// Streaming enables or disables streaming output. It is a pointer so the
|
||||
// SDK can distinguish "unset" (nil) from an explicit choice, mirroring the
|
||||
// sampling-parameter fields below. nil leaves streaming to the precedence
|
||||
// chain (env → .kit.yml → default true); a non-nil value forces it. Prefer
|
||||
// [WithStreaming] for the functional-options API.
|
||||
Streaming *bool
|
||||
|
||||
Quiet bool // Suppress debug output
|
||||
Tools []Tool // Custom tool set. If empty, AllTools() is used.
|
||||
ExtraTools []Tool // Additional tools added alongside core/MCP/extension tools.
|
||||
|
||||
// Generation parameters. These override the corresponding values from
|
||||
// .kit.yml / KIT_* environment variables. Leaving a field at its
|
||||
@@ -1125,7 +1160,7 @@ type CLIOptions struct {
|
||||
// - Continue: resume most recent session for SessionDir (or cwd)
|
||||
// - SessionPath: open a specific JSONL session file
|
||||
// - default: create a new tree session for SessionDir (or cwd)
|
||||
func InitTreeSession(opts *Options) (*session.TreeManager, error) {
|
||||
func InitTreeSession(opts *Options) (*TreeManager, error) {
|
||||
if opts == nil {
|
||||
opts = &Options{}
|
||||
}
|
||||
@@ -1151,40 +1186,40 @@ func InitTreeSession(opts *Options) (*session.TreeManager, error) {
|
||||
return session.CreateTreeSession(sessionDir)
|
||||
}
|
||||
|
||||
// viperInitMu serializes viper writes during [New]. Viper's global state
|
||||
// is not thread-safe, so concurrent calls (e.g. parallel subagent spawns)
|
||||
// must not overlap the Set/Get window. Note that this mutex only protects
|
||||
// the construction window — it does not isolate long-lived Kit instances
|
||||
// from each other. See the "Global viper state warning" on [Options].
|
||||
var viperInitMu sync.Mutex
|
||||
|
||||
// New creates a Kit instance using the same initialization as the CLI.
|
||||
// It loads configuration, initializes MCP servers, creates the LLM model, and
|
||||
// sets up the agent for interaction. Returns an error if initialization fails.
|
||||
//
|
||||
// Global viper state warning: fields on [Options] are applied by calling
|
||||
// [viper.Set] on viper's process-global store. As a result, two Kits
|
||||
// constructed in the same process are NOT isolated: the second New
|
||||
// overwrites viper keys set by the first, and any downstream reader
|
||||
// (e.g. [Kit.SetModel], [Kit.GetThinkingLevel]) will observe the most
|
||||
// recent value. Callers that need multiple independent Kits should call
|
||||
// viper.Reset() between constructions, or avoid constructing more than
|
||||
// one Kit per process. Writes during New are serialized by [viperInitMu].
|
||||
// Config isolation: New constructs a per-instance configuration store (via
|
||||
// viper.New internally) and applies [Options] to it. Two Kits constructed in
|
||||
// the same process are therefore fully isolated — neither overwrites the
|
||||
// other's model, thinking level, or generation parameters, and runtime
|
||||
// mutators ([Kit.SetModel], [Kit.SetThinkingLevel]) only affect the owning
|
||||
// instance. This makes subagent spawning and multi-Kit embedding safe without
|
||||
// any external synchronization.
|
||||
//
|
||||
// TODO: refactor to use a per-call viper.New() instance so each Kit owns
|
||||
// its own isolated config store and Options stop leaking through the
|
||||
// global singleton.
|
||||
// CLI integration: when Options.CLI is non-nil the Kit shares the
|
||||
// process-global viper store instead of allocating a fresh one, so cobra flag
|
||||
// bindings established by the CLI remain in effect. SDK callers leave
|
||||
// Options.CLI nil and always get an isolated store.
|
||||
//
|
||||
// For an ergonomic functional-options front door, see [NewAgent].
|
||||
func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
if opts == nil {
|
||||
opts = &Options{}
|
||||
}
|
||||
|
||||
// All viper writes (SetSDKDefaults, InitConfig, Set calls, system-prompt
|
||||
// composition) happen under viperInitMu. We also call BuildProviderConfig
|
||||
// here — it's fast (just reads) — so we can capture the full config
|
||||
// snapshot before releasing the lock. The expensive work (MCP loading,
|
||||
// provider creation, session init) then runs outside the lock, allowing
|
||||
// parallel subagent spawns to proceed concurrently.
|
||||
// Construct this Kit's configuration store. SDK callers get a fresh,
|
||||
// isolated *viper.Viper so concurrent constructions never clobber each
|
||||
// other. The CLI (Options.CLI != nil) shares the process-global store so
|
||||
// its cobra flag bindings and pre-loaded config remain visible.
|
||||
var v *viper.Viper
|
||||
if opts.CLI != nil {
|
||||
v = viper.GetViper()
|
||||
} else {
|
||||
v = viper.New()
|
||||
}
|
||||
|
||||
var (
|
||||
providerConfig *models.ProviderConfig
|
||||
modelString string
|
||||
@@ -1194,86 +1229,93 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
mcpConfig *config.Config
|
||||
debug bool
|
||||
noExtensions bool
|
||||
disableCoreTools bool
|
||||
maxSteps int
|
||||
streaming bool
|
||||
hasCustomSystemPrompt bool
|
||||
systemPromptSource string
|
||||
capturedBasePrompt string
|
||||
)
|
||||
|
||||
if err := func() error {
|
||||
viperInitMu.Lock()
|
||||
defer viperInitMu.Unlock()
|
||||
// Set CLI-equivalent defaults on the instance store. When used as an
|
||||
// SDK (without cobra), these defaults are not registered via flag bindings.
|
||||
setSDKDefaults(v)
|
||||
|
||||
// Set CLI-equivalent defaults for viper. When used as an SDK (without
|
||||
// cobra), these defaults are not registered via flag bindings.
|
||||
setSDKDefaults()
|
||||
|
||||
// Initialize config (loads config files and env vars).
|
||||
// Only initialize if not already done (e.g., by CLI's cobra.OnInitialize).
|
||||
// Check if model is already set, which indicates config was loaded.
|
||||
// Initialize config (loads config files and env vars) into the instance
|
||||
// store. The CLI shares the process-global store, which cobra.OnInitialize
|
||||
// has already populated, so re-running initConfig there is unnecessary;
|
||||
// SDK callers get a fresh isolated store that must be loaded here.
|
||||
// We key off opts.CLI (not a config value) because setSDKDefaults always
|
||||
// seeds "model", which would otherwise mask an empty store.
|
||||
// SkipConfig bypasses .kit.yml file loading (viper defaults and env vars still apply).
|
||||
if !opts.SkipConfig && viper.GetString("model") == "" {
|
||||
if err := InitConfig(opts.ConfigFile, false); err != nil {
|
||||
if !opts.SkipConfig && opts.CLI == nil {
|
||||
if err := initConfig(v, opts.ConfigFile, false); err != nil {
|
||||
return fmt.Errorf("failed to initialize config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle CLI debug mode.
|
||||
if opts.Debug {
|
||||
viper.Set("debug", true)
|
||||
v.Set("debug", true)
|
||||
}
|
||||
|
||||
// Override viper settings with options.
|
||||
// Override instance settings with options.
|
||||
if opts.Model != "" {
|
||||
viper.Set("model", opts.Model)
|
||||
v.Set("model", opts.Model)
|
||||
}
|
||||
if opts.SystemPrompt != "" {
|
||||
viper.Set("system-prompt", opts.SystemPrompt)
|
||||
v.Set("system-prompt", opts.SystemPrompt)
|
||||
}
|
||||
if opts.MaxSteps > 0 {
|
||||
viper.Set("max-steps", opts.MaxSteps)
|
||||
v.Set("max-steps", opts.MaxSteps)
|
||||
}
|
||||
// Only override streaming when the caller explicitly set it. Otherwise
|
||||
// leave the precedence chain (env → config → default true) untouched so a
|
||||
// zero-valued Options does not silently force stream=false.
|
||||
if opts.Streaming != nil {
|
||||
v.Set("stream", *opts.Streaming)
|
||||
}
|
||||
viper.Set("stream", opts.Streaming)
|
||||
|
||||
// Generation parameter overrides. Each Options field, when set,
|
||||
// is pushed into viper here so the existing downstream code
|
||||
// (BuildProviderConfig, SetModel, modelSettings lookups) picks
|
||||
// it up uniformly. Pointer-typed sampling params use viper.Set
|
||||
// only when non-nil so that nil means "leave provider/per-model
|
||||
// default in place" (BuildProviderConfig keys off viper.IsSet).
|
||||
// is pushed into the instance store here so the existing downstream
|
||||
// code (BuildProviderConfig, SetModel, modelSettings lookups) picks
|
||||
// it up uniformly. Pointer-typed sampling params use Set only when
|
||||
// non-nil so that nil means "leave provider/per-model default in
|
||||
// place" (BuildProviderConfig keys off IsSet).
|
||||
if opts.MaxTokens > 0 {
|
||||
viper.Set("max-tokens", opts.MaxTokens)
|
||||
v.Set("max-tokens", opts.MaxTokens)
|
||||
}
|
||||
if opts.ThinkingLevel != "" {
|
||||
viper.Set("thinking-level", opts.ThinkingLevel)
|
||||
v.Set("thinking-level", opts.ThinkingLevel)
|
||||
}
|
||||
if opts.Temperature != nil {
|
||||
viper.Set("temperature", *opts.Temperature)
|
||||
v.Set("temperature", *opts.Temperature)
|
||||
}
|
||||
if opts.TopP != nil {
|
||||
viper.Set("top-p", *opts.TopP)
|
||||
v.Set("top-p", *opts.TopP)
|
||||
}
|
||||
if opts.TopK != nil {
|
||||
viper.Set("top-k", *opts.TopK)
|
||||
v.Set("top-k", *opts.TopK)
|
||||
}
|
||||
if opts.FrequencyPenalty != nil {
|
||||
viper.Set("frequency-penalty", *opts.FrequencyPenalty)
|
||||
v.Set("frequency-penalty", *opts.FrequencyPenalty)
|
||||
}
|
||||
if opts.PresencePenalty != nil {
|
||||
viper.Set("presence-penalty", *opts.PresencePenalty)
|
||||
v.Set("presence-penalty", *opts.PresencePenalty)
|
||||
}
|
||||
|
||||
// Provider overrides. TLSSkipVerify only takes effect when true —
|
||||
// callers wanting to force-disable should use the config file or
|
||||
// env var instead.
|
||||
if opts.ProviderAPIKey != "" {
|
||||
viper.Set("provider-api-key", opts.ProviderAPIKey)
|
||||
v.Set("provider-api-key", opts.ProviderAPIKey)
|
||||
}
|
||||
if opts.ProviderURL != "" {
|
||||
viper.Set("provider-url", opts.ProviderURL)
|
||||
v.Set("provider-url", opts.ProviderURL)
|
||||
}
|
||||
if opts.TLSSkipVerify {
|
||||
viper.Set("tls-skip-verify", true)
|
||||
v.Set("tls-skip-verify", true)
|
||||
}
|
||||
|
||||
// Resolve working directory for context/skill discovery.
|
||||
@@ -1288,9 +1330,25 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
}
|
||||
|
||||
// Load skills — either from explicit paths or via auto-discovery.
|
||||
if !opts.NoSkills {
|
||||
// Merge viper config with opts: CLI flag / config file values are
|
||||
// already bound to viper by cmd/root.go, so v.GetBool("no-skills"),
|
||||
// v.GetStringSlice("skill"), and v.GetString("skills-dir") capture
|
||||
// both --flag and .kit.yml keys transparently.
|
||||
noSkills := opts.NoSkills || v.GetBool("no-skills")
|
||||
skillPaths := opts.Skills
|
||||
if len(skillPaths) == 0 {
|
||||
skillPaths = v.GetStringSlice("skill")
|
||||
}
|
||||
skillsDir := opts.SkillsDir
|
||||
if skillsDir == "" {
|
||||
skillsDir = v.GetString("skills-dir")
|
||||
}
|
||||
if !noSkills {
|
||||
mergedOpts := *opts
|
||||
mergedOpts.Skills = skillPaths
|
||||
mergedOpts.SkillsDir = skillsDir
|
||||
var err error
|
||||
loadedSkills, err = loadSkills(opts)
|
||||
loadedSkills, err = loadSkills(&mergedOpts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load skills: %w", err)
|
||||
}
|
||||
@@ -1304,7 +1362,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
// explicitly set system-prompt, use the per-model prompt as the
|
||||
// base instead of the global default.
|
||||
{
|
||||
rawPromptInput := viper.GetString("system-prompt")
|
||||
rawPromptInput := v.GetString("system-prompt")
|
||||
|
||||
// Resolve a file path to its content so PromptBuilder receives the
|
||||
// actual prompt text rather than a literal path string. Without this,
|
||||
@@ -1329,12 +1387,12 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
// Check for per-model system prompt override when no explicit
|
||||
// global system-prompt was configured by the user.
|
||||
if !userSetSystemPrompt {
|
||||
modelStr := viper.GetString("model")
|
||||
modelStr := v.GetString("model")
|
||||
if modelStr != "" {
|
||||
if mi := models.LookupModelForSettings(modelStr); mi != nil {
|
||||
var perModelParams *models.GenerationParams
|
||||
// modelSettings takes priority over custom model params.
|
||||
if ms := models.LoadModelSettingsFromConfig(); ms != nil {
|
||||
if ms := models.LoadModelSettingsFrom(v); ms != nil {
|
||||
perModelParams = ms[modelStr]
|
||||
}
|
||||
if perModelParams == nil && mi.Params != nil {
|
||||
@@ -1349,6 +1407,10 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
|
||||
pb := skills.NewPromptBuilder(basePrompt)
|
||||
|
||||
// Capture the resolved base prompt so RefreshSystemPrompt can
|
||||
// recompose later after runtime skill/context-file mutations.
|
||||
capturedBasePrompt = basePrompt
|
||||
|
||||
// Inject AGENTS.md content as project context.
|
||||
for _, cf := range contextFiles {
|
||||
pb.WithSection("", fmt.Sprintf("Instructions from: %s\n\n%s", cf.Path, cf.Content))
|
||||
@@ -1365,41 +1427,42 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
time.Now().Format("Monday, January 2, 2006, 3:04:05 PM MST"), cwd,
|
||||
))
|
||||
|
||||
viper.Set("system-prompt", pb.Build())
|
||||
v.Set("system-prompt", pb.Build())
|
||||
}
|
||||
|
||||
// Snapshot all viper-derived values now, while the lock is held.
|
||||
// BuildProviderConfig is fast (pure reads), so we do it here.
|
||||
// Snapshot all instance-derived values now.
|
||||
// BuildProviderConfig is fast (pure reads).
|
||||
var pcErr error
|
||||
providerConfig, _, pcErr = kitsetup.BuildProviderConfig()
|
||||
providerConfig, _, pcErr = kitsetup.BuildProviderConfig(v)
|
||||
if pcErr != nil {
|
||||
return fmt.Errorf("failed to build provider config: %w", pcErr)
|
||||
}
|
||||
|
||||
// SDK last-resort max-tokens floor. When nothing — Options, env,
|
||||
// config, nor a per-model default — supplied a value, we land on
|
||||
// zero here (viper.GetInt returns 0 for unset keys). Apply the
|
||||
// SDK default directly on the struct rather than via viper so
|
||||
// viper.IsSet("max-tokens") stays false: downstream right-sizing
|
||||
// zero here (GetInt returns 0 for unset keys). Apply the
|
||||
// SDK default directly on the struct rather than via the store so
|
||||
// IsSet("max-tokens") stays false: downstream right-sizing
|
||||
// can still raise this toward the model's known output ceiling,
|
||||
// and per-model modelSettings[...].maxTokens can still win.
|
||||
if providerConfig.MaxTokens == 0 && opts.MaxTokens == 0 {
|
||||
providerConfig.MaxTokens = sdkDefaultMaxTokens
|
||||
}
|
||||
modelString = viper.GetString("model")
|
||||
debug = viper.GetBool("debug")
|
||||
noExtensions = opts.NoExtensions || viper.GetBool("no-extensions")
|
||||
maxSteps = viper.GetInt("max-steps")
|
||||
streaming = viper.GetBool("stream")
|
||||
modelString = v.GetString("model")
|
||||
debug = v.GetBool("debug")
|
||||
noExtensions = opts.NoExtensions || v.GetBool("no-extensions")
|
||||
disableCoreTools = opts.DisableCoreTools || v.GetBool("no-core-tools")
|
||||
maxSteps = v.GetInt("max-steps")
|
||||
streaming = v.GetBool("stream")
|
||||
|
||||
return nil
|
||||
}(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// ---- viperInitMu released — heavy I/O below runs concurrently ----
|
||||
// ---- config snapshot complete — heavy I/O below ----
|
||||
|
||||
// Load MCP configuration. Use pre-loaded config if provided directly,
|
||||
// via CLI options, or load from viper as a last resort.
|
||||
// via CLI options, or load from the instance store as a last resort.
|
||||
if opts.MCPConfig != nil {
|
||||
mcpConfig = opts.MCPConfig
|
||||
} else if opts.CLI != nil && opts.CLI.MCPConfig != nil {
|
||||
@@ -1407,7 +1470,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
}
|
||||
if mcpConfig == nil {
|
||||
var err error
|
||||
mcpConfig, err = config.LoadAndValidateConfig()
|
||||
mcpConfig, err = config.LoadAndValidateConfigFrom(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load MCP config: %w", err)
|
||||
}
|
||||
@@ -1446,7 +1509,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
MCPConfig: mcpConfig,
|
||||
Quiet: opts.Quiet,
|
||||
CoreTools: opts.Tools,
|
||||
DisableCoreTools: opts.DisableCoreTools,
|
||||
DisableCoreTools: disableCoreTools,
|
||||
ExtraTools: opts.ExtraTools,
|
||||
ToolWrapper: hookToolWrapper(beforeToolCall, afterToolResult),
|
||||
ProviderConfig: providerConfig,
|
||||
@@ -1463,6 +1526,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
timeout: opts.MCPTaskTimeout,
|
||||
progress: opts.MCPTaskProgress,
|
||||
}.toToolsConfig(),
|
||||
Viper: v,
|
||||
}
|
||||
|
||||
// Set up OAuth handler for remote MCP servers. The SDK does not create
|
||||
@@ -1532,8 +1596,10 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
authHandler: setupOpts.AuthHandler,
|
||||
opts: opts,
|
||||
mcpConfig: mcpConfig,
|
||||
v: v,
|
||||
hasCustomSystemPrompt: hasCustomSystemPrompt,
|
||||
systemPromptSource: systemPromptSource,
|
||||
basePrompt: capturedBasePrompt,
|
||||
beforeToolCall: beforeToolCall,
|
||||
afterToolResult: afterToolResult,
|
||||
beforeTurn: beforeTurn,
|
||||
@@ -1560,15 +1626,32 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
return k, nil
|
||||
}
|
||||
|
||||
// GetContextFiles returns the context files (e.g. AGENTS.md) loaded during
|
||||
// initialisation. Returns nil if no context files were found.
|
||||
// GetContextFiles returns the context files (e.g. AGENTS.md) currently active
|
||||
// on this Kit instance. The returned slice is a snapshot — mutating it does
|
||||
// not affect Kit state. Returns nil when no context files are loaded.
|
||||
func (m *Kit) GetContextFiles() []*ContextFile {
|
||||
return m.contextFiles
|
||||
m.runtimeMu.RLock()
|
||||
defer m.runtimeMu.RUnlock()
|
||||
if len(m.contextFiles) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*ContextFile, len(m.contextFiles))
|
||||
copy(out, m.contextFiles)
|
||||
return out
|
||||
}
|
||||
|
||||
// GetSkills returns the skills loaded during initialisation.
|
||||
// GetSkills returns the skills currently active on this Kit instance. The
|
||||
// returned slice is a snapshot — mutating it does not affect Kit state.
|
||||
// Returns nil when no skills are loaded.
|
||||
func (m *Kit) GetSkills() []*Skill {
|
||||
return m.skills
|
||||
m.runtimeMu.RLock()
|
||||
defer m.runtimeMu.RUnlock()
|
||||
if len(m.skills) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*Skill, len(m.skills))
|
||||
copy(out, m.skills)
|
||||
return out
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -1613,12 +1696,14 @@ func (m *Kit) expandSkillCommand(prompt string) string {
|
||||
|
||||
// Find the skill by name.
|
||||
var skillPath string
|
||||
m.runtimeMu.RLock()
|
||||
for _, s := range m.skills {
|
||||
if s.Name == name {
|
||||
skillPath = s.Path
|
||||
break
|
||||
}
|
||||
}
|
||||
m.runtimeMu.RUnlock()
|
||||
if skillPath == "" {
|
||||
return prompt
|
||||
}
|
||||
@@ -1758,8 +1843,14 @@ type SubagentConfig struct {
|
||||
// Empty string uses a minimal default prompt.
|
||||
SystemPrompt string
|
||||
|
||||
// Tools overrides the tool set. If nil, SubagentTools() is used (all
|
||||
// core tools except subagent, preventing infinite recursion).
|
||||
// Tools overrides the tool set available to the subagent.
|
||||
// If nil and the subagent is created via the SDK (Kit.Subagent()), the
|
||||
// static SubagentTools() set (all core tools except "subagent") is used.
|
||||
// When spawned internally by the agent loop, the parent's active tools
|
||||
// minus "subagent" are used instead (see GetToolsForSubagent()).
|
||||
// Pass m.GetToolsForSubagent() explicitly to opt into inheritance from
|
||||
// SDK call sites.
|
||||
// (The subagent tool is dropped to prevent infinite recursion.)
|
||||
Tools []Tool
|
||||
|
||||
// NoSession, when true, uses an in-memory ephemeral session. When false
|
||||
@@ -1791,6 +1882,50 @@ type SubagentResult struct {
|
||||
Elapsed time.Duration
|
||||
}
|
||||
|
||||
// inheritProviderConfig copies the parent's effective provider/runtime
|
||||
// configuration from its isolated config store onto child Options. Used by
|
||||
// Kit.Subagent so the child — which owns a separate store and re-loads only
|
||||
// .kit.yml / KIT_* on its own — still observes provider credentials, the
|
||||
// thinking level, and sampler/token overrides the parent acquired via
|
||||
// programmatic Options or runtime setters (e.g. SetThinkingLevel).
|
||||
//
|
||||
// max-tokens and the sampling parameters are only propagated when the parent
|
||||
// explicitly set them (IsSet), preserving the tri-state precedence so per-model
|
||||
// defaults still apply on the child when the parent left them unset. A nil
|
||||
// child or store is a no-op.
|
||||
func inheritProviderConfig(child *Options, v *viper.Viper) {
|
||||
if child == nil || v == nil {
|
||||
return
|
||||
}
|
||||
child.ProviderAPIKey = v.GetString("provider-api-key")
|
||||
child.ProviderURL = v.GetString("provider-url")
|
||||
child.TLSSkipVerify = v.GetBool("tls-skip-verify")
|
||||
child.ThinkingLevel = v.GetString("thinking-level")
|
||||
if v.IsSet("max-tokens") {
|
||||
child.MaxTokens = v.GetInt("max-tokens")
|
||||
}
|
||||
if v.IsSet("temperature") {
|
||||
t := float32(v.GetFloat64("temperature"))
|
||||
child.Temperature = &t
|
||||
}
|
||||
if v.IsSet("top-p") {
|
||||
p := float32(v.GetFloat64("top-p"))
|
||||
child.TopP = &p
|
||||
}
|
||||
if v.IsSet("top-k") {
|
||||
k := int32(v.GetInt("top-k"))
|
||||
child.TopK = &k
|
||||
}
|
||||
if v.IsSet("frequency-penalty") {
|
||||
fp := float32(v.GetFloat64("frequency-penalty"))
|
||||
child.FrequencyPenalty = &fp
|
||||
}
|
||||
if v.IsSet("presence-penalty") {
|
||||
pp := float32(v.GetFloat64("presence-penalty"))
|
||||
child.PresencePenalty = &pp
|
||||
}
|
||||
}
|
||||
|
||||
// Subagent spawns an in-process child Kit instance to perform a task. The
|
||||
// child gets its own session, event bus, and agent loop but shares the
|
||||
// parent's config (API keys, provider settings) and defaults to the parent's
|
||||
@@ -1860,22 +1995,28 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult
|
||||
}
|
||||
|
||||
// Create child Kit instance. Pass the parent's loaded MCP config to
|
||||
// avoid re-reading viper (which races with concurrent subagent spawns).
|
||||
// Streaming must be explicitly enabled — Options.Streaming defaults to
|
||||
// false, and New() unconditionally writes viper.Set("stream", opts.Streaming).
|
||||
// Without this, the subagent would (a) pollute viper global state for
|
||||
// other concurrent callers and (b) potentially hit provider-level
|
||||
// differences (e.g. Anthropic non-streaming timeouts with extended
|
||||
// thinking).
|
||||
// avoid re-loading and re-validating config for the child.
|
||||
// Streaming is enabled explicitly — without it, non-streaming can hit
|
||||
// provider-level differences (e.g. Anthropic non-streaming timeouts with
|
||||
// extended thinking). The child gets its own config store, so this does not
|
||||
// affect any other concurrent caller.
|
||||
streamOn := true
|
||||
childOpts := &Options{
|
||||
Model: model,
|
||||
SystemPrompt: systemPrompt,
|
||||
Tools: tools,
|
||||
NoSession: cfg.NoSession,
|
||||
Quiet: true,
|
||||
Streaming: true,
|
||||
Streaming: &streamOn,
|
||||
MCPConfig: m.mcpConfig,
|
||||
}
|
||||
|
||||
// Inherit the parent's effective provider/runtime configuration. Since #40
|
||||
// each Kit owns an isolated config store, so the child's New() only re-loads
|
||||
// .kit.yml / KIT_* on its own — values the parent picked up from
|
||||
// programmatic Options or runtime setters (e.g. SetThinkingLevel) would
|
||||
// otherwise be lost.
|
||||
inheritProviderConfig(childOpts, m.v)
|
||||
// Propagate the parent's MCP task configuration so a child subagent
|
||||
// invoking long-running MCP tools observes the same per-server modes,
|
||||
// timeouts, and progress callback as the parent. Without this, child
|
||||
@@ -1970,6 +2111,7 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
|
||||
SystemPrompt: systemPrompt,
|
||||
Timeout: timeout,
|
||||
OnEvent: onEvent,
|
||||
Tools: m.GetToolsForSubagent(),
|
||||
})
|
||||
m.cleanupSubagentListeners(toolCallID)
|
||||
if result == nil {
|
||||
@@ -2084,7 +2226,7 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
|
||||
}
|
||||
},
|
||||
OnStepUsage: func(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64) {
|
||||
if viper.GetBool("debug") {
|
||||
if m.v.GetBool("debug") {
|
||||
log.Printf("DEBUG Kit.generate emitting StepUsageEvent: input=%d output=%d cacheRead=%d cacheCreate=%d",
|
||||
inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens,
|
||||
)
|
||||
@@ -2126,7 +2268,7 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
|
||||
})
|
||||
},
|
||||
|
||||
// New callbacks for previously unwired Fantasy lifecycle events.
|
||||
// New callbacks for previously unwired agent lifecycle events.
|
||||
OnStepStart: func(stepNumber int) {
|
||||
m.events.emit(StepStartEvent{StepNumber: stepNumber})
|
||||
},
|
||||
@@ -2580,7 +2722,7 @@ func (m *Kit) IsReasoningModel() bool {
|
||||
|
||||
// GetThinkingLevel returns the current thinking level.
|
||||
func (m *Kit) GetThinkingLevel() string {
|
||||
return viper.GetString("thinking-level")
|
||||
return m.v.GetString("thinking-level")
|
||||
}
|
||||
|
||||
// SetThinkingLevel changes the thinking level and recreates the agent with
|
||||
@@ -2589,7 +2731,7 @@ func (m *Kit) GetThinkingLevel() string {
|
||||
// With message-level caching, both thinking and caching work together.
|
||||
// Caching reduces costs by 60-90% for repeated context.
|
||||
func (m *Kit) SetThinkingLevel(ctx context.Context, level string) error {
|
||||
viper.Set("thinking-level", level)
|
||||
m.v.Set("thinking-level", level)
|
||||
// Recreate agent with new thinking config by re-running SetModel
|
||||
// with the same model string. SetModel rebuilds the provider and
|
||||
// passes the updated viper config (including thinking-level).
|
||||
|
||||
+100
-24
@@ -86,8 +86,8 @@ func TestNewWithGenerationOptions(t *testing.T) {
|
||||
if got := host.MaxTokens(); got != want {
|
||||
t.Errorf("Options.MaxTokens=%d did not propagate; Kit.MaxTokens()=%d", want, got)
|
||||
}
|
||||
if !viper.IsSet("max-tokens") {
|
||||
t.Error("viper.IsSet(\"max-tokens\") should be true after MaxTokens override")
|
||||
if !host.ConfigValueIsSetForTest("max-tokens") {
|
||||
t.Error("max-tokens should be marked explicitly set on the instance store after MaxTokens override")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -129,11 +129,11 @@ func TestNewWithGenerationOptions(t *testing.T) {
|
||||
}
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
if !viper.IsSet("temperature") {
|
||||
t.Fatal("viper.IsSet(\"temperature\") should be true after Temperature override")
|
||||
if !host.ConfigValueIsSetForTest("temperature") {
|
||||
t.Fatal("temperature should be marked explicitly set on the instance store after Temperature override")
|
||||
}
|
||||
if got := float32(viper.GetFloat64("temperature")); got != want {
|
||||
t.Errorf("Options.Temperature=%v did not propagate; viper=%v", want, got)
|
||||
if got := float32(host.ConfigFloatForTest("temperature")); got != want {
|
||||
t.Errorf("Options.Temperature=%v did not propagate; instance store=%v", want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -185,8 +185,8 @@ func TestNewPreservesIsSetSemantics(t *testing.T) {
|
||||
// from SDK-side SetDefault/Set calls — which is exactly what this
|
||||
// test is guarding against.
|
||||
for _, k := range checkKeys {
|
||||
if viper.IsSet(k) {
|
||||
t.Errorf("viper.IsSet(%q) == true when no Options field set it "+
|
||||
if host.ConfigValueIsSetForTest(k) {
|
||||
t.Errorf("instance store reports %q explicitly set when no Options field set it "+
|
||||
"(SDK defaults must not corrupt IsSet semantics)", k)
|
||||
}
|
||||
}
|
||||
@@ -217,14 +217,14 @@ func TestNewWithProviderOptions(t *testing.T) {
|
||||
}
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
if got := viper.GetString("provider-api-key"); got != apiKey {
|
||||
t.Errorf("Options.ProviderAPIKey did not propagate to viper; got %q (len=%d)", got, len(got))
|
||||
if got := host.ConfigStringForTest("provider-api-key"); got != apiKey {
|
||||
t.Errorf("Options.ProviderAPIKey did not propagate to the instance store; got %q (len=%d)", got, len(got))
|
||||
}
|
||||
})
|
||||
|
||||
// Override precedence: even when viper already holds a different
|
||||
// provider-api-key value (as it would if a config file or earlier
|
||||
// Set() call populated one), Options.ProviderAPIKey must win.
|
||||
// Override precedence: even when the process-global store already holds a
|
||||
// different provider-api-key value, Options.ProviderAPIKey must win on the
|
||||
// Kit's isolated store.
|
||||
t.Run("Options override beats pre-existing viper state", func(t *testing.T) {
|
||||
defer resetViper()
|
||||
|
||||
@@ -242,15 +242,16 @@ func TestNewWithProviderOptions(t *testing.T) {
|
||||
ProviderAPIKey: want,
|
||||
})
|
||||
// Creation may still fail if the model registry is strict, but
|
||||
// we only care that the override reached viper before any
|
||||
// provider handshake happened.
|
||||
if host != nil {
|
||||
defer func() { _ = host.Close() }()
|
||||
// we only care that the override reached the instance store before
|
||||
// any provider handshake happened.
|
||||
if host == nil {
|
||||
t.Fatalf("expected a Kit instance to inspect; got nil (err=%v)", err)
|
||||
}
|
||||
defer func() { _ = host.Close() }()
|
||||
_ = err
|
||||
|
||||
if got := viper.GetString("provider-api-key"); got != want {
|
||||
t.Errorf("Options.ProviderAPIKey did not override pre-existing viper value; got %q, want %q", got, want)
|
||||
if got := host.ConfigStringForTest("provider-api-key"); got != want {
|
||||
t.Errorf("Options.ProviderAPIKey did not override pre-existing value on the instance store; got %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -270,7 +271,7 @@ func TestNewWithProviderOptions(t *testing.T) {
|
||||
}
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
if got := viper.GetString("provider-url"); got != want {
|
||||
if got := host.ConfigStringForTest("provider-url"); got != want {
|
||||
t.Errorf("Options.ProviderURL did not propagate; got %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
@@ -353,9 +354,9 @@ func TestNewSystemPromptFilePath(t *testing.T) {
|
||||
t.Errorf("GetSystemPromptSource() = %q; want %q", got, want)
|
||||
}
|
||||
|
||||
// The composed system prompt is written back to viper after PromptBuilder
|
||||
// runs. It must contain the file's contents, not the file path.
|
||||
composed := viper.GetString("system-prompt")
|
||||
// The composed system prompt is written back to the instance store after
|
||||
// PromptBuilder runs. It must contain the file's contents, not the file path.
|
||||
composed := host.ConfigStringForTest("system-prompt")
|
||||
if !strings.Contains(composed, promptContent) {
|
||||
t.Errorf("composed system-prompt does not contain file contents\n composed = %q\n want substring = %q", composed, promptContent)
|
||||
}
|
||||
@@ -364,6 +365,81 @@ func TestNewSystemPromptFilePath(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewWithSkillsOptions verifies that the three skills-related Options
|
||||
// fields (NoSkills, Skills, SkillsDir) are wired correctly into kit.New().
|
||||
func TestNewWithSkillsOptions(t *testing.T) {
|
||||
if os.Getenv("ANTHROPIC_API_KEY") == "" {
|
||||
t.Skip("Skipping test: ANTHROPIC_API_KEY not set")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("NoSkills disables skill loading", func(t *testing.T) {
|
||||
host, err := kit.New(ctx, &kit.Options{
|
||||
Model: "anthropic/claude-sonnet-4-5-20250929",
|
||||
Quiet: true,
|
||||
NoSession: true,
|
||||
NoSkills: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("kit.New failed: %v", err)
|
||||
}
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
if got := host.GetSkills(); len(got) != 0 {
|
||||
t.Errorf("NoSkills=true: expected 0 skills, got %d", len(got))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SkillsDir propagates", func(t *testing.T) {
|
||||
// Use a non-existent dir — no skills will load but the option must be
|
||||
// accepted without error and result in zero skills.
|
||||
dir := t.TempDir()
|
||||
host, err := kit.New(ctx, &kit.Options{
|
||||
Model: "anthropic/claude-sonnet-4-5-20250929",
|
||||
Quiet: true,
|
||||
NoSession: true,
|
||||
SkillsDir: dir,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("kit.New failed: %v", err)
|
||||
}
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
// Empty dir → no skills; the important thing is no error.
|
||||
_ = host.GetSkills()
|
||||
})
|
||||
|
||||
t.Run("explicit Skills paths load correctly", func(t *testing.T) {
|
||||
// Write a minimal skill file to a temp dir.
|
||||
dir := t.TempDir()
|
||||
skillFile := dir + "/my-skill.md"
|
||||
content := "---\nname: test-skill\ndescription: A test skill\n---\nDo the thing.\n"
|
||||
if err := os.WriteFile(skillFile, []byte(content), 0o644); err != nil {
|
||||
t.Fatalf("failed to write skill file: %v", err)
|
||||
}
|
||||
|
||||
host, err := kit.New(ctx, &kit.Options{
|
||||
Model: "anthropic/claude-sonnet-4-5-20250929",
|
||||
Quiet: true,
|
||||
NoSession: true,
|
||||
Skills: []string{skillFile},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("kit.New failed: %v", err)
|
||||
}
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
skills := host.GetSkills()
|
||||
if len(skills) != 1 {
|
||||
t.Fatalf("expected 1 skill, got %d", len(skills))
|
||||
}
|
||||
if skills[0].Name != "test-skill" {
|
||||
t.Errorf("skill name = %q; want %q", skills[0].Name, "test-skill")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestNewSystemPromptInline confirms that inline system-prompt strings still
|
||||
// flow through unchanged after the file-path resolution change.
|
||||
func TestNewSystemPromptInline(t *testing.T) {
|
||||
@@ -392,7 +468,7 @@ func TestNewSystemPromptInline(t *testing.T) {
|
||||
if got := host.GetSystemPromptSource(); got != inline {
|
||||
t.Errorf("GetSystemPromptSource() = %q; want %q", got, inline)
|
||||
}
|
||||
if composed := viper.GetString("system-prompt"); !strings.Contains(composed, inline) {
|
||||
if composed := host.ConfigStringForTest("system-prompt"); !strings.Contains(composed, inline) {
|
||||
t.Errorf("composed system-prompt missing inline content; got %q", composed)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
|
||||
"github.com/mark3labs/kit/internal/tools"
|
||||
)
|
||||
|
||||
@@ -163,3 +165,82 @@ func TestSubagentPropagatesMCPTaskOptions(t *testing.T) {
|
||||
inheritMCPTaskOptions(&Options{}, nil)
|
||||
inheritMCPTaskOptions(nil, parent)
|
||||
}
|
||||
|
||||
// TestInheritProviderConfig verifies that Kit.Subagent's provider/runtime
|
||||
// config inheritance copies the parent's effective settings onto child
|
||||
// Options, and that the tri-state (IsSet) keys are only propagated when the
|
||||
// parent explicitly set them. Regression test for config loss after the
|
||||
// per-instance viper store isolation (#40).
|
||||
func TestInheritProviderConfig(t *testing.T) {
|
||||
t.Run("explicit values propagate", func(t *testing.T) {
|
||||
v := viper.New()
|
||||
v.Set("provider-api-key", "sk-parent")
|
||||
v.Set("provider-url", "https://proxy.internal/v1")
|
||||
v.Set("tls-skip-verify", true)
|
||||
v.Set("thinking-level", "high")
|
||||
v.Set("max-tokens", 4321)
|
||||
v.Set("temperature", 0.25)
|
||||
v.Set("top-p", 0.9)
|
||||
v.Set("top-k", 40)
|
||||
v.Set("frequency-penalty", 0.1)
|
||||
v.Set("presence-penalty", 0.2)
|
||||
|
||||
child := &Options{}
|
||||
inheritProviderConfig(child, v)
|
||||
|
||||
if child.ProviderAPIKey != "sk-parent" {
|
||||
t.Errorf("ProviderAPIKey = %q, want sk-parent", child.ProviderAPIKey)
|
||||
}
|
||||
if child.ProviderURL != "https://proxy.internal/v1" {
|
||||
t.Errorf("ProviderURL = %q", child.ProviderURL)
|
||||
}
|
||||
if !child.TLSSkipVerify {
|
||||
t.Error("TLSSkipVerify not propagated")
|
||||
}
|
||||
if child.ThinkingLevel != "high" {
|
||||
t.Errorf("ThinkingLevel = %q, want high", child.ThinkingLevel)
|
||||
}
|
||||
if child.MaxTokens != 4321 {
|
||||
t.Errorf("MaxTokens = %d, want 4321", child.MaxTokens)
|
||||
}
|
||||
if child.Temperature == nil || *child.Temperature != 0.25 {
|
||||
t.Errorf("Temperature = %v, want 0.25", child.Temperature)
|
||||
}
|
||||
if child.TopP == nil || *child.TopP != 0.9 {
|
||||
t.Errorf("TopP = %v, want 0.9", child.TopP)
|
||||
}
|
||||
if child.TopK == nil || *child.TopK != 40 {
|
||||
t.Errorf("TopK = %v, want 40", child.TopK)
|
||||
}
|
||||
if child.FrequencyPenalty == nil || *child.FrequencyPenalty != 0.1 {
|
||||
t.Errorf("FrequencyPenalty = %v, want 0.1", child.FrequencyPenalty)
|
||||
}
|
||||
if child.PresencePenalty == nil || *child.PresencePenalty != 0.2 {
|
||||
t.Errorf("PresencePenalty = %v, want 0.2", child.PresencePenalty)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unset tri-state keys stay unset", func(t *testing.T) {
|
||||
// A store with no sampler / max-tokens keys must leave the child's
|
||||
// pointers nil and MaxTokens zero so per-model defaults still apply.
|
||||
v := viper.New()
|
||||
child := &Options{}
|
||||
inheritProviderConfig(child, v)
|
||||
|
||||
if child.MaxTokens != 0 {
|
||||
t.Errorf("MaxTokens = %d, want 0 (unset)", child.MaxTokens)
|
||||
}
|
||||
if child.Temperature != nil || child.TopP != nil || child.TopK != nil ||
|
||||
child.FrequencyPenalty != nil || child.PresencePenalty != nil {
|
||||
t.Error("sampler pointers must stay nil when the parent did not set them")
|
||||
}
|
||||
if child.ThinkingLevel != "" {
|
||||
t.Errorf("ThinkingLevel = %q, want empty", child.ThinkingLevel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil child or store is a no-op", func(t *testing.T) {
|
||||
inheritProviderConfig(nil, viper.New())
|
||||
inheritProviderConfig(&Options{}, nil)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -61,3 +61,23 @@ func CheckProviderReady(provider string) error {
|
||||
}
|
||||
return models.GetGlobalRegistry().ValidateEnvironment(provider, "")
|
||||
}
|
||||
|
||||
// ResolveProviderBaseURL returns the base API URL kit will use when talking to
|
||||
// the given provider, applying the same resolution order that CreateProvider
|
||||
// uses internally:
|
||||
//
|
||||
// 1. The provider's `api` field from the models.dev registry.
|
||||
// 2. The hard-coded default base URL of its npm SDK package (e.g.
|
||||
// @ai-sdk/groq → https://api.groq.com/openai/v1).
|
||||
// 3. Template substitution against the current process environment when the
|
||||
// URL contains "${VAR}" placeholders.
|
||||
//
|
||||
// Returns a non-nil error when the provider is unknown, when no URL can be
|
||||
// derived, or when a templated URL has unset placeholders.
|
||||
//
|
||||
// Use this from your SDK integration to surface the effective endpoint before
|
||||
// instantiating a Kit, or to validate that a provider is reachable without
|
||||
// running an actual request.
|
||||
func ResolveProviderBaseURL(providerID string) (string, error) {
|
||||
return models.ResolveProviderBaseURL(providerID)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
package kit
|
||||
|
||||
import "context"
|
||||
|
||||
// Option configures a [Kit] created via [NewAgent]. Options are applied in
|
||||
// order to an [Options] value, so later options override earlier ones. The
|
||||
// type is a plain func(*Options), so callers can define their own options
|
||||
// without depending on any internal type.
|
||||
type Option func(*Options)
|
||||
|
||||
// NewAgent creates a Kit using an ergonomic functional-options API. It is a
|
||||
// thin, additive front door over [New]: the supplied options are applied to a
|
||||
// fresh [Options] value which is then passed to [New]. For advanced
|
||||
// configuration not covered by the With* helpers (MCPConfig,
|
||||
// InProcessMCPServers, session backends, MCP task tuning, etc.) construct an
|
||||
// [Options] explicitly and call [New].
|
||||
//
|
||||
// Streaming defaults to enabled. Pass WithStreaming(false) to disable it.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// k, err := kit.NewAgent(ctx,
|
||||
// kit.WithModel("anthropic/claude-sonnet-4-5-20250929"),
|
||||
// kit.WithSystemPrompt("You are a helpful assistant."),
|
||||
// kit.WithMaxTokens(8192),
|
||||
// kit.Ephemeral(),
|
||||
// )
|
||||
func NewAgent(ctx context.Context, opts ...Option) (*Kit, error) {
|
||||
// Streaming defaults to true for the ergonomic constructor — this is the
|
||||
// natural expectation for interactive agents. WithStreaming(false) overrides it.
|
||||
streamOn := true
|
||||
o := &Options{Streaming: &streamOn}
|
||||
for _, fn := range opts {
|
||||
fn(o)
|
||||
}
|
||||
return New(ctx, o)
|
||||
}
|
||||
|
||||
// WithModel sets the model in "provider/model" format
|
||||
// (e.g. "anthropic/claude-sonnet-4-5-20250929").
|
||||
func WithModel(m string) Option { return func(o *Options) { o.Model = m } }
|
||||
|
||||
// WithSystemPrompt sets the system prompt. The value may be inline text or a
|
||||
// path to a file whose contents are loaded as the prompt.
|
||||
func WithSystemPrompt(p string) Option { return func(o *Options) { o.SystemPrompt = p } }
|
||||
|
||||
// WithStreaming enables or disables streaming responses. [NewAgent] enables
|
||||
// streaming by default, so pass WithStreaming(false) to opt out.
|
||||
func WithStreaming(b bool) Option {
|
||||
return func(o *Options) { o.Streaming = &b }
|
||||
}
|
||||
|
||||
// WithMaxTokens sets the maximum output tokens per LLM response. A value of 0
|
||||
// lets the precedence chain (env → config → per-model → SDK floor) resolve a
|
||||
// value; a non-zero value pins it and suppresses automatic right-sizing.
|
||||
func WithMaxTokens(n int) Option { return func(o *Options) { o.MaxTokens = n } }
|
||||
|
||||
// WithThinkingLevel sets the reasoning effort for models that support extended
|
||||
// thinking. Valid values: "off", "none", "minimal", "low", "medium", "high".
|
||||
// An empty string lets the precedence chain resolve a level.
|
||||
func WithThinkingLevel(level string) Option { return func(o *Options) { o.ThinkingLevel = level } }
|
||||
|
||||
// WithTools sets the agent's tool set, replacing the default core tools. When
|
||||
// no tools are provided the default set is used.
|
||||
func WithTools(t ...Tool) Option { return func(o *Options) { o.Tools = t } }
|
||||
|
||||
// WithExtraTools adds tools alongside the core/MCP/extension tools rather than
|
||||
// replacing them.
|
||||
func WithExtraTools(t ...Tool) Option { return func(o *Options) { o.ExtraTools = t } }
|
||||
|
||||
// WithProviderAPIKey overrides the API key used to authenticate with the model
|
||||
// provider.
|
||||
func WithProviderAPIKey(key string) Option { return func(o *Options) { o.ProviderAPIKey = key } }
|
||||
|
||||
// WithProviderURL overrides the provider endpoint URL. Useful for
|
||||
// OpenAI-compatible proxies (LiteLLM, vLLM, Azure OpenAI, etc.).
|
||||
func WithProviderURL(url string) Option { return func(o *Options) { o.ProviderURL = url } }
|
||||
|
||||
// WithConfigFile sets an explicit config file path, overriding the default
|
||||
// .kit.yml search.
|
||||
func WithConfigFile(path string) Option { return func(o *Options) { o.ConfigFile = path } }
|
||||
|
||||
// WithDebug enables SDK debug logging.
|
||||
func WithDebug() Option { return func(o *Options) { o.Debug = true } }
|
||||
|
||||
// Ephemeral configures an in-memory session with no persistence (equivalent to
|
||||
// Options.NoSession = true).
|
||||
func Ephemeral() Option { return func(o *Options) { o.NoSession = true } }
|
||||
@@ -0,0 +1,342 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/kit/internal/agent"
|
||||
"github.com/mark3labs/kit/internal/skills"
|
||||
)
|
||||
|
||||
// TestAddSkill_AddsAndDeduplicates verifies that AddSkill registers new skills
|
||||
// and that re-adding a skill with the same Name replaces the existing entry
|
||||
// rather than appending a duplicate. agent is nil in these tests; the method
|
||||
// must still mutate the in-memory state and tolerate the absent agent.
|
||||
func TestAddSkill_AddsAndDeduplicates(t *testing.T) {
|
||||
k := &Kit{basePrompt: "base"}
|
||||
|
||||
if err := k.AddSkill(&skills.Skill{Name: "alpha", Content: "first"}); err != nil {
|
||||
t.Fatalf("AddSkill alpha: %v", err)
|
||||
}
|
||||
if err := k.AddSkill(&skills.Skill{Name: "beta", Content: "second"}); err != nil {
|
||||
t.Fatalf("AddSkill beta: %v", err)
|
||||
}
|
||||
got := k.GetSkills()
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected 2 skills, got %d", len(got))
|
||||
}
|
||||
|
||||
// Re-adding alpha with new content must replace, not duplicate.
|
||||
if err := k.AddSkill(&skills.Skill{Name: "alpha", Content: "replaced"}); err != nil {
|
||||
t.Fatalf("AddSkill alpha replace: %v", err)
|
||||
}
|
||||
got = k.GetSkills()
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected 2 skills after replace, got %d", len(got))
|
||||
}
|
||||
for _, s := range got {
|
||||
if s.Name == "alpha" && s.Content != "replaced" {
|
||||
t.Errorf("alpha content = %q; want %q", s.Content, "replaced")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddSkill_Validation rejects nil skills and unnamed skills with errors
|
||||
// instead of corrupting state.
|
||||
func TestAddSkill_Validation(t *testing.T) {
|
||||
k := &Kit{}
|
||||
if err := k.AddSkill(nil); err == nil {
|
||||
t.Error("expected error for nil skill")
|
||||
}
|
||||
if err := k.AddSkill(&skills.Skill{Content: "x"}); err == nil {
|
||||
t.Error("expected error for unnamed skill")
|
||||
}
|
||||
if got := k.GetSkills(); got != nil {
|
||||
t.Errorf("skills list mutated after invalid AddSkill calls: %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRemoveSkill verifies removal and the false return for misses.
|
||||
func TestRemoveSkill(t *testing.T) {
|
||||
k := &Kit{}
|
||||
_ = k.AddSkill(&skills.Skill{Name: "alpha"})
|
||||
_ = k.AddSkill(&skills.Skill{Name: "beta"})
|
||||
|
||||
if removed := k.RemoveSkill("missing"); removed {
|
||||
t.Error("RemoveSkill(missing) = true; want false")
|
||||
}
|
||||
if removed := k.RemoveSkill("alpha"); !removed {
|
||||
t.Error("RemoveSkill(alpha) = false; want true")
|
||||
}
|
||||
got := k.GetSkills()
|
||||
if len(got) != 1 || got[0].Name != "beta" {
|
||||
t.Errorf("remaining skills = %#v; want [beta]", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSetSkills replaces the entire set and validates input.
|
||||
func TestSetSkills(t *testing.T) {
|
||||
k := &Kit{}
|
||||
_ = k.AddSkill(&skills.Skill{Name: "alpha"})
|
||||
|
||||
err := k.SetSkills([]*skills.Skill{
|
||||
{Name: "one"},
|
||||
{Name: "two"},
|
||||
{Name: "three"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SetSkills: %v", err)
|
||||
}
|
||||
if got := k.GetSkills(); len(got) != 3 {
|
||||
t.Errorf("expected 3 skills, got %d", len(got))
|
||||
}
|
||||
|
||||
// Invalid entry rejects the whole batch.
|
||||
bad := []*skills.Skill{{Name: "ok"}, nil}
|
||||
if err := k.SetSkills(bad); err == nil {
|
||||
t.Error("expected error when batch contains nil")
|
||||
}
|
||||
// State unchanged after rejected batch.
|
||||
if got := k.GetSkills(); len(got) != 3 {
|
||||
t.Errorf("skills mutated by rejected SetSkills batch: len=%d", len(got))
|
||||
}
|
||||
|
||||
// Empty slice clears.
|
||||
if err := k.SetSkills(nil); err != nil {
|
||||
t.Fatalf("SetSkills(nil): %v", err)
|
||||
}
|
||||
if got := k.GetSkills(); got != nil {
|
||||
t.Errorf("expected nil skills after clear; got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadAndAddSkill round-trips a skill file from disk.
|
||||
func TestLoadAndAddSkill(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "demo.md")
|
||||
body := "---\nname: demo\ndescription: demo skill\n---\nhello world"
|
||||
if err := os.WriteFile(path, []byte(body), 0o644); err != nil {
|
||||
t.Fatalf("write skill file: %v", err)
|
||||
}
|
||||
|
||||
k := &Kit{}
|
||||
s, err := k.LoadAndAddSkill(path)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadAndAddSkill: %v", err)
|
||||
}
|
||||
if s.Name != "demo" {
|
||||
t.Errorf("loaded skill Name = %q; want demo", s.Name)
|
||||
}
|
||||
if got := k.GetSkills(); len(got) != 1 {
|
||||
t.Errorf("expected 1 skill registered, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddContextFile_DeduplicatesByPath confirms identical paths replace
|
||||
// rather than duplicate.
|
||||
func TestAddContextFile_DeduplicatesByPath(t *testing.T) {
|
||||
k := &Kit{}
|
||||
if err := k.AddContextFile(&ContextFile{Path: "/a/AGENTS.md", Content: "v1"}); err != nil {
|
||||
t.Fatalf("AddContextFile: %v", err)
|
||||
}
|
||||
if err := k.AddContextFile(&ContextFile{Path: "/b/AGENTS.md", Content: "vB"}); err != nil {
|
||||
t.Fatalf("AddContextFile: %v", err)
|
||||
}
|
||||
if err := k.AddContextFile(&ContextFile{Path: "/a/AGENTS.md", Content: "v2"}); err != nil {
|
||||
t.Fatalf("AddContextFile replace: %v", err)
|
||||
}
|
||||
|
||||
got := k.GetContextFiles()
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected 2 context files, got %d", len(got))
|
||||
}
|
||||
for _, cf := range got {
|
||||
if cf.Path == "/a/AGENTS.md" && cf.Content != "v2" {
|
||||
t.Errorf("/a/AGENTS.md content = %q; want v2", cf.Content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddContextFile_Validation rejects nil and unpathed entries.
|
||||
func TestAddContextFile_Validation(t *testing.T) {
|
||||
k := &Kit{}
|
||||
if err := k.AddContextFile(nil); err == nil {
|
||||
t.Error("expected error for nil context file")
|
||||
}
|
||||
if err := k.AddContextFile(&ContextFile{Content: "x"}); err == nil {
|
||||
t.Error("expected error for empty path")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRemoveContextFile_Behavior verifies remove returns true on hit and
|
||||
// false on miss without mutating state on a miss.
|
||||
func TestRemoveContextFile_Behavior(t *testing.T) {
|
||||
k := &Kit{}
|
||||
_ = k.AddContextFile(&ContextFile{Path: "/a", Content: "x"})
|
||||
_ = k.AddContextFile(&ContextFile{Path: "/b", Content: "y"})
|
||||
|
||||
if removed := k.RemoveContextFile("/missing"); removed {
|
||||
t.Error("RemoveContextFile(missing) = true; want false")
|
||||
}
|
||||
if removed := k.RemoveContextFile("/a"); !removed {
|
||||
t.Error("RemoveContextFile(/a) = false; want true")
|
||||
}
|
||||
got := k.GetContextFiles()
|
||||
if len(got) != 1 || got[0].Path != "/b" {
|
||||
t.Errorf("remaining = %#v; want [/b]", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSetContextFiles replaces and validates batch input.
|
||||
func TestSetContextFiles(t *testing.T) {
|
||||
k := &Kit{}
|
||||
_ = k.AddContextFile(&ContextFile{Path: "/seed", Content: "old"})
|
||||
|
||||
err := k.SetContextFiles([]*ContextFile{
|
||||
{Path: "/x", Content: "x"},
|
||||
{Path: "/y", Content: "y"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SetContextFiles: %v", err)
|
||||
}
|
||||
if got := k.GetContextFiles(); len(got) != 2 {
|
||||
t.Errorf("expected 2 context files, got %d", len(got))
|
||||
}
|
||||
|
||||
bad := []*ContextFile{{Path: "/ok"}, {Path: ""}}
|
||||
if err := k.SetContextFiles(bad); err == nil {
|
||||
t.Error("expected error for empty path in batch")
|
||||
}
|
||||
if got := k.GetContextFiles(); len(got) != 2 {
|
||||
t.Errorf("state mutated by rejected batch: len=%d", len(got))
|
||||
}
|
||||
|
||||
if err := k.SetContextFiles(nil); err != nil {
|
||||
t.Fatalf("SetContextFiles(nil): %v", err)
|
||||
}
|
||||
if got := k.GetContextFiles(); got != nil {
|
||||
t.Errorf("expected nil after clear; got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadAndAddContextFile reads from disk and registers the context file.
|
||||
func TestLoadAndAddContextFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "AGENTS.md")
|
||||
const content = "# Agent rules\nuse the new lint config"
|
||||
if err := os.WriteFile(path, []byte(content), 0o644); err != nil {
|
||||
t.Fatalf("write file: %v", err)
|
||||
}
|
||||
|
||||
k := &Kit{}
|
||||
cf, err := k.LoadAndAddContextFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadAndAddContextFile: %v", err)
|
||||
}
|
||||
if !strings.HasSuffix(cf.Path, "AGENTS.md") {
|
||||
t.Errorf("Path = %q; want suffix AGENTS.md", cf.Path)
|
||||
}
|
||||
if !strings.Contains(cf.Content, "use the new lint config") {
|
||||
t.Errorf("Content missing expected body: %q", cf.Content)
|
||||
}
|
||||
got := k.GetContextFiles()
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1 context file, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddContextFileContent registers an in-memory context blob.
|
||||
func TestAddContextFileContent(t *testing.T) {
|
||||
k := &Kit{}
|
||||
cf, err := k.AddContextFileContent("session://user-123/AGENTS.md", "always greet in French")
|
||||
if err != nil {
|
||||
t.Fatalf("AddContextFileContent: %v", err)
|
||||
}
|
||||
if cf.Path != "session://user-123/AGENTS.md" {
|
||||
t.Errorf("Path = %q", cf.Path)
|
||||
}
|
||||
if cf.Content != "always greet in French" {
|
||||
t.Errorf("Content = %q", cf.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// TestComposeSystemPrompt_IncludesSkillsAndContext verifies that runtime
|
||||
// mutations actually flow into the composed system prompt that the agent
|
||||
// would receive.
|
||||
func TestComposeSystemPrompt_IncludesSkillsAndContext(t *testing.T) {
|
||||
k := &Kit{basePrompt: "BASE-PROMPT-MARKER"}
|
||||
|
||||
if err := k.AddContextFile(&ContextFile{
|
||||
Path: "/proj/AGENTS.md",
|
||||
Content: "CTX-MARKER-OK",
|
||||
}); err != nil {
|
||||
t.Fatalf("AddContextFile: %v", err)
|
||||
}
|
||||
if err := k.AddSkill(&skills.Skill{
|
||||
Name: "greeter",
|
||||
Description: "SKILL-DESC-MARKER",
|
||||
Content: "do greetings",
|
||||
Path: "/skills/greeter.md",
|
||||
}); err != nil {
|
||||
t.Fatalf("AddSkill: %v", err)
|
||||
}
|
||||
|
||||
composed := k.composeSystemPrompt(k.basePrompt)
|
||||
for _, want := range []string{
|
||||
"BASE-PROMPT-MARKER",
|
||||
"CTX-MARKER-OK",
|
||||
"/proj/AGENTS.md",
|
||||
"greeter",
|
||||
"SKILL-DESC-MARKER",
|
||||
} {
|
||||
if !strings.Contains(composed, want) {
|
||||
t.Errorf("composed prompt missing %q\n--- composed ---\n%s", want, composed)
|
||||
}
|
||||
}
|
||||
|
||||
// Removing the skill should remove its marker from the next composition.
|
||||
k.RemoveSkill("greeter")
|
||||
composed = k.composeSystemPrompt(k.basePrompt)
|
||||
if strings.Contains(composed, "SKILL-DESC-MARKER") {
|
||||
t.Errorf("composed prompt still contains removed skill description:\n%s", composed)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRuntimeMutations_AreThreadSafe stresses the mutation API from multiple
|
||||
// goroutines to surface data races under `go test -race`.
|
||||
func TestRuntimeMutations_AreThreadSafe(t *testing.T) {
|
||||
// Use a non-nil agent so applyComposedSystemPrompt actually invokes
|
||||
// agent.SetSystemPrompt (a no-op agent is fine — we only need the
|
||||
// systemPrompt mutation + fantasy rebuild path to run concurrently so
|
||||
// -race can observe any unsynchronized writes).
|
||||
k := &Kit{basePrompt: "base", agent: &agent.Agent{}}
|
||||
var wg sync.WaitGroup
|
||||
const goroutines = 8
|
||||
const iterations = 50
|
||||
|
||||
for g := range goroutines {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for range iterations {
|
||||
_ = k.AddSkill(&skills.Skill{
|
||||
Name: "skill",
|
||||
Content: "content",
|
||||
})
|
||||
_ = k.AddContextFile(&ContextFile{
|
||||
Path: "/shared/AGENTS.md",
|
||||
Content: "shared",
|
||||
})
|
||||
_ = k.GetSkills()
|
||||
_ = k.GetContextFiles()
|
||||
_ = k.composeSystemPrompt("base")
|
||||
k.RemoveSkill("skill")
|
||||
k.RemoveContextFile("/shared/AGENTS.md")
|
||||
}
|
||||
}(g)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
+138
-1
@@ -139,13 +139,150 @@ func (m *Kit) ClearSkillCache() {
|
||||
}
|
||||
|
||||
// ReloadSkills re-discovers skills from disk, replacing the current set.
|
||||
// This is called by file watchers when skill files change.
|
||||
// This is called by file watchers when skill files change. The system prompt
|
||||
// is recomposed and applied to the running agent so subsequent turns see the
|
||||
// new skill set.
|
||||
func (m *Kit) ReloadSkills() error {
|
||||
newSkills, err := loadSkills(m.opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reloading skills: %w", err)
|
||||
}
|
||||
m.runtimeMu.Lock()
|
||||
m.skills = newSkills
|
||||
m.runtimeMu.Unlock()
|
||||
m.ClearSkillCache()
|
||||
m.applyComposedSystemPrompt()
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Runtime skill management (Issue #36)
|
||||
// ---------------------------------------------------------------------------
|
||||
//
|
||||
// The methods below let SDK consumers (chatbot hosts, multi-tenant agents)
|
||||
// mutate the active skill set after Kit construction. Each mutation recomposes
|
||||
// the system prompt and applies it to the underlying agent so the LLM sees
|
||||
// the new skill metadata on its next turn.
|
||||
|
||||
// AddSkill registers a single skill on this Kit instance. The skill object
|
||||
// can be built programmatically (no file on disk required) — only Name and
|
||||
// Content are mandatory. If a skill with the same Name is already loaded the
|
||||
// new skill replaces it. Returns an error when skill is nil or has an empty
|
||||
// name.
|
||||
//
|
||||
// After mutation the system prompt is recomposed and applied to the running
|
||||
// agent so the next turn sees the updated skill metadata. AddSkill is safe to
|
||||
// call from any goroutine.
|
||||
func (m *Kit) AddSkill(skill *Skill) error {
|
||||
if skill == nil {
|
||||
return fmt.Errorf("AddSkill: skill is nil")
|
||||
}
|
||||
if skill.Name == "" {
|
||||
return fmt.Errorf("AddSkill: skill name is required")
|
||||
}
|
||||
|
||||
m.runtimeMu.Lock()
|
||||
replaced := false
|
||||
for i, s := range m.skills {
|
||||
if s.Name == skill.Name {
|
||||
m.skills[i] = skill
|
||||
replaced = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !replaced {
|
||||
m.skills = append(m.skills, skill)
|
||||
}
|
||||
m.runtimeMu.Unlock()
|
||||
|
||||
m.ClearSkillCache()
|
||||
m.applyComposedSystemPrompt()
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadAndAddSkill loads a skill from a filesystem path (single .md/.txt file)
|
||||
// and adds it via [Kit.AddSkill]. Returns the loaded skill on success.
|
||||
func (m *Kit) LoadAndAddSkill(path string) (*Skill, error) {
|
||||
s, err := skills.LoadSkill(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("LoadAndAddSkill: %w", err)
|
||||
}
|
||||
if err := m.AddSkill(s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// RemoveSkill removes the named skill from this Kit instance and recomposes
|
||||
// the system prompt. Returns true when a skill with that name was found and
|
||||
// removed, false otherwise.
|
||||
func (m *Kit) RemoveSkill(name string) bool {
|
||||
m.runtimeMu.Lock()
|
||||
found := false
|
||||
for i, s := range m.skills {
|
||||
if s.Name == name {
|
||||
m.skills = append(m.skills[:i], m.skills[i+1:]...)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
m.runtimeMu.Unlock()
|
||||
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
m.ClearSkillCache()
|
||||
m.applyComposedSystemPrompt()
|
||||
return true
|
||||
}
|
||||
|
||||
// SetSkills replaces the active skill set with the provided slice. Pass nil
|
||||
// or an empty slice to remove all skills. The system prompt is recomposed and
|
||||
// applied. Skills with empty names are rejected and no mutation is performed.
|
||||
func (m *Kit) SetSkills(skillList []*Skill) error {
|
||||
// Validate first so a bad input doesn't partially mutate state.
|
||||
for i, s := range skillList {
|
||||
if s == nil {
|
||||
return fmt.Errorf("SetSkills: skill at index %d is nil", i)
|
||||
}
|
||||
if s.Name == "" {
|
||||
return fmt.Errorf("SetSkills: skill at index %d has empty name", i)
|
||||
}
|
||||
}
|
||||
|
||||
copied := make([]*Skill, len(skillList))
|
||||
copy(copied, skillList)
|
||||
|
||||
m.runtimeMu.Lock()
|
||||
m.skills = copied
|
||||
m.runtimeMu.Unlock()
|
||||
|
||||
m.ClearSkillCache()
|
||||
m.applyComposedSystemPrompt()
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyComposedSystemPrompt recomposes the system prompt from the captured
|
||||
// base prompt + current contextFiles + current skills + date/cwd, and pushes
|
||||
// the result onto the underlying agent. No-op when the agent is unset (i.e.
|
||||
// during construction).
|
||||
func (m *Kit) applyComposedSystemPrompt() {
|
||||
if m.agent == nil {
|
||||
return
|
||||
}
|
||||
m.runtimeMu.RLock()
|
||||
base := m.basePrompt
|
||||
m.runtimeMu.RUnlock()
|
||||
composed := m.composeSystemPrompt(base)
|
||||
m.agent.SetSystemPrompt(composed)
|
||||
}
|
||||
|
||||
// RefreshSystemPrompt manually recomposes the system prompt from the current
|
||||
// skills and context files and applies it to the agent. Call this after a
|
||||
// batch of low-level mutations or to force a re-render of the date/cwd
|
||||
// section. Most callers don't need to invoke this directly because
|
||||
// AddSkill, RemoveSkill, SetSkills, AddContextFile, RemoveContextFile, and
|
||||
// SetContextFiles all refresh automatically.
|
||||
func (m *Kit) RefreshSystemPrompt() {
|
||||
m.applyComposedSystemPrompt()
|
||||
}
|
||||
|
||||
+27
-71
@@ -7,45 +7,36 @@ import (
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
"github.com/mark3labs/kit/internal/prompts"
|
||||
"github.com/mark3labs/kit/internal/skills"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Template Parsing Bridge for Extensions (Phase 3)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// varRegex matches {{variable}} placeholders in templates.
|
||||
var varRegex = regexp.MustCompile(`\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}\}`)
|
||||
|
||||
// ParseTemplate extracts {{variables}} from template content.
|
||||
// ParseTemplate extracts {{variables}} from template content. The template
|
||||
// grammar is shared with skill prompt templates, so a template parses
|
||||
// identically regardless of which API loads it.
|
||||
func ParseTemplate(name, content string) extensions.PromptTemplate {
|
||||
matches := varRegex.FindAllStringSubmatch(content, -1)
|
||||
vars := make([]string, 0, len(matches))
|
||||
seen := make(map[string]bool)
|
||||
for _, m := range matches {
|
||||
if len(m) > 1 && !seen[m[1]] {
|
||||
seen[m[1]] = true
|
||||
vars = append(vars, m[1])
|
||||
}
|
||||
tpl := skills.NewPromptTemplate(name, content)
|
||||
vars := tpl.Variables
|
||||
if vars == nil {
|
||||
vars = []string{}
|
||||
}
|
||||
return extensions.PromptTemplate{
|
||||
Name: name,
|
||||
Content: content,
|
||||
Name: tpl.Name,
|
||||
Content: tpl.Content,
|
||||
Variables: vars,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderTemplate substitutes variables into template content.
|
||||
// Handles {{name}} and {{ name }} (any whitespace) placeholders.
|
||||
// Handles {{name}} and {{ name }} (any whitespace) placeholders; missing
|
||||
// variables are left as-is.
|
||||
func RenderTemplate(tpl extensions.PromptTemplate, vars map[string]string) string {
|
||||
return varRegex.ReplaceAllStringFunc(tpl.Content, func(m string) string {
|
||||
sub := varRegex.FindStringSubmatch(m)
|
||||
if len(sub) > 1 {
|
||||
if v, ok := vars[sub[1]]; ok {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return m
|
||||
})
|
||||
t := skills.PromptTemplate{Content: tpl.Content}
|
||||
return t.Expand(vars)
|
||||
}
|
||||
|
||||
// ParseArguments parses command-line style arguments.
|
||||
@@ -183,44 +174,12 @@ func SimpleParseArguments(input string, count int) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
// parseFields splits input respecting quoted strings.
|
||||
// parseFields splits input into arguments respecting quoted strings and
|
||||
// backslash escaping. It delegates to the canonical tokenizer in
|
||||
// internal/prompts so extension argument parsing and builtin prompt-template
|
||||
// parsing agree on grammar.
|
||||
func parseFields(input string) []string {
|
||||
var fields []string
|
||||
var current strings.Builder
|
||||
inQuote := false
|
||||
quoteChar := rune(0)
|
||||
|
||||
for _, r := range input {
|
||||
switch r {
|
||||
case '"', '\'':
|
||||
if !inQuote {
|
||||
inQuote = true
|
||||
quoteChar = r
|
||||
} else if r == quoteChar {
|
||||
inQuote = false
|
||||
quoteChar = 0
|
||||
} else {
|
||||
current.WriteRune(r)
|
||||
}
|
||||
case ' ', '\t':
|
||||
if inQuote {
|
||||
current.WriteRune(r)
|
||||
} else {
|
||||
if current.Len() > 0 {
|
||||
fields = append(fields, current.String())
|
||||
current.Reset()
|
||||
}
|
||||
}
|
||||
default:
|
||||
current.WriteRune(r)
|
||||
}
|
||||
}
|
||||
|
||||
if current.Len() > 0 {
|
||||
fields = append(fields, current.String())
|
||||
}
|
||||
|
||||
return fields
|
||||
return prompts.ParseCommandArgs(input)
|
||||
}
|
||||
|
||||
// EvaluateModelConditional checks if condition matches current model.
|
||||
@@ -417,21 +376,18 @@ func MatchModelGlob(model, pattern string) bool {
|
||||
}
|
||||
|
||||
// ExtractProviderFromPath extracts provider from a path-like model string.
|
||||
//
|
||||
// Deprecated: Use GetCurrentProvider instead.
|
||||
func ExtractProviderFromPath(model string) string {
|
||||
parts := strings.Split(model, "/")
|
||||
if len(parts) >= 2 {
|
||||
return parts[0]
|
||||
}
|
||||
return ""
|
||||
return GetCurrentProvider(model)
|
||||
}
|
||||
|
||||
// ExtractModelFromPath extracts model ID from a path-like model string.
|
||||
//
|
||||
// Deprecated: Use RemoveProviderFromModel instead, which correctly handles
|
||||
// model IDs containing "/" (e.g. "openrouter/meta/llama").
|
||||
func ExtractModelFromPath(model string) string {
|
||||
parts := strings.Split(model, "/")
|
||||
if len(parts) >= 2 {
|
||||
return parts[1]
|
||||
}
|
||||
return model
|
||||
return RemoveProviderFromModel(model)
|
||||
}
|
||||
|
||||
// IsBareModelID checks if a string is a bare model ID (no provider).
|
||||
|
||||
@@ -0,0 +1,357 @@
|
||||
package kit_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// TestOptionFunctionsPlumbing verifies that the functional options apply their
|
||||
// values to the underlying Options struct. This does not create a provider, so
|
||||
// it runs without API keys.
|
||||
func TestOptionFunctionsPlumbing(t *testing.T) {
|
||||
o := &kit.Options{}
|
||||
opts := []kit.Option{
|
||||
kit.WithModel("anthropic/claude-sonnet-4-5-20250929"),
|
||||
kit.WithSystemPrompt("be terse"),
|
||||
kit.WithMaxTokens(4321),
|
||||
kit.WithThinkingLevel("high"),
|
||||
kit.WithProviderAPIKey("sk-test"),
|
||||
kit.WithProviderURL("https://example.test/v1"),
|
||||
kit.WithConfigFile("/tmp/.kit.yml"),
|
||||
kit.WithStreaming(false),
|
||||
kit.WithDebug(),
|
||||
kit.Ephemeral(),
|
||||
}
|
||||
for _, fn := range opts {
|
||||
fn(o)
|
||||
}
|
||||
|
||||
if o.Model != "anthropic/claude-sonnet-4-5-20250929" {
|
||||
t.Errorf("WithModel: got %q", o.Model)
|
||||
}
|
||||
if o.SystemPrompt != "be terse" {
|
||||
t.Errorf("WithSystemPrompt: got %q", o.SystemPrompt)
|
||||
}
|
||||
if o.MaxTokens != 4321 {
|
||||
t.Errorf("WithMaxTokens: got %d", o.MaxTokens)
|
||||
}
|
||||
if o.ThinkingLevel != "high" {
|
||||
t.Errorf("WithThinkingLevel: got %q", o.ThinkingLevel)
|
||||
}
|
||||
if o.ProviderAPIKey != "sk-test" {
|
||||
t.Errorf("WithProviderAPIKey: got %q", o.ProviderAPIKey)
|
||||
}
|
||||
if o.ProviderURL != "https://example.test/v1" {
|
||||
t.Errorf("WithProviderURL: got %q", o.ProviderURL)
|
||||
}
|
||||
if o.ConfigFile != "/tmp/.kit.yml" {
|
||||
t.Errorf("WithConfigFile: got %q", o.ConfigFile)
|
||||
}
|
||||
if o.Streaming == nil {
|
||||
t.Error("WithStreaming: expected Streaming to be set (non-nil)")
|
||||
} else if *o.Streaming {
|
||||
t.Error("WithStreaming(false): expected *Streaming=false")
|
||||
}
|
||||
if !o.Debug {
|
||||
t.Error("WithDebug: expected Debug=true")
|
||||
}
|
||||
if !o.NoSession {
|
||||
t.Error("Ephemeral: expected NoSession=true")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOptionOrderingOverrides verifies later options override earlier ones.
|
||||
func TestOptionOrderingOverrides(t *testing.T) {
|
||||
o := &kit.Options{}
|
||||
kit.WithModel("a/b")(o)
|
||||
kit.WithModel("c/d")(o)
|
||||
if o.Model != "c/d" {
|
||||
t.Errorf("later WithModel should win; got %q", o.Model)
|
||||
}
|
||||
}
|
||||
|
||||
// TestKitConfigIsolation is the regression test for issue #40: two Kit
|
||||
// instances constructed in the same process must own independent configuration
|
||||
// stores. Setting the thinking level (or model) on one must not affect the
|
||||
// other. Against the previous global-viper implementation this test fails
|
||||
// because both Kits read and write the same process-global store.
|
||||
func TestKitConfigIsolation(t *testing.T) {
|
||||
if os.Getenv("ANTHROPIC_API_KEY") == "" {
|
||||
t.Skip("Skipping test: ANTHROPIC_API_KEY not set")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
a, err := kit.New(ctx, &kit.Options{
|
||||
Model: "anthropic/claude-sonnet-4-5-20250929",
|
||||
ThinkingLevel: "low",
|
||||
Quiet: true,
|
||||
NoSession: true,
|
||||
NoExtensions: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create Kit A: %v", err)
|
||||
}
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
b, err := kit.New(ctx, &kit.Options{
|
||||
Model: "anthropic/claude-sonnet-4-5-20250929",
|
||||
ThinkingLevel: "high",
|
||||
Quiet: true,
|
||||
NoSession: true,
|
||||
NoExtensions: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create Kit B: %v", err)
|
||||
}
|
||||
defer func() { _ = b.Close() }()
|
||||
|
||||
// Each instance must retain its own configured thinking level. Under the
|
||||
// old global-viper implementation, B's construction overwrote A's value.
|
||||
if got := a.GetThinkingLevel(); got != "low" {
|
||||
t.Errorf("Kit A thinking level = %q; want %q (config leaked from B)", got, "low")
|
||||
}
|
||||
if got := b.GetThinkingLevel(); got != "high" {
|
||||
t.Errorf("Kit B thinking level = %q; want %q", got, "high")
|
||||
}
|
||||
|
||||
// Mutating one at runtime must not bleed into the other.
|
||||
if err := a.SetThinkingLevel(ctx, "medium"); err != nil {
|
||||
t.Fatalf("SetThinkingLevel on A: %v", err)
|
||||
}
|
||||
if got := a.GetThinkingLevel(); got != "medium" {
|
||||
t.Errorf("after SetThinkingLevel, Kit A = %q; want %q", got, "medium")
|
||||
}
|
||||
if got := b.GetThinkingLevel(); got != "high" {
|
||||
t.Errorf("after mutating A, Kit B leaked to %q; want %q", got, "high")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewAgentDefaultsStreamingOn verifies that the ergonomic constructor
|
||||
// enables streaming by default and applies functional options.
|
||||
func TestNewAgentDefaultsStreamingOn(t *testing.T) {
|
||||
if os.Getenv("ANTHROPIC_API_KEY") == "" {
|
||||
t.Skip("Skipping test: ANTHROPIC_API_KEY not set")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
k, err := kit.NewAgent(ctx,
|
||||
kit.WithModel("anthropic/claude-sonnet-4-5-20250929"),
|
||||
kit.WithMaxTokens(2048),
|
||||
kit.Ephemeral(),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("NewAgent failed: %v", err)
|
||||
}
|
||||
defer func() { _ = k.Close() }()
|
||||
|
||||
if !k.ConfigValueIsSetForTest("max-tokens") {
|
||||
t.Error("NewAgent did not propagate WithMaxTokens to the instance store")
|
||||
}
|
||||
if !k.ConfigBoolForTest("stream") {
|
||||
t.Error("NewAgent should enable streaming by default")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewAgentStreamingOptOut verifies WithStreaming(false) disables the
|
||||
// default-on streaming behaviour of NewAgent.
|
||||
func TestNewAgentStreamingOptOut(t *testing.T) {
|
||||
if os.Getenv("ANTHROPIC_API_KEY") == "" {
|
||||
t.Skip("Skipping test: ANTHROPIC_API_KEY not set")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
k, err := kit.NewAgent(ctx,
|
||||
kit.WithModel("anthropic/claude-sonnet-4-5-20250929"),
|
||||
kit.WithStreaming(false),
|
||||
kit.Ephemeral(),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("NewAgent failed: %v", err)
|
||||
}
|
||||
defer func() { _ = k.Close() }()
|
||||
|
||||
if k.ConfigBoolForTest("stream") {
|
||||
t.Error("WithStreaming(false) should disable streaming")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewZeroOptionsKeepsStreamingDefault is the regression test for the
|
||||
// unconditional `v.Set("stream", opts.Streaming)` bug: a zero-valued Options
|
||||
// (Streaming == nil) must NOT force stream=false. With Streaming unset,
|
||||
// streaming resolves through the precedence chain, whose SDK default is true.
|
||||
func TestNewZeroOptionsKeepsStreamingDefault(t *testing.T) {
|
||||
if os.Getenv("ANTHROPIC_API_KEY") == "" {
|
||||
t.Skip("Skipping test: ANTHROPIC_API_KEY not set")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
k, err := kit.New(ctx, &kit.Options{
|
||||
Model: "anthropic/claude-sonnet-4-5-20250929",
|
||||
Quiet: true,
|
||||
NoSession: true,
|
||||
SkipConfig: true, // isolate from any ~/.kit.yml / env stream setting
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New failed: %v", err)
|
||||
}
|
||||
defer func() { _ = k.Close() }()
|
||||
|
||||
if !k.ConfigBoolForTest("stream") {
|
||||
t.Error("zero-valued Options must not force stream=false; expected the default (true)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSkillsViperKeys verifies that the three skills config keys (no-skills,
|
||||
// skill, skills-dir) flow through viper when set via a config file, matching
|
||||
// the pattern used by no-extensions and no-core-tools. This test does not
|
||||
// require an API key because it only exercises Options struct plumbing.
|
||||
func TestSkillsViperKeys(t *testing.T) {
|
||||
t.Run("NoSkills option disables skill loading", func(t *testing.T) {
|
||||
o := &kit.Options{}
|
||||
o.NoSkills = true
|
||||
if !o.NoSkills {
|
||||
t.Error("Options.NoSkills = true not reflected on struct")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Skills paths set on Options", func(t *testing.T) {
|
||||
o := &kit.Options{
|
||||
Skills: []string{"/a/skill.md", "/b/skill.md"},
|
||||
}
|
||||
if len(o.Skills) != 2 {
|
||||
t.Errorf("Options.Skills: got %d paths, want 2", len(o.Skills))
|
||||
}
|
||||
if o.Skills[0] != "/a/skill.md" {
|
||||
t.Errorf("Options.Skills[0] = %q; want %q", o.Skills[0], "/a/skill.md")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SkillsDir set on Options", func(t *testing.T) {
|
||||
o := &kit.Options{
|
||||
SkillsDir: "/custom/skills",
|
||||
}
|
||||
if o.SkillsDir != "/custom/skills" {
|
||||
t.Errorf("Options.SkillsDir = %q; want %q", o.SkillsDir, "/custom/skills")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestSkillsConfigFileKeys verifies that no-skills, skill, and skills-dir
|
||||
// config file keys are read via viper and applied correctly. Requires an API
|
||||
// key because kit.New() is called to exercise the full config-load path.
|
||||
func TestSkillsConfigFileKeys(t *testing.T) {
|
||||
if os.Getenv("ANTHROPIC_API_KEY") == "" {
|
||||
t.Skip("Skipping test: ANTHROPIC_API_KEY not set")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("no-skills config key disables skill loading", func(t *testing.T) {
|
||||
// Write a config file with no-skills: true.
|
||||
cfgFile := t.TempDir() + "/.kit.yml"
|
||||
if err := os.WriteFile(cfgFile, []byte("no-skills: true\n"), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config: %v", err)
|
||||
}
|
||||
|
||||
host, err := kit.New(ctx, &kit.Options{
|
||||
Model: "anthropic/claude-sonnet-4-5-20250929",
|
||||
Quiet: true,
|
||||
NoSession: true,
|
||||
ConfigFile: cfgFile,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("kit.New failed: %v", err)
|
||||
}
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
if got := host.GetSkills(); len(got) != 0 {
|
||||
t.Errorf("no-skills:true in config: expected 0 skills, got %d", len(got))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("skill config key loads explicit skill files", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
skillFile := dir + "/cfg-skill.md"
|
||||
if err := os.WriteFile(skillFile, []byte("---\nname: cfg-skill\ndescription: from config\n---\nContent.\n"), 0o644); err != nil {
|
||||
t.Fatalf("failed to write skill file: %v", err)
|
||||
}
|
||||
|
||||
cfgContent := "skill:\n - " + skillFile + "\n"
|
||||
cfgFile := dir + "/.kit.yml"
|
||||
if err := os.WriteFile(cfgFile, []byte(cfgContent), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config: %v", err)
|
||||
}
|
||||
|
||||
host, err := kit.New(ctx, &kit.Options{
|
||||
Model: "anthropic/claude-sonnet-4-5-20250929",
|
||||
Quiet: true,
|
||||
NoSession: true,
|
||||
ConfigFile: cfgFile,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("kit.New failed: %v", err)
|
||||
}
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
skills := host.GetSkills()
|
||||
if len(skills) != 1 {
|
||||
t.Fatalf("expected 1 skill from config, got %d", len(skills))
|
||||
}
|
||||
if skills[0].Name != "cfg-skill" {
|
||||
t.Errorf("skill name = %q; want %q", skills[0].Name, "cfg-skill")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("skills-dir config key overrides auto-discovery root", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgContent := "skills-dir: " + dir + "\n"
|
||||
cfgFile := dir + "/.kit.yml"
|
||||
if err := os.WriteFile(cfgFile, []byte(cfgContent), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config: %v", err)
|
||||
}
|
||||
|
||||
host, err := kit.New(ctx, &kit.Options{
|
||||
Model: "anthropic/claude-sonnet-4-5-20250929",
|
||||
Quiet: true,
|
||||
NoSession: true,
|
||||
ConfigFile: cfgFile,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("kit.New failed: %v", err)
|
||||
}
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
// Empty dir → 0 skills; the key point is no error during init.
|
||||
_ = host.GetSkills()
|
||||
})
|
||||
}
|
||||
|
||||
// TestNewStreamingExplicitOptOut verifies that a raw Options can still disable
|
||||
// streaming by setting Streaming to a pointer to false.
|
||||
func TestNewStreamingExplicitOptOut(t *testing.T) {
|
||||
if os.Getenv("ANTHROPIC_API_KEY") == "" {
|
||||
t.Skip("Skipping test: ANTHROPIC_API_KEY not set")
|
||||
}
|
||||
|
||||
streamOff := false
|
||||
ctx := context.Background()
|
||||
k, err := kit.New(ctx, &kit.Options{
|
||||
Model: "anthropic/claude-sonnet-4-5-20250929",
|
||||
Quiet: true,
|
||||
NoSession: true,
|
||||
SkipConfig: true,
|
||||
Streaming: &streamOff,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New failed: %v", err)
|
||||
}
|
||||
defer func() { _ = k.Close() }()
|
||||
|
||||
if k.ConfigBoolForTest("stream") {
|
||||
t.Error("Streaming=&false should disable streaming")
|
||||
}
|
||||
}
|
||||
@@ -88,7 +88,8 @@ api.OnAgentStart(func(e ext.AgentStartEvent, ctx ext.Context) {
|
||||
// e.Prompt string
|
||||
})
|
||||
|
||||
// Agent finished responding.
|
||||
// Agent finished responding. Carries per-turn aggregates so observer-style
|
||||
// extensions don't need to maintain parallel bookkeeping.
|
||||
api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) {
|
||||
// e.Response string
|
||||
// e.StopReason string — "error" (on failure), "completed" (when LLM returns
|
||||
@@ -96,6 +97,33 @@ api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) {
|
||||
// (e.g. "stop", "length" (max output tokens hit), "tool-calls", "content-filter").
|
||||
// To detect errors, check e.StopReason == "error".
|
||||
// Do NOT compare against "completed" for success — instead check != "error".
|
||||
//
|
||||
// Per-turn aggregates (computed by Kit's runtime):
|
||||
// e.ToolCallCount int — total tool invocations this turn
|
||||
// e.ToolNames []string — tool names in call order (duplicates preserved)
|
||||
// e.LLMCallCount int — LLM round-trips / tool-loop iterations
|
||||
// e.InputTokensDelta int — sum of input tokens across LLM calls this turn
|
||||
// e.OutputTokensDelta int
|
||||
// e.CacheReadTokensDelta int
|
||||
// e.CacheWriteTokensDelta int
|
||||
// e.CostDelta float64 — USD cost (zero when pricing unknown / OAuth)
|
||||
// e.DurationMs int64 — wall-clock duration AgentStart→AgentEnd
|
||||
})
|
||||
|
||||
// Per-LLM-call usage — fires after each provider round-trip with token + cost
|
||||
// deltas attributed to that specific call. A single turn typically produces
|
||||
// multiple LLMUsageEvents (one per tool-loop iteration). Use this for accurate
|
||||
// budget enforcement that needs to react between calls instead of waiting
|
||||
// for the turn to finish.
|
||||
api.OnLLMUsage(func(e ext.LLMUsageEvent, ctx ext.Context) {
|
||||
// e.InputTokens, e.OutputTokens int
|
||||
// e.CacheReadTokens, e.CacheWriteTokens int
|
||||
// e.Cost float64 — USD; zero when pricing unknown / OAuth
|
||||
// e.Model, e.Provider string — model used for THIS call
|
||||
// (may differ across calls if SetModel was called)
|
||||
// e.StepNumber int — zero-based step index in this turn
|
||||
// e.FinishReason string — "stop" / "tool_calls" / "length" / ...
|
||||
// e.RequestID string — optional provider correlation id (may be empty)
|
||||
})
|
||||
```
|
||||
|
||||
@@ -528,11 +556,38 @@ stats := ctx.GetContextStats() // .EstimatedTokens, .ContextLimit, .UsagePer
|
||||
msgs := ctx.GetMessages() // []ext.SessionMessage on current branch
|
||||
path := ctx.GetSessionPath() // file path of session JSONL
|
||||
|
||||
// Persist custom data in the session tree:
|
||||
// Append-only log in the session tree (fork-aware, walked on every branch read):
|
||||
id, err := ctx.AppendEntry("my-type", "data string")
|
||||
entries := ctx.GetEntries("my-type") // []ext.ExtensionEntry{ID, EntryType, Data, Timestamp}
|
||||
```
|
||||
|
||||
### Session State (last-write-wins)
|
||||
|
||||
Key-value store scoped to the session, persisted to a sidecar file
|
||||
(`<session>.ext-state.json`) outside the conversation tree. Reads are O(1)
|
||||
(no branch walk), writes don't grow the JSONL, and the store is not
|
||||
duplicated on fork. State is invisible to the LLM and survives session
|
||||
resume. For ephemeral / in-memory sessions, state lives only in memory.
|
||||
|
||||
```go
|
||||
ctx.SetState("myext:budget-cap", "10.00") // last write wins
|
||||
val, ok := ctx.GetState("myext:budget-cap") // (string, bool)
|
||||
ctx.DeleteState("myext:budget-cap") // no-op if missing
|
||||
keys := ctx.ListState() // []string, unspecified order
|
||||
```
|
||||
|
||||
**When to use which:**
|
||||
|
||||
| Need | Use |
|
||||
|------|-----|
|
||||
| Snapshot state ("current value of X") | `SetState` / `GetState` |
|
||||
| Audit log / event history | `AppendEntry` / `GetEntries` |
|
||||
| One-shot per-turn signal | enriched `AgentEndEvent` fields |
|
||||
| Per-LLM-call observation | `OnLLMUsage` event |
|
||||
|
||||
Namespace keys with your extension name (e.g. `"myext:budget-cap"`) to avoid
|
||||
collisions across extensions.
|
||||
|
||||
### Model Management
|
||||
|
||||
```go
|
||||
|
||||
@@ -1104,6 +1104,19 @@ if extAPI.HasExtensions() {
|
||||
tools := extAPI.GetToolInfos()
|
||||
extAPI.SetActiveTools([]string{"bash", "read"})
|
||||
|
||||
// Session-scoped extension state (last-write-wins key-value store).
|
||||
// Backed by an in-memory map and a per-session sidecar file
|
||||
// (<session>.ext-state.json) outside the conversation tree.
|
||||
extAPI.SetState("myext:budget-cap", "10.00")
|
||||
val, ok := extAPI.GetState("myext:budget-cap")
|
||||
extAPI.DeleteState("myext:budget-cap")
|
||||
keys := extAPI.ListState()
|
||||
|
||||
// Load any existing state from the sidecar and install a saver hook so
|
||||
// subsequent SetState/DeleteState mutations are flushed atomically.
|
||||
// No-op for ephemeral / in-memory sessions. Safe to call multiple times.
|
||||
_ = extAPI.InitStatePersistence()
|
||||
|
||||
// Events
|
||||
extAPI.EmitSessionStart()
|
||||
extAPI.EmitModelChange("new/model", "old/model", "extension")
|
||||
|
||||
@@ -56,6 +56,26 @@ kit install --all # Install all extensions without prompting
|
||||
kit skill # Install the Kit extensions skill via skills.sh
|
||||
```
|
||||
|
||||
### Skills CLI flags
|
||||
|
||||
Control which skills are loaded at startup:
|
||||
|
||||
```bash
|
||||
# Load a specific skill file
|
||||
kit --skill path/to/skill.md "prompt"
|
||||
|
||||
# Load multiple skill files or directories (flag is repeatable)
|
||||
kit --skill ./skill1.md --skill ./skill2.md "prompt"
|
||||
|
||||
# Load all skills from a custom directory instead of the default locations
|
||||
kit --skills-dir /path/to/skills "prompt"
|
||||
|
||||
# Disable all skill loading (auto-discovery and explicit)
|
||||
kit --no-skills "prompt"
|
||||
```
|
||||
|
||||
Skills are auto-discovered from `~/.config/kit/skills/`, `.kit/skills/`, and `.agents/skills/` by default. Use `--skills-dir` to override the project-local search root, or `--skill` to load files explicitly (which disables auto-discovery). `--no-skills` suppresses all skill loading regardless of other flags.
|
||||
|
||||
## Interactive slash commands
|
||||
|
||||
These commands are available inside the Kit TUI during an interactive session:
|
||||
@@ -110,6 +130,23 @@ Press **Ctrl+X s** during streaming to inject a system-level instruction mid-tur
|
||||
|
||||
Example: While the model is writing code, press Ctrl+X s and type "Use async/await instead" to change the implementation approach.
|
||||
|
||||
### Image attachments
|
||||
|
||||
Attach images to your next prompt straight from the clipboard:
|
||||
|
||||
- Copy an image (e.g. a screenshot) to the system clipboard, then press **Ctrl+V** in the input to attach it.
|
||||
- Press **Ctrl+U** to clear all pending image attachments.
|
||||
- Attachments are sent alongside your text when you submit, and cleared afterward.
|
||||
|
||||
When a terminal supports color, Kit renders a small low-resolution **thumbnail preview** of each pending image directly in the input, below the `[N image(s) attached]` indicator, so you can confirm the right image was attached before sending.
|
||||
|
||||
The preview is drawn with Unicode half-block characters and ordinary terminal colors — not a graphics protocol — so it renders correctly inside terminal multiplexers like **tmux** and **zellij**. Thumbnails are capped to a small cell box for a glanceable, low-res look.
|
||||
|
||||
- Best fidelity needs a **truecolor** terminal (`COLORTERM=truecolor`); Kit degrades to 256-color where truecolor is unavailable.
|
||||
- On terminals with neither, the preview is skipped and the `[N image(s) attached]` text indicator is shown alone.
|
||||
|
||||
You can also attach image files by referencing them with `@path/to/image.png` — binary files are auto-detected by MIME type. See [Quick Start](/quick-start) for the `@` attachment syntax.
|
||||
|
||||
## Prompt templates
|
||||
|
||||
### Creating templates
|
||||
|
||||
@@ -48,6 +48,14 @@ These flags control Kit's behavior. When a prompt is passed as a positional argu
|
||||
| `--prompt-template` | — | — | Load a specific prompt template by name |
|
||||
| `--no-prompt-templates` | — | `false` | Disable prompt template loading |
|
||||
|
||||
## Skills
|
||||
|
||||
| Flag | Short | Default | Description |
|
||||
|------|-------|---------|-------------|
|
||||
| `--skill` | — | — | Load skill file or directory (repeatable) |
|
||||
| `--skills-dir` | — | — | Override the project-local skills directory for auto-discovery |
|
||||
| `--no-skills` | — | `false` | Disable skill loading (auto-discovery and explicit) |
|
||||
|
||||
## Generation parameters
|
||||
|
||||
| Flag | Short | Default | Description |
|
||||
|
||||
@@ -47,6 +47,9 @@ stream: true
|
||||
| `theme` | object or string | — | UI theme ([inline overrides or file path](/themes)) |
|
||||
| `prompt-templates` | bool | `true` | Enable prompt template loading |
|
||||
| `prompt-template` | string | — | Specific template to load by name |
|
||||
| `no-skills` | bool | `false` | Disable skill loading (auto-discovery and explicit) |
|
||||
| `skill` | list | — | Explicit skill files or directories to load (disables auto-discovery) |
|
||||
| `skills-dir` | string | — | Override the project-local directory used for skill auto-discovery |
|
||||
|
||||
## Environment variables
|
||||
|
||||
@@ -88,6 +91,9 @@ mcpServers:
|
||||
type: remote
|
||||
url: "https://pubmed.mcp.example.com"
|
||||
noOAuth: true # skip OAuth for public servers
|
||||
headers:
|
||||
- "ApiKey: ${env://API_KEY}" # required env var
|
||||
- "X-Tenant: ${env://TENANT_ID:-default}" # with fallback default
|
||||
|
||||
builds:
|
||||
type: remote
|
||||
@@ -106,9 +112,10 @@ mcpServers:
|
||||
| `allowedTools` | list | Whitelist of tool names to expose |
|
||||
| `excludedTools` | list | Blacklist of tool names to hide |
|
||||
| `noOAuth` | bool | Skip OAuth for this server (for public servers that don't require auth) |
|
||||
| `headers` | list of strings | HTTP headers to attach to every request, each as a `"Key: Value"` string. Values support env-substitution: `${env://VAR}` or `${env://VAR:-default}`. |
|
||||
| `tasksMode` | string | When to augment `tools/call` with MCP task metadata: `auto` (default — only when the server advertises task support), `never`, or `always`. See [MCP tasks](#mcp-tasks-long-running-tools). |
|
||||
|
||||
A legacy format with `transport`, `args`, `env`, and `headers` fields is also supported.
|
||||
A legacy format with `transport`, `args`, and `env` fields is also supported; `headers` works in both the current and legacy formats.
|
||||
|
||||
### MCP tasks (long-running tools)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ description: All extension capabilities — lifecycle events, tools, commands, w
|
||||
|
||||
## Lifecycle events
|
||||
|
||||
Extensions can hook into 26 lifecycle events:
|
||||
Extensions can hook into 27 lifecycle events:
|
||||
|
||||
| Event | Description |
|
||||
|-------|-------------|
|
||||
@@ -15,7 +15,8 @@ Extensions can hook into 26 lifecycle events:
|
||||
| `OnSessionShutdown` | Session ending |
|
||||
| `OnBeforeAgentStart` | Before the agent loop begins |
|
||||
| `OnAgentStart` | Agent loop started |
|
||||
| `OnAgentEnd` | Agent loop completed |
|
||||
| `OnAgentEnd` | Agent loop completed (carries per-turn aggregates: tool counts, token deltas, cost, duration) |
|
||||
| `OnLLMUsage` | Per-LLM-call token + cost delta (fires once per provider round-trip) |
|
||||
| `OnToolCall` | Tool call requested by the model |
|
||||
| `OnToolCallInputStart` | LLM began generating tool call arguments (tool name known, args streaming) |
|
||||
| `OnToolCallInputDelta` | Streamed JSON fragment of tool call arguments |
|
||||
@@ -45,11 +46,52 @@ api.OnToolCall(func(event ext.ToolCallEvent, ctx ext.Context) {
|
||||
ctx.PrintInfo("Calling tool: " + event.Name)
|
||||
})
|
||||
|
||||
api.OnAgentEnd(func(_ ext.AgentEndEvent, ctx ext.Context) {
|
||||
ctx.PrintInfo("Agent finished")
|
||||
api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) {
|
||||
// Per-turn aggregates populated by Kit's runtime — no parallel
|
||||
// bookkeeping required in the handler.
|
||||
ctx.PrintInfo(fmt.Sprintf(
|
||||
"Turn finished: %d tool calls (%v), %d LLM round-trips, $%.4f, %dms",
|
||||
e.ToolCallCount, e.ToolNames, e.LLMCallCount, e.CostDelta, e.DurationMs,
|
||||
))
|
||||
})
|
||||
|
||||
// Per-LLM-call usage — fires multiple times per turn (once per round-trip).
|
||||
// Use for accurate budget enforcement between calls.
|
||||
api.OnLLMUsage(func(e ext.LLMUsageEvent, ctx ext.Context) {
|
||||
ctx.PrintInfo(fmt.Sprintf(
|
||||
"%s/%s step=%d tokens=↑%d ↓%d cost=$%.4f (%s)",
|
||||
e.Provider, e.Model, e.StepNumber,
|
||||
e.InputTokens, e.OutputTokens, e.Cost, e.FinishReason,
|
||||
))
|
||||
})
|
||||
```
|
||||
|
||||
**`AgentEndEvent` fields** (in addition to `Response` and `StopReason`):
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `ToolCallCount` | `int` | Total tool invocations during the turn |
|
||||
| `ToolNames` | `[]string` | Tool names in call order (duplicates preserved) |
|
||||
| `LLMCallCount` | `int` | LLM round-trips / tool-loop iterations |
|
||||
| `InputTokensDelta` | `int` | Sum of input tokens across all LLM calls this turn |
|
||||
| `OutputTokensDelta` | `int` | Sum of output tokens across all LLM calls this turn |
|
||||
| `CacheReadTokensDelta` | `int` | Sum of cache-read tokens this turn |
|
||||
| `CacheWriteTokensDelta` | `int` | Sum of cache-write tokens this turn |
|
||||
| `CostDelta` | `float64` | Cost in USD (zero when pricing is unknown or OAuth credentials) |
|
||||
| `DurationMs` | `int64` | Wall-clock time from `AgentStart` to `AgentEnd` |
|
||||
|
||||
**`LLMUsageEvent` fields**:
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `InputTokens` / `OutputTokens` | `int` | Per-call token deltas |
|
||||
| `CacheReadTokens` / `CacheWriteTokens` | `int` | Per-call cache token deltas |
|
||||
| `Cost` | `float64` | Per-call USD cost (zero when pricing unknown) |
|
||||
| `Model` / `Provider` | `string` | Model used for this specific call — may differ from earlier calls if `ctx.SetModel` was called mid-turn |
|
||||
| `StepNumber` | `int` | Zero-based step index within the turn |
|
||||
| `FinishReason` | `string` | Provider finish reason for this call (`"stop"`, `"tool_calls"`, `"length"`, ...) |
|
||||
| `RequestID` | `string` | Optional provider correlation id (may be empty) |
|
||||
|
||||
## Tools
|
||||
|
||||
Register custom tools that the LLM can invoke:
|
||||
@@ -338,6 +380,36 @@ api.OnCustomEvent("my-extension:data-ready", func(data any, ctx ext.Context) {
|
||||
})
|
||||
```
|
||||
|
||||
## Session state
|
||||
|
||||
Last-write-wins key-value store, scoped to the current session and persisted to a sidecar file (`<session>.ext-state.json`) outside the conversation tree:
|
||||
|
||||
```go
|
||||
ctx.SetState("myext:budget-cap", "10.00")
|
||||
|
||||
if cap, ok := ctx.GetState("myext:budget-cap"); ok {
|
||||
// ...
|
||||
}
|
||||
|
||||
ctx.DeleteState("myext:budget-cap")
|
||||
keys := ctx.ListState() // []string, unspecified order
|
||||
```
|
||||
|
||||
Reads are O(1) (no branch walk), writes don't grow the session JSONL, and the store is not duplicated when the conversation forks. State is invisible to the LLM and survives session resume.
|
||||
|
||||
### When to use which persistence primitive
|
||||
|
||||
| Need | Use | Why |
|
||||
|------|-----|-----|
|
||||
| Snapshot state ("current value of X") | `SetState` / `GetState` | O(1) reads, sidecar file, last-write-wins |
|
||||
| Audit log / event history | `AppendEntry` / `GetEntries` | Append-only, lives in conversation tree, fork-aware |
|
||||
| One-shot per-turn signal | Enriched `AgentEndEvent` fields | No persistence needed; runtime tracks it for you |
|
||||
| Per-LLM-call observation | `OnLLMUsage` event | Already attributed to model/provider/step |
|
||||
|
||||
Using `AppendEntry` for snapshot state has a cost: it's O(branch_length) to read, fsyncs into the JSONL on every write, and the entry list duplicates on every fork. Prefer `SetState` for "what's the current value of X?"-style data.
|
||||
|
||||
For ephemeral / in-memory sessions (no JSONL path) the state lives only in memory for the lifetime of the runner.
|
||||
|
||||
## Bridged SDK APIs
|
||||
|
||||
Extensions can access powerful internal SDK capabilities that enable advanced features like conversation tree navigation, dynamic skill loading, template parsing, and model resolution.
|
||||
|
||||
@@ -50,6 +50,7 @@ Kit ships with a rich set of example extensions in the `examples/extensions/` di
|
||||
| [`context-inject.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/context-inject.go) | Inject context into conversations |
|
||||
| [`summarize.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/summarize.go) | Conversation summarization |
|
||||
| [`lsp-diagnostics.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/lsp-diagnostics.go) | LSP diagnostic integration |
|
||||
| [`usage-budget.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/usage-budget.go) | Per-call usage callback (`OnLLMUsage`), session state (`SetState`/`GetState`), and enriched `OnAgentEnd` per-turn report |
|
||||
|
||||
## Bridged SDK APIs
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user