diff --git a/internal/config/config.go b/internal/config/config.go index f63c08de..8a603b2e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -157,6 +157,20 @@ type Theme struct { Markdown MarkdownThemeConfig `json:"markdown,omitzero" yaml:"markdown,omitempty"` } +// GenerationParams defines generation parameter defaults that can be attached +// to individual models. These act as model-level defaults — CLI flags and +// global config values take precedence when explicitly set. +type GenerationParams struct { + MaxTokens *int `json:"maxTokens,omitempty" yaml:"maxTokens,omitempty"` + Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"` + TopP *float32 `json:"topP,omitempty" yaml:"topP,omitempty"` + TopK *int32 `json:"topK,omitempty" yaml:"topK,omitempty"` + FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty" yaml:"frequencyPenalty,omitempty"` + PresencePenalty *float32 `json:"presencePenalty,omitempty" yaml:"presencePenalty,omitempty"` + StopSequences []string `json:"stopSequences,omitempty" yaml:"stopSequences,omitempty"` + ThinkingLevel string `json:"thinkingLevel,omitempty" yaml:"thinkingLevel,omitempty"` +} + // CustomModelConfig defines a custom model that can be used with custom/custom // or other custom/ prefixed models. These models are loaded from the config file // and merged into the custom provider in the model registry. @@ -171,6 +185,11 @@ type CustomModelConfig struct { Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"` Cost CostConfig `json:"cost" yaml:"cost"` Limit LimitConfig `json:"limit" yaml:"limit"` + + // Generation parameter defaults for this model. + // These are applied when the user hasn't explicitly set the corresponding + // CLI flag or global config value. + Params GenerationParams `json:"params,omitzero" yaml:"params,omitempty"` } // CostConfig defines the pricing for a custom model. @@ -219,6 +238,12 @@ type Config struct { // Custom model definitions (under custom/ provider) CustomModels map[string]CustomModelConfig `json:"customModels,omitempty" yaml:"customModels,omitempty"` + + // Per-model generation parameter overrides. Keys are "provider/model" strings + // (e.g. "anthropic/claude-sonnet-4-5-20250929", "openai/gpt-4o"). These + // settings act as model-level defaults — CLI flags and global config values + // take precedence when explicitly set. + ModelSettings map[string]GenerationParams `json:"modelSettings,omitempty" yaml:"modelSettings,omitempty"` } // GetTransportType returns the transport type for the server config, mapping @@ -367,7 +392,7 @@ mcpServers: # debug: false # Enable debug logging # system-prompt: "/path/to/system-prompt.txt" # System prompt text file -# Model generation parameters (all optional) +# Model generation parameters (all optional, apply globally to all models) # max-tokens: 4096 # Maximum tokens in response # temperature: 0.7 # Randomness (0.0-1.0) # top-p: 0.95 # Nucleus sampling (0.0-1.0) @@ -376,9 +401,44 @@ mcpServers: # presence-penalty: 0.0 # Penalize present tokens (0.0-2.0) # stop-sequences: ["Human:", "Assistant:"] # Custom stop sequences +# Per-model generation parameter overrides (apply to specific models) +# These act as model-level defaults — CLI flags and global settings above take precedence. +# Keys are "provider/model" strings matching the model you use. +# modelSettings: +# anthropic/claude-sonnet-4-5-20250929: +# temperature: 0.3 +# maxTokens: 8192 +# openai/gpt-4o: +# temperature: 0.7 +# topP: 0.95 +# topK: 40 +# frequencyPenalty: 0.1 +# presencePenalty: 0.1 +# anthropic/claude-opus-4-6: +# thinkingLevel: "high" +# maxTokens: 16384 + # API Configuration (can also use environment variables) # provider-api-key: "your-api-key" # API key for OpenAI, Anthropic, or Google # provider-url: "https://api.openai.com/v1" # Base URL for OpenAI, Anthropic, or Ollama + +# Custom model definitions (under custom/ provider) +# customModels: +# my-local-llama: +# name: "Local Llama 3" +# baseUrl: "http://localhost:8080/v1" +# family: "llama" +# temperature: true +# cost: +# input: 0.0 +# output: 0.0 +# limit: +# context: 131072 +# output: 8192 +# params: # Generation parameter defaults for this model +# temperature: 0.8 +# topP: 0.95 +# topK: 40 ` _, err = file.WriteString(content) diff --git a/internal/kitsetup/setup.go b/internal/kitsetup/setup.go index c91f4b14..4a8f29a6 100644 --- a/internal/kitsetup/setup.go +++ b/internal/kitsetup/setup.go @@ -82,36 +82,55 @@ type AgentSetupResult struct { // BuildProviderConfig creates a *models.ProviderConfig from the current viper // state. All entry points (root, script, SDK) converge through this function. +// +// Generation parameter pointers (Temperature, TopP, etc.) are only set when +// the user has explicitly configured them via CLI flag, environment variable, +// or global config file. This allows per-model defaults from modelSettings +// and customModels to fill in unset parameters downstream. func BuildProviderConfig() (*models.ProviderConfig, string, error) { systemPrompt, err := config.LoadSystemPrompt(viper.GetString("system-prompt")) if err != nil { return nil, "", fmt.Errorf("failed to load system prompt: %w", err) } - temperature := float32(viper.GetFloat64("temperature")) - topP := float32(viper.GetFloat64("top-p")) - topK := int32(viper.GetInt("top-k")) - frequencyPenalty := float32(viper.GetFloat64("frequency-penalty")) - presencePenalty := float32(viper.GetFloat64("presence-penalty")) numGPU := int32(viper.GetInt("num-gpu-layers")) mainGPU := int32(viper.GetInt("main-gpu")) cfg := &models.ProviderConfig{ - ModelString: viper.GetString("model"), - SystemPrompt: systemPrompt, - ProviderAPIKey: viper.GetString("provider-api-key"), - ProviderURL: viper.GetString("provider-url"), - MaxTokens: viper.GetInt("max-tokens"), - Temperature: &temperature, - TopP: &topP, - TopK: &topK, - FrequencyPenalty: &frequencyPenalty, - PresencePenalty: &presencePenalty, - StopSequences: viper.GetStringSlice("stop-sequences"), - NumGPU: &numGPU, - MainGPU: &mainGPU, - TLSSkipVerify: viper.GetBool("tls-skip-verify"), - ThinkingLevel: models.ParseThinkingLevel(viper.GetString("thinking-level")), + ModelString: viper.GetString("model"), + SystemPrompt: systemPrompt, + ProviderAPIKey: viper.GetString("provider-api-key"), + ProviderURL: viper.GetString("provider-url"), + MaxTokens: viper.GetInt("max-tokens"), + StopSequences: viper.GetStringSlice("stop-sequences"), + NumGPU: &numGPU, + MainGPU: &mainGPU, + TLSSkipVerify: viper.GetBool("tls-skip-verify"), + ThinkingLevel: models.ParseThinkingLevel(viper.GetString("thinking-level")), + } + + // Only set generation parameter pointers when the user has explicitly + // provided a value. This leaves nil pointers for unset params, allowing + // per-model defaults (modelSettings / customModels params) to apply. + if viper.IsSet("temperature") { + v := float32(viper.GetFloat64("temperature")) + cfg.Temperature = &v + } + if viper.IsSet("top-p") { + v := float32(viper.GetFloat64("top-p")) + cfg.TopP = &v + } + if viper.IsSet("top-k") { + v := int32(viper.GetInt("top-k")) + cfg.TopK = &v + } + if viper.IsSet("frequency-penalty") { + v := float32(viper.GetFloat64("frequency-penalty")) + cfg.FrequencyPenalty = &v + } + if viper.IsSet("presence-penalty") { + v := float32(viper.GetFloat64("presence-penalty")) + cfg.PresencePenalty = &v } return cfg, systemPrompt, nil diff --git a/internal/models/custom.go b/internal/models/custom.go index 28a06935..56827f44 100644 --- a/internal/models/custom.go +++ b/internal/models/custom.go @@ -31,7 +31,7 @@ func loadCustomModelsFromConfig() map[string]ModelInfo { // modelConfigToModelInfo converts a CustomModelConfig to a ModelInfo. func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo { - return ModelInfo{ + info := ModelInfo{ ID: modelID, Name: cfg.Name, Attachment: cfg.Attachment, @@ -48,21 +48,210 @@ func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo { Output: cfg.Limit.Output, }, } + + // Convert custom model generation params if any are set. + if p := convertGenerationParams(cfg.Params); p != nil { + info.Params = p + } + + return info +} + +// loadModelSettingsFromConfig loads per-model generation parameter overrides +// from the config file. Keys are "provider/model" strings. Returns nil if +// no model settings are configured. +func loadModelSettingsFromConfig() map[string]*GenerationParams { + if !viper.IsSet("modelSettings") { + return nil + } + + var settings map[string]GenerationParamsConfig + if err := viper.UnmarshalKey("modelSettings", &settings); err != nil { + log.Printf("Warning: Failed to parse modelSettings: %v", err) + return nil + } + + result := make(map[string]*GenerationParams, len(settings)) + for modelKey, cfg := range settings { + if p := convertGenerationParams(cfg); p != nil { + result[modelKey] = p + } + } + + return result +} + +// convertGenerationParams converts a GenerationParamsConfig to a GenerationParams. +// Returns nil if no parameters are set. +func convertGenerationParams(cfg GenerationParamsConfig) *GenerationParams { + p := &GenerationParams{} + any := false + + if cfg.MaxTokens != nil { + p.MaxTokens = cfg.MaxTokens + any = true + } + if cfg.Temperature != nil { + p.Temperature = cfg.Temperature + any = true + } + if cfg.TopP != nil { + p.TopP = cfg.TopP + any = true + } + if cfg.TopK != nil { + p.TopK = cfg.TopK + any = true + } + if cfg.FrequencyPenalty != nil { + p.FrequencyPenalty = cfg.FrequencyPenalty + any = true + } + if cfg.PresencePenalty != nil { + p.PresencePenalty = cfg.PresencePenalty + any = true + } + if len(cfg.StopSequences) > 0 { + p.StopSequences = cfg.StopSequences + any = true + } + if cfg.ThinkingLevel != "" { + p.ThinkingLevel = ParseThinkingLevel(cfg.ThinkingLevel) + any = true + } + + if !any { + return nil + } + return p +} + +// ApplyModelSettings merges per-model generation parameter defaults from the +// registry into a ProviderConfig. Model-level params are only applied for +// fields where the user has not explicitly set a value (i.e., the +// corresponding viper key is not set via CLI flag or global config). +// +// The lookup order is: +// 1. modelSettings["provider/model"] from config (highest model-level priority) +// 2. ModelInfo.Params from custom model definitions +// +// Both are overridden by explicit CLI flags / global config values. +func ApplyModelSettings(config *ProviderConfig, modelInfo *ModelInfo) { + provider, modelName, err := ParseModelString(config.ModelString) + if err != nil { + return + } + + // Collect model-level params: modelSettings override > custom model params. + // modelSettings takes priority because it's the more specific/intentional config. + var params *GenerationParams + + // First check modelSettings from config. + if settings := loadModelSettingsFromConfig(); settings != nil { + modelKey := provider + "/" + modelName + if p, ok := settings[modelKey]; ok { + params = p + } + } + + // Fall back to ModelInfo.Params (from custom model definitions). + if params == nil && modelInfo != nil && modelInfo.Params != nil { + params = modelInfo.Params + } + + if params == nil { + return + } + + // Apply each parameter only when the user hasn't explicitly set it. + // We check viper.IsSet() which returns true only when the key was + // set via CLI flag, environment variable, or config file global section. + + if params.MaxTokens != nil && !isExplicitlySet("max-tokens") { + config.MaxTokens = *params.MaxTokens + } + if params.Temperature != nil && !isExplicitlySet("temperature") { + config.Temperature = params.Temperature + } + if params.TopP != nil && !isExplicitlySet("top-p") { + config.TopP = params.TopP + } + if params.TopK != nil && !isExplicitlySet("top-k") { + config.TopK = params.TopK + } + if params.FrequencyPenalty != nil && !isExplicitlySet("frequency-penalty") { + config.FrequencyPenalty = params.FrequencyPenalty + } + if params.PresencePenalty != nil && !isExplicitlySet("presence-penalty") { + config.PresencePenalty = params.PresencePenalty + } + if len(params.StopSequences) > 0 && !isExplicitlySet("stop-sequences") { + config.StopSequences = params.StopSequences + } + if params.ThinkingLevel != "" && !isExplicitlySet("thinking-level") { + config.ThinkingLevel = params.ThinkingLevel + } +} + +// isExplicitlySet returns true when the user has explicitly set a config key +// via CLI flag, environment variable, or the global section of the config file. +// Model-level defaults should not override explicitly set values. +func isExplicitlySet(key string) bool { + // viper.IsSet returns true if the key has been set in any of the + // data stores (flag, env, config file, default). We need to check + // whether the value was set at the global config level (not just + // as a default). For generation params, the global config keys use + // hyphenated names (e.g. "max-tokens", "top-p"). + // + // Since viper merges all sources, IsSet returns true even for config + // file values. This means global config file values (e.g. + // temperature: 0.7 at the top level) will correctly take precedence + // over model-level defaults, which is the desired behavior. + return viper.IsSet(key) +} + +// GenerationParams holds per-model generation parameter defaults. +// These are stored on ModelInfo and applied during provider creation. +// Nil pointer fields mean "no model-level default" — the global config +// or CLI flag value (if any) will be used instead. +type GenerationParams struct { + MaxTokens *int + Temperature *float32 + TopP *float32 + TopK *int32 + FrequencyPenalty *float32 + PresencePenalty *float32 + StopSequences []string + ThinkingLevel ThinkingLevel } // CustomModelConfig defines a custom model configuration loaded from the config file. // This is a duplicate here to avoid circular dependencies with internal/config. type CustomModelConfig struct { - Name string `json:"name" yaml:"name"` - BaseURL string `json:"baseUrl,omitempty" yaml:"baseUrl,omitempty"` - APIKey string `json:"apiKey,omitempty" yaml:"apiKey,omitempty"` - Family string `json:"family,omitempty" yaml:"family,omitempty"` - Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"` - Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"` - Temperature bool `json:"temperature,omitempty" yaml:"temperature,omitempty"` - Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"` - Cost CostConfig `json:"cost" yaml:"cost"` - Limit LimitConfig `json:"limit" yaml:"limit"` + Name string `json:"name" yaml:"name"` + BaseURL string `json:"baseUrl,omitempty" yaml:"baseUrl,omitempty"` + APIKey string `json:"apiKey,omitempty" yaml:"apiKey,omitempty"` + Family string `json:"family,omitempty" yaml:"family,omitempty"` + Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"` + Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"` + Temperature bool `json:"temperature,omitempty" yaml:"temperature,omitempty"` + Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"` + Cost CostConfig `json:"cost" yaml:"cost"` + Limit LimitConfig `json:"limit" yaml:"limit"` + Params GenerationParamsConfig `json:"params,omitzero" yaml:"params,omitempty"` +} + +// GenerationParamsConfig is the JSON/YAML-serializable form of generation +// parameter defaults. Used in both customModels[].params and modelSettings[]. +type GenerationParamsConfig struct { + MaxTokens *int `json:"maxTokens,omitempty" yaml:"maxTokens,omitempty"` + Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"` + TopP *float32 `json:"topP,omitempty" yaml:"topP,omitempty"` + TopK *int32 `json:"topK,omitempty" yaml:"topK,omitempty"` + FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty" yaml:"frequencyPenalty,omitempty"` + PresencePenalty *float32 `json:"presencePenalty,omitempty" yaml:"presencePenalty,omitempty"` + StopSequences []string `json:"stopSequences,omitempty" yaml:"stopSequences,omitempty"` + ThinkingLevel string `json:"thinkingLevel,omitempty" yaml:"thinkingLevel,omitempty"` } // CostConfig defines the pricing for a custom model. diff --git a/internal/models/custom_test.go b/internal/models/custom_test.go new file mode 100644 index 00000000..1a81f467 --- /dev/null +++ b/internal/models/custom_test.go @@ -0,0 +1,307 @@ +package models + +import ( + "testing" + + "github.com/spf13/viper" +) + +func TestConvertGenerationParams(t *testing.T) { + t.Run("empty config returns nil", func(t *testing.T) { + cfg := GenerationParamsConfig{} + p := convertGenerationParams(cfg) + if p != nil { + t.Errorf("expected nil, got %+v", p) + } + }) + + t.Run("temperature only", func(t *testing.T) { + temp := float32(0.7) + cfg := GenerationParamsConfig{Temperature: &temp} + p := convertGenerationParams(cfg) + if p == nil { + t.Fatal("expected non-nil") + } + if p.Temperature == nil || *p.Temperature != 0.7 { + t.Errorf("expected temperature 0.7, got %v", p.Temperature) + } + if p.TopP != nil { + t.Errorf("expected nil TopP, got %v", p.TopP) + } + }) + + t.Run("all params set", func(t *testing.T) { + maxTokens := 8192 + temp := float32(0.5) + topP := float32(0.9) + topK := int32(50) + freqPenalty := float32(0.1) + presPenalty := float32(0.2) + cfg := GenerationParamsConfig{ + MaxTokens: &maxTokens, + Temperature: &temp, + TopP: &topP, + TopK: &topK, + FrequencyPenalty: &freqPenalty, + PresencePenalty: &presPenalty, + StopSequences: []string{"STOP"}, + ThinkingLevel: "high", + } + p := convertGenerationParams(cfg) + if p == nil { + t.Fatal("expected non-nil") + } + if p.MaxTokens == nil || *p.MaxTokens != 8192 { + t.Errorf("expected maxTokens 8192, got %v", p.MaxTokens) + } + if p.Temperature == nil || *p.Temperature != 0.5 { + t.Errorf("expected temperature 0.5, got %v", p.Temperature) + } + if p.TopP == nil || *p.TopP != 0.9 { + t.Errorf("expected topP 0.9, got %v", p.TopP) + } + if p.TopK == nil || *p.TopK != 50 { + t.Errorf("expected topK 50, got %v", p.TopK) + } + if p.FrequencyPenalty == nil || *p.FrequencyPenalty != 0.1 { + t.Errorf("expected frequencyPenalty 0.1, got %v", p.FrequencyPenalty) + } + if p.PresencePenalty == nil || *p.PresencePenalty != 0.2 { + t.Errorf("expected presencePenalty 0.2, got %v", p.PresencePenalty) + } + if len(p.StopSequences) != 1 || p.StopSequences[0] != "STOP" { + t.Errorf("expected stop sequences [STOP], got %v", p.StopSequences) + } + if p.ThinkingLevel != ThinkingHigh { + t.Errorf("expected thinking level high, got %v", p.ThinkingLevel) + } + }) + + t.Run("thinking level parsing", func(t *testing.T) { + cfg := GenerationParamsConfig{ThinkingLevel: "medium"} + p := convertGenerationParams(cfg) + if p == nil { + t.Fatal("expected non-nil") + } + if p.ThinkingLevel != ThinkingMedium { + t.Errorf("expected thinking level medium, got %v", p.ThinkingLevel) + } + }) +} + +func TestModelConfigToModelInfoWithParams(t *testing.T) { + temp := float32(0.8) + topP := float32(0.95) + cfg := CustomModelConfig{ + Name: "Test Model", + BaseURL: "http://localhost:8080/v1", + Temperature: true, + Params: GenerationParamsConfig{ + Temperature: &temp, + TopP: &topP, + }, + } + + info := modelConfigToModelInfo("test-model", cfg) + + if info.Params == nil { + t.Fatal("expected non-nil Params") + } + if info.Params.Temperature == nil || *info.Params.Temperature != 0.8 { + t.Errorf("expected temperature 0.8, got %v", info.Params.Temperature) + } + if info.Params.TopP == nil || *info.Params.TopP != 0.95 { + t.Errorf("expected topP 0.95, got %v", info.Params.TopP) + } +} + +func TestModelConfigToModelInfoWithoutParams(t *testing.T) { + cfg := CustomModelConfig{ + Name: "Test Model", + BaseURL: "http://localhost:8080/v1", + } + + info := modelConfigToModelInfo("test-model", cfg) + + if info.Params != nil { + t.Errorf("expected nil Params, got %+v", info.Params) + } +} + +func TestApplyModelSettings(t *testing.T) { + // Save and restore viper state. + originalViper := viper.AllSettings() + defer func() { + viper.Reset() + for k, v := range originalViper { + viper.Set(k, v) + } + }() + + t.Run("applies model params when not explicitly set", func(t *testing.T) { + viper.Reset() + + temp := float32(0.8) + topK := int32(50) + maxTokens := 4096 + modelInfo := &ModelInfo{ + ID: "test-model", + Params: &GenerationParams{ + Temperature: &temp, + TopK: &topK, + MaxTokens: &maxTokens, + }, + } + + config := &ProviderConfig{ + ModelString: "custom/test-model", + } + + ApplyModelSettings(config, modelInfo) + + if config.Temperature == nil || *config.Temperature != 0.8 { + t.Errorf("expected temperature 0.8, got %v", config.Temperature) + } + if config.TopK == nil || *config.TopK != 50 { + t.Errorf("expected topK 50, got %v", config.TopK) + } + if config.MaxTokens != 4096 { + t.Errorf("expected maxTokens 4096, got %d", config.MaxTokens) + } + }) + + t.Run("explicit viper values take precedence", func(t *testing.T) { + viper.Reset() + viper.Set("temperature", 0.3) + + temp := float32(0.8) + modelInfo := &ModelInfo{ + ID: "test-model", + Params: &GenerationParams{ + Temperature: &temp, + }, + } + + explicitTemp := float32(0.3) + config := &ProviderConfig{ + ModelString: "custom/test-model", + Temperature: &explicitTemp, + } + + ApplyModelSettings(config, modelInfo) + + // Temperature should NOT be overridden because it's explicitly set in viper + if config.Temperature == nil || *config.Temperature != 0.3 { + t.Errorf("expected temperature 0.3 (explicit), got %v", config.Temperature) + } + }) + + t.Run("nil model info is safe", func(t *testing.T) { + viper.Reset() + + config := &ProviderConfig{ + ModelString: "custom/test-model", + } + + // Should not panic + ApplyModelSettings(config, nil) + + if config.Temperature != nil { + t.Errorf("expected nil temperature, got %v", config.Temperature) + } + }) + + t.Run("model info without params is safe", func(t *testing.T) { + viper.Reset() + + modelInfo := &ModelInfo{ID: "test-model"} + config := &ProviderConfig{ + ModelString: "custom/test-model", + } + + ApplyModelSettings(config, modelInfo) + + if config.Temperature != nil { + t.Errorf("expected nil temperature, got %v", config.Temperature) + } + }) + + t.Run("modelSettings from viper takes priority over ModelInfo.Params", func(t *testing.T) { + viper.Reset() + + // Set up modelSettings in viper (simulating config file) + viper.Set("modelSettings", map[string]any{ + "custom/test-model": map[string]any{ + "temperature": 0.5, + "topK": 30, + }, + }) + + // ModelInfo has different params + temp := float32(0.8) + topK := int32(50) + modelInfo := &ModelInfo{ + ID: "test-model", + Params: &GenerationParams{ + Temperature: &temp, + TopK: &topK, + }, + } + + config := &ProviderConfig{ + ModelString: "custom/test-model", + } + + ApplyModelSettings(config, modelInfo) + + // modelSettings should win over ModelInfo.Params + if config.Temperature == nil || *config.Temperature != 0.5 { + t.Errorf("expected temperature 0.5 (from modelSettings), got %v", config.Temperature) + } + if config.TopK == nil || *config.TopK != 30 { + t.Errorf("expected topK 30 (from modelSettings), got %v", config.TopK) + } + }) + + t.Run("stop sequences applied from model params", func(t *testing.T) { + viper.Reset() + + modelInfo := &ModelInfo{ + ID: "test-model", + Params: &GenerationParams{ + StopSequences: []string{"STOP", "END"}, + }, + } + + config := &ProviderConfig{ + ModelString: "custom/test-model", + } + + ApplyModelSettings(config, modelInfo) + + if len(config.StopSequences) != 2 || config.StopSequences[0] != "STOP" { + t.Errorf("expected stop sequences [STOP END], got %v", config.StopSequences) + } + }) + + t.Run("thinking level applied from model params", func(t *testing.T) { + viper.Reset() + + modelInfo := &ModelInfo{ + ID: "test-model", + Params: &GenerationParams{ + ThinkingLevel: ThinkingHigh, + }, + } + + config := &ProviderConfig{ + ModelString: "custom/test-model", + } + + ApplyModelSettings(config, modelInfo) + + if config.ThinkingLevel != ThinkingHigh { + t.Errorf("expected thinking level high, got %v", config.ThinkingLevel) + } + }) +} diff --git a/internal/models/providers.go b/internal/models/providers.go index 33fe0f14..ab9de4bd 100644 --- a/internal/models/providers.go +++ b/internal/models/providers.go @@ -241,6 +241,11 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul validateModelConfig(config, modelInfo) } + // Apply per-model generation parameter defaults. Model-level params are + // only applied for fields where the user hasn't explicitly set a value + // via CLI flag or global config. + ApplyModelSettings(config, modelInfo) + // Create the base provider var result *ProviderResult var createErr error diff --git a/internal/models/registry.go b/internal/models/registry.go index d7076073..283d27ab 100644 --- a/internal/models/registry.go +++ b/internal/models/registry.go @@ -26,6 +26,11 @@ type ModelInfo struct { ProviderNPM string // Model-specific provider npm override (e.g. "@ai-sdk/anthropic") BaseURL string // Per-model base URL override (custom models only) APIKey string // Per-model API key override (custom models only) + + // Params holds per-model generation parameter defaults. These are applied + // when the user hasn't explicitly set the corresponding CLI flag or global + // config value. Nil pointer fields mean "no model-level default". + Params *GenerationParams } // SupportsCaching returns true if this model family supports prompt caching. diff --git a/pkg/kit/kit.go b/pkg/kit/kit.go index 3e168bf9..3e756eeb 100644 --- a/pkg/kit/kit.go +++ b/pkg/kit/kit.go @@ -239,7 +239,7 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error { // With message-level caching, thinking and caching can work together. // No need to disable caching when thinking is enabled. - config := &models.ProviderConfig{ + cfg := &models.ProviderConfig{ ModelString: modelString, SystemPrompt: systemPrompt, ProviderAPIKey: viper.GetString("provider-api-key"), @@ -249,18 +249,32 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error { ThinkingLevel: thinkingLevel, DisableCaching: false, // Caching enabled by default, works with thinking } - temperature := float32(viper.GetFloat64("temperature")) - config.Temperature = &temperature - topP := float32(viper.GetFloat64("top-p")) - config.TopP = &topP - topK := int32(viper.GetInt("top-k")) - config.TopK = &topK - frequencyPenalty := float32(viper.GetFloat64("frequency-penalty")) - config.FrequencyPenalty = &frequencyPenalty - presencePenalty := float32(viper.GetFloat64("presence-penalty")) - config.PresencePenalty = &presencePenalty - if err := m.agent.SetModel(ctx, config); err != nil { + // Only set generation parameter pointers when the user has explicitly + // provided a value. This leaves nil pointers for unset params, allowing + // per-model defaults (modelSettings / customModels params) to apply. + if viper.IsSet("temperature") { + v := float32(viper.GetFloat64("temperature")) + cfg.Temperature = &v + } + if viper.IsSet("top-p") { + v := float32(viper.GetFloat64("top-p")) + cfg.TopP = &v + } + if viper.IsSet("top-k") { + v := int32(viper.GetInt("top-k")) + cfg.TopK = &v + } + if viper.IsSet("frequency-penalty") { + v := float32(viper.GetFloat64("frequency-penalty")) + cfg.FrequencyPenalty = &v + } + if viper.IsSet("presence-penalty") { + v := float32(viper.GetFloat64("presence-penalty")) + cfg.PresencePenalty = &v + } + + if err := m.agent.SetModel(ctx, cfg); err != nil { return err }