mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
febdc530e1
* feat(auth): add Copilot login Add experimental GitHub Copilot device login and copilot/* provider support for users with Copilot access but no OpenAI account. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix(copilot): use responses for GPT-5 Route Copilot GPT-5 models through the Responses API because gpt-5.5 is not available on /chat/completions. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix(copilot): honor device flow timing * docs(copilot): add auth helper docstrings * fix(auth): address copilot review feedback --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
509 lines
15 KiB
Go
509 lines
15 KiB
Go
package models
|
|
|
|
import (
|
|
_ "embed"
|
|
"encoding/json"
|
|
"fmt"
|
|
"maps"
|
|
"os"
|
|
"strings"
|
|
|
|
"github.com/mark3labs/kit/internal/auth"
|
|
)
|
|
|
|
//go:embed embedded_models.json
|
|
var embeddedModelsJSON []byte
|
|
|
|
// ModelInfo represents information about a specific model.
|
|
type ModelInfo struct {
|
|
ID string
|
|
Name string
|
|
Family string // Model family (e.g., "claude", "gpt", "gemini")
|
|
Attachment bool
|
|
Reasoning bool
|
|
Temperature bool
|
|
Cost Cost
|
|
Limit Limit
|
|
ProviderNPM string // Model-specific provider npm override (e.g. "@ai-sdk/anthropic")
|
|
BaseURL string // Per-model base URL override (custom models only)
|
|
APIKey string // Per-model API key override (custom models only)
|
|
|
|
// Params holds per-model generation parameter defaults. These are applied
|
|
// when the user hasn't explicitly set the corresponding CLI flag or global
|
|
// config value. Nil pointer fields mean "no model-level default".
|
|
Params *GenerationParams
|
|
}
|
|
|
|
// SupportsCaching returns true if this model family supports prompt caching.
|
|
// This enables automatic cost savings for supported models regardless of provider.
|
|
func (m *ModelInfo) SupportsCaching() bool {
|
|
switch {
|
|
case strings.HasPrefix(m.Family, "claude"):
|
|
return true
|
|
case strings.HasPrefix(m.Family, "gpt"),
|
|
strings.HasPrefix(m.Family, "o1"),
|
|
strings.HasPrefix(m.Family, "o3"),
|
|
strings.HasPrefix(m.Family, "o4"),
|
|
strings.HasPrefix(m.Family, "codex"):
|
|
return true
|
|
case strings.HasPrefix(m.Family, "gemini"):
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// CacheType returns the appropriate cache mechanism for this model family.
|
|
// Returns empty string if caching is not supported.
|
|
func (m *ModelInfo) CacheType() string {
|
|
switch {
|
|
case strings.HasPrefix(m.Family, "claude"):
|
|
return "anthropic-ephemeral"
|
|
case strings.HasPrefix(m.Family, "gpt"),
|
|
strings.HasPrefix(m.Family, "o1"),
|
|
strings.HasPrefix(m.Family, "o3"),
|
|
strings.HasPrefix(m.Family, "o4"),
|
|
strings.HasPrefix(m.Family, "codex"):
|
|
return "openai-prompt-cache"
|
|
case strings.HasPrefix(m.Family, "gemini"):
|
|
return "google-cached-content"
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
// Cost represents the pricing information for a model.
|
|
type Cost struct {
|
|
Input float64
|
|
Output float64
|
|
CacheRead *float64
|
|
CacheWrite *float64
|
|
}
|
|
|
|
// Limit represents the context and output limits for a model.
|
|
type Limit struct {
|
|
Context int
|
|
Output int
|
|
}
|
|
|
|
// ProviderInfo represents information about a model provider.
|
|
type ProviderInfo struct {
|
|
ID string
|
|
Env []string
|
|
NPM string // npm package identifier from models.dev (e.g. "@ai-sdk/openai-compatible")
|
|
API string // base API URL for openai-compatible providers
|
|
Name string
|
|
Models map[string]ModelInfo
|
|
}
|
|
|
|
// ModelsRegistry provides validation and information about models.
|
|
// It maintains a registry of all supported LLM providers and their models,
|
|
// including capabilities, pricing, and configuration requirements.
|
|
// The registry data comes from models.dev.
|
|
type ModelsRegistry struct {
|
|
providers map[string]ProviderInfo
|
|
}
|
|
|
|
// NewModelsRegistry creates a new models registry populated from models.dev data.
|
|
func NewModelsRegistry() *ModelsRegistry {
|
|
return &ModelsRegistry{
|
|
providers: buildFromModelsDB(),
|
|
}
|
|
}
|
|
|
|
// buildFromModelsDB converts models.dev provider data into our internal format.
|
|
// It starts from the compile-time embedded database and merges on-disk cached
|
|
// data from `kit update-models` on top. Cached provider metadata replaces
|
|
// embedded metadata, and model entries are merged with cached models taking
|
|
// precedence. This means newly synced models are available while embedded
|
|
// models that haven't been synced yet are still reachable.
|
|
func buildFromModelsDB() map[string]ProviderInfo {
|
|
// Start with compile-time embedded data as the base.
|
|
dbProviders := loadEmbeddedProviders()
|
|
if dbProviders == nil {
|
|
dbProviders = make(ModelsDBProviders)
|
|
}
|
|
|
|
// Merge on-disk cached data on top (cached takes precedence).
|
|
if cached, _ := LoadCachedProviders(); len(cached) > 0 {
|
|
for providerID, cp := range cached {
|
|
if existing, ok := dbProviders[providerID]; ok {
|
|
// Merge models: embedded base + cached overrides.
|
|
mergedModels := make(map[string]modelsDBModel, len(existing.Models)+len(cp.Models))
|
|
maps.Copy(mergedModels, existing.Models)
|
|
maps.Copy(mergedModels, cp.Models)
|
|
cp.Models = mergedModels
|
|
}
|
|
dbProviders[providerID] = cp
|
|
}
|
|
}
|
|
|
|
providers := make(map[string]ProviderInfo, len(dbProviders))
|
|
|
|
for providerID, dp := range dbProviders {
|
|
modelsMap := make(map[string]ModelInfo, len(dp.Models))
|
|
for modelID, dm := range dp.Models {
|
|
providerNPM := ""
|
|
if dm.Provider != nil {
|
|
providerNPM = dm.Provider.NPM
|
|
}
|
|
modelsMap[modelID] = ModelInfo{
|
|
ID: dm.ID,
|
|
Name: dm.Name,
|
|
Family: dm.Family,
|
|
Attachment: dm.Attachment,
|
|
Reasoning: dm.Reasoning,
|
|
Temperature: dm.Temperature,
|
|
Cost: Cost{
|
|
Input: dm.Cost.Input,
|
|
Output: dm.Cost.Output,
|
|
CacheRead: dm.Cost.CacheRead,
|
|
CacheWrite: dm.Cost.CacheWrite,
|
|
},
|
|
Limit: Limit{
|
|
Context: dm.Limit.Context,
|
|
Output: dm.Limit.Output,
|
|
},
|
|
ProviderNPM: providerNPM,
|
|
}
|
|
}
|
|
|
|
providers[providerID] = ProviderInfo{
|
|
ID: providerID,
|
|
Env: dp.Env,
|
|
NPM: dp.NPM,
|
|
API: dp.API,
|
|
Name: dp.Name,
|
|
Models: modelsMap,
|
|
}
|
|
}
|
|
|
|
// Ensure ollama is always present (not in models.dev — it's a local server)
|
|
if _, exists := providers["ollama"]; !exists {
|
|
providers["ollama"] = ProviderInfo{
|
|
ID: "ollama",
|
|
Name: "Ollama",
|
|
Models: make(map[string]ModelInfo),
|
|
}
|
|
}
|
|
|
|
// Register the "custom" provider stub for --provider-url without --model.
|
|
// This allows users to point kit at any OpenAI-compatible endpoint without
|
|
// needing to specify a model from the database.
|
|
providers["custom"] = ProviderInfo{
|
|
ID: "custom",
|
|
Name: "Custom",
|
|
Models: map[string]ModelInfo{
|
|
"custom": {
|
|
ID: "custom",
|
|
Name: "Custom",
|
|
Attachment: false,
|
|
Reasoning: true,
|
|
Temperature: true,
|
|
Cost: Cost{
|
|
Input: 0,
|
|
Output: 0,
|
|
},
|
|
Limit: Limit{
|
|
Context: 262_144,
|
|
Output: 65_536,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
// Load custom models from config file and merge into custom provider.
|
|
// Config file models take precedence - if a model ID exists in both
|
|
// models.dev and config, the config version wins.
|
|
if customModels := loadCustomModelsFromConfig(); customModels != nil {
|
|
for modelID, info := range customModels {
|
|
// Validate custom model config
|
|
if info.Limit.Context <= 0 {
|
|
fmt.Fprintf(os.Stderr, "Warning: custom model %q has invalid context limit: %d\n", modelID, info.Limit.Context)
|
|
}
|
|
if info.Limit.Output <= 0 {
|
|
fmt.Fprintf(os.Stderr, "Warning: custom model %q has invalid output limit: %d\n", modelID, info.Limit.Output)
|
|
}
|
|
providers["custom"].Models[modelID] = info
|
|
}
|
|
}
|
|
|
|
return providers
|
|
}
|
|
|
|
// loadEmbeddedProviders parses the compile-time embedded models.dev snapshot.
|
|
func loadEmbeddedProviders() map[string]modelsDBProvider {
|
|
var providers map[string]modelsDBProvider
|
|
if err := json.Unmarshal(embeddedModelsJSON, &providers); err != nil {
|
|
return nil
|
|
}
|
|
return providers
|
|
}
|
|
|
|
// LookupModel returns model metadata from the database if available.
|
|
// Returns nil when the model or provider is not in the database — this is
|
|
// expected for new models, custom fine-tunes, or providers the database
|
|
// doesn't track yet. Callers should treat a nil return as "unknown model"
|
|
// and continue with sensible defaults.
|
|
func (r *ModelsRegistry) LookupModel(provider, modelID string) *ModelInfo {
|
|
provider = catalogProviderID(provider)
|
|
providerInfo, exists := r.providers[provider]
|
|
if !exists {
|
|
return nil
|
|
}
|
|
|
|
modelInfo, exists := providerInfo.Models[modelID]
|
|
if !exists {
|
|
return nil
|
|
}
|
|
|
|
return &modelInfo
|
|
}
|
|
|
|
// LookupModelForSettings is a convenience function that parses a
|
|
// "provider/model" string and looks up the ModelInfo in the global registry.
|
|
// Returns nil when the model string is invalid or the model is unknown.
|
|
// Used by Kit.SetModel to pre-apply per-model settings before CreateProvider.
|
|
func LookupModelForSettings(modelString string) *ModelInfo {
|
|
provider, modelName, err := ParseModelString(modelString)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
return GetGlobalRegistry().LookupModel(provider, modelName)
|
|
}
|
|
|
|
// getRequiredEnvVars returns the required environment variables for a provider.
|
|
func (r *ModelsRegistry) getRequiredEnvVars(provider string) ([]string, error) {
|
|
provider = catalogProviderID(provider)
|
|
providerInfo, exists := r.providers[provider]
|
|
if !exists {
|
|
return nil, fmt.Errorf("unsupported provider: %s", provider)
|
|
}
|
|
|
|
return providerInfo.Env, nil
|
|
}
|
|
|
|
// ValidateEnvironment checks if required credentials are available for a
|
|
// provider. It checks the explicit API key, stored credentials (for
|
|
// providers that support them, such as Anthropic OAuth), and environment
|
|
// variables. Returns nil for providers not in the registry (unknown
|
|
// providers are assumed to handle auth themselves or via --provider-api-key).
|
|
func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) error {
|
|
provider = catalogProviderID(provider)
|
|
if apiKey != "" {
|
|
return nil
|
|
}
|
|
|
|
// For anthropic, also check stored credentials (OAuth / API key)
|
|
// since auth resolution goes through the credential manager, not
|
|
// just environment variables.
|
|
if provider == "anthropic" {
|
|
if cm, err := auth.NewCredentialManager(); err == nil {
|
|
if has, _ := cm.HasAnthropicCredentials(); has {
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
// For openai, check stored credentials (OAuth / API key)
|
|
if provider == "openai" {
|
|
if cm, err := auth.NewCredentialManager(); err == nil {
|
|
if has, _ := cm.HasOpenAICredentials(); has {
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
// For GitHub Copilot, check stored GitHub OAuth credentials.
|
|
if provider == copilotProviderID {
|
|
if cm, err := auth.NewCredentialManager(); err == nil {
|
|
if has, _ := cm.HasCopilotCredentials(); has {
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
envVars, err := r.getRequiredEnvVars(provider)
|
|
if err != nil {
|
|
// Unknown provider — nothing to validate
|
|
return nil
|
|
}
|
|
|
|
if len(envVars) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Add alternative environment variable names for google-vertex-anthropic
|
|
if provider == "google-vertex-anthropic" {
|
|
envVars = append(envVars,
|
|
"ANTHROPIC_VERTEX_PROJECT_ID",
|
|
"GOOGLE_CLOUD_PROJECT",
|
|
"GCLOUD_PROJECT",
|
|
"CLOUDSDK_CORE_PROJECT",
|
|
"ANTHROPIC_VERTEX_REGION",
|
|
"CLOUD_ML_REGION",
|
|
)
|
|
}
|
|
|
|
// Add GOOGLE_API_KEY as an alternative for google
|
|
if provider == "google" || provider == "gemini" {
|
|
envVars = append(envVars, "GOOGLE_API_KEY")
|
|
}
|
|
|
|
for _, envVar := range envVars {
|
|
if os.Getenv(envVar) != "" {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return fmt.Errorf("missing required environment variables for %s: %s (at least one required)",
|
|
provider, strings.Join(envVars, ", "))
|
|
}
|
|
|
|
// SuggestModels returns similar model names when an invalid model is provided.
|
|
func (r *ModelsRegistry) SuggestModels(provider, invalidModel string) []string {
|
|
provider = catalogProviderID(provider)
|
|
providerInfo, exists := r.providers[provider]
|
|
if !exists {
|
|
return nil
|
|
}
|
|
|
|
var suggestions []string
|
|
invalidLower := strings.ToLower(invalidModel)
|
|
|
|
for modelID, modelInfo := range providerInfo.Models {
|
|
modelIDLower := strings.ToLower(modelID)
|
|
modelNameLower := strings.ToLower(modelInfo.Name)
|
|
|
|
if strings.Contains(modelIDLower, invalidLower) ||
|
|
strings.Contains(modelNameLower, invalidLower) ||
|
|
strings.Contains(invalidLower, strings.ToLower(strings.Split(modelID, "-")[0])) {
|
|
suggestions = append(suggestions, modelID)
|
|
}
|
|
}
|
|
|
|
if len(suggestions) > 5 {
|
|
suggestions = suggestions[:5]
|
|
}
|
|
|
|
return suggestions
|
|
}
|
|
|
|
// GetSupportedProviders returns a list of all provider IDs in the registry.
|
|
func (r *ModelsRegistry) GetSupportedProviders() []string {
|
|
providers := make([]string, 0, len(r.providers))
|
|
for providerID := range r.providers {
|
|
providers = append(providers, providerID)
|
|
}
|
|
return providers
|
|
}
|
|
|
|
// GetLLMProviders returns provider IDs that have LLM support,
|
|
// either through a native provider or via openaicompat auto-routing.
|
|
func (r *ModelsRegistry) GetLLMProviders() []string {
|
|
var providers []string
|
|
for providerID, info := range r.providers {
|
|
if isProviderLLMSupported(providerID, &info) {
|
|
providers = append(providers, providerID)
|
|
}
|
|
}
|
|
return providers
|
|
}
|
|
|
|
// isProviderLLMSupported checks if a provider can be used with the LLM layer.
|
|
func isProviderLLMSupported(providerID string, info *ProviderInfo) bool {
|
|
// Ollama and custom are always supported (model names are user-defined).
|
|
if providerID == "ollama" || providerID == "custom" {
|
|
return true
|
|
}
|
|
|
|
// Check if npm maps to a known wire protocol
|
|
if _, ok := npmToWireProtocol[info.NPM]; ok {
|
|
return true
|
|
}
|
|
|
|
// Any provider with an API URL can be auto-routed through openaicompat
|
|
return info.API != ""
|
|
}
|
|
|
|
// GetModelsForProvider returns all models for a specific provider.
|
|
func (r *ModelsRegistry) GetModelsForProvider(provider string) (map[string]ModelInfo, error) {
|
|
provider = catalogProviderID(provider)
|
|
providerInfo, exists := r.providers[provider]
|
|
if !exists {
|
|
return nil, fmt.Errorf("unsupported provider: %s", provider)
|
|
}
|
|
|
|
return providerInfo.Models, nil
|
|
}
|
|
|
|
// GetProviderInfo returns the full provider info, or nil if not found.
|
|
func (r *ModelsRegistry) GetProviderInfo(provider string) *ProviderInfo {
|
|
provider = catalogProviderID(provider)
|
|
info, exists := r.providers[provider]
|
|
if !exists {
|
|
return nil
|
|
}
|
|
return &info
|
|
}
|
|
|
|
// ValidateModelString checks whether a model string is well-formed and refers
|
|
// to a known provider. It returns a user-friendly error with suggestions when
|
|
// the model or provider is unrecognised. Passing validation does not guarantee
|
|
// that API authentication will succeed — it only catches obvious mistakes
|
|
// (typos, missing provider prefix, non-existent provider names) early so that
|
|
// callers such as subagent spawning can return fast feedback.
|
|
//
|
|
// Unknown models under a known provider are allowed (the provider API is the
|
|
// authority), but a completely unknown provider is rejected.
|
|
func (r *ModelsRegistry) ValidateModelString(modelString string) error {
|
|
provider, modelName, err := ParseModelString(modelString)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Ollama and custom are always valid — model names are user-defined.
|
|
if provider == "ollama" || provider == "custom" {
|
|
return nil
|
|
}
|
|
|
|
// Check if the provider exists in the registry.
|
|
providerInfo := r.GetProviderInfo(provider)
|
|
if providerInfo == nil {
|
|
known := r.GetSupportedProviders()
|
|
return fmt.Errorf(
|
|
"unknown provider %q in model string %q. Known providers: %s",
|
|
provider, modelString, strings.Join(known, ", "),
|
|
)
|
|
}
|
|
|
|
// Provider exists — check if the model is known. An unknown model is
|
|
// only a warning (the provider API decides), but we surface suggestions
|
|
// so the caller can self-correct.
|
|
if r.LookupModel(provider, modelName) == nil {
|
|
if suggestions := r.SuggestModels(provider, modelName); len(suggestions) > 0 {
|
|
return fmt.Errorf(
|
|
"model %q not found for provider %s. Did you mean one of: %s",
|
|
modelName, provider, strings.Join(suggestions, ", "),
|
|
)
|
|
}
|
|
// No suggestions — let it through; the provider API is the authority.
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Global registry instance
|
|
var globalRegistry = NewModelsRegistry()
|
|
|
|
// GetGlobalRegistry returns the global models registry instance.
|
|
func GetGlobalRegistry() *ModelsRegistry {
|
|
return globalRegistry
|
|
}
|
|
|
|
// ReloadGlobalRegistry rebuilds the global registry from the current
|
|
// data sources (cache → embedded). Call after updating the cache.
|
|
func ReloadGlobalRegistry() {
|
|
globalRegistry = NewModelsRegistry()
|
|
}
|