mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
feat: add frequency-penalty and presence-penalty parameters
- Add --frequency-penalty and --presence-penalty CLI flags (0.0-2.0) - Wire through config, viper, ProviderConfig, and fantasy agent options - Support in config file, env vars (KIT_FREQUENCY_PENALTY), and SDK - Pass to Ollama via options map (frequency_penalty, presence_penalty) - Apply on both initial agent creation and runtime model swap
This commit is contained in:
+12
-6
@@ -48,12 +48,14 @@ var (
|
||||
noSessionFlag bool // --no-session: ephemeral mode, no persistence
|
||||
|
||||
// Model generation parameters
|
||||
maxTokens int
|
||||
temperature float32
|
||||
topP float32
|
||||
topK int32
|
||||
stopSequences []string
|
||||
thinkingLevel string
|
||||
maxTokens int
|
||||
temperature float32
|
||||
topP float32
|
||||
topK int32
|
||||
frequencyPenalty float32
|
||||
presencePenalty float32
|
||||
stopSequences []string
|
||||
thinkingLevel string
|
||||
|
||||
// Ollama-specific parameters
|
||||
numGPU int32
|
||||
@@ -291,6 +293,8 @@ func init() {
|
||||
flags.Float32Var(&temperature, "temperature", 0.7, "controls randomness in responses (0.0-1.0)")
|
||||
flags.Float32Var(&topP, "top-p", 0.95, "controls diversity via nucleus sampling (0.0-1.0)")
|
||||
flags.Int32Var(&topK, "top-k", 40, "controls diversity by limiting top K tokens to sample from")
|
||||
flags.Float32Var(&frequencyPenalty, "frequency-penalty", 0.0, "penalizes tokens based on frequency of appearance (0.0-2.0)")
|
||||
flags.Float32Var(&presencePenalty, "presence-penalty", 0.0, "penalizes tokens based on whether they have appeared (0.0-2.0)")
|
||||
flags.StringSliceVar(&stopSequences, "stop-sequences", nil, "custom stop sequences (comma-separated)")
|
||||
flags.StringVar(&thinkingLevel, "thinking-level", "off", "extended thinking level: off, minimal, low, medium, high")
|
||||
|
||||
@@ -313,6 +317,8 @@ func init() {
|
||||
_ = viper.BindPFlag("temperature", rootCmd.PersistentFlags().Lookup("temperature"))
|
||||
_ = viper.BindPFlag("top-p", rootCmd.PersistentFlags().Lookup("top-p"))
|
||||
_ = viper.BindPFlag("top-k", rootCmd.PersistentFlags().Lookup("top-k"))
|
||||
_ = viper.BindPFlag("frequency-penalty", rootCmd.PersistentFlags().Lookup("frequency-penalty"))
|
||||
_ = viper.BindPFlag("presence-penalty", rootCmd.PersistentFlags().Lookup("presence-penalty"))
|
||||
_ = viper.BindPFlag("stop-sequences", rootCmd.PersistentFlags().Lookup("stop-sequences"))
|
||||
_ = viper.BindPFlag("thinking-level", rootCmd.PersistentFlags().Lookup("thinking-level"))
|
||||
_ = viper.BindPFlag("num-gpu-layers", rootCmd.PersistentFlags().Lookup("num-gpu-layers"))
|
||||
|
||||
@@ -209,6 +209,12 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
if agentConfig.ModelConfig.TopK != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithTopK(int64(*agentConfig.ModelConfig.TopK)))
|
||||
}
|
||||
if agentConfig.ModelConfig.FrequencyPenalty != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithFrequencyPenalty(float64(*agentConfig.ModelConfig.FrequencyPenalty)))
|
||||
}
|
||||
if agentConfig.ModelConfig.PresencePenalty != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithPresencePenalty(float64(*agentConfig.ModelConfig.PresencePenalty)))
|
||||
}
|
||||
}
|
||||
|
||||
// Create the agent
|
||||
@@ -753,6 +759,12 @@ func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) err
|
||||
if config.TopK != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithTopK(int64(*config.TopK)))
|
||||
}
|
||||
if config.FrequencyPenalty != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithFrequencyPenalty(float64(*config.FrequencyPenalty)))
|
||||
}
|
||||
if config.PresencePenalty != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithPresencePenalty(float64(*config.PresencePenalty)))
|
||||
}
|
||||
|
||||
newFantasyAgent := fantasy.NewAgent(providerResult.Model, agentOpts...)
|
||||
|
||||
|
||||
@@ -199,11 +199,13 @@ type Config struct {
|
||||
Stream *bool `json:"stream,omitempty" yaml:"stream,omitempty"`
|
||||
Theme any `json:"theme" yaml:"theme"`
|
||||
// Model generation parameters
|
||||
MaxTokens int `json:"max-tokens,omitempty" yaml:"max-tokens,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
TopP *float32 `json:"top-p,omitempty" yaml:"top-p,omitempty"`
|
||||
TopK *int32 `json:"top-k,omitempty" yaml:"top-k,omitempty"`
|
||||
StopSequences []string `json:"stop-sequences,omitempty" yaml:"stop-sequences,omitempty"`
|
||||
MaxTokens int `json:"max-tokens,omitempty" yaml:"max-tokens,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
TopP *float32 `json:"top-p,omitempty" yaml:"top-p,omitempty"`
|
||||
TopK *int32 `json:"top-k,omitempty" yaml:"top-k,omitempty"`
|
||||
FrequencyPenalty *float32 `json:"frequency-penalty,omitempty" yaml:"frequency-penalty,omitempty"`
|
||||
PresencePenalty *float32 `json:"presence-penalty,omitempty" yaml:"presence-penalty,omitempty"`
|
||||
StopSequences []string `json:"stop-sequences,omitempty" yaml:"stop-sequences,omitempty"`
|
||||
|
||||
// Thinking / extended reasoning
|
||||
ThinkingLevel string `json:"thinking-level,omitempty" yaml:"thinking-level,omitempty"`
|
||||
@@ -370,6 +372,8 @@ mcpServers:
|
||||
# temperature: 0.7 # Randomness (0.0-1.0)
|
||||
# top-p: 0.95 # Nucleus sampling (0.0-1.0)
|
||||
# top-k: 40 # Top K sampling
|
||||
# frequency-penalty: 0.0 # Penalize frequent tokens (0.0-2.0)
|
||||
# presence-penalty: 0.0 # Penalize present tokens (0.0-2.0)
|
||||
# stop-sequences: ["Human:", "Assistant:"] # Custom stop sequences
|
||||
|
||||
# API Configuration (can also use environment variables)
|
||||
|
||||
+17
-13
@@ -84,23 +84,27 @@ func BuildProviderConfig() (*models.ProviderConfig, string, error) {
|
||||
temperature := float32(viper.GetFloat64("temperature"))
|
||||
topP := float32(viper.GetFloat64("top-p"))
|
||||
topK := int32(viper.GetInt("top-k"))
|
||||
frequencyPenalty := float32(viper.GetFloat64("frequency-penalty"))
|
||||
presencePenalty := float32(viper.GetFloat64("presence-penalty"))
|
||||
numGPU := int32(viper.GetInt("num-gpu-layers"))
|
||||
mainGPU := int32(viper.GetInt("main-gpu"))
|
||||
|
||||
cfg := &models.ProviderConfig{
|
||||
ModelString: viper.GetString("model"),
|
||||
SystemPrompt: systemPrompt,
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
ProviderURL: viper.GetString("provider-url"),
|
||||
MaxTokens: viper.GetInt("max-tokens"),
|
||||
Temperature: &temperature,
|
||||
TopP: &topP,
|
||||
TopK: &topK,
|
||||
StopSequences: viper.GetStringSlice("stop-sequences"),
|
||||
NumGPU: &numGPU,
|
||||
MainGPU: &mainGPU,
|
||||
TLSSkipVerify: viper.GetBool("tls-skip-verify"),
|
||||
ThinkingLevel: models.ParseThinkingLevel(viper.GetString("thinking-level")),
|
||||
ModelString: viper.GetString("model"),
|
||||
SystemPrompt: systemPrompt,
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
ProviderURL: viper.GetString("provider-url"),
|
||||
MaxTokens: viper.GetInt("max-tokens"),
|
||||
Temperature: &temperature,
|
||||
TopP: &topP,
|
||||
TopK: &topK,
|
||||
FrequencyPenalty: &frequencyPenalty,
|
||||
PresencePenalty: &presencePenalty,
|
||||
StopSequences: viper.GetStringSlice("stop-sequences"),
|
||||
NumGPU: &numGPU,
|
||||
MainGPU: &mainGPU,
|
||||
TLSSkipVerify: viper.GetBool("tls-skip-verify"),
|
||||
ThinkingLevel: models.ParseThinkingLevel(viper.GetString("thinking-level")),
|
||||
}
|
||||
|
||||
return cfg, systemPrompt, nil
|
||||
|
||||
@@ -143,20 +143,22 @@ func ParseThinkingLevel(s string) ThinkingLevel {
|
||||
|
||||
// ProviderConfig holds configuration for creating LLM providers.
|
||||
type ProviderConfig struct {
|
||||
ModelString string
|
||||
SystemPrompt string
|
||||
ProviderAPIKey string
|
||||
ProviderURL string
|
||||
MaxTokens int
|
||||
Temperature *float32
|
||||
TopP *float32
|
||||
TopK *int32
|
||||
StopSequences []string
|
||||
NumGPU *int32
|
||||
MainGPU *int32
|
||||
TLSSkipVerify bool
|
||||
ThinkingLevel ThinkingLevel
|
||||
DisableCaching bool // Opt-out: set to true to disable automatic prompt caching
|
||||
ModelString string
|
||||
SystemPrompt string
|
||||
ProviderAPIKey string
|
||||
ProviderURL string
|
||||
MaxTokens int
|
||||
Temperature *float32
|
||||
TopP *float32
|
||||
TopK *int32
|
||||
FrequencyPenalty *float32
|
||||
PresencePenalty *float32
|
||||
StopSequences []string
|
||||
NumGPU *int32
|
||||
MainGPU *int32
|
||||
TLSSkipVerify bool
|
||||
ThinkingLevel ThinkingLevel
|
||||
DisableCaching bool // Opt-out: set to true to disable automatic prompt caching
|
||||
}
|
||||
|
||||
// ProviderResult contains the result of provider creation.
|
||||
@@ -1164,6 +1166,12 @@ func buildOllamaOptions(config *ProviderConfig) map[string]any {
|
||||
if config.TopK != nil {
|
||||
options["top_k"] = int(*config.TopK)
|
||||
}
|
||||
if config.FrequencyPenalty != nil {
|
||||
options["frequency_penalty"] = *config.FrequencyPenalty
|
||||
}
|
||||
if config.PresencePenalty != nil {
|
||||
options["presence_penalty"] = *config.PresencePenalty
|
||||
}
|
||||
if len(config.StopSequences) > 0 {
|
||||
options["stop"] = config.StopSequences
|
||||
}
|
||||
|
||||
@@ -48,6 +48,8 @@ func setSDKDefaults() {
|
||||
viper.SetDefault("temperature", 0.7)
|
||||
viper.SetDefault("top-p", 0.95)
|
||||
viper.SetDefault("top-k", 40)
|
||||
viper.SetDefault("frequency-penalty", 0.0)
|
||||
viper.SetDefault("presence-penalty", 0.0)
|
||||
viper.SetDefault("stream", true)
|
||||
viper.SetDefault("thinking-level", "off")
|
||||
viper.SetDefault("num-gpu-layers", -1)
|
||||
|
||||
@@ -225,6 +225,10 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
|
||||
config.TopP = &topP
|
||||
topK := int32(viper.GetInt("top-k"))
|
||||
config.TopK = &topK
|
||||
frequencyPenalty := float32(viper.GetFloat64("frequency-penalty"))
|
||||
config.FrequencyPenalty = &frequencyPenalty
|
||||
presencePenalty := float32(viper.GetFloat64("presence-penalty"))
|
||||
config.PresencePenalty = &presencePenalty
|
||||
|
||||
if err := m.agent.SetModel(ctx, config); err != nil {
|
||||
return err
|
||||
|
||||
Reference in New Issue
Block a user