From 4ecdee7e254e8fc0e087e43eab1c970011de2ba5 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Fri, 27 Feb 2026 12:25:33 +0300 Subject: [PATCH] expose auth and model management APIs in SDK, migrate CLI to consume them (Plan 06) --- cmd/auth.go | 7 +++-- cmd/models.go | 24 +++++++------- cmd/script_deepseek_test.go | 14 ++++----- cmd/setup.go | 5 ++- pkg/kit/auth.go | 44 ++++++++++++++++++++++++++ pkg/kit/kit.go | 11 +++++++ pkg/kit/models.go | 63 +++++++++++++++++++++++++++++++++++++ 7 files changed, 142 insertions(+), 26 deletions(-) create mode 100644 pkg/kit/auth.go create mode 100644 pkg/kit/models.go diff --git a/cmd/auth.go b/cmd/auth.go index a75800a7..811d27e1 100644 --- a/cmd/auth.go +++ b/cmd/auth.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/mark3labs/kit/internal/auth" + kit "github.com/mark3labs/kit/pkg/kit" "github.com/spf13/cobra" ) @@ -117,7 +118,7 @@ func runAuthLogout(cmd *cobra.Command, args []string) error { } func runAuthStatus(cmd *cobra.Command, args []string) error { - cm, err := auth.NewCredentialManager() + cm, err := kit.NewCredentialManager() if err != nil { return fmt.Errorf("failed to initialize credential manager: %w", err) } @@ -163,7 +164,7 @@ func runAuthStatus(cmd *cobra.Command, args []string) error { } func loginAnthropic() error { - cm, err := auth.NewCredentialManager() + cm, err := kit.NewCredentialManager() if err != nil { return fmt.Errorf("failed to initialize credential manager: %w", err) } @@ -237,7 +238,7 @@ func loginAnthropic() error { } func logoutAnthropic() error { - cm, err := auth.NewCredentialManager() + cm, err := kit.NewCredentialManager() if err != nil { return fmt.Errorf("failed to initialize credential manager: %w", err) } diff --git a/cmd/models.go b/cmd/models.go index 4115a20a..6fe0131f 100644 --- a/cmd/models.go +++ b/cmd/models.go @@ -4,7 +4,7 @@ import ( "fmt" "sort" - "github.com/mark3labs/kit/internal/models" + kit "github.com/mark3labs/kit/pkg/kit" "github.com/spf13/cobra" ) @@ -39,28 +39,26 @@ func init() { } func runModels(_ *cobra.Command, args []string) error { - registry := models.GetGlobalRegistry() - if len(args) == 1 { - return printProvider(registry, args[0]) + return printProvider(args[0]) } - return printAllProviders(registry, modelsAllFlag) + return printAllProviders(modelsAllFlag) } -func printAllProviders(registry *models.ModelsRegistry, showAll bool) error { +func printAllProviders(showAll bool) error { var providerIDs []string if showAll { - providerIDs = registry.GetSupportedProviders() + providerIDs = kit.GetSupportedProviders() } else { - providerIDs = registry.GetFantasyProviders() + providerIDs = kit.GetFantasyProviders() } sort.Strings(providerIDs) // Filter to providers that have models var withModels []string for _, id := range providerIDs { - m, _ := registry.GetModelsForProvider(id) + m, _ := kit.GetModelsForProvider(id) if len(m) > 0 { withModels = append(withModels, id) } @@ -72,7 +70,7 @@ func printAllProviders(registry *models.ModelsRegistry, showAll bool) error { } for i, id := range withModels { - m, _ := registry.GetModelsForProvider(id) + m, _ := kit.GetModelsForProvider(id) modelIDs := sortedModelIDs(m) isLast := i == len(withModels)-1 @@ -99,8 +97,8 @@ func printAllProviders(registry *models.ModelsRegistry, showAll bool) error { return nil } -func printProvider(registry *models.ModelsRegistry, provider string) error { - m, err := registry.GetModelsForProvider(provider) +func printProvider(provider string) error { + m, err := kit.GetModelsForProvider(provider) if err != nil { return fmt.Errorf("unknown provider %q. Run 'kit models' to see all providers", provider) } @@ -118,7 +116,7 @@ func printProvider(registry *models.ModelsRegistry, provider string) error { return nil } -func sortedModelIDs(m map[string]models.ModelInfo) []string { +func sortedModelIDs(m map[string]kit.ModelInfo) []string { ids := make([]string, 0, len(m)) for id := range m { ids = append(ids, id) diff --git a/cmd/script_deepseek_test.go b/cmd/script_deepseek_test.go index 935798a6..9e2624b0 100644 --- a/cmd/script_deepseek_test.go +++ b/cmd/script_deepseek_test.go @@ -7,7 +7,7 @@ import ( "strings" "testing" - "github.com/mark3labs/kit/internal/models" + kit "github.com/mark3labs/kit/pkg/kit" ) // TestDeepSeekChatScriptMode tests the regression where deepseek-chat model @@ -55,7 +55,7 @@ Calculate 3 times 4 equal to? } // Now test the actual model creation - this should NOT fail when provider-url is set - providerConfig := &models.ProviderConfig{ + providerConfig := &kit.ProviderConfig{ ModelString: scriptConfig.Model, ProviderAPIKey: scriptConfig.ProviderAPIKey, ProviderURL: scriptConfig.ProviderURL, @@ -68,7 +68,7 @@ Calculate 3 times 4 equal to? // This should succeed because provider-url is set, which should skip model validation ctx := context.Background() - _, err = models.CreateProvider(ctx, providerConfig) + _, err = kit.CreateProvider(ctx, providerConfig) // We expect this to fail with a connection error (since we're using a fake API key), // NOT with a "model not found" error. The "model not found" error indicates @@ -86,7 +86,7 @@ Calculate 3 times 4 equal to? // TestDeepSeekChatCLIMode tests that the CLI mode works correctly with custom provider URL func TestDeepSeekChatCLIMode(t *testing.T) { // Test the CLI mode behavior - this should work - providerConfig := &models.ProviderConfig{ + providerConfig := &kit.ProviderConfig{ ModelString: "openai/deepseek-chat", ProviderAPIKey: "sk-test-key", ProviderURL: "https://api.deepseek.com/v1", // This should skip validation @@ -94,7 +94,7 @@ func TestDeepSeekChatCLIMode(t *testing.T) { } ctx := context.Background() - _, err := models.CreateProvider(ctx, providerConfig) + _, err := kit.CreateProvider(ctx, providerConfig) // We expect this to fail with a connection error (since we're using a fake API key), // NOT with a "model not found" error @@ -143,14 +143,14 @@ func TestProviderURLValidationSkip(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - providerConfig := &models.ProviderConfig{ + providerConfig := &kit.ProviderConfig{ ModelString: tc.model, ProviderAPIKey: "test-key", ProviderURL: tc.providerURL, } ctx := context.Background() - _, err := models.CreateProvider(ctx, providerConfig) + _, err := kit.CreateProvider(ctx, providerConfig) // Should never get a "not found for provider" error — unknown // models are passed through to the provider API. diff --git a/cmd/setup.go b/cmd/setup.go index cf0ae999..59a76a15 100644 --- a/cmd/setup.go +++ b/cmd/setup.go @@ -8,7 +8,6 @@ import ( "github.com/mark3labs/kit/internal/app" "github.com/mark3labs/kit/internal/config" "github.com/mark3labs/kit/internal/extensions" - "github.com/mark3labs/kit/internal/models" "github.com/mark3labs/kit/internal/ui" kit "github.com/mark3labs/kit/pkg/kit" "github.com/spf13/viper" @@ -23,7 +22,7 @@ type AgentSetupResult = kit.AgentSetupResult // BuildProviderConfig delegates to the SDK to build a ProviderConfig from // the current viper state. -func BuildProviderConfig() (*models.ProviderConfig, string, error) { +func BuildProviderConfig() (*kit.ProviderConfig, string, error) { return kit.BuildProviderConfig() } @@ -40,7 +39,7 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult, // app.Options and UI setup. func CollectAgentMetadata(mcpAgent *agent.Agent, mcpConfig *config.Config) (provider, modelName string, serverNames, toolNames []string) { modelString := viper.GetString("model") - provider, modelName, _ = models.ParseModelString(modelString) + provider, modelName, _ = kit.ParseModelString(modelString) if modelName == "" { modelName = "Unknown" } diff --git a/pkg/kit/auth.go b/pkg/kit/auth.go new file mode 100644 index 00000000..0b2329b5 --- /dev/null +++ b/pkg/kit/auth.go @@ -0,0 +1,44 @@ +package kit + +import "github.com/mark3labs/kit/internal/auth" + +// CredentialManager manages API keys and OAuth credentials. +type CredentialManager = auth.CredentialManager + +// AnthropicCredentials holds Anthropic API credentials supporting both OAuth +// and API key authentication methods. +type AnthropicCredentials = auth.AnthropicCredentials + +// CredentialStore holds all stored credentials for various providers. +type CredentialStore = auth.CredentialStore + +// NewCredentialManager creates a credential manager for secure storage and +// retrieval of authentication credentials. +func NewCredentialManager() (*CredentialManager, error) { + return auth.NewCredentialManager() +} + +// HasAnthropicCredentials checks if valid Anthropic credentials are stored +// (either OAuth token or API key). +func HasAnthropicCredentials() bool { + cm, err := auth.NewCredentialManager() + if err != nil { + return false + } + has, err := cm.HasAnthropicCredentials() + if err != nil { + return false + } + return has +} + +// GetAnthropicAPIKey resolves the Anthropic API key using the standard +// resolution order: stored credentials -> ANTHROPIC_API_KEY env var. +// Returns an empty string if no key is found. +func GetAnthropicAPIKey() string { + key, _, err := auth.GetAnthropicAPIKey("") + if err != nil { + return "" + } + return key +} diff --git a/pkg/kit/kit.go b/pkg/kit/kit.go index d92d2da8..e5ab171d 100644 --- a/pkg/kit/kit.go +++ b/pkg/kit/kit.go @@ -345,6 +345,17 @@ func (m *Kit) GetModelString() string { return m.modelString } +// GetModelInfo returns detailed information about the current model +// (capabilities, pricing, limits). Returns nil if the model is not in the +// registry — this is expected for new models or custom fine-tunes. +func (m *Kit) GetModelInfo() *ModelInfo { + provider, modelID, err := ParseModelString(m.modelString) + if err != nil { + return nil + } + return LookupModel(provider, modelID) +} + // GetTools returns all tools available to the agent (core + MCP + extensions). func (m *Kit) GetTools() []Tool { return m.agent.GetTools() diff --git a/pkg/kit/models.go b/pkg/kit/models.go new file mode 100644 index 00000000..ad8fc66f --- /dev/null +++ b/pkg/kit/models.go @@ -0,0 +1,63 @@ +package kit + +import ( + "fmt" + + "github.com/mark3labs/kit/internal/models" +) + +// LookupModel returns information about a model, or nil if unknown. +func LookupModel(provider, modelID string) *ModelInfo { + return models.GetGlobalRegistry().LookupModel(provider, modelID) +} + +// GetSupportedProviders returns all known provider names in the registry. +func GetSupportedProviders() []string { + return models.GetGlobalRegistry().GetSupportedProviders() +} + +// GetFantasyProviders returns provider IDs that can be used with fantasy, +// either through a native provider or via openaicompat auto-routing. +func GetFantasyProviders() []string { + return models.GetGlobalRegistry().GetFantasyProviders() +} + +// GetModelsForProvider returns all known models for a provider. +func GetModelsForProvider(provider string) (map[string]ModelInfo, error) { + return models.GetGlobalRegistry().GetModelsForProvider(provider) +} + +// GetProviderInfo returns information about a provider (env vars, API URL, etc.). +// Returns nil if the provider is not in the registry. +func GetProviderInfo(provider string) *ProviderInfo { + return models.GetGlobalRegistry().GetProviderInfo(provider) +} + +// ValidateEnvironment checks if required API keys are set for a provider. +// Returns nil for providers not in the registry (unknown providers are +// assumed to handle auth themselves or via --provider-api-key). +func ValidateEnvironment(provider string, apiKey string) error { + return models.GetGlobalRegistry().ValidateEnvironment(provider, apiKey) +} + +// SuggestModels returns model names similar to an invalid model string. +func SuggestModels(provider, invalidModel string) []string { + return models.GetGlobalRegistry().SuggestModels(provider, invalidModel) +} + +// RefreshModelRegistry reloads the global model database from the current +// data sources (cache -> embedded). Call after updating the cache. +func RefreshModelRegistry() { + models.ReloadGlobalRegistry() +} + +// CheckProviderReady validates that a provider is properly configured +// by checking that it exists in the registry and has required environment +// variables set. +func CheckProviderReady(provider string) error { + info := models.GetGlobalRegistry().GetProviderInfo(provider) + if info == nil { + return fmt.Errorf("unknown provider: %s", provider) + } + return models.GetGlobalRegistry().ValidateEnvironment(provider, "") +}