2025-06-18 13:48:08 +03:00
|
|
|
package models
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"fmt"
|
|
|
|
|
"os"
|
|
|
|
|
"strings"
|
2026-02-25 18:17:25 +03:00
|
|
|
|
|
|
|
|
"charm.land/catwalk/pkg/embedded"
|
2025-06-18 13:48:08 +03:00
|
|
|
)
|
|
|
|
|
|
2026-02-25 18:17:25 +03:00
|
|
|
// ModelInfo represents information about a specific model.
|
|
|
|
|
type ModelInfo struct {
|
|
|
|
|
ID string
|
|
|
|
|
Name string
|
|
|
|
|
Attachment bool
|
|
|
|
|
Reasoning bool
|
|
|
|
|
Temperature bool
|
|
|
|
|
Cost Cost
|
|
|
|
|
Limit Limit
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Cost represents the pricing information for a model.
|
|
|
|
|
type Cost struct {
|
|
|
|
|
Input float64
|
|
|
|
|
Output float64
|
|
|
|
|
CacheRead *float64
|
|
|
|
|
CacheWrite *float64
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Limit represents the context and output limits for a model.
|
|
|
|
|
type Limit struct {
|
|
|
|
|
Context int
|
|
|
|
|
Output int
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ProviderInfo represents information about a model provider.
|
|
|
|
|
type ProviderInfo struct {
|
|
|
|
|
ID string
|
|
|
|
|
Env []string
|
|
|
|
|
Name string
|
|
|
|
|
Models map[string]ModelInfo
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// providerEnvVars maps provider IDs to their required environment variables.
|
|
|
|
|
// Catwalk provides APIKey field names but we need the actual env var names.
|
|
|
|
|
var providerEnvVars = map[string][]string{
|
|
|
|
|
"anthropic": {"ANTHROPIC_API_KEY"},
|
|
|
|
|
"openai": {"OPENAI_API_KEY"},
|
|
|
|
|
"google": {"GOOGLE_API_KEY", "GEMINI_API_KEY", "GOOGLE_GENERATIVE_AI_API_KEY"},
|
|
|
|
|
"azure": {"AZURE_OPENAI_API_KEY"},
|
|
|
|
|
"openrouter": {"OPENROUTER_API_KEY"},
|
|
|
|
|
"bedrock": {"AWS_ACCESS_KEY_ID"},
|
|
|
|
|
"google-vertex-anthropic": {"GOOGLE_APPLICATION_CREDENTIALS"},
|
|
|
|
|
"ollama": {},
|
|
|
|
|
"mistral": {"MISTRAL_API_KEY"},
|
|
|
|
|
"groq": {"GROQ_API_KEY"},
|
|
|
|
|
"deepseek": {"DEEPSEEK_API_KEY"},
|
|
|
|
|
"xai": {"XAI_API_KEY"},
|
|
|
|
|
"fireworks": {"FIREWORKS_API_KEY"},
|
|
|
|
|
"together": {"TOGETHER_API_KEY"},
|
|
|
|
|
"perplexity": {"PERPLEXITY_API_KEY"},
|
|
|
|
|
"alibaba": {"DASHSCOPE_API_KEY"},
|
|
|
|
|
"cohere": {"COHERE_API_KEY"},
|
|
|
|
|
}
|
|
|
|
|
|
2025-11-12 16:48:46 +03:00
|
|
|
// ModelsRegistry provides validation and information about models.
|
|
|
|
|
// It maintains a registry of all supported LLM providers and their models,
|
|
|
|
|
// including capabilities, pricing, and configuration requirements.
|
2026-02-25 18:17:25 +03:00
|
|
|
// The registry data comes from the catwalk embedded database.
|
2025-06-18 13:48:08 +03:00
|
|
|
type ModelsRegistry struct {
|
|
|
|
|
providers map[string]ProviderInfo
|
|
|
|
|
}
|
|
|
|
|
|
2026-02-25 18:17:25 +03:00
|
|
|
// NewModelsRegistry creates a new models registry populated from the catwalk embedded database.
|
2025-06-18 13:48:08 +03:00
|
|
|
func NewModelsRegistry() *ModelsRegistry {
|
|
|
|
|
return &ModelsRegistry{
|
2026-02-25 18:17:25 +03:00
|
|
|
providers: buildFromCatwalk(),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// buildFromCatwalk converts catwalk embedded data into our internal format.
|
|
|
|
|
func buildFromCatwalk() map[string]ProviderInfo {
|
|
|
|
|
providers := make(map[string]ProviderInfo)
|
|
|
|
|
|
|
|
|
|
for _, cp := range embedded.GetAll() {
|
|
|
|
|
providerID := string(cp.ID)
|
|
|
|
|
|
|
|
|
|
modelsMap := make(map[string]ModelInfo, len(cp.Models))
|
|
|
|
|
for _, cm := range cp.Models {
|
|
|
|
|
var cacheRead, cacheWrite *float64
|
|
|
|
|
if cm.CostPer1MInCached > 0 {
|
|
|
|
|
v := cm.CostPer1MInCached
|
|
|
|
|
cacheRead = &v
|
|
|
|
|
}
|
|
|
|
|
if cm.CostPer1MOutCached > 0 {
|
|
|
|
|
v := cm.CostPer1MOutCached
|
|
|
|
|
cacheWrite = &v
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
hasTemperature := true // most models support temperature
|
|
|
|
|
if cm.Options.Temperature != nil && *cm.Options.Temperature == 0 {
|
|
|
|
|
hasTemperature = false
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
modelsMap[cm.ID] = ModelInfo{
|
|
|
|
|
ID: cm.ID,
|
|
|
|
|
Name: cm.Name,
|
|
|
|
|
Attachment: cm.SupportsImages,
|
|
|
|
|
Reasoning: cm.CanReason,
|
|
|
|
|
Temperature: hasTemperature,
|
|
|
|
|
Cost: Cost{
|
|
|
|
|
Input: cm.CostPer1MIn,
|
|
|
|
|
Output: cm.CostPer1MOut,
|
|
|
|
|
CacheRead: cacheRead,
|
|
|
|
|
CacheWrite: cacheWrite,
|
|
|
|
|
},
|
|
|
|
|
Limit: Limit{
|
|
|
|
|
Context: int(cm.ContextWindow),
|
|
|
|
|
Output: int(cm.DefaultMaxTokens),
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
envVars := providerEnvVars[providerID]
|
|
|
|
|
if envVars == nil {
|
|
|
|
|
// Derive from the catwalk APIKey field if available
|
|
|
|
|
if cp.APIKey != "" {
|
|
|
|
|
envVars = []string{cp.APIKey}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
providers[providerID] = ProviderInfo{
|
|
|
|
|
ID: providerID,
|
|
|
|
|
Env: envVars,
|
|
|
|
|
Name: cp.Name,
|
|
|
|
|
Models: modelsMap,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Ensure providers that mcphost explicitly supports are always present
|
|
|
|
|
// even if catwalk doesn't list them (e.g. ollama, google-vertex-anthropic)
|
|
|
|
|
ensureProvider(providers, "ollama", "Ollama", nil)
|
|
|
|
|
ensureProvider(providers, "google-vertex-anthropic", "Google Vertex (Anthropic)",
|
|
|
|
|
providerEnvVars["google-vertex-anthropic"])
|
|
|
|
|
|
|
|
|
|
return providers
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ensureProvider ensures a provider entry exists in the map.
|
|
|
|
|
func ensureProvider(providers map[string]ProviderInfo, id, name string, env []string) {
|
|
|
|
|
if _, exists := providers[id]; !exists {
|
|
|
|
|
providers[id] = ProviderInfo{
|
|
|
|
|
ID: id,
|
|
|
|
|
Env: env,
|
|
|
|
|
Name: name,
|
|
|
|
|
Models: make(map[string]ModelInfo),
|
|
|
|
|
}
|
2025-06-18 13:48:08 +03:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2025-11-12 16:48:46 +03:00
|
|
|
// ValidateModel validates if a model exists and returns detailed information.
|
2025-06-18 13:48:08 +03:00
|
|
|
func (r *ModelsRegistry) ValidateModel(provider, modelID string) (*ModelInfo, error) {
|
|
|
|
|
providerInfo, exists := r.providers[provider]
|
|
|
|
|
if !exists {
|
|
|
|
|
return nil, fmt.Errorf("unsupported provider: %s", provider)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
modelInfo, exists := providerInfo.Models[modelID]
|
|
|
|
|
if !exists {
|
|
|
|
|
return nil, fmt.Errorf("model %s not found for provider %s", modelID, provider)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return &modelInfo, nil
|
|
|
|
|
}
|
|
|
|
|
|
2025-11-12 16:48:46 +03:00
|
|
|
// GetRequiredEnvVars returns the required environment variables for a provider.
|
2025-06-18 13:48:08 +03:00
|
|
|
func (r *ModelsRegistry) GetRequiredEnvVars(provider string) ([]string, error) {
|
|
|
|
|
providerInfo, exists := r.providers[provider]
|
|
|
|
|
if !exists {
|
|
|
|
|
return nil, fmt.Errorf("unsupported provider: %s", provider)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return providerInfo.Env, nil
|
|
|
|
|
}
|
|
|
|
|
|
2025-11-12 16:48:46 +03:00
|
|
|
// ValidateEnvironment checks if required environment variables are set.
|
2025-06-18 13:48:08 +03:00
|
|
|
func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) error {
|
|
|
|
|
envVars, err := r.GetRequiredEnvVars(provider)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if apiKey != "" {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
2026-01-10 05:01:18 -05:00
|
|
|
// Add alternative environment variable names for google-vertex-anthropic
|
|
|
|
|
if provider == "google-vertex-anthropic" {
|
|
|
|
|
envVars = append(envVars,
|
|
|
|
|
"ANTHROPIC_VERTEX_PROJECT_ID",
|
|
|
|
|
"GOOGLE_CLOUD_PROJECT",
|
|
|
|
|
"GCLOUD_PROJECT",
|
|
|
|
|
"CLOUDSDK_CORE_PROJECT",
|
|
|
|
|
"ANTHROPIC_VERTEX_REGION",
|
|
|
|
|
"CLOUD_ML_REGION",
|
|
|
|
|
)
|
|
|
|
|
}
|
|
|
|
|
|
2025-06-27 11:40:11 +03:00
|
|
|
var foundVar bool
|
2025-06-18 13:48:08 +03:00
|
|
|
for _, envVar := range envVars {
|
2025-06-27 11:40:11 +03:00
|
|
|
if os.Getenv(envVar) != "" {
|
|
|
|
|
foundVar = true
|
|
|
|
|
break
|
2025-06-18 13:48:08 +03:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2025-06-27 11:40:11 +03:00
|
|
|
if !foundVar {
|
|
|
|
|
return fmt.Errorf("missing required environment variables for %s: %s (at least one required)",
|
|
|
|
|
provider, strings.Join(envVars, ", "))
|
2025-06-18 13:48:08 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
2025-11-12 16:48:46 +03:00
|
|
|
// SuggestModels returns similar model names when an invalid model is provided.
|
2025-06-18 13:48:08 +03:00
|
|
|
func (r *ModelsRegistry) SuggestModels(provider, invalidModel string) []string {
|
|
|
|
|
providerInfo, exists := r.providers[provider]
|
|
|
|
|
if !exists {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var suggestions []string
|
|
|
|
|
invalidLower := strings.ToLower(invalidModel)
|
|
|
|
|
|
|
|
|
|
for modelID, modelInfo := range providerInfo.Models {
|
|
|
|
|
modelIDLower := strings.ToLower(modelID)
|
|
|
|
|
modelNameLower := strings.ToLower(modelInfo.Name)
|
|
|
|
|
|
|
|
|
|
if strings.Contains(modelIDLower, invalidLower) ||
|
|
|
|
|
strings.Contains(modelNameLower, invalidLower) ||
|
|
|
|
|
strings.Contains(invalidLower, strings.ToLower(strings.Split(modelID, "-")[0])) {
|
|
|
|
|
suggestions = append(suggestions, modelID)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if len(suggestions) > 5 {
|
|
|
|
|
suggestions = suggestions[:5]
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return suggestions
|
|
|
|
|
}
|
|
|
|
|
|
2025-11-12 16:48:46 +03:00
|
|
|
// GetSupportedProviders returns a list of all supported providers.
|
2025-06-18 13:48:08 +03:00
|
|
|
func (r *ModelsRegistry) GetSupportedProviders() []string {
|
|
|
|
|
var providers []string
|
|
|
|
|
for providerID := range r.providers {
|
|
|
|
|
providers = append(providers, providerID)
|
|
|
|
|
}
|
|
|
|
|
return providers
|
|
|
|
|
}
|
|
|
|
|
|
2025-11-12 16:48:46 +03:00
|
|
|
// GetModelsForProvider returns all models for a specific provider.
|
2025-06-18 13:48:08 +03:00
|
|
|
func (r *ModelsRegistry) GetModelsForProvider(provider string) (map[string]ModelInfo, error) {
|
|
|
|
|
providerInfo, exists := r.providers[provider]
|
|
|
|
|
if !exists {
|
|
|
|
|
return nil, fmt.Errorf("unsupported provider: %s", provider)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return providerInfo.Models, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Global registry instance
|
|
|
|
|
var globalRegistry = NewModelsRegistry()
|
|
|
|
|
|
2025-11-12 16:48:46 +03:00
|
|
|
// GetGlobalRegistry returns the global models registry instance.
|
2025-06-18 13:48:08 +03:00
|
|
|
func GetGlobalRegistry() *ModelsRegistry {
|
|
|
|
|
return globalRegistry
|
|
|
|
|
}
|