Files
Nuno do Carmo febdc530e1 Feat/copilot login (#49)
* feat(auth): add Copilot login

Add experimental GitHub Copilot device login and copilot/* provider support for users with Copilot access but no OpenAI account.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fix(copilot): use responses for GPT-5

Route Copilot GPT-5 models through the Responses API because gpt-5.5 is not available on /chat/completions.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fix(copilot): honor device flow timing

* docs(copilot): add auth helper docstrings

* fix(auth): address copilot review feedback

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-06-08 00:21:20 +03:00

509 lines
15 KiB
Go

package models
import (
_ "embed"
"encoding/json"
"fmt"
"maps"
"os"
"strings"
"github.com/mark3labs/kit/internal/auth"
)
//go:embed embedded_models.json
var embeddedModelsJSON []byte
// ModelInfo represents information about a specific model.
type ModelInfo struct {
ID string
Name string
Family string // Model family (e.g., "claude", "gpt", "gemini")
Attachment bool
Reasoning bool
Temperature bool
Cost Cost
Limit Limit
ProviderNPM string // Model-specific provider npm override (e.g. "@ai-sdk/anthropic")
BaseURL string // Per-model base URL override (custom models only)
APIKey string // Per-model API key override (custom models only)
// Params holds per-model generation parameter defaults. These are applied
// when the user hasn't explicitly set the corresponding CLI flag or global
// config value. Nil pointer fields mean "no model-level default".
Params *GenerationParams
}
// SupportsCaching returns true if this model family supports prompt caching.
// This enables automatic cost savings for supported models regardless of provider.
func (m *ModelInfo) SupportsCaching() bool {
switch {
case strings.HasPrefix(m.Family, "claude"):
return true
case strings.HasPrefix(m.Family, "gpt"),
strings.HasPrefix(m.Family, "o1"),
strings.HasPrefix(m.Family, "o3"),
strings.HasPrefix(m.Family, "o4"),
strings.HasPrefix(m.Family, "codex"):
return true
case strings.HasPrefix(m.Family, "gemini"):
return true
default:
return false
}
}
// CacheType returns the appropriate cache mechanism for this model family.
// Returns empty string if caching is not supported.
func (m *ModelInfo) CacheType() string {
switch {
case strings.HasPrefix(m.Family, "claude"):
return "anthropic-ephemeral"
case strings.HasPrefix(m.Family, "gpt"),
strings.HasPrefix(m.Family, "o1"),
strings.HasPrefix(m.Family, "o3"),
strings.HasPrefix(m.Family, "o4"),
strings.HasPrefix(m.Family, "codex"):
return "openai-prompt-cache"
case strings.HasPrefix(m.Family, "gemini"):
return "google-cached-content"
default:
return ""
}
}
// Cost represents the pricing information for a model.
type Cost struct {
Input float64
Output float64
CacheRead *float64
CacheWrite *float64
}
// Limit represents the context and output limits for a model.
type Limit struct {
Context int
Output int
}
// ProviderInfo represents information about a model provider.
type ProviderInfo struct {
ID string
Env []string
NPM string // npm package identifier from models.dev (e.g. "@ai-sdk/openai-compatible")
API string // base API URL for openai-compatible providers
Name string
Models map[string]ModelInfo
}
// ModelsRegistry provides validation and information about models.
// It maintains a registry of all supported LLM providers and their models,
// including capabilities, pricing, and configuration requirements.
// The registry data comes from models.dev.
type ModelsRegistry struct {
providers map[string]ProviderInfo
}
// NewModelsRegistry creates a new models registry populated from models.dev data.
func NewModelsRegistry() *ModelsRegistry {
return &ModelsRegistry{
providers: buildFromModelsDB(),
}
}
// buildFromModelsDB converts models.dev provider data into our internal format.
// It starts from the compile-time embedded database and merges on-disk cached
// data from `kit update-models` on top. Cached provider metadata replaces
// embedded metadata, and model entries are merged with cached models taking
// precedence. This means newly synced models are available while embedded
// models that haven't been synced yet are still reachable.
func buildFromModelsDB() map[string]ProviderInfo {
// Start with compile-time embedded data as the base.
dbProviders := loadEmbeddedProviders()
if dbProviders == nil {
dbProviders = make(ModelsDBProviders)
}
// Merge on-disk cached data on top (cached takes precedence).
if cached, _ := LoadCachedProviders(); len(cached) > 0 {
for providerID, cp := range cached {
if existing, ok := dbProviders[providerID]; ok {
// Merge models: embedded base + cached overrides.
mergedModels := make(map[string]modelsDBModel, len(existing.Models)+len(cp.Models))
maps.Copy(mergedModels, existing.Models)
maps.Copy(mergedModels, cp.Models)
cp.Models = mergedModels
}
dbProviders[providerID] = cp
}
}
providers := make(map[string]ProviderInfo, len(dbProviders))
for providerID, dp := range dbProviders {
modelsMap := make(map[string]ModelInfo, len(dp.Models))
for modelID, dm := range dp.Models {
providerNPM := ""
if dm.Provider != nil {
providerNPM = dm.Provider.NPM
}
modelsMap[modelID] = ModelInfo{
ID: dm.ID,
Name: dm.Name,
Family: dm.Family,
Attachment: dm.Attachment,
Reasoning: dm.Reasoning,
Temperature: dm.Temperature,
Cost: Cost{
Input: dm.Cost.Input,
Output: dm.Cost.Output,
CacheRead: dm.Cost.CacheRead,
CacheWrite: dm.Cost.CacheWrite,
},
Limit: Limit{
Context: dm.Limit.Context,
Output: dm.Limit.Output,
},
ProviderNPM: providerNPM,
}
}
providers[providerID] = ProviderInfo{
ID: providerID,
Env: dp.Env,
NPM: dp.NPM,
API: dp.API,
Name: dp.Name,
Models: modelsMap,
}
}
// Ensure ollama is always present (not in models.dev — it's a local server)
if _, exists := providers["ollama"]; !exists {
providers["ollama"] = ProviderInfo{
ID: "ollama",
Name: "Ollama",
Models: make(map[string]ModelInfo),
}
}
// Register the "custom" provider stub for --provider-url without --model.
// This allows users to point kit at any OpenAI-compatible endpoint without
// needing to specify a model from the database.
providers["custom"] = ProviderInfo{
ID: "custom",
Name: "Custom",
Models: map[string]ModelInfo{
"custom": {
ID: "custom",
Name: "Custom",
Attachment: false,
Reasoning: true,
Temperature: true,
Cost: Cost{
Input: 0,
Output: 0,
},
Limit: Limit{
Context: 262_144,
Output: 65_536,
},
},
},
}
// Load custom models from config file and merge into custom provider.
// Config file models take precedence - if a model ID exists in both
// models.dev and config, the config version wins.
if customModels := loadCustomModelsFromConfig(); customModels != nil {
for modelID, info := range customModels {
// Validate custom model config
if info.Limit.Context <= 0 {
fmt.Fprintf(os.Stderr, "Warning: custom model %q has invalid context limit: %d\n", modelID, info.Limit.Context)
}
if info.Limit.Output <= 0 {
fmt.Fprintf(os.Stderr, "Warning: custom model %q has invalid output limit: %d\n", modelID, info.Limit.Output)
}
providers["custom"].Models[modelID] = info
}
}
return providers
}
// loadEmbeddedProviders parses the compile-time embedded models.dev snapshot.
func loadEmbeddedProviders() map[string]modelsDBProvider {
var providers map[string]modelsDBProvider
if err := json.Unmarshal(embeddedModelsJSON, &providers); err != nil {
return nil
}
return providers
}
// LookupModel returns model metadata from the database if available.
// Returns nil when the model or provider is not in the database — this is
// expected for new models, custom fine-tunes, or providers the database
// doesn't track yet. Callers should treat a nil return as "unknown model"
// and continue with sensible defaults.
func (r *ModelsRegistry) LookupModel(provider, modelID string) *ModelInfo {
provider = catalogProviderID(provider)
providerInfo, exists := r.providers[provider]
if !exists {
return nil
}
modelInfo, exists := providerInfo.Models[modelID]
if !exists {
return nil
}
return &modelInfo
}
// LookupModelForSettings is a convenience function that parses a
// "provider/model" string and looks up the ModelInfo in the global registry.
// Returns nil when the model string is invalid or the model is unknown.
// Used by Kit.SetModel to pre-apply per-model settings before CreateProvider.
func LookupModelForSettings(modelString string) *ModelInfo {
provider, modelName, err := ParseModelString(modelString)
if err != nil {
return nil
}
return GetGlobalRegistry().LookupModel(provider, modelName)
}
// getRequiredEnvVars returns the required environment variables for a provider.
func (r *ModelsRegistry) getRequiredEnvVars(provider string) ([]string, error) {
provider = catalogProviderID(provider)
providerInfo, exists := r.providers[provider]
if !exists {
return nil, fmt.Errorf("unsupported provider: %s", provider)
}
return providerInfo.Env, nil
}
// ValidateEnvironment checks if required credentials are available for a
// provider. It checks the explicit API key, stored credentials (for
// providers that support them, such as Anthropic OAuth), and environment
// variables. Returns nil for providers not in the registry (unknown
// providers are assumed to handle auth themselves or via --provider-api-key).
func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) error {
provider = catalogProviderID(provider)
if apiKey != "" {
return nil
}
// For anthropic, also check stored credentials (OAuth / API key)
// since auth resolution goes through the credential manager, not
// just environment variables.
if provider == "anthropic" {
if cm, err := auth.NewCredentialManager(); err == nil {
if has, _ := cm.HasAnthropicCredentials(); has {
return nil
}
}
}
// For openai, check stored credentials (OAuth / API key)
if provider == "openai" {
if cm, err := auth.NewCredentialManager(); err == nil {
if has, _ := cm.HasOpenAICredentials(); has {
return nil
}
}
}
// For GitHub Copilot, check stored GitHub OAuth credentials.
if provider == copilotProviderID {
if cm, err := auth.NewCredentialManager(); err == nil {
if has, _ := cm.HasCopilotCredentials(); has {
return nil
}
}
}
envVars, err := r.getRequiredEnvVars(provider)
if err != nil {
// Unknown provider — nothing to validate
return nil
}
if len(envVars) == 0 {
return nil
}
// Add alternative environment variable names for google-vertex-anthropic
if provider == "google-vertex-anthropic" {
envVars = append(envVars,
"ANTHROPIC_VERTEX_PROJECT_ID",
"GOOGLE_CLOUD_PROJECT",
"GCLOUD_PROJECT",
"CLOUDSDK_CORE_PROJECT",
"ANTHROPIC_VERTEX_REGION",
"CLOUD_ML_REGION",
)
}
// Add GOOGLE_API_KEY as an alternative for google
if provider == "google" || provider == "gemini" {
envVars = append(envVars, "GOOGLE_API_KEY")
}
for _, envVar := range envVars {
if os.Getenv(envVar) != "" {
return nil
}
}
return fmt.Errorf("missing required environment variables for %s: %s (at least one required)",
provider, strings.Join(envVars, ", "))
}
// SuggestModels returns similar model names when an invalid model is provided.
func (r *ModelsRegistry) SuggestModels(provider, invalidModel string) []string {
provider = catalogProviderID(provider)
providerInfo, exists := r.providers[provider]
if !exists {
return nil
}
var suggestions []string
invalidLower := strings.ToLower(invalidModel)
for modelID, modelInfo := range providerInfo.Models {
modelIDLower := strings.ToLower(modelID)
modelNameLower := strings.ToLower(modelInfo.Name)
if strings.Contains(modelIDLower, invalidLower) ||
strings.Contains(modelNameLower, invalidLower) ||
strings.Contains(invalidLower, strings.ToLower(strings.Split(modelID, "-")[0])) {
suggestions = append(suggestions, modelID)
}
}
if len(suggestions) > 5 {
suggestions = suggestions[:5]
}
return suggestions
}
// GetSupportedProviders returns a list of all provider IDs in the registry.
func (r *ModelsRegistry) GetSupportedProviders() []string {
providers := make([]string, 0, len(r.providers))
for providerID := range r.providers {
providers = append(providers, providerID)
}
return providers
}
// GetLLMProviders returns provider IDs that have LLM support,
// either through a native provider or via openaicompat auto-routing.
func (r *ModelsRegistry) GetLLMProviders() []string {
var providers []string
for providerID, info := range r.providers {
if isProviderLLMSupported(providerID, &info) {
providers = append(providers, providerID)
}
}
return providers
}
// isProviderLLMSupported checks if a provider can be used with the LLM layer.
func isProviderLLMSupported(providerID string, info *ProviderInfo) bool {
// Ollama and custom are always supported (model names are user-defined).
if providerID == "ollama" || providerID == "custom" {
return true
}
// Check if npm maps to a known wire protocol
if _, ok := npmToWireProtocol[info.NPM]; ok {
return true
}
// Any provider with an API URL can be auto-routed through openaicompat
return info.API != ""
}
// GetModelsForProvider returns all models for a specific provider.
func (r *ModelsRegistry) GetModelsForProvider(provider string) (map[string]ModelInfo, error) {
provider = catalogProviderID(provider)
providerInfo, exists := r.providers[provider]
if !exists {
return nil, fmt.Errorf("unsupported provider: %s", provider)
}
return providerInfo.Models, nil
}
// GetProviderInfo returns the full provider info, or nil if not found.
func (r *ModelsRegistry) GetProviderInfo(provider string) *ProviderInfo {
provider = catalogProviderID(provider)
info, exists := r.providers[provider]
if !exists {
return nil
}
return &info
}
// ValidateModelString checks whether a model string is well-formed and refers
// to a known provider. It returns a user-friendly error with suggestions when
// the model or provider is unrecognised. Passing validation does not guarantee
// that API authentication will succeed — it only catches obvious mistakes
// (typos, missing provider prefix, non-existent provider names) early so that
// callers such as subagent spawning can return fast feedback.
//
// Unknown models under a known provider are allowed (the provider API is the
// authority), but a completely unknown provider is rejected.
func (r *ModelsRegistry) ValidateModelString(modelString string) error {
provider, modelName, err := ParseModelString(modelString)
if err != nil {
return err
}
// Ollama and custom are always valid — model names are user-defined.
if provider == "ollama" || provider == "custom" {
return nil
}
// Check if the provider exists in the registry.
providerInfo := r.GetProviderInfo(provider)
if providerInfo == nil {
known := r.GetSupportedProviders()
return fmt.Errorf(
"unknown provider %q in model string %q. Known providers: %s",
provider, modelString, strings.Join(known, ", "),
)
}
// Provider exists — check if the model is known. An unknown model is
// only a warning (the provider API decides), but we surface suggestions
// so the caller can self-correct.
if r.LookupModel(provider, modelName) == nil {
if suggestions := r.SuggestModels(provider, modelName); len(suggestions) > 0 {
return fmt.Errorf(
"model %q not found for provider %s. Did you mean one of: %s",
modelName, provider, strings.Join(suggestions, ", "),
)
}
// No suggestions — let it through; the provider API is the authority.
}
return nil
}
// Global registry instance
var globalRegistry = NewModelsRegistry()
// GetGlobalRegistry returns the global models registry instance.
func GetGlobalRegistry() *ModelsRegistry {
return globalRegistry
}
// ReloadGlobalRegistry rebuilds the global registry from the current
// data sources (cache → embedded). Call after updating the cache.
func ReloadGlobalRegistry() {
globalRegistry = NewModelsRegistry()
}