mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2de98d32be | |||
| 83127467c5 | |||
| e07c94f49d | |||
| b87146a284 | |||
| 186d9f7f44 | |||
| 3a8ffc2104 | |||
| e54570162e | |||
| 34bb97a40e | |||
| f5c1a16f8a | |||
| b29d7d2166 | |||
| 3ea0db69ea | |||
| 4304a5e899 | |||
| 4019c1e4f7 | |||
| 30ad7c1d0b | |||
| e33564c569 | |||
| 5ff28445fd | |||
| 13d177e5d0 | |||
| 3ffc995f27 | |||
| b2bd016135 |
@@ -545,6 +545,28 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
})
|
||||
```
|
||||
|
||||
### Custom Tools
|
||||
|
||||
Create custom tools with automatic schema generation — no external dependencies needed:
|
||||
|
||||
```go
|
||||
type SearchInput struct {
|
||||
Query string `json:"query" description:"Search query"`
|
||||
}
|
||||
|
||||
searchTool := kit.NewTool("search", "Search the codebase",
|
||||
func(ctx context.Context, input SearchInput) (kit.ToolOutput, error) {
|
||||
return kit.TextResult("Found: ..."), nil
|
||||
},
|
||||
)
|
||||
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
ExtraTools: []kit.Tool{searchTool}, // adds alongside built-in tools
|
||||
})
|
||||
```
|
||||
|
||||
Use `kit.NewParallelTool` for tools safe to run concurrently. See the [SDK docs](/sdk/overview) for full details on struct tags, `ToolOutput` fields, and `ToolCallIDFromContext`.
|
||||
|
||||
### With Callbacks
|
||||
|
||||
```go
|
||||
|
||||
+12
@@ -2003,6 +2003,18 @@ func writeJSONError(err error) {
|
||||
//
|
||||
// SetupCLI is not used for interactive mode; the TUI (AppModel) handles its own rendering.
|
||||
func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getPromptTemplates func() []*prompts.PromptTemplate, getSkillItems func() []ui.SkillItem, getToolNames func() []string, getMCPToolCount func() int, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error, startupExtensionMessages []string) error {
|
||||
// Redirect all log output (stdlib and charm) to a file so that log
|
||||
// messages don't write to stderr and corrupt the TUI. Bubble Tea
|
||||
// captures stdout for rendering; any stray stderr output from
|
||||
// background goroutines (watchers, extension handlers, SDK internals)
|
||||
// will visually corrupt the terminal.
|
||||
logDir := filepath.Join(os.TempDir(), "kit")
|
||||
_ = os.MkdirAll(logDir, 0o700)
|
||||
logFile, logErr := tea.LogToFile(filepath.Join(logDir, "kit.log"), "kit")
|
||||
if logErr == nil {
|
||||
defer func() { _ = logFile.Close() }()
|
||||
}
|
||||
|
||||
// Determine terminal size; fall back gracefully.
|
||||
termWidth, termHeight, err := term.GetSize(int(os.Stdout.Fd()))
|
||||
if err != nil || termWidth == 0 {
|
||||
|
||||
@@ -21,7 +21,7 @@ require (
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/indaco/herald v0.13.0
|
||||
github.com/indaco/herald-md v0.3.0
|
||||
github.com/mark3labs/mcp-go v0.47.0
|
||||
github.com/mark3labs/mcp-go v0.47.1
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/traefik/yaegi v0.16.1
|
||||
@@ -31,7 +31,7 @@ require (
|
||||
|
||||
require (
|
||||
cloud.google.com/go v0.123.0 // indirect
|
||||
cloud.google.com/go/auth v0.19.0 // indirect
|
||||
cloud.google.com/go/auth v0.20.0 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.9.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 // indirect
|
||||
@@ -58,9 +58,9 @@ require (
|
||||
github.com/charmbracelet/harmonica v0.2.0 // indirect
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260330094520-2dce04b6f8a4 // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260406091427-a791e22d5143 // indirect
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260330094520-2dce04b6f8a4 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260406091427-a791e22d5143 // indirect
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/json v0.2.0 // indirect
|
||||
github.com/charmbracelet/x/termios v0.1.1 // indirect
|
||||
@@ -81,7 +81,7 @@ require (
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.21.0 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.3.0 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.3.1 // indirect
|
||||
github.com/kaptinlin/jsonpointer v0.4.17 // indirect
|
||||
github.com/kaptinlin/jsonschema v0.7.7 // indirect
|
||||
github.com/kaptinlin/messageformat-go v0.4.19 // indirect
|
||||
@@ -103,8 +103,8 @@ require (
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
github.com/yuin/goldmark v1.8.2 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 // indirect
|
||||
go.opentelemetry.io/otel v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.43.0 // indirect
|
||||
@@ -114,9 +114,9 @@ require (
|
||||
golang.org/x/net v0.52.0 // indirect
|
||||
golang.org/x/oauth2 v0.36.0 // indirect
|
||||
golang.org/x/time v0.15.0 // indirect
|
||||
google.golang.org/api v0.274.0 // indirect
|
||||
google.golang.org/api v0.275.0 // indirect
|
||||
google.golang.org/genai v1.52.1 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260406210006-6f92a3bedf2d // indirect
|
||||
google.golang.org/grpc v1.80.0 // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
@@ -128,13 +128,13 @@ require (
|
||||
github.com/charmbracelet/x/term v0.2.2 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.22 // indirect
|
||||
github.com/mattn/go-isatty v0.0.21 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.23 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/sys v0.43.0 // indirect
|
||||
golang.org/x/text v0.35.0
|
||||
)
|
||||
|
||||
@@ -10,8 +10,8 @@ charm.land/lipgloss/v2 v2.0.2 h1:xFolbF8JdpNkM2cEPTfXEcW1p6NRzOWTSamRfYEw8cs=
|
||||
charm.land/lipgloss/v2 v2.0.2/go.mod h1:KjPle2Qd3YmvP1KL5OMHiHysGcNwq6u83MUjYkFvEkM=
|
||||
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
|
||||
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
|
||||
cloud.google.com/go/auth v0.19.0 h1:DGYwtbcsGsT1ywuxsIoWi1u/vlks0moIblQHgSDgQkQ=
|
||||
cloud.google.com/go/auth v0.19.0/go.mod h1:2Aph7BT2KnaSFOM0JDPyiYgNh6PL9vGMiP8CUIXZ+IY=
|
||||
cloud.google.com/go/auth v0.20.0 h1:kXTssoVb4azsVDoUiF8KvxAqrsQcQtB53DcSgta74CA=
|
||||
cloud.google.com/go/auth v0.20.0/go.mod h1:942/yi/itH1SsmpyrbnTMDgGfdy2BUqIKyd0cyYLc5Q=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
|
||||
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
|
||||
@@ -96,14 +96,14 @@ github.com/charmbracelet/x/conpty v0.1.1 h1:s1bUxjoi7EpqiXysVtC+a8RrvPPNcNvAjfi4
|
||||
github.com/charmbracelet/x/conpty v0.1.1/go.mod h1:OmtR77VODEFbiTzGE9G1XiRJAga6011PIm4u5fTNZpk=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 h1:JSt3B+U9iqk37QUU2Rvb6DSBYRLtWqFqfxf8l5hOZUA=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86/go.mod h1:2P0UgXMEa6TsToMSuFqKFQR+fZTO9CNGUNokkPatT/0=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260330094520-2dce04b6f8a4 h1:pIj18ZCZO4WOVj7jwjLoUb1lC7rS/I8oC3fZWXugNaY=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260330094520-2dce04b6f8a4/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260406091427-a791e22d5143 h1:zmBor0ftFNqVFp9U59ZoEDRUCIYSGOGSIfGGkNZRufs=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260406091427-a791e22d5143/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f h1:pk6gmGpCE7F3FcjaOEKYriCvpmIN4+6OS/RD0vm4uIA=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f/go.mod h1:IfZAMTHB6XkZSeXUqriemErjAWCCzT0LwjKFYCZyw0I=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 h1:55/qLwjIh0gL0Vni+QAWk7T/qRVP6sBf+2agPBgnOFE=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0/go.mod h1:5UHwmG+is5THxMyCJHNPCn2/ecI07aKNrW+LcResjJ8=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260330094520-2dce04b6f8a4 h1:VSd4zShIAf/4FgEDFJpapEcAPrc7h3dyyN7V9JlJpQw=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260330094520-2dce04b6f8a4/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260406091427-a791e22d5143 h1:aEppolah2k9c0LzKX2fk5ryuyQ0Lq8kCOjkvMw1b8o4=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260406091427-a791e22d5143/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 h1:i69S2XI7uG1u4NLGeJPSYU++Nmjvpo9nwd6aoEm7gkA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0/go.mod h1:/ehtMPNh9K4odGFkqYJKpIYyePhdp1hLBRvyY4bWkH8=
|
||||
github.com/charmbracelet/x/json v0.2.0 h1:DqB+ZGx2h+Z+1s98HOuOyli+i97wsFQIxP2ZQANTPrQ=
|
||||
@@ -185,8 +185,8 @@ github.com/indaco/herald v0.13.0 h1:+xVG9Fx5NpuWhwku/9IlRL6I009NnX4VUGKvlZHTRxU=
|
||||
github.com/indaco/herald v0.13.0/go.mod h1:T5g1+XLYvpjouhzAGHnAHDCKizhESkoV6+QPZ3DhgWA=
|
||||
github.com/indaco/herald-md v0.3.0 h1:hN1cKyrexPPM9PeHBsKuaWvIizSi/iYvM9yzRgtdb8M=
|
||||
github.com/indaco/herald-md v0.3.0/go.mod h1:RUHVaDSG45ymJjKyxpDwBocLXrZo93FB4OeYMsw9B9s=
|
||||
github.com/kaptinlin/go-i18n v0.3.0 h1:wP76dvYg04bvwTb+8NB+CmdZ2kL7lSSCQ9B/kFv7QHo=
|
||||
github.com/kaptinlin/go-i18n v0.3.0/go.mod h1:pVcu9qsW5pOIOoZFJXesRYmLos1vMQrby70JPAoWmJU=
|
||||
github.com/kaptinlin/go-i18n v0.3.1 h1:plXi3XQE1aYamFi8TU0K6actODmw2+5FSobmhTkfQ/0=
|
||||
github.com/kaptinlin/go-i18n v0.3.1/go.mod h1:ZRoAHj7elWYamfbv7wev7Ajch6LOzjtBaq8nWe8HIVk=
|
||||
github.com/kaptinlin/jsonpointer v0.4.17 h1:mY9k8ciWncxbsECyaxKnR0MdmxamNdp2tLQkAKVrtSk=
|
||||
github.com/kaptinlin/jsonpointer v0.4.17/go.mod h1:SsfsjqnHG5zuKo1DTBzk1VknaHlL4osHw+X9kZKukpU=
|
||||
github.com/kaptinlin/jsonschema v0.7.7 h1:41BlQJ9dskH0oE5DSzBUrl/w4JQYIr6N6L0B5GNyDoM=
|
||||
@@ -201,12 +201,12 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0 h1:UtrWVfLdarDgc44HcS7pYloGHJUjHV/4FwW4TvVgFr4=
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mark3labs/mcp-go v0.47.0 h1:h44yeM3DduDyQgzImYWu4pt6VRkqP/0p/95AGhWngnA=
|
||||
github.com/mark3labs/mcp-go v0.47.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.22 h1:76lXsPn6FyHtTY+jt2fTTvsMUCZq1k0qwRsAMuxzKAk=
|
||||
github.com/mattn/go-runewidth v0.0.22/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/mark3labs/mcp-go v0.47.1 h1:A9sJJ20mscl/ssLYHjodfaoBmq6uuhMG7pAPNYaQymQ=
|
||||
github.com/mark3labs/mcp-go v0.47.1/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
|
||||
github.com/mattn/go-isatty v0.0.21 h1:xYae+lCNBP7QuW4PUnNG61ffM4hVIfm+zUzDuSzYLGs=
|
||||
github.com/mattn/go-isatty v0.0.21/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
|
||||
github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw=
|
||||
github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
@@ -272,18 +272,18 @@ github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE=
|
||||
github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 h1:yI1/OhfEPy7J9eoa6Sj051C7n5dvpj0QX8g4sRchg04=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 h1:OyrsyzuttWTSur2qN/Lm0m2a8yqyIjUVBZcxFPuXq2o=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0/go.mod h1:C2NGBr+kAB4bk3xtMXfZ94gqFDtg/GkI7e9zqGh5Beg=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0 h1:0Qx7VGBacMm9ZENQ7TnNObTYI4ShC+lHI16seduaxZo=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0/go.mod h1:Sje3i3MjSPKTSPvVWCaL8ugBzJwik3u4smCjUeuupqg=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 h1:CqXxU8VOmDefoh0+ztfGaymYbhdB/tT3zs79QaZTNGY=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0/go.mod h1:BuhAPThV8PBHBvg8ZzZ/Ok3idOdhWIodywz2xEcRbJo=
|
||||
go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I=
|
||||
go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0=
|
||||
go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM=
|
||||
go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY=
|
||||
go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo=
|
||||
go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
@@ -298,9 +298,8 @@ golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
|
||||
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
@@ -309,16 +308,16 @@ golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
|
||||
google.golang.org/api v0.274.0 h1:aYhycS5QQCwxHLwfEHRRLf9yNsfvp1JadKKWBE54RFA=
|
||||
google.golang.org/api v0.274.0/go.mod h1:JbAt7mF+XVmWu6xNP8/+CTiGH30ofmCmk9nM8d8fHew=
|
||||
google.golang.org/api v0.275.0 h1:vfY5d9vFVJeWEZT65QDd9hbndr7FyZ2+6mIzGAh71NI=
|
||||
google.golang.org/api v0.275.0/go.mod h1:Fnag/EWUPIcJXuIkP1pjoTgS5vdxlk3eeemL7Do6bvw=
|
||||
google.golang.org/genai v1.52.1 h1:dYoljKtLDXMiBdVaClSJ/ZPwZ7j1N0lGjMhwOKOQUlk=
|
||||
google.golang.org/genai v1.52.1/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 h1:XzmzkmB14QhVhgnawEVsOn6OFsnpyxNPRY9QV01dNB0=
|
||||
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:L43LFes82YgSonw6iTXTxXUX1OlULt4AQtkik4ULL/I=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7 h1:41r6JMbpzBMen0R/4TZeeAmGXSJC7DftGINUodzTkPI=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:EIQZ5bFCfRQDV4MhRle7+OgjNtZ6P1PiZBgAKuxXu/Y=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260406210006-6f92a3bedf2d h1:wT2n40TBqFY6wiwazVK9/iTWbsQrgk5ZfCSVFLO9LQA=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260406210006-6f92a3bedf2d/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
|
||||
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
|
||||
+2
-112
@@ -23,18 +23,6 @@ import (
|
||||
// Version is injected at build time; fallback to "dev".
|
||||
var Version = "dev"
|
||||
|
||||
// thinkingTagOpen and thinkingTagClose are the XML-style tags that some models
|
||||
// (Qwen, DeepSeek) wrap reasoning content in. We parse these to extract
|
||||
// reasoning/thinking content and send it as ACP thought updates.
|
||||
// Also support <think> format used by some models.
|
||||
const (
|
||||
thinkingTagOpen = "<thinking>"
|
||||
thinkingTagClose = "</thinking>"
|
||||
shortThinkTagOpen = "<think>"
|
||||
shortThinkTagClose = "</think>"
|
||||
)
|
||||
|
||||
// Agent implements the acp.Agent interface, delegating to Kit for LLM
|
||||
// execution, tool calls, and session management.
|
||||
type Agent struct {
|
||||
conn *acp.AgentSideConnection
|
||||
@@ -42,10 +30,6 @@ type Agent struct {
|
||||
|
||||
// toolCallCounter provides unique IDs for tool calls within a turn.
|
||||
toolCallCounter atomic.Int64
|
||||
|
||||
// inThinkingTag tracks whether we're currently inside a <thinking> tag
|
||||
// when parsing streaming content from models that wrap reasoning in XML tags.
|
||||
inThinkingTag bool
|
||||
}
|
||||
|
||||
// NewAgent creates a new ACP agent backed by Kit.
|
||||
@@ -144,9 +128,6 @@ func (a *Agent) Prompt(ctx context.Context, params acp.PromptRequest) (acp.Promp
|
||||
|
||||
log.Debug("acp: prompt", "session", sessionID, "prompt_len", len(promptText), "files", len(files))
|
||||
|
||||
// Reset thinking tag state for this new prompt turn
|
||||
a.inThinkingTag = false
|
||||
|
||||
// Create a cancellable context for this prompt turn.
|
||||
promptCtx, cancel := context.WithCancel(ctx)
|
||||
sess.setCancel(cancel)
|
||||
@@ -230,24 +211,8 @@ func (a *Agent) subscribeEvents(ctx context.Context, k *kit.Kit, sessionID acp.S
|
||||
var update *acp.SessionUpdate
|
||||
switch ev := e.(type) {
|
||||
case kit.MessageUpdateEvent:
|
||||
// Handle models that wrap reasoning in <thinking> tags (Qwen, DeepSeek)
|
||||
// Parse the chunk and separate reasoning from regular text
|
||||
reasoning, text := a.parseThinkingTags(ev.Chunk)
|
||||
|
||||
// Send reasoning update if we have reasoning content
|
||||
if reasoning != "" {
|
||||
u := acp.UpdateAgentThoughtText(reasoning)
|
||||
_ = a.conn.SessionUpdate(ctx, acp.SessionNotification{
|
||||
SessionId: sessionID,
|
||||
Update: u,
|
||||
})
|
||||
}
|
||||
|
||||
// Send text update if we have text content
|
||||
if text != "" {
|
||||
u := acp.UpdateAgentMessageText(text)
|
||||
update = &u
|
||||
}
|
||||
u := acp.UpdateAgentMessageText(ev.Chunk)
|
||||
update = &u
|
||||
|
||||
case kit.ReasoningDeltaEvent:
|
||||
u := acp.UpdateAgentThoughtText(ev.Delta)
|
||||
@@ -430,81 +395,6 @@ func extractPromptContent(blocks []acp.ContentBlock) (string, []kit.LLMFilePart)
|
||||
return strings.Join(textParts, "\n"), files
|
||||
}
|
||||
|
||||
// parseThinkingTags parses a text chunk for <thinking> or tags and separates
|
||||
// reasoning content from regular text. This handles models (Qwen, DeepSeek)
|
||||
// that wrap reasoning in XML-style tags instead of using proper reasoning events.
|
||||
// Returns (reasoningContent, textContent).
|
||||
func (a *Agent) parseThinkingTags(chunk string) (reasoning string, text string) {
|
||||
// Handle empty chunk
|
||||
if chunk == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// Determine which tag format to use (long or short)
|
||||
openTag := thinkingTagOpen
|
||||
closeTag := thinkingTagClose
|
||||
|
||||
if strings.Contains(chunk, shortThinkTagOpen) || strings.Contains(chunk, shortThinkTagClose) {
|
||||
openTag = shortThinkTagOpen
|
||||
closeTag = shortThinkTagClose
|
||||
} else if !strings.Contains(chunk, thinkingTagOpen) && !strings.Contains(chunk, thinkingTagClose) && !a.inThinkingTag {
|
||||
// No tags at all and not in thinking mode - return as text
|
||||
return "", chunk
|
||||
}
|
||||
|
||||
// Check for opening tag
|
||||
if strings.Contains(chunk, openTag) {
|
||||
parts := strings.SplitN(chunk, openTag, 2)
|
||||
|
||||
// Content before the opening tag is regular text
|
||||
if !a.inThinkingTag && parts[0] != "" {
|
||||
text = parts[0]
|
||||
}
|
||||
|
||||
a.inThinkingTag = true
|
||||
|
||||
// Content after the opening tag is reasoning
|
||||
if len(parts) > 1 {
|
||||
// Check if the same chunk contains the closing tag
|
||||
if strings.Contains(parts[1], closeTag) {
|
||||
innerParts := strings.SplitN(parts[1], closeTag, 2)
|
||||
reasoning = innerParts[0]
|
||||
a.inThinkingTag = false
|
||||
|
||||
// Content after closing tag is regular text
|
||||
if len(innerParts) > 1 && innerParts[1] != "" {
|
||||
text += innerParts[1]
|
||||
}
|
||||
} else if parts[1] != "" {
|
||||
// No closing tag yet, all remaining content is reasoning
|
||||
reasoning = parts[1]
|
||||
}
|
||||
}
|
||||
return reasoning, text
|
||||
}
|
||||
|
||||
// Check for closing tag
|
||||
if strings.Contains(chunk, closeTag) {
|
||||
parts := strings.SplitN(chunk, closeTag, 2)
|
||||
a.inThinkingTag = false
|
||||
|
||||
// Content before closing tag is reasoning
|
||||
reasoning = parts[0]
|
||||
|
||||
// Content after closing tag is regular text
|
||||
if len(parts) > 1 && parts[1] != "" {
|
||||
text = parts[1]
|
||||
}
|
||||
return reasoning, text
|
||||
}
|
||||
|
||||
// No tags found - content goes to current mode
|
||||
if a.inThinkingTag {
|
||||
return chunk, ""
|
||||
}
|
||||
return "", chunk
|
||||
}
|
||||
|
||||
// isTextMimeType returns true if the MIME type indicates text content.
|
||||
func isTextMimeType(mimeType string) bool {
|
||||
return strings.HasPrefix(mimeType, "text/") ||
|
||||
|
||||
+104
-6
@@ -30,6 +30,11 @@ type AgentConfig struct {
|
||||
// If nil, remote MCP servers that require OAuth will fail to connect.
|
||||
AuthHandler tools.MCPAuthHandler
|
||||
|
||||
// TokenStoreFactory, if non-nil, creates a custom token store for each
|
||||
// remote MCP server's OAuth tokens. When nil, the default file-based
|
||||
// token store is used.
|
||||
TokenStoreFactory tools.TokenStoreFactory
|
||||
|
||||
// CoreTools overrides the default core tool set. If empty, core.AllTools()
|
||||
// is used. This allows SDK users to provide a custom tool set (e.g.
|
||||
// CodingTools or tools with a custom WorkDir).
|
||||
@@ -89,6 +94,14 @@ type ReasoningCompleteHandler func()
|
||||
// Note: This is an alias for core.ToolOutputCallback to avoid import cycles.
|
||||
type ToolOutputHandler = core.ToolOutputCallback
|
||||
|
||||
// StepMessagesHandler is a function type for persisting messages after each
|
||||
// complete step in a multi-step agent turn. The handler receives the messages
|
||||
// produced by the step (typically an assistant message with tool calls followed
|
||||
// by a tool-role message with results, or a final assistant message with text).
|
||||
// This enables incremental session persistence so that progress is saved as
|
||||
// it happens rather than only at the end of the turn.
|
||||
type StepMessagesHandler func(stepMessages []fantasy.Message)
|
||||
|
||||
// StepUsageHandler is a function type for handling token usage after each
|
||||
// complete step in a multi-step agent turn. This enables real-time cost
|
||||
// tracking during long-running tool-calling conversations.
|
||||
@@ -141,6 +154,11 @@ type GenerateWithLoopResult struct {
|
||||
TotalUsage fantasy.Usage
|
||||
// StopReason is the LLM provider's finish reason for the final response.
|
||||
StopReason string
|
||||
// PersistedMessageCount is the number of new messages (beyond the original
|
||||
// input) that were already persisted incrementally via OnStepMessages during
|
||||
// generation. The caller should skip these when doing post-generation
|
||||
// persistence to avoid duplicates.
|
||||
PersistedMessageCount int
|
||||
}
|
||||
|
||||
// NewAgent creates a new Agent with core tools and optional MCP tool integration.
|
||||
@@ -223,6 +241,9 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
if agentConfig.AuthHandler != nil {
|
||||
toolManager.SetAuthHandler(agentConfig.AuthHandler)
|
||||
}
|
||||
if agentConfig.TokenStoreFactory != nil {
|
||||
toolManager.SetTokenStoreFactory(agentConfig.TokenStoreFactory)
|
||||
}
|
||||
if agentConfig.DebugLogger != nil {
|
||||
toolManager.SetDebugLogger(agentConfig.DebugLogger)
|
||||
}
|
||||
@@ -377,7 +398,7 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []fantasy.Message
|
||||
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult,
|
||||
onResponse, onToolCallContent, nil, nil, nil, nil, nil)
|
||||
onResponse, onToolCallContent, nil, nil, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
// GenerateWithLoopAndStreaming processes messages using the agent with streaming and callbacks.
|
||||
@@ -390,6 +411,7 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
onReasoningDelta ReasoningDeltaHandler,
|
||||
onReasoningComplete ReasoningCompleteHandler,
|
||||
onToolOutput ToolOutputHandler,
|
||||
onStepMessages StepMessagesHandler,
|
||||
onStepUsage StepUsageHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
|
||||
@@ -429,6 +451,10 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
// when it returns an error, but the OnStepFinish callback fires
|
||||
// for every step that completed before the error occurred.
|
||||
var completedStepMessages []fantasy.Message
|
||||
// persistedCount tracks how many new messages (beyond the original
|
||||
// input) were persisted incrementally via onStepMessages, so the
|
||||
// caller can skip them during post-generation persistence.
|
||||
var persistedCount int
|
||||
|
||||
// Use the streaming agent
|
||||
streamCall := fantasy.AgentStreamCall{
|
||||
@@ -514,6 +540,13 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
// persisted even if a later step is cancelled.
|
||||
completedStepMessages = append(completedStepMessages, step.Messages...)
|
||||
|
||||
// Persist step messages incrementally so progress is saved
|
||||
// as it happens rather than only at the end of the turn.
|
||||
if onStepMessages != nil && len(step.Messages) > 0 {
|
||||
onStepMessages(step.Messages)
|
||||
persistedCount += len(step.Messages)
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
@@ -592,7 +625,8 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
partialMessages = append(partialMessages, messages...)
|
||||
partialMessages = append(partialMessages, completedStepMessages...)
|
||||
return &GenerateWithLoopResult{
|
||||
ConversationMessages: partialMessages,
|
||||
ConversationMessages: partialMessages,
|
||||
PersistedMessageCount: persistedCount,
|
||||
}, err
|
||||
}
|
||||
return nil, err
|
||||
@@ -607,7 +641,9 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
onResponse(result.Response.Content.Text())
|
||||
}
|
||||
|
||||
return convertAgentResult(result, messages), nil
|
||||
r := convertAgentResult(result, messages)
|
||||
r.PersistedMessageCount = persistedCount
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// Non-streaming path with no callbacks — use the simpler Generate call.
|
||||
@@ -798,6 +834,59 @@ func (a *Agent) SetExtraTools(extraTools []fantasy.AgentTool) {
|
||||
a.rebuildFantasyAgent()
|
||||
}
|
||||
|
||||
// AddMCPServer connects to a new MCP server at runtime and makes its tools
|
||||
// available to the agent. Returns the number of tools loaded.
|
||||
// If the agent has no tool manager (no MCP servers were configured at init),
|
||||
// one is created automatically.
|
||||
func (a *Agent) AddMCPServer(ctx context.Context, name string, cfg config.MCPServerConfig) (int, error) {
|
||||
// Ensure MCP tools from initial load are settled first.
|
||||
a.ensureMCPTools()
|
||||
|
||||
if a.toolManager == nil {
|
||||
a.toolManager = tools.NewMCPToolManager()
|
||||
a.toolManager.SetModel(a.model)
|
||||
a.toolManager.SetOnToolsChanged(func() {
|
||||
a.rebuildFantasyAgent()
|
||||
})
|
||||
}
|
||||
|
||||
count, err := a.toolManager.AddServer(ctx, name, cfg)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// AddServer's onToolsChanged callback triggers rebuildFantasyAgent,
|
||||
// but only if it was wired. Ensure rebuild happens regardless.
|
||||
a.rebuildFantasyAgent()
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// RemoveMCPServer disconnects an MCP server and removes its tools from the agent.
|
||||
func (a *Agent) RemoveMCPServer(name string) error {
|
||||
if a.toolManager == nil {
|
||||
return fmt.Errorf("no MCP servers loaded")
|
||||
}
|
||||
|
||||
// Ensure MCP tools from initial load are settled first.
|
||||
a.ensureMCPTools()
|
||||
|
||||
err := a.toolManager.RemoveServer(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveServer's onToolsChanged callback triggers rebuildFantasyAgent,
|
||||
// but ensure rebuild happens regardless.
|
||||
a.rebuildFantasyAgent()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMCPToolManager returns the underlying MCP tool manager.
|
||||
// Returns nil if no MCP servers have been configured.
|
||||
func (a *Agent) GetMCPToolManager() *tools.MCPToolManager {
|
||||
return a.toolManager
|
||||
}
|
||||
|
||||
// GetLoadingMessage returns the loading message from provider creation.
|
||||
func (a *Agent) GetLoadingMessage() string {
|
||||
return a.loadingMessage
|
||||
@@ -811,9 +900,11 @@ func (a *Agent) GetLoadedServerNames() []string {
|
||||
return a.toolManager.GetLoadedServerNames()
|
||||
}
|
||||
|
||||
// SetModel swaps the agent's LLM provider to a new model. The existing tools,
|
||||
// system prompt, and configuration are preserved. The old provider is closed
|
||||
// if it has a closer. Returns the previous model string for notification.
|
||||
// SetModel swaps the agent's LLM provider to a new model. The existing tools
|
||||
// and configuration are preserved. When the new model's ProviderConfig carries
|
||||
// a system prompt (from per-model settings), it replaces the agent's stored
|
||||
// prompt so the rebuilt fantasy agent uses it. The old provider is closed if
|
||||
// it has a closer.
|
||||
func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) error {
|
||||
// Ensure MCP tools are loaded before rebuilding (SetModel may be called
|
||||
// before the first LLM call).
|
||||
@@ -840,6 +931,13 @@ func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) err
|
||||
a.skipMaxOutputTokens = providerResult.SkipMaxOutputTokens
|
||||
a.modelConfig = config
|
||||
|
||||
// Update system prompt when the config carries one (from per-model
|
||||
// settings or the global config). This allows model-specific system
|
||||
// prompts to take effect on model switch.
|
||||
if config.SystemPrompt != "" {
|
||||
a.systemPrompt = config.SystemPrompt
|
||||
}
|
||||
|
||||
// Update provider type.
|
||||
if config.ModelString != "" {
|
||||
if p, _, err := models.ParseModelString(config.ModelString); err == nil {
|
||||
|
||||
@@ -0,0 +1,242 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
)
|
||||
|
||||
// mockModel is a minimal LanguageModel that satisfies the interface
|
||||
// without making real API calls. Used to test tool management wiring.
|
||||
type mockModel struct{}
|
||||
|
||||
func (m *mockModel) Generate(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
return &fantasy.Response{}, nil
|
||||
}
|
||||
func (m *mockModel) Stream(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockModel) GenerateObject(_ context.Context, _ fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
return &fantasy.ObjectResponse{}, nil
|
||||
}
|
||||
func (m *mockModel) StreamObject(_ context.Context, _ fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockModel) Provider() string { return "mock" }
|
||||
func (m *mockModel) Model() string { return "mock-model" }
|
||||
|
||||
// testdataDir returns the absolute path to the tools testdata directory.
|
||||
func testdataDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
_, file, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
t.Fatal("cannot determine test file path")
|
||||
}
|
||||
return filepath.Join(filepath.Dir(file), "..", "tools", "testdata")
|
||||
}
|
||||
|
||||
// echoServerConfig returns an MCPServerConfig for the test echo MCP server.
|
||||
func echoServerConfig(t *testing.T) config.MCPServerConfig {
|
||||
t.Helper()
|
||||
script := filepath.Join(testdataDir(t), "echo_server.py")
|
||||
if _, err := os.Stat(script); err != nil {
|
||||
t.Skipf("echo_server.py not found: %v", err)
|
||||
}
|
||||
return config.MCPServerConfig{
|
||||
Command: []string{"python3", script},
|
||||
}
|
||||
}
|
||||
|
||||
// newTestAgent creates a minimal Agent with a mock model and no core tools,
|
||||
// suitable for testing MCP server management without an API key.
|
||||
func newTestAgent() *Agent {
|
||||
model := &mockModel{}
|
||||
a := &Agent{
|
||||
model: model,
|
||||
coreTools: nil,
|
||||
extraTools: nil,
|
||||
maxSteps: 10,
|
||||
systemPrompt: "test",
|
||||
fantasyAgent: fantasy.NewAgent(model),
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func TestAgent_AddMCPServer(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
a := newTestAgent()
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Initially no MCP tools.
|
||||
if a.GetMCPToolCount() != 0 {
|
||||
t.Fatalf("Expected 0 MCP tools initially, got %d", a.GetMCPToolCount())
|
||||
}
|
||||
|
||||
// Add a server.
|
||||
count, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMCPServer failed: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("Expected 2 tools, got %d", count)
|
||||
}
|
||||
|
||||
// Verify tools are in the agent's tool list.
|
||||
if a.GetMCPToolCount() != 2 {
|
||||
t.Errorf("Expected 2 MCP tools, got %d", a.GetMCPToolCount())
|
||||
}
|
||||
|
||||
allTools := a.GetTools()
|
||||
toolNames := make(map[string]bool)
|
||||
for _, tool := range allTools {
|
||||
toolNames[tool.Info().Name] = true
|
||||
}
|
||||
if !toolNames["echo__echo"] {
|
||||
t.Error("Expected tool 'echo__echo' in agent tools")
|
||||
}
|
||||
if !toolNames["echo__greet"] {
|
||||
t.Error("Expected tool 'echo__greet' in agent tools")
|
||||
}
|
||||
|
||||
// Verify loaded server names.
|
||||
names := a.GetLoadedServerNames()
|
||||
found := false
|
||||
for _, n := range names {
|
||||
if n == "echo" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected 'echo' in loaded server names: %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_RemoveMCPServer(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
a := newTestAgent()
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add then remove.
|
||||
_, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMCPServer failed: %v", err)
|
||||
}
|
||||
|
||||
err = a.RemoveMCPServer("echo")
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveMCPServer failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify tools removed.
|
||||
if a.GetMCPToolCount() != 0 {
|
||||
t.Errorf("Expected 0 MCP tools after removal, got %d", a.GetMCPToolCount())
|
||||
}
|
||||
|
||||
// Verify agent's tool list has no MCP tools.
|
||||
for _, tool := range a.GetTools() {
|
||||
if strings.Contains(tool.Info().Name, "echo__") {
|
||||
t.Errorf("Found leftover tool after removal: %s", tool.Info().Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_RemoveMCPServer_NoToolManager(t *testing.T) {
|
||||
a := newTestAgent()
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
err := a.RemoveMCPServer("nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no tool manager exists")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no MCP servers loaded") {
|
||||
t.Errorf("Expected 'no MCP servers loaded' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_AddMCPServer_CreatesToolManager(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
a := newTestAgent()
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
// Initially no tool manager.
|
||||
if a.GetMCPToolManager() != nil {
|
||||
t.Fatal("Expected nil tool manager initially")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
_, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMCPServer failed: %v", err)
|
||||
}
|
||||
|
||||
// Tool manager should now exist.
|
||||
if a.GetMCPToolManager() == nil {
|
||||
t.Fatal("Expected tool manager to be created by AddMCPServer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_AddRemoveAdd_MCP(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
a := newTestAgent()
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add → Remove → Add cycle.
|
||||
_, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("First add failed: %v", err)
|
||||
}
|
||||
|
||||
err = a.RemoveMCPServer("echo")
|
||||
if err != nil {
|
||||
t.Fatalf("Remove failed: %v", err)
|
||||
}
|
||||
|
||||
count, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Re-add failed: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("Expected 2 tools on re-add, got %d", count)
|
||||
}
|
||||
if a.GetMCPToolCount() != 2 {
|
||||
t.Errorf("Expected 2 MCP tools after re-add, got %d", a.GetMCPToolCount())
|
||||
}
|
||||
}
|
||||
@@ -38,6 +38,10 @@ type AgentCreationOptions struct {
|
||||
DebugLogger tools.DebugLogger // Optional debug logger
|
||||
// AuthHandler handles OAuth authorization for remote MCP servers
|
||||
AuthHandler tools.MCPAuthHandler
|
||||
// TokenStoreFactory, if non-nil, creates a custom token store for each
|
||||
// remote MCP server's OAuth tokens. When nil, the default file-based
|
||||
// token store is used.
|
||||
TokenStoreFactory tools.TokenStoreFactory
|
||||
// CoreTools overrides the default core tool set. If empty, core.AllTools()
|
||||
// is used.
|
||||
CoreTools []fantasy.AgentTool
|
||||
@@ -66,6 +70,7 @@ func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error
|
||||
StreamingEnabled: opts.StreamingEnabled,
|
||||
DebugLogger: opts.DebugLogger,
|
||||
AuthHandler: opts.AuthHandler,
|
||||
TokenStoreFactory: opts.TokenStoreFactory,
|
||||
CoreTools: opts.CoreTools,
|
||||
DisableCoreTools: opts.DisableCoreTools,
|
||||
ToolWrapper: opts.ToolWrapper,
|
||||
|
||||
+44
-18
@@ -930,7 +930,8 @@ func (a *App) QuitFromExtension() {
|
||||
// controls styling: "" for plain text, "info" for a system message block,
|
||||
// "error" for an error block. In interactive mode it sends an
|
||||
// ExtensionPrintEvent through the program so the TUI can render it with the
|
||||
// appropriate renderer. In non-interactive mode it falls back to stdout.
|
||||
// appropriate renderer. In non-interactive mode it falls back to stderr with
|
||||
// a level prefix so errors are distinguishable from plain output.
|
||||
func (a *App) PrintFromExtension(level, text string) {
|
||||
a.mu.Lock()
|
||||
prog := a.program
|
||||
@@ -939,8 +940,16 @@ func (a *App) PrintFromExtension(level, text string) {
|
||||
prog.Send(ExtensionPrintEvent{Text: text, Level: level})
|
||||
return
|
||||
}
|
||||
// Non-interactive fallback: write directly to stdout.
|
||||
fmt.Println(text)
|
||||
// Non-interactive fallback: write to stderr with a level prefix so that
|
||||
// errors and info messages are distinguishable from plain output.
|
||||
switch level {
|
||||
case "error":
|
||||
fmt.Fprintf(os.Stderr, "[ERROR] %s\n", text)
|
||||
case "info":
|
||||
fmt.Fprintf(os.Stderr, "[INFO] %s\n", text)
|
||||
default:
|
||||
fmt.Println(text)
|
||||
}
|
||||
}
|
||||
|
||||
// SetEditorTextFromExtension sends an EditorTextSetEvent to the TUI to
|
||||
@@ -1122,11 +1131,12 @@ func (a *App) PrintBlockFromExtension(opts extensions.PrintBlockOpts) {
|
||||
})
|
||||
return
|
||||
}
|
||||
// Non-interactive fallback.
|
||||
// Non-interactive fallback: render a simple framed block to stderr so
|
||||
// it is visually distinct from plain stdout output.
|
||||
if opts.Subtitle != "" {
|
||||
fmt.Printf("%s\n — %s\n", opts.Text, opts.Subtitle)
|
||||
fmt.Fprintf(os.Stderr, "--- %s ---\n%s\n", opts.Subtitle, opts.Text)
|
||||
} else {
|
||||
fmt.Println(opts.Text)
|
||||
fmt.Fprintf(os.Stderr, "---\n%s\n---\n", opts.Text)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1155,9 +1165,10 @@ func (a *App) recordStepUsage(ev kit.StepUsageEvent, stepUsageSeen *atomic.Bool)
|
||||
int(ev.CacheWriteTokens),
|
||||
)
|
||||
// NOTE: We do NOT call SetContextTokens here. Context fill is set once
|
||||
// at turn completion via updateUsageFromTurnResult using FinalUsage.InputTokens,
|
||||
// which reflects the full accumulated context. Per-step context tokens would
|
||||
// cause the display to jump around during multi-step tool calls.
|
||||
// at turn completion via updateUsageFromTurnResult, which sums all token
|
||||
// categories (Input + CacheRead + CacheCreate + Output) from FinalUsage.
|
||||
// Per-step context tokens would cause the display to jump around during
|
||||
// multi-step tool calls.
|
||||
}
|
||||
|
||||
// updateUsageFromTurnResult records token usage from an SDK TurnResult into the
|
||||
@@ -1221,15 +1232,30 @@ func (a *App) updateUsageFromTurnResult(result *kit.TurnResult, userPrompt strin
|
||||
}
|
||||
|
||||
// --- Context window fill (drives the % bar) ---
|
||||
// Use FinalUsage.InputTokens as the context window fill. The API's InputTokens
|
||||
// already includes the full conversation history (system prompt + all previous
|
||||
// messages + current user message). Adding OutputTokens would double-count since
|
||||
// the output becomes part of the input for the next turn.
|
||||
if result.FinalUsage != nil && result.FinalUsage.InputTokens > 0 {
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: calling SetContextTokens=%d (FinalUsage.InputTokens)",
|
||||
result.FinalUsage.InputTokens)
|
||||
// Calculate context fill from the LAST API call's usage. The context
|
||||
// window is filled by everything sent to and received from the model:
|
||||
//
|
||||
// InputTokens — non-cached input (may be small with prompt caching)
|
||||
// CacheReadTokens — input tokens served from cache
|
||||
// CacheCreationTokens — input tokens written to cache this call
|
||||
// OutputTokens — assistant output (becomes input next turn)
|
||||
//
|
||||
// With Anthropic prompt caching, InputTokens can drop to near-zero while
|
||||
// CacheReadTokens holds the bulk of the context. We must sum all four to
|
||||
// get the true context window utilization.
|
||||
//
|
||||
// We use FinalUsage (last step only), NOT TotalUsage, because TotalUsage
|
||||
// sums across all tool-calling steps — and each step re-sends the full
|
||||
// conversation, so TotalUsage massively overstates the actual window fill.
|
||||
if result.FinalUsage != nil {
|
||||
u := result.FinalUsage
|
||||
contextFill := int(u.InputTokens) + int(u.CacheReadTokens) + int(u.CacheCreationTokens) + int(u.OutputTokens)
|
||||
if contextFill > 0 {
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: SetContextTokens=%d (Input=%d + CacheRead=%d + CacheCreate=%d + Output=%d)",
|
||||
contextFill, u.InputTokens, u.CacheReadTokens, u.CacheCreationTokens, u.OutputTokens)
|
||||
}
|
||||
a.opts.UsageTracker.SetContextTokens(contextFill)
|
||||
}
|
||||
a.opts.UsageTracker.SetContextTokens(int(result.FinalUsage.InputTokens))
|
||||
}
|
||||
}
|
||||
|
||||
+19
-13
@@ -630,10 +630,12 @@ func TestUpdateUsageFromTurnResult_recordsWhenInputTokensZero(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateUsageFromTurnResult_contextTokensUsesInputOnly verifies that context
|
||||
// window fill uses InputTokens only (not input+output). The API's InputTokens
|
||||
// already includes the full conversation history; adding output would double-count.
|
||||
func TestUpdateUsageFromTurnResult_contextTokensUsesInputOnly(t *testing.T) {
|
||||
// TestUpdateUsageFromTurnResult_contextTokensUsesAllCategories verifies that
|
||||
// context window fill uses all token categories from the final API call:
|
||||
// InputTokens + CacheReadTokens + CacheCreationTokens + OutputTokens.
|
||||
// With Anthropic prompt caching, InputTokens can be near-zero while
|
||||
// CacheReadTokens holds the bulk of the context.
|
||||
func TestUpdateUsageFromTurnResult_contextTokensUsesAllCategories(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
defer app.Close()
|
||||
@@ -641,22 +643,26 @@ func TestUpdateUsageFromTurnResult_contextTokensUsesInputOnly(t *testing.T) {
|
||||
app.updateUsageFromTurnResult(&kit.TurnResult{
|
||||
Response: "ok",
|
||||
TotalUsage: &kit.LLMUsage{
|
||||
InputTokens: 1000,
|
||||
OutputTokens: 200,
|
||||
InputTokens: 3,
|
||||
OutputTokens: 5,
|
||||
CacheReadTokens: 0,
|
||||
CacheCreationTokens: 4317,
|
||||
},
|
||||
FinalUsage: &kit.LLMUsage{
|
||||
InputTokens: 1000, // Full context including history
|
||||
OutputTokens: 200,
|
||||
InputTokens: 3, // Non-cached input (small with caching)
|
||||
OutputTokens: 5, // Assistant output
|
||||
CacheReadTokens: 0, // No cache reads on first call
|
||||
CacheCreationTokens: 4317, // System prompt + tools written to cache
|
||||
},
|
||||
}, "prompt", false)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
|
||||
// Context tokens should be InputTokens only (1000), not input+output (1200)
|
||||
// because InputTokens already includes the full conversation history
|
||||
if usage.contextCalls != 1 || usage.lastContextTokens != 1000 {
|
||||
t.Fatalf("expected context tokens=1000 (InputTokens only), got calls=%d tokens=%d",
|
||||
usage.contextCalls, usage.lastContextTokens)
|
||||
// Context tokens should be Input + CacheRead + CacheCreate + Output = 4325
|
||||
expected := 3 + 0 + 4317 + 5
|
||||
if usage.contextCalls != 1 || usage.lastContextTokens != expected {
|
||||
t.Fatalf("expected context tokens=%d (all categories), got calls=%d tokens=%d",
|
||||
expected, usage.contextCalls, usage.lastContextTokens)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,8 +21,10 @@ type UsageUpdater interface {
|
||||
// the provider does not return exact counts.
|
||||
EstimateAndUpdateUsage(inputText, outputText string)
|
||||
// SetContextTokens records the approximate current context window fill
|
||||
// level. This should be the final API call's input+output tokens (from
|
||||
// FinalResponse.Usage), NOT the aggregate TotalUsage.
|
||||
// level. This should be the sum of ALL token categories from the last
|
||||
// API call: InputTokens + CacheReadTokens + CacheCreationTokens +
|
||||
// OutputTokens. With Anthropic prompt caching, InputTokens can be
|
||||
// near-zero while CacheReadTokens holds the bulk of the context.
|
||||
SetContextTokens(tokens int)
|
||||
}
|
||||
|
||||
|
||||
@@ -157,6 +157,21 @@ type Theme struct {
|
||||
Markdown MarkdownThemeConfig `json:"markdown,omitzero" yaml:"markdown,omitempty"`
|
||||
}
|
||||
|
||||
// GenerationParams defines generation parameter defaults that can be attached
|
||||
// to individual models. These act as model-level defaults — CLI flags and
|
||||
// global config values take precedence when explicitly set.
|
||||
type GenerationParams struct {
|
||||
MaxTokens *int `json:"maxTokens,omitempty" yaml:"maxTokens,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
TopP *float32 `json:"topP,omitempty" yaml:"topP,omitempty"`
|
||||
TopK *int32 `json:"topK,omitempty" yaml:"topK,omitempty"`
|
||||
FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty" yaml:"frequencyPenalty,omitempty"`
|
||||
PresencePenalty *float32 `json:"presencePenalty,omitempty" yaml:"presencePenalty,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty" yaml:"stopSequences,omitempty"`
|
||||
ThinkingLevel string `json:"thinkingLevel,omitempty" yaml:"thinkingLevel,omitempty"`
|
||||
SystemPrompt string `json:"systemPrompt,omitempty" yaml:"systemPrompt,omitempty"`
|
||||
}
|
||||
|
||||
// CustomModelConfig defines a custom model that can be used with custom/custom
|
||||
// or other custom/ prefixed models. These models are loaded from the config file
|
||||
// and merged into the custom provider in the model registry.
|
||||
@@ -171,6 +186,11 @@ type CustomModelConfig struct {
|
||||
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
|
||||
Cost CostConfig `json:"cost" yaml:"cost"`
|
||||
Limit LimitConfig `json:"limit" yaml:"limit"`
|
||||
|
||||
// Generation parameter defaults for this model.
|
||||
// These are applied when the user hasn't explicitly set the corresponding
|
||||
// CLI flag or global config value.
|
||||
Params GenerationParams `json:"params,omitzero" yaml:"params,omitempty"`
|
||||
}
|
||||
|
||||
// CostConfig defines the pricing for a custom model.
|
||||
@@ -219,6 +239,12 @@ type Config struct {
|
||||
|
||||
// Custom model definitions (under custom/ provider)
|
||||
CustomModels map[string]CustomModelConfig `json:"customModels,omitempty" yaml:"customModels,omitempty"`
|
||||
|
||||
// Per-model generation parameter overrides. Keys are "provider/model" strings
|
||||
// (e.g. "anthropic/claude-sonnet-4-5-20250929", "openai/gpt-4o"). These
|
||||
// settings act as model-level defaults — CLI flags and global config values
|
||||
// take precedence when explicitly set.
|
||||
ModelSettings map[string]GenerationParams `json:"modelSettings,omitempty" yaml:"modelSettings,omitempty"`
|
||||
}
|
||||
|
||||
// GetTransportType returns the transport type for the server config, mapping
|
||||
@@ -367,7 +393,7 @@ mcpServers:
|
||||
# debug: false # Enable debug logging
|
||||
# system-prompt: "/path/to/system-prompt.txt" # System prompt text file
|
||||
|
||||
# Model generation parameters (all optional)
|
||||
# Model generation parameters (all optional, apply globally to all models)
|
||||
# max-tokens: 4096 # Maximum tokens in response
|
||||
# temperature: 0.7 # Randomness (0.0-1.0)
|
||||
# top-p: 0.95 # Nucleus sampling (0.0-1.0)
|
||||
@@ -376,9 +402,46 @@ mcpServers:
|
||||
# presence-penalty: 0.0 # Penalize present tokens (0.0-2.0)
|
||||
# stop-sequences: ["Human:", "Assistant:"] # Custom stop sequences
|
||||
|
||||
# Per-model generation parameter overrides (apply to specific models)
|
||||
# These act as model-level defaults — CLI flags and global settings above take precedence.
|
||||
# Keys are "provider/model" strings matching the model you use.
|
||||
# modelSettings:
|
||||
# anthropic/claude-sonnet-4-5-20250929:
|
||||
# temperature: 0.3
|
||||
# maxTokens: 8192
|
||||
# openai/gpt-4o:
|
||||
# temperature: 0.7
|
||||
# topP: 0.95
|
||||
# topK: 40
|
||||
# frequencyPenalty: 0.1
|
||||
# presencePenalty: 0.1
|
||||
# anthropic/claude-opus-4-6:
|
||||
# thinkingLevel: "high"
|
||||
# maxTokens: 16384
|
||||
# systemPrompt: "You are a deep reasoning assistant." # or a file path
|
||||
|
||||
# API Configuration (can also use environment variables)
|
||||
# provider-api-key: "your-api-key" # API key for OpenAI, Anthropic, or Google
|
||||
# provider-url: "https://api.openai.com/v1" # Base URL for OpenAI, Anthropic, or Ollama
|
||||
|
||||
# Custom model definitions (under custom/ provider)
|
||||
# customModels:
|
||||
# my-local-llama:
|
||||
# name: "Local Llama 3"
|
||||
# baseUrl: "http://localhost:8080/v1"
|
||||
# family: "llama"
|
||||
# temperature: true
|
||||
# cost:
|
||||
# input: 0.0
|
||||
# output: 0.0
|
||||
# limit:
|
||||
# context: 131072
|
||||
# output: 8192
|
||||
# params: # Generation parameter defaults for this model
|
||||
# temperature: 0.8
|
||||
# topP: 0.95
|
||||
# topK: 40
|
||||
# systemPrompt: "You are a helpful local assistant."
|
||||
`
|
||||
|
||||
_, err = file.WriteString(content)
|
||||
|
||||
@@ -34,15 +34,10 @@ func LoadExtensions(extraPaths []string) ([]LoadedExtension, error) {
|
||||
for _, p := range paths {
|
||||
ext, err := loadSingleExtension(p)
|
||||
if err != nil {
|
||||
log.Warn("skipping extension", "path", p, "err", err)
|
||||
continue
|
||||
}
|
||||
loaded = append(loaded, *ext)
|
||||
log.Debug("loaded extension", "path", p,
|
||||
"handlers", countHandlers(ext),
|
||||
"tools", len(ext.Tools),
|
||||
"commands", len(ext.Commands),
|
||||
"tool_renderers", len(ext.ToolRenderers))
|
||||
log.Debug("loaded extension", "path", p, "handlers", countHandlers(ext), "tools", len(ext.Tools), "commands", len(ext.Commands), "tool_renderers", len(ext.ToolRenderers))
|
||||
}
|
||||
return loaded, nil
|
||||
}
|
||||
|
||||
@@ -2,12 +2,12 @@ package extensions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
@@ -370,10 +370,7 @@ func (r *Runner) Emit(event Event) (Result, error) {
|
||||
for _, handler := range handlers {
|
||||
result, err := safeCall(handler, event, ctx)
|
||||
if err != nil {
|
||||
log.Warn("extension handler error",
|
||||
"path", ext.Path,
|
||||
"event", event.Type(),
|
||||
"err", err)
|
||||
log.Printf("WARN extension handler error: path=%s event=%s err=%v", ext.Path, event.Type(), err)
|
||||
continue
|
||||
}
|
||||
if result == nil {
|
||||
@@ -707,9 +704,7 @@ func (r *Runner) EmitCustomEvent(name, data string) {
|
||||
safeInvoke := func(h func(string)) {
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
log.Warn("custom event handler panicked",
|
||||
"event", name,
|
||||
"err", fmt.Sprintf("%v", rec))
|
||||
log.Printf("WARN custom event handler panicked: event=%s err=%v", name, rec)
|
||||
}
|
||||
}()
|
||||
h(data)
|
||||
|
||||
@@ -3,13 +3,13 @@ package extensions
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
@@ -39,7 +39,7 @@ func NewWatcher(dirs []string, onReload func()) (*Watcher, error) {
|
||||
for _, dir := range dirs {
|
||||
// Watch the directory itself.
|
||||
if err := fsw.Add(dir); err != nil {
|
||||
log.Debug("watcher: skipping directory", "dir", dir, "err", err)
|
||||
log.Printf("DEBUG watcher: skipping directory: dir=%s err=%v", dir, err)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ func NewWatcher(dirs []string, onReload func()) (*Watcher, error) {
|
||||
if entry.IsDir() {
|
||||
subdir := filepath.Join(dir, entry.Name())
|
||||
if err := fsw.Add(subdir); err != nil {
|
||||
log.Debug("watcher: skipping subdirectory", "dir", subdir, "err", err)
|
||||
log.Printf("DEBUG watcher: skipping subdirectory: dir=%s err=%v", subdir, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -101,7 +101,7 @@ func (w *Watcher) Start(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug("watcher: file changed", "file", event.Name, "op", event.Op)
|
||||
log.Printf("DEBUG watcher: file changed: file=%s op=%s", event.Name, event.Op)
|
||||
|
||||
// Debounce: reset timer on each event.
|
||||
if timer != nil {
|
||||
@@ -113,14 +113,14 @@ func (w *Watcher) Start(ctx context.Context) {
|
||||
case <-timerC:
|
||||
timerC = nil
|
||||
timer = nil
|
||||
log.Debug("watcher: reloading extensions")
|
||||
log.Printf("DEBUG watcher: reloading extensions")
|
||||
w.onReload()
|
||||
|
||||
case err, ok := <-w.watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Warn("watcher: error", "err", err)
|
||||
log.Printf("WARN watcher: error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+44
-20
@@ -65,6 +65,10 @@ type AgentSetupOptions struct {
|
||||
// AuthHandler handles OAuth authorization for remote MCP servers.
|
||||
// When set, remote transports are configured with OAuth support.
|
||||
AuthHandler tools.MCPAuthHandler
|
||||
// TokenStoreFactory, if non-nil, creates a custom token store for each
|
||||
// remote MCP server's OAuth tokens. When nil, the default file-based
|
||||
// token store is used.
|
||||
TokenStoreFactory tools.TokenStoreFactory
|
||||
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
|
||||
// loading (successfully or with error). Called from the background goroutine.
|
||||
OnMCPServerLoaded func(serverName string, toolCount int, err error)
|
||||
@@ -82,36 +86,55 @@ type AgentSetupResult struct {
|
||||
|
||||
// BuildProviderConfig creates a *models.ProviderConfig from the current viper
|
||||
// state. All entry points (root, script, SDK) converge through this function.
|
||||
//
|
||||
// Generation parameter pointers (Temperature, TopP, etc.) are only set when
|
||||
// the user has explicitly configured them via CLI flag, environment variable,
|
||||
// or global config file. This allows per-model defaults from modelSettings
|
||||
// and customModels to fill in unset parameters downstream.
|
||||
func BuildProviderConfig() (*models.ProviderConfig, string, error) {
|
||||
systemPrompt, err := config.LoadSystemPrompt(viper.GetString("system-prompt"))
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to load system prompt: %w", err)
|
||||
}
|
||||
|
||||
temperature := float32(viper.GetFloat64("temperature"))
|
||||
topP := float32(viper.GetFloat64("top-p"))
|
||||
topK := int32(viper.GetInt("top-k"))
|
||||
frequencyPenalty := float32(viper.GetFloat64("frequency-penalty"))
|
||||
presencePenalty := float32(viper.GetFloat64("presence-penalty"))
|
||||
numGPU := int32(viper.GetInt("num-gpu-layers"))
|
||||
mainGPU := int32(viper.GetInt("main-gpu"))
|
||||
|
||||
cfg := &models.ProviderConfig{
|
||||
ModelString: viper.GetString("model"),
|
||||
SystemPrompt: systemPrompt,
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
ProviderURL: viper.GetString("provider-url"),
|
||||
MaxTokens: viper.GetInt("max-tokens"),
|
||||
Temperature: &temperature,
|
||||
TopP: &topP,
|
||||
TopK: &topK,
|
||||
FrequencyPenalty: &frequencyPenalty,
|
||||
PresencePenalty: &presencePenalty,
|
||||
StopSequences: viper.GetStringSlice("stop-sequences"),
|
||||
NumGPU: &numGPU,
|
||||
MainGPU: &mainGPU,
|
||||
TLSSkipVerify: viper.GetBool("tls-skip-verify"),
|
||||
ThinkingLevel: models.ParseThinkingLevel(viper.GetString("thinking-level")),
|
||||
ModelString: viper.GetString("model"),
|
||||
SystemPrompt: systemPrompt,
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
ProviderURL: viper.GetString("provider-url"),
|
||||
MaxTokens: viper.GetInt("max-tokens"),
|
||||
StopSequences: viper.GetStringSlice("stop-sequences"),
|
||||
NumGPU: &numGPU,
|
||||
MainGPU: &mainGPU,
|
||||
TLSSkipVerify: viper.GetBool("tls-skip-verify"),
|
||||
ThinkingLevel: models.ParseThinkingLevel(viper.GetString("thinking-level")),
|
||||
}
|
||||
|
||||
// Only set generation parameter pointers when the user has explicitly
|
||||
// provided a value. This leaves nil pointers for unset params, allowing
|
||||
// per-model defaults (modelSettings / customModels params) to apply.
|
||||
if viper.IsSet("temperature") {
|
||||
v := float32(viper.GetFloat64("temperature"))
|
||||
cfg.Temperature = &v
|
||||
}
|
||||
if viper.IsSet("top-p") {
|
||||
v := float32(viper.GetFloat64("top-p"))
|
||||
cfg.TopP = &v
|
||||
}
|
||||
if viper.IsSet("top-k") {
|
||||
v := int32(viper.GetInt("top-k"))
|
||||
cfg.TopK = &v
|
||||
}
|
||||
if viper.IsSet("frequency-penalty") {
|
||||
v := float32(viper.GetFloat64("frequency-penalty"))
|
||||
cfg.FrequencyPenalty = &v
|
||||
}
|
||||
if viper.IsSet("presence-penalty") {
|
||||
v := float32(viper.GetFloat64("presence-penalty"))
|
||||
cfg.PresencePenalty = &v
|
||||
}
|
||||
|
||||
return cfg, systemPrompt, nil
|
||||
@@ -200,6 +223,7 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
SpinnerFunc: opts.SpinnerFunc,
|
||||
DebugLogger: debugLogger,
|
||||
AuthHandler: opts.AuthHandler,
|
||||
TokenStoreFactory: opts.TokenStoreFactory,
|
||||
CoreTools: opts.CoreTools,
|
||||
DisableCoreTools: opts.DisableCoreTools,
|
||||
ToolWrapper: toolWrapper,
|
||||
|
||||
+234
-11
@@ -2,6 +2,8 @@ package models
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
@@ -31,7 +33,7 @@ func loadCustomModelsFromConfig() map[string]ModelInfo {
|
||||
|
||||
// modelConfigToModelInfo converts a CustomModelConfig to a ModelInfo.
|
||||
func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
|
||||
return ModelInfo{
|
||||
info := ModelInfo{
|
||||
ID: modelID,
|
||||
Name: cfg.Name,
|
||||
Attachment: cfg.Attachment,
|
||||
@@ -48,21 +50,242 @@ func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
|
||||
Output: cfg.Limit.Output,
|
||||
},
|
||||
}
|
||||
|
||||
// Convert custom model generation params if any are set.
|
||||
if p := convertGenerationParams(cfg.Params); p != nil {
|
||||
info.Params = p
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// LoadModelSettingsFromConfig loads per-model generation parameter overrides
|
||||
// from the config file. Keys are "provider/model" strings. Returns nil if
|
||||
// no model settings are configured.
|
||||
func LoadModelSettingsFromConfig() map[string]*GenerationParams {
|
||||
if !viper.IsSet("modelSettings") {
|
||||
return nil
|
||||
}
|
||||
|
||||
var settings map[string]GenerationParamsConfig
|
||||
if err := viper.UnmarshalKey("modelSettings", &settings); err != nil {
|
||||
log.Printf("Warning: Failed to parse modelSettings: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make(map[string]*GenerationParams, len(settings))
|
||||
for modelKey, cfg := range settings {
|
||||
if p := convertGenerationParams(cfg); p != nil {
|
||||
result[modelKey] = p
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// convertGenerationParams converts a GenerationParamsConfig to a GenerationParams.
|
||||
// Returns nil if no parameters are set.
|
||||
func convertGenerationParams(cfg GenerationParamsConfig) *GenerationParams {
|
||||
p := &GenerationParams{}
|
||||
any := false
|
||||
|
||||
if cfg.MaxTokens != nil {
|
||||
p.MaxTokens = cfg.MaxTokens
|
||||
any = true
|
||||
}
|
||||
if cfg.Temperature != nil {
|
||||
p.Temperature = cfg.Temperature
|
||||
any = true
|
||||
}
|
||||
if cfg.TopP != nil {
|
||||
p.TopP = cfg.TopP
|
||||
any = true
|
||||
}
|
||||
if cfg.TopK != nil {
|
||||
p.TopK = cfg.TopK
|
||||
any = true
|
||||
}
|
||||
if cfg.FrequencyPenalty != nil {
|
||||
p.FrequencyPenalty = cfg.FrequencyPenalty
|
||||
any = true
|
||||
}
|
||||
if cfg.PresencePenalty != nil {
|
||||
p.PresencePenalty = cfg.PresencePenalty
|
||||
any = true
|
||||
}
|
||||
if len(cfg.StopSequences) > 0 {
|
||||
p.StopSequences = cfg.StopSequences
|
||||
any = true
|
||||
}
|
||||
if cfg.ThinkingLevel != "" {
|
||||
p.ThinkingLevel = ParseThinkingLevel(cfg.ThinkingLevel)
|
||||
any = true
|
||||
}
|
||||
if cfg.SystemPrompt != "" {
|
||||
p.SystemPrompt = cfg.SystemPrompt
|
||||
any = true
|
||||
}
|
||||
|
||||
if !any {
|
||||
return nil
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// ApplyModelSettings merges per-model generation parameter defaults from the
|
||||
// registry into a ProviderConfig. Model-level params are only applied for
|
||||
// fields where the user has not explicitly set a value (i.e., the
|
||||
// corresponding viper key is not set via CLI flag or global config).
|
||||
//
|
||||
// The lookup order is:
|
||||
// 1. modelSettings["provider/model"] from config (highest model-level priority)
|
||||
// 2. ModelInfo.Params from custom model definitions
|
||||
//
|
||||
// Both are overridden by explicit CLI flags / global config values.
|
||||
func ApplyModelSettings(config *ProviderConfig, modelInfo *ModelInfo) {
|
||||
provider, modelName, err := ParseModelString(config.ModelString)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Collect model-level params: modelSettings override > custom model params.
|
||||
// modelSettings takes priority because it's the more specific/intentional config.
|
||||
var params *GenerationParams
|
||||
|
||||
// First check modelSettings from config.
|
||||
if settings := LoadModelSettingsFromConfig(); settings != nil {
|
||||
modelKey := provider + "/" + modelName
|
||||
if p, ok := settings[modelKey]; ok {
|
||||
params = p
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to ModelInfo.Params (from custom model definitions).
|
||||
if params == nil && modelInfo != nil && modelInfo.Params != nil {
|
||||
params = modelInfo.Params
|
||||
}
|
||||
|
||||
if params == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Apply each parameter only when the user hasn't explicitly set it.
|
||||
// We check viper.IsSet() which returns true only when the key was
|
||||
// set via CLI flag, environment variable, or config file global section.
|
||||
|
||||
if params.MaxTokens != nil && !isExplicitlySet("max-tokens") {
|
||||
config.MaxTokens = *params.MaxTokens
|
||||
}
|
||||
if params.Temperature != nil && !isExplicitlySet("temperature") {
|
||||
config.Temperature = params.Temperature
|
||||
}
|
||||
if params.TopP != nil && !isExplicitlySet("top-p") {
|
||||
config.TopP = params.TopP
|
||||
}
|
||||
if params.TopK != nil && !isExplicitlySet("top-k") {
|
||||
config.TopK = params.TopK
|
||||
}
|
||||
if params.FrequencyPenalty != nil && !isExplicitlySet("frequency-penalty") {
|
||||
config.FrequencyPenalty = params.FrequencyPenalty
|
||||
}
|
||||
if params.PresencePenalty != nil && !isExplicitlySet("presence-penalty") {
|
||||
config.PresencePenalty = params.PresencePenalty
|
||||
}
|
||||
if len(params.StopSequences) > 0 && !isExplicitlySet("stop-sequences") {
|
||||
config.StopSequences = params.StopSequences
|
||||
}
|
||||
if params.ThinkingLevel != "" && !isExplicitlySet("thinking-level") {
|
||||
config.ThinkingLevel = params.ThinkingLevel
|
||||
}
|
||||
if params.SystemPrompt != "" && config.SystemPrompt == "" {
|
||||
// Resolve file paths: if the value points to an existing file, read it.
|
||||
// We check config.SystemPrompt == "" rather than isExplicitlySet because
|
||||
// viper.BindPFlag causes IsSet to return true even for unset flags.
|
||||
config.SystemPrompt = LoadSystemPromptValue(params.SystemPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
// LoadSystemPromptValue resolves a system prompt value that may be either
|
||||
// inline text or a file path. If the value is a path to an existing file,
|
||||
// its contents are read and returned. Otherwise the string is returned as-is.
|
||||
// This mirrors config.LoadSystemPrompt but lives in the models package to
|
||||
// avoid circular dependencies.
|
||||
func LoadSystemPromptValue(input string) string {
|
||||
if input == "" {
|
||||
return ""
|
||||
}
|
||||
if info, err := os.Stat(input); err == nil && !info.IsDir() {
|
||||
content, err := os.ReadFile(input)
|
||||
if err != nil {
|
||||
log.Printf("Warning: failed to read system prompt file %q: %v", input, err)
|
||||
return input
|
||||
}
|
||||
return strings.TrimSpace(string(content))
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
// isExplicitlySet returns true when the user has explicitly set a config key
|
||||
// via CLI flag, environment variable, or the global section of the config file.
|
||||
// Model-level defaults should not override explicitly set values.
|
||||
func isExplicitlySet(key string) bool {
|
||||
// viper.IsSet returns true if the key has been set in any of the
|
||||
// data stores (flag, env, config file, default). We need to check
|
||||
// whether the value was set at the global config level (not just
|
||||
// as a default). For generation params, the global config keys use
|
||||
// hyphenated names (e.g. "max-tokens", "top-p").
|
||||
//
|
||||
// Since viper merges all sources, IsSet returns true even for config
|
||||
// file values. This means global config file values (e.g.
|
||||
// temperature: 0.7 at the top level) will correctly take precedence
|
||||
// over model-level defaults, which is the desired behavior.
|
||||
return viper.IsSet(key)
|
||||
}
|
||||
|
||||
// GenerationParams holds per-model generation parameter defaults.
|
||||
// These are stored on ModelInfo and applied during provider creation.
|
||||
// Nil pointer fields mean "no model-level default" — the global config
|
||||
// or CLI flag value (if any) will be used instead.
|
||||
type GenerationParams struct {
|
||||
MaxTokens *int
|
||||
Temperature *float32
|
||||
TopP *float32
|
||||
TopK *int32
|
||||
FrequencyPenalty *float32
|
||||
PresencePenalty *float32
|
||||
StopSequences []string
|
||||
ThinkingLevel ThinkingLevel
|
||||
SystemPrompt string // Per-model system prompt (inline text or file path)
|
||||
}
|
||||
|
||||
// CustomModelConfig defines a custom model configuration loaded from the config file.
|
||||
// This is a duplicate here to avoid circular dependencies with internal/config.
|
||||
type CustomModelConfig struct {
|
||||
Name string `json:"name" yaml:"name"`
|
||||
BaseURL string `json:"baseUrl,omitempty" yaml:"baseUrl,omitempty"`
|
||||
APIKey string `json:"apiKey,omitempty" yaml:"apiKey,omitempty"`
|
||||
Family string `json:"family,omitempty" yaml:"family,omitempty"`
|
||||
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
|
||||
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
|
||||
Temperature bool `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
|
||||
Cost CostConfig `json:"cost" yaml:"cost"`
|
||||
Limit LimitConfig `json:"limit" yaml:"limit"`
|
||||
Name string `json:"name" yaml:"name"`
|
||||
BaseURL string `json:"baseUrl,omitempty" yaml:"baseUrl,omitempty"`
|
||||
APIKey string `json:"apiKey,omitempty" yaml:"apiKey,omitempty"`
|
||||
Family string `json:"family,omitempty" yaml:"family,omitempty"`
|
||||
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
|
||||
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
|
||||
Temperature bool `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
|
||||
Cost CostConfig `json:"cost" yaml:"cost"`
|
||||
Limit LimitConfig `json:"limit" yaml:"limit"`
|
||||
Params GenerationParamsConfig `json:"params,omitzero" yaml:"params,omitempty"`
|
||||
}
|
||||
|
||||
// GenerationParamsConfig is the JSON/YAML-serializable form of generation
|
||||
// parameter defaults. Used in both customModels[].params and modelSettings[].
|
||||
type GenerationParamsConfig struct {
|
||||
MaxTokens *int `json:"maxTokens,omitempty" yaml:"maxTokens,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
TopP *float32 `json:"topP,omitempty" yaml:"topP,omitempty"`
|
||||
TopK *int32 `json:"topK,omitempty" yaml:"topK,omitempty"`
|
||||
FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty" yaml:"frequencyPenalty,omitempty"`
|
||||
PresencePenalty *float32 `json:"presencePenalty,omitempty" yaml:"presencePenalty,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty" yaml:"stopSequences,omitempty"`
|
||||
ThinkingLevel string `json:"thinkingLevel,omitempty" yaml:"thinkingLevel,omitempty"`
|
||||
SystemPrompt string `json:"systemPrompt,omitempty" yaml:"systemPrompt,omitempty"`
|
||||
}
|
||||
|
||||
// CostConfig defines the pricing for a custom model.
|
||||
|
||||
@@ -0,0 +1,422 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
func TestConvertGenerationParams(t *testing.T) {
|
||||
t.Run("empty config returns nil", func(t *testing.T) {
|
||||
cfg := GenerationParamsConfig{}
|
||||
p := convertGenerationParams(cfg)
|
||||
if p != nil {
|
||||
t.Errorf("expected nil, got %+v", p)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("temperature only", func(t *testing.T) {
|
||||
temp := float32(0.7)
|
||||
cfg := GenerationParamsConfig{Temperature: &temp}
|
||||
p := convertGenerationParams(cfg)
|
||||
if p == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if p.Temperature == nil || *p.Temperature != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %v", p.Temperature)
|
||||
}
|
||||
if p.TopP != nil {
|
||||
t.Errorf("expected nil TopP, got %v", p.TopP)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("all params set", func(t *testing.T) {
|
||||
maxTokens := 8192
|
||||
temp := float32(0.5)
|
||||
topP := float32(0.9)
|
||||
topK := int32(50)
|
||||
freqPenalty := float32(0.1)
|
||||
presPenalty := float32(0.2)
|
||||
cfg := GenerationParamsConfig{
|
||||
MaxTokens: &maxTokens,
|
||||
Temperature: &temp,
|
||||
TopP: &topP,
|
||||
TopK: &topK,
|
||||
FrequencyPenalty: &freqPenalty,
|
||||
PresencePenalty: &presPenalty,
|
||||
StopSequences: []string{"STOP"},
|
||||
ThinkingLevel: "high",
|
||||
}
|
||||
p := convertGenerationParams(cfg)
|
||||
if p == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if p.MaxTokens == nil || *p.MaxTokens != 8192 {
|
||||
t.Errorf("expected maxTokens 8192, got %v", p.MaxTokens)
|
||||
}
|
||||
if p.Temperature == nil || *p.Temperature != 0.5 {
|
||||
t.Errorf("expected temperature 0.5, got %v", p.Temperature)
|
||||
}
|
||||
if p.TopP == nil || *p.TopP != 0.9 {
|
||||
t.Errorf("expected topP 0.9, got %v", p.TopP)
|
||||
}
|
||||
if p.TopK == nil || *p.TopK != 50 {
|
||||
t.Errorf("expected topK 50, got %v", p.TopK)
|
||||
}
|
||||
if p.FrequencyPenalty == nil || *p.FrequencyPenalty != 0.1 {
|
||||
t.Errorf("expected frequencyPenalty 0.1, got %v", p.FrequencyPenalty)
|
||||
}
|
||||
if p.PresencePenalty == nil || *p.PresencePenalty != 0.2 {
|
||||
t.Errorf("expected presencePenalty 0.2, got %v", p.PresencePenalty)
|
||||
}
|
||||
if len(p.StopSequences) != 1 || p.StopSequences[0] != "STOP" {
|
||||
t.Errorf("expected stop sequences [STOP], got %v", p.StopSequences)
|
||||
}
|
||||
if p.ThinkingLevel != ThinkingHigh {
|
||||
t.Errorf("expected thinking level high, got %v", p.ThinkingLevel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("thinking level parsing", func(t *testing.T) {
|
||||
cfg := GenerationParamsConfig{ThinkingLevel: "medium"}
|
||||
p := convertGenerationParams(cfg)
|
||||
if p == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if p.ThinkingLevel != ThinkingMedium {
|
||||
t.Errorf("expected thinking level medium, got %v", p.ThinkingLevel)
|
||||
}
|
||||
})
|
||||
t.Run("system prompt only", func(t *testing.T) {
|
||||
cfg := GenerationParamsConfig{SystemPrompt: "You are helpful."}
|
||||
p := convertGenerationParams(cfg)
|
||||
if p == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if p.SystemPrompt != "You are helpful." {
|
||||
t.Errorf("expected system prompt, got %q", p.SystemPrompt)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelConfigToModelInfoWithParams(t *testing.T) {
|
||||
temp := float32(0.8)
|
||||
topP := float32(0.95)
|
||||
cfg := CustomModelConfig{
|
||||
Name: "Test Model",
|
||||
BaseURL: "http://localhost:8080/v1",
|
||||
Temperature: true,
|
||||
Params: GenerationParamsConfig{
|
||||
Temperature: &temp,
|
||||
TopP: &topP,
|
||||
},
|
||||
}
|
||||
|
||||
info := modelConfigToModelInfo("test-model", cfg)
|
||||
|
||||
if info.Params == nil {
|
||||
t.Fatal("expected non-nil Params")
|
||||
}
|
||||
if info.Params.Temperature == nil || *info.Params.Temperature != 0.8 {
|
||||
t.Errorf("expected temperature 0.8, got %v", info.Params.Temperature)
|
||||
}
|
||||
if info.Params.TopP == nil || *info.Params.TopP != 0.95 {
|
||||
t.Errorf("expected topP 0.95, got %v", info.Params.TopP)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelConfigToModelInfoWithoutParams(t *testing.T) {
|
||||
cfg := CustomModelConfig{
|
||||
Name: "Test Model",
|
||||
BaseURL: "http://localhost:8080/v1",
|
||||
}
|
||||
|
||||
info := modelConfigToModelInfo("test-model", cfg)
|
||||
|
||||
if info.Params != nil {
|
||||
t.Errorf("expected nil Params, got %+v", info.Params)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyModelSettings(t *testing.T) {
|
||||
// Save and restore viper state.
|
||||
originalViper := viper.AllSettings()
|
||||
defer func() {
|
||||
viper.Reset()
|
||||
for k, v := range originalViper {
|
||||
viper.Set(k, v)
|
||||
}
|
||||
}()
|
||||
|
||||
t.Run("applies model params when not explicitly set", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
temp := float32(0.8)
|
||||
topK := int32(50)
|
||||
maxTokens := 4096
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
Temperature: &temp,
|
||||
TopK: &topK,
|
||||
MaxTokens: &maxTokens,
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.Temperature == nil || *config.Temperature != 0.8 {
|
||||
t.Errorf("expected temperature 0.8, got %v", config.Temperature)
|
||||
}
|
||||
if config.TopK == nil || *config.TopK != 50 {
|
||||
t.Errorf("expected topK 50, got %v", config.TopK)
|
||||
}
|
||||
if config.MaxTokens != 4096 {
|
||||
t.Errorf("expected maxTokens 4096, got %d", config.MaxTokens)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("explicit viper values take precedence", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
viper.Set("temperature", 0.3)
|
||||
|
||||
temp := float32(0.8)
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
Temperature: &temp,
|
||||
},
|
||||
}
|
||||
|
||||
explicitTemp := float32(0.3)
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
Temperature: &explicitTemp,
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
// Temperature should NOT be overridden because it's explicitly set in viper
|
||||
if config.Temperature == nil || *config.Temperature != 0.3 {
|
||||
t.Errorf("expected temperature 0.3 (explicit), got %v", config.Temperature)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil model info is safe", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
// Should not panic
|
||||
ApplyModelSettings(config, nil)
|
||||
|
||||
if config.Temperature != nil {
|
||||
t.Errorf("expected nil temperature, got %v", config.Temperature)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model info without params is safe", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
modelInfo := &ModelInfo{ID: "test-model"}
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.Temperature != nil {
|
||||
t.Errorf("expected nil temperature, got %v", config.Temperature)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modelSettings from viper takes priority over ModelInfo.Params", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
// Set up modelSettings in viper (simulating config file)
|
||||
viper.Set("modelSettings", map[string]any{
|
||||
"custom/test-model": map[string]any{
|
||||
"temperature": 0.5,
|
||||
"topK": 30,
|
||||
},
|
||||
})
|
||||
|
||||
// ModelInfo has different params
|
||||
temp := float32(0.8)
|
||||
topK := int32(50)
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
Temperature: &temp,
|
||||
TopK: &topK,
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
// modelSettings should win over ModelInfo.Params
|
||||
if config.Temperature == nil || *config.Temperature != 0.5 {
|
||||
t.Errorf("expected temperature 0.5 (from modelSettings), got %v", config.Temperature)
|
||||
}
|
||||
if config.TopK == nil || *config.TopK != 30 {
|
||||
t.Errorf("expected topK 30 (from modelSettings), got %v", config.TopK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("stop sequences applied from model params", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
StopSequences: []string{"STOP", "END"},
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if len(config.StopSequences) != 2 || config.StopSequences[0] != "STOP" {
|
||||
t.Errorf("expected stop sequences [STOP END], got %v", config.StopSequences)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("thinking level applied from model params", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
ThinkingLevel: ThinkingHigh,
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.ThinkingLevel != ThinkingHigh {
|
||||
t.Errorf("expected thinking level high, got %v", config.ThinkingLevel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("system prompt applied from model params", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
SystemPrompt: "You are a coding assistant.",
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.SystemPrompt != "You are a coding assistant." {
|
||||
t.Errorf("expected system prompt to be set, got %q", config.SystemPrompt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("explicit system prompt takes precedence", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
SystemPrompt: "Model-specific prompt",
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
SystemPrompt: "Global prompt",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
// Global system prompt should NOT be overridden because config
|
||||
// already has a non-empty SystemPrompt.
|
||||
if config.SystemPrompt != "Global prompt" {
|
||||
t.Errorf("expected global prompt preserved, got %q", config.SystemPrompt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("system prompt from file path", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
// Create a temp file with a system prompt
|
||||
tmpFile, err := os.CreateTemp("", "kit-test-prompt-*.txt")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() { _ = os.Remove(tmpFile.Name()) }()
|
||||
if _, err := tmpFile.WriteString(" Prompt from file "); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_ = tmpFile.Close()
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
SystemPrompt: tmpFile.Name(),
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.SystemPrompt != "Prompt from file" {
|
||||
t.Errorf("expected trimmed file content, got %q", config.SystemPrompt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modelSettings system prompt overrides custom model params", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
viper.Set("modelSettings", map[string]any{
|
||||
"custom/test-model": map[string]any{
|
||||
"systemPrompt": "From modelSettings",
|
||||
},
|
||||
})
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
SystemPrompt: "From custom model",
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.SystemPrompt != "From modelSettings" {
|
||||
t.Errorf("expected modelSettings prompt, got %q", config.SystemPrompt)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -241,6 +241,11 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
|
||||
validateModelConfig(config, modelInfo)
|
||||
}
|
||||
|
||||
// Apply per-model generation parameter defaults. Model-level params are
|
||||
// only applied for fields where the user hasn't explicitly set a value
|
||||
// via CLI flag or global config.
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
// Create the base provider
|
||||
var result *ProviderResult
|
||||
var createErr error
|
||||
|
||||
@@ -26,6 +26,11 @@ type ModelInfo struct {
|
||||
ProviderNPM string // Model-specific provider npm override (e.g. "@ai-sdk/anthropic")
|
||||
BaseURL string // Per-model base URL override (custom models only)
|
||||
APIKey string // Per-model API key override (custom models only)
|
||||
|
||||
// Params holds per-model generation parameter defaults. These are applied
|
||||
// when the user hasn't explicitly set the corresponding CLI flag or global
|
||||
// config value. Nil pointer fields mean "no model-level default".
|
||||
Params *GenerationParams
|
||||
}
|
||||
|
||||
// SupportsCaching returns true if this model family supports prompt caching.
|
||||
@@ -236,6 +241,18 @@ func (r *ModelsRegistry) LookupModel(provider, modelID string) *ModelInfo {
|
||||
return &modelInfo
|
||||
}
|
||||
|
||||
// LookupModelForSettings is a convenience function that parses a
|
||||
// "provider/model" string and looks up the ModelInfo in the global registry.
|
||||
// Returns nil when the model string is invalid or the model is unknown.
|
||||
// Used by Kit.SetModel to pre-apply per-model settings before CreateProvider.
|
||||
func LookupModelForSettings(modelString string) *ModelInfo {
|
||||
provider, modelName, err := ParseModelString(modelString)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return GetGlobalRegistry().LookupModel(provider, modelName)
|
||||
}
|
||||
|
||||
// getRequiredEnvVars returns the required environment variables for a provider.
|
||||
func (r *ModelsRegistry) getRequiredEnvVars(provider string) ([]string, error) {
|
||||
providerInfo, exists := r.providers[provider]
|
||||
|
||||
@@ -2,11 +2,10 @@ package prompts
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
)
|
||||
|
||||
// LoadOptions configures how templates are discovered and loaded.
|
||||
@@ -74,10 +73,7 @@ func LoadAll(opts LoadOptions) ([]*PromptTemplate, []Diagnostic, error) {
|
||||
DroppedPath: tpl.FilePath,
|
||||
Reason: fmt.Sprintf("template from %s overridden by %s", source, existing.Source),
|
||||
})
|
||||
log.Debug("template collision",
|
||||
"name", tpl.Name,
|
||||
"dropped", tpl.FilePath,
|
||||
"kept", existing.FilePath)
|
||||
log.Printf("DEBUG template collision: name=%s dropped=%s kept=%s", tpl.Name, tpl.FilePath, existing.FilePath)
|
||||
} else {
|
||||
tpl.Source = source
|
||||
seen[tpl.Name] = tpl
|
||||
|
||||
@@ -0,0 +1,317 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/mark3labs/kit/internal/message"
|
||||
)
|
||||
|
||||
// TestCompactionCreatesNewLeaf verifies that after compaction, the compaction
|
||||
// entry has no parent (creating a new root), and BuildContext returns only
|
||||
// the summary and kept messages, not the old compacted messages.
|
||||
func TestCompactionCreatesNewLeaf(t *testing.T) {
|
||||
tm := InMemoryTreeSession("/test")
|
||||
|
||||
// Add some messages: M1, M2 (old, will be compacted), M3, M4 (kept)
|
||||
msg1 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Message 1 - old"}}}
|
||||
msg2 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 2 - old"}}}
|
||||
msg3 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Message 3 - kept"}}}
|
||||
msg4 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 4 - kept"}}}
|
||||
|
||||
_, _ = tm.AppendMessage(msg1)
|
||||
_, _ = tm.AppendMessage(msg2)
|
||||
id3, _ := tm.AppendMessage(msg3)
|
||||
id4, _ := tm.AppendMessage(msg4)
|
||||
|
||||
// Verify initial state - all messages should be in context
|
||||
messages, _, _ := tm.BuildContext()
|
||||
if len(messages) != 4 {
|
||||
t.Fatalf("expected 4 messages before compaction, got %d", len(messages))
|
||||
}
|
||||
|
||||
// Verify entry IDs
|
||||
entryIDs := tm.GetContextEntryIDs()
|
||||
if len(entryIDs) != 4 {
|
||||
t.Fatalf("expected 4 entry IDs before compaction, got %d", len(entryIDs))
|
||||
}
|
||||
|
||||
// Now add a compaction entry, simulating that M3 is the first kept entry
|
||||
summary := "Summary of old messages"
|
||||
compactionID, err := tm.AppendCompaction(summary, id3, 1000, 500, 2, []string{}, []string{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to append compaction: %v", err)
|
||||
}
|
||||
|
||||
// Verify the compaction entry has no parent (empty ParentID)
|
||||
compactionEntry := tm.GetEntry(compactionID).(*CompactionEntry)
|
||||
if compactionEntry.ParentID != "" {
|
||||
t.Errorf("compaction entry should have no parent, got %q", compactionEntry.ParentID)
|
||||
}
|
||||
|
||||
// Verify the leaf is now the compaction entry
|
||||
if tm.GetLeafID() != compactionID {
|
||||
t.Errorf("leaf should be compaction entry %q, got %q", compactionID, tm.GetLeafID())
|
||||
}
|
||||
|
||||
// Now BuildContext should return: [summary] + [M3, M4]
|
||||
messages, _, _ = tm.BuildContext()
|
||||
if len(messages) != 3 {
|
||||
t.Fatalf("expected 3 messages after compaction (summary + 2 kept), got %d", len(messages))
|
||||
}
|
||||
|
||||
// First message should be the summary
|
||||
if messages[0].Role != fantasy.MessageRoleSystem {
|
||||
t.Errorf("first message should be system summary, got %s", messages[0].Role)
|
||||
}
|
||||
summaryText := messages[0].Content[0].(fantasy.TextPart).Text
|
||||
if summaryText != "[Conversation summary — earlier messages were compacted]\n\n"+summary {
|
||||
t.Errorf("unexpected summary text: %s", summaryText)
|
||||
}
|
||||
|
||||
// Second message should be M3 (kept)
|
||||
if messages[1].Role != fantasy.MessageRoleUser {
|
||||
t.Errorf("second message should be user (M3), got %s", messages[1].Role)
|
||||
}
|
||||
m3Text := messages[1].Content[0].(fantasy.TextPart).Text
|
||||
if m3Text != "Message 3 - kept" {
|
||||
t.Errorf("unexpected M3 text: %s", m3Text)
|
||||
}
|
||||
|
||||
// Third message should be M4 (kept)
|
||||
if messages[2].Role != fantasy.MessageRoleAssistant {
|
||||
t.Errorf("third message should be assistant (M4), got %s", messages[2].Role)
|
||||
}
|
||||
m4Text := messages[2].Content[0].(fantasy.TextPart).Text
|
||||
if m4Text != "Message 4 - kept" {
|
||||
t.Errorf("unexpected M4 text: %s", m4Text)
|
||||
}
|
||||
|
||||
// Verify GetContextEntryIDs returns correct IDs
|
||||
entryIDs = tm.GetContextEntryIDs()
|
||||
if len(entryIDs) != 3 {
|
||||
t.Fatalf("expected 3 entry IDs after compaction (empty for summary + 2 kept), got %d: %v", len(entryIDs), entryIDs)
|
||||
}
|
||||
|
||||
// First entry ID should be empty (summary has no entry)
|
||||
if entryIDs[0] != "" {
|
||||
t.Errorf("first entry ID should be empty (summary), got %q", entryIDs[0])
|
||||
}
|
||||
|
||||
// Second and third should be id3 and id4 (the kept messages)
|
||||
if entryIDs[1] != id3 {
|
||||
t.Errorf("second entry ID should be %q (M3), got %q", id3, entryIDs[1])
|
||||
}
|
||||
if entryIDs[2] != id4 {
|
||||
t.Errorf("third entry ID should be %q (M4), got %q", id4, entryIDs[2])
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompactionWithNewMessagesAfterCompaction verifies that messages appended
|
||||
// after compaction are correctly included in the context.
|
||||
func TestCompactionWithNewMessagesAfterCompaction(t *testing.T) {
|
||||
tm := InMemoryTreeSession("/test")
|
||||
|
||||
// Add initial messages
|
||||
msg1 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Message 1"}}}
|
||||
msg2 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 2"}}}
|
||||
msg3 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Message 3 - kept"}}}
|
||||
|
||||
_, _ = tm.AppendMessage(msg1)
|
||||
_, _ = tm.AppendMessage(msg2)
|
||||
id3, _ := tm.AppendMessage(msg3)
|
||||
|
||||
// Compact, keeping only M3
|
||||
_, _ = tm.AppendCompaction("Summary", id3, 1000, 500, 2, []string{}, []string{})
|
||||
|
||||
// Add a new message after compaction
|
||||
msg4 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 4 - after compaction"}}}
|
||||
_, _ = tm.AppendMessage(msg4)
|
||||
|
||||
// BuildContext should return: [summary] + [M4 (new after compaction)] + [M3 (kept)]
|
||||
messages, _, _ := tm.BuildContext()
|
||||
if len(messages) != 3 {
|
||||
t.Fatalf("expected 3 messages (summary + M4 + M3), got %d: %+v", len(messages), messages)
|
||||
}
|
||||
|
||||
// Verify order: summary, M4 (new), M3 (kept)
|
||||
if messages[0].Role != fantasy.MessageRoleSystem {
|
||||
t.Errorf("first message should be summary, got %s", messages[0].Role)
|
||||
}
|
||||
if messages[1].Role != fantasy.MessageRoleAssistant {
|
||||
t.Errorf("second message should be assistant (M4), got %s", messages[1].Role)
|
||||
}
|
||||
m4Text := messages[1].Content[0].(fantasy.TextPart).Text
|
||||
if m4Text != "Message 4 - after compaction" {
|
||||
t.Errorf("unexpected M4 text: %s", m4Text)
|
||||
}
|
||||
if messages[2].Role != fantasy.MessageRoleUser {
|
||||
t.Errorf("third message should be user (M3), got %s", messages[2].Role)
|
||||
}
|
||||
|
||||
// Verify that M1 is NOT in the context
|
||||
for i, msg := range messages {
|
||||
if msg.Role == fantasy.MessageRoleUser {
|
||||
text := msg.Content[0].(fantasy.TextPart).Text
|
||||
if text == "Message 1" {
|
||||
t.Errorf("Message 1 (compacted) should not be in context at index %d", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompactionWithNoKeptMessages verifies compaction when all messages are compacted.
|
||||
func TestCompactionWithNoKeptMessages(t *testing.T) {
|
||||
tm := InMemoryTreeSession("/test")
|
||||
|
||||
// Add messages that will all be compacted
|
||||
msg1 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Message 1"}}}
|
||||
msg2 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 2"}}}
|
||||
|
||||
if _, err := tm.AppendMessage(msg1); err != nil {
|
||||
t.Fatalf("failed to append message: %v", err)
|
||||
}
|
||||
if _, err := tm.AppendMessage(msg2); err != nil {
|
||||
t.Fatalf("failed to append message: %v", err)
|
||||
}
|
||||
|
||||
// Compact with no kept messages (empty firstKeptEntryID)
|
||||
summary := "All messages summarized"
|
||||
compactionID, _ := tm.AppendCompaction(summary, "", 1000, 100, 2, []string{}, []string{})
|
||||
|
||||
// Verify the compaction entry has no parent
|
||||
compactionEntry := tm.GetEntry(compactionID).(*CompactionEntry)
|
||||
if compactionEntry.ParentID != "" {
|
||||
t.Errorf("compaction entry should have no parent, got %q", compactionEntry.ParentID)
|
||||
}
|
||||
|
||||
// BuildContext should return only the summary
|
||||
messages, _, _ := tm.BuildContext()
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected 1 message (summary only), got %d: %+v", len(messages), messages)
|
||||
}
|
||||
if messages[0].Role != fantasy.MessageRoleSystem {
|
||||
t.Errorf("message should be system summary, got %s", messages[0].Role)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultipleCompactions verifies that multiple compactions work correctly.
|
||||
func TestMultipleCompactions(t *testing.T) {
|
||||
tm := InMemoryTreeSession("/test")
|
||||
|
||||
// First batch of messages
|
||||
msg1 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Batch 1 - User"}}}
|
||||
msg2 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Batch 1 - Assistant"}}}
|
||||
id1, _ := tm.AppendMessage(msg1)
|
||||
id2, _ := tm.AppendMessage(msg2)
|
||||
|
||||
// First compaction
|
||||
_, _ = tm.AppendCompaction("Summary 1", id1, 1000, 500, 1, []string{}, []string{})
|
||||
|
||||
// Second batch
|
||||
msg3 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Batch 2 - User"}}}
|
||||
msg4 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Batch 2 - Assistant"}}}
|
||||
id3, _ := tm.AppendMessage(msg3)
|
||||
id4, _ := tm.AppendMessage(msg4)
|
||||
|
||||
// Second compaction (compacting the first compaction + batch 2)
|
||||
// Note: id3 is the first kept entry, so id3 and id4 should be preserved
|
||||
compactionID2, _ := tm.AppendCompaction("Summary 2", id3, 1000, 500, 3, []string{}, []string{})
|
||||
|
||||
// Verify second compaction has no parent
|
||||
compactionEntry2 := tm.GetEntry(compactionID2).(*CompactionEntry)
|
||||
if compactionEntry2.ParentID != "" {
|
||||
t.Errorf("second compaction entry should have no parent, got %q", compactionEntry2.ParentID)
|
||||
}
|
||||
|
||||
// Add final message
|
||||
msg5 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Final message"}}}
|
||||
id5, _ := tm.AppendMessage(msg5)
|
||||
|
||||
// BuildContext should include:
|
||||
// - Summary 2 (from second compaction)
|
||||
// - msg5 (final message)
|
||||
// - msg3, msg4 (kept from second compaction)
|
||||
// But NOT Summary 1 or msg1, msg2 (they're before the first kept entry of compaction 2)
|
||||
messages, _, _ := tm.BuildContext()
|
||||
|
||||
// Should have: Summary 2 + msg5 + msg3 + msg4 = 4 messages
|
||||
if len(messages) != 4 {
|
||||
t.Fatalf("expected 4 messages (Summary 2 + msg5 + msg3 + msg4), got %d: %+v", len(messages), messages)
|
||||
}
|
||||
|
||||
// First should be Summary 2
|
||||
if messages[0].Role != fantasy.MessageRoleSystem {
|
||||
t.Errorf("first message should be system (Summary 2), got %s", messages[0].Role)
|
||||
}
|
||||
summaryText := messages[0].Content[0].(fantasy.TextPart).Text
|
||||
if summaryText != "[Conversation summary — earlier messages were compacted]\n\nSummary 2" {
|
||||
t.Errorf("unexpected summary: %s", summaryText)
|
||||
}
|
||||
|
||||
// Verify msg5 is included
|
||||
foundFinal := false
|
||||
for _, msg := range messages {
|
||||
if msg.Role == fantasy.MessageRoleUser {
|
||||
text := msg.Content[0].(fantasy.TextPart).Text
|
||||
if text == "Final message" {
|
||||
foundFinal = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundFinal {
|
||||
t.Error("Final message (msg5) should be in context")
|
||||
}
|
||||
|
||||
// Verify msg1, msg2 are NOT included (compacted by first compaction, then second)
|
||||
for _, msg := range messages {
|
||||
if msg.Role == fantasy.MessageRoleUser || msg.Role == fantasy.MessageRoleAssistant {
|
||||
text := msg.Content[0].(fantasy.TextPart).Text
|
||||
if text == "Batch 1 - User" || text == "Batch 1 - Assistant" {
|
||||
t.Errorf("Batch 1 messages should not be in context, found: %s", text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify entry IDs
|
||||
entryIDs := tm.GetContextEntryIDs()
|
||||
if len(entryIDs) != 4 {
|
||||
t.Fatalf("expected 4 entry IDs, got %d: %v", len(entryIDs), entryIDs)
|
||||
}
|
||||
|
||||
// First should be empty (summary)
|
||||
if entryIDs[0] != "" {
|
||||
t.Errorf("first entry ID should be empty (summary), got %q", entryIDs[0])
|
||||
}
|
||||
|
||||
// Check that id5 is in the list
|
||||
if !slices.Contains(entryIDs, id5) {
|
||||
t.Errorf("id5 (final message) should be in entry IDs, got %v", entryIDs)
|
||||
}
|
||||
|
||||
// Verify id3 and id4 ARE in the list (they were kept)
|
||||
foundID3, foundID4 := false, false
|
||||
for _, id := range entryIDs {
|
||||
if id == id3 {
|
||||
foundID3 = true
|
||||
}
|
||||
if id == id4 {
|
||||
foundID4 = true
|
||||
}
|
||||
}
|
||||
if !foundID3 {
|
||||
t.Errorf("id3 (kept message) should be in entry IDs, got %v", entryIDs)
|
||||
}
|
||||
if !foundID4 {
|
||||
t.Errorf("id4 (kept message) should be in entry IDs, got %v", entryIDs)
|
||||
}
|
||||
|
||||
// Verify id1 and id2 are NOT in the list (they were compacted away)
|
||||
for _, id := range entryIDs {
|
||||
if id == id1 || id == id2 {
|
||||
t.Errorf("id1 or id2 (compacted) should not be in entry IDs, found %q in %v", id, entryIDs)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -509,11 +509,19 @@ func (tm *TreeManager) AppendExtensionData(extType, data string) (string, error)
|
||||
// AppendCompaction adds a compaction entry to the tree. The entry records
|
||||
// the summary and the ID of the first entry that should be preserved in the
|
||||
// LLM context. Messages before that entry are replaced by the summary.
|
||||
//
|
||||
// The compaction entry becomes a new "root" for the post-compaction branch
|
||||
// with no parent (empty ParentID). This breaks the parent chain so that old
|
||||
// compacted messages are no longer traversed when building context. The kept
|
||||
// messages are explicitly collected via FirstKeptEntryID in BuildContext.
|
||||
func (tm *TreeManager) AppendCompaction(summary, firstKeptEntryID string, tokensBefore, tokensAfter, messagesRemoved int, readFiles, modifiedFiles []string) (string, error) {
|
||||
tm.mu.Lock()
|
||||
defer tm.mu.Unlock()
|
||||
|
||||
entry := NewCompactionEntry(tm.leafID, summary, firstKeptEntryID, tokensBefore, tokensAfter, messagesRemoved, readFiles, modifiedFiles)
|
||||
// The compaction entry has no parent, making it a new "root" for the
|
||||
// post-compaction branch. This ensures old compacted messages are not
|
||||
// traversed when walking from the current leaf.
|
||||
entry := NewCompactionEntry("", summary, firstKeptEntryID, tokensBefore, tokensAfter, messagesRemoved, readFiles, modifiedFiles)
|
||||
if err := tm.appendAndPersist(entry); err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -683,14 +691,18 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
// Find the last compaction entry on this branch — it determines
|
||||
// which older messages are replaced by the summary.
|
||||
var lastCompaction *CompactionEntry
|
||||
var compactionIndex = -1
|
||||
for i := len(branch) - 1; i >= 0; i-- {
|
||||
if c, ok := branch[i].(*CompactionEntry); ok {
|
||||
lastCompaction = c
|
||||
compactionIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If there is a compaction, inject the summary first.
|
||||
// If there is a compaction, inject the summary first and collect
|
||||
// the kept messages starting from FirstKeptEntryID (since the
|
||||
// compaction entry's parent chain doesn't include them).
|
||||
if lastCompaction != nil {
|
||||
messages = append(messages, fantasy.Message{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
@@ -700,21 +712,104 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Determine whether to skip entries (everything before firstKeptEntryID).
|
||||
skipping := lastCompaction != nil
|
||||
for _, entry := range branch {
|
||||
// Once we reach the first kept entry, stop skipping.
|
||||
if skipping {
|
||||
entryID := tm.EntryID(entry)
|
||||
if entryID == lastCompaction.FirstKeptEntryID {
|
||||
skipping = false
|
||||
} else {
|
||||
// Collect entries from the compaction entry itself (at compactionIndex)
|
||||
// and any entries before it in the branch (newer messages).
|
||||
for i := compactionIndex; i < len(branch); i++ {
|
||||
entry := branch[i]
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue // skip malformed entries
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
messages = append(messages, msgs...)
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
// Convert branch summary to a user message for context.
|
||||
if e.Summary != "" {
|
||||
messages = append(messages, fantasy.Message{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{
|
||||
Text: fmt.Sprintf("[Branch context: %s]", e.Summary),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
case *ModelChangeEntry:
|
||||
provider = e.Provider
|
||||
modelID = e.ModelID
|
||||
|
||||
case *CompactionEntry:
|
||||
// Already handled above (summary injected).
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Now collect the kept messages starting from FirstKeptEntryID.
|
||||
// These are not in the current branch because the compaction entry
|
||||
// is parented to the first kept entry's parent, not the first kept entry.
|
||||
// We iterate through entries in order (not using getBranchLocked) to avoid
|
||||
// walking back to old compacted messages.
|
||||
// We stop when we reach the compaction entry to avoid double-counting
|
||||
// messages that were added after the compaction.
|
||||
if lastCompaction.FirstKeptEntryID != "" {
|
||||
found := false
|
||||
for _, entry := range tm.entries {
|
||||
entryID := tm.EntryID(entry)
|
||||
|
||||
// Skip entries until we reach the first kept entry.
|
||||
if !found {
|
||||
if entryID == lastCompaction.FirstKeptEntryID {
|
||||
found = true
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Stop when we reach the compaction entry itself.
|
||||
// Messages after the compaction are collected from the branch walk above.
|
||||
if entryID == lastCompaction.ID {
|
||||
break
|
||||
}
|
||||
|
||||
// Process this kept entry.
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
messages = append(messages, msgs...)
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
if e.Summary != "" {
|
||||
messages = append(messages, fantasy.Message{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{
|
||||
Text: fmt.Sprintf("[Branch context: %s]", e.Summary),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
case *ModelChangeEntry:
|
||||
provider = e.Provider
|
||||
modelID = e.ModelID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return messages, provider, modelID
|
||||
}
|
||||
|
||||
// No compaction - process the entire branch normally.
|
||||
for _, entry := range branch {
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
@@ -740,10 +835,6 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
case *ModelChangeEntry:
|
||||
provider = e.Provider
|
||||
modelID = e.ModelID
|
||||
|
||||
case *CompactionEntry:
|
||||
// Already handled above (the last one on the branch).
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
@@ -853,31 +944,92 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
|
||||
|
||||
// Find the last compaction entry for skip logic.
|
||||
var lastCompaction *CompactionEntry
|
||||
var compactionIndex = -1
|
||||
for i := len(branch) - 1; i >= 0; i-- {
|
||||
if c, ok := branch[i].(*CompactionEntry); ok {
|
||||
lastCompaction = c
|
||||
compactionIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
var ids []string
|
||||
|
||||
// If there's a compaction summary injected, it has no entry ID.
|
||||
// If there's a compaction, we need to collect IDs from:
|
||||
// 1. Entries after the compaction entry in the branch (newer messages)
|
||||
// 2. Entries from FirstKeptEntryID onwards (kept messages)
|
||||
if lastCompaction != nil {
|
||||
ids = append(ids, "") // placeholder for the summary system message
|
||||
}
|
||||
// Placeholder for the summary system message (no entry ID).
|
||||
ids = append(ids, "")
|
||||
|
||||
skipping := lastCompaction != nil
|
||||
for _, entry := range branch {
|
||||
if skipping {
|
||||
entryID := tm.EntryID(entry)
|
||||
if entryID == lastCompaction.FirstKeptEntryID {
|
||||
skipping = false
|
||||
} else {
|
||||
continue
|
||||
// Collect IDs from entries after the compaction entry (newer messages).
|
||||
for i := compactionIndex + 1; i < len(branch); i++ {
|
||||
entry := branch[i]
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
for range msgs {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
if e.Summary != "" {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Collect IDs from the kept messages starting at FirstKeptEntryID.
|
||||
// We iterate through entries in order (not using getBranchLocked) to avoid
|
||||
// walking back to old compacted messages.
|
||||
// We stop when we reach the compaction entry to avoid double-counting.
|
||||
if lastCompaction.FirstKeptEntryID != "" {
|
||||
found := false
|
||||
for _, entry := range tm.entries {
|
||||
entryID := tm.EntryID(entry)
|
||||
|
||||
// Skip entries until we reach the first kept entry.
|
||||
if !found {
|
||||
if entryID == lastCompaction.FirstKeptEntryID {
|
||||
found = true
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Stop when we reach the compaction entry itself.
|
||||
if entryID == lastCompaction.ID {
|
||||
break
|
||||
}
|
||||
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
for range msgs {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
if e.Summary != "" {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
// No compaction - collect IDs from the entire branch.
|
||||
for _, entry := range branch {
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
@@ -893,9 +1045,6 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
|
||||
if e.Summary != "" {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
|
||||
case *CompactionEntry:
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -60,15 +60,16 @@ type MCPConnection struct {
|
||||
// creation, health monitoring, and cleanup. The pool runs background health checks
|
||||
// to proactively identify and remove unhealthy connections.
|
||||
type MCPConnectionPool struct {
|
||||
connections map[string]*MCPConnection
|
||||
config *ConnectionPoolConfig
|
||||
mu sync.RWMutex
|
||||
model fantasy.LanguageModel
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
debug bool
|
||||
debugLogger DebugLogger
|
||||
oauthFlow *OAuthFlowRunner
|
||||
connections map[string]*MCPConnection
|
||||
config *ConnectionPoolConfig
|
||||
mu sync.RWMutex
|
||||
model fantasy.LanguageModel
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
debug bool
|
||||
debugLogger DebugLogger
|
||||
oauthFlow *OAuthFlowRunner
|
||||
tokenStoreFactory TokenStoreFactory // custom factory for per-server token stores (nil = default FileTokenStore)
|
||||
}
|
||||
|
||||
// NewMCPConnectionPool creates a new MCP connection pool with the specified configuration.
|
||||
@@ -76,19 +77,20 @@ type MCPConnectionPool struct {
|
||||
// goroutine for periodic health checks that runs until Close is called.
|
||||
// The model parameter is used for MCP servers that require sampling support.
|
||||
// Thread-safe for concurrent use immediately after creation.
|
||||
func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageModel, debug bool, authHandler MCPAuthHandler) *MCPConnectionPool {
|
||||
func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageModel, debug bool, authHandler MCPAuthHandler, tokenStoreFactory TokenStoreFactory) *MCPConnectionPool {
|
||||
if config == nil {
|
||||
config = DefaultConnectionPoolConfig()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pool := &MCPConnectionPool{
|
||||
connections: make(map[string]*MCPConnection),
|
||||
config: config,
|
||||
model: model,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
debug: debug,
|
||||
connections: make(map[string]*MCPConnection),
|
||||
config: config,
|
||||
model: model,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
debug: debug,
|
||||
tokenStoreFactory: tokenStoreFactory,
|
||||
}
|
||||
|
||||
if authHandler != nil {
|
||||
@@ -367,7 +369,7 @@ func (p *MCPConnectionPool) createSSEClient(ctx context.Context, serverConfig co
|
||||
// scopes are discovered automatically via dynamic client registration and
|
||||
// server metadata (RFC 9728).
|
||||
if p.oauthFlow != nil {
|
||||
tokenStore, tsErr := NewFileTokenStore(serverConfig.URL)
|
||||
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
|
||||
if tsErr != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
|
||||
}
|
||||
@@ -414,7 +416,7 @@ func (p *MCPConnectionPool) createStreamableClient(ctx context.Context, serverCo
|
||||
// scopes are discovered automatically via dynamic client registration and
|
||||
// server metadata (RFC 9728).
|
||||
if p.oauthFlow != nil {
|
||||
tokenStore, tsErr := NewFileTokenStore(serverConfig.URL)
|
||||
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
|
||||
if tsErr != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
|
||||
}
|
||||
@@ -437,6 +439,16 @@ func (p *MCPConnectionPool) createStreamableClient(ctx context.Context, serverCo
|
||||
return streamableClient, nil
|
||||
}
|
||||
|
||||
// createTokenStore creates a token store for the given server URL.
|
||||
// If a custom TokenStoreFactory is configured, it is used; otherwise the
|
||||
// default file-backed token store is created.
|
||||
func (p *MCPConnectionPool) createTokenStore(serverURL string) (transport.TokenStore, error) {
|
||||
if p.tokenStoreFactory != nil {
|
||||
return p.tokenStoreFactory(serverURL)
|
||||
}
|
||||
return NewFileTokenStore(serverURL)
|
||||
}
|
||||
|
||||
// initializeClient initializes the client
|
||||
func (p *MCPConnectionPool) initializeClient(ctx context.Context, client client.MCPClient) error {
|
||||
initCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
@@ -583,6 +595,27 @@ func (p *MCPConnectionPool) GetClients() map[string]client.MCPClient {
|
||||
return clients
|
||||
}
|
||||
|
||||
// RemoveConnection closes and removes a single connection from the pool.
|
||||
// Returns an error if the connection does not exist or if closing fails.
|
||||
// Thread-safe for concurrent use.
|
||||
func (p *MCPConnectionPool) RemoveConnection(serverName string) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
conn, exists := p.connections[serverName]
|
||||
if !exists {
|
||||
return fmt.Errorf("connection %q not found in pool", serverName)
|
||||
}
|
||||
|
||||
err := conn.client.Close()
|
||||
delete(p.connections, serverName)
|
||||
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
p.debugLogger.LogDebug(fmt.Sprintf("[POOL] Removed connection %s", serverName))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Close gracefully shuts down the connection pool, closing all client connections
|
||||
// and stopping the background health check goroutine. It attempts to close all
|
||||
// connections even if some fail, logging any errors encountered.
|
||||
|
||||
+147
-10
@@ -20,19 +20,25 @@ import (
|
||||
// pooling, health checks, tool name prefixing to avoid conflicts, and sampling support for LLM interactions.
|
||||
// Thread-safe for concurrent tool invocations.
|
||||
type MCPToolManager struct {
|
||||
connectionPool *MCPConnectionPool
|
||||
tools []fantasy.AgentTool
|
||||
toolMap map[string]*toolMapping // maps prefixed tool names to their server and original name
|
||||
mu sync.Mutex // protects tools and toolMap during parallel loading
|
||||
model fantasy.LanguageModel // LLM model for sampling
|
||||
authHandler MCPAuthHandler // OAuth handler for remote servers (nil = no OAuth)
|
||||
config *config.Config
|
||||
debug bool
|
||||
debugLogger DebugLogger
|
||||
connectionPool *MCPConnectionPool
|
||||
tools []fantasy.AgentTool
|
||||
toolMap map[string]*toolMapping // maps prefixed tool names to their server and original name
|
||||
mu sync.Mutex // protects tools and toolMap during parallel loading
|
||||
model fantasy.LanguageModel // LLM model for sampling
|
||||
authHandler MCPAuthHandler // OAuth handler for remote servers (nil = no OAuth)
|
||||
tokenStoreFactory TokenStoreFactory // factory for creating per-server token stores (nil = default FileTokenStore)
|
||||
config *config.Config
|
||||
debug bool
|
||||
debugLogger DebugLogger
|
||||
|
||||
// onServerLoaded, if non-nil, is called when each server finishes loading.
|
||||
// Called with server name, tool count, and error (nil on success).
|
||||
onServerLoaded func(serverName string, toolCount int, err error)
|
||||
|
||||
// onToolsChanged, if non-nil, is called after AddServer or RemoveServer
|
||||
// mutates the tool list. The agent layer uses this to trigger a
|
||||
// rebuildFantasyAgent so the LLM sees the updated tools.
|
||||
onToolsChanged func()
|
||||
}
|
||||
|
||||
// toolMapping stores the mapping between prefixed tool names and their original details
|
||||
@@ -69,6 +75,14 @@ func (m *MCPToolManager) SetAuthHandler(handler MCPAuthHandler) {
|
||||
m.authHandler = handler
|
||||
}
|
||||
|
||||
// SetTokenStoreFactory sets a custom factory for creating per-server OAuth token
|
||||
// stores. When set, the factory is called for each remote MCP server instead of
|
||||
// using the default file-based token store. This method should be called before
|
||||
// LoadTools.
|
||||
func (m *MCPToolManager) SetTokenStoreFactory(factory TokenStoreFactory) {
|
||||
m.tokenStoreFactory = factory
|
||||
}
|
||||
|
||||
// SetDebugLogger sets the debug logger for the tool manager.
|
||||
// The logger will be used to output detailed debugging information about MCP connections,
|
||||
// tool loading, and execution. If a connection pool exists, it will also be configured
|
||||
@@ -87,6 +101,126 @@ func (m *MCPToolManager) SetOnServerLoaded(cb func(serverName string, toolCount
|
||||
m.onServerLoaded = cb
|
||||
}
|
||||
|
||||
// SetOnToolsChanged sets the callback that's invoked after AddServer or
|
||||
// RemoveServer mutates the tool list. The agent layer uses this to trigger
|
||||
// a rebuild of the fantasy agent so the LLM sees the updated tool set.
|
||||
func (m *MCPToolManager) SetOnToolsChanged(cb func()) {
|
||||
m.onToolsChanged = cb
|
||||
}
|
||||
|
||||
// AddServer connects to a new MCP server at runtime and loads its tools.
|
||||
// The server's tools are immediately available to the agent after this call.
|
||||
// Returns the number of tools loaded from the server.
|
||||
//
|
||||
// If the connection pool has not been initialised yet (i.e. LoadTools was never
|
||||
// called), AddServer creates one automatically using the manager's current
|
||||
// configuration.
|
||||
//
|
||||
// Returns an error if a server with the same name is already loaded, or if
|
||||
// the connection or tool loading fails.
|
||||
func (m *MCPToolManager) AddServer(ctx context.Context, name string, cfg config.MCPServerConfig) (int, error) {
|
||||
m.mu.Lock()
|
||||
// Check for duplicate.
|
||||
if _, exists := m.toolMap[name+"__"]; exists {
|
||||
m.mu.Unlock()
|
||||
return 0, fmt.Errorf("MCP server %q is already loaded", name)
|
||||
}
|
||||
// More thorough duplicate check: scan toolMap for any key with the server prefix.
|
||||
prefix := name + "__"
|
||||
for k := range m.toolMap {
|
||||
if len(k) >= len(prefix) && k[:len(prefix)] == prefix {
|
||||
m.mu.Unlock()
|
||||
return 0, fmt.Errorf("MCP server %q is already loaded", name)
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
// Lazily create the connection pool if LoadTools was never called.
|
||||
m.ensureConnectionPool()
|
||||
|
||||
count, err := m.loadServerTools(ctx, name, cfg)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to add MCP server %q: %w", name, err)
|
||||
}
|
||||
|
||||
// Notify listeners.
|
||||
if m.onServerLoaded != nil {
|
||||
m.onServerLoaded(name, count, nil)
|
||||
}
|
||||
if m.onToolsChanged != nil {
|
||||
m.onToolsChanged()
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// RemoveServer disconnects an MCP server and removes all its tools.
|
||||
// After this call the agent will no longer see or be able to call tools from
|
||||
// the named server. Returns an error if the server is not loaded.
|
||||
func (m *MCPToolManager) RemoveServer(name string) error {
|
||||
prefix := name + "__"
|
||||
|
||||
m.mu.Lock()
|
||||
|
||||
// Check the server actually has tools loaded.
|
||||
found := false
|
||||
for k := range m.toolMap {
|
||||
if len(k) >= len(prefix) && k[:len(prefix)] == prefix {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
m.mu.Unlock()
|
||||
return fmt.Errorf("MCP server %q is not loaded", name)
|
||||
}
|
||||
|
||||
// Remove tools belonging to this server.
|
||||
newTools := make([]fantasy.AgentTool, 0, len(m.tools))
|
||||
for _, t := range m.tools {
|
||||
if len(t.Info().Name) < len(prefix) || t.Info().Name[:len(prefix)] != prefix {
|
||||
newTools = append(newTools, t)
|
||||
}
|
||||
}
|
||||
m.tools = newTools
|
||||
|
||||
// Remove tool mappings.
|
||||
for k := range m.toolMap {
|
||||
if len(k) >= len(prefix) && k[:len(prefix)] == prefix {
|
||||
delete(m.toolMap, k)
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
// Close the connection in the pool (best-effort).
|
||||
if m.connectionPool != nil {
|
||||
_ = m.connectionPool.RemoveConnection(name)
|
||||
}
|
||||
|
||||
if m.onToolsChanged != nil {
|
||||
m.onToolsChanged()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureConnectionPool lazily creates a connection pool if one does not exist.
|
||||
// This allows AddServer to work even if LoadTools was never called.
|
||||
func (m *MCPToolManager) ensureConnectionPool() {
|
||||
if m.connectionPool != nil {
|
||||
return
|
||||
}
|
||||
debug := false
|
||||
if m.config != nil {
|
||||
debug = m.config.Debug
|
||||
}
|
||||
if m.debugLogger == nil {
|
||||
m.debugLogger = NewSimpleDebugLogger(debug)
|
||||
}
|
||||
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, debug, m.authHandler, m.tokenStoreFactory)
|
||||
m.connectionPool.SetDebugLogger(m.debugLogger)
|
||||
}
|
||||
|
||||
// LoadTools loads tools from all configured MCP servers based on the provided configuration.
|
||||
// It initializes the connection pool, connects to each configured server, and loads their tools.
|
||||
// Tools from different servers are prefixed with the server name to avoid naming conflicts.
|
||||
@@ -99,7 +233,7 @@ func (m *MCPToolManager) LoadTools(ctx context.Context, cfg *config.Config) erro
|
||||
if m.debugLogger == nil {
|
||||
m.debugLogger = NewSimpleDebugLogger(cfg.Debug)
|
||||
}
|
||||
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, cfg.Debug, m.authHandler)
|
||||
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, cfg.Debug, m.authHandler, m.tokenStoreFactory)
|
||||
m.connectionPool.SetDebugLogger(m.debugLogger)
|
||||
|
||||
// Load all servers in parallel. Each server connection (subprocess
|
||||
@@ -290,6 +424,9 @@ func (m *MCPToolManager) GetLoadedServerNames() []string {
|
||||
// proper cleanup of stdio processes, network connections, and other resources.
|
||||
// It is safe to call Close multiple times.
|
||||
func (m *MCPToolManager) Close() error {
|
||||
if m.connectionPool == nil {
|
||||
return nil
|
||||
}
|
||||
return m.connectionPool.Close()
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,323 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
)
|
||||
|
||||
// testdataDir returns the absolute path to the testdata directory.
|
||||
func testdataDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
_, file, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
t.Fatal("cannot determine test file path")
|
||||
}
|
||||
return filepath.Join(filepath.Dir(file), "testdata")
|
||||
}
|
||||
|
||||
// echoServerConfig returns an MCPServerConfig for the test echo MCP server.
|
||||
func echoServerConfig(t *testing.T) config.MCPServerConfig {
|
||||
t.Helper()
|
||||
script := filepath.Join(testdataDir(t), "echo_server.py")
|
||||
if _, err := os.Stat(script); err != nil {
|
||||
t.Skipf("echo_server.py not found: %v", err)
|
||||
}
|
||||
return config.MCPServerConfig{
|
||||
Command: []string{"python3", script},
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_AddServer_Integration tests adding a real MCP server
|
||||
// at runtime and verifying tools are loaded.
|
||||
func TestMCPToolManager_AddServer_Integration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
manager := NewMCPToolManager()
|
||||
defer func() { _ = manager.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Track callbacks.
|
||||
var mu sync.Mutex
|
||||
var loadedServer string
|
||||
var loadedCount int
|
||||
toolsChangedCount := 0
|
||||
|
||||
manager.SetOnServerLoaded(func(name string, count int, err error) {
|
||||
mu.Lock()
|
||||
loadedServer = name
|
||||
loadedCount = count
|
||||
mu.Unlock()
|
||||
})
|
||||
manager.SetOnToolsChanged(func() {
|
||||
mu.Lock()
|
||||
toolsChangedCount++
|
||||
mu.Unlock()
|
||||
})
|
||||
|
||||
// Add the server.
|
||||
count, err := manager.AddServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddServer failed: %v", err)
|
||||
}
|
||||
|
||||
if count != 2 {
|
||||
t.Errorf("Expected 2 tools from echo server, got %d", count)
|
||||
}
|
||||
|
||||
// Verify callbacks fired.
|
||||
mu.Lock()
|
||||
if loadedServer != "echo" {
|
||||
t.Errorf("Expected onServerLoaded for 'echo', got %q", loadedServer)
|
||||
}
|
||||
if loadedCount != 2 {
|
||||
t.Errorf("Expected onServerLoaded count=2, got %d", loadedCount)
|
||||
}
|
||||
if toolsChangedCount != 1 {
|
||||
t.Errorf("Expected onToolsChanged called once, got %d", toolsChangedCount)
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
// Verify tools are accessible.
|
||||
tools := manager.GetTools()
|
||||
if len(tools) != 2 {
|
||||
t.Fatalf("Expected 2 tools, got %d", len(tools))
|
||||
}
|
||||
|
||||
// Verify tool names are prefixed.
|
||||
toolNames := make(map[string]bool)
|
||||
for _, tool := range tools {
|
||||
toolNames[tool.Info().Name] = true
|
||||
}
|
||||
if !toolNames["echo__echo"] {
|
||||
t.Error("Expected tool 'echo__echo'")
|
||||
}
|
||||
if !toolNames["echo__greet"] {
|
||||
t.Error("Expected tool 'echo__greet'")
|
||||
}
|
||||
|
||||
// Verify server appears in loaded names.
|
||||
names := manager.GetLoadedServerNames()
|
||||
if !slices.Contains(names, "echo") {
|
||||
t.Errorf("Expected 'echo' in loaded server names, got: %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_RemoveServer_Integration tests removing a real MCP server
|
||||
// and verifying tools are cleaned up.
|
||||
func TestMCPToolManager_RemoveServer_Integration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
manager := NewMCPToolManager()
|
||||
defer func() { _ = manager.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add the server first.
|
||||
count, err := manager.AddServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddServer failed: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Fatalf("Expected 2 tools, got %d", count)
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
toolsChangedCount := 0
|
||||
manager.SetOnToolsChanged(func() {
|
||||
mu.Lock()
|
||||
toolsChangedCount++
|
||||
mu.Unlock()
|
||||
})
|
||||
|
||||
// Remove the server.
|
||||
err = manager.RemoveServer("echo")
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveServer failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify tools are gone.
|
||||
tools := manager.GetTools()
|
||||
if len(tools) != 0 {
|
||||
t.Errorf("Expected 0 tools after removal, got %d", len(tools))
|
||||
}
|
||||
|
||||
// Verify callback fired.
|
||||
mu.Lock()
|
||||
if toolsChangedCount != 1 {
|
||||
t.Errorf("Expected onToolsChanged called once, got %d", toolsChangedCount)
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
// Verify server is gone from loaded names.
|
||||
names := manager.GetLoadedServerNames()
|
||||
for _, n := range names {
|
||||
if n == "echo" {
|
||||
t.Error("Server 'echo' should not appear in loaded names after removal")
|
||||
}
|
||||
}
|
||||
|
||||
// Removing again should error.
|
||||
err = manager.RemoveServer("echo")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error removing already-removed server")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not loaded") {
|
||||
t.Errorf("Expected 'not loaded' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_AddRemoveMultiple_Integration tests adding and removing
|
||||
// multiple servers, verifying tool isolation.
|
||||
func TestMCPToolManager_AddRemoveMultiple_Integration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
manager := NewMCPToolManager()
|
||||
defer func() { _ = manager.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add two servers with the same binary but different names.
|
||||
count1, err := manager.AddServer(ctx, "server-a", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddServer server-a failed: %v", err)
|
||||
}
|
||||
count2, err := manager.AddServer(ctx, "server-b", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddServer server-b failed: %v", err)
|
||||
}
|
||||
|
||||
totalTools := count1 + count2
|
||||
if totalTools != 4 {
|
||||
t.Fatalf("Expected 4 total tools (2+2), got %d", totalTools)
|
||||
}
|
||||
|
||||
tools := manager.GetTools()
|
||||
if len(tools) != 4 {
|
||||
t.Fatalf("Expected 4 tools, got %d", len(tools))
|
||||
}
|
||||
|
||||
// Remove server-a, verify server-b tools remain.
|
||||
err = manager.RemoveServer("server-a")
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveServer server-a failed: %v", err)
|
||||
}
|
||||
|
||||
tools = manager.GetTools()
|
||||
if len(tools) != 2 {
|
||||
t.Fatalf("Expected 2 tools after removing server-a, got %d", len(tools))
|
||||
}
|
||||
|
||||
// Remaining tools should all be from server-b.
|
||||
for _, tool := range tools {
|
||||
if !strings.HasPrefix(tool.Info().Name, "server-b__") {
|
||||
t.Errorf("Expected tool from server-b, got: %s", tool.Info().Name)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove server-b.
|
||||
err = manager.RemoveServer("server-b")
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveServer server-b failed: %v", err)
|
||||
}
|
||||
|
||||
tools = manager.GetTools()
|
||||
if len(tools) != 0 {
|
||||
t.Errorf("Expected 0 tools after removing all servers, got %d", len(tools))
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_AddServer_DuplicateDetection_Integration tests that
|
||||
// adding a server with the same name as an already loaded server errors.
|
||||
func TestMCPToolManager_AddServer_DuplicateDetection_Integration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
manager := NewMCPToolManager()
|
||||
defer func() { _ = manager.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add the server.
|
||||
_, err := manager.AddServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("First AddServer failed: %v", err)
|
||||
}
|
||||
|
||||
// Try to add again with the same name.
|
||||
_, err = manager.AddServer(ctx, "echo", cfg)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error adding duplicate server")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "already loaded") {
|
||||
t.Errorf("Expected 'already loaded' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_AddAfterRemove_Integration tests that a server can be
|
||||
// re-added after being removed.
|
||||
func TestMCPToolManager_AddAfterRemove_Integration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
manager := NewMCPToolManager()
|
||||
defer func() { _ = manager.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add, remove, re-add.
|
||||
_, err := manager.AddServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("First AddServer failed: %v", err)
|
||||
}
|
||||
|
||||
err = manager.RemoveServer("echo")
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveServer failed: %v", err)
|
||||
}
|
||||
|
||||
count, err := manager.AddServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Re-AddServer failed: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("Expected 2 tools on re-add, got %d", count)
|
||||
}
|
||||
|
||||
tools := manager.GetTools()
|
||||
if len(tools) != 2 {
|
||||
t.Errorf("Expected 2 tools after re-add, got %d", len(tools))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
)
|
||||
|
||||
// TestMCPToolManager_AddServer_DuplicateName verifies that adding a server
|
||||
// with a name that already exists returns an error.
|
||||
func TestMCPToolManager_AddServer_DuplicateName(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
|
||||
cfg := config.MCPServerConfig{
|
||||
Command: []string{"non-existent-command"},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// First add will fail (bad command), but let's test the duplicate detection
|
||||
// by simulating a loaded server via LoadTools first.
|
||||
loadCfg := &config.Config{
|
||||
MCPServers: map[string]config.MCPServerConfig{
|
||||
"test-server": cfg,
|
||||
},
|
||||
}
|
||||
// This will fail to load but creates the connection pool.
|
||||
_ = manager.LoadTools(ctx, loadCfg)
|
||||
|
||||
// Now try to add the same server name — the tools didn't load (bad command),
|
||||
// so AddServer should not find a duplicate and should fail with connection error.
|
||||
_, err := manager.AddServer(ctx, "test-server", cfg)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when adding server with bad command, got nil")
|
||||
}
|
||||
// It should be a connection error, not a duplicate error.
|
||||
if strings.Contains(err.Error(), "already loaded") {
|
||||
t.Fatalf("Should not report duplicate since server failed to load initially: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_RemoveServer_NotLoaded verifies that removing a server
|
||||
// that doesn't exist returns an appropriate error.
|
||||
func TestMCPToolManager_RemoveServer_NotLoaded(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
|
||||
err := manager.RemoveServer("nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when removing non-existent server, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not loaded") {
|
||||
t.Errorf("Expected 'not loaded' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_AddServer_CreatesConnectionPool verifies that AddServer
|
||||
// lazily creates a connection pool when LoadTools was never called.
|
||||
func TestMCPToolManager_AddServer_CreatesConnectionPool(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
|
||||
// Connection pool should be nil initially.
|
||||
if manager.connectionPool != nil {
|
||||
t.Fatal("Expected nil connection pool before any operation")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// AddServer with a bad command — should fail, but the pool should be created.
|
||||
_, err := manager.AddServer(ctx, "lazy-server", config.MCPServerConfig{
|
||||
Command: []string{"non-existent-command"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for bad command")
|
||||
}
|
||||
|
||||
// Connection pool should have been created.
|
||||
if manager.connectionPool == nil {
|
||||
t.Fatal("Expected connection pool to be created lazily by AddServer")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_OnToolsChanged_Callback verifies that the onToolsChanged
|
||||
// callback fires on RemoveServer (we can't easily test AddServer with a real
|
||||
// MCP server, but we can test the callback wiring).
|
||||
func TestMCPToolManager_OnToolsChanged_Callback(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
|
||||
var mu sync.Mutex
|
||||
callCount := 0
|
||||
manager.SetOnToolsChanged(func() {
|
||||
mu.Lock()
|
||||
callCount++
|
||||
mu.Unlock()
|
||||
})
|
||||
|
||||
// RemoveServer on non-existent should NOT fire callback.
|
||||
_ = manager.RemoveServer("nonexistent")
|
||||
|
||||
mu.Lock()
|
||||
if callCount != 0 {
|
||||
t.Errorf("Expected 0 callback calls for failed remove, got %d", callCount)
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
// TestMCPToolManager_Close_NilPool verifies Close is safe when the connection
|
||||
// pool was never initialized.
|
||||
func TestMCPToolManager_Close_NilPool(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
err := manager.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil error from Close with nil pool, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPConnectionPool_RemoveConnection_NotFound verifies that removing a
|
||||
// non-existent connection returns an error.
|
||||
func TestMCPConnectionPool_RemoveConnection_NotFound(t *testing.T) {
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), nil, false, nil, nil)
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
err := pool.RemoveConnection("nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for non-existent connection")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not found") {
|
||||
t.Errorf("Expected 'not found' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_EnsureConnectionPool_Idempotent verifies that
|
||||
// ensureConnectionPool doesn't recreate an existing pool.
|
||||
func TestMCPToolManager_EnsureConnectionPool_Idempotent(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
|
||||
// First call creates the pool.
|
||||
manager.ensureConnectionPool()
|
||||
pool1 := manager.connectionPool
|
||||
if pool1 == nil {
|
||||
t.Fatal("Expected pool to be created")
|
||||
}
|
||||
|
||||
// Second call should be a no-op.
|
||||
manager.ensureConnectionPool()
|
||||
pool2 := manager.connectionPool
|
||||
if pool1 != pool2 {
|
||||
t.Fatal("Expected ensureConnectionPool to be idempotent")
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/url"
|
||||
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/client/transport"
|
||||
)
|
||||
|
||||
// MCPAuthHandler is the internal interface for handling MCP OAuth flows.
|
||||
@@ -21,6 +22,12 @@ type MCPAuthHandler interface {
|
||||
HandleAuth(ctx context.Context, serverName string, authURL string) (callbackURL string, err error)
|
||||
}
|
||||
|
||||
// TokenStoreFactory creates a transport.TokenStore for a given MCP server URL.
|
||||
// When provided to the connection pool, it is called once per remote MCP server
|
||||
// instead of using the default file-based token store. Implementations can
|
||||
// return any transport.TokenStore — in-memory, database-backed, encrypted, etc.
|
||||
type TokenStoreFactory func(serverURL string) (transport.TokenStore, error)
|
||||
|
||||
// OAuthFlowRunner handles the OAuth authorization flow when an MCP server
|
||||
// returns an OAuthAuthorizationRequiredError. It coordinates dynamic client
|
||||
// registration, PKCE generation, user authorization (via MCPAuthHandler),
|
||||
|
||||
+111
@@ -0,0 +1,111 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Minimal MCP server over stdio for testing. Exposes one tool: echo."""
|
||||
import json
|
||||
import sys
|
||||
|
||||
|
||||
def read_message():
|
||||
"""Read a JSON-RPC message from stdin."""
|
||||
line = sys.stdin.readline()
|
||||
if not line:
|
||||
return None
|
||||
return json.loads(line.strip())
|
||||
|
||||
|
||||
def write_message(msg):
|
||||
"""Write a JSON-RPC message to stdout."""
|
||||
sys.stdout.write(json.dumps(msg) + "\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def handle(msg):
|
||||
method = msg.get("method", "")
|
||||
mid = msg.get("id")
|
||||
|
||||
if method == "initialize":
|
||||
write_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": mid,
|
||||
"result": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"serverInfo": {"name": "test-echo", "version": "1.0.0"},
|
||||
},
|
||||
})
|
||||
elif method == "notifications/initialized":
|
||||
pass # no response needed
|
||||
elif method == "tools/list":
|
||||
write_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": mid,
|
||||
"result": {
|
||||
"tools": [
|
||||
{
|
||||
"name": "echo",
|
||||
"description": "Echoes the input text back.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string", "description": "Text to echo"}
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "greet",
|
||||
"description": "Returns a greeting.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Name to greet"}
|
||||
},
|
||||
"required": ["name"],
|
||||
},
|
||||
},
|
||||
]
|
||||
},
|
||||
})
|
||||
elif method == "tools/call":
|
||||
tool_name = msg["params"]["name"]
|
||||
args = msg["params"].get("arguments", {})
|
||||
if tool_name == "echo":
|
||||
text = args.get("text", "")
|
||||
write_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": mid,
|
||||
"result": {
|
||||
"content": [{"type": "text", "text": text}]
|
||||
},
|
||||
})
|
||||
elif tool_name == "greet":
|
||||
name = args.get("name", "World")
|
||||
write_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": mid,
|
||||
"result": {
|
||||
"content": [{"type": "text", "text": f"Hello, {name}!"}]
|
||||
},
|
||||
})
|
||||
else:
|
||||
write_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": mid,
|
||||
"error": {"code": -32601, "message": f"Unknown tool: {tool_name}"},
|
||||
})
|
||||
elif method == "ping":
|
||||
write_message({"jsonrpc": "2.0", "id": mid, "result": {}})
|
||||
else:
|
||||
if mid is not None:
|
||||
write_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": mid,
|
||||
"error": {"code": -32601, "message": f"Unknown method: {method}"},
|
||||
})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
while True:
|
||||
msg = read_message()
|
||||
if msg is None:
|
||||
break
|
||||
handle(msg)
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"charm.land/lipgloss/v2"
|
||||
"golang.org/x/term"
|
||||
|
||||
@@ -173,33 +172,6 @@ func (c *CLI) DisplayDebugConfig(config map[string]any) {
|
||||
fmt.Println(c.renderer.RenderDebugConfigMessage(config, time.Now()).Content)
|
||||
}
|
||||
|
||||
// UpdateUsageFromResponse records token usage using metadata from the fantasy
|
||||
// response. Only actual API-reported tokens are used for cost tracking.
|
||||
// If the provider doesn't report token counts, no usage is recorded.
|
||||
func (c *CLI) UpdateUsageFromResponse(response *fantasy.Response, inputText string) {
|
||||
if c.usageTracker == nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage := response.Usage
|
||||
inputTokens := int(usage.InputTokens)
|
||||
outputTokens := int(usage.OutputTokens)
|
||||
|
||||
// Only use actual API-reported tokens for cost tracking.
|
||||
// We intentionally do NOT estimate tokens - estimation is inaccurate
|
||||
// and should never be used for cost calculations.
|
||||
if inputTokens > 0 {
|
||||
cacheReadTokens := int(usage.CacheReadTokens)
|
||||
cacheWriteTokens := int(usage.CacheCreationTokens)
|
||||
c.usageTracker.UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens)
|
||||
// Per-response usage is a single API call, so it represents the
|
||||
// actual context window fill level.
|
||||
c.usageTracker.SetContextTokens(inputTokens + outputTokens)
|
||||
}
|
||||
// If inputTokens is 0, the provider didn't report usage - we skip recording
|
||||
// rather than estimating, to ensure cost accuracy.
|
||||
}
|
||||
|
||||
// DisplayUsageAfterResponse renders and displays token usage information immediately
|
||||
// following an AI response. This provides real-time feedback about the cost and
|
||||
// token consumption of each interaction.
|
||||
|
||||
@@ -139,7 +139,9 @@ func (h *CLIEventHandler) Handle(msg tea.Msg) {
|
||||
case "block":
|
||||
h.cli.DisplayExtensionBlock(e.Text, e.BorderColor, e.Subtitle)
|
||||
default:
|
||||
fmt.Println(e.Text)
|
||||
// Route unstyled extension prints through the system message
|
||||
// renderer so they get consistent formatting and timestamps.
|
||||
h.cli.DisplayInfo(e.Text)
|
||||
}
|
||||
|
||||
case app.StepCompleteEvent:
|
||||
|
||||
@@ -109,9 +109,7 @@ func SetupCLI(opts *CLISetupOptions) (*CLI, error) {
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("")
|
||||
|
||||
// Display model info
|
||||
// Display model info (the system message block provides its own spacing).
|
||||
if provider != "unknown" && model != "unknown" {
|
||||
cli.DisplayInfo(fmt.Sprintf("Model loaded: %s (%s)", provider, model))
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ type InputComponent struct {
|
||||
hideHint bool
|
||||
|
||||
// agentBusy indicates the agent is currently working. When true, the
|
||||
// hint text shows steering shortcut (Ctrl+S) instead of submit.
|
||||
// hint text shows steering shortcut (Ctrl+X s) instead of submit.
|
||||
agentBusy bool
|
||||
|
||||
// pendingImages holds clipboard images attached to the next submission.
|
||||
@@ -109,7 +109,7 @@ func NewInputComponent(width int, title string, appCtrl AppController) *InputCom
|
||||
ta.Placeholder = "Type your message..."
|
||||
ta.ShowLineNumbers = false
|
||||
ta.Prompt = ""
|
||||
ta.CharLimit = 5000
|
||||
ta.CharLimit = 0
|
||||
ta.SetWidth(width - 8) // Account for container padding, border and internal padding
|
||||
ta.SetHeight(3) // Default to 3 lines like huh
|
||||
ta.Focus()
|
||||
@@ -514,12 +514,12 @@ func (s *InputComponent) View() tea.View {
|
||||
availableHintWidth := s.width - 3
|
||||
if s.agentBusy {
|
||||
// When the agent is working, show steering shortcut.
|
||||
if availableHintWidth >= 55 {
|
||||
hint = "enter queue • ctrl+s steer • esc esc cancel"
|
||||
} else if availableHintWidth >= 35 {
|
||||
hint = "↵ queue • ^S steer • esc×2 cancel"
|
||||
if availableHintWidth >= 60 {
|
||||
hint = "enter queue • ctrl+x s steer • esc esc cancel"
|
||||
} else if availableHintWidth >= 40 {
|
||||
hint = "↵ queue • ^X s steer • esc×2 cancel"
|
||||
} else {
|
||||
hint = "^S steer"
|
||||
hint = "^X s steer"
|
||||
}
|
||||
} else if availableHintWidth >= 67 {
|
||||
hint = "enter submit • ctrl+j / shift+enter new line • ctrl+v paste image"
|
||||
|
||||
@@ -152,7 +152,7 @@ func (r *MessageRenderer) SetWidth(width int) {
|
||||
|
||||
// RenderUserMessage renders a user's input message using herald Tip alert
|
||||
func (r *MessageRenderer) RenderUserMessage(content string, timestamp time.Time) UIMessage {
|
||||
rendered := render.UserBlock(content, r.ty, style.GetTheme())
|
||||
rendered := render.UserBlock(content, r.width, r.ty, style.GetTheme())
|
||||
|
||||
return UIMessage{
|
||||
Type: UserMessage,
|
||||
|
||||
+149
-59
@@ -477,7 +477,7 @@ type AppModel struct {
|
||||
queuedMessages []string
|
||||
|
||||
// steeringMessages stores the text of prompts that were sent as steer
|
||||
// messages (injected mid-turn via Ctrl+S). Rendered with a "STEERING"
|
||||
// messages (injected mid-turn via Ctrl+X s). Rendered with a "STEERING"
|
||||
// badge above the input. Cleared when the steer is consumed.
|
||||
steeringMessages []string
|
||||
|
||||
@@ -498,6 +498,11 @@ type AppModel struct {
|
||||
// A second ESC within 2 seconds will cancel the current step.
|
||||
canceling bool
|
||||
|
||||
// leaderKeyActive tracks whether the Ctrl+X leader key prefix has been
|
||||
// pressed. The next keypress is interpreted as a chord suffix (e.g. "s"
|
||||
// for steer). Cleared on any subsequent keypress.
|
||||
leaderKeyActive bool
|
||||
|
||||
// providerName is the LLM provider for the startup message.
|
||||
providerName string
|
||||
|
||||
@@ -1268,6 +1273,71 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
// ── Leader key chord handling (Ctrl+X prefix) ──────────────
|
||||
// If the leader key was previously pressed, the current key
|
||||
// completes the chord. We consume it regardless of match so
|
||||
// the prefix doesn't leak to child components.
|
||||
if m.leaderKeyActive {
|
||||
m.leaderKeyActive = false
|
||||
switch msg.String() {
|
||||
case "s":
|
||||
// Ctrl+X s → Steer: inject the current input as a steering
|
||||
// message into the running agent turn.
|
||||
if m.state == stateWorking && m.appCtrl != nil {
|
||||
var text string
|
||||
if ic, ok := m.input.(*InputComponent); ok {
|
||||
text = strings.TrimSpace(ic.textarea.Value())
|
||||
}
|
||||
if text != "" {
|
||||
// Clear the input, collect pending images, and push to history.
|
||||
var images []uicore.ImageAttachment
|
||||
if ic, ok := m.input.(*InputComponent); ok {
|
||||
ic.pushHistory(text)
|
||||
ic.textarea.SetValue("")
|
||||
images = ic.ClearPendingImages()
|
||||
}
|
||||
|
||||
// Preprocess @file references.
|
||||
processedText := text
|
||||
if m.cwd != "" {
|
||||
processedText = fileutil.ProcessFileAttachments(text, m.cwd)
|
||||
}
|
||||
|
||||
// Convert image attachments to kit.LLMFilePart for the app layer.
|
||||
var fileParts []kit.LLMFilePart
|
||||
for _, img := range images {
|
||||
fileParts = append(fileParts, kit.LLMFilePart{
|
||||
Data: img.Data,
|
||||
MediaType: img.MediaType,
|
||||
})
|
||||
}
|
||||
|
||||
// Build display text (include image count if any).
|
||||
displayText := text
|
||||
if len(images) > 0 {
|
||||
displayText = fmt.Sprintf("%s\n[%d image(s) attached]", text, len(images))
|
||||
}
|
||||
|
||||
// Inject the steer message.
|
||||
sLen := m.appCtrl.SteerWithFiles(processedText, fileParts)
|
||||
if sLen > 0 {
|
||||
m.steeringMessages = append(m.steeringMessages, displayText)
|
||||
m.layoutDirty = true
|
||||
} else {
|
||||
// Started immediately (agent was idle).
|
||||
m.pendingUserPrints = append(m.pendingUserPrints, displayText)
|
||||
m.flushStreamAndPendingUserMessages()
|
||||
if m.state != stateWorking {
|
||||
m.state = stateWorking
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Chord consumed — don't propagate to children.
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
switch msg.String() {
|
||||
case "esc":
|
||||
if m.state == stateWorking {
|
||||
@@ -1286,61 +1356,10 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
// In other states pass ESC through to children below.
|
||||
|
||||
case "ctrl+s":
|
||||
// Steer: inject the current input as a steering message into the
|
||||
// running agent turn. Only active during stateWorking — in input
|
||||
// state, Ctrl+S is passed through to children (no-op by default).
|
||||
if m.state == stateWorking && m.appCtrl != nil {
|
||||
var text string
|
||||
if ic, ok := m.input.(*InputComponent); ok {
|
||||
text = strings.TrimSpace(ic.textarea.Value())
|
||||
}
|
||||
if text != "" {
|
||||
// Clear the input, collect pending images, and push to history.
|
||||
var images []uicore.ImageAttachment
|
||||
if ic, ok := m.input.(*InputComponent); ok {
|
||||
ic.pushHistory(text)
|
||||
ic.textarea.SetValue("")
|
||||
images = ic.ClearPendingImages()
|
||||
}
|
||||
|
||||
// Preprocess @file references.
|
||||
processedText := text
|
||||
if m.cwd != "" {
|
||||
processedText = fileutil.ProcessFileAttachments(text, m.cwd)
|
||||
}
|
||||
|
||||
// Convert image attachments to kit.LLMFilePart for the app layer.
|
||||
var fileParts []kit.LLMFilePart
|
||||
for _, img := range images {
|
||||
fileParts = append(fileParts, kit.LLMFilePart{
|
||||
Data: img.Data,
|
||||
MediaType: img.MediaType,
|
||||
})
|
||||
}
|
||||
|
||||
// Build display text (include image count if any).
|
||||
displayText := text
|
||||
if len(images) > 0 {
|
||||
displayText = fmt.Sprintf("%s\n[%d image(s) attached]", text, len(images))
|
||||
}
|
||||
|
||||
// Inject the steer message.
|
||||
sLen := m.appCtrl.SteerWithFiles(processedText, fileParts)
|
||||
if sLen > 0 {
|
||||
m.steeringMessages = append(m.steeringMessages, displayText)
|
||||
m.layoutDirty = true
|
||||
} else {
|
||||
// Started immediately (agent was idle).
|
||||
m.pendingUserPrints = append(m.pendingUserPrints, displayText)
|
||||
m.flushStreamAndPendingUserMessages()
|
||||
if m.state != stateWorking {
|
||||
m.state = stateWorking
|
||||
}
|
||||
}
|
||||
}
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
case "ctrl+x":
|
||||
// Activate leader key prefix — the next keypress completes the chord.
|
||||
m.leaderKeyActive = true
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
// Route key events to the focused child. Check for editor
|
||||
@@ -1801,6 +1820,12 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// Refresh content to show the finalized message.
|
||||
m.refreshContent()
|
||||
|
||||
// Reset context token display — the pre-compaction count is stale.
|
||||
// The next API call will set the accurate post-compaction value.
|
||||
if m.usageTracker != nil {
|
||||
m.usageTracker.SetContextTokens(0)
|
||||
}
|
||||
|
||||
// Print stats as a separate system message.
|
||||
saved := msg.OriginalTokens - msg.CompactedTokens
|
||||
statsMsg := fmt.Sprintf(
|
||||
@@ -2462,22 +2487,34 @@ func (m *AppModel) renderHeaderFooter(getter func() *WidgetData) string {
|
||||
return renderContentBlock(data.Text, m.width, opts...)
|
||||
}
|
||||
|
||||
// maxQueuedMessageLines is the maximum number of visible content lines
|
||||
// rendered for each queued or steering message block. Messages exceeding
|
||||
// this limit are truncated with an ellipsis to prevent large pastes from
|
||||
// overflowing the screen and squeezing the stream region to zero.
|
||||
const maxQueuedMessageLines = 3
|
||||
|
||||
// renderQueuedMessages renders queued and steering prompts as styled content
|
||||
// blocks with badges, anchored between the separator and input. Steering
|
||||
// messages use a distinct "STEERING" badge to differentiate from queued ones.
|
||||
// Long messages are visually truncated to maxQueuedMessageLines.
|
||||
func (m *AppModel) renderQueuedMessages() string {
|
||||
if len(m.queuedMessages) == 0 && len(m.steeringMessages) == 0 {
|
||||
return ""
|
||||
}
|
||||
theme := style.GetTheme()
|
||||
|
||||
// Available content width inside the block: container minus border (1)
|
||||
// minus left padding (2). Used to estimate line wrapping for truncation.
|
||||
contentWidth := max(m.width-3, 10)
|
||||
|
||||
var blocks []string
|
||||
|
||||
// Render steering messages first (higher priority).
|
||||
if len(m.steeringMessages) > 0 {
|
||||
badge := style.CreateBadge("STEERING", theme.Warning)
|
||||
for _, msg := range m.steeringMessages {
|
||||
content := msg + "\n" + badge
|
||||
display := truncateMessageForBlock(msg, maxQueuedMessageLines, contentWidth)
|
||||
content := display + "\n" + badge
|
||||
rendered := renderContentBlock(
|
||||
content,
|
||||
m.width,
|
||||
@@ -2492,7 +2529,8 @@ func (m *AppModel) renderQueuedMessages() string {
|
||||
if len(m.queuedMessages) > 0 {
|
||||
badge := style.CreateBadge("QUEUED", theme.Accent)
|
||||
for _, msg := range m.queuedMessages {
|
||||
content := msg + "\n" + badge
|
||||
display := truncateMessageForBlock(msg, maxQueuedMessageLines, contentWidth)
|
||||
content := display + "\n" + badge
|
||||
rendered := renderContentBlock(
|
||||
content,
|
||||
m.width,
|
||||
@@ -2506,6 +2544,58 @@ func (m *AppModel) renderQueuedMessages() string {
|
||||
return strings.Join(blocks, "\n")
|
||||
}
|
||||
|
||||
// truncateMessageForBlock truncates a message to at most maxLines visible
|
||||
// lines, accounting for soft-wrapping at the given width. If the message is
|
||||
// truncated, the last visible line is replaced with an ellipsis ("…").
|
||||
func truncateMessageForBlock(msg string, maxLines, width int) string {
|
||||
if width <= 0 {
|
||||
width = 1
|
||||
}
|
||||
|
||||
lines := strings.Split(msg, "\n")
|
||||
|
||||
// Count visible lines (each hard line may wrap into multiple visual lines).
|
||||
var kept []string
|
||||
visibleCount := 0
|
||||
truncated := false
|
||||
|
||||
for _, line := range lines {
|
||||
// Calculate how many visual lines this hard line occupies.
|
||||
lineWidth := lipgloss.Width(line)
|
||||
wrapped := 1
|
||||
if lineWidth > width {
|
||||
wrapped = (lineWidth + width - 1) / width // ceil division
|
||||
}
|
||||
|
||||
if visibleCount+wrapped > maxLines {
|
||||
// This line would exceed the limit. Keep a partial if we
|
||||
// still have room for at least one more visual line.
|
||||
remaining := maxLines - visibleCount
|
||||
if remaining > 0 {
|
||||
// Truncate the line to fit the remaining visual lines.
|
||||
runes := []rune(line)
|
||||
maxRunes := remaining * width
|
||||
if maxRunes < len(runes) {
|
||||
kept = append(kept, string(runes[:maxRunes]))
|
||||
} else {
|
||||
kept = append(kept, line)
|
||||
}
|
||||
}
|
||||
truncated = true
|
||||
break
|
||||
}
|
||||
|
||||
kept = append(kept, line)
|
||||
visibleCount += wrapped
|
||||
}
|
||||
|
||||
if !truncated {
|
||||
return msg
|
||||
}
|
||||
|
||||
return strings.Join(kept, "\n") + "…"
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Print helpers — add content to ScrollList
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -2876,7 +2966,7 @@ func (m *AppModel) printHelpMessage() {
|
||||
"**Keys:**\n" +
|
||||
"- `Ctrl+C`: Exit at any time\n" +
|
||||
"- `ESC` (x2): Cancel ongoing LLM generation\n" +
|
||||
"- `Ctrl+S`: Steer — redirect the agent mid-turn (injected between tool calls)\n" +
|
||||
"- `Ctrl+X s`: Steer — redirect the agent mid-turn (injected between tool calls)\n" +
|
||||
"- `Enter` (while working): Queue message for after the agent finishes\n\n" +
|
||||
"You can also just type your message to chat with the AI assistant."
|
||||
m.printSystemMessage(help)
|
||||
|
||||
@@ -2,6 +2,7 @@ package ui
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
@@ -892,3 +893,107 @@ func TestSubmit_duringWorking_stays(t *testing.T) {
|
||||
t.Fatalf("expected Run('queued prompt') called, got %v", ctrl.runCalls)
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// truncateMessageForBlock
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// TestTruncateMessageForBlock_shortMessage verifies that short messages are
|
||||
// returned unchanged.
|
||||
func TestTruncateMessageForBlock_shortMessage(t *testing.T) {
|
||||
msg := "hello world"
|
||||
got := truncateMessageForBlock(msg, 3, 80)
|
||||
if got != msg {
|
||||
t.Fatalf("expected unchanged message, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTruncateMessageForBlock_exactLines verifies that a message with exactly
|
||||
// maxLines hard lines is returned unchanged.
|
||||
func TestTruncateMessageForBlock_exactLines(t *testing.T) {
|
||||
msg := "line1\nline2\nline3"
|
||||
got := truncateMessageForBlock(msg, 3, 80)
|
||||
if got != msg {
|
||||
t.Fatalf("expected unchanged message, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTruncateMessageForBlock_tooManyLines verifies that messages exceeding
|
||||
// maxLines are truncated with an ellipsis.
|
||||
func TestTruncateMessageForBlock_tooManyLines(t *testing.T) {
|
||||
msg := "line1\nline2\nline3\nline4\nline5"
|
||||
got := truncateMessageForBlock(msg, 3, 80)
|
||||
want := "line1\nline2\nline3…"
|
||||
if got != want {
|
||||
t.Fatalf("expected %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTruncateMessageForBlock_longWrappingLine verifies that a single long
|
||||
// line that would wrap beyond maxLines is truncated.
|
||||
func TestTruncateMessageForBlock_longWrappingLine(t *testing.T) {
|
||||
// 100 chars at width 20 = 5 visual lines, exceeds maxLines=3
|
||||
msg := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
got := truncateMessageForBlock(msg, 3, 20)
|
||||
// Should be truncated to 3*20=60 runes + "…"
|
||||
if len([]rune(got)) != 61 { // 60 runes + "…"
|
||||
t.Fatalf("expected 61 runes (60 + ellipsis), got %d runes: %q", len([]rune(got)), got)
|
||||
}
|
||||
if got[len(got)-3:] != "…" { // "…" is 3 bytes in UTF-8
|
||||
t.Fatal("expected trailing ellipsis")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTruncateMessageForBlock_emptyMessage verifies that empty messages are
|
||||
// returned unchanged.
|
||||
func TestTruncateMessageForBlock_emptyMessage(t *testing.T) {
|
||||
got := truncateMessageForBlock("", 3, 80)
|
||||
if got != "" {
|
||||
t.Fatalf("expected empty string, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTruncateMessageForBlock_mixedWrapAndHardLines verifies truncation when
|
||||
// some hard lines wrap and the total exceeds maxLines.
|
||||
func TestTruncateMessageForBlock_mixedWrapAndHardLines(t *testing.T) {
|
||||
// First line: 40 chars at width 20 = 2 visual lines
|
||||
// Second line: "short" = 1 visual line (total: 3, exactly at limit)
|
||||
// Third line: would exceed
|
||||
msg := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\nshort\nextra"
|
||||
got := truncateMessageForBlock(msg, 3, 20)
|
||||
want := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\nshort…"
|
||||
if got != want {
|
||||
t.Fatalf("expected %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRenderQueuedMessages_truncatesLongMessages verifies that the rendered
|
||||
// queued message view truncates long messages instead of showing them in full.
|
||||
func TestRenderQueuedMessages_truncatesLongMessages(t *testing.T) {
|
||||
ctrl := &stubAppController{}
|
||||
m, _, _ := newTestAppModel(ctrl)
|
||||
m.width = 80
|
||||
|
||||
// Queue a very long message (20 lines).
|
||||
var b strings.Builder
|
||||
for i := range 20 {
|
||||
if i > 0 {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
b.WriteString("This is a long line of text for testing purposes")
|
||||
}
|
||||
m.queuedMessages = []string{b.String()}
|
||||
|
||||
rendered := m.renderQueuedMessages()
|
||||
if rendered == "" {
|
||||
t.Fatal("expected non-empty rendered output")
|
||||
}
|
||||
|
||||
// The full message would be ~20+ lines. With truncation to 3 content
|
||||
// lines + badge + padding, it should be much shorter.
|
||||
lines := len(strings.Split(rendered, "\n"))
|
||||
// 3 content lines + 1 badge + 2 padding + border overhead ≈ ~7 lines max
|
||||
if lines > 10 {
|
||||
t.Fatalf("expected truncated output to be ≤10 lines, got %d lines", lines)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ func newInputPrompt(message, placeholder, defaultValue string, width, height int
|
||||
ta.Placeholder = placeholder
|
||||
ta.ShowLineNumbers = false
|
||||
ta.Prompt = ""
|
||||
ta.CharLimit = 1000
|
||||
ta.CharLimit = 0
|
||||
ta.SetWidth(width - 12) // account for border + padding
|
||||
ta.SetHeight(1)
|
||||
ta.Focus()
|
||||
|
||||
@@ -14,11 +14,19 @@ import (
|
||||
)
|
||||
|
||||
// UserBlock renders a user message with herald Tip styling.
|
||||
func UserBlock(content string, ty *herald.Typography, theme style.Theme) string {
|
||||
// The width parameter controls line wrapping so long messages don't overflow.
|
||||
func UserBlock(content string, width int, ty *herald.Typography, theme style.Theme) string {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
content = "(empty message)"
|
||||
}
|
||||
|
||||
// Wrap content before passing to herald Alert so long lines break
|
||||
// inside the alert box. Subtract 4 to account for the alert bar
|
||||
// prefix ("│ ") and a small margin.
|
||||
if width > 4 {
|
||||
content = lipgloss.Wrap(content, width-4, "")
|
||||
}
|
||||
|
||||
rendered := ty.Tip(content)
|
||||
return styleMarginBottom(theme, rendered)
|
||||
}
|
||||
|
||||
@@ -85,11 +85,13 @@ func GetMarkdownTypography() *herald.Typography {
|
||||
return ty
|
||||
}
|
||||
|
||||
// ToMarkdown renders markdown content using herald-md.
|
||||
// The width parameter is currently unused as herald handles wrapping
|
||||
// based on terminal width internally.
|
||||
// ToMarkdown renders markdown content using herald-md and wraps the result
|
||||
// to the given width so that long lines do not overflow the terminal.
|
||||
func ToMarkdown(content string, width int) string {
|
||||
ty := GetMarkdownTypography()
|
||||
rendered := heraldmd.Render(ty, []byte(content))
|
||||
if width > 0 {
|
||||
rendered = lipgloss.Wrap(rendered, width, "")
|
||||
}
|
||||
return rendered
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ func NewToolApprovalInput(toolName, toolArgs string, width int) *ToolApprovalInp
|
||||
ta := textarea.New()
|
||||
ta.Placeholder = ""
|
||||
ta.ShowLineNumbers = false
|
||||
ta.CharLimit = 1000
|
||||
ta.CharLimit = 0
|
||||
ta.SetWidth(width - 8) // Account for container padding, border and internal padding
|
||||
ta.SetHeight(4) // Default to 3 lines like huh
|
||||
ta.Focus()
|
||||
|
||||
@@ -134,23 +134,28 @@ func (ut *UsageTracker) EstimateAndUpdateUsage(inputText, outputText string) {
|
||||
}
|
||||
|
||||
// SetContextTokens records the approximate current context window utilization.
|
||||
// This should be set from FinalUsage.InputTokens, which already includes the
|
||||
// full conversation history (system prompt + all previous messages). Do NOT
|
||||
// add OutputTokens as that would double-count (output becomes input next turn).
|
||||
// Use FinalResponse.Usage rather than aggregate TotalUsage, because TotalUsage
|
||||
// sums across all tool-calling steps and overstates the actual window fill level.
|
||||
//
|
||||
// The value should include ALL token categories from the last API call:
|
||||
//
|
||||
// InputTokens + CacheReadTokens + CacheCreationTokens + OutputTokens
|
||||
//
|
||||
// With Anthropic prompt caching, InputTokens can be near-zero while
|
||||
// CacheReadTokens holds the bulk of the context. All four must be summed
|
||||
// to get the true context window fill level.
|
||||
//
|
||||
// OutputTokens is included because the assistant's output becomes part of
|
||||
// the context on the next turn.
|
||||
//
|
||||
// Use FinalResponse.Usage (last step only) rather than aggregate TotalUsage,
|
||||
// because TotalUsage sums across all tool-calling steps and overstates the
|
||||
// actual window fill level.
|
||||
//
|
||||
// The value is set unconditionally (not max-only) so that context shrinks
|
||||
// correctly after compaction.
|
||||
func (ut *UsageTracker) SetContextTokens(tokens int) {
|
||||
ut.mu.Lock()
|
||||
defer ut.mu.Unlock()
|
||||
// Track the maximum context seen so far. In multi-step tool calls,
|
||||
// FinalUsage.InputTokens may reflect only the last step's input, which
|
||||
// can be smaller than previous steps. We want to show the largest context
|
||||
// the model has processed in this session.
|
||||
if tokens > ut.contextTokens {
|
||||
ut.contextTokens = tokens
|
||||
}
|
||||
// If tokens < current, we keep the larger value (no-op)
|
||||
// This prevents the display from dropping during multi-step tool calls.
|
||||
ut.contextTokens = tokens
|
||||
}
|
||||
|
||||
// RenderUsageInfo generates a formatted string displaying current usage statistics
|
||||
|
||||
@@ -31,6 +31,11 @@ func (m *Kit) EstimateContextTokens() int {
|
||||
// limit and should be compacted.
|
||||
// Formula: contextTokens > contextWindow − reserveTokens.
|
||||
// Returns false if the model's context limit is unknown.
|
||||
//
|
||||
// When API-reported token counts are available (after at least one turn),
|
||||
// the real count is used instead of the text-based heuristic. This is
|
||||
// significantly more accurate because it includes system prompts, tool
|
||||
// definitions, and other overhead that the heuristic cannot account for.
|
||||
func (m *Kit) ShouldCompact() bool {
|
||||
info := m.GetModelInfo()
|
||||
if info == nil || info.Limit.Context <= 0 {
|
||||
@@ -42,6 +47,16 @@ func (m *Kit) ShouldCompact() bool {
|
||||
reserveTokens = m.compactionOpts.ReserveTokens
|
||||
}
|
||||
|
||||
// Prefer the real API-reported token count when available.
|
||||
m.lastInputTokensMu.RLock()
|
||||
realTokens := m.lastInputTokens
|
||||
m.lastInputTokensMu.RUnlock()
|
||||
|
||||
if realTokens > 0 {
|
||||
return realTokens > info.Limit.Context-reserveTokens
|
||||
}
|
||||
|
||||
// Fall back to text-based heuristic before first turn completes.
|
||||
messages := m.session.GetMessages()
|
||||
return compaction.ShouldCompact(convertKitMessagesToFantasy(messages), info.Limit.Context, reserveTokens)
|
||||
}
|
||||
@@ -245,6 +260,14 @@ func (m *Kit) persistAndEmitCompaction(
|
||||
); err != nil {
|
||||
return fmt.Errorf("failed to persist compaction entry: %w", err)
|
||||
}
|
||||
|
||||
// Reset the API-reported token count so GetContextStats() and
|
||||
// ShouldCompact() don't use stale pre-compaction values. The next
|
||||
// API call will set the accurate post-compaction count.
|
||||
m.lastInputTokensMu.Lock()
|
||||
m.lastInputTokens = 0
|
||||
m.lastInputTokensMu.Unlock()
|
||||
|
||||
m.events.emit(CompactionEvent{
|
||||
Summary: summary,
|
||||
OriginalTokens: originalTokens,
|
||||
|
||||
+310
-80
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -11,7 +12,6 @@ import (
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
charmlog "github.com/charmbracelet/log"
|
||||
|
||||
"github.com/mark3labs/kit/internal/agent"
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
@@ -51,6 +51,12 @@ type Kit struct {
|
||||
authHandler MCPAuthHandler // OAuth handler for remote MCP servers (may need Close)
|
||||
opts *Options // stored for reload operations (skills, etc.)
|
||||
|
||||
// hasCustomSystemPrompt is true when the user explicitly configured a
|
||||
// system prompt (via --system-prompt flag, config file, or SDK option).
|
||||
// When false, per-model system prompts from modelSettings/customModels
|
||||
// can replace the default prompt on model switch.
|
||||
hasCustomSystemPrompt bool
|
||||
|
||||
// Hook registries — interception layer (see hooks.go).
|
||||
beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult]
|
||||
afterToolResult *hookRegistry[AfterToolResultHook, AfterToolResultResult]
|
||||
@@ -140,6 +146,79 @@ func (m *Kit) MCPToolsReady() bool {
|
||||
return m.agent.MCPToolsReady()
|
||||
}
|
||||
|
||||
// MCPServerStatus describes the runtime state of a loaded MCP server.
|
||||
type MCPServerStatus struct {
|
||||
// Name is the configured server name.
|
||||
Name string
|
||||
// ToolCount is the number of tools loaded from this server.
|
||||
ToolCount int
|
||||
}
|
||||
|
||||
// AddMCPServer connects to a new MCP server at runtime and makes its tools
|
||||
// available to the agent immediately. The server's tools are prefixed with the
|
||||
// server name (e.g. "myserver__tool_name") to avoid naming conflicts, matching
|
||||
// the behaviour of servers loaded at initialization.
|
||||
//
|
||||
// Returns the number of tools loaded from the server.
|
||||
//
|
||||
// AddMCPServer is safe to call while the agent is idle. If a turn is in
|
||||
// progress ([Kit.IsGenerating] returns true), the new tools will be visible
|
||||
// starting from the next LLM step.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// n, err := k.AddMCPServer(ctx, "github", kit.MCPServerConfig{
|
||||
// Command: []string{"npx", "-y", "@modelcontextprotocol/server-github"},
|
||||
// Environment: map[string]string{"GITHUB_TOKEN": os.Getenv("GITHUB_TOKEN")},
|
||||
// })
|
||||
func (m *Kit) AddMCPServer(ctx context.Context, name string, cfg MCPServerConfig) (int, error) {
|
||||
return m.agent.AddMCPServer(ctx, name, cfg)
|
||||
}
|
||||
|
||||
// RemoveMCPServer disconnects an MCP server and removes all its tools from
|
||||
// the agent. After this call the agent will no longer see or be able to call
|
||||
// tools from the named server.
|
||||
//
|
||||
// RemoveMCPServer is safe to call while the agent is idle. If a turn is in
|
||||
// progress, the tools are removed at the next LLM step. Any in-flight tool
|
||||
// calls to the removed server will fail gracefully.
|
||||
//
|
||||
// Returns an error if the named server is not currently loaded.
|
||||
func (m *Kit) RemoveMCPServer(name string) error {
|
||||
return m.agent.RemoveMCPServer(name)
|
||||
}
|
||||
|
||||
// ListMCPServers returns the status of all currently loaded MCP servers.
|
||||
// The returned slice is a snapshot; it is safe to read concurrently.
|
||||
func (m *Kit) ListMCPServers() []MCPServerStatus {
|
||||
names := m.agent.GetLoadedServerNames()
|
||||
if len(names) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build a tool count per server by scanning tool names for the prefix.
|
||||
toolNames := m.GetToolNames()
|
||||
countByServer := make(map[string]int, len(names))
|
||||
for _, tn := range toolNames {
|
||||
for _, sn := range names {
|
||||
prefix := sn + "__"
|
||||
if len(tn) > len(prefix) && tn[:len(prefix)] == prefix {
|
||||
countByServer[sn]++
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := make([]MCPServerStatus, 0, len(names))
|
||||
for _, n := range names {
|
||||
result = append(result, MCPServerStatus{
|
||||
Name: n,
|
||||
ToolCount: countByServer[n],
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetExtensionToolCount returns the number of tools registered by extensions.
|
||||
func (m *Kit) GetExtensionToolCount() int {
|
||||
return m.agent.GetExtensionToolCount()
|
||||
@@ -221,9 +300,12 @@ func iterBranchMessages[T any](tm *session.TreeManager, fn func(*session.Message
|
||||
return results
|
||||
}
|
||||
|
||||
// SetModel changes the active model at runtime. The existing tools, system
|
||||
// prompt, and session are preserved. The model string should be in
|
||||
// "provider/model" format (e.g. "anthropic/claude-sonnet-4-5-20250929").
|
||||
// SetModel changes the active model at runtime. The existing tools and
|
||||
// session are preserved. When the new model has a per-model system prompt
|
||||
// (from modelSettings or customModels params), it is composed with the
|
||||
// current AGENTS.md context and skills before being applied.
|
||||
// The model string should be in "provider/model" format
|
||||
// (e.g. "anthropic/claude-sonnet-4-5-20250929").
|
||||
// Returns an error if the model string is invalid or the provider cannot
|
||||
// be created.
|
||||
func (m *Kit) SetModel(ctx context.Context, modelString string) error {
|
||||
@@ -239,7 +321,7 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
|
||||
|
||||
// With message-level caching, thinking and caching can work together.
|
||||
// No need to disable caching when thinking is enabled.
|
||||
config := &models.ProviderConfig{
|
||||
cfg := &models.ProviderConfig{
|
||||
ModelString: modelString,
|
||||
SystemPrompt: systemPrompt,
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
@@ -249,18 +331,50 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
|
||||
ThinkingLevel: thinkingLevel,
|
||||
DisableCaching: false, // Caching enabled by default, works with thinking
|
||||
}
|
||||
temperature := float32(viper.GetFloat64("temperature"))
|
||||
config.Temperature = &temperature
|
||||
topP := float32(viper.GetFloat64("top-p"))
|
||||
config.TopP = &topP
|
||||
topK := int32(viper.GetInt("top-k"))
|
||||
config.TopK = &topK
|
||||
frequencyPenalty := float32(viper.GetFloat64("frequency-penalty"))
|
||||
config.FrequencyPenalty = &frequencyPenalty
|
||||
presencePenalty := float32(viper.GetFloat64("presence-penalty"))
|
||||
config.PresencePenalty = &presencePenalty
|
||||
|
||||
if err := m.agent.SetModel(ctx, config); err != nil {
|
||||
// Only set generation parameter pointers when the user has explicitly
|
||||
// provided a value. This leaves nil pointers for unset params, allowing
|
||||
// per-model defaults (modelSettings / customModels params) to apply.
|
||||
if viper.IsSet("temperature") {
|
||||
v := float32(viper.GetFloat64("temperature"))
|
||||
cfg.Temperature = &v
|
||||
}
|
||||
if viper.IsSet("top-p") {
|
||||
v := float32(viper.GetFloat64("top-p"))
|
||||
cfg.TopP = &v
|
||||
}
|
||||
if viper.IsSet("top-k") {
|
||||
v := int32(viper.GetInt("top-k"))
|
||||
cfg.TopK = &v
|
||||
}
|
||||
if viper.IsSet("frequency-penalty") {
|
||||
v := float32(viper.GetFloat64("frequency-penalty"))
|
||||
cfg.FrequencyPenalty = &v
|
||||
}
|
||||
if viper.IsSet("presence-penalty") {
|
||||
v := float32(viper.GetFloat64("presence-penalty"))
|
||||
cfg.PresencePenalty = &v
|
||||
}
|
||||
|
||||
// When the user hasn't set a custom global system prompt, check for a
|
||||
// per-model system prompt. Pre-apply model settings to discover it,
|
||||
// then compose with AGENTS.md context and skills if found.
|
||||
if !m.hasCustomSystemPrompt {
|
||||
// Temporarily clear the system prompt so ApplyModelSettings can
|
||||
// detect that no explicit prompt is set and apply the per-model one.
|
||||
cfg.SystemPrompt = ""
|
||||
models.ApplyModelSettings(cfg, models.LookupModelForSettings(modelString))
|
||||
|
||||
if cfg.SystemPrompt != "" {
|
||||
// Per-model system prompt found — compose with runtime context.
|
||||
cfg.SystemPrompt = m.composeSystemPrompt(cfg.SystemPrompt)
|
||||
} else {
|
||||
// No per-model prompt — restore the global composed prompt.
|
||||
cfg.SystemPrompt = systemPrompt
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.agent.SetModel(ctx, cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -276,6 +390,32 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// composeSystemPrompt takes a base system prompt and composes it with the
|
||||
// current runtime context: AGENTS.md content, skills metadata, and date/cwd.
|
||||
// This mirrors the composition done during Kit.New() initialization.
|
||||
func (m *Kit) composeSystemPrompt(basePrompt string) string {
|
||||
cwd, _ := os.Getwd()
|
||||
pb := skills.NewPromptBuilder(basePrompt)
|
||||
|
||||
// Inject AGENTS.md content as project context.
|
||||
for _, cf := range m.contextFiles {
|
||||
pb.WithSection("", fmt.Sprintf("Instructions from: %s\n\n%s", cf.Path, cf.Content))
|
||||
}
|
||||
|
||||
// Inject skills metadata.
|
||||
if len(m.skills) > 0 {
|
||||
pb.WithSkills(m.skills)
|
||||
}
|
||||
|
||||
// Append current date/time and working directory.
|
||||
pb.WithSection("", fmt.Sprintf(
|
||||
"Current date and time: %s\nCurrent working directory: %s",
|
||||
time.Now().Format("Monday, January 2, 2006, 3:04:05 PM MST"), cwd,
|
||||
))
|
||||
|
||||
return pb.Build()
|
||||
}
|
||||
|
||||
// GetAvailableModels returns a list of known models from the registry. Each
|
||||
// entry includes provider, model ID, context limit, and whether the model
|
||||
// supports reasoning. This is an advisory list — models not in the registry
|
||||
@@ -477,6 +617,14 @@ type Options struct {
|
||||
// Skills
|
||||
Skills []string // Explicit skill files/dirs to load (empty = auto-discover)
|
||||
SkillsDir string // Override default project-local skills directory
|
||||
NoSkills bool // Disable skill loading entirely (auto-discovery and explicit)
|
||||
|
||||
// NoExtensions disables Yaegi extension loading entirely.
|
||||
NoExtensions bool
|
||||
|
||||
// NoContextFiles disables automatic loading of project context files
|
||||
// (e.g. AGENTS.md) from the working directory.
|
||||
NoContextFiles bool
|
||||
|
||||
// Compaction
|
||||
AutoCompact bool // Auto-compact when near context limit
|
||||
@@ -497,6 +645,17 @@ type Options struct {
|
||||
// display a URL in a custom UI, redirect to a web app, etc.).
|
||||
MCPAuthHandler MCPAuthHandler
|
||||
|
||||
// MCPTokenStoreFactory, if non-nil, is called to create a token store for
|
||||
// each remote MCP server that requires OAuth. The factory receives the
|
||||
// server's URL and returns a [MCPTokenStore] implementation.
|
||||
//
|
||||
// When nil (default), tokens are persisted to a JSON file at
|
||||
// $XDG_CONFIG_HOME/.kit/mcp_tokens.json (or ~/.config/.kit/mcp_tokens.json).
|
||||
//
|
||||
// Use this to store tokens in a database, encrypt them, keep them
|
||||
// in-memory, or write them to a custom file path.
|
||||
MCPTokenStoreFactory MCPTokenStoreFactory
|
||||
|
||||
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
|
||||
// loading during Kit initialization. The callback receives the server name,
|
||||
// tool count, and any error. Called from a background goroutine; safe to
|
||||
@@ -582,16 +741,17 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
// provider creation, session init) then runs outside the lock, allowing
|
||||
// parallel subagent spawns to proceed concurrently.
|
||||
var (
|
||||
providerConfig *models.ProviderConfig
|
||||
modelString string
|
||||
cwd string
|
||||
contextFiles []*ContextFile
|
||||
loadedSkills []*Skill
|
||||
mcpConfig *config.Config
|
||||
debug bool
|
||||
noExtensions bool
|
||||
maxSteps int
|
||||
streaming bool
|
||||
providerConfig *models.ProviderConfig
|
||||
modelString string
|
||||
cwd string
|
||||
contextFiles []*ContextFile
|
||||
loadedSkills []*Skill
|
||||
mcpConfig *config.Config
|
||||
debug bool
|
||||
noExtensions bool
|
||||
maxSteps int
|
||||
streaming bool
|
||||
hasCustomSystemPrompt bool
|
||||
)
|
||||
|
||||
if err := func() error {
|
||||
@@ -636,19 +796,56 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
}
|
||||
|
||||
// Load context files (AGENTS.md) from the project root.
|
||||
contextFiles = loadContextFiles(cwd)
|
||||
if !opts.NoContextFiles {
|
||||
contextFiles = loadContextFiles(cwd)
|
||||
}
|
||||
|
||||
// Load skills — either from explicit paths or via auto-discovery.
|
||||
var err error
|
||||
loadedSkills, err = loadSkills(opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load skills: %w", err)
|
||||
if !opts.NoSkills {
|
||||
var err error
|
||||
loadedSkills, err = loadSkills(opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load skills: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Always compose the system prompt with runtime context: base prompt +
|
||||
// AGENTS.md context + skills metadata + date/cwd.
|
||||
//
|
||||
// If the configured model has a per-model system prompt (via
|
||||
// modelSettings or customModels params) and the user hasn't
|
||||
// explicitly set system-prompt, use the per-model prompt as the
|
||||
// base instead of the global default.
|
||||
{
|
||||
basePrompt := viper.GetString("system-prompt")
|
||||
|
||||
// Track whether the user explicitly configured a custom system
|
||||
// prompt. When they haven't (basePrompt is the built-in default
|
||||
// or empty), per-model system prompts can replace it on switch.
|
||||
userSetSystemPrompt := basePrompt != "" && basePrompt != defaultSystemPrompt
|
||||
hasCustomSystemPrompt = userSetSystemPrompt
|
||||
|
||||
// Check for per-model system prompt override when no explicit
|
||||
// global system-prompt was configured by the user.
|
||||
if !userSetSystemPrompt {
|
||||
modelStr := viper.GetString("model")
|
||||
if modelStr != "" {
|
||||
if mi := models.LookupModelForSettings(modelStr); mi != nil {
|
||||
var perModelParams *models.GenerationParams
|
||||
// modelSettings takes priority over custom model params.
|
||||
if ms := models.LoadModelSettingsFromConfig(); ms != nil {
|
||||
perModelParams = ms[modelStr]
|
||||
}
|
||||
if perModelParams == nil && mi.Params != nil {
|
||||
perModelParams = mi.Params
|
||||
}
|
||||
if perModelParams != nil && perModelParams.SystemPrompt != "" {
|
||||
basePrompt = models.LoadSystemPromptValue(perModelParams.SystemPrompt)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pb := skills.NewPromptBuilder(basePrompt)
|
||||
|
||||
// Inject AGENTS.md content as project context.
|
||||
@@ -679,7 +876,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
}
|
||||
modelString = viper.GetString("model")
|
||||
debug = viper.GetBool("debug")
|
||||
noExtensions = viper.GetBool("no-extensions")
|
||||
noExtensions = opts.NoExtensions || viper.GetBool("no-extensions")
|
||||
maxSteps = viper.GetInt("max-steps")
|
||||
streaming = viper.GetBool("stream")
|
||||
|
||||
@@ -739,12 +936,19 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
defaultHandler, authErr := NewDefaultMCPAuthHandler()
|
||||
if authErr != nil {
|
||||
// Non-fatal: OAuth just won't be available for remote servers.
|
||||
charmlog.Warn("Failed to create OAuth handler; remote MCP servers requiring auth will fail", "error", authErr)
|
||||
log.Printf("WARN Failed to create OAuth handler; remote MCP servers requiring auth will fail: %v", authErr)
|
||||
} else {
|
||||
setupOpts.AuthHandler = defaultHandler
|
||||
}
|
||||
}
|
||||
|
||||
// Set up custom token store factory for MCP OAuth tokens.
|
||||
// The SDK MCPTokenStoreFactory is structurally identical to
|
||||
// tools.TokenStoreFactory, so it can be assigned directly.
|
||||
if opts.MCPTokenStoreFactory != nil {
|
||||
setupOpts.TokenStoreFactory = tools.TokenStoreFactory(opts.MCPTokenStoreFactory)
|
||||
}
|
||||
|
||||
if opts.CLI != nil {
|
||||
setupOpts.ShowSpinner = opts.CLI.ShowSpinner
|
||||
setupOpts.SpinnerFunc = opts.CLI.SpinnerFunc
|
||||
@@ -774,24 +978,25 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
}
|
||||
|
||||
k := &Kit{
|
||||
agent: agentResult.Agent,
|
||||
session: sessionManager,
|
||||
modelString: modelString,
|
||||
events: newEventBus(),
|
||||
autoCompact: opts.AutoCompact,
|
||||
compactionOpts: opts.CompactionOptions,
|
||||
contextFiles: contextFiles,
|
||||
skills: loadedSkills,
|
||||
extRunner: agentResult.ExtRunner,
|
||||
bufferedLogger: agentResult.BufferedLogger,
|
||||
authHandler: setupOpts.AuthHandler,
|
||||
opts: opts,
|
||||
beforeToolCall: beforeToolCall,
|
||||
afterToolResult: afterToolResult,
|
||||
beforeTurn: beforeTurn,
|
||||
afterTurn: afterTurn,
|
||||
contextPrepare: contextPrepare,
|
||||
beforeCompact: beforeCompact,
|
||||
agent: agentResult.Agent,
|
||||
session: sessionManager,
|
||||
modelString: modelString,
|
||||
events: newEventBus(),
|
||||
autoCompact: opts.AutoCompact,
|
||||
compactionOpts: opts.CompactionOptions,
|
||||
contextFiles: contextFiles,
|
||||
skills: loadedSkills,
|
||||
extRunner: agentResult.ExtRunner,
|
||||
bufferedLogger: agentResult.BufferedLogger,
|
||||
authHandler: setupOpts.AuthHandler,
|
||||
opts: opts,
|
||||
hasCustomSystemPrompt: hasCustomSystemPrompt,
|
||||
beforeToolCall: beforeToolCall,
|
||||
afterToolResult: afterToolResult,
|
||||
beforeTurn: beforeTurn,
|
||||
afterTurn: afterTurn,
|
||||
contextPrepare: contextPrepare,
|
||||
beforeCompact: beforeCompact,
|
||||
}
|
||||
|
||||
// Bridge extension events to SDK hooks.
|
||||
@@ -978,9 +1183,11 @@ type TurnResult struct {
|
||||
// report usage.
|
||||
TotalUsage *LLMUsage
|
||||
|
||||
// FinalUsage is the token usage from the last API call only. Use this
|
||||
// for context window fill estimation (InputTokens + OutputTokens ≈
|
||||
// current context size). Nil if unavailable.
|
||||
// FinalUsage is the token usage from the last API call only. For context
|
||||
// window fill, sum all categories: InputTokens + CacheReadTokens +
|
||||
// CacheCreationTokens + OutputTokens. With prompt caching, InputTokens
|
||||
// alone understates the context (cached tokens are reported separately).
|
||||
// Nil if unavailable.
|
||||
FinalUsage *LLMUsage
|
||||
|
||||
// Messages is the full updated conversation after the turn, including
|
||||
@@ -1309,14 +1516,22 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
|
||||
IsStderr: isStderr,
|
||||
})
|
||||
},
|
||||
// Persist step messages incrementally so that progress survives
|
||||
// crashes and long-running turns don't lose work. Each step's
|
||||
// messages are persisted as a unit: for tool-calling steps this is
|
||||
// the assistant message (with tool_use parts) + tool-role message
|
||||
// (with tool_result parts) as a pair; for the final step it's the
|
||||
// assistant text/reasoning message alone.
|
||||
func(stepMessages []fantasy.Message) {
|
||||
for _, msg := range stepMessages {
|
||||
_, _ = m.session.AppendMessage(msg)
|
||||
}
|
||||
},
|
||||
func(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64) {
|
||||
// Emit step usage event for real-time cost tracking
|
||||
if viper.GetBool("debug") {
|
||||
charmlog.Debug("Kit.generate emitting StepUsageEvent",
|
||||
"input", inputTokens,
|
||||
"output", outputTokens,
|
||||
"cacheRead", cacheReadTokens,
|
||||
"cacheCreate", cacheCreationTokens,
|
||||
log.Printf("DEBUG Kit.generate emitting StepUsageEvent: input=%d output=%d cacheRead=%d cacheCreate=%d",
|
||||
inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens,
|
||||
)
|
||||
}
|
||||
m.events.emit(StepUsageEvent{
|
||||
@@ -1334,11 +1549,17 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
|
||||
// 2. Persist pre-generation messages to the tree session.
|
||||
// 3. Build context from the tree (walks leaf-to-root for current branch).
|
||||
// 4. Emit turn/message start events.
|
||||
// 5. Run generation.
|
||||
// 6. Emit turn/message end events.
|
||||
// 7. Persist post-generation messages (tool calls, results, assistant).
|
||||
// 5. Run generation (messages are persisted incrementally per step).
|
||||
// 6. Persist any remaining messages not covered by incremental persistence.
|
||||
// 7. Emit turn/message end events.
|
||||
// 8. Run AfterTurn hooks.
|
||||
//
|
||||
// During generation, each completed step's messages are persisted immediately
|
||||
// via the onStepMessages callback. Tool calls are always persisted as
|
||||
// call/response pairs (assistant + tool messages together). Reasoning and
|
||||
// text-only assistant messages are persisted as soon as their step completes.
|
||||
// This ensures long-running turns don't lose progress on crash or cancellation.
|
||||
//
|
||||
// promptLabel is the human-readable label emitted in TurnStartEvent.Prompt.
|
||||
// prompt is the raw user text passed to BeforeTurn hooks.
|
||||
func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, preMessages []fantasy.Message) (*TurnResult, error) {
|
||||
@@ -1408,16 +1629,18 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
|
||||
|
||||
result, err := m.generate(ctx, messages)
|
||||
if err != nil {
|
||||
// Persist any messages from completed steps (tool call/result
|
||||
// pairs) so partial progress is not lost. The agent layer only
|
||||
// includes fully-paired tool_use + tool_result messages in
|
||||
// completedStepMessages, so there are no orphaned entries that
|
||||
// would break subsequent API requests. The user message and any
|
||||
// completed work remain in the session; only the in-progress
|
||||
// (pending) message or tool call is discarded.
|
||||
if result != nil && len(result.ConversationMessages) > sentCount {
|
||||
for _, msg := range result.ConversationMessages[sentCount:] {
|
||||
_, _ = m.session.AppendMessage(msg)
|
||||
// Persist any messages from completed steps that were NOT already
|
||||
// persisted incrementally by the onStepMessages callback. The agent
|
||||
// layer only includes fully-paired tool_use + tool_result messages
|
||||
// in completedStepMessages, so there are no orphaned entries that
|
||||
// would break subsequent API requests.
|
||||
if result != nil {
|
||||
newMessages := result.ConversationMessages[sentCount:]
|
||||
alreadyPersisted := result.PersistedMessageCount
|
||||
if alreadyPersisted < len(newMessages) {
|
||||
for _, msg := range newMessages[alreadyPersisted:] {
|
||||
_, _ = m.session.AppendMessage(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
m.events.emit(TurnEndEvent{Error: err})
|
||||
@@ -1428,22 +1651,29 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
|
||||
|
||||
responseText := result.FinalResponse.Content.Text()
|
||||
|
||||
// Persist new messages (tool calls, tool results, assistant response)
|
||||
// BEFORE emitting events so that extension handlers calling
|
||||
// GetContextStats() see up-to-date token counts.
|
||||
// Persist any new messages that were NOT already persisted incrementally
|
||||
// by the onStepMessages callback during generation. This handles the
|
||||
// non-streaming path (where onStepMessages is not called) and any edge
|
||||
// cases where the final response messages weren't covered by step callbacks.
|
||||
if len(result.ConversationMessages) > sentCount {
|
||||
for _, msg := range result.ConversationMessages[sentCount:] {
|
||||
_, _ = m.session.AppendMessage(msg)
|
||||
newMessages := result.ConversationMessages[sentCount:]
|
||||
alreadyPersisted := result.PersistedMessageCount
|
||||
if alreadyPersisted < len(newMessages) {
|
||||
for _, msg := range newMessages[alreadyPersisted:] {
|
||||
_, _ = m.session.AppendMessage(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store the API-reported token count so GetContextStats() matches the
|
||||
// built-in status bar (which uses input + output tokens). The
|
||||
// text-based heuristic misses system prompts, tool definitions, etc.
|
||||
// built-in status bar. The context window is filled by all token
|
||||
// categories: non-cached input, cache reads, cache writes, and output.
|
||||
// With Anthropic prompt caching, InputTokens can be near-zero while
|
||||
// CacheReadTokens/CacheCreationTokens hold the bulk of the context.
|
||||
if result.FinalResponse != nil {
|
||||
u := result.FinalResponse.Usage
|
||||
m.lastInputTokensMu.Lock()
|
||||
m.lastInputTokens = int(u.InputTokens) + int(u.OutputTokens)
|
||||
m.lastInputTokens = int(u.InputTokens) + int(u.CacheReadTokens) + int(u.CacheCreationTokens) + int(u.OutputTokens)
|
||||
m.lastInputTokensMu.Unlock()
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
package kit_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// TestMCPServerStatus_TypeSurface verifies the MCPServerStatus type is
|
||||
// accessible and has the expected fields.
|
||||
func TestMCPServerStatus_TypeSurface(t *testing.T) {
|
||||
s := kit.MCPServerStatus{
|
||||
Name: "test-server",
|
||||
ToolCount: 5,
|
||||
}
|
||||
if s.Name != "test-server" {
|
||||
t.Errorf("Expected Name 'test-server', got %q", s.Name)
|
||||
}
|
||||
if s.ToolCount != 5 {
|
||||
t.Errorf("Expected ToolCount 5, got %d", s.ToolCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPServerConfig_ForDynamicAdd verifies that MCPServerConfig can be
|
||||
// constructed with the expected fields for dynamic server management.
|
||||
func TestMCPServerConfig_ForDynamicAdd(t *testing.T) {
|
||||
// Stdio server config.
|
||||
stdio := kit.MCPServerConfig{
|
||||
Command: []string{"npx", "-y", "@modelcontextprotocol/server-github"},
|
||||
Environment: map[string]string{"GITHUB_TOKEN": "test-token"},
|
||||
}
|
||||
if len(stdio.Command) != 3 {
|
||||
t.Errorf("Expected 3 command parts, got %d", len(stdio.Command))
|
||||
}
|
||||
if stdio.Environment["GITHUB_TOKEN"] != "test-token" {
|
||||
t.Error("Expected GITHUB_TOKEN in environment")
|
||||
}
|
||||
|
||||
// Remote server config.
|
||||
remote := kit.MCPServerConfig{
|
||||
URL: "https://mcp.example.com/sse",
|
||||
Headers: []string{"Authorization: Bearer test"},
|
||||
}
|
||||
if remote.URL != "https://mcp.example.com/sse" {
|
||||
t.Errorf("Unexpected URL: %s", remote.URL)
|
||||
}
|
||||
|
||||
// Config with tool filtering.
|
||||
filtered := kit.MCPServerConfig{
|
||||
Command: []string{"some-server"},
|
||||
AllowedTools: []string{"read", "write"},
|
||||
}
|
||||
if len(filtered.AllowedTools) != 2 {
|
||||
t.Errorf("Expected 2 allowed tools, got %d", len(filtered.AllowedTools))
|
||||
}
|
||||
}
|
||||
@@ -6,9 +6,21 @@ import (
|
||||
|
||||
// SessionManager defines the contract for conversation storage backends.
|
||||
// Implementations can use files (default), databases, cloud storage, etc.
|
||||
//
|
||||
// Implementations must be safe for concurrent use. During generation,
|
||||
// AppendMessage is called incrementally from the agent's step-completion
|
||||
// callback while read methods (GetMessages, GetCurrentBranch, etc.) may be
|
||||
// called concurrently from the UI or extension goroutines.
|
||||
type SessionManager interface {
|
||||
// AppendMessage adds a message to the current branch and returns its entry ID.
|
||||
// The entry ID is used for tree navigation and must be unique within the session.
|
||||
//
|
||||
// During generation, AppendMessage is called incrementally after each
|
||||
// completed agent step rather than in a batch at the end of the turn.
|
||||
// For tool-calling steps, the assistant message (containing tool_use parts)
|
||||
// and the tool-role message (containing tool_result parts) are appended
|
||||
// together as a pair. This ensures the session never contains an orphaned
|
||||
// tool call without its result, which would break subsequent LLM requests.
|
||||
AppendMessage(msg LLMMessage) (entryID string, err error)
|
||||
|
||||
// GetMessages returns all messages on the current branch (from root to leaf),
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/mark3labs/kit/internal/core"
|
||||
@@ -16,6 +18,123 @@ type ToolOption = core.ToolOption
|
||||
// If empty, os.Getwd() is used at execution time.
|
||||
var WithWorkDir = core.WithWorkDir
|
||||
|
||||
// --- Custom tool creation ---
|
||||
|
||||
// ToolOutput is the return value from custom tool handlers created with
|
||||
// [NewTool] or [NewParallelTool]. It provides a dependency-free way to
|
||||
// return results without importing the underlying LLM framework.
|
||||
type ToolOutput struct {
|
||||
// Content is the text content returned to the LLM.
|
||||
Content string
|
||||
|
||||
// IsError, when true, signals to the LLM that the tool call failed.
|
||||
IsError bool
|
||||
|
||||
// Data contains optional binary data (images, audio, etc.).
|
||||
Data []byte
|
||||
|
||||
// MediaType is the MIME type for binary Data (e.g. "image/png").
|
||||
MediaType string
|
||||
|
||||
// Metadata is optional opaque metadata attached to the response.
|
||||
// It is not sent to the LLM but may be consumed by hooks or the UI.
|
||||
Metadata any
|
||||
}
|
||||
|
||||
// TextResult creates a successful text [ToolOutput].
|
||||
func TextResult(content string) ToolOutput {
|
||||
return ToolOutput{Content: content}
|
||||
}
|
||||
|
||||
// ErrorResult creates an error [ToolOutput]. The LLM will see the content
|
||||
// as a tool error, allowing it to retry or adjust its approach.
|
||||
func ErrorResult(content string) ToolOutput {
|
||||
return ToolOutput{Content: content, IsError: true}
|
||||
}
|
||||
|
||||
// toolCallIDKey is the context key for the tool call ID.
|
||||
type toolCallIDKey struct{}
|
||||
|
||||
// ToolCallIDFromContext extracts the tool call ID from the context.
|
||||
// The call ID is set automatically by [NewTool] and [NewParallelTool]
|
||||
// before invoking the handler. Returns an empty string if no ID is present.
|
||||
func ToolCallIDFromContext(ctx context.Context) string {
|
||||
s, _ := ctx.Value(toolCallIDKey{}).(string)
|
||||
return s
|
||||
}
|
||||
|
||||
// NewTool creates a custom [Tool] with automatic JSON schema generation from
|
||||
// the TInput struct type. The handler receives a typed input (deserialized
|
||||
// from the LLM's JSON arguments) and returns a [ToolResult].
|
||||
//
|
||||
// Struct tags on TInput control the generated schema:
|
||||
//
|
||||
// json:"name" → parameter name
|
||||
// description:"..." → parameter description shown to the LLM
|
||||
// enum:"a,b,c" → restrict valid values
|
||||
// omitempty → marks the parameter as optional
|
||||
//
|
||||
// The tool call ID is injected into the context and can be retrieved with
|
||||
// [ToolCallIDFromContext].
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type WeatherInput struct {
|
||||
// City string `json:"city" description:"City name"`
|
||||
// }
|
||||
//
|
||||
// tool := kit.NewTool("get_weather", "Get weather for a city",
|
||||
// func(ctx context.Context, input WeatherInput) (kit.ToolResult, error) {
|
||||
// return kit.TextResult("72°F, sunny in " + input.City), nil
|
||||
// },
|
||||
// )
|
||||
func NewTool[TInput any](name, description string, fn func(ctx context.Context, input TInput) (ToolOutput, error)) Tool {
|
||||
return fantasy.NewAgentTool(name, description,
|
||||
func(ctx context.Context, input TInput, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
ctx = context.WithValue(ctx, toolCallIDKey{}, call.ID)
|
||||
result, err := fn(ctx, input)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
resp := fantasy.ToolResponse{
|
||||
Content: result.Content,
|
||||
IsError: result.IsError,
|
||||
Data: result.Data,
|
||||
MediaType: result.MediaType,
|
||||
}
|
||||
if result.Metadata != nil {
|
||||
resp = fantasy.WithResponseMetadata(resp, result.Metadata)
|
||||
}
|
||||
return resp, nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// NewParallelTool is like [NewTool] but marks the tool as safe for concurrent
|
||||
// execution alongside other tools. Use this when the tool has no side effects
|
||||
// or when concurrent calls are safe.
|
||||
func NewParallelTool[TInput any](name, description string, fn func(ctx context.Context, input TInput) (ToolOutput, error)) Tool {
|
||||
return fantasy.NewParallelAgentTool(name, description,
|
||||
func(ctx context.Context, input TInput, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
ctx = context.WithValue(ctx, toolCallIDKey{}, call.ID)
|
||||
result, err := fn(ctx, input)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
resp := fantasy.ToolResponse{
|
||||
Content: result.Content,
|
||||
IsError: result.IsError,
|
||||
Data: result.Data,
|
||||
MediaType: result.MediaType,
|
||||
}
|
||||
if result.Metadata != nil {
|
||||
resp = fantasy.WithResponseMetadata(resp, result.Metadata)
|
||||
}
|
||||
return resp, nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// --- Individual tool constructors ---
|
||||
|
||||
// NewReadTool creates a file-reading tool.
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
package kit_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// TestNewTool_BasicTextResult verifies that NewTool creates a working tool
|
||||
// that returns text content via ToolOutput.
|
||||
func TestNewTool_BasicTextResult(t *testing.T) {
|
||||
type Input struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
tool := kit.NewTool("greet", "Greet someone",
|
||||
func(ctx context.Context, input Input) (kit.ToolOutput, error) {
|
||||
return kit.TextResult("hello " + input.Name), nil
|
||||
},
|
||||
)
|
||||
|
||||
info := tool.Info()
|
||||
if info.Name != "greet" {
|
||||
t.Errorf("Info().Name = %q, want %q", info.Name, "greet")
|
||||
}
|
||||
if info.Description != "Greet someone" {
|
||||
t.Errorf("Info().Description = %q, want %q", info.Description, "Greet someone")
|
||||
}
|
||||
if info.Parallel {
|
||||
t.Error("NewTool should not mark tool as parallel")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewParallelTool_MarkedParallel verifies that NewParallelTool marks the
|
||||
// tool as safe for concurrent execution.
|
||||
func TestNewParallelTool_MarkedParallel(t *testing.T) {
|
||||
type Input struct {
|
||||
Query string `json:"query"`
|
||||
}
|
||||
|
||||
tool := kit.NewParallelTool("search", "Search for things",
|
||||
func(ctx context.Context, input Input) (kit.ToolOutput, error) {
|
||||
return kit.TextResult("found: " + input.Query), nil
|
||||
},
|
||||
)
|
||||
|
||||
info := tool.Info()
|
||||
if info.Name != "search" {
|
||||
t.Errorf("Info().Name = %q, want %q", info.Name, "search")
|
||||
}
|
||||
if !info.Parallel {
|
||||
t.Error("NewParallelTool should mark tool as parallel")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTextResult verifies the TextResult convenience constructor.
|
||||
func TestTextResult(t *testing.T) {
|
||||
r := kit.TextResult("ok")
|
||||
if r.Content != "ok" {
|
||||
t.Errorf("Content = %q, want %q", r.Content, "ok")
|
||||
}
|
||||
if r.IsError {
|
||||
t.Error("TextResult should not set IsError")
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorResult verifies the ErrorResult convenience constructor.
|
||||
func TestErrorResult(t *testing.T) {
|
||||
r := kit.ErrorResult("bad input")
|
||||
if r.Content != "bad input" {
|
||||
t.Errorf("Content = %q, want %q", r.Content, "bad input")
|
||||
}
|
||||
if !r.IsError {
|
||||
t.Error("ErrorResult should set IsError")
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolCallIDFromContext verifies round-trip context injection.
|
||||
func TestToolCallIDFromContext(t *testing.T) {
|
||||
// Empty context returns empty string.
|
||||
if id := kit.ToolCallIDFromContext(context.Background()); id != "" {
|
||||
t.Errorf("expected empty string from bare context, got %q", id)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolOutput_Metadata verifies that metadata can be set on ToolOutput.
|
||||
func TestToolOutput_Metadata(t *testing.T) {
|
||||
r := kit.ToolOutput{
|
||||
Content: "data",
|
||||
Metadata: map[string]string{"key": "value"},
|
||||
}
|
||||
if r.Metadata == nil {
|
||||
t.Error("expected non-nil Metadata")
|
||||
}
|
||||
m, ok := r.Metadata.(map[string]string)
|
||||
if !ok {
|
||||
t.Fatalf("expected map[string]string, got %T", r.Metadata)
|
||||
}
|
||||
if m["key"] != "value" {
|
||||
t.Errorf("Metadata[key] = %q, want %q", m["key"], "value")
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolOutput_BinaryData verifies that binary data fields work correctly.
|
||||
func TestToolOutput_BinaryData(t *testing.T) {
|
||||
data := []byte{0x89, 0x50, 0x4E, 0x47}
|
||||
r := kit.ToolOutput{
|
||||
Content: "image result",
|
||||
Data: data,
|
||||
MediaType: "image/png",
|
||||
}
|
||||
if len(r.Data) != 4 {
|
||||
t.Errorf("Data len = %d, want 4", len(r.Data))
|
||||
}
|
||||
if r.MediaType != "image/png" {
|
||||
t.Errorf("MediaType = %q, want %q", r.MediaType, "image/png")
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/mark3labs/kit/internal/message"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
"github.com/mark3labs/kit/internal/session"
|
||||
"github.com/mark3labs/mcp-go/client/transport"
|
||||
)
|
||||
|
||||
// ==== Message Types (internal/message/content.go) ====
|
||||
@@ -204,6 +205,29 @@ type CompactionResult = compaction.CompactionResult
|
||||
// CompactionOptions configures compaction behaviour.
|
||||
type CompactionOptions = compaction.CompactionOptions
|
||||
|
||||
// ==== MCP OAuth Types ====
|
||||
|
||||
// MCPTokenStore persists OAuth tokens for a single MCP server. Implementations
|
||||
// must be safe for concurrent use.
|
||||
//
|
||||
// This is a type alias for the mcp-go transport.TokenStore interface. SDK
|
||||
// consumers can implement this interface to provide custom storage backends
|
||||
// (database, encrypted file, in-memory, etc.).
|
||||
type MCPTokenStore = transport.TokenStore
|
||||
|
||||
// MCPToken represents an OAuth token for an MCP server, containing access
|
||||
// and refresh tokens along with expiration metadata.
|
||||
type MCPToken = transport.Token
|
||||
|
||||
// MCPTokenStoreFactory creates an [MCPTokenStore] for a given MCP server URL.
|
||||
// It is called once per remote MCP server during connection setup.
|
||||
type MCPTokenStoreFactory func(serverURL string) (MCPTokenStore, error)
|
||||
|
||||
// ErrMCPNoToken is the sentinel error that [MCPTokenStore] implementations
|
||||
// should return from GetToken when no token is stored for the server.
|
||||
// Callers can check for this with errors.Is.
|
||||
var ErrMCPNoToken = transport.ErrNoToken
|
||||
|
||||
// ==== Constructor & Helper Functions ====
|
||||
|
||||
// ParseModelString parses a model string in "provider/model" format.
|
||||
|
||||
@@ -347,6 +347,77 @@ Lower values run first. Within the same priority, registration order applies. Fi
|
||||
|
||||
## Tools
|
||||
|
||||
### Creating custom tools
|
||||
|
||||
Use `kit.NewTool` to create custom tools. The JSON schema is auto-generated from the input struct — no external dependencies required:
|
||||
|
||||
```go
|
||||
type WeatherInput struct {
|
||||
City string `json:"city" description:"City name, e.g. 'San Francisco'"`
|
||||
}
|
||||
|
||||
weatherTool := kit.NewTool("get_weather", "Get current weather for a city",
|
||||
func(ctx context.Context, input WeatherInput) (kit.ToolOutput, error) {
|
||||
// Your logic here (API calls, database lookups, etc.)
|
||||
return kit.TextResult("72°F, sunny in " + input.City), nil
|
||||
},
|
||||
)
|
||||
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
ExtraTools: []kit.Tool{weatherTool},
|
||||
})
|
||||
```
|
||||
|
||||
**Struct tags** control the generated schema:
|
||||
|
||||
| Tag | Purpose | Example |
|
||||
|-----|---------|---------|
|
||||
| `json:"name"` | Parameter name | `json:"city"` |
|
||||
| `description:"..."` | Description shown to the LLM | `description:"City name"` |
|
||||
| `enum:"a,b,c"` | Restrict valid values | `enum:"json,text,csv"` |
|
||||
| `omitempty` | Marks parameter as optional | `json:"limit,omitempty"` |
|
||||
|
||||
**Return helpers:**
|
||||
|
||||
| Function | Description |
|
||||
|----------|-------------|
|
||||
| `kit.TextResult(content)` | Successful text result |
|
||||
| `kit.ErrorResult(content)` | Error result (LLM sees it as a tool error) |
|
||||
|
||||
**ToolOutput fields** (for advanced use):
|
||||
|
||||
```go
|
||||
kit.ToolOutput{
|
||||
Content: "result text", // text returned to the LLM
|
||||
IsError: false, // true = LLM sees this as an error
|
||||
Data: pngBytes, // optional binary data (images, audio)
|
||||
MediaType: "image/png", // MIME type for binary Data
|
||||
Metadata: map[string]any{}, // opaque metadata for hooks/UI (not sent to LLM)
|
||||
}
|
||||
```
|
||||
|
||||
**Parallel tools** — mark as safe for concurrent execution:
|
||||
|
||||
```go
|
||||
searchTool := kit.NewParallelTool("search", "Search the web",
|
||||
func(ctx context.Context, input SearchInput) (kit.ToolOutput, error) {
|
||||
return kit.TextResult("results..."), nil
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
**Tool call ID** — available in context for logging/tracing:
|
||||
|
||||
```go
|
||||
tool := kit.NewTool("my_tool", "...",
|
||||
func(ctx context.Context, input MyInput) (kit.ToolOutput, error) {
|
||||
callID := kit.ToolCallIDFromContext(ctx) // correlation ID from the LLM
|
||||
log.Printf("[%s] my_tool called", callID)
|
||||
return kit.TextResult("ok"), nil
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
### Built-in tool constructors
|
||||
|
||||
```go
|
||||
|
||||
+49
-21
@@ -7,17 +7,16 @@ description: Monitor tool calls and streaming output with the Kit Go SDK.
|
||||
|
||||
## Event-based monitoring
|
||||
|
||||
For more granular control, use the event subscription API:
|
||||
Subscribe to events for real-time monitoring. Each method returns an unsubscribe function:
|
||||
|
||||
```go
|
||||
// Subscribe returns an unsubscribe function
|
||||
unsub := host.OnToolCall(func(event kit.ToolCallEvent) {
|
||||
fmt.Printf("Tool: %s, Args: %s\n", event.Name, event.Args)
|
||||
fmt.Printf("Tool: %s, Args: %s\n", event.ToolName, event.ToolArgs)
|
||||
})
|
||||
defer unsub()
|
||||
|
||||
unsub2 := host.OnToolResult(func(event kit.ToolResultEvent) {
|
||||
fmt.Printf("Result: %s (error: %v)\n", event.Name, event.IsError)
|
||||
fmt.Printf("Result: %s (error: %v)\n", event.ToolName, event.IsError)
|
||||
})
|
||||
defer unsub2()
|
||||
|
||||
@@ -44,33 +43,62 @@ defer unsub6()
|
||||
|
||||
## Hook system
|
||||
|
||||
Hooks allow you to intercept and modify behavior. Unlike events, hooks can modify or cancel operations:
|
||||
Hooks can **modify or cancel** operations. Unlike events (read-only), hooks are read-write interceptors.
|
||||
|
||||
### BeforeToolCall — block tool execution
|
||||
|
||||
```go
|
||||
// Intercept tool calls before execution
|
||||
host.OnBeforeToolCall(0, func(ctx context.Context, name string, args string) (string, error) {
|
||||
if name == "bash" {
|
||||
log.Println("Bash command:", args)
|
||||
host.OnBeforeToolCall(kit.HookPriorityNormal, func(h kit.BeforeToolCallHook) *kit.BeforeToolCallResult {
|
||||
// h.ToolCallID, h.ToolName, h.ToolArgs
|
||||
if h.ToolName == "bash" && strings.Contains(h.ToolArgs, "rm -rf") {
|
||||
return &kit.BeforeToolCallResult{Block: true, Reason: "dangerous command"}
|
||||
}
|
||||
return args, nil // return modified args or error to cancel
|
||||
return nil // allow
|
||||
})
|
||||
```
|
||||
|
||||
// Process results after tool execution
|
||||
host.OnAfterToolResult(0, func(ctx context.Context, name string, result string) (string, error) {
|
||||
return result, nil
|
||||
})
|
||||
### AfterToolResult — modify tool output
|
||||
|
||||
// Before/after each agent turn
|
||||
host.OnBeforeTurn(0, func(ctx context.Context) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
host.OnAfterTurn(0, func(ctx context.Context) error {
|
||||
```go
|
||||
host.OnAfterToolResult(kit.HookPriorityNormal, func(h kit.AfterToolResultHook) *kit.AfterToolResultResult {
|
||||
// h.ToolCallID, h.ToolName, h.ToolArgs, h.Result, h.IsError
|
||||
if h.ToolName == "read" {
|
||||
filtered := redactSecrets(h.Result)
|
||||
return &kit.AfterToolResultResult{Result: &filtered}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
The first argument is a priority (lower = runs first).
|
||||
### BeforeTurn — modify prompt, inject messages
|
||||
|
||||
```go
|
||||
host.OnBeforeTurn(kit.HookPriorityNormal, func(h kit.BeforeTurnHook) *kit.BeforeTurnResult {
|
||||
// h.Prompt
|
||||
newPrompt := h.Prompt + "\nAlways respond in JSON."
|
||||
return &kit.BeforeTurnResult{Prompt: &newPrompt}
|
||||
// Also available: SystemPrompt *string, InjectText *string
|
||||
})
|
||||
```
|
||||
|
||||
### AfterTurn — observation only
|
||||
|
||||
```go
|
||||
host.OnAfterTurn(kit.HookPriorityNormal, func(h kit.AfterTurnHook) {
|
||||
// h.Response, h.Error
|
||||
log.Printf("Turn completed: %d chars", len(h.Response))
|
||||
})
|
||||
```
|
||||
|
||||
### Hook priorities
|
||||
|
||||
```go
|
||||
kit.HookPriorityHigh = 0 // runs first
|
||||
kit.HookPriorityNormal = 50 // default
|
||||
kit.HookPriorityLow = 100 // runs last
|
||||
```
|
||||
|
||||
Lower values run first. First non-nil result wins.
|
||||
|
||||
## Subagent event monitoring
|
||||
|
||||
|
||||
@@ -68,3 +68,28 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
| `CompactionOptions` | `*CompactionOptions` | — | Configuration for auto-compaction |
|
||||
| `Skills` | `[]string` | — | Explicit skill files/dirs to load |
|
||||
| `SkillsDir` | `string` | — | Override default skills directory |
|
||||
|
||||
## Tool configuration
|
||||
|
||||
**`Tools`** replaces ALL default tools (core + MCP + extension). **`ExtraTools`** adds tools alongside the defaults. Use `Tools` to restrict capabilities; use `ExtraTools` to extend them.
|
||||
|
||||
Create custom tools with `kit.NewTool` — no external dependencies needed:
|
||||
|
||||
```go
|
||||
type LookupInput struct {
|
||||
ID string `json:"id" description:"Record ID to look up"`
|
||||
}
|
||||
|
||||
lookupTool := kit.NewTool("lookup", "Look up a record by ID",
|
||||
func(ctx context.Context, input LookupInput) (kit.ToolOutput, error) {
|
||||
record := db.Find(input.ID)
|
||||
return kit.TextResult(record.String()), nil
|
||||
},
|
||||
)
|
||||
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
ExtraTools: []kit.Tool{lookupTool},
|
||||
})
|
||||
```
|
||||
|
||||
See [Overview](/sdk/overview#custom-tools) for full custom tool documentation.
|
||||
|
||||
@@ -68,6 +68,44 @@ The SDK provides several prompt variants:
|
||||
| `Steer(ctx, instruction)` | System-level steering without user message |
|
||||
| `FollowUp(ctx, text)` | Continue without new user input |
|
||||
|
||||
## Custom tools
|
||||
|
||||
Create custom tools with `kit.NewTool`. The JSON schema is auto-generated from the input struct — no external dependencies required:
|
||||
|
||||
```go
|
||||
type WeatherInput struct {
|
||||
City string `json:"city" description:"City name"`
|
||||
}
|
||||
|
||||
weatherTool := kit.NewTool("get_weather", "Get current weather for a city",
|
||||
func(ctx context.Context, input WeatherInput) (kit.ToolOutput, error) {
|
||||
return kit.TextResult("72°F, sunny in " + input.City), nil
|
||||
},
|
||||
)
|
||||
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
ExtraTools: []kit.Tool{weatherTool},
|
||||
})
|
||||
```
|
||||
|
||||
Struct tags control the schema:
|
||||
|
||||
- `json:"name"` — parameter name
|
||||
- `description:"..."` — description shown to the LLM
|
||||
- `enum:"a,b,c"` — restrict valid values
|
||||
- `omitempty` — marks the parameter as optional
|
||||
|
||||
Return values:
|
||||
|
||||
| Helper | Description |
|
||||
|--------|-------------|
|
||||
| `kit.TextResult(s)` | Successful text result |
|
||||
| `kit.ErrorResult(s)` | Error result (LLM sees it as a tool error) |
|
||||
|
||||
For advanced use, return a `kit.ToolOutput` struct directly with `Data`, `MediaType`, and `Metadata` fields.
|
||||
|
||||
Use `kit.NewParallelTool` for tools that are safe to run concurrently. Use `kit.ToolCallIDFromContext(ctx)` to retrieve the LLM-assigned call ID for logging or tracing.
|
||||
|
||||
## Event system
|
||||
|
||||
Subscribe to events for monitoring:
|
||||
|
||||
@@ -901,6 +901,126 @@ a:hover { text-decoration: underline; }
|
||||
color: var(--text-muted);
|
||||
}
|
||||
|
||||
/* ============================================================
|
||||
Compaction Card
|
||||
============================================================ */
|
||||
.compaction-card {
|
||||
margin: 16px 0;
|
||||
border: 1px solid var(--border);
|
||||
border-radius: var(--radius);
|
||||
background: var(--surface);
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.compaction-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
padding: 12px 16px;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
transition: background var(--transition);
|
||||
background: var(--surface-raised);
|
||||
}
|
||||
|
||||
.compaction-header:hover {
|
||||
background: var(--surface-overlay);
|
||||
}
|
||||
|
||||
.compaction-icon {
|
||||
width: 18px;
|
||||
height: 18px;
|
||||
color: var(--yellow);
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.compaction-title {
|
||||
font-size: 13px;
|
||||
font-weight: 600;
|
||||
color: var(--text-secondary);
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.compaction-badge {
|
||||
font-size: 11px;
|
||||
font-weight: 500;
|
||||
color: var(--text-muted);
|
||||
background: var(--surface);
|
||||
padding: 2px 8px;
|
||||
border-radius: 10px;
|
||||
border: 1px solid var(--border);
|
||||
}
|
||||
|
||||
.compaction-chevron {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
color: var(--text-faint);
|
||||
transition: transform var(--transition);
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.compaction-card.expanded .compaction-chevron {
|
||||
transform: rotate(180deg);
|
||||
}
|
||||
|
||||
.compaction-content {
|
||||
max-height: 0;
|
||||
overflow: hidden;
|
||||
transition: max-height var(--transition);
|
||||
border-top: 1px solid transparent;
|
||||
}
|
||||
|
||||
.compaction-card.expanded .compaction-content {
|
||||
max-height: 2000px;
|
||||
overflow-y: auto;
|
||||
border-top-color: var(--border);
|
||||
}
|
||||
|
||||
.compaction-summary {
|
||||
padding: 16px;
|
||||
font-size: 13.5px;
|
||||
line-height: 1.7;
|
||||
color: var(--text-secondary);
|
||||
}
|
||||
|
||||
.compaction-summary .md-content h1,
|
||||
.compaction-summary .md-content h2,
|
||||
.compaction-summary .md-content h3 {
|
||||
color: var(--text);
|
||||
margin: 16px 0 8px;
|
||||
}
|
||||
|
||||
.compaction-summary .md-content h1 { font-size: 1.3em; }
|
||||
.compaction-summary .md-content h2 { font-size: 1.15em; }
|
||||
.compaction-summary .md-content h3 { font-size: 1.05em; }
|
||||
|
||||
.compaction-summary .md-content ul,
|
||||
.compaction-summary .md-content ol {
|
||||
padding-left: 20px;
|
||||
}
|
||||
|
||||
.compaction-stats {
|
||||
display: flex;
|
||||
gap: 16px;
|
||||
padding: 12px 16px;
|
||||
background: var(--surface-raised);
|
||||
border-top: 1px solid var(--border-subtle);
|
||||
font-size: 11.5px;
|
||||
color: var(--text-muted);
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.compaction-stat {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.compaction-stat strong {
|
||||
color: var(--text-secondary);
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
/* ============================================================
|
||||
System Prompt Display
|
||||
============================================================ */
|
||||
@@ -1460,6 +1580,7 @@ a:hover { text-decoration: underline; }
|
||||
let userMsgCount = 0;
|
||||
let assistantMsgCount = 0;
|
||||
let toolCallCount = 0;
|
||||
let compactionCount = 0;
|
||||
|
||||
// Render each entry
|
||||
for (const entry of path) {
|
||||
@@ -1491,6 +1612,9 @@ a:hover { text-decoration: underline; }
|
||||
renderSystemNotice('Label', entry.label || '', 'label');
|
||||
} else if (entry.type === 'session_info') {
|
||||
// Already handled above for header
|
||||
} else if (entry.type === 'compaction') {
|
||||
compactionCount++;
|
||||
renderCompaction(entry);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1501,6 +1625,7 @@ a:hover { text-decoration: underline; }
|
||||
<div class="stat-item"><strong>${userMsgCount}</strong> user message${userMsgCount !== 1 ? 's' : ''}</div>
|
||||
<div class="stat-item"><strong>${assistantMsgCount}</strong> assistant message${assistantMsgCount !== 1 ? 's' : ''}</div>
|
||||
${toolCallCount > 0 ? `<div class="stat-item"><strong>${toolCallCount}</strong> tool call${toolCallCount !== 1 ? 's' : ''}</div>` : ''}
|
||||
${compactionCount > 0 ? `<div class="stat-item"><strong>${compactionCount}</strong> compaction${compactionCount !== 1 ? 's' : ''}</div>` : ''}
|
||||
${header && header.cwd ? `<div class="stat-item">📁 ${escapeHtml(header.cwd)}</div>` : ''}
|
||||
</div>`;
|
||||
$conversation.insertAdjacentHTML('beforeend', statsHtml);
|
||||
@@ -2030,6 +2155,60 @@ a:hover { text-decoration: underline; }
|
||||
$conversation.appendChild(el);
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Compaction Display
|
||||
// ============================================================
|
||||
function renderCompaction(entry) {
|
||||
const el = document.createElement('div');
|
||||
el.className = 'compaction-card fade-in';
|
||||
|
||||
const cardId = 'compaction-' + Math.random().toString(36).substr(2, 9);
|
||||
|
||||
// Build stats
|
||||
const stats = [];
|
||||
if (entry.messages_removed > 0) {
|
||||
stats.push(`<div class="compaction-stat"><strong>${entry.messages_removed}</strong> messages compacted</div>`);
|
||||
}
|
||||
if (entry.tokens_before > 0 && entry.tokens_after > 0) {
|
||||
const saved = entry.tokens_before - entry.tokens_after;
|
||||
stats.push(`<div class="compaction-stat"><strong>${saved}</strong> tokens saved</div>`);
|
||||
}
|
||||
|
||||
// Format timestamp
|
||||
const timeStr = formatTime(entry.timestamp);
|
||||
|
||||
el.innerHTML = `
|
||||
<div class="compaction-header" onclick="toggleCompaction('${cardId}')">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16" fill="currentColor" class="compaction-icon">
|
||||
<path d="M2.5 3.5a.5.5 0 0 1 .5-.5h10a.5.5 0 0 1 .5.5v1a.5.5 0 0 1-.5.5h-10a.5.5 0 0 1-.5-.5v-1Zm0 4a.5.5 0 0 1 .5-.5h10a.5.5 0 0 1 .5.5v1a.5.5 0 0 1-.5.5h-10a.5.5 0 0 1-.5-.5v-1Zm0 4a.5.5 0 0 1 .5-.5h10a.5.5 0 0 1 .5.5v1a.5.5 0 0 1-.5.5h-10a.5.5 0 0 1-.5-.5v-1Z"/>
|
||||
</svg>
|
||||
<span class="compaction-title">Context Compacted</span>
|
||||
${timeStr ? `<span class="compaction-badge">${timeStr}</span>` : ''}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16" fill="currentColor" class="compaction-chevron" id="${cardId}-chevron">
|
||||
<path d="M4.427 9.427a.25.25 0 0 0 0 .353l3 3a.25.25 0 0 0 .353 0l3-3a.25.25 0 0 0-.353-.353L8 11.646V4.75a.75.75 0 0 0-1.5 0v6.896L4.78 9.427a.25.25 0 0 0-.353 0Z"/>
|
||||
</svg>
|
||||
</div>
|
||||
<div class="compaction-content" id="${cardId}">
|
||||
${entry.summary ? `<div class="compaction-summary"><div class="md-content">${renderMarkdown(entry.summary)}</div></div>` : ''}
|
||||
${stats.length > 0 ? `<div class="compaction-stats">${stats.join('')}</div>` : ''}
|
||||
</div>`;
|
||||
|
||||
$conversation.appendChild(el);
|
||||
}
|
||||
|
||||
// Toggle compaction card expansion
|
||||
window.toggleCompaction = function(cardId) {
|
||||
const card = document.getElementById(cardId).closest('.compaction-card');
|
||||
const chevron = document.getElementById(cardId + '-chevron');
|
||||
const isExpanded = card.classList.contains('expanded');
|
||||
|
||||
if (isExpanded) {
|
||||
card.classList.remove('expanded');
|
||||
} else {
|
||||
card.classList.add('expanded');
|
||||
}
|
||||
};
|
||||
|
||||
// ============================================================
|
||||
// System Prompt Display (collapsible)
|
||||
// ============================================================
|
||||
|
||||
Reference in New Issue
Block a user