package ui import ( "fmt" "github.com/mark3labs/kit/internal/auth" "github.com/mark3labs/kit/internal/models" ) // AgentInterface defines the minimal interface required from the agent package // to avoid circular dependencies while still accessing necessary agent functionality. type AgentInterface interface { GetLoadingMessage() string GetTools() []any // Using any to avoid importing tool types GetLoadedServerNames() []string // Add this method for debug config GetMCPToolCount() int // Tools loaded from external MCP servers GetExtensionToolCount() int // Tools registered by extensions } // CLISetupOptions encapsulates all configuration parameters needed to initialize // and set up a CLI instance, including display preferences, model information, // and debugging settings. type CLISetupOptions struct { Agent AgentInterface ModelString string Debug bool Quiet bool ShowDebug bool // Whether to show debug config ProviderAPIKey string // For OAuth detection } // parseModelName extracts provider and model name from model string func parseModelName(modelString string) (provider, model string) { p, m, err := models.ParseModelString(modelString) if err != nil { return "unknown", "unknown" } return p, m } // CreateUsageTracker creates a UsageTracker for the given model string and // provider API key. It returns nil when usage tracking is unavailable (e.g. // ollama or unrecognised models). This is used by the interactive TUI path // which doesn't go through SetupCLI. func CreateUsageTracker(modelString, providerAPIKey string) *UsageTracker { modelInfo, provider := lookupTrackableModel(modelString) if modelInfo == nil { return nil } isOAuth := provider == "anthropic" && auth.IsAnthropicOAuth(providerAPIKey) return NewUsageTracker(modelInfo, provider, 80, isOAuth) } // UpdateUsageTrackerForModel refreshes an existing tracker after a model // switch so token counting and cost reporting use the new model's metadata. // No-op for a nil tracker or untrackable models (unknown/ollama). func UpdateUsageTrackerForModel(t *UsageTracker, modelString, providerAPIKey string) { if t == nil { return } modelInfo, provider := lookupTrackableModel(modelString) if modelInfo == nil { return } isOAuth := provider == "anthropic" && auth.IsAnthropicOAuth(providerAPIKey) t.UpdateModelInfo(modelInfo, provider, isOAuth) } // lookupTrackableModel resolves a model string to registry metadata, returning // nil for models without usage tracking support (unknown or ollama models). func lookupTrackableModel(modelString string) (*models.ModelInfo, string) { provider, model := parseModelName(modelString) if provider == "unknown" || model == "unknown" || provider == "ollama" { return nil, provider } return models.GetGlobalRegistry().LookupModel(provider, model), provider } // SetupCLI creates, configures, and initializes a CLI instance with the provided // options. It sets up model display, usage tracking for supported providers, and // shows initial loading information. Returns nil in quiet mode or an initialized // CLI instance ready for user interaction. func SetupCLI(opts *CLISetupOptions) (*CLI, error) { if opts.Quiet { return nil, nil // No CLI in quiet mode } cli, err := NewCLI(opts.Debug) if err != nil { return nil, fmt.Errorf("failed to create CLI: %v", err) } // Parse model string for display and usage tracking provider, model := parseModelName(opts.ModelString) // Set the model name for consistent display if model != "unknown" { cli.SetModelName(model) } // Set up usage tracking for supported providers if usageTracker := CreateUsageTracker(opts.ModelString, opts.ProviderAPIKey); usageTracker != nil { cli.SetUsageTracker(usageTracker) } // Display model info (the system message block provides its own spacing). if provider != "unknown" && model != "unknown" { cli.DisplayInfo(fmt.Sprintf("Model loaded: %s (%s)", provider, model)) } // Display loading message if available (e.g., GPU fallback info) if loadingMessage := opts.Agent.GetLoadingMessage(); loadingMessage != "" { cli.DisplayInfo(loadingMessage) } // Display extension tool count (only when > 0). if extCount := opts.Agent.GetExtensionToolCount(); extCount > 0 { cli.DisplayInfo(fmt.Sprintf("Loaded %d extension tools", extCount)) } // Display MCP tool count (only when > 0). if mcpCount := opts.Agent.GetMCPToolCount(); mcpCount > 0 { cli.DisplayInfo(fmt.Sprintf("Loaded %d tools from MCP servers", mcpCount)) } return cli, nil }