Compare commits

...

19 Commits

Author SHA1 Message Date
Ed Zynda e07c94f49d feat(mcp): add dynamic MCP server loading and unloading
- Add AddServer/RemoveServer to MCPToolManager for runtime server management
- Add RemoveConnection to MCPConnectionPool for per-server teardown
- Add AddMCPServer/RemoveMCPServer/ListMCPServers to Agent and SDK Kit
- Lazily create connection pool so AddServer works without prior LoadTools
- Wire onToolsChanged callback to trigger agent tool list rebuild
- Make MCPToolManager.Close nil-safe when pool was never initialized

Tests:
- Integration tests with real stdio MCP server (Python echo server)
- Agent-level tests using mock LLM model (no API key needed)
- Unit tests for error paths, callbacks, idempotency, nil safety
- SDK type surface tests
2026-04-09 13:54:11 +03:00
Ed Zynda b87146a284 feat(sdk): add MCPTokenStoreFactory for custom OAuth token storage
- Add MCPTokenStoreFactory option to kit.Options allowing SDK consumers
  to provide custom token storage backends for remote MCP servers
- Thread TokenStoreFactory through the full chain: kit.Options →
  kitsetup → agent → MCPToolManager → MCPConnectionPool
- Add createTokenStore() helper on connection pool that delegates to the
  factory or falls back to the default FileTokenStore
- Export MCPTokenStore, MCPToken, MCPTokenStoreFactory, and ErrMCPNoToken
  in pkg/kit/types.go following SDK naming conventions
- Default behavior (file-based storage) is preserved when factory is nil
2026-04-09 13:27:40 +03:00
Ed Zynda 186d9f7f44 fix(ui): route raw fmt.Print calls through proper renderers
- event_handler: route default extension print level through DisplayInfo
  instead of bare fmt.Println for consistent styling and timestamps
- factory: remove orphan fmt.Println("") before system messages; the
  renderer already manages its own spacing
- app: PrintFromExtension non-interactive fallback now respects level,
  writing errors/info to stderr with prefix to keep stdout clean
- app: PrintBlockFromExtension non-interactive fallback writes framed
  blocks to stderr instead of raw text to stdout
2026-04-09 13:00:23 +03:00
Ed Zynda 3a8ffc2104 feat(models): add per-model system prompt support
- Add systemPrompt field to GenerationParams and config structs
- On init, replace default system prompt with per-model prompt when
  user hasn't explicitly set one (via flag, config, or SDK option)
- On model switch, detect per-model prompt and compose it with
  AGENTS.md, skills, and date/cwd context
- Fix viper.IsSet bug: BindPFlag causes IsSet to return true for
  unset flags, so compare against defaultSystemPrompt instead
- Agent.SetModel now updates stored system prompt from config
- Export LoadModelSettingsFromConfig, LoadSystemPromptValue, and
  LookupModelForSettings for use by Kit.SetModel
- Add tests for prompt apply, precedence, file path, and
  modelSettings override
2026-04-09 12:35:00 +03:00
Ed Zynda e54570162e feat(models): add per-model generation parameter defaults
- Add modelSettings config section for attaching generation params
  (temperature, topP, topK, frequencyPenalty, presencePenalty,
  maxTokens, stopSequences, thinkingLevel) to any model by
  provider/model key
- Add params field to customModels definitions for inline defaults
- Change BuildProviderConfig and SetModel to use viper.IsSet so
  unset params remain nil, allowing model-level defaults to apply
- Wire ApplyModelSettings into CreateProvider with priority order:
  CLI flags > global config > modelSettings > customModels params
- Add GenerationParams to ModelInfo in the registry
- Update default config template with modelSettings and customModels
  params examples
2026-04-09 12:07:42 +03:00
Ed Zynda 34bb97a40e chore(deps): update dependencies
- bump mcp-go to v0.47.1
- update cloud auth, otel, and various indirect deps
2026-04-08 20:51:59 +03:00
Ed Zynda f5c1a16f8a feat(session): make compaction create new leaf with no parent
Change compaction behavior so the compaction entry has no parent (empty
ParentID), creating a new root for post-compaction history. This ensures
old compacted messages are not traversed when building LLM context.

- Modify AppendCompaction to create entries with empty ParentID
- Update BuildContext to collect kept messages via FirstKeptEntryID
- Update GetContextEntryIDs with same logic
- Add comprehensive tests for compaction behavior
- Add web viewer support for displaying compaction entries
2026-04-08 18:52:44 +03:00
Ed Zynda b29d7d2166 refactor(acpserver): remove redundant thinking tag parsing
Remove dead code now that pkg/kit transparently handles <thinking> and
 tags at the agent layer. The ACP server no longer needs to:

- Track inThinkingTag state across chunks
- Parse and split reasoning/text from MessageUpdateEvent chunks
- Maintain tag format constants

MessageUpdateEvent now contains clean text, and ReasoningDeltaEvent
contains structured reasoning - no duplicate filtering needed.
2026-04-08 16:55:53 +03:00
Ed Zynda 3ea0db69ea fix(ui): wrap user messages to terminal width
- Add width parameter to UserBlock and apply lipgloss.Wrap() before
  passing content to herald Tip alert
- Subtract 4 from width to account for alert bar prefix and margin
- Pass renderer width from RenderUserMessage to UserBlock
- Mirrors the assistant message wrapping added in e33564c
2026-04-08 15:15:27 +03:00
Ed Zynda 4304a5e899 feat(ui): change steer keybind to Ctrl+X s leader key chord
- Replace single Ctrl+S with Ctrl+X leader prefix followed by "s"
- Add leaderKeyActive flag to AppModel for two-key chord state
- Ctrl+X sets the leader flag; next keypress completes or cancels chord
- Update hint text in input component (adjust width thresholds)
- Update /help command output to reflect new keybind
2026-04-08 15:04:48 +03:00
Ed Zynda 4019c1e4f7 fix(ui): remove character limits from all textarea inputs
- Main message input: 5000 -> unlimited
- Prompt dialog input: 1000 -> unlimited
- Tool approval input: 1000 -> unlimited

Setting CharLimit to 0 disables the limit in Bubble Tea's textarea.
2026-04-08 14:23:34 +03:00
Ed Zynda 30ad7c1d0b feat(sdk): persist session messages incrementally per agent step
- Add StepMessagesHandler callback to agent's GenerateWithLoopAndStreaming
  so callers can persist messages as each step completes
- Wire onStepMessages in Kit.generate() to call session.AppendMessage
  for each step's messages immediately on completion
- Track PersistedMessageCount on GenerateWithLoopResult so runTurn
  skips already-persisted messages in post-generation cleanup
- Tool calls are always persisted as assistant+tool pairs (never orphaned)
- Document concurrency and incremental persistence requirements on
  the SessionManager interface for custom implementations
2026-04-08 14:15:05 +03:00
Ed Zynda e33564c569 fix(ui): wrap assistant messages to terminal width
- ToMarkdown() received a width param but never used it
- Apply lipgloss.Wrap() after herald-md render to break long lines
- Preserves ANSI styles/colors through the wrapping pass
- Fixes overflow for all markdown paths: assistant messages, tool
  bodies, and overlay text
2026-04-08 13:34:33 +03:00
Ed Zynda 5ff28445fd fix(ui): truncate queued and steering message blocks to prevent overflow
- Limit each queued/steering block to 3 visible content lines with ellipsis
- Account for soft-wrapping when counting visual lines
- Truncation is visual only; full text is preserved for scrollback
- Add truncateMessageForBlock helper with wrap-aware line counting
- Add 7 unit tests covering short, exact, overflow, wrapping, and mixed cases
2026-04-08 13:24:26 +03:00
Ed Zynda 13d177e5d0 fix(extensions): use structured logging that respects log levels
Switch from standard log.Printf to charmbracelet/log for extension loading
messages. This ensures DEBUG output only appears when explicitly enabled.

- Remove unconditional WARN log for failed extension loads
- Convert DEBUG loaded extension message to structured log.Debug call
2026-04-08 00:39:21 +03:00
Ed Zynda 3ffc995f27 feat(sdk): add NewTool/NewParallelTool for dependency-free custom tools
- Add ToolOutput struct, TextResult/ErrorResult helpers, and
  ToolCallIDFromContext so SDK consumers can create custom tools
  without importing charm.land/fantasy
- Add NewTool (sequential) and NewParallelTool (concurrent) generic
  constructors with automatic JSON schema generation from struct tags
- Remove dead UpdateUsageFromResponse method and fantasy import from
  internal/ui/cli.go
- Update SDK skill, README, and www/ docs with custom tool examples
  and corrected hook signatures
2026-04-07 22:05:42 +03:00
Ed Zynda b2bd016135 fix(tui): redirect log output to file to prevent TUI corruption
- Add tea.LogToFile in runInteractiveModeBubbleTea to send stdlib log
  output to /tmp/kit/kit.log instead of stderr
- Replace charmbracelet/log with stdlib log in extensions loader,
  runner, watcher, prompts loader, and pkg/kit so all log calls go
  through the redirected stdlib logger
- Leave charmbracelet/log in CLI-only commands (install, acp) and
  acpserver where stderr logging is correct
2026-04-07 21:20:04 +03:00
Ed Zynda 812dedaea2 feat(pkg/kit): add SessionManager interface for custom session backends
Add SessionManager interface to allow pluggable session storage backends.
This enables users to implement custom session managers for databases,
cloud storage, or other persistence mechanisms instead of the default
JSONL file-based TreeManager.

Changes:
- Add SessionManager interface with methods for message storage,
  tree navigation, compaction, and extension data
- Add treeManagerAdapter to wrap existing TreeManager for backward compatibility
- Update Kit struct to use SessionManager interface instead of concrete type
- Add SessionManager option to Options struct
- Update all session-related methods to use interface
- Add documentation for custom SessionManager usage

The default behavior is preserved - when no SessionManager is provided,
Kit automatically uses the TreeManager via the adapter.
2026-04-07 17:41:46 +03:00
Ed Zynda f65b6737f2 feat(sdk): add SkipConfig and DisableCoreTools options
Add two new Options fields for programmatic SDK usage:

- SkipConfig: Skip .kit.yml file loading while still using viper defaults
  and environment variables. Useful for fully programmatic configuration.

- DisableCoreTools: Allow creating agents with 0 tools (chat-only mode) or
  with only custom tools. When true and Tools is empty, no tools are loaded.
  When combined with custom Tools, only those tools are loaded.

Updates documentation in README, pkg/kit/README, skills/kit-sdk/SKILL,
and www/pages/sdk/options.
2026-04-07 17:10:58 +03:00
54 changed files with 4507 additions and 614 deletions
+28 -1
View File
@@ -531,7 +531,12 @@ host, err := kit.New(ctx, &kit.Options{
NoSession: true, // Ephemeral mode
// Tool options
ExtraTools: []kit.Tool{...}, // Additional tools alongside defaults
Tools: []kit.Tool{...}, // Replace default tool set entirely
ExtraTools: []kit.Tool{...}, // Add tools alongside defaults
DisableCoreTools: true, // Use no core tools (0 tools, for chat-only)
// Configuration
SkipConfig: true, // Skip .kit.yml files (viper defaults + env vars still apply)
// Compaction
AutoCompact: true, // Auto-compact near context limit
@@ -540,6 +545,28 @@ host, err := kit.New(ctx, &kit.Options{
})
```
### Custom Tools
Create custom tools with automatic schema generation — no external dependencies needed:
```go
type SearchInput struct {
Query string `json:"query" description:"Search query"`
}
searchTool := kit.NewTool("search", "Search the codebase",
func(ctx context.Context, input SearchInput) (kit.ToolOutput, error) {
return kit.TextResult("Found: ..."), nil
},
)
host, _ := kit.New(ctx, &kit.Options{
ExtraTools: []kit.Tool{searchTool}, // adds alongside built-in tools
})
```
Use `kit.NewParallelTool` for tools safe to run concurrently. See the [SDK docs](/sdk/overview) for full details on struct tags, `ToolOutput` fields, and `ToolCallIDFromContext`.
### With Callbacks
```go
+12
View File
@@ -2003,6 +2003,18 @@ func writeJSONError(err error) {
//
// SetupCLI is not used for interactive mode; the TUI (AppModel) handles its own rendering.
func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getPromptTemplates func() []*prompts.PromptTemplate, getSkillItems func() []ui.SkillItem, getToolNames func() []string, getMCPToolCount func() int, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error, startupExtensionMessages []string) error {
// Redirect all log output (stdlib and charm) to a file so that log
// messages don't write to stderr and corrupt the TUI. Bubble Tea
// captures stdout for rendering; any stray stderr output from
// background goroutines (watchers, extension handlers, SDK internals)
// will visually corrupt the terminal.
logDir := filepath.Join(os.TempDir(), "kit")
_ = os.MkdirAll(logDir, 0o700)
logFile, logErr := tea.LogToFile(filepath.Join(logDir, "kit.log"), "kit")
if logErr == nil {
defer func() { _ = logFile.Close() }()
}
// Determine terminal size; fall back gracefully.
termWidth, termHeight, err := term.GetSize(int(os.Stdout.Fd()))
if err != nil || termWidth == 0 {
+12 -12
View File
@@ -21,7 +21,7 @@ require (
github.com/fsnotify/fsnotify v1.9.0
github.com/indaco/herald v0.13.0
github.com/indaco/herald-md v0.3.0
github.com/mark3labs/mcp-go v0.47.0
github.com/mark3labs/mcp-go v0.47.1
github.com/spf13/cobra v1.10.2
github.com/spf13/viper v1.21.0
github.com/traefik/yaegi v0.16.1
@@ -31,7 +31,7 @@ require (
require (
cloud.google.com/go v0.123.0 // indirect
cloud.google.com/go/auth v0.19.0 // indirect
cloud.google.com/go/auth v0.20.0 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
cloud.google.com/go/compute/metadata v0.9.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 // indirect
@@ -58,9 +58,9 @@ require (
github.com/charmbracelet/harmonica v0.2.0 // indirect
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 // indirect
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260330094520-2dce04b6f8a4 // indirect
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260406091427-a791e22d5143 // indirect
github.com/charmbracelet/x/exp/ordered v0.1.0 // indirect
github.com/charmbracelet/x/exp/slice v0.0.0-20260330094520-2dce04b6f8a4 // indirect
github.com/charmbracelet/x/exp/slice v0.0.0-20260406091427-a791e22d5143 // indirect
github.com/charmbracelet/x/exp/strings v0.1.0 // indirect
github.com/charmbracelet/x/json v0.2.0 // indirect
github.com/charmbracelet/x/termios v0.1.1 // indirect
@@ -81,7 +81,7 @@ require (
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
github.com/googleapis/gax-go/v2 v2.21.0 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/kaptinlin/go-i18n v0.3.0 // indirect
github.com/kaptinlin/go-i18n v0.3.1 // indirect
github.com/kaptinlin/jsonpointer v0.4.17 // indirect
github.com/kaptinlin/jsonschema v0.7.7 // indirect
github.com/kaptinlin/messageformat-go v0.4.19 // indirect
@@ -103,8 +103,8 @@ require (
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
github.com/yuin/goldmark v1.8.2 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 // indirect
go.opentelemetry.io/otel v1.43.0 // indirect
go.opentelemetry.io/otel/metric v1.43.0 // indirect
go.opentelemetry.io/otel/trace v1.43.0 // indirect
@@ -114,9 +114,9 @@ require (
golang.org/x/net v0.52.0 // indirect
golang.org/x/oauth2 v0.36.0 // indirect
golang.org/x/time v0.15.0 // indirect
google.golang.org/api v0.274.0 // indirect
google.golang.org/api v0.275.0 // indirect
google.golang.org/genai v1.52.1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260406210006-6f92a3bedf2d // indirect
google.golang.org/grpc v1.80.0 // indirect
google.golang.org/protobuf v1.36.11 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
@@ -128,13 +128,13 @@ require (
github.com/charmbracelet/x/term v0.2.2 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/lucasb-eyer/go-colorful v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.22 // indirect
github.com/mattn/go-isatty v0.0.21 // indirect
github.com/mattn/go-runewidth v0.0.23 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/spf13/pflag v1.0.10 // indirect
golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/sys v0.43.0 // indirect
golang.org/x/text v0.35.0
)
+28 -29
View File
@@ -10,8 +10,8 @@ charm.land/lipgloss/v2 v2.0.2 h1:xFolbF8JdpNkM2cEPTfXEcW1p6NRzOWTSamRfYEw8cs=
charm.land/lipgloss/v2 v2.0.2/go.mod h1:KjPle2Qd3YmvP1KL5OMHiHysGcNwq6u83MUjYkFvEkM=
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
cloud.google.com/go/auth v0.19.0 h1:DGYwtbcsGsT1ywuxsIoWi1u/vlks0moIblQHgSDgQkQ=
cloud.google.com/go/auth v0.19.0/go.mod h1:2Aph7BT2KnaSFOM0JDPyiYgNh6PL9vGMiP8CUIXZ+IY=
cloud.google.com/go/auth v0.20.0 h1:kXTssoVb4azsVDoUiF8KvxAqrsQcQtB53DcSgta74CA=
cloud.google.com/go/auth v0.20.0/go.mod h1:942/yi/itH1SsmpyrbnTMDgGfdy2BUqIKyd0cyYLc5Q=
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
@@ -96,14 +96,14 @@ github.com/charmbracelet/x/conpty v0.1.1 h1:s1bUxjoi7EpqiXysVtC+a8RrvPPNcNvAjfi4
github.com/charmbracelet/x/conpty v0.1.1/go.mod h1:OmtR77VODEFbiTzGE9G1XiRJAga6011PIm4u5fTNZpk=
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 h1:JSt3B+U9iqk37QUU2Rvb6DSBYRLtWqFqfxf8l5hOZUA=
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86/go.mod h1:2P0UgXMEa6TsToMSuFqKFQR+fZTO9CNGUNokkPatT/0=
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260330094520-2dce04b6f8a4 h1:pIj18ZCZO4WOVj7jwjLoUb1lC7rS/I8oC3fZWXugNaY=
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260330094520-2dce04b6f8a4/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260406091427-a791e22d5143 h1:zmBor0ftFNqVFp9U59ZoEDRUCIYSGOGSIfGGkNZRufs=
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260406091427-a791e22d5143/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f h1:pk6gmGpCE7F3FcjaOEKYriCvpmIN4+6OS/RD0vm4uIA=
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f/go.mod h1:IfZAMTHB6XkZSeXUqriemErjAWCCzT0LwjKFYCZyw0I=
github.com/charmbracelet/x/exp/ordered v0.1.0 h1:55/qLwjIh0gL0Vni+QAWk7T/qRVP6sBf+2agPBgnOFE=
github.com/charmbracelet/x/exp/ordered v0.1.0/go.mod h1:5UHwmG+is5THxMyCJHNPCn2/ecI07aKNrW+LcResjJ8=
github.com/charmbracelet/x/exp/slice v0.0.0-20260330094520-2dce04b6f8a4 h1:VSd4zShIAf/4FgEDFJpapEcAPrc7h3dyyN7V9JlJpQw=
github.com/charmbracelet/x/exp/slice v0.0.0-20260330094520-2dce04b6f8a4/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
github.com/charmbracelet/x/exp/slice v0.0.0-20260406091427-a791e22d5143 h1:aEppolah2k9c0LzKX2fk5ryuyQ0Lq8kCOjkvMw1b8o4=
github.com/charmbracelet/x/exp/slice v0.0.0-20260406091427-a791e22d5143/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
github.com/charmbracelet/x/exp/strings v0.1.0 h1:i69S2XI7uG1u4NLGeJPSYU++Nmjvpo9nwd6aoEm7gkA=
github.com/charmbracelet/x/exp/strings v0.1.0/go.mod h1:/ehtMPNh9K4odGFkqYJKpIYyePhdp1hLBRvyY4bWkH8=
github.com/charmbracelet/x/json v0.2.0 h1:DqB+ZGx2h+Z+1s98HOuOyli+i97wsFQIxP2ZQANTPrQ=
@@ -185,8 +185,8 @@ github.com/indaco/herald v0.13.0 h1:+xVG9Fx5NpuWhwku/9IlRL6I009NnX4VUGKvlZHTRxU=
github.com/indaco/herald v0.13.0/go.mod h1:T5g1+XLYvpjouhzAGHnAHDCKizhESkoV6+QPZ3DhgWA=
github.com/indaco/herald-md v0.3.0 h1:hN1cKyrexPPM9PeHBsKuaWvIizSi/iYvM9yzRgtdb8M=
github.com/indaco/herald-md v0.3.0/go.mod h1:RUHVaDSG45ymJjKyxpDwBocLXrZo93FB4OeYMsw9B9s=
github.com/kaptinlin/go-i18n v0.3.0 h1:wP76dvYg04bvwTb+8NB+CmdZ2kL7lSSCQ9B/kFv7QHo=
github.com/kaptinlin/go-i18n v0.3.0/go.mod h1:pVcu9qsW5pOIOoZFJXesRYmLos1vMQrby70JPAoWmJU=
github.com/kaptinlin/go-i18n v0.3.1 h1:plXi3XQE1aYamFi8TU0K6actODmw2+5FSobmhTkfQ/0=
github.com/kaptinlin/go-i18n v0.3.1/go.mod h1:ZRoAHj7elWYamfbv7wev7Ajch6LOzjtBaq8nWe8HIVk=
github.com/kaptinlin/jsonpointer v0.4.17 h1:mY9k8ciWncxbsECyaxKnR0MdmxamNdp2tLQkAKVrtSk=
github.com/kaptinlin/jsonpointer v0.4.17/go.mod h1:SsfsjqnHG5zuKo1DTBzk1VknaHlL4osHw+X9kZKukpU=
github.com/kaptinlin/jsonschema v0.7.7 h1:41BlQJ9dskH0oE5DSzBUrl/w4JQYIr6N6L0B5GNyDoM=
@@ -201,12 +201,12 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/lucasb-eyer/go-colorful v1.4.0 h1:UtrWVfLdarDgc44HcS7pYloGHJUjHV/4FwW4TvVgFr4=
github.com/lucasb-eyer/go-colorful v1.4.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mark3labs/mcp-go v0.47.0 h1:h44yeM3DduDyQgzImYWu4pt6VRkqP/0p/95AGhWngnA=
github.com/mark3labs/mcp-go v0.47.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.22 h1:76lXsPn6FyHtTY+jt2fTTvsMUCZq1k0qwRsAMuxzKAk=
github.com/mattn/go-runewidth v0.0.22/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mark3labs/mcp-go v0.47.1 h1:A9sJJ20mscl/ssLYHjodfaoBmq6uuhMG7pAPNYaQymQ=
github.com/mark3labs/mcp-go v0.47.1/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
github.com/mattn/go-isatty v0.0.21 h1:xYae+lCNBP7QuW4PUnNG61ffM4hVIfm+zUzDuSzYLGs=
github.com/mattn/go-isatty v0.0.21/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw=
github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
@@ -272,18 +272,18 @@ github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE=
github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 h1:yI1/OhfEPy7J9eoa6Sj051C7n5dvpj0QX8g4sRchg04=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 h1:OyrsyzuttWTSur2qN/Lm0m2a8yqyIjUVBZcxFPuXq2o=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0/go.mod h1:C2NGBr+kAB4bk3xtMXfZ94gqFDtg/GkI7e9zqGh5Beg=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0 h1:0Qx7VGBacMm9ZENQ7TnNObTYI4ShC+lHI16seduaxZo=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0/go.mod h1:Sje3i3MjSPKTSPvVWCaL8ugBzJwik3u4smCjUeuupqg=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 h1:CqXxU8VOmDefoh0+ztfGaymYbhdB/tT3zs79QaZTNGY=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0/go.mod h1:BuhAPThV8PBHBvg8ZzZ/Ok3idOdhWIodywz2xEcRbJo=
go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I=
go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0=
go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM=
go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY=
go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo=
go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts=
go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA=
go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc=
go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg=
go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg=
go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw=
go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A=
go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A=
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
@@ -298,9 +298,8 @@ golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
@@ -309,16 +308,16 @@ golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
google.golang.org/api v0.274.0 h1:aYhycS5QQCwxHLwfEHRRLf9yNsfvp1JadKKWBE54RFA=
google.golang.org/api v0.274.0/go.mod h1:JbAt7mF+XVmWu6xNP8/+CTiGH30ofmCmk9nM8d8fHew=
google.golang.org/api v0.275.0 h1:vfY5d9vFVJeWEZT65QDd9hbndr7FyZ2+6mIzGAh71NI=
google.golang.org/api v0.275.0/go.mod h1:Fnag/EWUPIcJXuIkP1pjoTgS5vdxlk3eeemL7Do6bvw=
google.golang.org/genai v1.52.1 h1:dYoljKtLDXMiBdVaClSJ/ZPwZ7j1N0lGjMhwOKOQUlk=
google.golang.org/genai v1.52.1/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 h1:XzmzkmB14QhVhgnawEVsOn6OFsnpyxNPRY9QV01dNB0=
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:L43LFes82YgSonw6iTXTxXUX1OlULt4AQtkik4ULL/I=
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7 h1:41r6JMbpzBMen0R/4TZeeAmGXSJC7DftGINUodzTkPI=
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:EIQZ5bFCfRQDV4MhRle7+OgjNtZ6P1PiZBgAKuxXu/Y=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260406210006-6f92a3bedf2d h1:wT2n40TBqFY6wiwazVK9/iTWbsQrgk5ZfCSVFLO9LQA=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260406210006-6f92a3bedf2d/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
+2 -112
View File
@@ -23,18 +23,6 @@ import (
// Version is injected at build time; fallback to "dev".
var Version = "dev"
// thinkingTagOpen and thinkingTagClose are the XML-style tags that some models
// (Qwen, DeepSeek) wrap reasoning content in. We parse these to extract
// reasoning/thinking content and send it as ACP thought updates.
// Also support <think> format used by some models.
const (
thinkingTagOpen = "<thinking>"
thinkingTagClose = "</thinking>"
shortThinkTagOpen = "<think>"
shortThinkTagClose = "</think>"
)
// Agent implements the acp.Agent interface, delegating to Kit for LLM
// execution, tool calls, and session management.
type Agent struct {
conn *acp.AgentSideConnection
@@ -42,10 +30,6 @@ type Agent struct {
// toolCallCounter provides unique IDs for tool calls within a turn.
toolCallCounter atomic.Int64
// inThinkingTag tracks whether we're currently inside a <thinking> tag
// when parsing streaming content from models that wrap reasoning in XML tags.
inThinkingTag bool
}
// NewAgent creates a new ACP agent backed by Kit.
@@ -144,9 +128,6 @@ func (a *Agent) Prompt(ctx context.Context, params acp.PromptRequest) (acp.Promp
log.Debug("acp: prompt", "session", sessionID, "prompt_len", len(promptText), "files", len(files))
// Reset thinking tag state for this new prompt turn
a.inThinkingTag = false
// Create a cancellable context for this prompt turn.
promptCtx, cancel := context.WithCancel(ctx)
sess.setCancel(cancel)
@@ -230,24 +211,8 @@ func (a *Agent) subscribeEvents(ctx context.Context, k *kit.Kit, sessionID acp.S
var update *acp.SessionUpdate
switch ev := e.(type) {
case kit.MessageUpdateEvent:
// Handle models that wrap reasoning in <thinking> tags (Qwen, DeepSeek)
// Parse the chunk and separate reasoning from regular text
reasoning, text := a.parseThinkingTags(ev.Chunk)
// Send reasoning update if we have reasoning content
if reasoning != "" {
u := acp.UpdateAgentThoughtText(reasoning)
_ = a.conn.SessionUpdate(ctx, acp.SessionNotification{
SessionId: sessionID,
Update: u,
})
}
// Send text update if we have text content
if text != "" {
u := acp.UpdateAgentMessageText(text)
update = &u
}
u := acp.UpdateAgentMessageText(ev.Chunk)
update = &u
case kit.ReasoningDeltaEvent:
u := acp.UpdateAgentThoughtText(ev.Delta)
@@ -430,81 +395,6 @@ func extractPromptContent(blocks []acp.ContentBlock) (string, []kit.LLMFilePart)
return strings.Join(textParts, "\n"), files
}
// parseThinkingTags parses a text chunk for <thinking> or tags and separates
// reasoning content from regular text. This handles models (Qwen, DeepSeek)
// that wrap reasoning in XML-style tags instead of using proper reasoning events.
// Returns (reasoningContent, textContent).
func (a *Agent) parseThinkingTags(chunk string) (reasoning string, text string) {
// Handle empty chunk
if chunk == "" {
return "", ""
}
// Determine which tag format to use (long or short)
openTag := thinkingTagOpen
closeTag := thinkingTagClose
if strings.Contains(chunk, shortThinkTagOpen) || strings.Contains(chunk, shortThinkTagClose) {
openTag = shortThinkTagOpen
closeTag = shortThinkTagClose
} else if !strings.Contains(chunk, thinkingTagOpen) && !strings.Contains(chunk, thinkingTagClose) && !a.inThinkingTag {
// No tags at all and not in thinking mode - return as text
return "", chunk
}
// Check for opening tag
if strings.Contains(chunk, openTag) {
parts := strings.SplitN(chunk, openTag, 2)
// Content before the opening tag is regular text
if !a.inThinkingTag && parts[0] != "" {
text = parts[0]
}
a.inThinkingTag = true
// Content after the opening tag is reasoning
if len(parts) > 1 {
// Check if the same chunk contains the closing tag
if strings.Contains(parts[1], closeTag) {
innerParts := strings.SplitN(parts[1], closeTag, 2)
reasoning = innerParts[0]
a.inThinkingTag = false
// Content after closing tag is regular text
if len(innerParts) > 1 && innerParts[1] != "" {
text += innerParts[1]
}
} else if parts[1] != "" {
// No closing tag yet, all remaining content is reasoning
reasoning = parts[1]
}
}
return reasoning, text
}
// Check for closing tag
if strings.Contains(chunk, closeTag) {
parts := strings.SplitN(chunk, closeTag, 2)
a.inThinkingTag = false
// Content before closing tag is reasoning
reasoning = parts[0]
// Content after closing tag is regular text
if len(parts) > 1 && parts[1] != "" {
text = parts[1]
}
return reasoning, text
}
// No tags found - content goes to current mode
if a.inThinkingTag {
return chunk, ""
}
return "", chunk
}
// isTextMimeType returns true if the MIME type indicates text content.
func isTextMimeType(mimeType string) bool {
return strings.HasPrefix(mimeType, "text/") ||
+119 -8
View File
@@ -30,11 +30,21 @@ type AgentConfig struct {
// If nil, remote MCP servers that require OAuth will fail to connect.
AuthHandler tools.MCPAuthHandler
// TokenStoreFactory, if non-nil, creates a custom token store for each
// remote MCP server's OAuth tokens. When nil, the default file-based
// token store is used.
TokenStoreFactory tools.TokenStoreFactory
// CoreTools overrides the default core tool set. If empty, core.AllTools()
// is used. This allows SDK users to provide a custom tool set (e.g.
// CodingTools or tools with a custom WorkDir).
CoreTools []fantasy.AgentTool
// DisableCoreTools, when true, prevents loading any core tools.
// If both DisableCoreTools is true and CoreTools is empty, the agent
// will have no tools (useful for simple chat completions).
DisableCoreTools bool
// ToolWrapper is an optional function that wraps the combined tool list
// before it is passed to the LLM agent. Used by the extensions system
// to intercept tool calls/results.
@@ -84,6 +94,14 @@ type ReasoningCompleteHandler func()
// Note: This is an alias for core.ToolOutputCallback to avoid import cycles.
type ToolOutputHandler = core.ToolOutputCallback
// StepMessagesHandler is a function type for persisting messages after each
// complete step in a multi-step agent turn. The handler receives the messages
// produced by the step (typically an assistant message with tool calls followed
// by a tool-role message with results, or a final assistant message with text).
// This enables incremental session persistence so that progress is saved as
// it happens rather than only at the end of the turn.
type StepMessagesHandler func(stepMessages []fantasy.Message)
// StepUsageHandler is a function type for handling token usage after each
// complete step in a multi-step agent turn. This enables real-time cost
// tracking during long-running tool-calling conversations.
@@ -136,6 +154,11 @@ type GenerateWithLoopResult struct {
TotalUsage fantasy.Usage
// StopReason is the LLM provider's finish reason for the final response.
StopReason string
// PersistedMessageCount is the number of new messages (beyond the original
// input) that were already persisted incrementally via OnStepMessages during
// generation. The caller should skip these when doing post-generation
// persistence to avoid duplicates.
PersistedMessageCount int
}
// NewAgent creates a new Agent with core tools and optional MCP tool integration.
@@ -153,8 +176,16 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
// Register core tools (direct AgentTool implementations, no MCP overhead).
// Use caller-provided tools if set, otherwise default to all core tools.
coreTools := agentConfig.CoreTools
if len(coreTools) == 0 {
// DisableCoreTools allows explicitly having zero tools (for chat-only mode).
var coreTools []fantasy.AgentTool
if agentConfig.DisableCoreTools && len(agentConfig.CoreTools) == 0 {
// Explicitly zero tools - chat-only mode
coreTools = nil
} else if len(agentConfig.CoreTools) > 0 {
// Custom tools provided - use them
coreTools = agentConfig.CoreTools
} else {
// Default: load all core tools
coreTools = core.AllTools()
}
@@ -210,6 +241,9 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
if agentConfig.AuthHandler != nil {
toolManager.SetAuthHandler(agentConfig.AuthHandler)
}
if agentConfig.TokenStoreFactory != nil {
toolManager.SetTokenStoreFactory(agentConfig.TokenStoreFactory)
}
if agentConfig.DebugLogger != nil {
toolManager.SetDebugLogger(agentConfig.DebugLogger)
}
@@ -364,7 +398,7 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []fantasy.Message
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
) (*GenerateWithLoopResult, error) {
return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult,
onResponse, onToolCallContent, nil, nil, nil, nil, nil)
onResponse, onToolCallContent, nil, nil, nil, nil, nil, nil)
}
// GenerateWithLoopAndStreaming processes messages using the agent with streaming and callbacks.
@@ -377,6 +411,7 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
onReasoningDelta ReasoningDeltaHandler,
onReasoningComplete ReasoningCompleteHandler,
onToolOutput ToolOutputHandler,
onStepMessages StepMessagesHandler,
onStepUsage StepUsageHandler,
) (*GenerateWithLoopResult, error) {
@@ -416,6 +451,10 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
// when it returns an error, but the OnStepFinish callback fires
// for every step that completed before the error occurred.
var completedStepMessages []fantasy.Message
// persistedCount tracks how many new messages (beyond the original
// input) were persisted incrementally via onStepMessages, so the
// caller can skip them during post-generation persistence.
var persistedCount int
// Use the streaming agent
streamCall := fantasy.AgentStreamCall{
@@ -501,6 +540,13 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
// persisted even if a later step is cancelled.
completedStepMessages = append(completedStepMessages, step.Messages...)
// Persist step messages incrementally so progress is saved
// as it happens rather than only at the end of the turn.
if onStepMessages != nil && len(step.Messages) > 0 {
onStepMessages(step.Messages)
persistedCount += len(step.Messages)
}
if ctx.Err() != nil {
return ctx.Err()
}
@@ -579,7 +625,8 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
partialMessages = append(partialMessages, messages...)
partialMessages = append(partialMessages, completedStepMessages...)
return &GenerateWithLoopResult{
ConversationMessages: partialMessages,
ConversationMessages: partialMessages,
PersistedMessageCount: persistedCount,
}, err
}
return nil, err
@@ -594,7 +641,9 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
onResponse(result.Response.Content.Text())
}
return convertAgentResult(result, messages), nil
r := convertAgentResult(result, messages)
r.PersistedMessageCount = persistedCount
return r, nil
}
// Non-streaming path with no callbacks — use the simpler Generate call.
@@ -785,6 +834,59 @@ func (a *Agent) SetExtraTools(extraTools []fantasy.AgentTool) {
a.rebuildFantasyAgent()
}
// AddMCPServer connects to a new MCP server at runtime and makes its tools
// available to the agent. Returns the number of tools loaded.
// If the agent has no tool manager (no MCP servers were configured at init),
// one is created automatically.
func (a *Agent) AddMCPServer(ctx context.Context, name string, cfg config.MCPServerConfig) (int, error) {
// Ensure MCP tools from initial load are settled first.
a.ensureMCPTools()
if a.toolManager == nil {
a.toolManager = tools.NewMCPToolManager()
a.toolManager.SetModel(a.model)
a.toolManager.SetOnToolsChanged(func() {
a.rebuildFantasyAgent()
})
}
count, err := a.toolManager.AddServer(ctx, name, cfg)
if err != nil {
return 0, err
}
// AddServer's onToolsChanged callback triggers rebuildFantasyAgent,
// but only if it was wired. Ensure rebuild happens regardless.
a.rebuildFantasyAgent()
return count, nil
}
// RemoveMCPServer disconnects an MCP server and removes its tools from the agent.
func (a *Agent) RemoveMCPServer(name string) error {
if a.toolManager == nil {
return fmt.Errorf("no MCP servers loaded")
}
// Ensure MCP tools from initial load are settled first.
a.ensureMCPTools()
err := a.toolManager.RemoveServer(name)
if err != nil {
return err
}
// RemoveServer's onToolsChanged callback triggers rebuildFantasyAgent,
// but ensure rebuild happens regardless.
a.rebuildFantasyAgent()
return nil
}
// GetMCPToolManager returns the underlying MCP tool manager.
// Returns nil if no MCP servers have been configured.
func (a *Agent) GetMCPToolManager() *tools.MCPToolManager {
return a.toolManager
}
// GetLoadingMessage returns the loading message from provider creation.
func (a *Agent) GetLoadingMessage() string {
return a.loadingMessage
@@ -798,9 +900,11 @@ func (a *Agent) GetLoadedServerNames() []string {
return a.toolManager.GetLoadedServerNames()
}
// SetModel swaps the agent's LLM provider to a new model. The existing tools,
// system prompt, and configuration are preserved. The old provider is closed
// if it has a closer. Returns the previous model string for notification.
// SetModel swaps the agent's LLM provider to a new model. The existing tools
// and configuration are preserved. When the new model's ProviderConfig carries
// a system prompt (from per-model settings), it replaces the agent's stored
// prompt so the rebuilt fantasy agent uses it. The old provider is closed if
// it has a closer.
func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) error {
// Ensure MCP tools are loaded before rebuilding (SetModel may be called
// before the first LLM call).
@@ -827,6 +931,13 @@ func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) err
a.skipMaxOutputTokens = providerResult.SkipMaxOutputTokens
a.modelConfig = config
// Update system prompt when the config carries one (from per-model
// settings or the global config). This allows model-specific system
// prompts to take effect on model switch.
if config.SystemPrompt != "" {
a.systemPrompt = config.SystemPrompt
}
// Update provider type.
if config.ModelString != "" {
if p, _, err := models.ParseModelString(config.ModelString); err == nil {
+242
View File
@@ -0,0 +1,242 @@
package agent
import (
"context"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
"charm.land/fantasy"
"github.com/mark3labs/kit/internal/config"
)
// mockModel is a minimal LanguageModel that satisfies the interface
// without making real API calls. Used to test tool management wiring.
type mockModel struct{}
func (m *mockModel) Generate(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
return &fantasy.Response{}, nil
}
func (m *mockModel) Stream(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return nil, nil
}
func (m *mockModel) GenerateObject(_ context.Context, _ fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
return &fantasy.ObjectResponse{}, nil
}
func (m *mockModel) StreamObject(_ context.Context, _ fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
return nil, nil
}
func (m *mockModel) Provider() string { return "mock" }
func (m *mockModel) Model() string { return "mock-model" }
// testdataDir returns the absolute path to the tools testdata directory.
func testdataDir(t *testing.T) string {
t.Helper()
_, file, _, ok := runtime.Caller(0)
if !ok {
t.Fatal("cannot determine test file path")
}
return filepath.Join(filepath.Dir(file), "..", "tools", "testdata")
}
// echoServerConfig returns an MCPServerConfig for the test echo MCP server.
func echoServerConfig(t *testing.T) config.MCPServerConfig {
t.Helper()
script := filepath.Join(testdataDir(t), "echo_server.py")
if _, err := os.Stat(script); err != nil {
t.Skipf("echo_server.py not found: %v", err)
}
return config.MCPServerConfig{
Command: []string{"python3", script},
}
}
// newTestAgent creates a minimal Agent with a mock model and no core tools,
// suitable for testing MCP server management without an API key.
func newTestAgent() *Agent {
model := &mockModel{}
a := &Agent{
model: model,
coreTools: nil,
extraTools: nil,
maxSteps: 10,
systemPrompt: "test",
fantasyAgent: fantasy.NewAgent(model),
}
return a
}
func TestAgent_AddMCPServer(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
a := newTestAgent()
defer func() { _ = a.Close() }()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cfg := echoServerConfig(t)
// Initially no MCP tools.
if a.GetMCPToolCount() != 0 {
t.Fatalf("Expected 0 MCP tools initially, got %d", a.GetMCPToolCount())
}
// Add a server.
count, err := a.AddMCPServer(ctx, "echo", cfg)
if err != nil {
t.Fatalf("AddMCPServer failed: %v", err)
}
if count != 2 {
t.Errorf("Expected 2 tools, got %d", count)
}
// Verify tools are in the agent's tool list.
if a.GetMCPToolCount() != 2 {
t.Errorf("Expected 2 MCP tools, got %d", a.GetMCPToolCount())
}
allTools := a.GetTools()
toolNames := make(map[string]bool)
for _, tool := range allTools {
toolNames[tool.Info().Name] = true
}
if !toolNames["echo__echo"] {
t.Error("Expected tool 'echo__echo' in agent tools")
}
if !toolNames["echo__greet"] {
t.Error("Expected tool 'echo__greet' in agent tools")
}
// Verify loaded server names.
names := a.GetLoadedServerNames()
found := false
for _, n := range names {
if n == "echo" {
found = true
}
}
if !found {
t.Errorf("Expected 'echo' in loaded server names: %v", names)
}
}
func TestAgent_RemoveMCPServer(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
a := newTestAgent()
defer func() { _ = a.Close() }()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cfg := echoServerConfig(t)
// Add then remove.
_, err := a.AddMCPServer(ctx, "echo", cfg)
if err != nil {
t.Fatalf("AddMCPServer failed: %v", err)
}
err = a.RemoveMCPServer("echo")
if err != nil {
t.Fatalf("RemoveMCPServer failed: %v", err)
}
// Verify tools removed.
if a.GetMCPToolCount() != 0 {
t.Errorf("Expected 0 MCP tools after removal, got %d", a.GetMCPToolCount())
}
// Verify agent's tool list has no MCP tools.
for _, tool := range a.GetTools() {
if strings.Contains(tool.Info().Name, "echo__") {
t.Errorf("Found leftover tool after removal: %s", tool.Info().Name)
}
}
}
func TestAgent_RemoveMCPServer_NoToolManager(t *testing.T) {
a := newTestAgent()
defer func() { _ = a.Close() }()
err := a.RemoveMCPServer("nonexistent")
if err == nil {
t.Fatal("Expected error when no tool manager exists")
}
if !strings.Contains(err.Error(), "no MCP servers loaded") {
t.Errorf("Expected 'no MCP servers loaded' error, got: %v", err)
}
}
func TestAgent_AddMCPServer_CreatesToolManager(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
a := newTestAgent()
defer func() { _ = a.Close() }()
// Initially no tool manager.
if a.GetMCPToolManager() != nil {
t.Fatal("Expected nil tool manager initially")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cfg := echoServerConfig(t)
_, err := a.AddMCPServer(ctx, "echo", cfg)
if err != nil {
t.Fatalf("AddMCPServer failed: %v", err)
}
// Tool manager should now exist.
if a.GetMCPToolManager() == nil {
t.Fatal("Expected tool manager to be created by AddMCPServer")
}
}
func TestAgent_AddRemoveAdd_MCP(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
a := newTestAgent()
defer func() { _ = a.Close() }()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cfg := echoServerConfig(t)
// Add → Remove → Add cycle.
_, err := a.AddMCPServer(ctx, "echo", cfg)
if err != nil {
t.Fatalf("First add failed: %v", err)
}
err = a.RemoveMCPServer("echo")
if err != nil {
t.Fatalf("Remove failed: %v", err)
}
count, err := a.AddMCPServer(ctx, "echo", cfg)
if err != nil {
t.Fatalf("Re-add failed: %v", err)
}
if count != 2 {
t.Errorf("Expected 2 tools on re-add, got %d", count)
}
if a.GetMCPToolCount() != 2 {
t.Errorf("Expected 2 MCP tools after re-add, got %d", a.GetMCPToolCount())
}
}
+10
View File
@@ -38,9 +38,17 @@ type AgentCreationOptions struct {
DebugLogger tools.DebugLogger // Optional debug logger
// AuthHandler handles OAuth authorization for remote MCP servers
AuthHandler tools.MCPAuthHandler
// TokenStoreFactory, if non-nil, creates a custom token store for each
// remote MCP server's OAuth tokens. When nil, the default file-based
// token store is used.
TokenStoreFactory tools.TokenStoreFactory
// CoreTools overrides the default core tool set. If empty, core.AllTools()
// is used.
CoreTools []fantasy.AgentTool
// DisableCoreTools, when true, prevents loading any core tools.
// If both DisableCoreTools is true and CoreTools is empty, the agent
// will have no tools (useful for simple chat completions).
DisableCoreTools bool
// ToolWrapper wraps the combined tool list before agent creation.
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
// ExtraTools are additional tools to include (e.g. from extensions).
@@ -62,7 +70,9 @@ func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error
StreamingEnabled: opts.StreamingEnabled,
DebugLogger: opts.DebugLogger,
AuthHandler: opts.AuthHandler,
TokenStoreFactory: opts.TokenStoreFactory,
CoreTools: opts.CoreTools,
DisableCoreTools: opts.DisableCoreTools,
ToolWrapper: opts.ToolWrapper,
ExtraTools: opts.ExtraTools,
OnMCPServerLoaded: opts.OnMCPServerLoaded,
+16 -6
View File
@@ -930,7 +930,8 @@ func (a *App) QuitFromExtension() {
// controls styling: "" for plain text, "info" for a system message block,
// "error" for an error block. In interactive mode it sends an
// ExtensionPrintEvent through the program so the TUI can render it with the
// appropriate renderer. In non-interactive mode it falls back to stdout.
// appropriate renderer. In non-interactive mode it falls back to stderr with
// a level prefix so errors are distinguishable from plain output.
func (a *App) PrintFromExtension(level, text string) {
a.mu.Lock()
prog := a.program
@@ -939,8 +940,16 @@ func (a *App) PrintFromExtension(level, text string) {
prog.Send(ExtensionPrintEvent{Text: text, Level: level})
return
}
// Non-interactive fallback: write directly to stdout.
fmt.Println(text)
// Non-interactive fallback: write to stderr with a level prefix so that
// errors and info messages are distinguishable from plain output.
switch level {
case "error":
fmt.Fprintf(os.Stderr, "[ERROR] %s\n", text)
case "info":
fmt.Fprintf(os.Stderr, "[INFO] %s\n", text)
default:
fmt.Println(text)
}
}
// SetEditorTextFromExtension sends an EditorTextSetEvent to the TUI to
@@ -1122,11 +1131,12 @@ func (a *App) PrintBlockFromExtension(opts extensions.PrintBlockOpts) {
})
return
}
// Non-interactive fallback.
// Non-interactive fallback: render a simple framed block to stderr so
// it is visually distinct from plain stdout output.
if opts.Subtitle != "" {
fmt.Printf("%s\n — %s\n", opts.Text, opts.Subtitle)
fmt.Fprintf(os.Stderr, "--- %s ---\n%s\n", opts.Subtitle, opts.Text)
} else {
fmt.Println(opts.Text)
fmt.Fprintf(os.Stderr, "---\n%s\n---\n", opts.Text)
}
}
+64 -1
View File
@@ -157,6 +157,21 @@ type Theme struct {
Markdown MarkdownThemeConfig `json:"markdown,omitzero" yaml:"markdown,omitempty"`
}
// GenerationParams defines generation parameter defaults that can be attached
// to individual models. These act as model-level defaults — CLI flags and
// global config values take precedence when explicitly set.
type GenerationParams struct {
MaxTokens *int `json:"maxTokens,omitempty" yaml:"maxTokens,omitempty"`
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
TopP *float32 `json:"topP,omitempty" yaml:"topP,omitempty"`
TopK *int32 `json:"topK,omitempty" yaml:"topK,omitempty"`
FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty" yaml:"frequencyPenalty,omitempty"`
PresencePenalty *float32 `json:"presencePenalty,omitempty" yaml:"presencePenalty,omitempty"`
StopSequences []string `json:"stopSequences,omitempty" yaml:"stopSequences,omitempty"`
ThinkingLevel string `json:"thinkingLevel,omitempty" yaml:"thinkingLevel,omitempty"`
SystemPrompt string `json:"systemPrompt,omitempty" yaml:"systemPrompt,omitempty"`
}
// CustomModelConfig defines a custom model that can be used with custom/custom
// or other custom/ prefixed models. These models are loaded from the config file
// and merged into the custom provider in the model registry.
@@ -171,6 +186,11 @@ type CustomModelConfig struct {
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
Cost CostConfig `json:"cost" yaml:"cost"`
Limit LimitConfig `json:"limit" yaml:"limit"`
// Generation parameter defaults for this model.
// These are applied when the user hasn't explicitly set the corresponding
// CLI flag or global config value.
Params GenerationParams `json:"params,omitzero" yaml:"params,omitempty"`
}
// CostConfig defines the pricing for a custom model.
@@ -219,6 +239,12 @@ type Config struct {
// Custom model definitions (under custom/ provider)
CustomModels map[string]CustomModelConfig `json:"customModels,omitempty" yaml:"customModels,omitempty"`
// Per-model generation parameter overrides. Keys are "provider/model" strings
// (e.g. "anthropic/claude-sonnet-4-5-20250929", "openai/gpt-4o"). These
// settings act as model-level defaults — CLI flags and global config values
// take precedence when explicitly set.
ModelSettings map[string]GenerationParams `json:"modelSettings,omitempty" yaml:"modelSettings,omitempty"`
}
// GetTransportType returns the transport type for the server config, mapping
@@ -367,7 +393,7 @@ mcpServers:
# debug: false # Enable debug logging
# system-prompt: "/path/to/system-prompt.txt" # System prompt text file
# Model generation parameters (all optional)
# Model generation parameters (all optional, apply globally to all models)
# max-tokens: 4096 # Maximum tokens in response
# temperature: 0.7 # Randomness (0.0-1.0)
# top-p: 0.95 # Nucleus sampling (0.0-1.0)
@@ -376,9 +402,46 @@ mcpServers:
# presence-penalty: 0.0 # Penalize present tokens (0.0-2.0)
# stop-sequences: ["Human:", "Assistant:"] # Custom stop sequences
# Per-model generation parameter overrides (apply to specific models)
# These act as model-level defaults — CLI flags and global settings above take precedence.
# Keys are "provider/model" strings matching the model you use.
# modelSettings:
# anthropic/claude-sonnet-4-5-20250929:
# temperature: 0.3
# maxTokens: 8192
# openai/gpt-4o:
# temperature: 0.7
# topP: 0.95
# topK: 40
# frequencyPenalty: 0.1
# presencePenalty: 0.1
# anthropic/claude-opus-4-6:
# thinkingLevel: "high"
# maxTokens: 16384
# systemPrompt: "You are a deep reasoning assistant." # or a file path
# API Configuration (can also use environment variables)
# provider-api-key: "your-api-key" # API key for OpenAI, Anthropic, or Google
# provider-url: "https://api.openai.com/v1" # Base URL for OpenAI, Anthropic, or Ollama
# Custom model definitions (under custom/ provider)
# customModels:
# my-local-llama:
# name: "Local Llama 3"
# baseUrl: "http://localhost:8080/v1"
# family: "llama"
# temperature: true
# cost:
# input: 0.0
# output: 0.0
# limit:
# context: 131072
# output: 8192
# params: # Generation parameter defaults for this model
# temperature: 0.8
# topP: 0.95
# topK: 40
# systemPrompt: "You are a helpful local assistant."
`
_, err = file.WriteString(content)
+1 -6
View File
@@ -34,15 +34,10 @@ func LoadExtensions(extraPaths []string) ([]LoadedExtension, error) {
for _, p := range paths {
ext, err := loadSingleExtension(p)
if err != nil {
log.Warn("skipping extension", "path", p, "err", err)
continue
}
loaded = append(loaded, *ext)
log.Debug("loaded extension", "path", p,
"handlers", countHandlers(ext),
"tools", len(ext.Tools),
"commands", len(ext.Commands),
"tool_renderers", len(ext.ToolRenderers))
log.Debug("loaded extension", "path", p, "handlers", countHandlers(ext), "tools", len(ext.Tools), "commands", len(ext.Commands), "tool_renderers", len(ext.ToolRenderers))
}
return loaded, nil
}
+3 -8
View File
@@ -2,12 +2,12 @@ package extensions
import (
"fmt"
"log"
"os"
"sort"
"strings"
"sync"
"github.com/charmbracelet/log"
"github.com/spf13/viper"
)
@@ -370,10 +370,7 @@ func (r *Runner) Emit(event Event) (Result, error) {
for _, handler := range handlers {
result, err := safeCall(handler, event, ctx)
if err != nil {
log.Warn("extension handler error",
"path", ext.Path,
"event", event.Type(),
"err", err)
log.Printf("WARN extension handler error: path=%s event=%s err=%v", ext.Path, event.Type(), err)
continue
}
if result == nil {
@@ -707,9 +704,7 @@ func (r *Runner) EmitCustomEvent(name, data string) {
safeInvoke := func(h func(string)) {
defer func() {
if rec := recover(); rec != nil {
log.Warn("custom event handler panicked",
"event", name,
"err", fmt.Sprintf("%v", rec))
log.Printf("WARN custom event handler panicked: event=%s err=%v", name, rec)
}
}()
h(data)
+6 -6
View File
@@ -3,13 +3,13 @@ package extensions
import (
"context"
"fmt"
"log"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/charmbracelet/log"
"github.com/fsnotify/fsnotify"
)
@@ -39,7 +39,7 @@ func NewWatcher(dirs []string, onReload func()) (*Watcher, error) {
for _, dir := range dirs {
// Watch the directory itself.
if err := fsw.Add(dir); err != nil {
log.Debug("watcher: skipping directory", "dir", dir, "err", err)
log.Printf("DEBUG watcher: skipping directory: dir=%s err=%v", dir, err)
continue
}
@@ -52,7 +52,7 @@ func NewWatcher(dirs []string, onReload func()) (*Watcher, error) {
if entry.IsDir() {
subdir := filepath.Join(dir, entry.Name())
if err := fsw.Add(subdir); err != nil {
log.Debug("watcher: skipping subdirectory", "dir", subdir, "err", err)
log.Printf("DEBUG watcher: skipping subdirectory: dir=%s err=%v", subdir, err)
}
}
}
@@ -101,7 +101,7 @@ func (w *Watcher) Start(ctx context.Context) {
continue
}
log.Debug("watcher: file changed", "file", event.Name, "op", event.Op)
log.Printf("DEBUG watcher: file changed: file=%s op=%s", event.Name, event.Op)
// Debounce: reset timer on each event.
if timer != nil {
@@ -113,14 +113,14 @@ func (w *Watcher) Start(ctx context.Context) {
case <-timerC:
timerC = nil
timer = nil
log.Debug("watcher: reloading extensions")
log.Printf("DEBUG watcher: reloading extensions")
w.onReload()
case err, ok := <-w.watcher.Errors:
if !ok {
return
}
log.Warn("watcher: error", "err", err)
log.Printf("WARN watcher: error: %v", err)
}
}
}
+49 -20
View File
@@ -33,6 +33,10 @@ type AgentSetupOptions struct {
// CoreTools overrides the default core tool set. If empty, core.AllTools()
// is used. Allows SDK users to pass custom tools (e.g. with WithWorkDir).
CoreTools []fantasy.AgentTool
// DisableCoreTools, when true, prevents loading any core tools.
// If both DisableCoreTools is true and CoreTools is empty, the agent
// will have no tools (useful for simple chat completions).
DisableCoreTools bool
// ExtraTools are additional tools added alongside core, MCP, and extension
// tools. They do not replace the defaults — they extend them.
ExtraTools []fantasy.AgentTool
@@ -61,6 +65,10 @@ type AgentSetupOptions struct {
// AuthHandler handles OAuth authorization for remote MCP servers.
// When set, remote transports are configured with OAuth support.
AuthHandler tools.MCPAuthHandler
// TokenStoreFactory, if non-nil, creates a custom token store for each
// remote MCP server's OAuth tokens. When nil, the default file-based
// token store is used.
TokenStoreFactory tools.TokenStoreFactory
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
// loading (successfully or with error). Called from the background goroutine.
OnMCPServerLoaded func(serverName string, toolCount int, err error)
@@ -78,36 +86,55 @@ type AgentSetupResult struct {
// BuildProviderConfig creates a *models.ProviderConfig from the current viper
// state. All entry points (root, script, SDK) converge through this function.
//
// Generation parameter pointers (Temperature, TopP, etc.) are only set when
// the user has explicitly configured them via CLI flag, environment variable,
// or global config file. This allows per-model defaults from modelSettings
// and customModels to fill in unset parameters downstream.
func BuildProviderConfig() (*models.ProviderConfig, string, error) {
systemPrompt, err := config.LoadSystemPrompt(viper.GetString("system-prompt"))
if err != nil {
return nil, "", fmt.Errorf("failed to load system prompt: %w", err)
}
temperature := float32(viper.GetFloat64("temperature"))
topP := float32(viper.GetFloat64("top-p"))
topK := int32(viper.GetInt("top-k"))
frequencyPenalty := float32(viper.GetFloat64("frequency-penalty"))
presencePenalty := float32(viper.GetFloat64("presence-penalty"))
numGPU := int32(viper.GetInt("num-gpu-layers"))
mainGPU := int32(viper.GetInt("main-gpu"))
cfg := &models.ProviderConfig{
ModelString: viper.GetString("model"),
SystemPrompt: systemPrompt,
ProviderAPIKey: viper.GetString("provider-api-key"),
ProviderURL: viper.GetString("provider-url"),
MaxTokens: viper.GetInt("max-tokens"),
Temperature: &temperature,
TopP: &topP,
TopK: &topK,
FrequencyPenalty: &frequencyPenalty,
PresencePenalty: &presencePenalty,
StopSequences: viper.GetStringSlice("stop-sequences"),
NumGPU: &numGPU,
MainGPU: &mainGPU,
TLSSkipVerify: viper.GetBool("tls-skip-verify"),
ThinkingLevel: models.ParseThinkingLevel(viper.GetString("thinking-level")),
ModelString: viper.GetString("model"),
SystemPrompt: systemPrompt,
ProviderAPIKey: viper.GetString("provider-api-key"),
ProviderURL: viper.GetString("provider-url"),
MaxTokens: viper.GetInt("max-tokens"),
StopSequences: viper.GetStringSlice("stop-sequences"),
NumGPU: &numGPU,
MainGPU: &mainGPU,
TLSSkipVerify: viper.GetBool("tls-skip-verify"),
ThinkingLevel: models.ParseThinkingLevel(viper.GetString("thinking-level")),
}
// Only set generation parameter pointers when the user has explicitly
// provided a value. This leaves nil pointers for unset params, allowing
// per-model defaults (modelSettings / customModels params) to apply.
if viper.IsSet("temperature") {
v := float32(viper.GetFloat64("temperature"))
cfg.Temperature = &v
}
if viper.IsSet("top-p") {
v := float32(viper.GetFloat64("top-p"))
cfg.TopP = &v
}
if viper.IsSet("top-k") {
v := int32(viper.GetInt("top-k"))
cfg.TopK = &v
}
if viper.IsSet("frequency-penalty") {
v := float32(viper.GetFloat64("frequency-penalty"))
cfg.FrequencyPenalty = &v
}
if viper.IsSet("presence-penalty") {
v := float32(viper.GetFloat64("presence-penalty"))
cfg.PresencePenalty = &v
}
return cfg, systemPrompt, nil
@@ -196,7 +223,9 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
SpinnerFunc: opts.SpinnerFunc,
DebugLogger: debugLogger,
AuthHandler: opts.AuthHandler,
TokenStoreFactory: opts.TokenStoreFactory,
CoreTools: opts.CoreTools,
DisableCoreTools: opts.DisableCoreTools,
ToolWrapper: toolWrapper,
ExtraTools: extraTools,
OnMCPServerLoaded: opts.OnMCPServerLoaded,
+234 -11
View File
@@ -2,6 +2,8 @@ package models
import (
"log"
"os"
"strings"
"github.com/spf13/viper"
)
@@ -31,7 +33,7 @@ func loadCustomModelsFromConfig() map[string]ModelInfo {
// modelConfigToModelInfo converts a CustomModelConfig to a ModelInfo.
func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
return ModelInfo{
info := ModelInfo{
ID: modelID,
Name: cfg.Name,
Attachment: cfg.Attachment,
@@ -48,21 +50,242 @@ func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
Output: cfg.Limit.Output,
},
}
// Convert custom model generation params if any are set.
if p := convertGenerationParams(cfg.Params); p != nil {
info.Params = p
}
return info
}
// LoadModelSettingsFromConfig loads per-model generation parameter overrides
// from the config file. Keys are "provider/model" strings. Returns nil if
// no model settings are configured.
func LoadModelSettingsFromConfig() map[string]*GenerationParams {
if !viper.IsSet("modelSettings") {
return nil
}
var settings map[string]GenerationParamsConfig
if err := viper.UnmarshalKey("modelSettings", &settings); err != nil {
log.Printf("Warning: Failed to parse modelSettings: %v", err)
return nil
}
result := make(map[string]*GenerationParams, len(settings))
for modelKey, cfg := range settings {
if p := convertGenerationParams(cfg); p != nil {
result[modelKey] = p
}
}
return result
}
// convertGenerationParams converts a GenerationParamsConfig to a GenerationParams.
// Returns nil if no parameters are set.
func convertGenerationParams(cfg GenerationParamsConfig) *GenerationParams {
p := &GenerationParams{}
any := false
if cfg.MaxTokens != nil {
p.MaxTokens = cfg.MaxTokens
any = true
}
if cfg.Temperature != nil {
p.Temperature = cfg.Temperature
any = true
}
if cfg.TopP != nil {
p.TopP = cfg.TopP
any = true
}
if cfg.TopK != nil {
p.TopK = cfg.TopK
any = true
}
if cfg.FrequencyPenalty != nil {
p.FrequencyPenalty = cfg.FrequencyPenalty
any = true
}
if cfg.PresencePenalty != nil {
p.PresencePenalty = cfg.PresencePenalty
any = true
}
if len(cfg.StopSequences) > 0 {
p.StopSequences = cfg.StopSequences
any = true
}
if cfg.ThinkingLevel != "" {
p.ThinkingLevel = ParseThinkingLevel(cfg.ThinkingLevel)
any = true
}
if cfg.SystemPrompt != "" {
p.SystemPrompt = cfg.SystemPrompt
any = true
}
if !any {
return nil
}
return p
}
// ApplyModelSettings merges per-model generation parameter defaults from the
// registry into a ProviderConfig. Model-level params are only applied for
// fields where the user has not explicitly set a value (i.e., the
// corresponding viper key is not set via CLI flag or global config).
//
// The lookup order is:
// 1. modelSettings["provider/model"] from config (highest model-level priority)
// 2. ModelInfo.Params from custom model definitions
//
// Both are overridden by explicit CLI flags / global config values.
func ApplyModelSettings(config *ProviderConfig, modelInfo *ModelInfo) {
provider, modelName, err := ParseModelString(config.ModelString)
if err != nil {
return
}
// Collect model-level params: modelSettings override > custom model params.
// modelSettings takes priority because it's the more specific/intentional config.
var params *GenerationParams
// First check modelSettings from config.
if settings := LoadModelSettingsFromConfig(); settings != nil {
modelKey := provider + "/" + modelName
if p, ok := settings[modelKey]; ok {
params = p
}
}
// Fall back to ModelInfo.Params (from custom model definitions).
if params == nil && modelInfo != nil && modelInfo.Params != nil {
params = modelInfo.Params
}
if params == nil {
return
}
// Apply each parameter only when the user hasn't explicitly set it.
// We check viper.IsSet() which returns true only when the key was
// set via CLI flag, environment variable, or config file global section.
if params.MaxTokens != nil && !isExplicitlySet("max-tokens") {
config.MaxTokens = *params.MaxTokens
}
if params.Temperature != nil && !isExplicitlySet("temperature") {
config.Temperature = params.Temperature
}
if params.TopP != nil && !isExplicitlySet("top-p") {
config.TopP = params.TopP
}
if params.TopK != nil && !isExplicitlySet("top-k") {
config.TopK = params.TopK
}
if params.FrequencyPenalty != nil && !isExplicitlySet("frequency-penalty") {
config.FrequencyPenalty = params.FrequencyPenalty
}
if params.PresencePenalty != nil && !isExplicitlySet("presence-penalty") {
config.PresencePenalty = params.PresencePenalty
}
if len(params.StopSequences) > 0 && !isExplicitlySet("stop-sequences") {
config.StopSequences = params.StopSequences
}
if params.ThinkingLevel != "" && !isExplicitlySet("thinking-level") {
config.ThinkingLevel = params.ThinkingLevel
}
if params.SystemPrompt != "" && config.SystemPrompt == "" {
// Resolve file paths: if the value points to an existing file, read it.
// We check config.SystemPrompt == "" rather than isExplicitlySet because
// viper.BindPFlag causes IsSet to return true even for unset flags.
config.SystemPrompt = LoadSystemPromptValue(params.SystemPrompt)
}
}
// LoadSystemPromptValue resolves a system prompt value that may be either
// inline text or a file path. If the value is a path to an existing file,
// its contents are read and returned. Otherwise the string is returned as-is.
// This mirrors config.LoadSystemPrompt but lives in the models package to
// avoid circular dependencies.
func LoadSystemPromptValue(input string) string {
if input == "" {
return ""
}
if info, err := os.Stat(input); err == nil && !info.IsDir() {
content, err := os.ReadFile(input)
if err != nil {
log.Printf("Warning: failed to read system prompt file %q: %v", input, err)
return input
}
return strings.TrimSpace(string(content))
}
return input
}
// isExplicitlySet returns true when the user has explicitly set a config key
// via CLI flag, environment variable, or the global section of the config file.
// Model-level defaults should not override explicitly set values.
func isExplicitlySet(key string) bool {
// viper.IsSet returns true if the key has been set in any of the
// data stores (flag, env, config file, default). We need to check
// whether the value was set at the global config level (not just
// as a default). For generation params, the global config keys use
// hyphenated names (e.g. "max-tokens", "top-p").
//
// Since viper merges all sources, IsSet returns true even for config
// file values. This means global config file values (e.g.
// temperature: 0.7 at the top level) will correctly take precedence
// over model-level defaults, which is the desired behavior.
return viper.IsSet(key)
}
// GenerationParams holds per-model generation parameter defaults.
// These are stored on ModelInfo and applied during provider creation.
// Nil pointer fields mean "no model-level default" — the global config
// or CLI flag value (if any) will be used instead.
type GenerationParams struct {
MaxTokens *int
Temperature *float32
TopP *float32
TopK *int32
FrequencyPenalty *float32
PresencePenalty *float32
StopSequences []string
ThinkingLevel ThinkingLevel
SystemPrompt string // Per-model system prompt (inline text or file path)
}
// CustomModelConfig defines a custom model configuration loaded from the config file.
// This is a duplicate here to avoid circular dependencies with internal/config.
type CustomModelConfig struct {
Name string `json:"name" yaml:"name"`
BaseURL string `json:"baseUrl,omitempty" yaml:"baseUrl,omitempty"`
APIKey string `json:"apiKey,omitempty" yaml:"apiKey,omitempty"`
Family string `json:"family,omitempty" yaml:"family,omitempty"`
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
Temperature bool `json:"temperature,omitempty" yaml:"temperature,omitempty"`
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
Cost CostConfig `json:"cost" yaml:"cost"`
Limit LimitConfig `json:"limit" yaml:"limit"`
Name string `json:"name" yaml:"name"`
BaseURL string `json:"baseUrl,omitempty" yaml:"baseUrl,omitempty"`
APIKey string `json:"apiKey,omitempty" yaml:"apiKey,omitempty"`
Family string `json:"family,omitempty" yaml:"family,omitempty"`
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
Temperature bool `json:"temperature,omitempty" yaml:"temperature,omitempty"`
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
Cost CostConfig `json:"cost" yaml:"cost"`
Limit LimitConfig `json:"limit" yaml:"limit"`
Params GenerationParamsConfig `json:"params,omitzero" yaml:"params,omitempty"`
}
// GenerationParamsConfig is the JSON/YAML-serializable form of generation
// parameter defaults. Used in both customModels[].params and modelSettings[].
type GenerationParamsConfig struct {
MaxTokens *int `json:"maxTokens,omitempty" yaml:"maxTokens,omitempty"`
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
TopP *float32 `json:"topP,omitempty" yaml:"topP,omitempty"`
TopK *int32 `json:"topK,omitempty" yaml:"topK,omitempty"`
FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty" yaml:"frequencyPenalty,omitempty"`
PresencePenalty *float32 `json:"presencePenalty,omitempty" yaml:"presencePenalty,omitempty"`
StopSequences []string `json:"stopSequences,omitempty" yaml:"stopSequences,omitempty"`
ThinkingLevel string `json:"thinkingLevel,omitempty" yaml:"thinkingLevel,omitempty"`
SystemPrompt string `json:"systemPrompt,omitempty" yaml:"systemPrompt,omitempty"`
}
// CostConfig defines the pricing for a custom model.
+422
View File
@@ -0,0 +1,422 @@
package models
import (
"os"
"testing"
"github.com/spf13/viper"
)
func TestConvertGenerationParams(t *testing.T) {
t.Run("empty config returns nil", func(t *testing.T) {
cfg := GenerationParamsConfig{}
p := convertGenerationParams(cfg)
if p != nil {
t.Errorf("expected nil, got %+v", p)
}
})
t.Run("temperature only", func(t *testing.T) {
temp := float32(0.7)
cfg := GenerationParamsConfig{Temperature: &temp}
p := convertGenerationParams(cfg)
if p == nil {
t.Fatal("expected non-nil")
}
if p.Temperature == nil || *p.Temperature != 0.7 {
t.Errorf("expected temperature 0.7, got %v", p.Temperature)
}
if p.TopP != nil {
t.Errorf("expected nil TopP, got %v", p.TopP)
}
})
t.Run("all params set", func(t *testing.T) {
maxTokens := 8192
temp := float32(0.5)
topP := float32(0.9)
topK := int32(50)
freqPenalty := float32(0.1)
presPenalty := float32(0.2)
cfg := GenerationParamsConfig{
MaxTokens: &maxTokens,
Temperature: &temp,
TopP: &topP,
TopK: &topK,
FrequencyPenalty: &freqPenalty,
PresencePenalty: &presPenalty,
StopSequences: []string{"STOP"},
ThinkingLevel: "high",
}
p := convertGenerationParams(cfg)
if p == nil {
t.Fatal("expected non-nil")
}
if p.MaxTokens == nil || *p.MaxTokens != 8192 {
t.Errorf("expected maxTokens 8192, got %v", p.MaxTokens)
}
if p.Temperature == nil || *p.Temperature != 0.5 {
t.Errorf("expected temperature 0.5, got %v", p.Temperature)
}
if p.TopP == nil || *p.TopP != 0.9 {
t.Errorf("expected topP 0.9, got %v", p.TopP)
}
if p.TopK == nil || *p.TopK != 50 {
t.Errorf("expected topK 50, got %v", p.TopK)
}
if p.FrequencyPenalty == nil || *p.FrequencyPenalty != 0.1 {
t.Errorf("expected frequencyPenalty 0.1, got %v", p.FrequencyPenalty)
}
if p.PresencePenalty == nil || *p.PresencePenalty != 0.2 {
t.Errorf("expected presencePenalty 0.2, got %v", p.PresencePenalty)
}
if len(p.StopSequences) != 1 || p.StopSequences[0] != "STOP" {
t.Errorf("expected stop sequences [STOP], got %v", p.StopSequences)
}
if p.ThinkingLevel != ThinkingHigh {
t.Errorf("expected thinking level high, got %v", p.ThinkingLevel)
}
})
t.Run("thinking level parsing", func(t *testing.T) {
cfg := GenerationParamsConfig{ThinkingLevel: "medium"}
p := convertGenerationParams(cfg)
if p == nil {
t.Fatal("expected non-nil")
}
if p.ThinkingLevel != ThinkingMedium {
t.Errorf("expected thinking level medium, got %v", p.ThinkingLevel)
}
})
t.Run("system prompt only", func(t *testing.T) {
cfg := GenerationParamsConfig{SystemPrompt: "You are helpful."}
p := convertGenerationParams(cfg)
if p == nil {
t.Fatal("expected non-nil")
}
if p.SystemPrompt != "You are helpful." {
t.Errorf("expected system prompt, got %q", p.SystemPrompt)
}
})
}
func TestModelConfigToModelInfoWithParams(t *testing.T) {
temp := float32(0.8)
topP := float32(0.95)
cfg := CustomModelConfig{
Name: "Test Model",
BaseURL: "http://localhost:8080/v1",
Temperature: true,
Params: GenerationParamsConfig{
Temperature: &temp,
TopP: &topP,
},
}
info := modelConfigToModelInfo("test-model", cfg)
if info.Params == nil {
t.Fatal("expected non-nil Params")
}
if info.Params.Temperature == nil || *info.Params.Temperature != 0.8 {
t.Errorf("expected temperature 0.8, got %v", info.Params.Temperature)
}
if info.Params.TopP == nil || *info.Params.TopP != 0.95 {
t.Errorf("expected topP 0.95, got %v", info.Params.TopP)
}
}
func TestModelConfigToModelInfoWithoutParams(t *testing.T) {
cfg := CustomModelConfig{
Name: "Test Model",
BaseURL: "http://localhost:8080/v1",
}
info := modelConfigToModelInfo("test-model", cfg)
if info.Params != nil {
t.Errorf("expected nil Params, got %+v", info.Params)
}
}
func TestApplyModelSettings(t *testing.T) {
// Save and restore viper state.
originalViper := viper.AllSettings()
defer func() {
viper.Reset()
for k, v := range originalViper {
viper.Set(k, v)
}
}()
t.Run("applies model params when not explicitly set", func(t *testing.T) {
viper.Reset()
temp := float32(0.8)
topK := int32(50)
maxTokens := 4096
modelInfo := &ModelInfo{
ID: "test-model",
Params: &GenerationParams{
Temperature: &temp,
TopK: &topK,
MaxTokens: &maxTokens,
},
}
config := &ProviderConfig{
ModelString: "custom/test-model",
}
ApplyModelSettings(config, modelInfo)
if config.Temperature == nil || *config.Temperature != 0.8 {
t.Errorf("expected temperature 0.8, got %v", config.Temperature)
}
if config.TopK == nil || *config.TopK != 50 {
t.Errorf("expected topK 50, got %v", config.TopK)
}
if config.MaxTokens != 4096 {
t.Errorf("expected maxTokens 4096, got %d", config.MaxTokens)
}
})
t.Run("explicit viper values take precedence", func(t *testing.T) {
viper.Reset()
viper.Set("temperature", 0.3)
temp := float32(0.8)
modelInfo := &ModelInfo{
ID: "test-model",
Params: &GenerationParams{
Temperature: &temp,
},
}
explicitTemp := float32(0.3)
config := &ProviderConfig{
ModelString: "custom/test-model",
Temperature: &explicitTemp,
}
ApplyModelSettings(config, modelInfo)
// Temperature should NOT be overridden because it's explicitly set in viper
if config.Temperature == nil || *config.Temperature != 0.3 {
t.Errorf("expected temperature 0.3 (explicit), got %v", config.Temperature)
}
})
t.Run("nil model info is safe", func(t *testing.T) {
viper.Reset()
config := &ProviderConfig{
ModelString: "custom/test-model",
}
// Should not panic
ApplyModelSettings(config, nil)
if config.Temperature != nil {
t.Errorf("expected nil temperature, got %v", config.Temperature)
}
})
t.Run("model info without params is safe", func(t *testing.T) {
viper.Reset()
modelInfo := &ModelInfo{ID: "test-model"}
config := &ProviderConfig{
ModelString: "custom/test-model",
}
ApplyModelSettings(config, modelInfo)
if config.Temperature != nil {
t.Errorf("expected nil temperature, got %v", config.Temperature)
}
})
t.Run("modelSettings from viper takes priority over ModelInfo.Params", func(t *testing.T) {
viper.Reset()
// Set up modelSettings in viper (simulating config file)
viper.Set("modelSettings", map[string]any{
"custom/test-model": map[string]any{
"temperature": 0.5,
"topK": 30,
},
})
// ModelInfo has different params
temp := float32(0.8)
topK := int32(50)
modelInfo := &ModelInfo{
ID: "test-model",
Params: &GenerationParams{
Temperature: &temp,
TopK: &topK,
},
}
config := &ProviderConfig{
ModelString: "custom/test-model",
}
ApplyModelSettings(config, modelInfo)
// modelSettings should win over ModelInfo.Params
if config.Temperature == nil || *config.Temperature != 0.5 {
t.Errorf("expected temperature 0.5 (from modelSettings), got %v", config.Temperature)
}
if config.TopK == nil || *config.TopK != 30 {
t.Errorf("expected topK 30 (from modelSettings), got %v", config.TopK)
}
})
t.Run("stop sequences applied from model params", func(t *testing.T) {
viper.Reset()
modelInfo := &ModelInfo{
ID: "test-model",
Params: &GenerationParams{
StopSequences: []string{"STOP", "END"},
},
}
config := &ProviderConfig{
ModelString: "custom/test-model",
}
ApplyModelSettings(config, modelInfo)
if len(config.StopSequences) != 2 || config.StopSequences[0] != "STOP" {
t.Errorf("expected stop sequences [STOP END], got %v", config.StopSequences)
}
})
t.Run("thinking level applied from model params", func(t *testing.T) {
viper.Reset()
modelInfo := &ModelInfo{
ID: "test-model",
Params: &GenerationParams{
ThinkingLevel: ThinkingHigh,
},
}
config := &ProviderConfig{
ModelString: "custom/test-model",
}
ApplyModelSettings(config, modelInfo)
if config.ThinkingLevel != ThinkingHigh {
t.Errorf("expected thinking level high, got %v", config.ThinkingLevel)
}
})
t.Run("system prompt applied from model params", func(t *testing.T) {
viper.Reset()
modelInfo := &ModelInfo{
ID: "test-model",
Params: &GenerationParams{
SystemPrompt: "You are a coding assistant.",
},
}
config := &ProviderConfig{
ModelString: "custom/test-model",
}
ApplyModelSettings(config, modelInfo)
if config.SystemPrompt != "You are a coding assistant." {
t.Errorf("expected system prompt to be set, got %q", config.SystemPrompt)
}
})
t.Run("explicit system prompt takes precedence", func(t *testing.T) {
viper.Reset()
modelInfo := &ModelInfo{
ID: "test-model",
Params: &GenerationParams{
SystemPrompt: "Model-specific prompt",
},
}
config := &ProviderConfig{
ModelString: "custom/test-model",
SystemPrompt: "Global prompt",
}
ApplyModelSettings(config, modelInfo)
// Global system prompt should NOT be overridden because config
// already has a non-empty SystemPrompt.
if config.SystemPrompt != "Global prompt" {
t.Errorf("expected global prompt preserved, got %q", config.SystemPrompt)
}
})
t.Run("system prompt from file path", func(t *testing.T) {
viper.Reset()
// Create a temp file with a system prompt
tmpFile, err := os.CreateTemp("", "kit-test-prompt-*.txt")
if err != nil {
t.Fatal(err)
}
defer func() { _ = os.Remove(tmpFile.Name()) }()
if _, err := tmpFile.WriteString(" Prompt from file "); err != nil {
t.Fatal(err)
}
_ = tmpFile.Close()
modelInfo := &ModelInfo{
ID: "test-model",
Params: &GenerationParams{
SystemPrompt: tmpFile.Name(),
},
}
config := &ProviderConfig{
ModelString: "custom/test-model",
}
ApplyModelSettings(config, modelInfo)
if config.SystemPrompt != "Prompt from file" {
t.Errorf("expected trimmed file content, got %q", config.SystemPrompt)
}
})
t.Run("modelSettings system prompt overrides custom model params", func(t *testing.T) {
viper.Reset()
viper.Set("modelSettings", map[string]any{
"custom/test-model": map[string]any{
"systemPrompt": "From modelSettings",
},
})
modelInfo := &ModelInfo{
ID: "test-model",
Params: &GenerationParams{
SystemPrompt: "From custom model",
},
}
config := &ProviderConfig{
ModelString: "custom/test-model",
}
ApplyModelSettings(config, modelInfo)
if config.SystemPrompt != "From modelSettings" {
t.Errorf("expected modelSettings prompt, got %q", config.SystemPrompt)
}
})
}
+5
View File
@@ -241,6 +241,11 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
validateModelConfig(config, modelInfo)
}
// Apply per-model generation parameter defaults. Model-level params are
// only applied for fields where the user hasn't explicitly set a value
// via CLI flag or global config.
ApplyModelSettings(config, modelInfo)
// Create the base provider
var result *ProviderResult
var createErr error
+17
View File
@@ -26,6 +26,11 @@ type ModelInfo struct {
ProviderNPM string // Model-specific provider npm override (e.g. "@ai-sdk/anthropic")
BaseURL string // Per-model base URL override (custom models only)
APIKey string // Per-model API key override (custom models only)
// Params holds per-model generation parameter defaults. These are applied
// when the user hasn't explicitly set the corresponding CLI flag or global
// config value. Nil pointer fields mean "no model-level default".
Params *GenerationParams
}
// SupportsCaching returns true if this model family supports prompt caching.
@@ -236,6 +241,18 @@ func (r *ModelsRegistry) LookupModel(provider, modelID string) *ModelInfo {
return &modelInfo
}
// LookupModelForSettings is a convenience function that parses a
// "provider/model" string and looks up the ModelInfo in the global registry.
// Returns nil when the model string is invalid or the model is unknown.
// Used by Kit.SetModel to pre-apply per-model settings before CreateProvider.
func LookupModelForSettings(modelString string) *ModelInfo {
provider, modelName, err := ParseModelString(modelString)
if err != nil {
return nil
}
return GetGlobalRegistry().LookupModel(provider, modelName)
}
// getRequiredEnvVars returns the required environment variables for a provider.
func (r *ModelsRegistry) getRequiredEnvVars(provider string) ([]string, error) {
providerInfo, exists := r.providers[provider]
+2 -6
View File
@@ -2,11 +2,10 @@ package prompts
import (
"fmt"
"log"
"os"
"path/filepath"
"strings"
"github.com/charmbracelet/log"
)
// LoadOptions configures how templates are discovered and loaded.
@@ -74,10 +73,7 @@ func LoadAll(opts LoadOptions) ([]*PromptTemplate, []Diagnostic, error) {
DroppedPath: tpl.FilePath,
Reason: fmt.Sprintf("template from %s overridden by %s", source, existing.Source),
})
log.Debug("template collision",
"name", tpl.Name,
"dropped", tpl.FilePath,
"kept", existing.FilePath)
log.Printf("DEBUG template collision: name=%s dropped=%s kept=%s", tpl.Name, tpl.FilePath, existing.FilePath)
} else {
tpl.Source = source
seen[tpl.Name] = tpl
+317
View File
@@ -0,0 +1,317 @@
package session
import (
"slices"
"testing"
"charm.land/fantasy"
"github.com/mark3labs/kit/internal/message"
)
// TestCompactionCreatesNewLeaf verifies that after compaction, the compaction
// entry has no parent (creating a new root), and BuildContext returns only
// the summary and kept messages, not the old compacted messages.
func TestCompactionCreatesNewLeaf(t *testing.T) {
tm := InMemoryTreeSession("/test")
// Add some messages: M1, M2 (old, will be compacted), M3, M4 (kept)
msg1 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Message 1 - old"}}}
msg2 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 2 - old"}}}
msg3 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Message 3 - kept"}}}
msg4 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 4 - kept"}}}
_, _ = tm.AppendMessage(msg1)
_, _ = tm.AppendMessage(msg2)
id3, _ := tm.AppendMessage(msg3)
id4, _ := tm.AppendMessage(msg4)
// Verify initial state - all messages should be in context
messages, _, _ := tm.BuildContext()
if len(messages) != 4 {
t.Fatalf("expected 4 messages before compaction, got %d", len(messages))
}
// Verify entry IDs
entryIDs := tm.GetContextEntryIDs()
if len(entryIDs) != 4 {
t.Fatalf("expected 4 entry IDs before compaction, got %d", len(entryIDs))
}
// Now add a compaction entry, simulating that M3 is the first kept entry
summary := "Summary of old messages"
compactionID, err := tm.AppendCompaction(summary, id3, 1000, 500, 2, []string{}, []string{})
if err != nil {
t.Fatalf("failed to append compaction: %v", err)
}
// Verify the compaction entry has no parent (empty ParentID)
compactionEntry := tm.GetEntry(compactionID).(*CompactionEntry)
if compactionEntry.ParentID != "" {
t.Errorf("compaction entry should have no parent, got %q", compactionEntry.ParentID)
}
// Verify the leaf is now the compaction entry
if tm.GetLeafID() != compactionID {
t.Errorf("leaf should be compaction entry %q, got %q", compactionID, tm.GetLeafID())
}
// Now BuildContext should return: [summary] + [M3, M4]
messages, _, _ = tm.BuildContext()
if len(messages) != 3 {
t.Fatalf("expected 3 messages after compaction (summary + 2 kept), got %d", len(messages))
}
// First message should be the summary
if messages[0].Role != fantasy.MessageRoleSystem {
t.Errorf("first message should be system summary, got %s", messages[0].Role)
}
summaryText := messages[0].Content[0].(fantasy.TextPart).Text
if summaryText != "[Conversation summary — earlier messages were compacted]\n\n"+summary {
t.Errorf("unexpected summary text: %s", summaryText)
}
// Second message should be M3 (kept)
if messages[1].Role != fantasy.MessageRoleUser {
t.Errorf("second message should be user (M3), got %s", messages[1].Role)
}
m3Text := messages[1].Content[0].(fantasy.TextPart).Text
if m3Text != "Message 3 - kept" {
t.Errorf("unexpected M3 text: %s", m3Text)
}
// Third message should be M4 (kept)
if messages[2].Role != fantasy.MessageRoleAssistant {
t.Errorf("third message should be assistant (M4), got %s", messages[2].Role)
}
m4Text := messages[2].Content[0].(fantasy.TextPart).Text
if m4Text != "Message 4 - kept" {
t.Errorf("unexpected M4 text: %s", m4Text)
}
// Verify GetContextEntryIDs returns correct IDs
entryIDs = tm.GetContextEntryIDs()
if len(entryIDs) != 3 {
t.Fatalf("expected 3 entry IDs after compaction (empty for summary + 2 kept), got %d: %v", len(entryIDs), entryIDs)
}
// First entry ID should be empty (summary has no entry)
if entryIDs[0] != "" {
t.Errorf("first entry ID should be empty (summary), got %q", entryIDs[0])
}
// Second and third should be id3 and id4 (the kept messages)
if entryIDs[1] != id3 {
t.Errorf("second entry ID should be %q (M3), got %q", id3, entryIDs[1])
}
if entryIDs[2] != id4 {
t.Errorf("third entry ID should be %q (M4), got %q", id4, entryIDs[2])
}
}
// TestCompactionWithNewMessagesAfterCompaction verifies that messages appended
// after compaction are correctly included in the context.
func TestCompactionWithNewMessagesAfterCompaction(t *testing.T) {
tm := InMemoryTreeSession("/test")
// Add initial messages
msg1 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Message 1"}}}
msg2 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 2"}}}
msg3 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Message 3 - kept"}}}
_, _ = tm.AppendMessage(msg1)
_, _ = tm.AppendMessage(msg2)
id3, _ := tm.AppendMessage(msg3)
// Compact, keeping only M3
_, _ = tm.AppendCompaction("Summary", id3, 1000, 500, 2, []string{}, []string{})
// Add a new message after compaction
msg4 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 4 - after compaction"}}}
_, _ = tm.AppendMessage(msg4)
// BuildContext should return: [summary] + [M4 (new after compaction)] + [M3 (kept)]
messages, _, _ := tm.BuildContext()
if len(messages) != 3 {
t.Fatalf("expected 3 messages (summary + M4 + M3), got %d: %+v", len(messages), messages)
}
// Verify order: summary, M4 (new), M3 (kept)
if messages[0].Role != fantasy.MessageRoleSystem {
t.Errorf("first message should be summary, got %s", messages[0].Role)
}
if messages[1].Role != fantasy.MessageRoleAssistant {
t.Errorf("second message should be assistant (M4), got %s", messages[1].Role)
}
m4Text := messages[1].Content[0].(fantasy.TextPart).Text
if m4Text != "Message 4 - after compaction" {
t.Errorf("unexpected M4 text: %s", m4Text)
}
if messages[2].Role != fantasy.MessageRoleUser {
t.Errorf("third message should be user (M3), got %s", messages[2].Role)
}
// Verify that M1 is NOT in the context
for i, msg := range messages {
if msg.Role == fantasy.MessageRoleUser {
text := msg.Content[0].(fantasy.TextPart).Text
if text == "Message 1" {
t.Errorf("Message 1 (compacted) should not be in context at index %d", i)
}
}
}
}
// TestCompactionWithNoKeptMessages verifies compaction when all messages are compacted.
func TestCompactionWithNoKeptMessages(t *testing.T) {
tm := InMemoryTreeSession("/test")
// Add messages that will all be compacted
msg1 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Message 1"}}}
msg2 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 2"}}}
if _, err := tm.AppendMessage(msg1); err != nil {
t.Fatalf("failed to append message: %v", err)
}
if _, err := tm.AppendMessage(msg2); err != nil {
t.Fatalf("failed to append message: %v", err)
}
// Compact with no kept messages (empty firstKeptEntryID)
summary := "All messages summarized"
compactionID, _ := tm.AppendCompaction(summary, "", 1000, 100, 2, []string{}, []string{})
// Verify the compaction entry has no parent
compactionEntry := tm.GetEntry(compactionID).(*CompactionEntry)
if compactionEntry.ParentID != "" {
t.Errorf("compaction entry should have no parent, got %q", compactionEntry.ParentID)
}
// BuildContext should return only the summary
messages, _, _ := tm.BuildContext()
if len(messages) != 1 {
t.Fatalf("expected 1 message (summary only), got %d: %+v", len(messages), messages)
}
if messages[0].Role != fantasy.MessageRoleSystem {
t.Errorf("message should be system summary, got %s", messages[0].Role)
}
}
// TestMultipleCompactions verifies that multiple compactions work correctly.
func TestMultipleCompactions(t *testing.T) {
tm := InMemoryTreeSession("/test")
// First batch of messages
msg1 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Batch 1 - User"}}}
msg2 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Batch 1 - Assistant"}}}
id1, _ := tm.AppendMessage(msg1)
id2, _ := tm.AppendMessage(msg2)
// First compaction
_, _ = tm.AppendCompaction("Summary 1", id1, 1000, 500, 1, []string{}, []string{})
// Second batch
msg3 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Batch 2 - User"}}}
msg4 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Batch 2 - Assistant"}}}
id3, _ := tm.AppendMessage(msg3)
id4, _ := tm.AppendMessage(msg4)
// Second compaction (compacting the first compaction + batch 2)
// Note: id3 is the first kept entry, so id3 and id4 should be preserved
compactionID2, _ := tm.AppendCompaction("Summary 2", id3, 1000, 500, 3, []string{}, []string{})
// Verify second compaction has no parent
compactionEntry2 := tm.GetEntry(compactionID2).(*CompactionEntry)
if compactionEntry2.ParentID != "" {
t.Errorf("second compaction entry should have no parent, got %q", compactionEntry2.ParentID)
}
// Add final message
msg5 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Final message"}}}
id5, _ := tm.AppendMessage(msg5)
// BuildContext should include:
// - Summary 2 (from second compaction)
// - msg5 (final message)
// - msg3, msg4 (kept from second compaction)
// But NOT Summary 1 or msg1, msg2 (they're before the first kept entry of compaction 2)
messages, _, _ := tm.BuildContext()
// Should have: Summary 2 + msg5 + msg3 + msg4 = 4 messages
if len(messages) != 4 {
t.Fatalf("expected 4 messages (Summary 2 + msg5 + msg3 + msg4), got %d: %+v", len(messages), messages)
}
// First should be Summary 2
if messages[0].Role != fantasy.MessageRoleSystem {
t.Errorf("first message should be system (Summary 2), got %s", messages[0].Role)
}
summaryText := messages[0].Content[0].(fantasy.TextPart).Text
if summaryText != "[Conversation summary — earlier messages were compacted]\n\nSummary 2" {
t.Errorf("unexpected summary: %s", summaryText)
}
// Verify msg5 is included
foundFinal := false
for _, msg := range messages {
if msg.Role == fantasy.MessageRoleUser {
text := msg.Content[0].(fantasy.TextPart).Text
if text == "Final message" {
foundFinal = true
break
}
}
}
if !foundFinal {
t.Error("Final message (msg5) should be in context")
}
// Verify msg1, msg2 are NOT included (compacted by first compaction, then second)
for _, msg := range messages {
if msg.Role == fantasy.MessageRoleUser || msg.Role == fantasy.MessageRoleAssistant {
text := msg.Content[0].(fantasy.TextPart).Text
if text == "Batch 1 - User" || text == "Batch 1 - Assistant" {
t.Errorf("Batch 1 messages should not be in context, found: %s", text)
}
}
}
// Verify entry IDs
entryIDs := tm.GetContextEntryIDs()
if len(entryIDs) != 4 {
t.Fatalf("expected 4 entry IDs, got %d: %v", len(entryIDs), entryIDs)
}
// First should be empty (summary)
if entryIDs[0] != "" {
t.Errorf("first entry ID should be empty (summary), got %q", entryIDs[0])
}
// Check that id5 is in the list
if !slices.Contains(entryIDs, id5) {
t.Errorf("id5 (final message) should be in entry IDs, got %v", entryIDs)
}
// Verify id3 and id4 ARE in the list (they were kept)
foundID3, foundID4 := false, false
for _, id := range entryIDs {
if id == id3 {
foundID3 = true
}
if id == id4 {
foundID4 = true
}
}
if !foundID3 {
t.Errorf("id3 (kept message) should be in entry IDs, got %v", entryIDs)
}
if !foundID4 {
t.Errorf("id4 (kept message) should be in entry IDs, got %v", entryIDs)
}
// Verify id1 and id2 are NOT in the list (they were compacted away)
for _, id := range entryIDs {
if id == id1 || id == id2 {
t.Errorf("id1 or id2 (compacted) should not be in entry IDs, found %q in %v", id, entryIDs)
}
}
}
+179 -30
View File
@@ -509,11 +509,19 @@ func (tm *TreeManager) AppendExtensionData(extType, data string) (string, error)
// AppendCompaction adds a compaction entry to the tree. The entry records
// the summary and the ID of the first entry that should be preserved in the
// LLM context. Messages before that entry are replaced by the summary.
//
// The compaction entry becomes a new "root" for the post-compaction branch
// with no parent (empty ParentID). This breaks the parent chain so that old
// compacted messages are no longer traversed when building context. The kept
// messages are explicitly collected via FirstKeptEntryID in BuildContext.
func (tm *TreeManager) AppendCompaction(summary, firstKeptEntryID string, tokensBefore, tokensAfter, messagesRemoved int, readFiles, modifiedFiles []string) (string, error) {
tm.mu.Lock()
defer tm.mu.Unlock()
entry := NewCompactionEntry(tm.leafID, summary, firstKeptEntryID, tokensBefore, tokensAfter, messagesRemoved, readFiles, modifiedFiles)
// The compaction entry has no parent, making it a new "root" for the
// post-compaction branch. This ensures old compacted messages are not
// traversed when walking from the current leaf.
entry := NewCompactionEntry("", summary, firstKeptEntryID, tokensBefore, tokensAfter, messagesRemoved, readFiles, modifiedFiles)
if err := tm.appendAndPersist(entry); err != nil {
return "", err
}
@@ -683,14 +691,18 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
// Find the last compaction entry on this branch — it determines
// which older messages are replaced by the summary.
var lastCompaction *CompactionEntry
var compactionIndex = -1
for i := len(branch) - 1; i >= 0; i-- {
if c, ok := branch[i].(*CompactionEntry); ok {
lastCompaction = c
compactionIndex = i
break
}
}
// If there is a compaction, inject the summary first.
// If there is a compaction, inject the summary first and collect
// the kept messages starting from FirstKeptEntryID (since the
// compaction entry's parent chain doesn't include them).
if lastCompaction != nil {
messages = append(messages, fantasy.Message{
Role: fantasy.MessageRoleSystem,
@@ -700,21 +712,104 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
},
},
})
}
// Determine whether to skip entries (everything before firstKeptEntryID).
skipping := lastCompaction != nil
for _, entry := range branch {
// Once we reach the first kept entry, stop skipping.
if skipping {
entryID := tm.EntryID(entry)
if entryID == lastCompaction.FirstKeptEntryID {
skipping = false
} else {
// Collect entries from the compaction entry itself (at compactionIndex)
// and any entries before it in the branch (newer messages).
for i := compactionIndex; i < len(branch); i++ {
entry := branch[i]
switch e := entry.(type) {
case *MessageEntry:
msg, err := e.ToMessage()
if err != nil {
continue // skip malformed entries
}
msgs := msg.ToLLMMessages()
messages = append(messages, msgs...)
case *BranchSummaryEntry:
// Convert branch summary to a user message for context.
if e.Summary != "" {
messages = append(messages, fantasy.Message{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{
Text: fmt.Sprintf("[Branch context: %s]", e.Summary),
},
},
})
}
case *ModelChangeEntry:
provider = e.Provider
modelID = e.ModelID
case *CompactionEntry:
// Already handled above (summary injected).
continue
}
}
// Now collect the kept messages starting from FirstKeptEntryID.
// These are not in the current branch because the compaction entry
// is parented to the first kept entry's parent, not the first kept entry.
// We iterate through entries in order (not using getBranchLocked) to avoid
// walking back to old compacted messages.
// We stop when we reach the compaction entry to avoid double-counting
// messages that were added after the compaction.
if lastCompaction.FirstKeptEntryID != "" {
found := false
for _, entry := range tm.entries {
entryID := tm.EntryID(entry)
// Skip entries until we reach the first kept entry.
if !found {
if entryID == lastCompaction.FirstKeptEntryID {
found = true
} else {
continue
}
}
// Stop when we reach the compaction entry itself.
// Messages after the compaction are collected from the branch walk above.
if entryID == lastCompaction.ID {
break
}
// Process this kept entry.
switch e := entry.(type) {
case *MessageEntry:
msg, err := e.ToMessage()
if err != nil {
continue
}
msgs := msg.ToLLMMessages()
messages = append(messages, msgs...)
case *BranchSummaryEntry:
if e.Summary != "" {
messages = append(messages, fantasy.Message{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{
Text: fmt.Sprintf("[Branch context: %s]", e.Summary),
},
},
})
}
case *ModelChangeEntry:
provider = e.Provider
modelID = e.ModelID
}
}
}
return messages, provider, modelID
}
// No compaction - process the entire branch normally.
for _, entry := range branch {
switch e := entry.(type) {
case *MessageEntry:
msg, err := e.ToMessage()
@@ -740,10 +835,6 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
case *ModelChangeEntry:
provider = e.Provider
modelID = e.ModelID
case *CompactionEntry:
// Already handled above (the last one on the branch).
continue
}
}
@@ -853,31 +944,92 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
// Find the last compaction entry for skip logic.
var lastCompaction *CompactionEntry
var compactionIndex = -1
for i := len(branch) - 1; i >= 0; i-- {
if c, ok := branch[i].(*CompactionEntry); ok {
lastCompaction = c
compactionIndex = i
break
}
}
var ids []string
// If there's a compaction summary injected, it has no entry ID.
// If there's a compaction, we need to collect IDs from:
// 1. Entries after the compaction entry in the branch (newer messages)
// 2. Entries from FirstKeptEntryID onwards (kept messages)
if lastCompaction != nil {
ids = append(ids, "") // placeholder for the summary system message
}
// Placeholder for the summary system message (no entry ID).
ids = append(ids, "")
skipping := lastCompaction != nil
for _, entry := range branch {
if skipping {
entryID := tm.EntryID(entry)
if entryID == lastCompaction.FirstKeptEntryID {
skipping = false
} else {
continue
// Collect IDs from entries after the compaction entry (newer messages).
for i := compactionIndex + 1; i < len(branch); i++ {
entry := branch[i]
switch e := entry.(type) {
case *MessageEntry:
msg, err := e.ToMessage()
if err != nil {
continue
}
msgs := msg.ToLLMMessages()
for range msgs {
ids = append(ids, e.ID)
}
case *BranchSummaryEntry:
if e.Summary != "" {
ids = append(ids, e.ID)
}
}
}
// Collect IDs from the kept messages starting at FirstKeptEntryID.
// We iterate through entries in order (not using getBranchLocked) to avoid
// walking back to old compacted messages.
// We stop when we reach the compaction entry to avoid double-counting.
if lastCompaction.FirstKeptEntryID != "" {
found := false
for _, entry := range tm.entries {
entryID := tm.EntryID(entry)
// Skip entries until we reach the first kept entry.
if !found {
if entryID == lastCompaction.FirstKeptEntryID {
found = true
} else {
continue
}
}
// Stop when we reach the compaction entry itself.
if entryID == lastCompaction.ID {
break
}
switch e := entry.(type) {
case *MessageEntry:
msg, err := e.ToMessage()
if err != nil {
continue
}
msgs := msg.ToLLMMessages()
for range msgs {
ids = append(ids, e.ID)
}
case *BranchSummaryEntry:
if e.Summary != "" {
ids = append(ids, e.ID)
}
}
}
}
return ids
}
// No compaction - collect IDs from the entire branch.
for _, entry := range branch {
switch e := entry.(type) {
case *MessageEntry:
msg, err := e.ToMessage()
@@ -893,9 +1045,6 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
if e.Summary != "" {
ids = append(ids, e.ID)
}
case *CompactionEntry:
continue
}
}
+51 -18
View File
@@ -60,15 +60,16 @@ type MCPConnection struct {
// creation, health monitoring, and cleanup. The pool runs background health checks
// to proactively identify and remove unhealthy connections.
type MCPConnectionPool struct {
connections map[string]*MCPConnection
config *ConnectionPoolConfig
mu sync.RWMutex
model fantasy.LanguageModel
ctx context.Context
cancel context.CancelFunc
debug bool
debugLogger DebugLogger
oauthFlow *OAuthFlowRunner
connections map[string]*MCPConnection
config *ConnectionPoolConfig
mu sync.RWMutex
model fantasy.LanguageModel
ctx context.Context
cancel context.CancelFunc
debug bool
debugLogger DebugLogger
oauthFlow *OAuthFlowRunner
tokenStoreFactory TokenStoreFactory // custom factory for per-server token stores (nil = default FileTokenStore)
}
// NewMCPConnectionPool creates a new MCP connection pool with the specified configuration.
@@ -76,19 +77,20 @@ type MCPConnectionPool struct {
// goroutine for periodic health checks that runs until Close is called.
// The model parameter is used for MCP servers that require sampling support.
// Thread-safe for concurrent use immediately after creation.
func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageModel, debug bool, authHandler MCPAuthHandler) *MCPConnectionPool {
func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageModel, debug bool, authHandler MCPAuthHandler, tokenStoreFactory TokenStoreFactory) *MCPConnectionPool {
if config == nil {
config = DefaultConnectionPoolConfig()
}
ctx, cancel := context.WithCancel(context.Background())
pool := &MCPConnectionPool{
connections: make(map[string]*MCPConnection),
config: config,
model: model,
ctx: ctx,
cancel: cancel,
debug: debug,
connections: make(map[string]*MCPConnection),
config: config,
model: model,
ctx: ctx,
cancel: cancel,
debug: debug,
tokenStoreFactory: tokenStoreFactory,
}
if authHandler != nil {
@@ -367,7 +369,7 @@ func (p *MCPConnectionPool) createSSEClient(ctx context.Context, serverConfig co
// scopes are discovered automatically via dynamic client registration and
// server metadata (RFC 9728).
if p.oauthFlow != nil {
tokenStore, tsErr := NewFileTokenStore(serverConfig.URL)
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
if tsErr != nil {
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
}
@@ -414,7 +416,7 @@ func (p *MCPConnectionPool) createStreamableClient(ctx context.Context, serverCo
// scopes are discovered automatically via dynamic client registration and
// server metadata (RFC 9728).
if p.oauthFlow != nil {
tokenStore, tsErr := NewFileTokenStore(serverConfig.URL)
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
if tsErr != nil {
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
}
@@ -437,6 +439,16 @@ func (p *MCPConnectionPool) createStreamableClient(ctx context.Context, serverCo
return streamableClient, nil
}
// createTokenStore creates a token store for the given server URL.
// If a custom TokenStoreFactory is configured, it is used; otherwise the
// default file-backed token store is created.
func (p *MCPConnectionPool) createTokenStore(serverURL string) (transport.TokenStore, error) {
if p.tokenStoreFactory != nil {
return p.tokenStoreFactory(serverURL)
}
return NewFileTokenStore(serverURL)
}
// initializeClient initializes the client
func (p *MCPConnectionPool) initializeClient(ctx context.Context, client client.MCPClient) error {
initCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
@@ -583,6 +595,27 @@ func (p *MCPConnectionPool) GetClients() map[string]client.MCPClient {
return clients
}
// RemoveConnection closes and removes a single connection from the pool.
// Returns an error if the connection does not exist or if closing fails.
// Thread-safe for concurrent use.
func (p *MCPConnectionPool) RemoveConnection(serverName string) error {
p.mu.Lock()
defer p.mu.Unlock()
conn, exists := p.connections[serverName]
if !exists {
return fmt.Errorf("connection %q not found in pool", serverName)
}
err := conn.client.Close()
delete(p.connections, serverName)
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
p.debugLogger.LogDebug(fmt.Sprintf("[POOL] Removed connection %s", serverName))
}
return err
}
// Close gracefully shuts down the connection pool, closing all client connections
// and stopping the background health check goroutine. It attempts to close all
// connections even if some fail, logging any errors encountered.
+147 -10
View File
@@ -20,19 +20,25 @@ import (
// pooling, health checks, tool name prefixing to avoid conflicts, and sampling support for LLM interactions.
// Thread-safe for concurrent tool invocations.
type MCPToolManager struct {
connectionPool *MCPConnectionPool
tools []fantasy.AgentTool
toolMap map[string]*toolMapping // maps prefixed tool names to their server and original name
mu sync.Mutex // protects tools and toolMap during parallel loading
model fantasy.LanguageModel // LLM model for sampling
authHandler MCPAuthHandler // OAuth handler for remote servers (nil = no OAuth)
config *config.Config
debug bool
debugLogger DebugLogger
connectionPool *MCPConnectionPool
tools []fantasy.AgentTool
toolMap map[string]*toolMapping // maps prefixed tool names to their server and original name
mu sync.Mutex // protects tools and toolMap during parallel loading
model fantasy.LanguageModel // LLM model for sampling
authHandler MCPAuthHandler // OAuth handler for remote servers (nil = no OAuth)
tokenStoreFactory TokenStoreFactory // factory for creating per-server token stores (nil = default FileTokenStore)
config *config.Config
debug bool
debugLogger DebugLogger
// onServerLoaded, if non-nil, is called when each server finishes loading.
// Called with server name, tool count, and error (nil on success).
onServerLoaded func(serverName string, toolCount int, err error)
// onToolsChanged, if non-nil, is called after AddServer or RemoveServer
// mutates the tool list. The agent layer uses this to trigger a
// rebuildFantasyAgent so the LLM sees the updated tools.
onToolsChanged func()
}
// toolMapping stores the mapping between prefixed tool names and their original details
@@ -69,6 +75,14 @@ func (m *MCPToolManager) SetAuthHandler(handler MCPAuthHandler) {
m.authHandler = handler
}
// SetTokenStoreFactory sets a custom factory for creating per-server OAuth token
// stores. When set, the factory is called for each remote MCP server instead of
// using the default file-based token store. This method should be called before
// LoadTools.
func (m *MCPToolManager) SetTokenStoreFactory(factory TokenStoreFactory) {
m.tokenStoreFactory = factory
}
// SetDebugLogger sets the debug logger for the tool manager.
// The logger will be used to output detailed debugging information about MCP connections,
// tool loading, and execution. If a connection pool exists, it will also be configured
@@ -87,6 +101,126 @@ func (m *MCPToolManager) SetOnServerLoaded(cb func(serverName string, toolCount
m.onServerLoaded = cb
}
// SetOnToolsChanged sets the callback that's invoked after AddServer or
// RemoveServer mutates the tool list. The agent layer uses this to trigger
// a rebuild of the fantasy agent so the LLM sees the updated tool set.
func (m *MCPToolManager) SetOnToolsChanged(cb func()) {
m.onToolsChanged = cb
}
// AddServer connects to a new MCP server at runtime and loads its tools.
// The server's tools are immediately available to the agent after this call.
// Returns the number of tools loaded from the server.
//
// If the connection pool has not been initialised yet (i.e. LoadTools was never
// called), AddServer creates one automatically using the manager's current
// configuration.
//
// Returns an error if a server with the same name is already loaded, or if
// the connection or tool loading fails.
func (m *MCPToolManager) AddServer(ctx context.Context, name string, cfg config.MCPServerConfig) (int, error) {
m.mu.Lock()
// Check for duplicate.
if _, exists := m.toolMap[name+"__"]; exists {
m.mu.Unlock()
return 0, fmt.Errorf("MCP server %q is already loaded", name)
}
// More thorough duplicate check: scan toolMap for any key with the server prefix.
prefix := name + "__"
for k := range m.toolMap {
if len(k) >= len(prefix) && k[:len(prefix)] == prefix {
m.mu.Unlock()
return 0, fmt.Errorf("MCP server %q is already loaded", name)
}
}
m.mu.Unlock()
// Lazily create the connection pool if LoadTools was never called.
m.ensureConnectionPool()
count, err := m.loadServerTools(ctx, name, cfg)
if err != nil {
return 0, fmt.Errorf("failed to add MCP server %q: %w", name, err)
}
// Notify listeners.
if m.onServerLoaded != nil {
m.onServerLoaded(name, count, nil)
}
if m.onToolsChanged != nil {
m.onToolsChanged()
}
return count, nil
}
// RemoveServer disconnects an MCP server and removes all its tools.
// After this call the agent will no longer see or be able to call tools from
// the named server. Returns an error if the server is not loaded.
func (m *MCPToolManager) RemoveServer(name string) error {
prefix := name + "__"
m.mu.Lock()
// Check the server actually has tools loaded.
found := false
for k := range m.toolMap {
if len(k) >= len(prefix) && k[:len(prefix)] == prefix {
found = true
break
}
}
if !found {
m.mu.Unlock()
return fmt.Errorf("MCP server %q is not loaded", name)
}
// Remove tools belonging to this server.
newTools := make([]fantasy.AgentTool, 0, len(m.tools))
for _, t := range m.tools {
if len(t.Info().Name) < len(prefix) || t.Info().Name[:len(prefix)] != prefix {
newTools = append(newTools, t)
}
}
m.tools = newTools
// Remove tool mappings.
for k := range m.toolMap {
if len(k) >= len(prefix) && k[:len(prefix)] == prefix {
delete(m.toolMap, k)
}
}
m.mu.Unlock()
// Close the connection in the pool (best-effort).
if m.connectionPool != nil {
_ = m.connectionPool.RemoveConnection(name)
}
if m.onToolsChanged != nil {
m.onToolsChanged()
}
return nil
}
// ensureConnectionPool lazily creates a connection pool if one does not exist.
// This allows AddServer to work even if LoadTools was never called.
func (m *MCPToolManager) ensureConnectionPool() {
if m.connectionPool != nil {
return
}
debug := false
if m.config != nil {
debug = m.config.Debug
}
if m.debugLogger == nil {
m.debugLogger = NewSimpleDebugLogger(debug)
}
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, debug, m.authHandler, m.tokenStoreFactory)
m.connectionPool.SetDebugLogger(m.debugLogger)
}
// LoadTools loads tools from all configured MCP servers based on the provided configuration.
// It initializes the connection pool, connects to each configured server, and loads their tools.
// Tools from different servers are prefixed with the server name to avoid naming conflicts.
@@ -99,7 +233,7 @@ func (m *MCPToolManager) LoadTools(ctx context.Context, cfg *config.Config) erro
if m.debugLogger == nil {
m.debugLogger = NewSimpleDebugLogger(cfg.Debug)
}
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, cfg.Debug, m.authHandler)
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, cfg.Debug, m.authHandler, m.tokenStoreFactory)
m.connectionPool.SetDebugLogger(m.debugLogger)
// Load all servers in parallel. Each server connection (subprocess
@@ -290,6 +424,9 @@ func (m *MCPToolManager) GetLoadedServerNames() []string {
// proper cleanup of stdio processes, network connections, and other resources.
// It is safe to call Close multiple times.
func (m *MCPToolManager) Close() error {
if m.connectionPool == nil {
return nil
}
return m.connectionPool.Close()
}
@@ -0,0 +1,323 @@
package tools
import (
"context"
"os"
"path/filepath"
"runtime"
"slices"
"strings"
"sync"
"testing"
"time"
"github.com/mark3labs/kit/internal/config"
)
// testdataDir returns the absolute path to the testdata directory.
func testdataDir(t *testing.T) string {
t.Helper()
_, file, _, ok := runtime.Caller(0)
if !ok {
t.Fatal("cannot determine test file path")
}
return filepath.Join(filepath.Dir(file), "testdata")
}
// echoServerConfig returns an MCPServerConfig for the test echo MCP server.
func echoServerConfig(t *testing.T) config.MCPServerConfig {
t.Helper()
script := filepath.Join(testdataDir(t), "echo_server.py")
if _, err := os.Stat(script); err != nil {
t.Skipf("echo_server.py not found: %v", err)
}
return config.MCPServerConfig{
Command: []string{"python3", script},
}
}
// TestMCPToolManager_AddServer_Integration tests adding a real MCP server
// at runtime and verifying tools are loaded.
func TestMCPToolManager_AddServer_Integration(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
manager := NewMCPToolManager()
defer func() { _ = manager.Close() }()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cfg := echoServerConfig(t)
// Track callbacks.
var mu sync.Mutex
var loadedServer string
var loadedCount int
toolsChangedCount := 0
manager.SetOnServerLoaded(func(name string, count int, err error) {
mu.Lock()
loadedServer = name
loadedCount = count
mu.Unlock()
})
manager.SetOnToolsChanged(func() {
mu.Lock()
toolsChangedCount++
mu.Unlock()
})
// Add the server.
count, err := manager.AddServer(ctx, "echo", cfg)
if err != nil {
t.Fatalf("AddServer failed: %v", err)
}
if count != 2 {
t.Errorf("Expected 2 tools from echo server, got %d", count)
}
// Verify callbacks fired.
mu.Lock()
if loadedServer != "echo" {
t.Errorf("Expected onServerLoaded for 'echo', got %q", loadedServer)
}
if loadedCount != 2 {
t.Errorf("Expected onServerLoaded count=2, got %d", loadedCount)
}
if toolsChangedCount != 1 {
t.Errorf("Expected onToolsChanged called once, got %d", toolsChangedCount)
}
mu.Unlock()
// Verify tools are accessible.
tools := manager.GetTools()
if len(tools) != 2 {
t.Fatalf("Expected 2 tools, got %d", len(tools))
}
// Verify tool names are prefixed.
toolNames := make(map[string]bool)
for _, tool := range tools {
toolNames[tool.Info().Name] = true
}
if !toolNames["echo__echo"] {
t.Error("Expected tool 'echo__echo'")
}
if !toolNames["echo__greet"] {
t.Error("Expected tool 'echo__greet'")
}
// Verify server appears in loaded names.
names := manager.GetLoadedServerNames()
if !slices.Contains(names, "echo") {
t.Errorf("Expected 'echo' in loaded server names, got: %v", names)
}
}
// TestMCPToolManager_RemoveServer_Integration tests removing a real MCP server
// and verifying tools are cleaned up.
func TestMCPToolManager_RemoveServer_Integration(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
manager := NewMCPToolManager()
defer func() { _ = manager.Close() }()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cfg := echoServerConfig(t)
// Add the server first.
count, err := manager.AddServer(ctx, "echo", cfg)
if err != nil {
t.Fatalf("AddServer failed: %v", err)
}
if count != 2 {
t.Fatalf("Expected 2 tools, got %d", count)
}
var mu sync.Mutex
toolsChangedCount := 0
manager.SetOnToolsChanged(func() {
mu.Lock()
toolsChangedCount++
mu.Unlock()
})
// Remove the server.
err = manager.RemoveServer("echo")
if err != nil {
t.Fatalf("RemoveServer failed: %v", err)
}
// Verify tools are gone.
tools := manager.GetTools()
if len(tools) != 0 {
t.Errorf("Expected 0 tools after removal, got %d", len(tools))
}
// Verify callback fired.
mu.Lock()
if toolsChangedCount != 1 {
t.Errorf("Expected onToolsChanged called once, got %d", toolsChangedCount)
}
mu.Unlock()
// Verify server is gone from loaded names.
names := manager.GetLoadedServerNames()
for _, n := range names {
if n == "echo" {
t.Error("Server 'echo' should not appear in loaded names after removal")
}
}
// Removing again should error.
err = manager.RemoveServer("echo")
if err == nil {
t.Fatal("Expected error removing already-removed server")
}
if !strings.Contains(err.Error(), "not loaded") {
t.Errorf("Expected 'not loaded' error, got: %v", err)
}
}
// TestMCPToolManager_AddRemoveMultiple_Integration tests adding and removing
// multiple servers, verifying tool isolation.
func TestMCPToolManager_AddRemoveMultiple_Integration(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
manager := NewMCPToolManager()
defer func() { _ = manager.Close() }()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cfg := echoServerConfig(t)
// Add two servers with the same binary but different names.
count1, err := manager.AddServer(ctx, "server-a", cfg)
if err != nil {
t.Fatalf("AddServer server-a failed: %v", err)
}
count2, err := manager.AddServer(ctx, "server-b", cfg)
if err != nil {
t.Fatalf("AddServer server-b failed: %v", err)
}
totalTools := count1 + count2
if totalTools != 4 {
t.Fatalf("Expected 4 total tools (2+2), got %d", totalTools)
}
tools := manager.GetTools()
if len(tools) != 4 {
t.Fatalf("Expected 4 tools, got %d", len(tools))
}
// Remove server-a, verify server-b tools remain.
err = manager.RemoveServer("server-a")
if err != nil {
t.Fatalf("RemoveServer server-a failed: %v", err)
}
tools = manager.GetTools()
if len(tools) != 2 {
t.Fatalf("Expected 2 tools after removing server-a, got %d", len(tools))
}
// Remaining tools should all be from server-b.
for _, tool := range tools {
if !strings.HasPrefix(tool.Info().Name, "server-b__") {
t.Errorf("Expected tool from server-b, got: %s", tool.Info().Name)
}
}
// Remove server-b.
err = manager.RemoveServer("server-b")
if err != nil {
t.Fatalf("RemoveServer server-b failed: %v", err)
}
tools = manager.GetTools()
if len(tools) != 0 {
t.Errorf("Expected 0 tools after removing all servers, got %d", len(tools))
}
}
// TestMCPToolManager_AddServer_DuplicateDetection_Integration tests that
// adding a server with the same name as an already loaded server errors.
func TestMCPToolManager_AddServer_DuplicateDetection_Integration(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
manager := NewMCPToolManager()
defer func() { _ = manager.Close() }()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cfg := echoServerConfig(t)
// Add the server.
_, err := manager.AddServer(ctx, "echo", cfg)
if err != nil {
t.Fatalf("First AddServer failed: %v", err)
}
// Try to add again with the same name.
_, err = manager.AddServer(ctx, "echo", cfg)
if err == nil {
t.Fatal("Expected error adding duplicate server")
}
if !strings.Contains(err.Error(), "already loaded") {
t.Errorf("Expected 'already loaded' error, got: %v", err)
}
}
// TestMCPToolManager_AddAfterRemove_Integration tests that a server can be
// re-added after being removed.
func TestMCPToolManager_AddAfterRemove_Integration(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
manager := NewMCPToolManager()
defer func() { _ = manager.Close() }()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cfg := echoServerConfig(t)
// Add, remove, re-add.
_, err := manager.AddServer(ctx, "echo", cfg)
if err != nil {
t.Fatalf("First AddServer failed: %v", err)
}
err = manager.RemoveServer("echo")
if err != nil {
t.Fatalf("RemoveServer failed: %v", err)
}
count, err := manager.AddServer(ctx, "echo", cfg)
if err != nil {
t.Fatalf("Re-AddServer failed: %v", err)
}
if count != 2 {
t.Errorf("Expected 2 tools on re-add, got %d", count)
}
tools := manager.GetTools()
if len(tools) != 2 {
t.Errorf("Expected 2 tools after re-add, got %d", len(tools))
}
}
+155
View File
@@ -0,0 +1,155 @@
package tools
import (
"context"
"strings"
"sync"
"testing"
"time"
"github.com/mark3labs/kit/internal/config"
)
// TestMCPToolManager_AddServer_DuplicateName verifies that adding a server
// with a name that already exists returns an error.
func TestMCPToolManager_AddServer_DuplicateName(t *testing.T) {
manager := NewMCPToolManager()
cfg := config.MCPServerConfig{
Command: []string{"non-existent-command"},
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// First add will fail (bad command), but let's test the duplicate detection
// by simulating a loaded server via LoadTools first.
loadCfg := &config.Config{
MCPServers: map[string]config.MCPServerConfig{
"test-server": cfg,
},
}
// This will fail to load but creates the connection pool.
_ = manager.LoadTools(ctx, loadCfg)
// Now try to add the same server name — the tools didn't load (bad command),
// so AddServer should not find a duplicate and should fail with connection error.
_, err := manager.AddServer(ctx, "test-server", cfg)
if err == nil {
t.Fatal("Expected error when adding server with bad command, got nil")
}
// It should be a connection error, not a duplicate error.
if strings.Contains(err.Error(), "already loaded") {
t.Fatalf("Should not report duplicate since server failed to load initially: %v", err)
}
}
// TestMCPToolManager_RemoveServer_NotLoaded verifies that removing a server
// that doesn't exist returns an appropriate error.
func TestMCPToolManager_RemoveServer_NotLoaded(t *testing.T) {
manager := NewMCPToolManager()
err := manager.RemoveServer("nonexistent")
if err == nil {
t.Fatal("Expected error when removing non-existent server, got nil")
}
if !strings.Contains(err.Error(), "not loaded") {
t.Errorf("Expected 'not loaded' error, got: %v", err)
}
}
// TestMCPToolManager_AddServer_CreatesConnectionPool verifies that AddServer
// lazily creates a connection pool when LoadTools was never called.
func TestMCPToolManager_AddServer_CreatesConnectionPool(t *testing.T) {
manager := NewMCPToolManager()
// Connection pool should be nil initially.
if manager.connectionPool != nil {
t.Fatal("Expected nil connection pool before any operation")
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// AddServer with a bad command — should fail, but the pool should be created.
_, err := manager.AddServer(ctx, "lazy-server", config.MCPServerConfig{
Command: []string{"non-existent-command"},
})
if err == nil {
t.Fatal("Expected error for bad command")
}
// Connection pool should have been created.
if manager.connectionPool == nil {
t.Fatal("Expected connection pool to be created lazily by AddServer")
}
}
// TestMCPToolManager_OnToolsChanged_Callback verifies that the onToolsChanged
// callback fires on RemoveServer (we can't easily test AddServer with a real
// MCP server, but we can test the callback wiring).
func TestMCPToolManager_OnToolsChanged_Callback(t *testing.T) {
manager := NewMCPToolManager()
var mu sync.Mutex
callCount := 0
manager.SetOnToolsChanged(func() {
mu.Lock()
callCount++
mu.Unlock()
})
// RemoveServer on non-existent should NOT fire callback.
_ = manager.RemoveServer("nonexistent")
mu.Lock()
if callCount != 0 {
t.Errorf("Expected 0 callback calls for failed remove, got %d", callCount)
}
mu.Unlock()
}
// TestMCPToolManager_Close_NilPool verifies Close is safe when the connection
// pool was never initialized.
func TestMCPToolManager_Close_NilPool(t *testing.T) {
manager := NewMCPToolManager()
err := manager.Close()
if err != nil {
t.Fatalf("Expected nil error from Close with nil pool, got: %v", err)
}
}
// TestMCPConnectionPool_RemoveConnection_NotFound verifies that removing a
// non-existent connection returns an error.
func TestMCPConnectionPool_RemoveConnection_NotFound(t *testing.T) {
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), nil, false, nil, nil)
defer func() { _ = pool.Close() }()
err := pool.RemoveConnection("nonexistent")
if err == nil {
t.Fatal("Expected error for non-existent connection")
}
if !strings.Contains(err.Error(), "not found") {
t.Errorf("Expected 'not found' error, got: %v", err)
}
}
// TestMCPToolManager_EnsureConnectionPool_Idempotent verifies that
// ensureConnectionPool doesn't recreate an existing pool.
func TestMCPToolManager_EnsureConnectionPool_Idempotent(t *testing.T) {
manager := NewMCPToolManager()
// First call creates the pool.
manager.ensureConnectionPool()
pool1 := manager.connectionPool
if pool1 == nil {
t.Fatal("Expected pool to be created")
}
// Second call should be a no-op.
manager.ensureConnectionPool()
pool2 := manager.connectionPool
if pool1 != pool2 {
t.Fatal("Expected ensureConnectionPool to be idempotent")
}
}
+7
View File
@@ -6,6 +6,7 @@ import (
"net/url"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
)
// MCPAuthHandler is the internal interface for handling MCP OAuth flows.
@@ -21,6 +22,12 @@ type MCPAuthHandler interface {
HandleAuth(ctx context.Context, serverName string, authURL string) (callbackURL string, err error)
}
// TokenStoreFactory creates a transport.TokenStore for a given MCP server URL.
// When provided to the connection pool, it is called once per remote MCP server
// instead of using the default file-based token store. Implementations can
// return any transport.TokenStore — in-memory, database-backed, encrypted, etc.
type TokenStoreFactory func(serverURL string) (transport.TokenStore, error)
// OAuthFlowRunner handles the OAuth authorization flow when an MCP server
// returns an OAuthAuthorizationRequiredError. It coordinates dynamic client
// registration, PKCE generation, user authorization (via MCPAuthHandler),
+111
View File
@@ -0,0 +1,111 @@
#!/usr/bin/env python3
"""Minimal MCP server over stdio for testing. Exposes one tool: echo."""
import json
import sys
def read_message():
"""Read a JSON-RPC message from stdin."""
line = sys.stdin.readline()
if not line:
return None
return json.loads(line.strip())
def write_message(msg):
"""Write a JSON-RPC message to stdout."""
sys.stdout.write(json.dumps(msg) + "\n")
sys.stdout.flush()
def handle(msg):
method = msg.get("method", "")
mid = msg.get("id")
if method == "initialize":
write_message({
"jsonrpc": "2.0",
"id": mid,
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"serverInfo": {"name": "test-echo", "version": "1.0.0"},
},
})
elif method == "notifications/initialized":
pass # no response needed
elif method == "tools/list":
write_message({
"jsonrpc": "2.0",
"id": mid,
"result": {
"tools": [
{
"name": "echo",
"description": "Echoes the input text back.",
"inputSchema": {
"type": "object",
"properties": {
"text": {"type": "string", "description": "Text to echo"}
},
"required": ["text"],
},
},
{
"name": "greet",
"description": "Returns a greeting.",
"inputSchema": {
"type": "object",
"properties": {
"name": {"type": "string", "description": "Name to greet"}
},
"required": ["name"],
},
},
]
},
})
elif method == "tools/call":
tool_name = msg["params"]["name"]
args = msg["params"].get("arguments", {})
if tool_name == "echo":
text = args.get("text", "")
write_message({
"jsonrpc": "2.0",
"id": mid,
"result": {
"content": [{"type": "text", "text": text}]
},
})
elif tool_name == "greet":
name = args.get("name", "World")
write_message({
"jsonrpc": "2.0",
"id": mid,
"result": {
"content": [{"type": "text", "text": f"Hello, {name}!"}]
},
})
else:
write_message({
"jsonrpc": "2.0",
"id": mid,
"error": {"code": -32601, "message": f"Unknown tool: {tool_name}"},
})
elif method == "ping":
write_message({"jsonrpc": "2.0", "id": mid, "result": {}})
else:
if mid is not None:
write_message({
"jsonrpc": "2.0",
"id": mid,
"error": {"code": -32601, "message": f"Unknown method: {method}"},
})
if __name__ == "__main__":
while True:
msg = read_message()
if msg is None:
break
handle(msg)
-28
View File
@@ -5,7 +5,6 @@ import (
"os"
"time"
"charm.land/fantasy"
"charm.land/lipgloss/v2"
"golang.org/x/term"
@@ -173,33 +172,6 @@ func (c *CLI) DisplayDebugConfig(config map[string]any) {
fmt.Println(c.renderer.RenderDebugConfigMessage(config, time.Now()).Content)
}
// UpdateUsageFromResponse records token usage using metadata from the fantasy
// response. Only actual API-reported tokens are used for cost tracking.
// If the provider doesn't report token counts, no usage is recorded.
func (c *CLI) UpdateUsageFromResponse(response *fantasy.Response, inputText string) {
if c.usageTracker == nil {
return
}
usage := response.Usage
inputTokens := int(usage.InputTokens)
outputTokens := int(usage.OutputTokens)
// Only use actual API-reported tokens for cost tracking.
// We intentionally do NOT estimate tokens - estimation is inaccurate
// and should never be used for cost calculations.
if inputTokens > 0 {
cacheReadTokens := int(usage.CacheReadTokens)
cacheWriteTokens := int(usage.CacheCreationTokens)
c.usageTracker.UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens)
// Per-response usage is a single API call, so it represents the
// actual context window fill level.
c.usageTracker.SetContextTokens(inputTokens + outputTokens)
}
// If inputTokens is 0, the provider didn't report usage - we skip recording
// rather than estimating, to ensure cost accuracy.
}
// DisplayUsageAfterResponse renders and displays token usage information immediately
// following an AI response. This provides real-time feedback about the cost and
// token consumption of each interaction.
+3 -1
View File
@@ -139,7 +139,9 @@ func (h *CLIEventHandler) Handle(msg tea.Msg) {
case "block":
h.cli.DisplayExtensionBlock(e.Text, e.BorderColor, e.Subtitle)
default:
fmt.Println(e.Text)
// Route unstyled extension prints through the system message
// renderer so they get consistent formatting and timestamps.
h.cli.DisplayInfo(e.Text)
}
case app.StepCompleteEvent:
+1 -3
View File
@@ -109,9 +109,7 @@ func SetupCLI(opts *CLISetupOptions) (*CLI, error) {
}
}
fmt.Println("")
// Display model info
// Display model info (the system message block provides its own spacing).
if provider != "unknown" && model != "unknown" {
cli.DisplayInfo(fmt.Sprintf("Model loaded: %s (%s)", provider, model))
}
+7 -7
View File
@@ -69,7 +69,7 @@ type InputComponent struct {
hideHint bool
// agentBusy indicates the agent is currently working. When true, the
// hint text shows steering shortcut (Ctrl+S) instead of submit.
// hint text shows steering shortcut (Ctrl+X s) instead of submit.
agentBusy bool
// pendingImages holds clipboard images attached to the next submission.
@@ -109,7 +109,7 @@ func NewInputComponent(width int, title string, appCtrl AppController) *InputCom
ta.Placeholder = "Type your message..."
ta.ShowLineNumbers = false
ta.Prompt = ""
ta.CharLimit = 5000
ta.CharLimit = 0
ta.SetWidth(width - 8) // Account for container padding, border and internal padding
ta.SetHeight(3) // Default to 3 lines like huh
ta.Focus()
@@ -514,12 +514,12 @@ func (s *InputComponent) View() tea.View {
availableHintWidth := s.width - 3
if s.agentBusy {
// When the agent is working, show steering shortcut.
if availableHintWidth >= 55 {
hint = "enter queue • ctrl+s steer • esc esc cancel"
} else if availableHintWidth >= 35 {
hint = "↵ queue • ^S steer • esc×2 cancel"
if availableHintWidth >= 60 {
hint = "enter queue • ctrl+x s steer • esc esc cancel"
} else if availableHintWidth >= 40 {
hint = "↵ queue • ^X s steer • esc×2 cancel"
} else {
hint = "^S steer"
hint = "^X s steer"
}
} else if availableHintWidth >= 67 {
hint = "enter submit • ctrl+j / shift+enter new line • ctrl+v paste image"
+1 -1
View File
@@ -152,7 +152,7 @@ func (r *MessageRenderer) SetWidth(width int) {
// RenderUserMessage renders a user's input message using herald Tip alert
func (r *MessageRenderer) RenderUserMessage(content string, timestamp time.Time) UIMessage {
rendered := render.UserBlock(content, r.ty, style.GetTheme())
rendered := render.UserBlock(content, r.width, r.ty, style.GetTheme())
return UIMessage{
Type: UserMessage,
+143 -59
View File
@@ -477,7 +477,7 @@ type AppModel struct {
queuedMessages []string
// steeringMessages stores the text of prompts that were sent as steer
// messages (injected mid-turn via Ctrl+S). Rendered with a "STEERING"
// messages (injected mid-turn via Ctrl+X s). Rendered with a "STEERING"
// badge above the input. Cleared when the steer is consumed.
steeringMessages []string
@@ -498,6 +498,11 @@ type AppModel struct {
// A second ESC within 2 seconds will cancel the current step.
canceling bool
// leaderKeyActive tracks whether the Ctrl+X leader key prefix has been
// pressed. The next keypress is interpreted as a chord suffix (e.g. "s"
// for steer). Cleared on any subsequent keypress.
leaderKeyActive bool
// providerName is the LLM provider for the startup message.
providerName string
@@ -1268,6 +1273,71 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, tea.Batch(cmds...)
}
// ── Leader key chord handling (Ctrl+X prefix) ──────────────
// If the leader key was previously pressed, the current key
// completes the chord. We consume it regardless of match so
// the prefix doesn't leak to child components.
if m.leaderKeyActive {
m.leaderKeyActive = false
switch msg.String() {
case "s":
// Ctrl+X s → Steer: inject the current input as a steering
// message into the running agent turn.
if m.state == stateWorking && m.appCtrl != nil {
var text string
if ic, ok := m.input.(*InputComponent); ok {
text = strings.TrimSpace(ic.textarea.Value())
}
if text != "" {
// Clear the input, collect pending images, and push to history.
var images []uicore.ImageAttachment
if ic, ok := m.input.(*InputComponent); ok {
ic.pushHistory(text)
ic.textarea.SetValue("")
images = ic.ClearPendingImages()
}
// Preprocess @file references.
processedText := text
if m.cwd != "" {
processedText = fileutil.ProcessFileAttachments(text, m.cwd)
}
// Convert image attachments to kit.LLMFilePart for the app layer.
var fileParts []kit.LLMFilePart
for _, img := range images {
fileParts = append(fileParts, kit.LLMFilePart{
Data: img.Data,
MediaType: img.MediaType,
})
}
// Build display text (include image count if any).
displayText := text
if len(images) > 0 {
displayText = fmt.Sprintf("%s\n[%d image(s) attached]", text, len(images))
}
// Inject the steer message.
sLen := m.appCtrl.SteerWithFiles(processedText, fileParts)
if sLen > 0 {
m.steeringMessages = append(m.steeringMessages, displayText)
m.layoutDirty = true
} else {
// Started immediately (agent was idle).
m.pendingUserPrints = append(m.pendingUserPrints, displayText)
m.flushStreamAndPendingUserMessages()
if m.state != stateWorking {
m.state = stateWorking
}
}
}
}
}
// Chord consumed — don't propagate to children.
return m, tea.Batch(cmds...)
}
switch msg.String() {
case "esc":
if m.state == stateWorking {
@@ -1286,61 +1356,10 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
// In other states pass ESC through to children below.
case "ctrl+s":
// Steer: inject the current input as a steering message into the
// running agent turn. Only active during stateWorking — in input
// state, Ctrl+S is passed through to children (no-op by default).
if m.state == stateWorking && m.appCtrl != nil {
var text string
if ic, ok := m.input.(*InputComponent); ok {
text = strings.TrimSpace(ic.textarea.Value())
}
if text != "" {
// Clear the input, collect pending images, and push to history.
var images []uicore.ImageAttachment
if ic, ok := m.input.(*InputComponent); ok {
ic.pushHistory(text)
ic.textarea.SetValue("")
images = ic.ClearPendingImages()
}
// Preprocess @file references.
processedText := text
if m.cwd != "" {
processedText = fileutil.ProcessFileAttachments(text, m.cwd)
}
// Convert image attachments to kit.LLMFilePart for the app layer.
var fileParts []kit.LLMFilePart
for _, img := range images {
fileParts = append(fileParts, kit.LLMFilePart{
Data: img.Data,
MediaType: img.MediaType,
})
}
// Build display text (include image count if any).
displayText := text
if len(images) > 0 {
displayText = fmt.Sprintf("%s\n[%d image(s) attached]", text, len(images))
}
// Inject the steer message.
sLen := m.appCtrl.SteerWithFiles(processedText, fileParts)
if sLen > 0 {
m.steeringMessages = append(m.steeringMessages, displayText)
m.layoutDirty = true
} else {
// Started immediately (agent was idle).
m.pendingUserPrints = append(m.pendingUserPrints, displayText)
m.flushStreamAndPendingUserMessages()
if m.state != stateWorking {
m.state = stateWorking
}
}
}
return m, tea.Batch(cmds...)
}
case "ctrl+x":
// Activate leader key prefix — the next keypress completes the chord.
m.leaderKeyActive = true
return m, tea.Batch(cmds...)
}
// Route key events to the focused child. Check for editor
@@ -2462,22 +2481,34 @@ func (m *AppModel) renderHeaderFooter(getter func() *WidgetData) string {
return renderContentBlock(data.Text, m.width, opts...)
}
// maxQueuedMessageLines is the maximum number of visible content lines
// rendered for each queued or steering message block. Messages exceeding
// this limit are truncated with an ellipsis to prevent large pastes from
// overflowing the screen and squeezing the stream region to zero.
const maxQueuedMessageLines = 3
// renderQueuedMessages renders queued and steering prompts as styled content
// blocks with badges, anchored between the separator and input. Steering
// messages use a distinct "STEERING" badge to differentiate from queued ones.
// Long messages are visually truncated to maxQueuedMessageLines.
func (m *AppModel) renderQueuedMessages() string {
if len(m.queuedMessages) == 0 && len(m.steeringMessages) == 0 {
return ""
}
theme := style.GetTheme()
// Available content width inside the block: container minus border (1)
// minus left padding (2). Used to estimate line wrapping for truncation.
contentWidth := max(m.width-3, 10)
var blocks []string
// Render steering messages first (higher priority).
if len(m.steeringMessages) > 0 {
badge := style.CreateBadge("STEERING", theme.Warning)
for _, msg := range m.steeringMessages {
content := msg + "\n" + badge
display := truncateMessageForBlock(msg, maxQueuedMessageLines, contentWidth)
content := display + "\n" + badge
rendered := renderContentBlock(
content,
m.width,
@@ -2492,7 +2523,8 @@ func (m *AppModel) renderQueuedMessages() string {
if len(m.queuedMessages) > 0 {
badge := style.CreateBadge("QUEUED", theme.Accent)
for _, msg := range m.queuedMessages {
content := msg + "\n" + badge
display := truncateMessageForBlock(msg, maxQueuedMessageLines, contentWidth)
content := display + "\n" + badge
rendered := renderContentBlock(
content,
m.width,
@@ -2506,6 +2538,58 @@ func (m *AppModel) renderQueuedMessages() string {
return strings.Join(blocks, "\n")
}
// truncateMessageForBlock truncates a message to at most maxLines visible
// lines, accounting for soft-wrapping at the given width. If the message is
// truncated, the last visible line is replaced with an ellipsis ("…").
func truncateMessageForBlock(msg string, maxLines, width int) string {
if width <= 0 {
width = 1
}
lines := strings.Split(msg, "\n")
// Count visible lines (each hard line may wrap into multiple visual lines).
var kept []string
visibleCount := 0
truncated := false
for _, line := range lines {
// Calculate how many visual lines this hard line occupies.
lineWidth := lipgloss.Width(line)
wrapped := 1
if lineWidth > width {
wrapped = (lineWidth + width - 1) / width // ceil division
}
if visibleCount+wrapped > maxLines {
// This line would exceed the limit. Keep a partial if we
// still have room for at least one more visual line.
remaining := maxLines - visibleCount
if remaining > 0 {
// Truncate the line to fit the remaining visual lines.
runes := []rune(line)
maxRunes := remaining * width
if maxRunes < len(runes) {
kept = append(kept, string(runes[:maxRunes]))
} else {
kept = append(kept, line)
}
}
truncated = true
break
}
kept = append(kept, line)
visibleCount += wrapped
}
if !truncated {
return msg
}
return strings.Join(kept, "\n") + "…"
}
// --------------------------------------------------------------------------
// Print helpers — add content to ScrollList
// --------------------------------------------------------------------------
@@ -2876,7 +2960,7 @@ func (m *AppModel) printHelpMessage() {
"**Keys:**\n" +
"- `Ctrl+C`: Exit at any time\n" +
"- `ESC` (x2): Cancel ongoing LLM generation\n" +
"- `Ctrl+S`: Steer — redirect the agent mid-turn (injected between tool calls)\n" +
"- `Ctrl+X s`: Steer — redirect the agent mid-turn (injected between tool calls)\n" +
"- `Enter` (while working): Queue message for after the agent finishes\n\n" +
"You can also just type your message to chat with the AI assistant."
m.printSystemMessage(help)
+105
View File
@@ -2,6 +2,7 @@ package ui
import (
"errors"
"strings"
"testing"
tea "charm.land/bubbletea/v2"
@@ -892,3 +893,107 @@ func TestSubmit_duringWorking_stays(t *testing.T) {
t.Fatalf("expected Run('queued prompt') called, got %v", ctrl.runCalls)
}
}
// --------------------------------------------------------------------------
// truncateMessageForBlock
// --------------------------------------------------------------------------
// TestTruncateMessageForBlock_shortMessage verifies that short messages are
// returned unchanged.
func TestTruncateMessageForBlock_shortMessage(t *testing.T) {
msg := "hello world"
got := truncateMessageForBlock(msg, 3, 80)
if got != msg {
t.Fatalf("expected unchanged message, got %q", got)
}
}
// TestTruncateMessageForBlock_exactLines verifies that a message with exactly
// maxLines hard lines is returned unchanged.
func TestTruncateMessageForBlock_exactLines(t *testing.T) {
msg := "line1\nline2\nline3"
got := truncateMessageForBlock(msg, 3, 80)
if got != msg {
t.Fatalf("expected unchanged message, got %q", got)
}
}
// TestTruncateMessageForBlock_tooManyLines verifies that messages exceeding
// maxLines are truncated with an ellipsis.
func TestTruncateMessageForBlock_tooManyLines(t *testing.T) {
msg := "line1\nline2\nline3\nline4\nline5"
got := truncateMessageForBlock(msg, 3, 80)
want := "line1\nline2\nline3…"
if got != want {
t.Fatalf("expected %q, got %q", want, got)
}
}
// TestTruncateMessageForBlock_longWrappingLine verifies that a single long
// line that would wrap beyond maxLines is truncated.
func TestTruncateMessageForBlock_longWrappingLine(t *testing.T) {
// 100 chars at width 20 = 5 visual lines, exceeds maxLines=3
msg := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
got := truncateMessageForBlock(msg, 3, 20)
// Should be truncated to 3*20=60 runes + "…"
if len([]rune(got)) != 61 { // 60 runes + "…"
t.Fatalf("expected 61 runes (60 + ellipsis), got %d runes: %q", len([]rune(got)), got)
}
if got[len(got)-3:] != "…" { // "…" is 3 bytes in UTF-8
t.Fatal("expected trailing ellipsis")
}
}
// TestTruncateMessageForBlock_emptyMessage verifies that empty messages are
// returned unchanged.
func TestTruncateMessageForBlock_emptyMessage(t *testing.T) {
got := truncateMessageForBlock("", 3, 80)
if got != "" {
t.Fatalf("expected empty string, got %q", got)
}
}
// TestTruncateMessageForBlock_mixedWrapAndHardLines verifies truncation when
// some hard lines wrap and the total exceeds maxLines.
func TestTruncateMessageForBlock_mixedWrapAndHardLines(t *testing.T) {
// First line: 40 chars at width 20 = 2 visual lines
// Second line: "short" = 1 visual line (total: 3, exactly at limit)
// Third line: would exceed
msg := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\nshort\nextra"
got := truncateMessageForBlock(msg, 3, 20)
want := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\nshort…"
if got != want {
t.Fatalf("expected %q, got %q", want, got)
}
}
// TestRenderQueuedMessages_truncatesLongMessages verifies that the rendered
// queued message view truncates long messages instead of showing them in full.
func TestRenderQueuedMessages_truncatesLongMessages(t *testing.T) {
ctrl := &stubAppController{}
m, _, _ := newTestAppModel(ctrl)
m.width = 80
// Queue a very long message (20 lines).
var b strings.Builder
for i := range 20 {
if i > 0 {
b.WriteByte('\n')
}
b.WriteString("This is a long line of text for testing purposes")
}
m.queuedMessages = []string{b.String()}
rendered := m.renderQueuedMessages()
if rendered == "" {
t.Fatal("expected non-empty rendered output")
}
// The full message would be ~20+ lines. With truncation to 3 content
// lines + badge + padding, it should be much shorter.
lines := len(strings.Split(rendered, "\n"))
// 3 content lines + 1 badge + 2 padding + border overhead ≈ ~7 lines max
if lines > 10 {
t.Fatalf("expected truncated output to be ≤10 lines, got %d lines", lines)
}
}
+1 -1
View File
@@ -78,7 +78,7 @@ func newInputPrompt(message, placeholder, defaultValue string, width, height int
ta.Placeholder = placeholder
ta.ShowLineNumbers = false
ta.Prompt = ""
ta.CharLimit = 1000
ta.CharLimit = 0
ta.SetWidth(width - 12) // account for border + padding
ta.SetHeight(1)
ta.Focus()
+9 -1
View File
@@ -14,11 +14,19 @@ import (
)
// UserBlock renders a user message with herald Tip styling.
func UserBlock(content string, ty *herald.Typography, theme style.Theme) string {
// The width parameter controls line wrapping so long messages don't overflow.
func UserBlock(content string, width int, ty *herald.Typography, theme style.Theme) string {
if strings.TrimSpace(content) == "" {
content = "(empty message)"
}
// Wrap content before passing to herald Alert so long lines break
// inside the alert box. Subtract 4 to account for the alert bar
// prefix ("│ ") and a small margin.
if width > 4 {
content = lipgloss.Wrap(content, width-4, "")
}
rendered := ty.Tip(content)
return styleMarginBottom(theme, rendered)
}
+5 -3
View File
@@ -85,11 +85,13 @@ func GetMarkdownTypography() *herald.Typography {
return ty
}
// ToMarkdown renders markdown content using herald-md.
// The width parameter is currently unused as herald handles wrapping
// based on terminal width internally.
// ToMarkdown renders markdown content using herald-md and wraps the result
// to the given width so that long lines do not overflow the terminal.
func ToMarkdown(content string, width int) string {
ty := GetMarkdownTypography()
rendered := heraldmd.Render(ty, []byte(content))
if width > 0 {
rendered = lipgloss.Wrap(rendered, width, "")
}
return rendered
}
+1 -1
View File
@@ -23,7 +23,7 @@ func NewToolApprovalInput(toolName, toolArgs string, width int) *ToolApprovalInp
ta := textarea.New()
ta.Placeholder = ""
ta.ShowLineNumbers = false
ta.CharLimit = 1000
ta.CharLimit = 0
ta.SetWidth(width - 8) // Account for container padding, border and internal padding
ta.SetHeight(4) // Default to 3 lines like huh
ta.Focus()
+24 -2
View File
@@ -68,8 +68,12 @@ host, err := kit.New(ctx, &kit.Options{
NoSession: true, // Ephemeral mode
// Tool options
Tools: []kit.Tool{kit.NewBashTool()}, // Replace default tool set
ExtraTools: []kit.Tool{myTool}, // Add alongside defaults
Tools: []kit.Tool{kit.NewBashTool()}, // Replace default tool set
ExtraTools: []kit.Tool{myTool}, // Add alongside defaults
DisableCoreTools: true, // Use no core tools (0 tools)
// Configuration
SkipConfig: true, // Skip .kit.yml files (viper defaults + env vars still apply)
// Compaction
AutoCompact: true, // Auto-compact near context limit
@@ -172,6 +176,24 @@ msg := kit.ConvertFromLLMMessage(lMsg) // LLMMessage → SDK Message
- `GetSessionID()` - Get session UUID
- `Close()` - Clean up resources
### Options
Key `Options` fields for SDK usage:
| Field | Description |
|-------|-------------|
| `Model` | Override model (e.g., "anthropic/claude-sonnet-4-5-20250929") |
| `SystemPrompt` | Override system prompt |
| `ConfigFile` | Load specific config file (empty = search defaults) |
| `SkipConfig` | Skip `.kit.yml` loading (defaults + env vars still apply) |
| `Tools` | Replace core tools with custom set |
| `ExtraTools` | Add tools alongside defaults |
| `DisableCoreTools` | Use no core tools (0 tools, for chat-only) |
| `NoSession` | Ephemeral mode (no session persistence) |
| `SessionPath` | Open specific session file |
| `Continue` | Resume most recent session |
| `Debug` | Enable debug logging |
## Environment Variables
All CLI environment variables work with the SDK:
+231
View File
@@ -0,0 +1,231 @@
package kit
import (
"strings"
"time"
"charm.land/fantasy"
"github.com/mark3labs/kit/internal/session"
)
// treeManagerAdapter adapts TreeManager to SessionManager interface.
// This is unexported - users don't interact with it directly.
type treeManagerAdapter struct {
inner *session.TreeManager
}
// NewTreeManagerAdapter creates an adapter (exported for use in New function).
// This is used by the SDK when no custom SessionManager is provided.
func NewTreeManagerAdapter(tm *session.TreeManager) SessionManager {
return &treeManagerAdapter{inner: tm}
}
// AppendMessage implements SessionManager.
func (a *treeManagerAdapter) AppendMessage(msg LLMMessage) (string, error) {
// LLMMessage is just an alias for fantasy.Message, so no conversion needed
return a.inner.AppendLLMMessage(msg)
}
// GetMessages implements SessionManager.
func (a *treeManagerAdapter) GetMessages() []LLMMessage {
// LLMMessage is just an alias for fantasy.Message
return a.inner.GetLLMMessages()
}
// BuildContext implements SessionManager.
func (a *treeManagerAdapter) BuildContext() ([]LLMMessage, string, string) {
msgs, provider, modelID := a.inner.BuildContext()
return msgs, provider, modelID
}
// Branch implements SessionManager.
func (a *treeManagerAdapter) Branch(entryID string) error {
return a.inner.Branch(entryID)
}
// GetCurrentBranch implements SessionManager.
func (a *treeManagerAdapter) GetCurrentBranch() []BranchEntry {
branch := a.inner.GetBranch("")
var result []BranchEntry
for _, entry := range branch {
be := a.convertEntry(entry)
if be != nil {
result = append(result, *be)
}
}
return result
}
// GetChildren implements SessionManager.
func (a *treeManagerAdapter) GetChildren(parentID string) []string {
return a.inner.GetChildren(parentID)
}
// GetEntry implements SessionManager.
func (a *treeManagerAdapter) GetEntry(entryID string) *BranchEntry {
entry := a.inner.GetEntry(entryID)
if entry == nil {
return nil
}
return a.convertEntry(entry)
}
// GetSessionID implements SessionManager.
func (a *treeManagerAdapter) GetSessionID() string {
return a.inner.GetSessionID()
}
// GetSessionName implements SessionManager.
func (a *treeManagerAdapter) GetSessionName() string {
return a.inner.GetSessionName()
}
// SetSessionName implements SessionManager.
func (a *treeManagerAdapter) SetSessionName(name string) error {
_, err := a.inner.AppendSessionInfo(name)
return err
}
// GetCreatedAt implements SessionManager.
func (a *treeManagerAdapter) GetCreatedAt() time.Time {
return a.inner.GetHeader().Timestamp
}
// IsPersisted implements SessionManager.
func (a *treeManagerAdapter) IsPersisted() bool {
return a.inner.IsPersisted()
}
// AppendCompaction implements SessionManager.
func (a *treeManagerAdapter) AppendCompaction(summary string, firstKeptEntryID string,
tokensBefore, tokensAfter int, messagesRemoved int, readFiles, modifiedFiles []string) (string, error) {
return a.inner.AppendCompaction(summary, firstKeptEntryID,
tokensBefore, tokensAfter, messagesRemoved, readFiles, modifiedFiles)
}
// GetLastCompaction implements SessionManager.
func (a *treeManagerAdapter) GetLastCompaction() *CompactionEntry {
c := a.inner.GetLastCompaction()
if c == nil {
return nil
}
return &CompactionEntry{
ID: c.ID,
Summary: c.Summary,
FirstKeptEntryID: c.FirstKeptEntryID,
TokensBefore: c.TokensBefore,
TokensAfter: c.TokensAfter,
MessagesRemoved: c.MessagesRemoved,
ReadFiles: c.ReadFiles,
ModifiedFiles: c.ModifiedFiles,
Timestamp: c.Timestamp,
}
}
// AppendExtensionData implements SessionManager.
func (a *treeManagerAdapter) AppendExtensionData(extType, data string) (string, error) {
return a.inner.AppendExtensionData(extType, data)
}
// GetExtensionData implements SessionManager.
func (a *treeManagerAdapter) GetExtensionData(extType string) []ExtensionDataEntry {
entries := a.inner.GetExtensionData(extType)
var result []ExtensionDataEntry
for _, e := range entries {
result = append(result, ExtensionDataEntry{
ID: e.ID,
ExtType: e.ExtType,
Data: e.Data,
Timestamp: e.Timestamp,
})
}
return result
}
// AppendModelChange implements SessionManager.
func (a *treeManagerAdapter) AppendModelChange(provider, modelID string) (string, error) {
return a.inner.AppendModelChange(provider, modelID)
}
// GetContextEntryIDs implements SessionManager.
func (a *treeManagerAdapter) GetContextEntryIDs() []string {
return a.inner.GetContextEntryIDs()
}
// Close implements SessionManager.
func (a *treeManagerAdapter) Close() error {
return a.inner.Close()
}
// Helper: Convert internal entry types to BranchEntry
func (a *treeManagerAdapter) convertEntry(entry any) *BranchEntry {
switch e := entry.(type) {
case *session.MessageEntry:
msg, err := e.ToMessage()
if err != nil {
return nil
}
// Build content text from parts
var content strings.Builder
for _, part := range msg.Parts {
if textPart, ok := part.(TextContent); ok {
content.WriteString(textPart.Text)
}
}
return &BranchEntry{
ID: e.ID,
ParentID: e.ParentID,
Type: EntryTypeMessage,
Role: string(msg.Role),
Content: content.String(),
Model: e.Model,
Provider: e.Provider,
Timestamp: e.Timestamp,
RawParts: msg.Parts,
}
case *session.BranchSummaryEntry:
return &BranchEntry{
ID: e.ID,
ParentID: e.ParentID,
Type: EntryTypeBranchSummary,
Content: e.Summary,
Timestamp: e.Timestamp,
}
case *session.ModelChangeEntry:
return &BranchEntry{
ID: e.ID,
ParentID: e.ParentID,
Type: EntryTypeModelChange,
Content: "Model changed to " + e.Provider + "/" + e.ModelID,
Model: e.ModelID,
Provider: e.Provider,
Timestamp: e.Timestamp,
}
case *session.CompactionEntry:
return &BranchEntry{
ID: e.ID,
ParentID: e.ParentID,
Type: EntryTypeCompaction,
Content: e.Summary,
Timestamp: e.Timestamp,
}
case *session.ExtensionDataEntry:
return &BranchEntry{
ID: e.ID,
ParentID: e.ParentID,
Type: EntryTypeExtensionData,
Content: "Extension data: " + e.ExtType,
Timestamp: e.Timestamp,
}
default:
return nil
}
}
// convertKitMessagesToFantasy converts kit LLM messages to fantasy messages.
// Since LLMMessage is an alias for fantasy.Message, this is a no-op.
func convertKitMessagesToFantasy(msgs []LLMMessage) []fantasy.Message {
// LLMMessage is just an alias for fantasy.Message, so we can type convert
return msgs
}
+12 -12
View File
@@ -21,9 +21,9 @@ type ContextStats struct {
const defaultReserveTokens = 16384
// EstimateContextTokens returns the estimated token count of the current
// conversation based on tree session messages.
// conversation based on session messages.
func (m *Kit) EstimateContextTokens() int {
messages := m.treeSession.GetLLMMessages()
messages := m.session.GetMessages()
return compaction.EstimateMessageTokens(messages)
}
@@ -42,8 +42,8 @@ func (m *Kit) ShouldCompact() bool {
reserveTokens = m.compactionOpts.ReserveTokens
}
messages := m.treeSession.GetLLMMessages()
return compaction.ShouldCompact(messages, info.Limit.Context, reserveTokens)
messages := m.session.GetMessages()
return compaction.ShouldCompact(convertKitMessagesToFantasy(messages), info.Limit.Context, reserveTokens)
}
// GetContextStats returns current context usage statistics including
@@ -55,7 +55,7 @@ func (m *Kit) ShouldCompact() bool {
// because it includes system prompts, tool definitions, and other overhead
// that the heuristic cannot account for.
func (m *Kit) GetContextStats() ContextStats {
messages := m.treeSession.GetLLMMessages()
messages := m.session.GetMessages()
// Prefer the real API-reported input token count when available.
m.lastInputTokensMu.RLock()
@@ -114,7 +114,7 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
}
}
messages := m.treeSession.GetLLMMessages()
messages := m.session.GetMessages()
if len(messages) < 2 {
return nil, fmt.Errorf("cannot compact: need at least 2 messages")
}
@@ -145,7 +145,7 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
// Carry forward file tracking from previous compaction.
var prev *compaction.PreviousCompaction
if lastCompaction := m.treeSession.GetLastCompaction(); lastCompaction != nil {
if lastCompaction := m.session.GetLastCompaction(); lastCompaction != nil {
prev = &compaction.PreviousCompaction{
ReadFiles: lastCompaction.ReadFiles,
ModifiedFiles: lastCompaction.ModifiedFiles,
@@ -171,7 +171,7 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
// Non-destructive: append a CompactionEntry to the session tree instead
// of clearing and rewriting messages.
entryIDs := m.treeSession.GetContextEntryIDs()
entryIDs := m.session.GetContextEntryIDs()
firstKeptEntryID := ""
if result.CutPoint >= 0 && result.CutPoint < len(entryIDs) {
firstKeptEntryID = entryIDs[result.CutPoint]
@@ -188,9 +188,9 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
// custom summary. It still determines the cut point and persists a
// CompactionEntry.
func (m *Kit) applyCustomCompaction(summary string, messages []LLMMessage, opts *CompactionOptions) (*CompactionResult, error) {
originalTokens := compaction.EstimateMessageTokens(messages)
originalTokens := compaction.EstimateMessageTokens(convertKitMessagesToFantasy(messages))
cutPoint := compaction.FindCutPoint(messages, opts.KeepRecentTokens)
cutPoint := compaction.FindCutPoint(convertKitMessagesToFantasy(messages), opts.KeepRecentTokens)
if cutPoint == 0 {
cutPoint = len(messages) - 1
if cutPoint < 1 {
@@ -198,7 +198,7 @@ func (m *Kit) applyCustomCompaction(summary string, messages []LLMMessage, opts
}
}
entryIDs := m.treeSession.GetContextEntryIDs()
entryIDs := m.session.GetContextEntryIDs()
firstKeptEntryID := ""
if cutPoint >= 0 && cutPoint < len(entryIDs) {
firstKeptEntryID = entryIDs[cutPoint]
@@ -234,7 +234,7 @@ func (m *Kit) persistAndEmitCompaction(
originalTokens, compactedTokens, messagesRemoved int,
readFiles, modifiedFiles []string,
) error {
if _, err := m.treeSession.AppendCompaction(
if _, err := m.session.AppendCompaction(
summary,
firstKeptEntryID,
originalTokens,
+34 -11
View File
@@ -227,28 +227,51 @@ func (e *extensionAPI) GetMessageRenderer(name string) *extensions.MessageRender
// Session data
func (e *extensionAPI) GetSessionMessages() []extensions.SessionMessage {
return iterBranchMessages(e.kit.treeSession, func(me *session.MessageEntry, msg message.Message) extensions.SessionMessage {
return extensions.SessionMessage{
ID: me.ID,
Role: string(msg.Role),
Content: msg.Content(),
Timestamp: me.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
if e.kit.session == nil {
return nil
}
// Try to use the legacy iterBranchMessages for backward compatibility
// with the default TreeManager adapter
if adapter, ok := e.kit.session.(*treeManagerAdapter); ok {
return iterBranchMessages(adapter.inner, func(me *session.MessageEntry, msg message.Message) extensions.SessionMessage {
return extensions.SessionMessage{
ID: me.ID,
Role: string(msg.Role),
Content: msg.Content(),
Timestamp: me.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
}
})
}
// For custom SessionManagers, use the public interface
branch := e.kit.session.GetCurrentBranch()
var result []extensions.SessionMessage
for _, entry := range branch {
if entry.Type == EntryTypeMessage {
result = append(result, extensions.SessionMessage{
ID: entry.ID,
Role: entry.Role,
Content: entry.Content,
Timestamp: entry.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
})
}
})
}
return result
}
func (e *extensionAPI) AppendEntry(extType, data string) (string, error) {
if e.kit.treeSession == nil {
if e.kit.session == nil {
return "", fmt.Errorf("no session available")
}
return e.kit.treeSession.AppendExtensionData(extType, data)
return e.kit.session.AppendExtensionData(extType, data)
}
func (e *extensionAPI) GetEntries(extType string) []extensions.ExtensionEntry {
if e.kit.treeSession == nil {
if e.kit.session == nil {
return nil
}
entries := e.kit.treeSession.GetExtensionData(extType)
entries := e.kit.session.GetExtensionData(extType)
result := make([]extensions.ExtensionEntry, 0, len(entries))
for _, e := range entries {
result = append(result, extensions.ExtensionEntry{
+352 -95
View File
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"strings"
@@ -11,7 +12,6 @@ import (
"time"
"charm.land/fantasy"
charmlog "github.com/charmbracelet/log"
"github.com/mark3labs/kit/internal/agent"
"github.com/mark3labs/kit/internal/config"
@@ -39,7 +39,7 @@ type ContextFile struct {
// agents, sessions, and model configurations.
type Kit struct {
agent *agent.Agent
treeSession *session.TreeManager
session SessionManager
modelString string
events *eventBus
autoCompact bool
@@ -51,6 +51,12 @@ type Kit struct {
authHandler MCPAuthHandler // OAuth handler for remote MCP servers (may need Close)
opts *Options // stored for reload operations (skills, etc.)
// hasCustomSystemPrompt is true when the user explicitly configured a
// system prompt (via --system-prompt flag, config file, or SDK option).
// When false, per-model system prompts from modelSettings/customModels
// can replace the default prompt on model switch.
hasCustomSystemPrompt bool
// Hook registries — interception layer (see hooks.go).
beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult]
afterToolResult *hookRegistry[AfterToolResultHook, AfterToolResultResult]
@@ -140,6 +146,79 @@ func (m *Kit) MCPToolsReady() bool {
return m.agent.MCPToolsReady()
}
// MCPServerStatus describes the runtime state of a loaded MCP server.
type MCPServerStatus struct {
// Name is the configured server name.
Name string
// ToolCount is the number of tools loaded from this server.
ToolCount int
}
// AddMCPServer connects to a new MCP server at runtime and makes its tools
// available to the agent immediately. The server's tools are prefixed with the
// server name (e.g. "myserver__tool_name") to avoid naming conflicts, matching
// the behaviour of servers loaded at initialization.
//
// Returns the number of tools loaded from the server.
//
// AddMCPServer is safe to call while the agent is idle. If a turn is in
// progress ([Kit.IsGenerating] returns true), the new tools will be visible
// starting from the next LLM step.
//
// Example:
//
// n, err := k.AddMCPServer(ctx, "github", kit.MCPServerConfig{
// Command: []string{"npx", "-y", "@modelcontextprotocol/server-github"},
// Environment: map[string]string{"GITHUB_TOKEN": os.Getenv("GITHUB_TOKEN")},
// })
func (m *Kit) AddMCPServer(ctx context.Context, name string, cfg MCPServerConfig) (int, error) {
return m.agent.AddMCPServer(ctx, name, cfg)
}
// RemoveMCPServer disconnects an MCP server and removes all its tools from
// the agent. After this call the agent will no longer see or be able to call
// tools from the named server.
//
// RemoveMCPServer is safe to call while the agent is idle. If a turn is in
// progress, the tools are removed at the next LLM step. Any in-flight tool
// calls to the removed server will fail gracefully.
//
// Returns an error if the named server is not currently loaded.
func (m *Kit) RemoveMCPServer(name string) error {
return m.agent.RemoveMCPServer(name)
}
// ListMCPServers returns the status of all currently loaded MCP servers.
// The returned slice is a snapshot; it is safe to read concurrently.
func (m *Kit) ListMCPServers() []MCPServerStatus {
names := m.agent.GetLoadedServerNames()
if len(names) == 0 {
return nil
}
// Build a tool count per server by scanning tool names for the prefix.
toolNames := m.GetToolNames()
countByServer := make(map[string]int, len(names))
for _, tn := range toolNames {
for _, sn := range names {
prefix := sn + "__"
if len(tn) > len(prefix) && tn[:len(prefix)] == prefix {
countByServer[sn]++
break
}
}
}
result := make([]MCPServerStatus, 0, len(names))
for _, n := range names {
result = append(result, MCPServerStatus{
Name: n,
ToolCount: countByServer[n],
})
}
return result
}
// GetExtensionToolCount returns the number of tools registered by extensions.
func (m *Kit) GetExtensionToolCount() int {
return m.agent.GetExtensionToolCount()
@@ -172,27 +251,39 @@ type StructuredMessage struct {
// flattens all content to a single text string, this preserves tool calls,
// tool results, reasoning blocks, and finish markers as distinct typed parts.
func (m *Kit) GetStructuredMessages() []StructuredMessage {
return iterBranchMessages(m.treeSession, func(me *session.MessageEntry, msg message.Message) StructuredMessage {
return StructuredMessage{
ID: me.ID,
ParentID: me.ParentID,
Role: msg.Role,
Parts: msg.Parts,
Model: msg.Model,
Provider: msg.Provider,
Timestamp: me.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
if m.session == nil {
return nil
}
branch := m.session.GetCurrentBranch()
var results []StructuredMessage
for _, entry := range branch {
if entry.Type != EntryTypeMessage {
continue
}
})
results = append(results, StructuredMessage{
ID: entry.ID,
ParentID: entry.ParentID,
Role: MessageRole(entry.Role),
Parts: entry.RawParts,
Model: entry.Model,
Provider: entry.Provider,
Timestamp: entry.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
})
}
return results
}
// iterBranchMessages iterates over the current branch's MessageEntry items,
// converting each to a message.Message and calling fn to build the result.
// Returns nil if there is no tree session. Skips entries that are not
// Returns nil if there is no session. Skips entries that are not
// MessageEntry or that fail conversion.
// Deprecated: Use SessionManager.GetCurrentBranch() directly.
func iterBranchMessages[T any](tm *session.TreeManager, fn func(*session.MessageEntry, message.Message) T) []T {
if tm == nil {
return nil
}
branch := tm.GetBranch("")
var results []T
for _, entry := range branch {
@@ -209,9 +300,12 @@ func iterBranchMessages[T any](tm *session.TreeManager, fn func(*session.Message
return results
}
// SetModel changes the active model at runtime. The existing tools, system
// prompt, and session are preserved. The model string should be in
// "provider/model" format (e.g. "anthropic/claude-sonnet-4-5-20250929").
// SetModel changes the active model at runtime. The existing tools and
// session are preserved. When the new model has a per-model system prompt
// (from modelSettings or customModels params), it is composed with the
// current AGENTS.md context and skills before being applied.
// The model string should be in "provider/model" format
// (e.g. "anthropic/claude-sonnet-4-5-20250929").
// Returns an error if the model string is invalid or the provider cannot
// be created.
func (m *Kit) SetModel(ctx context.Context, modelString string) error {
@@ -227,7 +321,7 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
// With message-level caching, thinking and caching can work together.
// No need to disable caching when thinking is enabled.
config := &models.ProviderConfig{
cfg := &models.ProviderConfig{
ModelString: modelString,
SystemPrompt: systemPrompt,
ProviderAPIKey: viper.GetString("provider-api-key"),
@@ -237,18 +331,50 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
ThinkingLevel: thinkingLevel,
DisableCaching: false, // Caching enabled by default, works with thinking
}
temperature := float32(viper.GetFloat64("temperature"))
config.Temperature = &temperature
topP := float32(viper.GetFloat64("top-p"))
config.TopP = &topP
topK := int32(viper.GetInt("top-k"))
config.TopK = &topK
frequencyPenalty := float32(viper.GetFloat64("frequency-penalty"))
config.FrequencyPenalty = &frequencyPenalty
presencePenalty := float32(viper.GetFloat64("presence-penalty"))
config.PresencePenalty = &presencePenalty
if err := m.agent.SetModel(ctx, config); err != nil {
// Only set generation parameter pointers when the user has explicitly
// provided a value. This leaves nil pointers for unset params, allowing
// per-model defaults (modelSettings / customModels params) to apply.
if viper.IsSet("temperature") {
v := float32(viper.GetFloat64("temperature"))
cfg.Temperature = &v
}
if viper.IsSet("top-p") {
v := float32(viper.GetFloat64("top-p"))
cfg.TopP = &v
}
if viper.IsSet("top-k") {
v := int32(viper.GetInt("top-k"))
cfg.TopK = &v
}
if viper.IsSet("frequency-penalty") {
v := float32(viper.GetFloat64("frequency-penalty"))
cfg.FrequencyPenalty = &v
}
if viper.IsSet("presence-penalty") {
v := float32(viper.GetFloat64("presence-penalty"))
cfg.PresencePenalty = &v
}
// When the user hasn't set a custom global system prompt, check for a
// per-model system prompt. Pre-apply model settings to discover it,
// then compose with AGENTS.md context and skills if found.
if !m.hasCustomSystemPrompt {
// Temporarily clear the system prompt so ApplyModelSettings can
// detect that no explicit prompt is set and apply the per-model one.
cfg.SystemPrompt = ""
models.ApplyModelSettings(cfg, models.LookupModelForSettings(modelString))
if cfg.SystemPrompt != "" {
// Per-model system prompt found — compose with runtime context.
cfg.SystemPrompt = m.composeSystemPrompt(cfg.SystemPrompt)
} else {
// No per-model prompt — restore the global composed prompt.
cfg.SystemPrompt = systemPrompt
}
}
if err := m.agent.SetModel(ctx, cfg); err != nil {
return err
}
@@ -264,6 +390,32 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
return nil
}
// composeSystemPrompt takes a base system prompt and composes it with the
// current runtime context: AGENTS.md content, skills metadata, and date/cwd.
// This mirrors the composition done during Kit.New() initialization.
func (m *Kit) composeSystemPrompt(basePrompt string) string {
cwd, _ := os.Getwd()
pb := skills.NewPromptBuilder(basePrompt)
// Inject AGENTS.md content as project context.
for _, cf := range m.contextFiles {
pb.WithSection("", fmt.Sprintf("Instructions from: %s\n\n%s", cf.Path, cf.Content))
}
// Inject skills metadata.
if len(m.skills) > 0 {
pb.WithSkills(m.skills)
}
// Append current date/time and working directory.
pb.WithSection("", fmt.Sprintf(
"Current date and time: %s\nCurrent working directory: %s",
time.Now().Format("Monday, January 2, 2006, 3:04:05 PM MST"), cwd,
))
return pb.Build()
}
// GetAvailableModels returns a list of known models from the registry. Each
// entry includes provider, model ID, context limit, and whether the model
// supports reasoning. This is an advisory list — models not in the registry
@@ -445,6 +597,17 @@ type Options struct {
Tools []Tool // Custom tool set. If empty, AllTools() is used.
ExtraTools []Tool // Additional tools added alongside core/MCP/extension tools.
// SkipConfig, when true, skips loading .kit.yml configuration files.
// Viper defaults (setSDKDefaults) and environment variables (KIT_*)
// are still applied. Use this for fully programmatic configuration.
SkipConfig bool
// DisableCoreTools, when true, prevents loading any core tools.
// Use with Tools or ExtraTools to provide only custom tools.
// If both DisableCoreTools is true and Tools is empty, the agent
// will have no tools (useful for simple chat completions).
DisableCoreTools bool
// Session configuration
SessionDir string // Base directory for session discovery (default: cwd)
SessionPath string // Open a specific session file by path
@@ -474,6 +637,17 @@ type Options struct {
// display a URL in a custom UI, redirect to a web app, etc.).
MCPAuthHandler MCPAuthHandler
// MCPTokenStoreFactory, if non-nil, is called to create a token store for
// each remote MCP server that requires OAuth. The factory receives the
// server's URL and returns a [MCPTokenStore] implementation.
//
// When nil (default), tokens are persisted to a JSON file at
// $XDG_CONFIG_HOME/.kit/mcp_tokens.json (or ~/.config/.kit/mcp_tokens.json).
//
// Use this to store tokens in a database, encrypt them, keep them
// in-memory, or write them to a custom file path.
MCPTokenStoreFactory MCPTokenStoreFactory
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
// loading during Kit initialization. The callback receives the server name,
// tool count, and any error. Called from a background goroutine; safe to
@@ -483,6 +657,11 @@ type Options struct {
// CLI is optional CLI-specific configuration. SDK users leave this nil.
CLI *CLIOptions
// SessionManager allows custom session storage backends.
// If nil (default), Kit uses the built-in file-based TreeManager.
// When provided, SessionPath, Continue, and NoSession options are ignored.
SessionManager SessionManager
}
// CLIOptions holds fields only relevant to the CLI binary. SDK users should
@@ -554,16 +733,17 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
// provider creation, session init) then runs outside the lock, allowing
// parallel subagent spawns to proceed concurrently.
var (
providerConfig *models.ProviderConfig
modelString string
cwd string
contextFiles []*ContextFile
loadedSkills []*Skill
mcpConfig *config.Config
debug bool
noExtensions bool
maxSteps int
streaming bool
providerConfig *models.ProviderConfig
modelString string
cwd string
contextFiles []*ContextFile
loadedSkills []*Skill
mcpConfig *config.Config
debug bool
noExtensions bool
maxSteps int
streaming bool
hasCustomSystemPrompt bool
)
if err := func() error {
@@ -577,7 +757,8 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
// Initialize config (loads config files and env vars).
// Only initialize if not already done (e.g., by CLI's cobra.OnInitialize).
// Check if model is already set, which indicates config was loaded.
if viper.GetString("model") == "" {
// SkipConfig bypasses .kit.yml file loading (viper defaults and env vars still apply).
if !opts.SkipConfig && viper.GetString("model") == "" {
if err := InitConfig(opts.ConfigFile, false); err != nil {
return fmt.Errorf("failed to initialize config: %w", err)
}
@@ -618,8 +799,41 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
// Always compose the system prompt with runtime context: base prompt +
// AGENTS.md context + skills metadata + date/cwd.
//
// If the configured model has a per-model system prompt (via
// modelSettings or customModels params) and the user hasn't
// explicitly set system-prompt, use the per-model prompt as the
// base instead of the global default.
{
basePrompt := viper.GetString("system-prompt")
// Track whether the user explicitly configured a custom system
// prompt. When they haven't (basePrompt is the built-in default
// or empty), per-model system prompts can replace it on switch.
userSetSystemPrompt := basePrompt != "" && basePrompt != defaultSystemPrompt
hasCustomSystemPrompt = userSetSystemPrompt
// Check for per-model system prompt override when no explicit
// global system-prompt was configured by the user.
if !userSetSystemPrompt {
modelStr := viper.GetString("model")
if modelStr != "" {
if mi := models.LookupModelForSettings(modelStr); mi != nil {
var perModelParams *models.GenerationParams
// modelSettings takes priority over custom model params.
if ms := models.LoadModelSettingsFromConfig(); ms != nil {
perModelParams = ms[modelStr]
}
if perModelParams == nil && mi.Params != nil {
perModelParams = mi.Params
}
if perModelParams != nil && perModelParams.SystemPrompt != "" {
basePrompt = models.LoadSystemPromptValue(perModelParams.SystemPrompt)
}
}
}
}
pb := skills.NewPromptBuilder(basePrompt)
// Inject AGENTS.md content as project context.
@@ -689,6 +903,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
MCPConfig: mcpConfig,
Quiet: opts.Quiet,
CoreTools: opts.Tools,
DisableCoreTools: opts.DisableCoreTools,
ExtraTools: opts.ExtraTools,
ToolWrapper: hookToolWrapper(beforeToolCall, afterToolResult),
ProviderConfig: providerConfig,
@@ -709,12 +924,19 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
defaultHandler, authErr := NewDefaultMCPAuthHandler()
if authErr != nil {
// Non-fatal: OAuth just won't be available for remote servers.
charmlog.Warn("Failed to create OAuth handler; remote MCP servers requiring auth will fail", "error", authErr)
log.Printf("WARN Failed to create OAuth handler; remote MCP servers requiring auth will fail: %v", authErr)
} else {
setupOpts.AuthHandler = defaultHandler
}
}
// Set up custom token store factory for MCP OAuth tokens.
// The SDK MCPTokenStoreFactory is structurally identical to
// tools.TokenStoreFactory, so it can be assigned directly.
if opts.MCPTokenStoreFactory != nil {
setupOpts.TokenStoreFactory = tools.TokenStoreFactory(opts.MCPTokenStoreFactory)
}
if opts.CLI != nil {
setupOpts.ShowSpinner = opts.CLI.ShowSpinner
setupOpts.SpinnerFunc = opts.CLI.SpinnerFunc
@@ -727,32 +949,42 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
return nil, err
}
// Initialize tree session.
treeSession, err := InitTreeSession(opts)
if err != nil {
_ = agentResult.Agent.Close()
return nil, fmt.Errorf("failed to initialize session: %w", err)
// Initialize session manager.
var sessionManager SessionManager
if opts.SessionManager != nil {
// Use custom session manager provided by user.
sessionManager = opts.SessionManager
} else {
// DEFAULT: Use built-in TreeManager (existing behavior).
treeSession, err := InitTreeSession(opts)
if err != nil {
_ = agentResult.Agent.Close()
return nil, fmt.Errorf("failed to initialize session: %w", err)
}
// Wrap TreeManager in adapter to satisfy SessionManager interface.
sessionManager = NewTreeManagerAdapter(treeSession)
}
k := &Kit{
agent: agentResult.Agent,
treeSession: treeSession,
modelString: modelString,
events: newEventBus(),
autoCompact: opts.AutoCompact,
compactionOpts: opts.CompactionOptions,
contextFiles: contextFiles,
skills: loadedSkills,
extRunner: agentResult.ExtRunner,
bufferedLogger: agentResult.BufferedLogger,
authHandler: setupOpts.AuthHandler,
opts: opts,
beforeToolCall: beforeToolCall,
afterToolResult: afterToolResult,
beforeTurn: beforeTurn,
afterTurn: afterTurn,
contextPrepare: contextPrepare,
beforeCompact: beforeCompact,
agent: agentResult.Agent,
session: sessionManager,
modelString: modelString,
events: newEventBus(),
autoCompact: opts.AutoCompact,
compactionOpts: opts.CompactionOptions,
contextFiles: contextFiles,
skills: loadedSkills,
extRunner: agentResult.ExtRunner,
bufferedLogger: agentResult.BufferedLogger,
authHandler: setupOpts.AuthHandler,
opts: opts,
hasCustomSystemPrompt: hasCustomSystemPrompt,
beforeToolCall: beforeToolCall,
afterToolResult: afterToolResult,
beforeTurn: beforeTurn,
afterTurn: afterTurn,
contextPrepare: contextPrepare,
beforeCompact: beforeCompact,
}
// Bridge extension events to SDK hooks.
@@ -1270,14 +1502,22 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
IsStderr: isStderr,
})
},
// Persist step messages incrementally so that progress survives
// crashes and long-running turns don't lose work. Each step's
// messages are persisted as a unit: for tool-calling steps this is
// the assistant message (with tool_use parts) + tool-role message
// (with tool_result parts) as a pair; for the final step it's the
// assistant text/reasoning message alone.
func(stepMessages []fantasy.Message) {
for _, msg := range stepMessages {
_, _ = m.session.AppendMessage(msg)
}
},
func(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64) {
// Emit step usage event for real-time cost tracking
if viper.GetBool("debug") {
charmlog.Debug("Kit.generate emitting StepUsageEvent",
"input", inputTokens,
"output", outputTokens,
"cacheRead", cacheReadTokens,
"cacheCreate", cacheCreationTokens,
log.Printf("DEBUG Kit.generate emitting StepUsageEvent: input=%d output=%d cacheRead=%d cacheCreate=%d",
inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens,
)
}
m.events.emit(StepUsageEvent{
@@ -1295,11 +1535,17 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
// 2. Persist pre-generation messages to the tree session.
// 3. Build context from the tree (walks leaf-to-root for current branch).
// 4. Emit turn/message start events.
// 5. Run generation.
// 6. Emit turn/message end events.
// 7. Persist post-generation messages (tool calls, results, assistant).
// 5. Run generation (messages are persisted incrementally per step).
// 6. Persist any remaining messages not covered by incremental persistence.
// 7. Emit turn/message end events.
// 8. Run AfterTurn hooks.
//
// During generation, each completed step's messages are persisted immediately
// via the onStepMessages callback. Tool calls are always persisted as
// call/response pairs (assistant + tool messages together). Reasoning and
// text-only assistant messages are persisted as soon as their step completes.
// This ensures long-running turns don't lose progress on crash or cancellation.
//
// promptLabel is the human-readable label emitted in TurnStartEvent.Prompt.
// prompt is the raw user text passed to BeforeTurn hooks.
func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, preMessages []fantasy.Message) (*TurnResult, error) {
@@ -1344,9 +1590,9 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
}
}
// Persist pre-generation messages to tree session.
// Persist pre-generation messages to session.
for _, msg := range preMessages {
_, _ = m.treeSession.AppendLLMMessage(msg)
_, _ = m.session.AppendMessage(msg)
}
// Auto-compact if enabled and conversation is near the context limit.
@@ -1354,8 +1600,8 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
_, _ = m.compactInternal(ctx, m.compactionOpts, "", true) // best-effort, automatic
}
// Build context from the tree so only the current branch is sent.
messages := m.treeSession.GetLLMMessages()
// Build context from the session so only the current branch is sent.
messages, _, _ := m.session.BuildContext()
// Run ContextPrepare hooks — extensions can filter, reorder, or inject messages.
if hookResult := m.contextPrepare.run(ContextPrepareHook{Messages: messages}); hookResult != nil && hookResult.Messages != nil {
@@ -1369,16 +1615,18 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
result, err := m.generate(ctx, messages)
if err != nil {
// Persist any messages from completed steps (tool call/result
// pairs) so partial progress is not lost. The agent layer only
// includes fully-paired tool_use + tool_result messages in
// completedStepMessages, so there are no orphaned entries that
// would break subsequent API requests. The user message and any
// completed work remain in the session; only the in-progress
// (pending) message or tool call is discarded.
if result != nil && len(result.ConversationMessages) > sentCount {
for _, msg := range result.ConversationMessages[sentCount:] {
_, _ = m.treeSession.AppendLLMMessage(msg)
// Persist any messages from completed steps that were NOT already
// persisted incrementally by the onStepMessages callback. The agent
// layer only includes fully-paired tool_use + tool_result messages
// in completedStepMessages, so there are no orphaned entries that
// would break subsequent API requests.
if result != nil {
newMessages := result.ConversationMessages[sentCount:]
alreadyPersisted := result.PersistedMessageCount
if alreadyPersisted < len(newMessages) {
for _, msg := range newMessages[alreadyPersisted:] {
_, _ = m.session.AppendMessage(msg)
}
}
}
m.events.emit(TurnEndEvent{Error: err})
@@ -1389,12 +1637,17 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
responseText := result.FinalResponse.Content.Text()
// Persist new messages (tool calls, tool results, assistant response)
// BEFORE emitting events so that extension handlers calling
// GetContextStats() see up-to-date token counts.
// Persist any new messages that were NOT already persisted incrementally
// by the onStepMessages callback during generation. This handles the
// non-streaming path (where onStepMessages is not called) and any edge
// cases where the final response messages weren't covered by step callbacks.
if len(result.ConversationMessages) > sentCount {
for _, msg := range result.ConversationMessages[sentCount:] {
_, _ = m.treeSession.AppendLLMMessage(msg)
newMessages := result.ConversationMessages[sentCount:]
alreadyPersisted := result.PersistedMessageCount
if alreadyPersisted < len(newMessages) {
for _, msg := range newMessages[alreadyPersisted:] {
_, _ = m.session.AppendMessage(msg)
}
}
}
@@ -1476,7 +1729,7 @@ func (m *Kit) Steer(ctx context.Context, instruction string) (string, error) {
// Returns an error if there are no previous messages in the session.
func (m *Kit) FollowUp(ctx context.Context, text string) (string, error) {
// Verify there is conversation history to follow up on.
if len(m.treeSession.GetLLMMessages()) == 0 {
if len(m.session.GetMessages()) == 0 {
return "", fmt.Errorf("cannot follow up: no previous messages")
}
@@ -1632,10 +1885,12 @@ func (m *Kit) PromptResultWithMessages(ctx context.Context, messages []string) (
return m.runTurn(ctx, promptLabel, messages[len(messages)-1], preMessages)
}
// ClearSession resets the tree session's leaf pointer to the root, starting
// ClearSession resets the session's leaf pointer to the root, starting
// a fresh conversation branch.
func (m *Kit) ClearSession() {
m.treeSession.ResetLeaf()
if m.session != nil {
_ = m.session.Branch("")
}
}
// GetModelString returns the current model string identifier (e.g.,
@@ -1704,8 +1959,8 @@ func (m *Kit) Close() error {
if m.extRunner != nil && m.extRunner.HasHandlers(extensions.SessionShutdown) {
_, _ = m.extRunner.Emit(extensions.SessionShutdownEvent{})
}
if m.treeSession != nil {
_ = m.treeSession.Close()
if m.session != nil {
_ = m.session.Close()
}
// Release the OAuth callback port if we own the handler.
if closer, ok := m.authHandler.(interface{ Close() error }); ok {
@@ -1713,3 +1968,5 @@ func (m *Kit) Close() error {
}
return m.agent.Close()
}
// Conversion helpers are defined in adapter.go.
+56
View File
@@ -0,0 +1,56 @@
package kit_test
import (
"testing"
kit "github.com/mark3labs/kit/pkg/kit"
)
// TestMCPServerStatus_TypeSurface verifies the MCPServerStatus type is
// accessible and has the expected fields.
func TestMCPServerStatus_TypeSurface(t *testing.T) {
s := kit.MCPServerStatus{
Name: "test-server",
ToolCount: 5,
}
if s.Name != "test-server" {
t.Errorf("Expected Name 'test-server', got %q", s.Name)
}
if s.ToolCount != 5 {
t.Errorf("Expected ToolCount 5, got %d", s.ToolCount)
}
}
// TestMCPServerConfig_ForDynamicAdd verifies that MCPServerConfig can be
// constructed with the expected fields for dynamic server management.
func TestMCPServerConfig_ForDynamicAdd(t *testing.T) {
// Stdio server config.
stdio := kit.MCPServerConfig{
Command: []string{"npx", "-y", "@modelcontextprotocol/server-github"},
Environment: map[string]string{"GITHUB_TOKEN": "test-token"},
}
if len(stdio.Command) != 3 {
t.Errorf("Expected 3 command parts, got %d", len(stdio.Command))
}
if stdio.Environment["GITHUB_TOKEN"] != "test-token" {
t.Error("Expected GITHUB_TOKEN in environment")
}
// Remote server config.
remote := kit.MCPServerConfig{
URL: "https://mcp.example.com/sse",
Headers: []string{"Authorization: Bearer test"},
}
if remote.URL != "https://mcp.example.com/sse" {
t.Errorf("Unexpected URL: %s", remote.URL)
}
// Config with tool filtering.
filtered := kit.MCPServerConfig{
Command: []string{"some-server"},
AllowedTools: []string{"read", "write"},
}
if len(filtered.AllowedTools) != 2 {
t.Errorf("Expected 2 allowed tools, got %d", len(filtered.AllowedTools))
}
}
+144
View File
@@ -0,0 +1,144 @@
package kit
import (
"time"
)
// SessionManager defines the contract for conversation storage backends.
// Implementations can use files (default), databases, cloud storage, etc.
//
// Implementations must be safe for concurrent use. During generation,
// AppendMessage is called incrementally from the agent's step-completion
// callback while read methods (GetMessages, GetCurrentBranch, etc.) may be
// called concurrently from the UI or extension goroutines.
type SessionManager interface {
// AppendMessage adds a message to the current branch and returns its entry ID.
// The entry ID is used for tree navigation and must be unique within the session.
//
// During generation, AppendMessage is called incrementally after each
// completed agent step rather than in a batch at the end of the turn.
// For tool-calling steps, the assistant message (containing tool_use parts)
// and the tool-role message (containing tool_result parts) are appended
// together as a pair. This ensures the session never contains an orphaned
// tool call without its result, which would break subsequent LLM requests.
AppendMessage(msg LLMMessage) (entryID string, err error)
// GetMessages returns all messages on the current branch (from root to leaf),
// including any compaction summaries at the appropriate positions.
GetMessages() []LLMMessage
// BuildContext returns the message history to send to the LLM, applying
// compaction rules and branch summaries as needed.
// Returns: messages, currentProvider, currentModelID
BuildContext() (messages []LLMMessage, provider string, modelID string)
// Branch moves the leaf pointer to the given entry ID, creating a branch point.
// Subsequent AppendMessage calls extend from this new position.
// entryID can be empty to reset to root (new conversation branch).
Branch(entryID string) error
// GetCurrentBranch returns the path from root to current leaf as entry metadata.
// Used for UI display and navigation.
GetCurrentBranch() []BranchEntry
// GetChildren returns direct child entry IDs for a given parent entry.
// Used to display branch points in the conversation tree.
GetChildren(parentID string) []string
// GetEntry returns a specific entry by ID, or nil if not found.
GetEntry(entryID string) *BranchEntry
// GetSessionID returns the unique session identifier (UUID).
GetSessionID() string
// GetSessionName returns the user-defined display name, or empty.
GetSessionName() string
// SetSessionName sets a display name for the session.
SetSessionName(name string) error
// GetCreatedAt returns when the session was created.
GetCreatedAt() time.Time
// IsPersisted returns true if this session writes to durable storage.
IsPersisted() bool
// AppendCompaction adds a compaction entry that summarizes older messages.
// firstKeptEntryID is the ID of the first message to preserve in context.
// readFiles and modifiedFiles track file changes for the compaction summary.
AppendCompaction(summary string, firstKeptEntryID string,
tokensBefore, tokensAfter int, messagesRemoved int, readFiles, modifiedFiles []string) (string, error)
// GetLastCompaction returns the most recent compaction entry on the current
// branch, or nil if none exists.
GetLastCompaction() *CompactionEntry
// AppendExtensionData stores custom extension data in the session tree.
// Extensions use this to persist state across restarts.
AppendExtensionData(extType, data string) (string, error)
// GetExtensionData returns all extension data entries of the given type
// on the current branch. If extType is empty, returns all extension data.
GetExtensionData(extType string) []ExtensionDataEntry
// AppendModelChange records a provider/model switch in the session.
AppendModelChange(provider, modelID string) (string, error)
// GetContextEntryIDs returns the entry IDs corresponding to the messages
// returned by BuildContext, in the same order. Used by compaction to
// determine which entries to summarize.
GetContextEntryIDs() []string
// Close releases resources (database connections, file handles, etc.).
Close() error
}
// BranchEntry represents a single node in the conversation tree.
// This is a SDK-friendly struct (not the internal entry types).
type BranchEntry struct {
ID string
ParentID string
Type EntryType // "message", "branch_summary", "model_change", "compaction", "extension_data"
Role string // for messages: "user", "assistant", "system", "tool"
Content string // text content or summary
Model string // model used (for messages and model_change)
Provider string // provider used
Timestamp time.Time
Children []string // child entry IDs (for tree display)
// RawParts contains the full typed content parts for structured access.
// Only populated for message entries.
RawParts []ContentPart
}
// EntryType identifies the kind of entry in the session tree.
type EntryType string
const (
EntryTypeMessage EntryType = "message"
EntryTypeBranchSummary EntryType = "branch_summary"
EntryTypeModelChange EntryType = "model_change"
EntryTypeCompaction EntryType = "compaction"
EntryTypeExtensionData EntryType = "extension_data"
)
// CompactionEntry represents a context compaction/summarization event.
type CompactionEntry struct {
ID string
Summary string
FirstKeptEntryID string
TokensBefore int
TokensAfter int
MessagesRemoved int
ReadFiles []string
ModifiedFiles []string
Timestamp time.Time
}
// ExtensionDataEntry represents custom extension data stored in the session.
type ExtensionDataEntry struct {
ID string
ExtType string
Data string
Timestamp time.Time
}
+111 -80
View File
@@ -8,7 +8,6 @@ import (
"time"
"github.com/mark3labs/kit/internal/extensions"
"github.com/mark3labs/kit/internal/message"
"github.com/mark3labs/kit/internal/session"
)
@@ -47,49 +46,73 @@ func OpenTreeSession(path string) (*TreeManager, error) {
// --- Instance methods on Kit ---
// GetSessionManager returns the session manager, or nil if not configured.
func (m *Kit) GetSessionManager() SessionManager {
return m.session
}
// GetTreeSession returns the tree session manager, or nil if not configured.
// Deprecated: Use GetSessionManager instead.
func (m *Kit) GetTreeSession() *TreeManager {
return m.treeSession
// Try to unwrap the adapter if using default implementation
if adapter, ok := m.session.(*treeManagerAdapter); ok {
return adapter.inner
}
return nil
}
// SetSessionManager replaces the session manager on a Kit instance.
func (m *Kit) SetSessionManager(sm SessionManager) {
m.session = sm
}
// SetTreeSession replaces the tree session on a Kit instance. This is used by
// the CLI when it handles session creation externally (e.g. --resume with a
// TUI picker) and needs to inject the result into a Kit-like workflow.
// Deprecated: Use SetSessionManager instead.
func (m *Kit) SetTreeSession(ts *TreeManager) {
m.treeSession = ts
m.session = NewTreeManagerAdapter(ts)
}
// GetSessionPath returns the file path of the active tree session, or empty
// for in-memory sessions or when no tree session is configured.
// GetSessionPath returns the file path of the active session, or empty
// for in-memory sessions or when no file-based session is configured.
func (m *Kit) GetSessionPath() string {
if m.treeSession != nil {
return m.treeSession.GetFilePath()
// Only file-based sessions have a path
// Try to get it from the underlying TreeManager if using default adapter
if m.session == nil {
return ""
}
// Check if it's the default adapter
if adapter, ok := m.session.(*treeManagerAdapter); ok {
return adapter.inner.GetFilePath()
}
return ""
}
// GetSessionID returns the UUID of the active tree session, or empty when no
// tree session is configured.
// GetSessionID returns the UUID of the active session, or empty when no
// session is configured.
func (m *Kit) GetSessionID() string {
if m.treeSession != nil {
return m.treeSession.GetSessionID()
if m.session == nil {
return ""
}
return ""
return m.session.GetSessionID()
}
// Branch moves the tree session's leaf pointer to the given entry ID, creating
// Branch moves the session's leaf pointer to the given entry ID, creating
// a branch point. Subsequent Prompt() calls will extend from the new position.
func (m *Kit) Branch(entryID string) error {
return m.treeSession.Branch(entryID)
if m.session == nil {
return fmt.Errorf("no session available")
}
return m.session.Branch(entryID)
}
// SetSessionName sets a user-defined display name for the active tree session.
// SetSessionName sets a user-defined display name for the active session.
func (m *Kit) SetSessionName(name string) error {
if m.treeSession == nil {
return fmt.Errorf("session naming requires a tree session")
if m.session == nil {
return fmt.Errorf("session naming requires a session")
}
_, err := m.treeSession.AppendSessionInfo(name)
return err
return m.session.SetSessionName(name)
}
// ---------------------------------------------------------------------------
@@ -97,27 +120,27 @@ func (m *Kit) SetSessionName(name string) error {
// ---------------------------------------------------------------------------
// GetTreeNode returns a node by ID with full metadata and children.
// Returns nil if entry not found or no tree session.
// Returns nil if entry not found or no session.
func (m *Kit) GetTreeNode(entryID string) *TreeNode {
if m.treeSession == nil {
if m.session == nil {
return nil
}
entry := m.treeSession.GetEntry(entryID)
entry := m.session.GetEntry(entryID)
if entry == nil {
return nil
}
return m.entryToTreeNode(entry)
return m.branchEntryToTreeNode(entry)
}
// GetCurrentBranch returns the path from root to current leaf as TreeNodes.
func (m *Kit) GetCurrentBranch() []TreeNode {
if m.treeSession == nil {
if m.session == nil {
return nil
}
branch := m.treeSession.GetBranch("")
branch := m.session.GetCurrentBranch()
var nodes []TreeNode
for _, entry := range branch {
node := m.entryToTreeNode(entry)
node := m.branchEntryToTreeNode(&entry)
if node != nil {
nodes = append(nodes, *node)
}
@@ -127,34 +150,34 @@ func (m *Kit) GetCurrentBranch() []TreeNode {
// GetChildren returns direct child IDs of an entry.
func (m *Kit) GetChildren(parentID string) []string {
if m.treeSession == nil {
if m.session == nil {
return nil
}
return m.treeSession.GetChildren(parentID)
return m.session.GetChildren(parentID)
}
// NavigateTo branches/forks the session to the specified entry ID.
// Returns an error if the session is unavailable or the entry ID is not found.
func (m *Kit) NavigateTo(entryID string) error {
if m.treeSession == nil {
return fmt.Errorf("no tree session available")
if m.session == nil {
return fmt.Errorf("no session available")
}
return m.treeSession.Branch(entryID)
return m.session.Branch(entryID)
}
// SummarizeBranch uses the LLM to summarize the conversation between two
// entry IDs. Returns the summary text, or an error if the range is invalid,
// the session is unavailable, or the LLM call fails.
func (m *Kit) SummarizeBranch(fromID, toID string) (string, error) {
if m.treeSession == nil {
return "", fmt.Errorf("no tree session available")
if m.session == nil {
return "", fmt.Errorf("no session available")
}
// Get the branch and find the range
branch := m.treeSession.GetBranch("")
branch := m.session.GetCurrentBranch()
var startIdx, endIdx = -1, -1
for i, entry := range branch {
id := m.treeSession.EntryID(entry)
id := entry.ID
if id == fromID {
startIdx = i
}
@@ -170,7 +193,7 @@ func (m *Kit) SummarizeBranch(fromID, toID string) (string, error) {
// Build text to summarize
var content strings.Builder
for i := startIdx; i <= endIdx; i++ {
node := m.entryToTreeNode(branch[i])
node := m.branchEntryToTreeNode(&branch[i])
if node != nil && node.Content != "" {
fmt.Fprintf(&content, "[%s] %s\n\n", node.Role, node.Content)
}
@@ -195,73 +218,81 @@ func (m *Kit) SummarizeBranch(fromID, toID string) (string, error) {
// CollapseBranch replaces a branch range with a summary entry.
// Returns an error if the session is unavailable or the operation fails.
func (m *Kit) CollapseBranch(fromID, toID, summary string) error {
if m.treeSession == nil {
return fmt.Errorf("no tree session available")
if m.session == nil {
return fmt.Errorf("no session available")
}
_, err := m.treeSession.AppendBranchSummary(fromID, summary)
return err
// Note: This operation is not directly supported by SessionManager interface
// as it requires AppendBranchSummary which is TreeManager-specific.
// For custom SessionManagers, this would need to be implemented differently.
// For now, we try to use the underlying TreeManager if available.
if adapter, ok := m.session.(*treeManagerAdapter); ok {
_, err := adapter.inner.AppendBranchSummary(fromID, summary)
return err
}
return fmt.Errorf("CollapseBranch not supported by custom session manager")
}
// entryToTreeNode converts a session entry to a TreeNode.
func (m *Kit) entryToTreeNode(entry any) *TreeNode {
switch e := entry.(type) {
case *session.MessageEntry:
msg, err := e.ToMessage()
if err != nil {
return nil
}
// branchEntryToTreeNode converts a BranchEntry to a TreeNode.
func (m *Kit) branchEntryToTreeNode(entry *BranchEntry) *TreeNode {
if entry == nil {
return nil
}
switch entry.Type {
case EntryTypeMessage:
// Build content from RawParts
var content strings.Builder
for _, p := range msg.Parts {
for _, p := range entry.RawParts {
switch pt := p.(type) {
case message.TextContent:
case TextContent:
content.WriteString(pt.Text)
case message.ReasoningContent:
case ReasoningContent:
content.WriteString(pt.Thinking)
case message.ToolCall:
case ToolCall:
fmt.Fprintf(&content, "[tool_call: %s]", pt.Name)
case message.ToolResult:
case ToolResult:
fmt.Fprintf(&content, "[tool_result: %s]", pt.Content)
}
}
return &TreeNode{
ID: e.ID,
ParentID: e.ParentID,
ID: entry.ID,
ParentID: entry.ParentID,
Type: "message",
Role: string(msg.Role),
Role: entry.Role,
Content: content.String(),
Model: msg.Model,
Provider: msg.Provider,
Timestamp: e.Timestamp.Format(time.RFC3339),
Children: m.treeSession.GetChildren(e.ID),
Model: entry.Model,
Provider: entry.Provider,
Timestamp: entry.Timestamp.Format(time.RFC3339),
Children: m.session.GetChildren(entry.ID),
}
case *session.BranchSummaryEntry:
case EntryTypeBranchSummary:
return &TreeNode{
ID: e.ID,
ParentID: e.ParentID,
ID: entry.ID,
ParentID: entry.ParentID,
Type: "branch_summary",
Content: e.Summary,
Timestamp: e.Timestamp.Format(time.RFC3339),
Children: m.treeSession.GetChildren(e.ID),
Content: entry.Content,
Timestamp: entry.Timestamp.Format(time.RFC3339),
Children: m.session.GetChildren(entry.ID),
}
case *session.ModelChangeEntry:
case EntryTypeModelChange:
return &TreeNode{
ID: e.ID,
ParentID: e.ParentID,
ID: entry.ID,
ParentID: entry.ParentID,
Type: "model_change",
Content: fmt.Sprintf("Model changed to %s/%s", e.Provider, e.ModelID),
Model: e.Provider + "/" + e.ModelID,
Provider: e.Provider,
Timestamp: e.Timestamp.Format(time.RFC3339),
Children: m.treeSession.GetChildren(e.ID),
Content: entry.Content,
Model: entry.Model,
Provider: entry.Provider,
Timestamp: entry.Timestamp.Format(time.RFC3339),
Children: m.session.GetChildren(entry.ID),
}
case *session.ExtensionDataEntry:
case EntryTypeExtensionData:
return &TreeNode{
ID: e.ID,
ParentID: e.ParentID,
ID: entry.ID,
ParentID: entry.ParentID,
Type: "extension_data",
Content: fmt.Sprintf("Extension data: %s", e.ExtType),
Timestamp: e.Timestamp.Format(time.RFC3339),
Children: m.treeSession.GetChildren(e.ID),
Content: entry.Content,
Timestamp: entry.Timestamp.Format(time.RFC3339),
Children: m.session.GetChildren(entry.ID),
}
default:
return nil
+119
View File
@@ -1,6 +1,8 @@
package kit
import (
"context"
"charm.land/fantasy"
"github.com/mark3labs/kit/internal/core"
@@ -16,6 +18,123 @@ type ToolOption = core.ToolOption
// If empty, os.Getwd() is used at execution time.
var WithWorkDir = core.WithWorkDir
// --- Custom tool creation ---
// ToolOutput is the return value from custom tool handlers created with
// [NewTool] or [NewParallelTool]. It provides a dependency-free way to
// return results without importing the underlying LLM framework.
type ToolOutput struct {
// Content is the text content returned to the LLM.
Content string
// IsError, when true, signals to the LLM that the tool call failed.
IsError bool
// Data contains optional binary data (images, audio, etc.).
Data []byte
// MediaType is the MIME type for binary Data (e.g. "image/png").
MediaType string
// Metadata is optional opaque metadata attached to the response.
// It is not sent to the LLM but may be consumed by hooks or the UI.
Metadata any
}
// TextResult creates a successful text [ToolOutput].
func TextResult(content string) ToolOutput {
return ToolOutput{Content: content}
}
// ErrorResult creates an error [ToolOutput]. The LLM will see the content
// as a tool error, allowing it to retry or adjust its approach.
func ErrorResult(content string) ToolOutput {
return ToolOutput{Content: content, IsError: true}
}
// toolCallIDKey is the context key for the tool call ID.
type toolCallIDKey struct{}
// ToolCallIDFromContext extracts the tool call ID from the context.
// The call ID is set automatically by [NewTool] and [NewParallelTool]
// before invoking the handler. Returns an empty string if no ID is present.
func ToolCallIDFromContext(ctx context.Context) string {
s, _ := ctx.Value(toolCallIDKey{}).(string)
return s
}
// NewTool creates a custom [Tool] with automatic JSON schema generation from
// the TInput struct type. The handler receives a typed input (deserialized
// from the LLM's JSON arguments) and returns a [ToolResult].
//
// Struct tags on TInput control the generated schema:
//
// json:"name" → parameter name
// description:"..." → parameter description shown to the LLM
// enum:"a,b,c" → restrict valid values
// omitempty → marks the parameter as optional
//
// The tool call ID is injected into the context and can be retrieved with
// [ToolCallIDFromContext].
//
// Example:
//
// type WeatherInput struct {
// City string `json:"city" description:"City name"`
// }
//
// tool := kit.NewTool("get_weather", "Get weather for a city",
// func(ctx context.Context, input WeatherInput) (kit.ToolResult, error) {
// return kit.TextResult("72°F, sunny in " + input.City), nil
// },
// )
func NewTool[TInput any](name, description string, fn func(ctx context.Context, input TInput) (ToolOutput, error)) Tool {
return fantasy.NewAgentTool(name, description,
func(ctx context.Context, input TInput, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
ctx = context.WithValue(ctx, toolCallIDKey{}, call.ID)
result, err := fn(ctx, input)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
resp := fantasy.ToolResponse{
Content: result.Content,
IsError: result.IsError,
Data: result.Data,
MediaType: result.MediaType,
}
if result.Metadata != nil {
resp = fantasy.WithResponseMetadata(resp, result.Metadata)
}
return resp, nil
},
)
}
// NewParallelTool is like [NewTool] but marks the tool as safe for concurrent
// execution alongside other tools. Use this when the tool has no side effects
// or when concurrent calls are safe.
func NewParallelTool[TInput any](name, description string, fn func(ctx context.Context, input TInput) (ToolOutput, error)) Tool {
return fantasy.NewParallelAgentTool(name, description,
func(ctx context.Context, input TInput, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
ctx = context.WithValue(ctx, toolCallIDKey{}, call.ID)
result, err := fn(ctx, input)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
resp := fantasy.ToolResponse{
Content: result.Content,
IsError: result.IsError,
Data: result.Data,
MediaType: result.MediaType,
}
if result.Metadata != nil {
resp = fantasy.WithResponseMetadata(resp, result.Metadata)
}
return resp, nil
},
)
}
// --- Individual tool constructors ---
// NewReadTool creates a file-reading tool.
+119
View File
@@ -0,0 +1,119 @@
package kit_test
import (
"context"
"testing"
kit "github.com/mark3labs/kit/pkg/kit"
)
// TestNewTool_BasicTextResult verifies that NewTool creates a working tool
// that returns text content via ToolOutput.
func TestNewTool_BasicTextResult(t *testing.T) {
type Input struct {
Name string `json:"name"`
}
tool := kit.NewTool("greet", "Greet someone",
func(ctx context.Context, input Input) (kit.ToolOutput, error) {
return kit.TextResult("hello " + input.Name), nil
},
)
info := tool.Info()
if info.Name != "greet" {
t.Errorf("Info().Name = %q, want %q", info.Name, "greet")
}
if info.Description != "Greet someone" {
t.Errorf("Info().Description = %q, want %q", info.Description, "Greet someone")
}
if info.Parallel {
t.Error("NewTool should not mark tool as parallel")
}
}
// TestNewParallelTool_MarkedParallel verifies that NewParallelTool marks the
// tool as safe for concurrent execution.
func TestNewParallelTool_MarkedParallel(t *testing.T) {
type Input struct {
Query string `json:"query"`
}
tool := kit.NewParallelTool("search", "Search for things",
func(ctx context.Context, input Input) (kit.ToolOutput, error) {
return kit.TextResult("found: " + input.Query), nil
},
)
info := tool.Info()
if info.Name != "search" {
t.Errorf("Info().Name = %q, want %q", info.Name, "search")
}
if !info.Parallel {
t.Error("NewParallelTool should mark tool as parallel")
}
}
// TestTextResult verifies the TextResult convenience constructor.
func TestTextResult(t *testing.T) {
r := kit.TextResult("ok")
if r.Content != "ok" {
t.Errorf("Content = %q, want %q", r.Content, "ok")
}
if r.IsError {
t.Error("TextResult should not set IsError")
}
}
// TestErrorResult verifies the ErrorResult convenience constructor.
func TestErrorResult(t *testing.T) {
r := kit.ErrorResult("bad input")
if r.Content != "bad input" {
t.Errorf("Content = %q, want %q", r.Content, "bad input")
}
if !r.IsError {
t.Error("ErrorResult should set IsError")
}
}
// TestToolCallIDFromContext verifies round-trip context injection.
func TestToolCallIDFromContext(t *testing.T) {
// Empty context returns empty string.
if id := kit.ToolCallIDFromContext(context.Background()); id != "" {
t.Errorf("expected empty string from bare context, got %q", id)
}
}
// TestToolOutput_Metadata verifies that metadata can be set on ToolOutput.
func TestToolOutput_Metadata(t *testing.T) {
r := kit.ToolOutput{
Content: "data",
Metadata: map[string]string{"key": "value"},
}
if r.Metadata == nil {
t.Error("expected non-nil Metadata")
}
m, ok := r.Metadata.(map[string]string)
if !ok {
t.Fatalf("expected map[string]string, got %T", r.Metadata)
}
if m["key"] != "value" {
t.Errorf("Metadata[key] = %q, want %q", m["key"], "value")
}
}
// TestToolOutput_BinaryData verifies that binary data fields work correctly.
func TestToolOutput_BinaryData(t *testing.T) {
data := []byte{0x89, 0x50, 0x4E, 0x47}
r := kit.ToolOutput{
Content: "image result",
Data: data,
MediaType: "image/png",
}
if len(r.Data) != 4 {
t.Errorf("Data len = %d, want 4", len(r.Data))
}
if r.MediaType != "image/png" {
t.Errorf("MediaType = %q, want %q", r.MediaType, "image/png")
}
}
+24
View File
@@ -11,6 +11,7 @@ import (
"github.com/mark3labs/kit/internal/message"
"github.com/mark3labs/kit/internal/models"
"github.com/mark3labs/kit/internal/session"
"github.com/mark3labs/mcp-go/client/transport"
)
// ==== Message Types (internal/message/content.go) ====
@@ -204,6 +205,29 @@ type CompactionResult = compaction.CompactionResult
// CompactionOptions configures compaction behaviour.
type CompactionOptions = compaction.CompactionOptions
// ==== MCP OAuth Types ====
// MCPTokenStore persists OAuth tokens for a single MCP server. Implementations
// must be safe for concurrent use.
//
// This is a type alias for the mcp-go transport.TokenStore interface. SDK
// consumers can implement this interface to provide custom storage backends
// (database, encrypted file, in-memory, etc.).
type MCPTokenStore = transport.TokenStore
// MCPToken represents an OAuth token for an MCP server, containing access
// and refresh tokens along with expiration metadata.
type MCPToken = transport.Token
// MCPTokenStoreFactory creates an [MCPTokenStore] for a given MCP server URL.
// It is called once per remote MCP server during connection setup.
type MCPTokenStoreFactory func(serverURL string) (MCPTokenStore, error)
// ErrMCPNoToken is the sentinel error that [MCPTokenStore] implementations
// should return from GetToken when no token is stored for the server.
// Callers can check for this with errors.Is.
var ErrMCPNoToken = transport.ErrNoToken
// ==== Constructor & Helper Functions ====
// ParseModelString parses a model string in "provider/model" format.
+144 -2
View File
@@ -85,10 +85,15 @@ host, err := kit.New(ctx, &kit.Options{
SessionPath: "/path/to/session.jsonl", // open specific session file
Continue: true, // resume most recent session for SessionDir
NoSession: true, // ephemeral in-memory session, no disk persistence
SessionManager: myCustomSession, // custom SessionManager implementation (advanced)
// Tools
Tools: []kit.Tool{kit.NewBashTool()}, // REPLACES entire default tool set
ExtraTools: []kit.Tool{myTool}, // ADDS alongside core/MCP/extension tools
Tools: []kit.Tool{kit.NewBashTool()}, // REPLACES entire default tool set
ExtraTools: []kit.Tool{myTool}, // ADDS alongside core/MCP/extension tools
DisableCoreTools: true, // Use no core tools (0 tools, for chat-only)
// Configuration
SkipConfig: true, // Skip .kit.yml files (viper defaults + env vars still apply)
// Skills
Skills: []string{"/path/to/skill.md"}, // explicit skill files (empty = auto-discover)
@@ -342,6 +347,77 @@ Lower values run first. Within the same priority, registration order applies. Fi
## Tools
### Creating custom tools
Use `kit.NewTool` to create custom tools. The JSON schema is auto-generated from the input struct — no external dependencies required:
```go
type WeatherInput struct {
City string `json:"city" description:"City name, e.g. 'San Francisco'"`
}
weatherTool := kit.NewTool("get_weather", "Get current weather for a city",
func(ctx context.Context, input WeatherInput) (kit.ToolOutput, error) {
// Your logic here (API calls, database lookups, etc.)
return kit.TextResult("72°F, sunny in " + input.City), nil
},
)
host, _ := kit.New(ctx, &kit.Options{
ExtraTools: []kit.Tool{weatherTool},
})
```
**Struct tags** control the generated schema:
| Tag | Purpose | Example |
|-----|---------|---------|
| `json:"name"` | Parameter name | `json:"city"` |
| `description:"..."` | Description shown to the LLM | `description:"City name"` |
| `enum:"a,b,c"` | Restrict valid values | `enum:"json,text,csv"` |
| `omitempty` | Marks parameter as optional | `json:"limit,omitempty"` |
**Return helpers:**
| Function | Description |
|----------|-------------|
| `kit.TextResult(content)` | Successful text result |
| `kit.ErrorResult(content)` | Error result (LLM sees it as a tool error) |
**ToolOutput fields** (for advanced use):
```go
kit.ToolOutput{
Content: "result text", // text returned to the LLM
IsError: false, // true = LLM sees this as an error
Data: pngBytes, // optional binary data (images, audio)
MediaType: "image/png", // MIME type for binary Data
Metadata: map[string]any{}, // opaque metadata for hooks/UI (not sent to LLM)
}
```
**Parallel tools** — mark as safe for concurrent execution:
```go
searchTool := kit.NewParallelTool("search", "Search the web",
func(ctx context.Context, input SearchInput) (kit.ToolOutput, error) {
return kit.TextResult("results..."), nil
},
)
```
**Tool call ID** — available in context for logging/tracing:
```go
tool := kit.NewTool("my_tool", "...",
func(ctx context.Context, input MyInput) (kit.ToolOutput, error) {
callID := kit.ToolCallIDFromContext(ctx) // correlation ID from the LLM
log.Printf("[%s] my_tool called", callID)
return kit.TextResult("ok"), nil
},
)
```
### Built-in tool constructors
```go
@@ -431,6 +507,72 @@ kit.DeleteSession("/path/to/session.jsonl")
tm, _ := kit.OpenTreeSession("/path/to/session.jsonl") // open for direct access
```
### Custom Session Manager (Advanced)
You can provide a custom session manager to store conversation history in your own backend (database, cloud storage, etc.) instead of the default JSONL files.
```go
// Implement the SessionManager interface
type MyDatabaseSessionManager struct {
db *sql.DB
// ... other fields
}
func (s *MyDatabaseSessionManager) AppendMessage(msg kit.LLMMessage) (string, error) {
// Store message in your database
}
func (s *MyDatabaseSessionManager) GetMessages() []kit.LLMMessage {
// Retrieve messages from your database
}
// ... implement all other SessionManager methods
// Use with Kit
host, _ := kit.New(ctx, &kit.Options{
SessionManager: myCustomSession, // Your custom implementation
Model: "anthropic/claude-sonnet-latest",
})
```
**SessionManager Interface:**
```go
type SessionManager interface {
AppendMessage(msg kit.LLMMessage) (entryID string, err error)
GetMessages() []kit.LLMMessage
BuildContext() (messages []kit.LLMMessage, provider string, modelID string)
Branch(entryID string) error
GetCurrentBranch() []kit.BranchEntry
GetChildren(parentID string) []string
GetEntry(entryID string) *kit.BranchEntry
GetSessionID() string
GetSessionName() string
SetSessionName(name string) error
GetCreatedAt() time.Time
IsPersisted() bool
AppendCompaction(summary string, firstKeptEntryID string,
tokensBefore, tokensAfter int, messagesRemoved int, readFiles, modifiedFiles []string) (string, error)
GetLastCompaction() *kit.CompactionEntry
AppendExtensionData(extType, data string) (string, error)
GetExtensionData(extType string) []kit.ExtensionDataEntry
AppendModelChange(provider, modelID string) (string, error)
GetContextEntryIDs() []string
Close() error
}
```
**Use Cases:**
- **PocketBase integration**: Store sessions as PocketBase records
- **Cloud storage**: Persist sessions to S3, GCS, or Azure Blob
- **Multi-user apps**: Store sessions per user in a database
- **Custom retention**: Implement your own session cleanup policies
**Note:** When using a custom SessionManager, the following Options are ignored:
- `SessionPath` - your manager handles its own storage
- `Continue` - your manager handles session selection
- `NoSession` - use an in-memory implementation instead
---
## Model Management
+49 -21
View File
@@ -7,17 +7,16 @@ description: Monitor tool calls and streaming output with the Kit Go SDK.
## Event-based monitoring
For more granular control, use the event subscription API:
Subscribe to events for real-time monitoring. Each method returns an unsubscribe function:
```go
// Subscribe returns an unsubscribe function
unsub := host.OnToolCall(func(event kit.ToolCallEvent) {
fmt.Printf("Tool: %s, Args: %s\n", event.Name, event.Args)
fmt.Printf("Tool: %s, Args: %s\n", event.ToolName, event.ToolArgs)
})
defer unsub()
unsub2 := host.OnToolResult(func(event kit.ToolResultEvent) {
fmt.Printf("Result: %s (error: %v)\n", event.Name, event.IsError)
fmt.Printf("Result: %s (error: %v)\n", event.ToolName, event.IsError)
})
defer unsub2()
@@ -44,33 +43,62 @@ defer unsub6()
## Hook system
Hooks allow you to intercept and modify behavior. Unlike events, hooks can modify or cancel operations:
Hooks can **modify or cancel** operations. Unlike events (read-only), hooks are read-write interceptors.
### BeforeToolCall — block tool execution
```go
// Intercept tool calls before execution
host.OnBeforeToolCall(0, func(ctx context.Context, name string, args string) (string, error) {
if name == "bash" {
log.Println("Bash command:", args)
host.OnBeforeToolCall(kit.HookPriorityNormal, func(h kit.BeforeToolCallHook) *kit.BeforeToolCallResult {
// h.ToolCallID, h.ToolName, h.ToolArgs
if h.ToolName == "bash" && strings.Contains(h.ToolArgs, "rm -rf") {
return &kit.BeforeToolCallResult{Block: true, Reason: "dangerous command"}
}
return args, nil // return modified args or error to cancel
return nil // allow
})
```
// Process results after tool execution
host.OnAfterToolResult(0, func(ctx context.Context, name string, result string) (string, error) {
return result, nil
})
### AfterToolResult — modify tool output
// Before/after each agent turn
host.OnBeforeTurn(0, func(ctx context.Context) error {
return nil
})
host.OnAfterTurn(0, func(ctx context.Context) error {
```go
host.OnAfterToolResult(kit.HookPriorityNormal, func(h kit.AfterToolResultHook) *kit.AfterToolResultResult {
// h.ToolCallID, h.ToolName, h.ToolArgs, h.Result, h.IsError
if h.ToolName == "read" {
filtered := redactSecrets(h.Result)
return &kit.AfterToolResultResult{Result: &filtered}
}
return nil
})
```
The first argument is a priority (lower = runs first).
### BeforeTurn — modify prompt, inject messages
```go
host.OnBeforeTurn(kit.HookPriorityNormal, func(h kit.BeforeTurnHook) *kit.BeforeTurnResult {
// h.Prompt
newPrompt := h.Prompt + "\nAlways respond in JSON."
return &kit.BeforeTurnResult{Prompt: &newPrompt}
// Also available: SystemPrompt *string, InjectText *string
})
```
### AfterTurn — observation only
```go
host.OnAfterTurn(kit.HookPriorityNormal, func(h kit.AfterTurnHook) {
// h.Response, h.Error
log.Printf("Turn completed: %d chars", len(h.Response))
})
```
### Hook priorities
```go
kit.HookPriorityHigh = 0 // runs first
kit.HookPriorityNormal = 50 // default
kit.HookPriorityLow = 100 // runs last
```
Lower values run first. First non-nil result wins.
## Subagent event monitoring
+33 -2
View File
@@ -29,8 +29,12 @@ host, err := kit.New(ctx, &kit.Options{
NoSession: true,
// Tools
Tools: []kit.Tool{...}, // Replace default tool set entirely
ExtraTools: []kit.Tool{...}, // Add tools alongside defaults
Tools: []kit.Tool{...}, // Replace default tool set entirely
ExtraTools: []kit.Tool{...}, // Add tools alongside defaults
DisableCoreTools: true, // Use no core tools (0 tools, for chat-only)
// Configuration
SkipConfig: true, // Skip .kit.yml files (viper defaults + env vars still apply)
// Compaction
AutoCompact: true,
@@ -58,7 +62,34 @@ host, err := kit.New(ctx, &kit.Options{
| `NoSession` | `bool` | `false` | Ephemeral mode (no persistence) |
| `Tools` | `[]Tool` | — | Replace the entire default tool set |
| `ExtraTools` | `[]Tool` | — | Additional tools alongside core/MCP/extension tools |
| `DisableCoreTools` | `bool` | `false` | Use no core tools (0 tools, for chat-only) |
| `SkipConfig` | `bool` | `false` | Skip .kit.yml file loading |
| `AutoCompact` | `bool` | `false` | Auto-compact when near context limit |
| `CompactionOptions` | `*CompactionOptions` | — | Configuration for auto-compaction |
| `Skills` | `[]string` | — | Explicit skill files/dirs to load |
| `SkillsDir` | `string` | — | Override default skills directory |
## Tool configuration
**`Tools`** replaces ALL default tools (core + MCP + extension). **`ExtraTools`** adds tools alongside the defaults. Use `Tools` to restrict capabilities; use `ExtraTools` to extend them.
Create custom tools with `kit.NewTool` — no external dependencies needed:
```go
type LookupInput struct {
ID string `json:"id" description:"Record ID to look up"`
}
lookupTool := kit.NewTool("lookup", "Look up a record by ID",
func(ctx context.Context, input LookupInput) (kit.ToolOutput, error) {
record := db.Find(input.ID)
return kit.TextResult(record.String()), nil
},
)
host, _ := kit.New(ctx, &kit.Options{
ExtraTools: []kit.Tool{lookupTool},
})
```
See [Overview](/sdk/overview#custom-tools) for full custom tool documentation.
+38
View File
@@ -68,6 +68,44 @@ The SDK provides several prompt variants:
| `Steer(ctx, instruction)` | System-level steering without user message |
| `FollowUp(ctx, text)` | Continue without new user input |
## Custom tools
Create custom tools with `kit.NewTool`. The JSON schema is auto-generated from the input struct — no external dependencies required:
```go
type WeatherInput struct {
City string `json:"city" description:"City name"`
}
weatherTool := kit.NewTool("get_weather", "Get current weather for a city",
func(ctx context.Context, input WeatherInput) (kit.ToolOutput, error) {
return kit.TextResult("72°F, sunny in " + input.City), nil
},
)
host, _ := kit.New(ctx, &kit.Options{
ExtraTools: []kit.Tool{weatherTool},
})
```
Struct tags control the schema:
- `json:"name"` — parameter name
- `description:"..."` — description shown to the LLM
- `enum:"a,b,c"` — restrict valid values
- `omitempty` — marks the parameter as optional
Return values:
| Helper | Description |
|--------|-------------|
| `kit.TextResult(s)` | Successful text result |
| `kit.ErrorResult(s)` | Error result (LLM sees it as a tool error) |
For advanced use, return a `kit.ToolOutput` struct directly with `Data`, `MediaType`, and `Metadata` fields.
Use `kit.NewParallelTool` for tools that are safe to run concurrently. Use `kit.ToolCallIDFromContext(ctx)` to retrieve the LLM-assigned call ID for logging or tracing.
## Event system
Subscribe to events for monitoring:
+179
View File
@@ -901,6 +901,126 @@ a:hover { text-decoration: underline; }
color: var(--text-muted);
}
/* ============================================================
Compaction Card
============================================================ */
.compaction-card {
margin: 16px 0;
border: 1px solid var(--border);
border-radius: var(--radius);
background: var(--surface);
overflow: hidden;
}
.compaction-header {
display: flex;
align-items: center;
gap: 10px;
padding: 12px 16px;
cursor: pointer;
user-select: none;
transition: background var(--transition);
background: var(--surface-raised);
}
.compaction-header:hover {
background: var(--surface-overlay);
}
.compaction-icon {
width: 18px;
height: 18px;
color: var(--yellow);
flex-shrink: 0;
}
.compaction-title {
font-size: 13px;
font-weight: 600;
color: var(--text-secondary);
flex: 1;
}
.compaction-badge {
font-size: 11px;
font-weight: 500;
color: var(--text-muted);
background: var(--surface);
padding: 2px 8px;
border-radius: 10px;
border: 1px solid var(--border);
}
.compaction-chevron {
width: 16px;
height: 16px;
color: var(--text-faint);
transition: transform var(--transition);
flex-shrink: 0;
}
.compaction-card.expanded .compaction-chevron {
transform: rotate(180deg);
}
.compaction-content {
max-height: 0;
overflow: hidden;
transition: max-height var(--transition);
border-top: 1px solid transparent;
}
.compaction-card.expanded .compaction-content {
max-height: 2000px;
overflow-y: auto;
border-top-color: var(--border);
}
.compaction-summary {
padding: 16px;
font-size: 13.5px;
line-height: 1.7;
color: var(--text-secondary);
}
.compaction-summary .md-content h1,
.compaction-summary .md-content h2,
.compaction-summary .md-content h3 {
color: var(--text);
margin: 16px 0 8px;
}
.compaction-summary .md-content h1 { font-size: 1.3em; }
.compaction-summary .md-content h2 { font-size: 1.15em; }
.compaction-summary .md-content h3 { font-size: 1.05em; }
.compaction-summary .md-content ul,
.compaction-summary .md-content ol {
padding-left: 20px;
}
.compaction-stats {
display: flex;
gap: 16px;
padding: 12px 16px;
background: var(--surface-raised);
border-top: 1px solid var(--border-subtle);
font-size: 11.5px;
color: var(--text-muted);
flex-wrap: wrap;
}
.compaction-stat {
display: flex;
align-items: center;
gap: 4px;
}
.compaction-stat strong {
color: var(--text-secondary);
font-weight: 600;
}
/* ============================================================
System Prompt Display
============================================================ */
@@ -1460,6 +1580,7 @@ a:hover { text-decoration: underline; }
let userMsgCount = 0;
let assistantMsgCount = 0;
let toolCallCount = 0;
let compactionCount = 0;
// Render each entry
for (const entry of path) {
@@ -1491,6 +1612,9 @@ a:hover { text-decoration: underline; }
renderSystemNotice('Label', entry.label || '', 'label');
} else if (entry.type === 'session_info') {
// Already handled above for header
} else if (entry.type === 'compaction') {
compactionCount++;
renderCompaction(entry);
}
}
@@ -1501,6 +1625,7 @@ a:hover { text-decoration: underline; }
<div class="stat-item"><strong>${userMsgCount}</strong> user message${userMsgCount !== 1 ? 's' : ''}</div>
<div class="stat-item"><strong>${assistantMsgCount}</strong> assistant message${assistantMsgCount !== 1 ? 's' : ''}</div>
${toolCallCount > 0 ? `<div class="stat-item"><strong>${toolCallCount}</strong> tool call${toolCallCount !== 1 ? 's' : ''}</div>` : ''}
${compactionCount > 0 ? `<div class="stat-item"><strong>${compactionCount}</strong> compaction${compactionCount !== 1 ? 's' : ''}</div>` : ''}
${header && header.cwd ? `<div class="stat-item">📁 ${escapeHtml(header.cwd)}</div>` : ''}
</div>`;
$conversation.insertAdjacentHTML('beforeend', statsHtml);
@@ -2030,6 +2155,60 @@ a:hover { text-decoration: underline; }
$conversation.appendChild(el);
}
// ============================================================
// Compaction Display
// ============================================================
function renderCompaction(entry) {
const el = document.createElement('div');
el.className = 'compaction-card fade-in';
const cardId = 'compaction-' + Math.random().toString(36).substr(2, 9);
// Build stats
const stats = [];
if (entry.messages_removed > 0) {
stats.push(`<div class="compaction-stat"><strong>${entry.messages_removed}</strong> messages compacted</div>`);
}
if (entry.tokens_before > 0 && entry.tokens_after > 0) {
const saved = entry.tokens_before - entry.tokens_after;
stats.push(`<div class="compaction-stat"><strong>${saved}</strong> tokens saved</div>`);
}
// Format timestamp
const timeStr = formatTime(entry.timestamp);
el.innerHTML = `
<div class="compaction-header" onclick="toggleCompaction('${cardId}')">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16" fill="currentColor" class="compaction-icon">
<path d="M2.5 3.5a.5.5 0 0 1 .5-.5h10a.5.5 0 0 1 .5.5v1a.5.5 0 0 1-.5.5h-10a.5.5 0 0 1-.5-.5v-1Zm0 4a.5.5 0 0 1 .5-.5h10a.5.5 0 0 1 .5.5v1a.5.5 0 0 1-.5.5h-10a.5.5 0 0 1-.5-.5v-1Zm0 4a.5.5 0 0 1 .5-.5h10a.5.5 0 0 1 .5.5v1a.5.5 0 0 1-.5.5h-10a.5.5 0 0 1-.5-.5v-1Z"/>
</svg>
<span class="compaction-title">Context Compacted</span>
${timeStr ? `<span class="compaction-badge">${timeStr}</span>` : ''}
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16" fill="currentColor" class="compaction-chevron" id="${cardId}-chevron">
<path d="M4.427 9.427a.25.25 0 0 0 0 .353l3 3a.25.25 0 0 0 .353 0l3-3a.25.25 0 0 0-.353-.353L8 11.646V4.75a.75.75 0 0 0-1.5 0v6.896L4.78 9.427a.25.25 0 0 0-.353 0Z"/>
</svg>
</div>
<div class="compaction-content" id="${cardId}">
${entry.summary ? `<div class="compaction-summary"><div class="md-content">${renderMarkdown(entry.summary)}</div></div>` : ''}
${stats.length > 0 ? `<div class="compaction-stats">${stats.join('')}</div>` : ''}
</div>`;
$conversation.appendChild(el);
}
// Toggle compaction card expansion
window.toggleCompaction = function(cardId) {
const card = document.getElementById(cardId).closest('.compaction-card');
const chevron = document.getElementById(cardId + '-chevron');
const isExpanded = card.classList.contains('expanded');
if (isExpanded) {
card.classList.remove('expanded');
} else {
card.classList.add('expanded');
}
};
// ============================================================
// System Prompt Display (collapsible)
// ============================================================