Compare commits

..

5 Commits

Author SHA1 Message Date
Ed Zynda e07c94f49d feat(mcp): add dynamic MCP server loading and unloading
- Add AddServer/RemoveServer to MCPToolManager for runtime server management
- Add RemoveConnection to MCPConnectionPool for per-server teardown
- Add AddMCPServer/RemoveMCPServer/ListMCPServers to Agent and SDK Kit
- Lazily create connection pool so AddServer works without prior LoadTools
- Wire onToolsChanged callback to trigger agent tool list rebuild
- Make MCPToolManager.Close nil-safe when pool was never initialized

Tests:
- Integration tests with real stdio MCP server (Python echo server)
- Agent-level tests using mock LLM model (no API key needed)
- Unit tests for error paths, callbacks, idempotency, nil safety
- SDK type surface tests
2026-04-09 13:54:11 +03:00
Ed Zynda b87146a284 feat(sdk): add MCPTokenStoreFactory for custom OAuth token storage
- Add MCPTokenStoreFactory option to kit.Options allowing SDK consumers
  to provide custom token storage backends for remote MCP servers
- Thread TokenStoreFactory through the full chain: kit.Options →
  kitsetup → agent → MCPToolManager → MCPConnectionPool
- Add createTokenStore() helper on connection pool that delegates to the
  factory or falls back to the default FileTokenStore
- Export MCPTokenStore, MCPToken, MCPTokenStoreFactory, and ErrMCPNoToken
  in pkg/kit/types.go following SDK naming conventions
- Default behavior (file-based storage) is preserved when factory is nil
2026-04-09 13:27:40 +03:00
Ed Zynda 186d9f7f44 fix(ui): route raw fmt.Print calls through proper renderers
- event_handler: route default extension print level through DisplayInfo
  instead of bare fmt.Println for consistent styling and timestamps
- factory: remove orphan fmt.Println("") before system messages; the
  renderer already manages its own spacing
- app: PrintFromExtension non-interactive fallback now respects level,
  writing errors/info to stderr with prefix to keep stdout clean
- app: PrintBlockFromExtension non-interactive fallback writes framed
  blocks to stderr instead of raw text to stdout
2026-04-09 13:00:23 +03:00
Ed Zynda 3a8ffc2104 feat(models): add per-model system prompt support
- Add systemPrompt field to GenerationParams and config structs
- On init, replace default system prompt with per-model prompt when
  user hasn't explicitly set one (via flag, config, or SDK option)
- On model switch, detect per-model prompt and compose it with
  AGENTS.md, skills, and date/cwd context
- Fix viper.IsSet bug: BindPFlag causes IsSet to return true for
  unset flags, so compare against defaultSystemPrompt instead
- Agent.SetModel now updates stored system prompt from config
- Export LoadModelSettingsFromConfig, LoadSystemPromptValue, and
  LookupModelForSettings for use by Kit.SetModel
- Add tests for prompt apply, precedence, file path, and
  modelSettings override
2026-04-09 12:35:00 +03:00
Ed Zynda e54570162e feat(models): add per-model generation parameter defaults
- Add modelSettings config section for attaching generation params
  (temperature, topP, topK, frequencyPenalty, presencePenalty,
  maxTokens, stopSequences, thinkingLevel) to any model by
  provider/model key
- Add params field to customModels definitions for inline defaults
- Change BuildProviderConfig and SetModel to use viper.IsSet so
  unset params remain nil, allowing model-level defaults to apply
- Wire ApplyModelSettings into CreateProvider with priority order:
  CLI flags > global config > modelSettings > customModels params
- Add GenerationParams to ModelInfo in the registry
- Update default config template with modelSettings and customModels
  params examples
2026-04-09 12:07:42 +03:00
21 changed files with 2236 additions and 116 deletions
+73 -3
View File
@@ -30,6 +30,11 @@ type AgentConfig struct {
// If nil, remote MCP servers that require OAuth will fail to connect.
AuthHandler tools.MCPAuthHandler
// TokenStoreFactory, if non-nil, creates a custom token store for each
// remote MCP server's OAuth tokens. When nil, the default file-based
// token store is used.
TokenStoreFactory tools.TokenStoreFactory
// CoreTools overrides the default core tool set. If empty, core.AllTools()
// is used. This allows SDK users to provide a custom tool set (e.g.
// CodingTools or tools with a custom WorkDir).
@@ -236,6 +241,9 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
if agentConfig.AuthHandler != nil {
toolManager.SetAuthHandler(agentConfig.AuthHandler)
}
if agentConfig.TokenStoreFactory != nil {
toolManager.SetTokenStoreFactory(agentConfig.TokenStoreFactory)
}
if agentConfig.DebugLogger != nil {
toolManager.SetDebugLogger(agentConfig.DebugLogger)
}
@@ -826,6 +834,59 @@ func (a *Agent) SetExtraTools(extraTools []fantasy.AgentTool) {
a.rebuildFantasyAgent()
}
// AddMCPServer connects to a new MCP server at runtime and makes its tools
// available to the agent. Returns the number of tools loaded.
// If the agent has no tool manager (no MCP servers were configured at init),
// one is created automatically.
func (a *Agent) AddMCPServer(ctx context.Context, name string, cfg config.MCPServerConfig) (int, error) {
// Ensure MCP tools from initial load are settled first.
a.ensureMCPTools()
if a.toolManager == nil {
a.toolManager = tools.NewMCPToolManager()
a.toolManager.SetModel(a.model)
a.toolManager.SetOnToolsChanged(func() {
a.rebuildFantasyAgent()
})
}
count, err := a.toolManager.AddServer(ctx, name, cfg)
if err != nil {
return 0, err
}
// AddServer's onToolsChanged callback triggers rebuildFantasyAgent,
// but only if it was wired. Ensure rebuild happens regardless.
a.rebuildFantasyAgent()
return count, nil
}
// RemoveMCPServer disconnects an MCP server and removes its tools from the agent.
func (a *Agent) RemoveMCPServer(name string) error {
if a.toolManager == nil {
return fmt.Errorf("no MCP servers loaded")
}
// Ensure MCP tools from initial load are settled first.
a.ensureMCPTools()
err := a.toolManager.RemoveServer(name)
if err != nil {
return err
}
// RemoveServer's onToolsChanged callback triggers rebuildFantasyAgent,
// but ensure rebuild happens regardless.
a.rebuildFantasyAgent()
return nil
}
// GetMCPToolManager returns the underlying MCP tool manager.
// Returns nil if no MCP servers have been configured.
func (a *Agent) GetMCPToolManager() *tools.MCPToolManager {
return a.toolManager
}
// GetLoadingMessage returns the loading message from provider creation.
func (a *Agent) GetLoadingMessage() string {
return a.loadingMessage
@@ -839,9 +900,11 @@ func (a *Agent) GetLoadedServerNames() []string {
return a.toolManager.GetLoadedServerNames()
}
// SetModel swaps the agent's LLM provider to a new model. The existing tools,
// system prompt, and configuration are preserved. The old provider is closed
// if it has a closer. Returns the previous model string for notification.
// SetModel swaps the agent's LLM provider to a new model. The existing tools
// and configuration are preserved. When the new model's ProviderConfig carries
// a system prompt (from per-model settings), it replaces the agent's stored
// prompt so the rebuilt fantasy agent uses it. The old provider is closed if
// it has a closer.
func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) error {
// Ensure MCP tools are loaded before rebuilding (SetModel may be called
// before the first LLM call).
@@ -868,6 +931,13 @@ func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) err
a.skipMaxOutputTokens = providerResult.SkipMaxOutputTokens
a.modelConfig = config
// Update system prompt when the config carries one (from per-model
// settings or the global config). This allows model-specific system
// prompts to take effect on model switch.
if config.SystemPrompt != "" {
a.systemPrompt = config.SystemPrompt
}
// Update provider type.
if config.ModelString != "" {
if p, _, err := models.ParseModelString(config.ModelString); err == nil {
+242
View File
@@ -0,0 +1,242 @@
package agent
import (
"context"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
"charm.land/fantasy"
"github.com/mark3labs/kit/internal/config"
)
// mockModel is a minimal LanguageModel that satisfies the interface
// without making real API calls. Used to test tool management wiring.
type mockModel struct{}
func (m *mockModel) Generate(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
return &fantasy.Response{}, nil
}
func (m *mockModel) Stream(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return nil, nil
}
func (m *mockModel) GenerateObject(_ context.Context, _ fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
return &fantasy.ObjectResponse{}, nil
}
func (m *mockModel) StreamObject(_ context.Context, _ fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
return nil, nil
}
func (m *mockModel) Provider() string { return "mock" }
func (m *mockModel) Model() string { return "mock-model" }
// testdataDir returns the absolute path to the tools testdata directory.
func testdataDir(t *testing.T) string {
t.Helper()
_, file, _, ok := runtime.Caller(0)
if !ok {
t.Fatal("cannot determine test file path")
}
return filepath.Join(filepath.Dir(file), "..", "tools", "testdata")
}
// echoServerConfig returns an MCPServerConfig for the test echo MCP server.
func echoServerConfig(t *testing.T) config.MCPServerConfig {
t.Helper()
script := filepath.Join(testdataDir(t), "echo_server.py")
if _, err := os.Stat(script); err != nil {
t.Skipf("echo_server.py not found: %v", err)
}
return config.MCPServerConfig{
Command: []string{"python3", script},
}
}
// newTestAgent creates a minimal Agent with a mock model and no core tools,
// suitable for testing MCP server management without an API key.
func newTestAgent() *Agent {
model := &mockModel{}
a := &Agent{
model: model,
coreTools: nil,
extraTools: nil,
maxSteps: 10,
systemPrompt: "test",
fantasyAgent: fantasy.NewAgent(model),
}
return a
}
func TestAgent_AddMCPServer(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
a := newTestAgent()
defer func() { _ = a.Close() }()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cfg := echoServerConfig(t)
// Initially no MCP tools.
if a.GetMCPToolCount() != 0 {
t.Fatalf("Expected 0 MCP tools initially, got %d", a.GetMCPToolCount())
}
// Add a server.
count, err := a.AddMCPServer(ctx, "echo", cfg)
if err != nil {
t.Fatalf("AddMCPServer failed: %v", err)
}
if count != 2 {
t.Errorf("Expected 2 tools, got %d", count)
}
// Verify tools are in the agent's tool list.
if a.GetMCPToolCount() != 2 {
t.Errorf("Expected 2 MCP tools, got %d", a.GetMCPToolCount())
}
allTools := a.GetTools()
toolNames := make(map[string]bool)
for _, tool := range allTools {
toolNames[tool.Info().Name] = true
}
if !toolNames["echo__echo"] {
t.Error("Expected tool 'echo__echo' in agent tools")
}
if !toolNames["echo__greet"] {
t.Error("Expected tool 'echo__greet' in agent tools")
}
// Verify loaded server names.
names := a.GetLoadedServerNames()
found := false
for _, n := range names {
if n == "echo" {
found = true
}
}
if !found {
t.Errorf("Expected 'echo' in loaded server names: %v", names)
}
}
func TestAgent_RemoveMCPServer(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
a := newTestAgent()
defer func() { _ = a.Close() }()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cfg := echoServerConfig(t)
// Add then remove.
_, err := a.AddMCPServer(ctx, "echo", cfg)
if err != nil {
t.Fatalf("AddMCPServer failed: %v", err)
}
err = a.RemoveMCPServer("echo")
if err != nil {
t.Fatalf("RemoveMCPServer failed: %v", err)
}
// Verify tools removed.
if a.GetMCPToolCount() != 0 {
t.Errorf("Expected 0 MCP tools after removal, got %d", a.GetMCPToolCount())
}
// Verify agent's tool list has no MCP tools.
for _, tool := range a.GetTools() {
if strings.Contains(tool.Info().Name, "echo__") {
t.Errorf("Found leftover tool after removal: %s", tool.Info().Name)
}
}
}
func TestAgent_RemoveMCPServer_NoToolManager(t *testing.T) {
a := newTestAgent()
defer func() { _ = a.Close() }()
err := a.RemoveMCPServer("nonexistent")
if err == nil {
t.Fatal("Expected error when no tool manager exists")
}
if !strings.Contains(err.Error(), "no MCP servers loaded") {
t.Errorf("Expected 'no MCP servers loaded' error, got: %v", err)
}
}
func TestAgent_AddMCPServer_CreatesToolManager(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
a := newTestAgent()
defer func() { _ = a.Close() }()
// Initially no tool manager.
if a.GetMCPToolManager() != nil {
t.Fatal("Expected nil tool manager initially")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cfg := echoServerConfig(t)
_, err := a.AddMCPServer(ctx, "echo", cfg)
if err != nil {
t.Fatalf("AddMCPServer failed: %v", err)
}
// Tool manager should now exist.
if a.GetMCPToolManager() == nil {
t.Fatal("Expected tool manager to be created by AddMCPServer")
}
}
func TestAgent_AddRemoveAdd_MCP(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
a := newTestAgent()
defer func() { _ = a.Close() }()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cfg := echoServerConfig(t)
// Add → Remove → Add cycle.
_, err := a.AddMCPServer(ctx, "echo", cfg)
if err != nil {
t.Fatalf("First add failed: %v", err)
}
err = a.RemoveMCPServer("echo")
if err != nil {
t.Fatalf("Remove failed: %v", err)
}
count, err := a.AddMCPServer(ctx, "echo", cfg)
if err != nil {
t.Fatalf("Re-add failed: %v", err)
}
if count != 2 {
t.Errorf("Expected 2 tools on re-add, got %d", count)
}
if a.GetMCPToolCount() != 2 {
t.Errorf("Expected 2 MCP tools after re-add, got %d", a.GetMCPToolCount())
}
}
+5
View File
@@ -38,6 +38,10 @@ type AgentCreationOptions struct {
DebugLogger tools.DebugLogger // Optional debug logger
// AuthHandler handles OAuth authorization for remote MCP servers
AuthHandler tools.MCPAuthHandler
// TokenStoreFactory, if non-nil, creates a custom token store for each
// remote MCP server's OAuth tokens. When nil, the default file-based
// token store is used.
TokenStoreFactory tools.TokenStoreFactory
// CoreTools overrides the default core tool set. If empty, core.AllTools()
// is used.
CoreTools []fantasy.AgentTool
@@ -66,6 +70,7 @@ func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error
StreamingEnabled: opts.StreamingEnabled,
DebugLogger: opts.DebugLogger,
AuthHandler: opts.AuthHandler,
TokenStoreFactory: opts.TokenStoreFactory,
CoreTools: opts.CoreTools,
DisableCoreTools: opts.DisableCoreTools,
ToolWrapper: opts.ToolWrapper,
+16 -6
View File
@@ -930,7 +930,8 @@ func (a *App) QuitFromExtension() {
// controls styling: "" for plain text, "info" for a system message block,
// "error" for an error block. In interactive mode it sends an
// ExtensionPrintEvent through the program so the TUI can render it with the
// appropriate renderer. In non-interactive mode it falls back to stdout.
// appropriate renderer. In non-interactive mode it falls back to stderr with
// a level prefix so errors are distinguishable from plain output.
func (a *App) PrintFromExtension(level, text string) {
a.mu.Lock()
prog := a.program
@@ -939,8 +940,16 @@ func (a *App) PrintFromExtension(level, text string) {
prog.Send(ExtensionPrintEvent{Text: text, Level: level})
return
}
// Non-interactive fallback: write directly to stdout.
fmt.Println(text)
// Non-interactive fallback: write to stderr with a level prefix so that
// errors and info messages are distinguishable from plain output.
switch level {
case "error":
fmt.Fprintf(os.Stderr, "[ERROR] %s\n", text)
case "info":
fmt.Fprintf(os.Stderr, "[INFO] %s\n", text)
default:
fmt.Println(text)
}
}
// SetEditorTextFromExtension sends an EditorTextSetEvent to the TUI to
@@ -1122,11 +1131,12 @@ func (a *App) PrintBlockFromExtension(opts extensions.PrintBlockOpts) {
})
return
}
// Non-interactive fallback.
// Non-interactive fallback: render a simple framed block to stderr so
// it is visually distinct from plain stdout output.
if opts.Subtitle != "" {
fmt.Printf("%s\n — %s\n", opts.Text, opts.Subtitle)
fmt.Fprintf(os.Stderr, "--- %s ---\n%s\n", opts.Subtitle, opts.Text)
} else {
fmt.Println(opts.Text)
fmt.Fprintf(os.Stderr, "---\n%s\n---\n", opts.Text)
}
}
+64 -1
View File
@@ -157,6 +157,21 @@ type Theme struct {
Markdown MarkdownThemeConfig `json:"markdown,omitzero" yaml:"markdown,omitempty"`
}
// GenerationParams defines generation parameter defaults that can be attached
// to individual models. These act as model-level defaults — CLI flags and
// global config values take precedence when explicitly set.
type GenerationParams struct {
MaxTokens *int `json:"maxTokens,omitempty" yaml:"maxTokens,omitempty"`
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
TopP *float32 `json:"topP,omitempty" yaml:"topP,omitempty"`
TopK *int32 `json:"topK,omitempty" yaml:"topK,omitempty"`
FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty" yaml:"frequencyPenalty,omitempty"`
PresencePenalty *float32 `json:"presencePenalty,omitempty" yaml:"presencePenalty,omitempty"`
StopSequences []string `json:"stopSequences,omitempty" yaml:"stopSequences,omitempty"`
ThinkingLevel string `json:"thinkingLevel,omitempty" yaml:"thinkingLevel,omitempty"`
SystemPrompt string `json:"systemPrompt,omitempty" yaml:"systemPrompt,omitempty"`
}
// CustomModelConfig defines a custom model that can be used with custom/custom
// or other custom/ prefixed models. These models are loaded from the config file
// and merged into the custom provider in the model registry.
@@ -171,6 +186,11 @@ type CustomModelConfig struct {
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
Cost CostConfig `json:"cost" yaml:"cost"`
Limit LimitConfig `json:"limit" yaml:"limit"`
// Generation parameter defaults for this model.
// These are applied when the user hasn't explicitly set the corresponding
// CLI flag or global config value.
Params GenerationParams `json:"params,omitzero" yaml:"params,omitempty"`
}
// CostConfig defines the pricing for a custom model.
@@ -219,6 +239,12 @@ type Config struct {
// Custom model definitions (under custom/ provider)
CustomModels map[string]CustomModelConfig `json:"customModels,omitempty" yaml:"customModels,omitempty"`
// Per-model generation parameter overrides. Keys are "provider/model" strings
// (e.g. "anthropic/claude-sonnet-4-5-20250929", "openai/gpt-4o"). These
// settings act as model-level defaults — CLI flags and global config values
// take precedence when explicitly set.
ModelSettings map[string]GenerationParams `json:"modelSettings,omitempty" yaml:"modelSettings,omitempty"`
}
// GetTransportType returns the transport type for the server config, mapping
@@ -367,7 +393,7 @@ mcpServers:
# debug: false # Enable debug logging
# system-prompt: "/path/to/system-prompt.txt" # System prompt text file
# Model generation parameters (all optional)
# Model generation parameters (all optional, apply globally to all models)
# max-tokens: 4096 # Maximum tokens in response
# temperature: 0.7 # Randomness (0.0-1.0)
# top-p: 0.95 # Nucleus sampling (0.0-1.0)
@@ -376,9 +402,46 @@ mcpServers:
# presence-penalty: 0.0 # Penalize present tokens (0.0-2.0)
# stop-sequences: ["Human:", "Assistant:"] # Custom stop sequences
# Per-model generation parameter overrides (apply to specific models)
# These act as model-level defaults — CLI flags and global settings above take precedence.
# Keys are "provider/model" strings matching the model you use.
# modelSettings:
# anthropic/claude-sonnet-4-5-20250929:
# temperature: 0.3
# maxTokens: 8192
# openai/gpt-4o:
# temperature: 0.7
# topP: 0.95
# topK: 40
# frequencyPenalty: 0.1
# presencePenalty: 0.1
# anthropic/claude-opus-4-6:
# thinkingLevel: "high"
# maxTokens: 16384
# systemPrompt: "You are a deep reasoning assistant." # or a file path
# API Configuration (can also use environment variables)
# provider-api-key: "your-api-key" # API key for OpenAI, Anthropic, or Google
# provider-url: "https://api.openai.com/v1" # Base URL for OpenAI, Anthropic, or Ollama
# Custom model definitions (under custom/ provider)
# customModels:
# my-local-llama:
# name: "Local Llama 3"
# baseUrl: "http://localhost:8080/v1"
# family: "llama"
# temperature: true
# cost:
# input: 0.0
# output: 0.0
# limit:
# context: 131072
# output: 8192
# params: # Generation parameter defaults for this model
# temperature: 0.8
# topP: 0.95
# topK: 40
# systemPrompt: "You are a helpful local assistant."
`
_, err = file.WriteString(content)
+44 -20
View File
@@ -65,6 +65,10 @@ type AgentSetupOptions struct {
// AuthHandler handles OAuth authorization for remote MCP servers.
// When set, remote transports are configured with OAuth support.
AuthHandler tools.MCPAuthHandler
// TokenStoreFactory, if non-nil, creates a custom token store for each
// remote MCP server's OAuth tokens. When nil, the default file-based
// token store is used.
TokenStoreFactory tools.TokenStoreFactory
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
// loading (successfully or with error). Called from the background goroutine.
OnMCPServerLoaded func(serverName string, toolCount int, err error)
@@ -82,36 +86,55 @@ type AgentSetupResult struct {
// BuildProviderConfig creates a *models.ProviderConfig from the current viper
// state. All entry points (root, script, SDK) converge through this function.
//
// Generation parameter pointers (Temperature, TopP, etc.) are only set when
// the user has explicitly configured them via CLI flag, environment variable,
// or global config file. This allows per-model defaults from modelSettings
// and customModels to fill in unset parameters downstream.
func BuildProviderConfig() (*models.ProviderConfig, string, error) {
systemPrompt, err := config.LoadSystemPrompt(viper.GetString("system-prompt"))
if err != nil {
return nil, "", fmt.Errorf("failed to load system prompt: %w", err)
}
temperature := float32(viper.GetFloat64("temperature"))
topP := float32(viper.GetFloat64("top-p"))
topK := int32(viper.GetInt("top-k"))
frequencyPenalty := float32(viper.GetFloat64("frequency-penalty"))
presencePenalty := float32(viper.GetFloat64("presence-penalty"))
numGPU := int32(viper.GetInt("num-gpu-layers"))
mainGPU := int32(viper.GetInt("main-gpu"))
cfg := &models.ProviderConfig{
ModelString: viper.GetString("model"),
SystemPrompt: systemPrompt,
ProviderAPIKey: viper.GetString("provider-api-key"),
ProviderURL: viper.GetString("provider-url"),
MaxTokens: viper.GetInt("max-tokens"),
Temperature: &temperature,
TopP: &topP,
TopK: &topK,
FrequencyPenalty: &frequencyPenalty,
PresencePenalty: &presencePenalty,
StopSequences: viper.GetStringSlice("stop-sequences"),
NumGPU: &numGPU,
MainGPU: &mainGPU,
TLSSkipVerify: viper.GetBool("tls-skip-verify"),
ThinkingLevel: models.ParseThinkingLevel(viper.GetString("thinking-level")),
ModelString: viper.GetString("model"),
SystemPrompt: systemPrompt,
ProviderAPIKey: viper.GetString("provider-api-key"),
ProviderURL: viper.GetString("provider-url"),
MaxTokens: viper.GetInt("max-tokens"),
StopSequences: viper.GetStringSlice("stop-sequences"),
NumGPU: &numGPU,
MainGPU: &mainGPU,
TLSSkipVerify: viper.GetBool("tls-skip-verify"),
ThinkingLevel: models.ParseThinkingLevel(viper.GetString("thinking-level")),
}
// Only set generation parameter pointers when the user has explicitly
// provided a value. This leaves nil pointers for unset params, allowing
// per-model defaults (modelSettings / customModels params) to apply.
if viper.IsSet("temperature") {
v := float32(viper.GetFloat64("temperature"))
cfg.Temperature = &v
}
if viper.IsSet("top-p") {
v := float32(viper.GetFloat64("top-p"))
cfg.TopP = &v
}
if viper.IsSet("top-k") {
v := int32(viper.GetInt("top-k"))
cfg.TopK = &v
}
if viper.IsSet("frequency-penalty") {
v := float32(viper.GetFloat64("frequency-penalty"))
cfg.FrequencyPenalty = &v
}
if viper.IsSet("presence-penalty") {
v := float32(viper.GetFloat64("presence-penalty"))
cfg.PresencePenalty = &v
}
return cfg, systemPrompt, nil
@@ -200,6 +223,7 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
SpinnerFunc: opts.SpinnerFunc,
DebugLogger: debugLogger,
AuthHandler: opts.AuthHandler,
TokenStoreFactory: opts.TokenStoreFactory,
CoreTools: opts.CoreTools,
DisableCoreTools: opts.DisableCoreTools,
ToolWrapper: toolWrapper,
+234 -11
View File
@@ -2,6 +2,8 @@ package models
import (
"log"
"os"
"strings"
"github.com/spf13/viper"
)
@@ -31,7 +33,7 @@ func loadCustomModelsFromConfig() map[string]ModelInfo {
// modelConfigToModelInfo converts a CustomModelConfig to a ModelInfo.
func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
return ModelInfo{
info := ModelInfo{
ID: modelID,
Name: cfg.Name,
Attachment: cfg.Attachment,
@@ -48,21 +50,242 @@ func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
Output: cfg.Limit.Output,
},
}
// Convert custom model generation params if any are set.
if p := convertGenerationParams(cfg.Params); p != nil {
info.Params = p
}
return info
}
// LoadModelSettingsFromConfig loads per-model generation parameter overrides
// from the config file. Keys are "provider/model" strings. Returns nil if
// no model settings are configured.
func LoadModelSettingsFromConfig() map[string]*GenerationParams {
if !viper.IsSet("modelSettings") {
return nil
}
var settings map[string]GenerationParamsConfig
if err := viper.UnmarshalKey("modelSettings", &settings); err != nil {
log.Printf("Warning: Failed to parse modelSettings: %v", err)
return nil
}
result := make(map[string]*GenerationParams, len(settings))
for modelKey, cfg := range settings {
if p := convertGenerationParams(cfg); p != nil {
result[modelKey] = p
}
}
return result
}
// convertGenerationParams converts a GenerationParamsConfig to a GenerationParams.
// Returns nil if no parameters are set.
func convertGenerationParams(cfg GenerationParamsConfig) *GenerationParams {
p := &GenerationParams{}
any := false
if cfg.MaxTokens != nil {
p.MaxTokens = cfg.MaxTokens
any = true
}
if cfg.Temperature != nil {
p.Temperature = cfg.Temperature
any = true
}
if cfg.TopP != nil {
p.TopP = cfg.TopP
any = true
}
if cfg.TopK != nil {
p.TopK = cfg.TopK
any = true
}
if cfg.FrequencyPenalty != nil {
p.FrequencyPenalty = cfg.FrequencyPenalty
any = true
}
if cfg.PresencePenalty != nil {
p.PresencePenalty = cfg.PresencePenalty
any = true
}
if len(cfg.StopSequences) > 0 {
p.StopSequences = cfg.StopSequences
any = true
}
if cfg.ThinkingLevel != "" {
p.ThinkingLevel = ParseThinkingLevel(cfg.ThinkingLevel)
any = true
}
if cfg.SystemPrompt != "" {
p.SystemPrompt = cfg.SystemPrompt
any = true
}
if !any {
return nil
}
return p
}
// ApplyModelSettings merges per-model generation parameter defaults from the
// registry into a ProviderConfig. Model-level params are only applied for
// fields where the user has not explicitly set a value (i.e., the
// corresponding viper key is not set via CLI flag or global config).
//
// The lookup order is:
// 1. modelSettings["provider/model"] from config (highest model-level priority)
// 2. ModelInfo.Params from custom model definitions
//
// Both are overridden by explicit CLI flags / global config values.
func ApplyModelSettings(config *ProviderConfig, modelInfo *ModelInfo) {
provider, modelName, err := ParseModelString(config.ModelString)
if err != nil {
return
}
// Collect model-level params: modelSettings override > custom model params.
// modelSettings takes priority because it's the more specific/intentional config.
var params *GenerationParams
// First check modelSettings from config.
if settings := LoadModelSettingsFromConfig(); settings != nil {
modelKey := provider + "/" + modelName
if p, ok := settings[modelKey]; ok {
params = p
}
}
// Fall back to ModelInfo.Params (from custom model definitions).
if params == nil && modelInfo != nil && modelInfo.Params != nil {
params = modelInfo.Params
}
if params == nil {
return
}
// Apply each parameter only when the user hasn't explicitly set it.
// We check viper.IsSet() which returns true only when the key was
// set via CLI flag, environment variable, or config file global section.
if params.MaxTokens != nil && !isExplicitlySet("max-tokens") {
config.MaxTokens = *params.MaxTokens
}
if params.Temperature != nil && !isExplicitlySet("temperature") {
config.Temperature = params.Temperature
}
if params.TopP != nil && !isExplicitlySet("top-p") {
config.TopP = params.TopP
}
if params.TopK != nil && !isExplicitlySet("top-k") {
config.TopK = params.TopK
}
if params.FrequencyPenalty != nil && !isExplicitlySet("frequency-penalty") {
config.FrequencyPenalty = params.FrequencyPenalty
}
if params.PresencePenalty != nil && !isExplicitlySet("presence-penalty") {
config.PresencePenalty = params.PresencePenalty
}
if len(params.StopSequences) > 0 && !isExplicitlySet("stop-sequences") {
config.StopSequences = params.StopSequences
}
if params.ThinkingLevel != "" && !isExplicitlySet("thinking-level") {
config.ThinkingLevel = params.ThinkingLevel
}
if params.SystemPrompt != "" && config.SystemPrompt == "" {
// Resolve file paths: if the value points to an existing file, read it.
// We check config.SystemPrompt == "" rather than isExplicitlySet because
// viper.BindPFlag causes IsSet to return true even for unset flags.
config.SystemPrompt = LoadSystemPromptValue(params.SystemPrompt)
}
}
// LoadSystemPromptValue resolves a system prompt value that may be either
// inline text or a file path. If the value is a path to an existing file,
// its contents are read and returned. Otherwise the string is returned as-is.
// This mirrors config.LoadSystemPrompt but lives in the models package to
// avoid circular dependencies.
func LoadSystemPromptValue(input string) string {
if input == "" {
return ""
}
if info, err := os.Stat(input); err == nil && !info.IsDir() {
content, err := os.ReadFile(input)
if err != nil {
log.Printf("Warning: failed to read system prompt file %q: %v", input, err)
return input
}
return strings.TrimSpace(string(content))
}
return input
}
// isExplicitlySet returns true when the user has explicitly set a config key
// via CLI flag, environment variable, or the global section of the config file.
// Model-level defaults should not override explicitly set values.
func isExplicitlySet(key string) bool {
// viper.IsSet returns true if the key has been set in any of the
// data stores (flag, env, config file, default). We need to check
// whether the value was set at the global config level (not just
// as a default). For generation params, the global config keys use
// hyphenated names (e.g. "max-tokens", "top-p").
//
// Since viper merges all sources, IsSet returns true even for config
// file values. This means global config file values (e.g.
// temperature: 0.7 at the top level) will correctly take precedence
// over model-level defaults, which is the desired behavior.
return viper.IsSet(key)
}
// GenerationParams holds per-model generation parameter defaults.
// These are stored on ModelInfo and applied during provider creation.
// Nil pointer fields mean "no model-level default" — the global config
// or CLI flag value (if any) will be used instead.
type GenerationParams struct {
MaxTokens *int
Temperature *float32
TopP *float32
TopK *int32
FrequencyPenalty *float32
PresencePenalty *float32
StopSequences []string
ThinkingLevel ThinkingLevel
SystemPrompt string // Per-model system prompt (inline text or file path)
}
// CustomModelConfig defines a custom model configuration loaded from the config file.
// This is a duplicate here to avoid circular dependencies with internal/config.
type CustomModelConfig struct {
Name string `json:"name" yaml:"name"`
BaseURL string `json:"baseUrl,omitempty" yaml:"baseUrl,omitempty"`
APIKey string `json:"apiKey,omitempty" yaml:"apiKey,omitempty"`
Family string `json:"family,omitempty" yaml:"family,omitempty"`
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
Temperature bool `json:"temperature,omitempty" yaml:"temperature,omitempty"`
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
Cost CostConfig `json:"cost" yaml:"cost"`
Limit LimitConfig `json:"limit" yaml:"limit"`
Name string `json:"name" yaml:"name"`
BaseURL string `json:"baseUrl,omitempty" yaml:"baseUrl,omitempty"`
APIKey string `json:"apiKey,omitempty" yaml:"apiKey,omitempty"`
Family string `json:"family,omitempty" yaml:"family,omitempty"`
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
Temperature bool `json:"temperature,omitempty" yaml:"temperature,omitempty"`
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
Cost CostConfig `json:"cost" yaml:"cost"`
Limit LimitConfig `json:"limit" yaml:"limit"`
Params GenerationParamsConfig `json:"params,omitzero" yaml:"params,omitempty"`
}
// GenerationParamsConfig is the JSON/YAML-serializable form of generation
// parameter defaults. Used in both customModels[].params and modelSettings[].
type GenerationParamsConfig struct {
MaxTokens *int `json:"maxTokens,omitempty" yaml:"maxTokens,omitempty"`
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
TopP *float32 `json:"topP,omitempty" yaml:"topP,omitempty"`
TopK *int32 `json:"topK,omitempty" yaml:"topK,omitempty"`
FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty" yaml:"frequencyPenalty,omitempty"`
PresencePenalty *float32 `json:"presencePenalty,omitempty" yaml:"presencePenalty,omitempty"`
StopSequences []string `json:"stopSequences,omitempty" yaml:"stopSequences,omitempty"`
ThinkingLevel string `json:"thinkingLevel,omitempty" yaml:"thinkingLevel,omitempty"`
SystemPrompt string `json:"systemPrompt,omitempty" yaml:"systemPrompt,omitempty"`
}
// CostConfig defines the pricing for a custom model.
+422
View File
@@ -0,0 +1,422 @@
package models
import (
"os"
"testing"
"github.com/spf13/viper"
)
func TestConvertGenerationParams(t *testing.T) {
t.Run("empty config returns nil", func(t *testing.T) {
cfg := GenerationParamsConfig{}
p := convertGenerationParams(cfg)
if p != nil {
t.Errorf("expected nil, got %+v", p)
}
})
t.Run("temperature only", func(t *testing.T) {
temp := float32(0.7)
cfg := GenerationParamsConfig{Temperature: &temp}
p := convertGenerationParams(cfg)
if p == nil {
t.Fatal("expected non-nil")
}
if p.Temperature == nil || *p.Temperature != 0.7 {
t.Errorf("expected temperature 0.7, got %v", p.Temperature)
}
if p.TopP != nil {
t.Errorf("expected nil TopP, got %v", p.TopP)
}
})
t.Run("all params set", func(t *testing.T) {
maxTokens := 8192
temp := float32(0.5)
topP := float32(0.9)
topK := int32(50)
freqPenalty := float32(0.1)
presPenalty := float32(0.2)
cfg := GenerationParamsConfig{
MaxTokens: &maxTokens,
Temperature: &temp,
TopP: &topP,
TopK: &topK,
FrequencyPenalty: &freqPenalty,
PresencePenalty: &presPenalty,
StopSequences: []string{"STOP"},
ThinkingLevel: "high",
}
p := convertGenerationParams(cfg)
if p == nil {
t.Fatal("expected non-nil")
}
if p.MaxTokens == nil || *p.MaxTokens != 8192 {
t.Errorf("expected maxTokens 8192, got %v", p.MaxTokens)
}
if p.Temperature == nil || *p.Temperature != 0.5 {
t.Errorf("expected temperature 0.5, got %v", p.Temperature)
}
if p.TopP == nil || *p.TopP != 0.9 {
t.Errorf("expected topP 0.9, got %v", p.TopP)
}
if p.TopK == nil || *p.TopK != 50 {
t.Errorf("expected topK 50, got %v", p.TopK)
}
if p.FrequencyPenalty == nil || *p.FrequencyPenalty != 0.1 {
t.Errorf("expected frequencyPenalty 0.1, got %v", p.FrequencyPenalty)
}
if p.PresencePenalty == nil || *p.PresencePenalty != 0.2 {
t.Errorf("expected presencePenalty 0.2, got %v", p.PresencePenalty)
}
if len(p.StopSequences) != 1 || p.StopSequences[0] != "STOP" {
t.Errorf("expected stop sequences [STOP], got %v", p.StopSequences)
}
if p.ThinkingLevel != ThinkingHigh {
t.Errorf("expected thinking level high, got %v", p.ThinkingLevel)
}
})
t.Run("thinking level parsing", func(t *testing.T) {
cfg := GenerationParamsConfig{ThinkingLevel: "medium"}
p := convertGenerationParams(cfg)
if p == nil {
t.Fatal("expected non-nil")
}
if p.ThinkingLevel != ThinkingMedium {
t.Errorf("expected thinking level medium, got %v", p.ThinkingLevel)
}
})
t.Run("system prompt only", func(t *testing.T) {
cfg := GenerationParamsConfig{SystemPrompt: "You are helpful."}
p := convertGenerationParams(cfg)
if p == nil {
t.Fatal("expected non-nil")
}
if p.SystemPrompt != "You are helpful." {
t.Errorf("expected system prompt, got %q", p.SystemPrompt)
}
})
}
func TestModelConfigToModelInfoWithParams(t *testing.T) {
temp := float32(0.8)
topP := float32(0.95)
cfg := CustomModelConfig{
Name: "Test Model",
BaseURL: "http://localhost:8080/v1",
Temperature: true,
Params: GenerationParamsConfig{
Temperature: &temp,
TopP: &topP,
},
}
info := modelConfigToModelInfo("test-model", cfg)
if info.Params == nil {
t.Fatal("expected non-nil Params")
}
if info.Params.Temperature == nil || *info.Params.Temperature != 0.8 {
t.Errorf("expected temperature 0.8, got %v", info.Params.Temperature)
}
if info.Params.TopP == nil || *info.Params.TopP != 0.95 {
t.Errorf("expected topP 0.95, got %v", info.Params.TopP)
}
}
func TestModelConfigToModelInfoWithoutParams(t *testing.T) {
cfg := CustomModelConfig{
Name: "Test Model",
BaseURL: "http://localhost:8080/v1",
}
info := modelConfigToModelInfo("test-model", cfg)
if info.Params != nil {
t.Errorf("expected nil Params, got %+v", info.Params)
}
}
func TestApplyModelSettings(t *testing.T) {
// Save and restore viper state.
originalViper := viper.AllSettings()
defer func() {
viper.Reset()
for k, v := range originalViper {
viper.Set(k, v)
}
}()
t.Run("applies model params when not explicitly set", func(t *testing.T) {
viper.Reset()
temp := float32(0.8)
topK := int32(50)
maxTokens := 4096
modelInfo := &ModelInfo{
ID: "test-model",
Params: &GenerationParams{
Temperature: &temp,
TopK: &topK,
MaxTokens: &maxTokens,
},
}
config := &ProviderConfig{
ModelString: "custom/test-model",
}
ApplyModelSettings(config, modelInfo)
if config.Temperature == nil || *config.Temperature != 0.8 {
t.Errorf("expected temperature 0.8, got %v", config.Temperature)
}
if config.TopK == nil || *config.TopK != 50 {
t.Errorf("expected topK 50, got %v", config.TopK)
}
if config.MaxTokens != 4096 {
t.Errorf("expected maxTokens 4096, got %d", config.MaxTokens)
}
})
t.Run("explicit viper values take precedence", func(t *testing.T) {
viper.Reset()
viper.Set("temperature", 0.3)
temp := float32(0.8)
modelInfo := &ModelInfo{
ID: "test-model",
Params: &GenerationParams{
Temperature: &temp,
},
}
explicitTemp := float32(0.3)
config := &ProviderConfig{
ModelString: "custom/test-model",
Temperature: &explicitTemp,
}
ApplyModelSettings(config, modelInfo)
// Temperature should NOT be overridden because it's explicitly set in viper
if config.Temperature == nil || *config.Temperature != 0.3 {
t.Errorf("expected temperature 0.3 (explicit), got %v", config.Temperature)
}
})
t.Run("nil model info is safe", func(t *testing.T) {
viper.Reset()
config := &ProviderConfig{
ModelString: "custom/test-model",
}
// Should not panic
ApplyModelSettings(config, nil)
if config.Temperature != nil {
t.Errorf("expected nil temperature, got %v", config.Temperature)
}
})
t.Run("model info without params is safe", func(t *testing.T) {
viper.Reset()
modelInfo := &ModelInfo{ID: "test-model"}
config := &ProviderConfig{
ModelString: "custom/test-model",
}
ApplyModelSettings(config, modelInfo)
if config.Temperature != nil {
t.Errorf("expected nil temperature, got %v", config.Temperature)
}
})
t.Run("modelSettings from viper takes priority over ModelInfo.Params", func(t *testing.T) {
viper.Reset()
// Set up modelSettings in viper (simulating config file)
viper.Set("modelSettings", map[string]any{
"custom/test-model": map[string]any{
"temperature": 0.5,
"topK": 30,
},
})
// ModelInfo has different params
temp := float32(0.8)
topK := int32(50)
modelInfo := &ModelInfo{
ID: "test-model",
Params: &GenerationParams{
Temperature: &temp,
TopK: &topK,
},
}
config := &ProviderConfig{
ModelString: "custom/test-model",
}
ApplyModelSettings(config, modelInfo)
// modelSettings should win over ModelInfo.Params
if config.Temperature == nil || *config.Temperature != 0.5 {
t.Errorf("expected temperature 0.5 (from modelSettings), got %v", config.Temperature)
}
if config.TopK == nil || *config.TopK != 30 {
t.Errorf("expected topK 30 (from modelSettings), got %v", config.TopK)
}
})
t.Run("stop sequences applied from model params", func(t *testing.T) {
viper.Reset()
modelInfo := &ModelInfo{
ID: "test-model",
Params: &GenerationParams{
StopSequences: []string{"STOP", "END"},
},
}
config := &ProviderConfig{
ModelString: "custom/test-model",
}
ApplyModelSettings(config, modelInfo)
if len(config.StopSequences) != 2 || config.StopSequences[0] != "STOP" {
t.Errorf("expected stop sequences [STOP END], got %v", config.StopSequences)
}
})
t.Run("thinking level applied from model params", func(t *testing.T) {
viper.Reset()
modelInfo := &ModelInfo{
ID: "test-model",
Params: &GenerationParams{
ThinkingLevel: ThinkingHigh,
},
}
config := &ProviderConfig{
ModelString: "custom/test-model",
}
ApplyModelSettings(config, modelInfo)
if config.ThinkingLevel != ThinkingHigh {
t.Errorf("expected thinking level high, got %v", config.ThinkingLevel)
}
})
t.Run("system prompt applied from model params", func(t *testing.T) {
viper.Reset()
modelInfo := &ModelInfo{
ID: "test-model",
Params: &GenerationParams{
SystemPrompt: "You are a coding assistant.",
},
}
config := &ProviderConfig{
ModelString: "custom/test-model",
}
ApplyModelSettings(config, modelInfo)
if config.SystemPrompt != "You are a coding assistant." {
t.Errorf("expected system prompt to be set, got %q", config.SystemPrompt)
}
})
t.Run("explicit system prompt takes precedence", func(t *testing.T) {
viper.Reset()
modelInfo := &ModelInfo{
ID: "test-model",
Params: &GenerationParams{
SystemPrompt: "Model-specific prompt",
},
}
config := &ProviderConfig{
ModelString: "custom/test-model",
SystemPrompt: "Global prompt",
}
ApplyModelSettings(config, modelInfo)
// Global system prompt should NOT be overridden because config
// already has a non-empty SystemPrompt.
if config.SystemPrompt != "Global prompt" {
t.Errorf("expected global prompt preserved, got %q", config.SystemPrompt)
}
})
t.Run("system prompt from file path", func(t *testing.T) {
viper.Reset()
// Create a temp file with a system prompt
tmpFile, err := os.CreateTemp("", "kit-test-prompt-*.txt")
if err != nil {
t.Fatal(err)
}
defer func() { _ = os.Remove(tmpFile.Name()) }()
if _, err := tmpFile.WriteString(" Prompt from file "); err != nil {
t.Fatal(err)
}
_ = tmpFile.Close()
modelInfo := &ModelInfo{
ID: "test-model",
Params: &GenerationParams{
SystemPrompt: tmpFile.Name(),
},
}
config := &ProviderConfig{
ModelString: "custom/test-model",
}
ApplyModelSettings(config, modelInfo)
if config.SystemPrompt != "Prompt from file" {
t.Errorf("expected trimmed file content, got %q", config.SystemPrompt)
}
})
t.Run("modelSettings system prompt overrides custom model params", func(t *testing.T) {
viper.Reset()
viper.Set("modelSettings", map[string]any{
"custom/test-model": map[string]any{
"systemPrompt": "From modelSettings",
},
})
modelInfo := &ModelInfo{
ID: "test-model",
Params: &GenerationParams{
SystemPrompt: "From custom model",
},
}
config := &ProviderConfig{
ModelString: "custom/test-model",
}
ApplyModelSettings(config, modelInfo)
if config.SystemPrompt != "From modelSettings" {
t.Errorf("expected modelSettings prompt, got %q", config.SystemPrompt)
}
})
}
+5
View File
@@ -241,6 +241,11 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
validateModelConfig(config, modelInfo)
}
// Apply per-model generation parameter defaults. Model-level params are
// only applied for fields where the user hasn't explicitly set a value
// via CLI flag or global config.
ApplyModelSettings(config, modelInfo)
// Create the base provider
var result *ProviderResult
var createErr error
+17
View File
@@ -26,6 +26,11 @@ type ModelInfo struct {
ProviderNPM string // Model-specific provider npm override (e.g. "@ai-sdk/anthropic")
BaseURL string // Per-model base URL override (custom models only)
APIKey string // Per-model API key override (custom models only)
// Params holds per-model generation parameter defaults. These are applied
// when the user hasn't explicitly set the corresponding CLI flag or global
// config value. Nil pointer fields mean "no model-level default".
Params *GenerationParams
}
// SupportsCaching returns true if this model family supports prompt caching.
@@ -236,6 +241,18 @@ func (r *ModelsRegistry) LookupModel(provider, modelID string) *ModelInfo {
return &modelInfo
}
// LookupModelForSettings is a convenience function that parses a
// "provider/model" string and looks up the ModelInfo in the global registry.
// Returns nil when the model string is invalid or the model is unknown.
// Used by Kit.SetModel to pre-apply per-model settings before CreateProvider.
func LookupModelForSettings(modelString string) *ModelInfo {
provider, modelName, err := ParseModelString(modelString)
if err != nil {
return nil
}
return GetGlobalRegistry().LookupModel(provider, modelName)
}
// getRequiredEnvVars returns the required environment variables for a provider.
func (r *ModelsRegistry) getRequiredEnvVars(provider string) ([]string, error) {
providerInfo, exists := r.providers[provider]
+51 -18
View File
@@ -60,15 +60,16 @@ type MCPConnection struct {
// creation, health monitoring, and cleanup. The pool runs background health checks
// to proactively identify and remove unhealthy connections.
type MCPConnectionPool struct {
connections map[string]*MCPConnection
config *ConnectionPoolConfig
mu sync.RWMutex
model fantasy.LanguageModel
ctx context.Context
cancel context.CancelFunc
debug bool
debugLogger DebugLogger
oauthFlow *OAuthFlowRunner
connections map[string]*MCPConnection
config *ConnectionPoolConfig
mu sync.RWMutex
model fantasy.LanguageModel
ctx context.Context
cancel context.CancelFunc
debug bool
debugLogger DebugLogger
oauthFlow *OAuthFlowRunner
tokenStoreFactory TokenStoreFactory // custom factory for per-server token stores (nil = default FileTokenStore)
}
// NewMCPConnectionPool creates a new MCP connection pool with the specified configuration.
@@ -76,19 +77,20 @@ type MCPConnectionPool struct {
// goroutine for periodic health checks that runs until Close is called.
// The model parameter is used for MCP servers that require sampling support.
// Thread-safe for concurrent use immediately after creation.
func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageModel, debug bool, authHandler MCPAuthHandler) *MCPConnectionPool {
func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageModel, debug bool, authHandler MCPAuthHandler, tokenStoreFactory TokenStoreFactory) *MCPConnectionPool {
if config == nil {
config = DefaultConnectionPoolConfig()
}
ctx, cancel := context.WithCancel(context.Background())
pool := &MCPConnectionPool{
connections: make(map[string]*MCPConnection),
config: config,
model: model,
ctx: ctx,
cancel: cancel,
debug: debug,
connections: make(map[string]*MCPConnection),
config: config,
model: model,
ctx: ctx,
cancel: cancel,
debug: debug,
tokenStoreFactory: tokenStoreFactory,
}
if authHandler != nil {
@@ -367,7 +369,7 @@ func (p *MCPConnectionPool) createSSEClient(ctx context.Context, serverConfig co
// scopes are discovered automatically via dynamic client registration and
// server metadata (RFC 9728).
if p.oauthFlow != nil {
tokenStore, tsErr := NewFileTokenStore(serverConfig.URL)
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
if tsErr != nil {
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
}
@@ -414,7 +416,7 @@ func (p *MCPConnectionPool) createStreamableClient(ctx context.Context, serverCo
// scopes are discovered automatically via dynamic client registration and
// server metadata (RFC 9728).
if p.oauthFlow != nil {
tokenStore, tsErr := NewFileTokenStore(serverConfig.URL)
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
if tsErr != nil {
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
}
@@ -437,6 +439,16 @@ func (p *MCPConnectionPool) createStreamableClient(ctx context.Context, serverCo
return streamableClient, nil
}
// createTokenStore creates a token store for the given server URL.
// If a custom TokenStoreFactory is configured, it is used; otherwise the
// default file-backed token store is created.
func (p *MCPConnectionPool) createTokenStore(serverURL string) (transport.TokenStore, error) {
if p.tokenStoreFactory != nil {
return p.tokenStoreFactory(serverURL)
}
return NewFileTokenStore(serverURL)
}
// initializeClient initializes the client
func (p *MCPConnectionPool) initializeClient(ctx context.Context, client client.MCPClient) error {
initCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
@@ -583,6 +595,27 @@ func (p *MCPConnectionPool) GetClients() map[string]client.MCPClient {
return clients
}
// RemoveConnection closes and removes a single connection from the pool.
// Returns an error if the connection does not exist or if closing fails.
// Thread-safe for concurrent use.
func (p *MCPConnectionPool) RemoveConnection(serverName string) error {
p.mu.Lock()
defer p.mu.Unlock()
conn, exists := p.connections[serverName]
if !exists {
return fmt.Errorf("connection %q not found in pool", serverName)
}
err := conn.client.Close()
delete(p.connections, serverName)
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
p.debugLogger.LogDebug(fmt.Sprintf("[POOL] Removed connection %s", serverName))
}
return err
}
// Close gracefully shuts down the connection pool, closing all client connections
// and stopping the background health check goroutine. It attempts to close all
// connections even if some fail, logging any errors encountered.
+147 -10
View File
@@ -20,19 +20,25 @@ import (
// pooling, health checks, tool name prefixing to avoid conflicts, and sampling support for LLM interactions.
// Thread-safe for concurrent tool invocations.
type MCPToolManager struct {
connectionPool *MCPConnectionPool
tools []fantasy.AgentTool
toolMap map[string]*toolMapping // maps prefixed tool names to their server and original name
mu sync.Mutex // protects tools and toolMap during parallel loading
model fantasy.LanguageModel // LLM model for sampling
authHandler MCPAuthHandler // OAuth handler for remote servers (nil = no OAuth)
config *config.Config
debug bool
debugLogger DebugLogger
connectionPool *MCPConnectionPool
tools []fantasy.AgentTool
toolMap map[string]*toolMapping // maps prefixed tool names to their server and original name
mu sync.Mutex // protects tools and toolMap during parallel loading
model fantasy.LanguageModel // LLM model for sampling
authHandler MCPAuthHandler // OAuth handler for remote servers (nil = no OAuth)
tokenStoreFactory TokenStoreFactory // factory for creating per-server token stores (nil = default FileTokenStore)
config *config.Config
debug bool
debugLogger DebugLogger
// onServerLoaded, if non-nil, is called when each server finishes loading.
// Called with server name, tool count, and error (nil on success).
onServerLoaded func(serverName string, toolCount int, err error)
// onToolsChanged, if non-nil, is called after AddServer or RemoveServer
// mutates the tool list. The agent layer uses this to trigger a
// rebuildFantasyAgent so the LLM sees the updated tools.
onToolsChanged func()
}
// toolMapping stores the mapping between prefixed tool names and their original details
@@ -69,6 +75,14 @@ func (m *MCPToolManager) SetAuthHandler(handler MCPAuthHandler) {
m.authHandler = handler
}
// SetTokenStoreFactory sets a custom factory for creating per-server OAuth token
// stores. When set, the factory is called for each remote MCP server instead of
// using the default file-based token store. This method should be called before
// LoadTools.
func (m *MCPToolManager) SetTokenStoreFactory(factory TokenStoreFactory) {
m.tokenStoreFactory = factory
}
// SetDebugLogger sets the debug logger for the tool manager.
// The logger will be used to output detailed debugging information about MCP connections,
// tool loading, and execution. If a connection pool exists, it will also be configured
@@ -87,6 +101,126 @@ func (m *MCPToolManager) SetOnServerLoaded(cb func(serverName string, toolCount
m.onServerLoaded = cb
}
// SetOnToolsChanged sets the callback that's invoked after AddServer or
// RemoveServer mutates the tool list. The agent layer uses this to trigger
// a rebuild of the fantasy agent so the LLM sees the updated tool set.
func (m *MCPToolManager) SetOnToolsChanged(cb func()) {
m.onToolsChanged = cb
}
// AddServer connects to a new MCP server at runtime and loads its tools.
// The server's tools are immediately available to the agent after this call.
// Returns the number of tools loaded from the server.
//
// If the connection pool has not been initialised yet (i.e. LoadTools was never
// called), AddServer creates one automatically using the manager's current
// configuration.
//
// Returns an error if a server with the same name is already loaded, or if
// the connection or tool loading fails.
func (m *MCPToolManager) AddServer(ctx context.Context, name string, cfg config.MCPServerConfig) (int, error) {
m.mu.Lock()
// Check for duplicate.
if _, exists := m.toolMap[name+"__"]; exists {
m.mu.Unlock()
return 0, fmt.Errorf("MCP server %q is already loaded", name)
}
// More thorough duplicate check: scan toolMap for any key with the server prefix.
prefix := name + "__"
for k := range m.toolMap {
if len(k) >= len(prefix) && k[:len(prefix)] == prefix {
m.mu.Unlock()
return 0, fmt.Errorf("MCP server %q is already loaded", name)
}
}
m.mu.Unlock()
// Lazily create the connection pool if LoadTools was never called.
m.ensureConnectionPool()
count, err := m.loadServerTools(ctx, name, cfg)
if err != nil {
return 0, fmt.Errorf("failed to add MCP server %q: %w", name, err)
}
// Notify listeners.
if m.onServerLoaded != nil {
m.onServerLoaded(name, count, nil)
}
if m.onToolsChanged != nil {
m.onToolsChanged()
}
return count, nil
}
// RemoveServer disconnects an MCP server and removes all its tools.
// After this call the agent will no longer see or be able to call tools from
// the named server. Returns an error if the server is not loaded.
func (m *MCPToolManager) RemoveServer(name string) error {
prefix := name + "__"
m.mu.Lock()
// Check the server actually has tools loaded.
found := false
for k := range m.toolMap {
if len(k) >= len(prefix) && k[:len(prefix)] == prefix {
found = true
break
}
}
if !found {
m.mu.Unlock()
return fmt.Errorf("MCP server %q is not loaded", name)
}
// Remove tools belonging to this server.
newTools := make([]fantasy.AgentTool, 0, len(m.tools))
for _, t := range m.tools {
if len(t.Info().Name) < len(prefix) || t.Info().Name[:len(prefix)] != prefix {
newTools = append(newTools, t)
}
}
m.tools = newTools
// Remove tool mappings.
for k := range m.toolMap {
if len(k) >= len(prefix) && k[:len(prefix)] == prefix {
delete(m.toolMap, k)
}
}
m.mu.Unlock()
// Close the connection in the pool (best-effort).
if m.connectionPool != nil {
_ = m.connectionPool.RemoveConnection(name)
}
if m.onToolsChanged != nil {
m.onToolsChanged()
}
return nil
}
// ensureConnectionPool lazily creates a connection pool if one does not exist.
// This allows AddServer to work even if LoadTools was never called.
func (m *MCPToolManager) ensureConnectionPool() {
if m.connectionPool != nil {
return
}
debug := false
if m.config != nil {
debug = m.config.Debug
}
if m.debugLogger == nil {
m.debugLogger = NewSimpleDebugLogger(debug)
}
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, debug, m.authHandler, m.tokenStoreFactory)
m.connectionPool.SetDebugLogger(m.debugLogger)
}
// LoadTools loads tools from all configured MCP servers based on the provided configuration.
// It initializes the connection pool, connects to each configured server, and loads their tools.
// Tools from different servers are prefixed with the server name to avoid naming conflicts.
@@ -99,7 +233,7 @@ func (m *MCPToolManager) LoadTools(ctx context.Context, cfg *config.Config) erro
if m.debugLogger == nil {
m.debugLogger = NewSimpleDebugLogger(cfg.Debug)
}
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, cfg.Debug, m.authHandler)
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, cfg.Debug, m.authHandler, m.tokenStoreFactory)
m.connectionPool.SetDebugLogger(m.debugLogger)
// Load all servers in parallel. Each server connection (subprocess
@@ -290,6 +424,9 @@ func (m *MCPToolManager) GetLoadedServerNames() []string {
// proper cleanup of stdio processes, network connections, and other resources.
// It is safe to call Close multiple times.
func (m *MCPToolManager) Close() error {
if m.connectionPool == nil {
return nil
}
return m.connectionPool.Close()
}
@@ -0,0 +1,323 @@
package tools
import (
"context"
"os"
"path/filepath"
"runtime"
"slices"
"strings"
"sync"
"testing"
"time"
"github.com/mark3labs/kit/internal/config"
)
// testdataDir returns the absolute path to the testdata directory.
func testdataDir(t *testing.T) string {
t.Helper()
_, file, _, ok := runtime.Caller(0)
if !ok {
t.Fatal("cannot determine test file path")
}
return filepath.Join(filepath.Dir(file), "testdata")
}
// echoServerConfig returns an MCPServerConfig for the test echo MCP server.
func echoServerConfig(t *testing.T) config.MCPServerConfig {
t.Helper()
script := filepath.Join(testdataDir(t), "echo_server.py")
if _, err := os.Stat(script); err != nil {
t.Skipf("echo_server.py not found: %v", err)
}
return config.MCPServerConfig{
Command: []string{"python3", script},
}
}
// TestMCPToolManager_AddServer_Integration tests adding a real MCP server
// at runtime and verifying tools are loaded.
func TestMCPToolManager_AddServer_Integration(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
manager := NewMCPToolManager()
defer func() { _ = manager.Close() }()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cfg := echoServerConfig(t)
// Track callbacks.
var mu sync.Mutex
var loadedServer string
var loadedCount int
toolsChangedCount := 0
manager.SetOnServerLoaded(func(name string, count int, err error) {
mu.Lock()
loadedServer = name
loadedCount = count
mu.Unlock()
})
manager.SetOnToolsChanged(func() {
mu.Lock()
toolsChangedCount++
mu.Unlock()
})
// Add the server.
count, err := manager.AddServer(ctx, "echo", cfg)
if err != nil {
t.Fatalf("AddServer failed: %v", err)
}
if count != 2 {
t.Errorf("Expected 2 tools from echo server, got %d", count)
}
// Verify callbacks fired.
mu.Lock()
if loadedServer != "echo" {
t.Errorf("Expected onServerLoaded for 'echo', got %q", loadedServer)
}
if loadedCount != 2 {
t.Errorf("Expected onServerLoaded count=2, got %d", loadedCount)
}
if toolsChangedCount != 1 {
t.Errorf("Expected onToolsChanged called once, got %d", toolsChangedCount)
}
mu.Unlock()
// Verify tools are accessible.
tools := manager.GetTools()
if len(tools) != 2 {
t.Fatalf("Expected 2 tools, got %d", len(tools))
}
// Verify tool names are prefixed.
toolNames := make(map[string]bool)
for _, tool := range tools {
toolNames[tool.Info().Name] = true
}
if !toolNames["echo__echo"] {
t.Error("Expected tool 'echo__echo'")
}
if !toolNames["echo__greet"] {
t.Error("Expected tool 'echo__greet'")
}
// Verify server appears in loaded names.
names := manager.GetLoadedServerNames()
if !slices.Contains(names, "echo") {
t.Errorf("Expected 'echo' in loaded server names, got: %v", names)
}
}
// TestMCPToolManager_RemoveServer_Integration tests removing a real MCP server
// and verifying tools are cleaned up.
func TestMCPToolManager_RemoveServer_Integration(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
manager := NewMCPToolManager()
defer func() { _ = manager.Close() }()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cfg := echoServerConfig(t)
// Add the server first.
count, err := manager.AddServer(ctx, "echo", cfg)
if err != nil {
t.Fatalf("AddServer failed: %v", err)
}
if count != 2 {
t.Fatalf("Expected 2 tools, got %d", count)
}
var mu sync.Mutex
toolsChangedCount := 0
manager.SetOnToolsChanged(func() {
mu.Lock()
toolsChangedCount++
mu.Unlock()
})
// Remove the server.
err = manager.RemoveServer("echo")
if err != nil {
t.Fatalf("RemoveServer failed: %v", err)
}
// Verify tools are gone.
tools := manager.GetTools()
if len(tools) != 0 {
t.Errorf("Expected 0 tools after removal, got %d", len(tools))
}
// Verify callback fired.
mu.Lock()
if toolsChangedCount != 1 {
t.Errorf("Expected onToolsChanged called once, got %d", toolsChangedCount)
}
mu.Unlock()
// Verify server is gone from loaded names.
names := manager.GetLoadedServerNames()
for _, n := range names {
if n == "echo" {
t.Error("Server 'echo' should not appear in loaded names after removal")
}
}
// Removing again should error.
err = manager.RemoveServer("echo")
if err == nil {
t.Fatal("Expected error removing already-removed server")
}
if !strings.Contains(err.Error(), "not loaded") {
t.Errorf("Expected 'not loaded' error, got: %v", err)
}
}
// TestMCPToolManager_AddRemoveMultiple_Integration tests adding and removing
// multiple servers, verifying tool isolation.
func TestMCPToolManager_AddRemoveMultiple_Integration(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
manager := NewMCPToolManager()
defer func() { _ = manager.Close() }()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cfg := echoServerConfig(t)
// Add two servers with the same binary but different names.
count1, err := manager.AddServer(ctx, "server-a", cfg)
if err != nil {
t.Fatalf("AddServer server-a failed: %v", err)
}
count2, err := manager.AddServer(ctx, "server-b", cfg)
if err != nil {
t.Fatalf("AddServer server-b failed: %v", err)
}
totalTools := count1 + count2
if totalTools != 4 {
t.Fatalf("Expected 4 total tools (2+2), got %d", totalTools)
}
tools := manager.GetTools()
if len(tools) != 4 {
t.Fatalf("Expected 4 tools, got %d", len(tools))
}
// Remove server-a, verify server-b tools remain.
err = manager.RemoveServer("server-a")
if err != nil {
t.Fatalf("RemoveServer server-a failed: %v", err)
}
tools = manager.GetTools()
if len(tools) != 2 {
t.Fatalf("Expected 2 tools after removing server-a, got %d", len(tools))
}
// Remaining tools should all be from server-b.
for _, tool := range tools {
if !strings.HasPrefix(tool.Info().Name, "server-b__") {
t.Errorf("Expected tool from server-b, got: %s", tool.Info().Name)
}
}
// Remove server-b.
err = manager.RemoveServer("server-b")
if err != nil {
t.Fatalf("RemoveServer server-b failed: %v", err)
}
tools = manager.GetTools()
if len(tools) != 0 {
t.Errorf("Expected 0 tools after removing all servers, got %d", len(tools))
}
}
// TestMCPToolManager_AddServer_DuplicateDetection_Integration tests that
// adding a server with the same name as an already loaded server errors.
func TestMCPToolManager_AddServer_DuplicateDetection_Integration(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
manager := NewMCPToolManager()
defer func() { _ = manager.Close() }()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cfg := echoServerConfig(t)
// Add the server.
_, err := manager.AddServer(ctx, "echo", cfg)
if err != nil {
t.Fatalf("First AddServer failed: %v", err)
}
// Try to add again with the same name.
_, err = manager.AddServer(ctx, "echo", cfg)
if err == nil {
t.Fatal("Expected error adding duplicate server")
}
if !strings.Contains(err.Error(), "already loaded") {
t.Errorf("Expected 'already loaded' error, got: %v", err)
}
}
// TestMCPToolManager_AddAfterRemove_Integration tests that a server can be
// re-added after being removed.
func TestMCPToolManager_AddAfterRemove_Integration(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
manager := NewMCPToolManager()
defer func() { _ = manager.Close() }()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cfg := echoServerConfig(t)
// Add, remove, re-add.
_, err := manager.AddServer(ctx, "echo", cfg)
if err != nil {
t.Fatalf("First AddServer failed: %v", err)
}
err = manager.RemoveServer("echo")
if err != nil {
t.Fatalf("RemoveServer failed: %v", err)
}
count, err := manager.AddServer(ctx, "echo", cfg)
if err != nil {
t.Fatalf("Re-AddServer failed: %v", err)
}
if count != 2 {
t.Errorf("Expected 2 tools on re-add, got %d", count)
}
tools := manager.GetTools()
if len(tools) != 2 {
t.Errorf("Expected 2 tools after re-add, got %d", len(tools))
}
}
+155
View File
@@ -0,0 +1,155 @@
package tools
import (
"context"
"strings"
"sync"
"testing"
"time"
"github.com/mark3labs/kit/internal/config"
)
// TestMCPToolManager_AddServer_DuplicateName verifies that adding a server
// with a name that already exists returns an error.
func TestMCPToolManager_AddServer_DuplicateName(t *testing.T) {
manager := NewMCPToolManager()
cfg := config.MCPServerConfig{
Command: []string{"non-existent-command"},
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// First add will fail (bad command), but let's test the duplicate detection
// by simulating a loaded server via LoadTools first.
loadCfg := &config.Config{
MCPServers: map[string]config.MCPServerConfig{
"test-server": cfg,
},
}
// This will fail to load but creates the connection pool.
_ = manager.LoadTools(ctx, loadCfg)
// Now try to add the same server name — the tools didn't load (bad command),
// so AddServer should not find a duplicate and should fail with connection error.
_, err := manager.AddServer(ctx, "test-server", cfg)
if err == nil {
t.Fatal("Expected error when adding server with bad command, got nil")
}
// It should be a connection error, not a duplicate error.
if strings.Contains(err.Error(), "already loaded") {
t.Fatalf("Should not report duplicate since server failed to load initially: %v", err)
}
}
// TestMCPToolManager_RemoveServer_NotLoaded verifies that removing a server
// that doesn't exist returns an appropriate error.
func TestMCPToolManager_RemoveServer_NotLoaded(t *testing.T) {
manager := NewMCPToolManager()
err := manager.RemoveServer("nonexistent")
if err == nil {
t.Fatal("Expected error when removing non-existent server, got nil")
}
if !strings.Contains(err.Error(), "not loaded") {
t.Errorf("Expected 'not loaded' error, got: %v", err)
}
}
// TestMCPToolManager_AddServer_CreatesConnectionPool verifies that AddServer
// lazily creates a connection pool when LoadTools was never called.
func TestMCPToolManager_AddServer_CreatesConnectionPool(t *testing.T) {
manager := NewMCPToolManager()
// Connection pool should be nil initially.
if manager.connectionPool != nil {
t.Fatal("Expected nil connection pool before any operation")
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// AddServer with a bad command — should fail, but the pool should be created.
_, err := manager.AddServer(ctx, "lazy-server", config.MCPServerConfig{
Command: []string{"non-existent-command"},
})
if err == nil {
t.Fatal("Expected error for bad command")
}
// Connection pool should have been created.
if manager.connectionPool == nil {
t.Fatal("Expected connection pool to be created lazily by AddServer")
}
}
// TestMCPToolManager_OnToolsChanged_Callback verifies that the onToolsChanged
// callback fires on RemoveServer (we can't easily test AddServer with a real
// MCP server, but we can test the callback wiring).
func TestMCPToolManager_OnToolsChanged_Callback(t *testing.T) {
manager := NewMCPToolManager()
var mu sync.Mutex
callCount := 0
manager.SetOnToolsChanged(func() {
mu.Lock()
callCount++
mu.Unlock()
})
// RemoveServer on non-existent should NOT fire callback.
_ = manager.RemoveServer("nonexistent")
mu.Lock()
if callCount != 0 {
t.Errorf("Expected 0 callback calls for failed remove, got %d", callCount)
}
mu.Unlock()
}
// TestMCPToolManager_Close_NilPool verifies Close is safe when the connection
// pool was never initialized.
func TestMCPToolManager_Close_NilPool(t *testing.T) {
manager := NewMCPToolManager()
err := manager.Close()
if err != nil {
t.Fatalf("Expected nil error from Close with nil pool, got: %v", err)
}
}
// TestMCPConnectionPool_RemoveConnection_NotFound verifies that removing a
// non-existent connection returns an error.
func TestMCPConnectionPool_RemoveConnection_NotFound(t *testing.T) {
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), nil, false, nil, nil)
defer func() { _ = pool.Close() }()
err := pool.RemoveConnection("nonexistent")
if err == nil {
t.Fatal("Expected error for non-existent connection")
}
if !strings.Contains(err.Error(), "not found") {
t.Errorf("Expected 'not found' error, got: %v", err)
}
}
// TestMCPToolManager_EnsureConnectionPool_Idempotent verifies that
// ensureConnectionPool doesn't recreate an existing pool.
func TestMCPToolManager_EnsureConnectionPool_Idempotent(t *testing.T) {
manager := NewMCPToolManager()
// First call creates the pool.
manager.ensureConnectionPool()
pool1 := manager.connectionPool
if pool1 == nil {
t.Fatal("Expected pool to be created")
}
// Second call should be a no-op.
manager.ensureConnectionPool()
pool2 := manager.connectionPool
if pool1 != pool2 {
t.Fatal("Expected ensureConnectionPool to be idempotent")
}
}
+7
View File
@@ -6,6 +6,7 @@ import (
"net/url"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
)
// MCPAuthHandler is the internal interface for handling MCP OAuth flows.
@@ -21,6 +22,12 @@ type MCPAuthHandler interface {
HandleAuth(ctx context.Context, serverName string, authURL string) (callbackURL string, err error)
}
// TokenStoreFactory creates a transport.TokenStore for a given MCP server URL.
// When provided to the connection pool, it is called once per remote MCP server
// instead of using the default file-based token store. Implementations can
// return any transport.TokenStore — in-memory, database-backed, encrypted, etc.
type TokenStoreFactory func(serverURL string) (transport.TokenStore, error)
// OAuthFlowRunner handles the OAuth authorization flow when an MCP server
// returns an OAuthAuthorizationRequiredError. It coordinates dynamic client
// registration, PKCE generation, user authorization (via MCPAuthHandler),
+111
View File
@@ -0,0 +1,111 @@
#!/usr/bin/env python3
"""Minimal MCP server over stdio for testing. Exposes one tool: echo."""
import json
import sys
def read_message():
"""Read a JSON-RPC message from stdin."""
line = sys.stdin.readline()
if not line:
return None
return json.loads(line.strip())
def write_message(msg):
"""Write a JSON-RPC message to stdout."""
sys.stdout.write(json.dumps(msg) + "\n")
sys.stdout.flush()
def handle(msg):
method = msg.get("method", "")
mid = msg.get("id")
if method == "initialize":
write_message({
"jsonrpc": "2.0",
"id": mid,
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"serverInfo": {"name": "test-echo", "version": "1.0.0"},
},
})
elif method == "notifications/initialized":
pass # no response needed
elif method == "tools/list":
write_message({
"jsonrpc": "2.0",
"id": mid,
"result": {
"tools": [
{
"name": "echo",
"description": "Echoes the input text back.",
"inputSchema": {
"type": "object",
"properties": {
"text": {"type": "string", "description": "Text to echo"}
},
"required": ["text"],
},
},
{
"name": "greet",
"description": "Returns a greeting.",
"inputSchema": {
"type": "object",
"properties": {
"name": {"type": "string", "description": "Name to greet"}
},
"required": ["name"],
},
},
]
},
})
elif method == "tools/call":
tool_name = msg["params"]["name"]
args = msg["params"].get("arguments", {})
if tool_name == "echo":
text = args.get("text", "")
write_message({
"jsonrpc": "2.0",
"id": mid,
"result": {
"content": [{"type": "text", "text": text}]
},
})
elif tool_name == "greet":
name = args.get("name", "World")
write_message({
"jsonrpc": "2.0",
"id": mid,
"result": {
"content": [{"type": "text", "text": f"Hello, {name}!"}]
},
})
else:
write_message({
"jsonrpc": "2.0",
"id": mid,
"error": {"code": -32601, "message": f"Unknown tool: {tool_name}"},
})
elif method == "ping":
write_message({"jsonrpc": "2.0", "id": mid, "result": {}})
else:
if mid is not None:
write_message({
"jsonrpc": "2.0",
"id": mid,
"error": {"code": -32601, "message": f"Unknown method: {method}"},
})
if __name__ == "__main__":
while True:
msg = read_message()
if msg is None:
break
handle(msg)
+3 -1
View File
@@ -139,7 +139,9 @@ func (h *CLIEventHandler) Handle(msg tea.Msg) {
case "block":
h.cli.DisplayExtensionBlock(e.Text, e.BorderColor, e.Subtitle)
default:
fmt.Println(e.Text)
// Route unstyled extension prints through the system message
// renderer so they get consistent formatting and timestamps.
h.cli.DisplayInfo(e.Text)
}
case app.StepCompleteEvent:
+1 -3
View File
@@ -109,9 +109,7 @@ func SetupCLI(opts *CLISetupOptions) (*CLI, error) {
}
}
fmt.Println("")
// Display model info
// Display model info (the system message block provides its own spacing).
if provider != "unknown" && model != "unknown" {
cli.DisplayInfo(fmt.Sprintf("Model loaded: %s (%s)", provider, model))
}
+236 -43
View File
@@ -51,6 +51,12 @@ type Kit struct {
authHandler MCPAuthHandler // OAuth handler for remote MCP servers (may need Close)
opts *Options // stored for reload operations (skills, etc.)
// hasCustomSystemPrompt is true when the user explicitly configured a
// system prompt (via --system-prompt flag, config file, or SDK option).
// When false, per-model system prompts from modelSettings/customModels
// can replace the default prompt on model switch.
hasCustomSystemPrompt bool
// Hook registries — interception layer (see hooks.go).
beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult]
afterToolResult *hookRegistry[AfterToolResultHook, AfterToolResultResult]
@@ -140,6 +146,79 @@ func (m *Kit) MCPToolsReady() bool {
return m.agent.MCPToolsReady()
}
// MCPServerStatus describes the runtime state of a loaded MCP server.
type MCPServerStatus struct {
// Name is the configured server name.
Name string
// ToolCount is the number of tools loaded from this server.
ToolCount int
}
// AddMCPServer connects to a new MCP server at runtime and makes its tools
// available to the agent immediately. The server's tools are prefixed with the
// server name (e.g. "myserver__tool_name") to avoid naming conflicts, matching
// the behaviour of servers loaded at initialization.
//
// Returns the number of tools loaded from the server.
//
// AddMCPServer is safe to call while the agent is idle. If a turn is in
// progress ([Kit.IsGenerating] returns true), the new tools will be visible
// starting from the next LLM step.
//
// Example:
//
// n, err := k.AddMCPServer(ctx, "github", kit.MCPServerConfig{
// Command: []string{"npx", "-y", "@modelcontextprotocol/server-github"},
// Environment: map[string]string{"GITHUB_TOKEN": os.Getenv("GITHUB_TOKEN")},
// })
func (m *Kit) AddMCPServer(ctx context.Context, name string, cfg MCPServerConfig) (int, error) {
return m.agent.AddMCPServer(ctx, name, cfg)
}
// RemoveMCPServer disconnects an MCP server and removes all its tools from
// the agent. After this call the agent will no longer see or be able to call
// tools from the named server.
//
// RemoveMCPServer is safe to call while the agent is idle. If a turn is in
// progress, the tools are removed at the next LLM step. Any in-flight tool
// calls to the removed server will fail gracefully.
//
// Returns an error if the named server is not currently loaded.
func (m *Kit) RemoveMCPServer(name string) error {
return m.agent.RemoveMCPServer(name)
}
// ListMCPServers returns the status of all currently loaded MCP servers.
// The returned slice is a snapshot; it is safe to read concurrently.
func (m *Kit) ListMCPServers() []MCPServerStatus {
names := m.agent.GetLoadedServerNames()
if len(names) == 0 {
return nil
}
// Build a tool count per server by scanning tool names for the prefix.
toolNames := m.GetToolNames()
countByServer := make(map[string]int, len(names))
for _, tn := range toolNames {
for _, sn := range names {
prefix := sn + "__"
if len(tn) > len(prefix) && tn[:len(prefix)] == prefix {
countByServer[sn]++
break
}
}
}
result := make([]MCPServerStatus, 0, len(names))
for _, n := range names {
result = append(result, MCPServerStatus{
Name: n,
ToolCount: countByServer[n],
})
}
return result
}
// GetExtensionToolCount returns the number of tools registered by extensions.
func (m *Kit) GetExtensionToolCount() int {
return m.agent.GetExtensionToolCount()
@@ -221,9 +300,12 @@ func iterBranchMessages[T any](tm *session.TreeManager, fn func(*session.Message
return results
}
// SetModel changes the active model at runtime. The existing tools, system
// prompt, and session are preserved. The model string should be in
// "provider/model" format (e.g. "anthropic/claude-sonnet-4-5-20250929").
// SetModel changes the active model at runtime. The existing tools and
// session are preserved. When the new model has a per-model system prompt
// (from modelSettings or customModels params), it is composed with the
// current AGENTS.md context and skills before being applied.
// The model string should be in "provider/model" format
// (e.g. "anthropic/claude-sonnet-4-5-20250929").
// Returns an error if the model string is invalid or the provider cannot
// be created.
func (m *Kit) SetModel(ctx context.Context, modelString string) error {
@@ -239,7 +321,7 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
// With message-level caching, thinking and caching can work together.
// No need to disable caching when thinking is enabled.
config := &models.ProviderConfig{
cfg := &models.ProviderConfig{
ModelString: modelString,
SystemPrompt: systemPrompt,
ProviderAPIKey: viper.GetString("provider-api-key"),
@@ -249,18 +331,50 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
ThinkingLevel: thinkingLevel,
DisableCaching: false, // Caching enabled by default, works with thinking
}
temperature := float32(viper.GetFloat64("temperature"))
config.Temperature = &temperature
topP := float32(viper.GetFloat64("top-p"))
config.TopP = &topP
topK := int32(viper.GetInt("top-k"))
config.TopK = &topK
frequencyPenalty := float32(viper.GetFloat64("frequency-penalty"))
config.FrequencyPenalty = &frequencyPenalty
presencePenalty := float32(viper.GetFloat64("presence-penalty"))
config.PresencePenalty = &presencePenalty
if err := m.agent.SetModel(ctx, config); err != nil {
// Only set generation parameter pointers when the user has explicitly
// provided a value. This leaves nil pointers for unset params, allowing
// per-model defaults (modelSettings / customModels params) to apply.
if viper.IsSet("temperature") {
v := float32(viper.GetFloat64("temperature"))
cfg.Temperature = &v
}
if viper.IsSet("top-p") {
v := float32(viper.GetFloat64("top-p"))
cfg.TopP = &v
}
if viper.IsSet("top-k") {
v := int32(viper.GetInt("top-k"))
cfg.TopK = &v
}
if viper.IsSet("frequency-penalty") {
v := float32(viper.GetFloat64("frequency-penalty"))
cfg.FrequencyPenalty = &v
}
if viper.IsSet("presence-penalty") {
v := float32(viper.GetFloat64("presence-penalty"))
cfg.PresencePenalty = &v
}
// When the user hasn't set a custom global system prompt, check for a
// per-model system prompt. Pre-apply model settings to discover it,
// then compose with AGENTS.md context and skills if found.
if !m.hasCustomSystemPrompt {
// Temporarily clear the system prompt so ApplyModelSettings can
// detect that no explicit prompt is set and apply the per-model one.
cfg.SystemPrompt = ""
models.ApplyModelSettings(cfg, models.LookupModelForSettings(modelString))
if cfg.SystemPrompt != "" {
// Per-model system prompt found — compose with runtime context.
cfg.SystemPrompt = m.composeSystemPrompt(cfg.SystemPrompt)
} else {
// No per-model prompt — restore the global composed prompt.
cfg.SystemPrompt = systemPrompt
}
}
if err := m.agent.SetModel(ctx, cfg); err != nil {
return err
}
@@ -276,6 +390,32 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
return nil
}
// composeSystemPrompt takes a base system prompt and composes it with the
// current runtime context: AGENTS.md content, skills metadata, and date/cwd.
// This mirrors the composition done during Kit.New() initialization.
func (m *Kit) composeSystemPrompt(basePrompt string) string {
cwd, _ := os.Getwd()
pb := skills.NewPromptBuilder(basePrompt)
// Inject AGENTS.md content as project context.
for _, cf := range m.contextFiles {
pb.WithSection("", fmt.Sprintf("Instructions from: %s\n\n%s", cf.Path, cf.Content))
}
// Inject skills metadata.
if len(m.skills) > 0 {
pb.WithSkills(m.skills)
}
// Append current date/time and working directory.
pb.WithSection("", fmt.Sprintf(
"Current date and time: %s\nCurrent working directory: %s",
time.Now().Format("Monday, January 2, 2006, 3:04:05 PM MST"), cwd,
))
return pb.Build()
}
// GetAvailableModels returns a list of known models from the registry. Each
// entry includes provider, model ID, context limit, and whether the model
// supports reasoning. This is an advisory list — models not in the registry
@@ -497,6 +637,17 @@ type Options struct {
// display a URL in a custom UI, redirect to a web app, etc.).
MCPAuthHandler MCPAuthHandler
// MCPTokenStoreFactory, if non-nil, is called to create a token store for
// each remote MCP server that requires OAuth. The factory receives the
// server's URL and returns a [MCPTokenStore] implementation.
//
// When nil (default), tokens are persisted to a JSON file at
// $XDG_CONFIG_HOME/.kit/mcp_tokens.json (or ~/.config/.kit/mcp_tokens.json).
//
// Use this to store tokens in a database, encrypt them, keep them
// in-memory, or write them to a custom file path.
MCPTokenStoreFactory MCPTokenStoreFactory
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
// loading during Kit initialization. The callback receives the server name,
// tool count, and any error. Called from a background goroutine; safe to
@@ -582,16 +733,17 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
// provider creation, session init) then runs outside the lock, allowing
// parallel subagent spawns to proceed concurrently.
var (
providerConfig *models.ProviderConfig
modelString string
cwd string
contextFiles []*ContextFile
loadedSkills []*Skill
mcpConfig *config.Config
debug bool
noExtensions bool
maxSteps int
streaming bool
providerConfig *models.ProviderConfig
modelString string
cwd string
contextFiles []*ContextFile
loadedSkills []*Skill
mcpConfig *config.Config
debug bool
noExtensions bool
maxSteps int
streaming bool
hasCustomSystemPrompt bool
)
if err := func() error {
@@ -647,8 +799,41 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
// Always compose the system prompt with runtime context: base prompt +
// AGENTS.md context + skills metadata + date/cwd.
//
// If the configured model has a per-model system prompt (via
// modelSettings or customModels params) and the user hasn't
// explicitly set system-prompt, use the per-model prompt as the
// base instead of the global default.
{
basePrompt := viper.GetString("system-prompt")
// Track whether the user explicitly configured a custom system
// prompt. When they haven't (basePrompt is the built-in default
// or empty), per-model system prompts can replace it on switch.
userSetSystemPrompt := basePrompt != "" && basePrompt != defaultSystemPrompt
hasCustomSystemPrompt = userSetSystemPrompt
// Check for per-model system prompt override when no explicit
// global system-prompt was configured by the user.
if !userSetSystemPrompt {
modelStr := viper.GetString("model")
if modelStr != "" {
if mi := models.LookupModelForSettings(modelStr); mi != nil {
var perModelParams *models.GenerationParams
// modelSettings takes priority over custom model params.
if ms := models.LoadModelSettingsFromConfig(); ms != nil {
perModelParams = ms[modelStr]
}
if perModelParams == nil && mi.Params != nil {
perModelParams = mi.Params
}
if perModelParams != nil && perModelParams.SystemPrompt != "" {
basePrompt = models.LoadSystemPromptValue(perModelParams.SystemPrompt)
}
}
}
}
pb := skills.NewPromptBuilder(basePrompt)
// Inject AGENTS.md content as project context.
@@ -745,6 +930,13 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
}
}
// Set up custom token store factory for MCP OAuth tokens.
// The SDK MCPTokenStoreFactory is structurally identical to
// tools.TokenStoreFactory, so it can be assigned directly.
if opts.MCPTokenStoreFactory != nil {
setupOpts.TokenStoreFactory = tools.TokenStoreFactory(opts.MCPTokenStoreFactory)
}
if opts.CLI != nil {
setupOpts.ShowSpinner = opts.CLI.ShowSpinner
setupOpts.SpinnerFunc = opts.CLI.SpinnerFunc
@@ -774,24 +966,25 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
}
k := &Kit{
agent: agentResult.Agent,
session: sessionManager,
modelString: modelString,
events: newEventBus(),
autoCompact: opts.AutoCompact,
compactionOpts: opts.CompactionOptions,
contextFiles: contextFiles,
skills: loadedSkills,
extRunner: agentResult.ExtRunner,
bufferedLogger: agentResult.BufferedLogger,
authHandler: setupOpts.AuthHandler,
opts: opts,
beforeToolCall: beforeToolCall,
afterToolResult: afterToolResult,
beforeTurn: beforeTurn,
afterTurn: afterTurn,
contextPrepare: contextPrepare,
beforeCompact: beforeCompact,
agent: agentResult.Agent,
session: sessionManager,
modelString: modelString,
events: newEventBus(),
autoCompact: opts.AutoCompact,
compactionOpts: opts.CompactionOptions,
contextFiles: contextFiles,
skills: loadedSkills,
extRunner: agentResult.ExtRunner,
bufferedLogger: agentResult.BufferedLogger,
authHandler: setupOpts.AuthHandler,
opts: opts,
hasCustomSystemPrompt: hasCustomSystemPrompt,
beforeToolCall: beforeToolCall,
afterToolResult: afterToolResult,
beforeTurn: beforeTurn,
afterTurn: afterTurn,
contextPrepare: contextPrepare,
beforeCompact: beforeCompact,
}
// Bridge extension events to SDK hooks.
+56
View File
@@ -0,0 +1,56 @@
package kit_test
import (
"testing"
kit "github.com/mark3labs/kit/pkg/kit"
)
// TestMCPServerStatus_TypeSurface verifies the MCPServerStatus type is
// accessible and has the expected fields.
func TestMCPServerStatus_TypeSurface(t *testing.T) {
s := kit.MCPServerStatus{
Name: "test-server",
ToolCount: 5,
}
if s.Name != "test-server" {
t.Errorf("Expected Name 'test-server', got %q", s.Name)
}
if s.ToolCount != 5 {
t.Errorf("Expected ToolCount 5, got %d", s.ToolCount)
}
}
// TestMCPServerConfig_ForDynamicAdd verifies that MCPServerConfig can be
// constructed with the expected fields for dynamic server management.
func TestMCPServerConfig_ForDynamicAdd(t *testing.T) {
// Stdio server config.
stdio := kit.MCPServerConfig{
Command: []string{"npx", "-y", "@modelcontextprotocol/server-github"},
Environment: map[string]string{"GITHUB_TOKEN": "test-token"},
}
if len(stdio.Command) != 3 {
t.Errorf("Expected 3 command parts, got %d", len(stdio.Command))
}
if stdio.Environment["GITHUB_TOKEN"] != "test-token" {
t.Error("Expected GITHUB_TOKEN in environment")
}
// Remote server config.
remote := kit.MCPServerConfig{
URL: "https://mcp.example.com/sse",
Headers: []string{"Authorization: Bearer test"},
}
if remote.URL != "https://mcp.example.com/sse" {
t.Errorf("Unexpected URL: %s", remote.URL)
}
// Config with tool filtering.
filtered := kit.MCPServerConfig{
Command: []string{"some-server"},
AllowedTools: []string{"read", "write"},
}
if len(filtered.AllowedTools) != 2 {
t.Errorf("Expected 2 allowed tools, got %d", len(filtered.AllowedTools))
}
}
+24
View File
@@ -11,6 +11,7 @@ import (
"github.com/mark3labs/kit/internal/message"
"github.com/mark3labs/kit/internal/models"
"github.com/mark3labs/kit/internal/session"
"github.com/mark3labs/mcp-go/client/transport"
)
// ==== Message Types (internal/message/content.go) ====
@@ -204,6 +205,29 @@ type CompactionResult = compaction.CompactionResult
// CompactionOptions configures compaction behaviour.
type CompactionOptions = compaction.CompactionOptions
// ==== MCP OAuth Types ====
// MCPTokenStore persists OAuth tokens for a single MCP server. Implementations
// must be safe for concurrent use.
//
// This is a type alias for the mcp-go transport.TokenStore interface. SDK
// consumers can implement this interface to provide custom storage backends
// (database, encrypted file, in-memory, etc.).
type MCPTokenStore = transport.TokenStore
// MCPToken represents an OAuth token for an MCP server, containing access
// and refresh tokens along with expiration metadata.
type MCPToken = transport.Token
// MCPTokenStoreFactory creates an [MCPTokenStore] for a given MCP server URL.
// It is called once per remote MCP server during connection setup.
type MCPTokenStoreFactory func(serverURL string) (MCPTokenStore, error)
// ErrMCPNoToken is the sentinel error that [MCPTokenStore] implementations
// should return from GetToken when no token is stored for the server.
// Callers can check for this with errors.Is.
var ErrMCPNoToken = transport.ErrNoToken
// ==== Constructor & Helper Functions ====
// ParseModelString parses a model string in "provider/model" format.