mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
add up to date model data
This commit is contained in:
@@ -0,0 +1,196 @@
|
||||
//go:build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"text/template"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ModelInfo represents information about a specific model
|
||||
type ModelInfo struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Attachment bool `json:"attachment"`
|
||||
Reasoning bool `json:"reasoning"`
|
||||
Temperature bool `json:"temperature"`
|
||||
Cost Cost `json:"cost"`
|
||||
Limit Limit `json:"limit"`
|
||||
}
|
||||
|
||||
// Cost represents the pricing information for a model
|
||||
type Cost struct {
|
||||
Input float64 `json:"input"`
|
||||
Output float64 `json:"output"`
|
||||
CacheRead *float64 `json:"cache_read,omitempty"`
|
||||
CacheWrite *float64 `json:"cache_write,omitempty"`
|
||||
}
|
||||
|
||||
// Limit represents the context and output limits for a model
|
||||
type Limit struct {
|
||||
Context int `json:"context"`
|
||||
Output int `json:"output"`
|
||||
}
|
||||
|
||||
// ProviderInfo represents information about a model provider
|
||||
type ProviderInfo struct {
|
||||
ID string `json:"id"`
|
||||
Env []string `json:"env"`
|
||||
NPM string `json:"npm"`
|
||||
Name string `json:"name"`
|
||||
Models map[string]ModelInfo `json:"models"`
|
||||
}
|
||||
|
||||
const codeTemplate = `// Code generated by go generate; DO NOT EDIT.
|
||||
// Generated at: {{.Timestamp}}
|
||||
|
||||
package models
|
||||
|
||||
// ModelInfo represents information about a specific model
|
||||
type ModelInfo struct {
|
||||
ID string
|
||||
Name string
|
||||
Attachment bool
|
||||
Reasoning bool
|
||||
Temperature bool
|
||||
Cost Cost
|
||||
Limit Limit
|
||||
}
|
||||
|
||||
// Cost represents the pricing information for a model
|
||||
type Cost struct {
|
||||
Input float64
|
||||
Output float64
|
||||
CacheRead *float64
|
||||
CacheWrite *float64
|
||||
}
|
||||
|
||||
// Limit represents the context and output limits for a model
|
||||
type Limit struct {
|
||||
Context int
|
||||
Output int
|
||||
}
|
||||
|
||||
// ProviderInfo represents information about a model provider
|
||||
type ProviderInfo struct {
|
||||
ID string
|
||||
Env []string
|
||||
NPM string
|
||||
Name string
|
||||
Models map[string]ModelInfo
|
||||
}
|
||||
|
||||
// GetModelsData returns the static models data from models.dev
|
||||
func GetModelsData() map[string]ProviderInfo {
|
||||
return map[string]ProviderInfo{
|
||||
{{- range $providerID, $provider := .Providers}}
|
||||
"{{$providerID}}": {
|
||||
ID: "{{$provider.ID}}",
|
||||
Env: []string{ {{- range $i, $env := $provider.Env}}{{if $i}}, {{end}}"{{$env}}"{{end}} },
|
||||
NPM: "{{$provider.NPM}}",
|
||||
Name: "{{$provider.Name}}",
|
||||
Models: map[string]ModelInfo{
|
||||
{{- range $modelID, $model := $provider.Models}}
|
||||
"{{$modelID}}": {
|
||||
ID: "{{$model.ID}}",
|
||||
Name: "{{$model.Name}}",
|
||||
Attachment: {{$model.Attachment}},
|
||||
Reasoning: {{$model.Reasoning}},
|
||||
Temperature: {{$model.Temperature}},
|
||||
Cost: Cost{
|
||||
Input: {{$model.Cost.Input}},
|
||||
Output: {{$model.Cost.Output}},
|
||||
{{- if $model.Cost.CacheRead}}
|
||||
CacheRead: &[]float64{{"{"}}{{$model.Cost.CacheRead}}{{"}"}}[0],
|
||||
{{- else}}
|
||||
CacheRead: nil,
|
||||
{{- end}}
|
||||
{{- if $model.Cost.CacheWrite}}
|
||||
CacheWrite: &[]float64{{"{"}}{{$model.Cost.CacheWrite}}{{"}"}}[0],
|
||||
{{- else}}
|
||||
CacheWrite: nil,
|
||||
{{- end}}
|
||||
},
|
||||
Limit: Limit{
|
||||
Context: {{$model.Limit.Context}},
|
||||
Output: {{$model.Limit.Output}},
|
||||
},
|
||||
},
|
||||
{{- end}}
|
||||
},
|
||||
},
|
||||
{{- end}}
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
func main() {
|
||||
fmt.Println("Fetching models data from models.dev...")
|
||||
|
||||
// Fetch data from API
|
||||
resp, err := http.Get("https://models.dev/api.json")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error fetching data: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
fmt.Fprintf(os.Stderr, "API returned status %d\n", resp.StatusCode)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error reading response: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Parse JSON
|
||||
var providers map[string]ProviderInfo
|
||||
if err := json.Unmarshal(body, &providers); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error parsing JSON: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Generate Go code
|
||||
tmpl, err := template.New("models").Parse(codeTemplate)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error parsing template: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Create output file
|
||||
file, err := os.Create("models_data.go")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error creating output file: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Execute template
|
||||
data := struct {
|
||||
Providers map[string]ProviderInfo
|
||||
Timestamp string
|
||||
}{
|
||||
Providers: providers,
|
||||
Timestamp: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
if err := tmpl.Execute(file, data); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error executing template: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Printf("Generated models_data.go with %d providers\n", len(providers))
|
||||
|
||||
// Print summary
|
||||
for providerID, provider := range providers {
|
||||
fmt.Printf(" %s: %d models\n", providerID, len(provider.Models))
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -45,6 +45,32 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (model.ToolCall
|
||||
provider := parts[0]
|
||||
modelName := parts[1]
|
||||
|
||||
// Get the global registry for validation
|
||||
registry := GetGlobalRegistry()
|
||||
|
||||
// Validate the model exists (skip for ollama as it's not in models.dev)
|
||||
if provider != "ollama" {
|
||||
modelInfo, err := registry.ValidateModel(provider, modelName)
|
||||
if err != nil {
|
||||
// Provide helpful suggestions
|
||||
suggestions := registry.SuggestModels(provider, modelName)
|
||||
if len(suggestions) > 0 {
|
||||
return nil, fmt.Errorf("%v. Did you mean one of: %s", err, strings.Join(suggestions, ", "))
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate environment variables
|
||||
if err := registry.ValidateEnvironment(provider, config.ProviderAPIKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate configuration parameters against model capabilities
|
||||
if err := validateModelConfig(config, modelInfo); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
switch provider {
|
||||
case "anthropic":
|
||||
return createAnthropicProvider(ctx, config, modelName)
|
||||
@@ -59,6 +85,22 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (model.ToolCall
|
||||
}
|
||||
}
|
||||
|
||||
// validateModelConfig validates configuration parameters against model capabilities
|
||||
func validateModelConfig(config *ProviderConfig, modelInfo *ModelInfo) error {
|
||||
// Check if temperature is supported
|
||||
if config.Temperature != nil && !modelInfo.Temperature {
|
||||
return fmt.Errorf("model %s does not support temperature parameter", modelInfo.ID)
|
||||
}
|
||||
|
||||
// Warn about context limits if MaxTokens is set too high
|
||||
if config.MaxTokens > modelInfo.Limit.Output {
|
||||
return fmt.Errorf("max_tokens (%d) exceeds model's output limit (%d) for %s",
|
||||
config.MaxTokens, modelInfo.Limit.Output, modelInfo.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createAnthropicProvider(ctx context.Context, config *ProviderConfig, modelName string) (model.ToolCallingChatModel, error) {
|
||||
apiKey := config.ProviderAPIKey
|
||||
if apiKey == "" {
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
//go:generate go run generate_models.go
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ModelsRegistry provides validation and information about models
|
||||
type ModelsRegistry struct {
|
||||
providers map[string]ProviderInfo
|
||||
}
|
||||
|
||||
// NewModelsRegistry creates a new models registry with static data
|
||||
func NewModelsRegistry() *ModelsRegistry {
|
||||
return &ModelsRegistry{
|
||||
providers: GetModelsData(),
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateModel validates if a model exists and returns detailed information
|
||||
func (r *ModelsRegistry) ValidateModel(provider, modelID string) (*ModelInfo, error) {
|
||||
providerInfo, exists := r.providers[provider]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("unsupported provider: %s", provider)
|
||||
}
|
||||
|
||||
modelInfo, exists := providerInfo.Models[modelID]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("model %s not found for provider %s", modelID, provider)
|
||||
}
|
||||
|
||||
return &modelInfo, nil
|
||||
}
|
||||
|
||||
// GetRequiredEnvVars returns the required environment variables for a provider
|
||||
func (r *ModelsRegistry) GetRequiredEnvVars(provider string) ([]string, error) {
|
||||
providerInfo, exists := r.providers[provider]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("unsupported provider: %s", provider)
|
||||
}
|
||||
|
||||
return providerInfo.Env, nil
|
||||
}
|
||||
|
||||
// ValidateEnvironment checks if required environment variables are set
|
||||
func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) error {
|
||||
envVars, err := r.GetRequiredEnvVars(provider)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If API key is provided via config, we don't need to check env vars
|
||||
if apiKey != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var missingVars []string
|
||||
for _, envVar := range envVars {
|
||||
if os.Getenv(envVar) == "" {
|
||||
missingVars = append(missingVars, envVar)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missingVars) > 0 {
|
||||
return fmt.Errorf("missing required environment variables for %s: %s",
|
||||
provider, strings.Join(missingVars, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SuggestModels returns similar model names when an invalid model is provided
|
||||
func (r *ModelsRegistry) SuggestModels(provider, invalidModel string) []string {
|
||||
providerInfo, exists := r.providers[provider]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
var suggestions []string
|
||||
invalidLower := strings.ToLower(invalidModel)
|
||||
|
||||
// Look for models that contain parts of the invalid model name
|
||||
for modelID, modelInfo := range providerInfo.Models {
|
||||
modelIDLower := strings.ToLower(modelID)
|
||||
modelNameLower := strings.ToLower(modelInfo.Name)
|
||||
|
||||
// Check if the invalid model is a substring of existing models
|
||||
if strings.Contains(modelIDLower, invalidLower) ||
|
||||
strings.Contains(modelNameLower, invalidLower) ||
|
||||
strings.Contains(invalidLower, strings.ToLower(strings.Split(modelID, "-")[0])) {
|
||||
suggestions = append(suggestions, modelID)
|
||||
}
|
||||
}
|
||||
|
||||
// Limit suggestions to avoid overwhelming output
|
||||
if len(suggestions) > 5 {
|
||||
suggestions = suggestions[:5]
|
||||
}
|
||||
|
||||
return suggestions
|
||||
}
|
||||
|
||||
// GetSupportedProviders returns a list of all supported providers
|
||||
func (r *ModelsRegistry) GetSupportedProviders() []string {
|
||||
var providers []string
|
||||
for providerID := range r.providers {
|
||||
providers = append(providers, providerID)
|
||||
}
|
||||
return providers
|
||||
}
|
||||
|
||||
// GetModelsForProvider returns all models for a specific provider
|
||||
func (r *ModelsRegistry) GetModelsForProvider(provider string) (map[string]ModelInfo, error) {
|
||||
providerInfo, exists := r.providers[provider]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("unsupported provider: %s", provider)
|
||||
}
|
||||
|
||||
return providerInfo.Models, nil
|
||||
}
|
||||
|
||||
// Global registry instance
|
||||
var globalRegistry = NewModelsRegistry()
|
||||
|
||||
// GetGlobalRegistry returns the global models registry instance
|
||||
func GetGlobalRegistry() *ModelsRegistry {
|
||||
return globalRegistry
|
||||
}
|
||||
Reference in New Issue
Block a user