add up to date model data

This commit is contained in:
Ed Zynda
2025-06-18 13:48:08 +03:00
parent 78a12147d7
commit 9e87977822
4 changed files with 2898 additions and 0 deletions
+196
View File
@@ -0,0 +1,196 @@
//go:build ignore
package main
import (
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"text/template"
"time"
)
// ModelInfo represents information about a specific model
type ModelInfo struct {
ID string `json:"id"`
Name string `json:"name"`
Attachment bool `json:"attachment"`
Reasoning bool `json:"reasoning"`
Temperature bool `json:"temperature"`
Cost Cost `json:"cost"`
Limit Limit `json:"limit"`
}
// Cost represents the pricing information for a model
type Cost struct {
Input float64 `json:"input"`
Output float64 `json:"output"`
CacheRead *float64 `json:"cache_read,omitempty"`
CacheWrite *float64 `json:"cache_write,omitempty"`
}
// Limit represents the context and output limits for a model
type Limit struct {
Context int `json:"context"`
Output int `json:"output"`
}
// ProviderInfo represents information about a model provider
type ProviderInfo struct {
ID string `json:"id"`
Env []string `json:"env"`
NPM string `json:"npm"`
Name string `json:"name"`
Models map[string]ModelInfo `json:"models"`
}
const codeTemplate = `// Code generated by go generate; DO NOT EDIT.
// Generated at: {{.Timestamp}}
package models
// ModelInfo represents information about a specific model
type ModelInfo struct {
ID string
Name string
Attachment bool
Reasoning bool
Temperature bool
Cost Cost
Limit Limit
}
// Cost represents the pricing information for a model
type Cost struct {
Input float64
Output float64
CacheRead *float64
CacheWrite *float64
}
// Limit represents the context and output limits for a model
type Limit struct {
Context int
Output int
}
// ProviderInfo represents information about a model provider
type ProviderInfo struct {
ID string
Env []string
NPM string
Name string
Models map[string]ModelInfo
}
// GetModelsData returns the static models data from models.dev
func GetModelsData() map[string]ProviderInfo {
return map[string]ProviderInfo{
{{- range $providerID, $provider := .Providers}}
"{{$providerID}}": {
ID: "{{$provider.ID}}",
Env: []string{ {{- range $i, $env := $provider.Env}}{{if $i}}, {{end}}"{{$env}}"{{end}} },
NPM: "{{$provider.NPM}}",
Name: "{{$provider.Name}}",
Models: map[string]ModelInfo{
{{- range $modelID, $model := $provider.Models}}
"{{$modelID}}": {
ID: "{{$model.ID}}",
Name: "{{$model.Name}}",
Attachment: {{$model.Attachment}},
Reasoning: {{$model.Reasoning}},
Temperature: {{$model.Temperature}},
Cost: Cost{
Input: {{$model.Cost.Input}},
Output: {{$model.Cost.Output}},
{{- if $model.Cost.CacheRead}}
CacheRead: &[]float64{{"{"}}{{$model.Cost.CacheRead}}{{"}"}}[0],
{{- else}}
CacheRead: nil,
{{- end}}
{{- if $model.Cost.CacheWrite}}
CacheWrite: &[]float64{{"{"}}{{$model.Cost.CacheWrite}}{{"}"}}[0],
{{- else}}
CacheWrite: nil,
{{- end}}
},
Limit: Limit{
Context: {{$model.Limit.Context}},
Output: {{$model.Limit.Output}},
},
},
{{- end}}
},
},
{{- end}}
}
}
`
func main() {
fmt.Println("Fetching models data from models.dev...")
// Fetch data from API
resp, err := http.Get("https://models.dev/api.json")
if err != nil {
fmt.Fprintf(os.Stderr, "Error fetching data: %v\n", err)
os.Exit(1)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
fmt.Fprintf(os.Stderr, "API returned status %d\n", resp.StatusCode)
os.Exit(1)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Fprintf(os.Stderr, "Error reading response: %v\n", err)
os.Exit(1)
}
// Parse JSON
var providers map[string]ProviderInfo
if err := json.Unmarshal(body, &providers); err != nil {
fmt.Fprintf(os.Stderr, "Error parsing JSON: %v\n", err)
os.Exit(1)
}
// Generate Go code
tmpl, err := template.New("models").Parse(codeTemplate)
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing template: %v\n", err)
os.Exit(1)
}
// Create output file
file, err := os.Create("models_data.go")
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating output file: %v\n", err)
os.Exit(1)
}
defer file.Close()
// Execute template
data := struct {
Providers map[string]ProviderInfo
Timestamp string
}{
Providers: providers,
Timestamp: time.Now().Format(time.RFC3339),
}
if err := tmpl.Execute(file, data); err != nil {
fmt.Fprintf(os.Stderr, "Error executing template: %v\n", err)
os.Exit(1)
}
fmt.Printf("Generated models_data.go with %d providers\n", len(providers))
// Print summary
for providerID, provider := range providers {
fmt.Printf(" %s: %d models\n", providerID, len(provider.Models))
}
}
File diff suppressed because it is too large Load Diff
+42
View File
@@ -45,6 +45,32 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (model.ToolCall
provider := parts[0]
modelName := parts[1]
// Get the global registry for validation
registry := GetGlobalRegistry()
// Validate the model exists (skip for ollama as it's not in models.dev)
if provider != "ollama" {
modelInfo, err := registry.ValidateModel(provider, modelName)
if err != nil {
// Provide helpful suggestions
suggestions := registry.SuggestModels(provider, modelName)
if len(suggestions) > 0 {
return nil, fmt.Errorf("%v. Did you mean one of: %s", err, strings.Join(suggestions, ", "))
}
return nil, err
}
// Validate environment variables
if err := registry.ValidateEnvironment(provider, config.ProviderAPIKey); err != nil {
return nil, err
}
// Validate configuration parameters against model capabilities
if err := validateModelConfig(config, modelInfo); err != nil {
return nil, err
}
}
switch provider {
case "anthropic":
return createAnthropicProvider(ctx, config, modelName)
@@ -59,6 +85,22 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (model.ToolCall
}
}
// validateModelConfig validates configuration parameters against model capabilities
func validateModelConfig(config *ProviderConfig, modelInfo *ModelInfo) error {
// Check if temperature is supported
if config.Temperature != nil && !modelInfo.Temperature {
return fmt.Errorf("model %s does not support temperature parameter", modelInfo.ID)
}
// Warn about context limits if MaxTokens is set too high
if config.MaxTokens > modelInfo.Limit.Output {
return fmt.Errorf("max_tokens (%d) exceeds model's output limit (%d) for %s",
config.MaxTokens, modelInfo.Limit.Output, modelInfo.ID)
}
return nil
}
func createAnthropicProvider(ctx context.Context, config *ProviderConfig, modelName string) (model.ToolCallingChatModel, error) {
apiKey := config.ProviderAPIKey
if apiKey == "" {
+131
View File
@@ -0,0 +1,131 @@
//go:generate go run generate_models.go
package models
import (
"fmt"
"os"
"strings"
)
// ModelsRegistry provides validation and information about models
type ModelsRegistry struct {
providers map[string]ProviderInfo
}
// NewModelsRegistry creates a new models registry with static data
func NewModelsRegistry() *ModelsRegistry {
return &ModelsRegistry{
providers: GetModelsData(),
}
}
// ValidateModel validates if a model exists and returns detailed information
func (r *ModelsRegistry) ValidateModel(provider, modelID string) (*ModelInfo, error) {
providerInfo, exists := r.providers[provider]
if !exists {
return nil, fmt.Errorf("unsupported provider: %s", provider)
}
modelInfo, exists := providerInfo.Models[modelID]
if !exists {
return nil, fmt.Errorf("model %s not found for provider %s", modelID, provider)
}
return &modelInfo, nil
}
// GetRequiredEnvVars returns the required environment variables for a provider
func (r *ModelsRegistry) GetRequiredEnvVars(provider string) ([]string, error) {
providerInfo, exists := r.providers[provider]
if !exists {
return nil, fmt.Errorf("unsupported provider: %s", provider)
}
return providerInfo.Env, nil
}
// ValidateEnvironment checks if required environment variables are set
func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) error {
envVars, err := r.GetRequiredEnvVars(provider)
if err != nil {
return err
}
// If API key is provided via config, we don't need to check env vars
if apiKey != "" {
return nil
}
var missingVars []string
for _, envVar := range envVars {
if os.Getenv(envVar) == "" {
missingVars = append(missingVars, envVar)
}
}
if len(missingVars) > 0 {
return fmt.Errorf("missing required environment variables for %s: %s",
provider, strings.Join(missingVars, ", "))
}
return nil
}
// SuggestModels returns similar model names when an invalid model is provided
func (r *ModelsRegistry) SuggestModels(provider, invalidModel string) []string {
providerInfo, exists := r.providers[provider]
if !exists {
return nil
}
var suggestions []string
invalidLower := strings.ToLower(invalidModel)
// Look for models that contain parts of the invalid model name
for modelID, modelInfo := range providerInfo.Models {
modelIDLower := strings.ToLower(modelID)
modelNameLower := strings.ToLower(modelInfo.Name)
// Check if the invalid model is a substring of existing models
if strings.Contains(modelIDLower, invalidLower) ||
strings.Contains(modelNameLower, invalidLower) ||
strings.Contains(invalidLower, strings.ToLower(strings.Split(modelID, "-")[0])) {
suggestions = append(suggestions, modelID)
}
}
// Limit suggestions to avoid overwhelming output
if len(suggestions) > 5 {
suggestions = suggestions[:5]
}
return suggestions
}
// GetSupportedProviders returns a list of all supported providers
func (r *ModelsRegistry) GetSupportedProviders() []string {
var providers []string
for providerID := range r.providers {
providers = append(providers, providerID)
}
return providers
}
// GetModelsForProvider returns all models for a specific provider
func (r *ModelsRegistry) GetModelsForProvider(provider string) (map[string]ModelInfo, error) {
providerInfo, exists := r.providers[provider]
if !exists {
return nil, fmt.Errorf("unsupported provider: %s", provider)
}
return providerInfo.Models, nil
}
// Global registry instance
var globalRegistry = NewModelsRegistry()
// GetGlobalRegistry returns the global models registry instance
func GetGlobalRegistry() *ModelsRegistry {
return globalRegistry
}