mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
expose auth and model management APIs in SDK, migrate CLI to consume them (Plan 06)
This commit is contained in:
+4
-3
@@ -7,6 +7,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -117,7 +118,7 @@ func runAuthLogout(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
func runAuthStatus(cmd *cobra.Command, args []string) error {
|
||||
cm, err := auth.NewCredentialManager()
|
||||
cm, err := kit.NewCredentialManager()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize credential manager: %w", err)
|
||||
}
|
||||
@@ -163,7 +164,7 @@ func runAuthStatus(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
func loginAnthropic() error {
|
||||
cm, err := auth.NewCredentialManager()
|
||||
cm, err := kit.NewCredentialManager()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize credential manager: %w", err)
|
||||
}
|
||||
@@ -237,7 +238,7 @@ func loginAnthropic() error {
|
||||
}
|
||||
|
||||
func logoutAnthropic() error {
|
||||
cm, err := auth.NewCredentialManager()
|
||||
cm, err := kit.NewCredentialManager()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize credential manager: %w", err)
|
||||
}
|
||||
|
||||
+11
-13
@@ -4,7 +4,7 @@ import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -39,28 +39,26 @@ func init() {
|
||||
}
|
||||
|
||||
func runModels(_ *cobra.Command, args []string) error {
|
||||
registry := models.GetGlobalRegistry()
|
||||
|
||||
if len(args) == 1 {
|
||||
return printProvider(registry, args[0])
|
||||
return printProvider(args[0])
|
||||
}
|
||||
|
||||
return printAllProviders(registry, modelsAllFlag)
|
||||
return printAllProviders(modelsAllFlag)
|
||||
}
|
||||
|
||||
func printAllProviders(registry *models.ModelsRegistry, showAll bool) error {
|
||||
func printAllProviders(showAll bool) error {
|
||||
var providerIDs []string
|
||||
if showAll {
|
||||
providerIDs = registry.GetSupportedProviders()
|
||||
providerIDs = kit.GetSupportedProviders()
|
||||
} else {
|
||||
providerIDs = registry.GetFantasyProviders()
|
||||
providerIDs = kit.GetFantasyProviders()
|
||||
}
|
||||
sort.Strings(providerIDs)
|
||||
|
||||
// Filter to providers that have models
|
||||
var withModels []string
|
||||
for _, id := range providerIDs {
|
||||
m, _ := registry.GetModelsForProvider(id)
|
||||
m, _ := kit.GetModelsForProvider(id)
|
||||
if len(m) > 0 {
|
||||
withModels = append(withModels, id)
|
||||
}
|
||||
@@ -72,7 +70,7 @@ func printAllProviders(registry *models.ModelsRegistry, showAll bool) error {
|
||||
}
|
||||
|
||||
for i, id := range withModels {
|
||||
m, _ := registry.GetModelsForProvider(id)
|
||||
m, _ := kit.GetModelsForProvider(id)
|
||||
modelIDs := sortedModelIDs(m)
|
||||
|
||||
isLast := i == len(withModels)-1
|
||||
@@ -99,8 +97,8 @@ func printAllProviders(registry *models.ModelsRegistry, showAll bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func printProvider(registry *models.ModelsRegistry, provider string) error {
|
||||
m, err := registry.GetModelsForProvider(provider)
|
||||
func printProvider(provider string) error {
|
||||
m, err := kit.GetModelsForProvider(provider)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unknown provider %q. Run 'kit models' to see all providers", provider)
|
||||
}
|
||||
@@ -118,7 +116,7 @@ func printProvider(registry *models.ModelsRegistry, provider string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func sortedModelIDs(m map[string]models.ModelInfo) []string {
|
||||
func sortedModelIDs(m map[string]kit.ModelInfo) []string {
|
||||
ids := make([]string, 0, len(m))
|
||||
for id := range m {
|
||||
ids = append(ids, id)
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// TestDeepSeekChatScriptMode tests the regression where deepseek-chat model
|
||||
@@ -55,7 +55,7 @@ Calculate 3 times 4 equal to?
|
||||
}
|
||||
|
||||
// Now test the actual model creation - this should NOT fail when provider-url is set
|
||||
providerConfig := &models.ProviderConfig{
|
||||
providerConfig := &kit.ProviderConfig{
|
||||
ModelString: scriptConfig.Model,
|
||||
ProviderAPIKey: scriptConfig.ProviderAPIKey,
|
||||
ProviderURL: scriptConfig.ProviderURL,
|
||||
@@ -68,7 +68,7 @@ Calculate 3 times 4 equal to?
|
||||
|
||||
// This should succeed because provider-url is set, which should skip model validation
|
||||
ctx := context.Background()
|
||||
_, err = models.CreateProvider(ctx, providerConfig)
|
||||
_, err = kit.CreateProvider(ctx, providerConfig)
|
||||
|
||||
// We expect this to fail with a connection error (since we're using a fake API key),
|
||||
// NOT with a "model not found" error. The "model not found" error indicates
|
||||
@@ -86,7 +86,7 @@ Calculate 3 times 4 equal to?
|
||||
// TestDeepSeekChatCLIMode tests that the CLI mode works correctly with custom provider URL
|
||||
func TestDeepSeekChatCLIMode(t *testing.T) {
|
||||
// Test the CLI mode behavior - this should work
|
||||
providerConfig := &models.ProviderConfig{
|
||||
providerConfig := &kit.ProviderConfig{
|
||||
ModelString: "openai/deepseek-chat",
|
||||
ProviderAPIKey: "sk-test-key",
|
||||
ProviderURL: "https://api.deepseek.com/v1", // This should skip validation
|
||||
@@ -94,7 +94,7 @@ func TestDeepSeekChatCLIMode(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := models.CreateProvider(ctx, providerConfig)
|
||||
_, err := kit.CreateProvider(ctx, providerConfig)
|
||||
|
||||
// We expect this to fail with a connection error (since we're using a fake API key),
|
||||
// NOT with a "model not found" error
|
||||
@@ -143,14 +143,14 @@ func TestProviderURLValidationSkip(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
providerConfig := &models.ProviderConfig{
|
||||
providerConfig := &kit.ProviderConfig{
|
||||
ModelString: tc.model,
|
||||
ProviderAPIKey: "test-key",
|
||||
ProviderURL: tc.providerURL,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := models.CreateProvider(ctx, providerConfig)
|
||||
_, err := kit.CreateProvider(ctx, providerConfig)
|
||||
|
||||
// Should never get a "not found for provider" error — unknown
|
||||
// models are passed through to the provider API.
|
||||
|
||||
+2
-3
@@ -8,7 +8,6 @@ import (
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
"github.com/mark3labs/kit/internal/ui"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
"github.com/spf13/viper"
|
||||
@@ -23,7 +22,7 @@ type AgentSetupResult = kit.AgentSetupResult
|
||||
|
||||
// BuildProviderConfig delegates to the SDK to build a ProviderConfig from
|
||||
// the current viper state.
|
||||
func BuildProviderConfig() (*models.ProviderConfig, string, error) {
|
||||
func BuildProviderConfig() (*kit.ProviderConfig, string, error) {
|
||||
return kit.BuildProviderConfig()
|
||||
}
|
||||
|
||||
@@ -40,7 +39,7 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
// app.Options and UI setup.
|
||||
func CollectAgentMetadata(mcpAgent *agent.Agent, mcpConfig *config.Config) (provider, modelName string, serverNames, toolNames []string) {
|
||||
modelString := viper.GetString("model")
|
||||
provider, modelName, _ = models.ParseModelString(modelString)
|
||||
provider, modelName, _ = kit.ParseModelString(modelString)
|
||||
if modelName == "" {
|
||||
modelName = "Unknown"
|
||||
}
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
package kit
|
||||
|
||||
import "github.com/mark3labs/kit/internal/auth"
|
||||
|
||||
// CredentialManager manages API keys and OAuth credentials.
|
||||
type CredentialManager = auth.CredentialManager
|
||||
|
||||
// AnthropicCredentials holds Anthropic API credentials supporting both OAuth
|
||||
// and API key authentication methods.
|
||||
type AnthropicCredentials = auth.AnthropicCredentials
|
||||
|
||||
// CredentialStore holds all stored credentials for various providers.
|
||||
type CredentialStore = auth.CredentialStore
|
||||
|
||||
// NewCredentialManager creates a credential manager for secure storage and
|
||||
// retrieval of authentication credentials.
|
||||
func NewCredentialManager() (*CredentialManager, error) {
|
||||
return auth.NewCredentialManager()
|
||||
}
|
||||
|
||||
// HasAnthropicCredentials checks if valid Anthropic credentials are stored
|
||||
// (either OAuth token or API key).
|
||||
func HasAnthropicCredentials() bool {
|
||||
cm, err := auth.NewCredentialManager()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
has, err := cm.HasAnthropicCredentials()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return has
|
||||
}
|
||||
|
||||
// GetAnthropicAPIKey resolves the Anthropic API key using the standard
|
||||
// resolution order: stored credentials -> ANTHROPIC_API_KEY env var.
|
||||
// Returns an empty string if no key is found.
|
||||
func GetAnthropicAPIKey() string {
|
||||
key, _, err := auth.GetAnthropicAPIKey("")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return key
|
||||
}
|
||||
@@ -345,6 +345,17 @@ func (m *Kit) GetModelString() string {
|
||||
return m.modelString
|
||||
}
|
||||
|
||||
// GetModelInfo returns detailed information about the current model
|
||||
// (capabilities, pricing, limits). Returns nil if the model is not in the
|
||||
// registry — this is expected for new models or custom fine-tunes.
|
||||
func (m *Kit) GetModelInfo() *ModelInfo {
|
||||
provider, modelID, err := ParseModelString(m.modelString)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return LookupModel(provider, modelID)
|
||||
}
|
||||
|
||||
// GetTools returns all tools available to the agent (core + MCP + extensions).
|
||||
func (m *Kit) GetTools() []Tool {
|
||||
return m.agent.GetTools()
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
)
|
||||
|
||||
// LookupModel returns information about a model, or nil if unknown.
|
||||
func LookupModel(provider, modelID string) *ModelInfo {
|
||||
return models.GetGlobalRegistry().LookupModel(provider, modelID)
|
||||
}
|
||||
|
||||
// GetSupportedProviders returns all known provider names in the registry.
|
||||
func GetSupportedProviders() []string {
|
||||
return models.GetGlobalRegistry().GetSupportedProviders()
|
||||
}
|
||||
|
||||
// GetFantasyProviders returns provider IDs that can be used with fantasy,
|
||||
// either through a native provider or via openaicompat auto-routing.
|
||||
func GetFantasyProviders() []string {
|
||||
return models.GetGlobalRegistry().GetFantasyProviders()
|
||||
}
|
||||
|
||||
// GetModelsForProvider returns all known models for a provider.
|
||||
func GetModelsForProvider(provider string) (map[string]ModelInfo, error) {
|
||||
return models.GetGlobalRegistry().GetModelsForProvider(provider)
|
||||
}
|
||||
|
||||
// GetProviderInfo returns information about a provider (env vars, API URL, etc.).
|
||||
// Returns nil if the provider is not in the registry.
|
||||
func GetProviderInfo(provider string) *ProviderInfo {
|
||||
return models.GetGlobalRegistry().GetProviderInfo(provider)
|
||||
}
|
||||
|
||||
// ValidateEnvironment checks if required API keys are set for a provider.
|
||||
// Returns nil for providers not in the registry (unknown providers are
|
||||
// assumed to handle auth themselves or via --provider-api-key).
|
||||
func ValidateEnvironment(provider string, apiKey string) error {
|
||||
return models.GetGlobalRegistry().ValidateEnvironment(provider, apiKey)
|
||||
}
|
||||
|
||||
// SuggestModels returns model names similar to an invalid model string.
|
||||
func SuggestModels(provider, invalidModel string) []string {
|
||||
return models.GetGlobalRegistry().SuggestModels(provider, invalidModel)
|
||||
}
|
||||
|
||||
// RefreshModelRegistry reloads the global model database from the current
|
||||
// data sources (cache -> embedded). Call after updating the cache.
|
||||
func RefreshModelRegistry() {
|
||||
models.ReloadGlobalRegistry()
|
||||
}
|
||||
|
||||
// CheckProviderReady validates that a provider is properly configured
|
||||
// by checking that it exists in the registry and has required environment
|
||||
// variables set.
|
||||
func CheckProviderReady(provider string) error {
|
||||
info := models.GetGlobalRegistry().GetProviderInfo(provider)
|
||||
if info == nil {
|
||||
return fmt.Errorf("unknown provider: %s", provider)
|
||||
}
|
||||
return models.GetGlobalRegistry().ValidateEnvironment(provider, "")
|
||||
}
|
||||
Reference in New Issue
Block a user