mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
Refactor: Extract shared code between normal and script modes (#94)
This commit addresses issue #92 by extracting duplicated code between normal mode (cmd/root.go) and script mode (cmd/script.go) into reusable factory functions and utilities. ## Changes Made ### New Factory Files - **internal/agent/factory.go**: Agent creation factory with spinner support - `CreateAgent()` function with configurable options - `ParseModelName()` utility for model string parsing - Spinner function injection to avoid import cycles - **internal/ui/factory.go**: CLI setup factory with standard configuration - `SetupCLI()` function for consistent CLI initialization - Usage tracking setup for supported providers - Model info and tool count display - **internal/config/merger.go**: Config loading and merging utilities - `LoadAndValidateConfig()` for standard config loading - `MergeConfigs()` for script frontmatter merging ### Updated Command Files - **cmd/root.go**: Refactored to use new factories - Replaced ~50 lines of agent creation logic - Replaced ~30 lines of CLI setup logic - Replaced ~20 lines of config loading logic - Added agentUIAdapter to handle interface compatibility - **cmd/script.go**: Refactored to use new factories - Same factory usage as normal mode for consistency - Maintained script-specific behavior (no spinners) - Improved config merging with frontmatter ## Benefits - **Reduced code duplication**: ~33 lines of duplicated code eliminated - **Single source of truth**: Agent creation and CLI setup logic centralized - **Consistent behavior**: Both modes now use identical underlying logic - **Easier maintenance**: Changes apply to both modes automatically - **Better testability**: Factory functions can be unit tested independently - **Cleaner command files**: Focus on mode-specific logic only ## Testing - All existing tests pass - Build verification successful - Both normal and script modes tested for basic functionality - Code formatting and linting checks passed 🤖 Generated with [opencode](https://opencode.ai) Co-authored-by: opencode <noreply@opencode.ai>
This commit is contained in:
+129
-156
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/mark3labs/mcphost/internal/agent"
|
||||
"github.com/mark3labs/mcphost/internal/auth"
|
||||
"github.com/mark3labs/mcphost/internal/config"
|
||||
"github.com/mark3labs/mcphost/internal/models"
|
||||
"github.com/mark3labs/mcphost/internal/session"
|
||||
@@ -53,6 +52,28 @@ var (
|
||||
mainGPU int32
|
||||
)
|
||||
|
||||
// agentUIAdapter adapts agent.Agent to ui.AgentInterface
|
||||
type agentUIAdapter struct {
|
||||
agent *agent.Agent
|
||||
}
|
||||
|
||||
func (a *agentUIAdapter) GetLoadingMessage() string {
|
||||
return a.agent.GetLoadingMessage()
|
||||
}
|
||||
|
||||
func (a *agentUIAdapter) GetTools() []any {
|
||||
tools := a.agent.GetTools()
|
||||
result := make([]any, len(tools))
|
||||
for i, tool := range tools {
|
||||
result[i] = tool
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (a *agentUIAdapter) GetLoadedServerNames() []string {
|
||||
return a.agent.GetLoadedServerNames()
|
||||
}
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "mcphost",
|
||||
Short: "Chat with AI models through a unified interface",
|
||||
@@ -285,17 +306,10 @@ func runNormalMode(ctx context.Context) error {
|
||||
// Use script-provided config
|
||||
mcpConfig = scriptMCPConfig
|
||||
} else {
|
||||
// Get MCP config from the global viper instance (already loaded by initConfig)
|
||||
mcpConfig = &config.Config{
|
||||
MCPServers: make(map[string]config.MCPServerConfig),
|
||||
}
|
||||
if err := viper.Unmarshal(mcpConfig); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal MCP config: %v", err)
|
||||
}
|
||||
|
||||
// Validate the config
|
||||
if err := mcpConfig.Validate(); err != nil {
|
||||
return fmt.Errorf("invalid MCP config: %v", err)
|
||||
// Use the new config loader
|
||||
mcpConfig, err = config.LoadAndValidateConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load MCP config: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -331,36 +345,30 @@ func runNormalMode(ctx context.Context) error {
|
||||
MainGPU: &mainGPU,
|
||||
}
|
||||
|
||||
// Create agent configuration
|
||||
agentConfig := &agent.AgentConfig{
|
||||
// Create spinner function for agent creation
|
||||
var spinnerFunc agent.SpinnerFunc
|
||||
if !quietFlag {
|
||||
spinnerFunc = func(message string, fn func() error) error {
|
||||
tempCli, tempErr := ui.NewCLI(viper.GetBool("debug"), viper.GetBool("compact"))
|
||||
if tempErr == nil {
|
||||
return tempCli.ShowSpinner(message, fn)
|
||||
}
|
||||
// Fallback without spinner
|
||||
return fn()
|
||||
}
|
||||
}
|
||||
|
||||
// Create the agent using the factory
|
||||
mcpAgent, err := agent.CreateAgent(ctx, &agent.AgentCreationOptions{
|
||||
ModelConfig: modelConfig,
|
||||
MCPConfig: mcpConfig,
|
||||
SystemPrompt: systemPrompt,
|
||||
MaxSteps: viper.GetInt("max-steps"), // Pass 0 for infinite, agent will handle it
|
||||
MaxSteps: viper.GetInt("max-steps"),
|
||||
StreamingEnabled: viper.GetBool("stream"),
|
||||
}
|
||||
|
||||
// Create the agent with spinner for Ollama models
|
||||
var mcpAgent *agent.Agent
|
||||
|
||||
if strings.HasPrefix(viper.GetString("model"), "ollama:") && !quietFlag {
|
||||
// Create a temporary CLI for the spinner
|
||||
tempCli, tempErr := ui.NewCLI(viper.GetBool("debug"), viper.GetBool("compact"))
|
||||
if tempErr == nil {
|
||||
err = tempCli.ShowSpinner("Loading Ollama model...", func() error {
|
||||
var agentErr error
|
||||
mcpAgent, agentErr = agent.NewAgent(ctx, agentConfig)
|
||||
return agentErr
|
||||
})
|
||||
} else {
|
||||
// Fallback without spinner
|
||||
mcpAgent, err = agent.NewAgent(ctx, agentConfig)
|
||||
}
|
||||
} else {
|
||||
// No spinner for other providers
|
||||
mcpAgent, err = agent.NewAgent(ctx, agentConfig)
|
||||
}
|
||||
|
||||
ShowSpinner: true,
|
||||
Quiet: quietFlag,
|
||||
SpinnerFunc: spinnerFunc,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create agent: %v", err)
|
||||
}
|
||||
@@ -374,137 +382,101 @@ func runNormalMode(ctx context.Context) error {
|
||||
modelName = parts[1]
|
||||
}
|
||||
|
||||
// Get tools
|
||||
tools := mcpAgent.GetTools()
|
||||
// Create an adapter for the agent to match the UI interface
|
||||
agentAdapter := &agentUIAdapter{agent: mcpAgent}
|
||||
|
||||
// Create CLI interface (skip if quiet mode)
|
||||
var cli *ui.CLI
|
||||
if !quietFlag {
|
||||
cli, err = ui.NewCLI(viper.GetBool("debug"), viper.GetBool("compact"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create CLI: %v", err)
|
||||
// Create CLI interface using the factory
|
||||
cli, err := ui.SetupCLI(&ui.CLISetupOptions{
|
||||
Agent: agentAdapter,
|
||||
ModelString: modelString,
|
||||
Debug: viper.GetBool("debug"),
|
||||
Compact: viper.GetBool("compact"),
|
||||
Quiet: quietFlag,
|
||||
ShowDebug: false, // Will be handled separately below
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup CLI: %v", err)
|
||||
}
|
||||
|
||||
// Display debug configuration if debug mode is enabled
|
||||
if !quietFlag && cli != nil && viper.GetBool("debug") {
|
||||
debugConfig := map[string]any{
|
||||
"model": viper.GetString("model"),
|
||||
"max-steps": viper.GetInt("max-steps"),
|
||||
"max-tokens": viper.GetInt("max-tokens"),
|
||||
"temperature": viper.GetFloat64("temperature"),
|
||||
"top-p": viper.GetFloat64("top-p"),
|
||||
"top-k": viper.GetInt("top-k"),
|
||||
"provider-url": viper.GetString("provider-url"),
|
||||
"system-prompt": viper.GetString("system-prompt"),
|
||||
}
|
||||
|
||||
// Set the model name for consistent display
|
||||
cli.SetModelName(modelName)
|
||||
// Add Ollama-specific parameters if using Ollama
|
||||
if strings.HasPrefix(viper.GetString("model"), "ollama:") {
|
||||
debugConfig["num-gpu-layers"] = viper.GetInt("num-gpu-layers")
|
||||
debugConfig["main-gpu"] = viper.GetInt("main-gpu")
|
||||
}
|
||||
|
||||
// Set the model name for consistent display
|
||||
cli.SetModelName(modelName)
|
||||
// Only include non-empty stop sequences
|
||||
stopSequences := viper.GetStringSlice("stop-sequences")
|
||||
if len(stopSequences) > 0 {
|
||||
debugConfig["stop-sequences"] = stopSequences
|
||||
}
|
||||
|
||||
// Set up usage tracking for supported providers
|
||||
if len(parts) == 2 {
|
||||
provider := parts[0]
|
||||
modelID := parts[1]
|
||||
// Only include API keys if they're set (but don't show the actual values for security)
|
||||
if viper.GetString("provider-api-key") != "" {
|
||||
debugConfig["provider-api-key"] = "[SET]"
|
||||
}
|
||||
|
||||
// Skip usage tracking for ollama as it's not in models.dev
|
||||
if provider != "ollama" {
|
||||
registry := models.GetGlobalRegistry()
|
||||
if modelInfo, err := registry.ValidateModel(provider, modelID); err == nil {
|
||||
// Check if OAuth credentials are being used for Anthropic models
|
||||
isOAuth := false
|
||||
if provider == "anthropic" {
|
||||
_, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key"))
|
||||
if err == nil && strings.HasPrefix(source, "stored OAuth") {
|
||||
isOAuth = true
|
||||
// Add MCP server configuration for debugging
|
||||
if len(mcpConfig.MCPServers) > 0 {
|
||||
mcpServers := make(map[string]any)
|
||||
loadedServers := mcpAgent.GetLoadedServerNames()
|
||||
loadedServerSet := make(map[string]bool)
|
||||
for _, name := range loadedServers {
|
||||
loadedServerSet[name] = true
|
||||
}
|
||||
|
||||
for name, server := range mcpConfig.MCPServers {
|
||||
serverInfo := map[string]any{
|
||||
"type": server.Type,
|
||||
"status": "failed", // Default to failed
|
||||
}
|
||||
|
||||
// Mark as loaded if it's in the loaded servers list
|
||||
if loadedServerSet[name] {
|
||||
serverInfo["status"] = "loaded"
|
||||
}
|
||||
|
||||
if len(server.Command) > 0 {
|
||||
serverInfo["command"] = server.Command
|
||||
}
|
||||
if len(server.Environment) > 0 {
|
||||
// Mask sensitive environment variables
|
||||
maskedEnv := make(map[string]string)
|
||||
for k, v := range server.Environment {
|
||||
if strings.Contains(strings.ToLower(k), "token") ||
|
||||
strings.Contains(strings.ToLower(k), "key") ||
|
||||
strings.Contains(strings.ToLower(k), "secret") {
|
||||
maskedEnv[k] = "[MASKED]"
|
||||
} else {
|
||||
maskedEnv[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
usageTracker := ui.NewUsageTracker(modelInfo, provider, 80, isOAuth) // Will be updated with actual width
|
||||
cli.SetUsageTracker(usageTracker)
|
||||
serverInfo["environment"] = maskedEnv
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Log successful initialization
|
||||
if len(parts) == 2 {
|
||||
cli.DisplayInfo(fmt.Sprintf("Model loaded: %s (%s)", parts[0], parts[1]))
|
||||
}
|
||||
|
||||
// Display loading message if available (e.g., GPU fallback info)
|
||||
if loadingMessage := mcpAgent.GetLoadingMessage(); loadingMessage != "" {
|
||||
cli.DisplayInfo(loadingMessage)
|
||||
}
|
||||
|
||||
cli.DisplayInfo(fmt.Sprintf("Loaded %d tools from MCP servers", len(tools)))
|
||||
// Display debug configuration if debug mode is enabled
|
||||
if viper.GetBool("debug") {
|
||||
debugConfig := map[string]any{
|
||||
"model": viper.GetString("model"),
|
||||
"max-steps": viper.GetInt("max-steps"),
|
||||
"max-tokens": viper.GetInt("max-tokens"),
|
||||
"temperature": viper.GetFloat64("temperature"),
|
||||
"top-p": viper.GetFloat64("top-p"),
|
||||
"top-k": viper.GetInt("top-k"),
|
||||
"provider-url": viper.GetString("provider-url"),
|
||||
"system-prompt": viper.GetString("system-prompt"),
|
||||
}
|
||||
|
||||
// Add Ollama-specific parameters if using Ollama
|
||||
if strings.HasPrefix(viper.GetString("model"), "ollama:") {
|
||||
debugConfig["num-gpu-layers"] = viper.GetInt("num-gpu-layers")
|
||||
debugConfig["main-gpu"] = viper.GetInt("main-gpu")
|
||||
}
|
||||
|
||||
// Only include non-empty stop sequences
|
||||
stopSequences := viper.GetStringSlice("stop-sequences")
|
||||
if len(stopSequences) > 0 {
|
||||
debugConfig["stop-sequences"] = stopSequences
|
||||
}
|
||||
|
||||
// Only include API keys if they're set (but don't show the actual values for security)
|
||||
if viper.GetString("provider-api-key") != "" {
|
||||
debugConfig["provider-api-key"] = "[SET]"
|
||||
}
|
||||
|
||||
// Add MCP server configuration for debugging
|
||||
if len(mcpConfig.MCPServers) > 0 {
|
||||
mcpServers := make(map[string]any)
|
||||
loadedServers := mcpAgent.GetLoadedServerNames()
|
||||
loadedServerSet := make(map[string]bool)
|
||||
for _, name := range loadedServers {
|
||||
loadedServerSet[name] = true
|
||||
if server.URL != "" {
|
||||
serverInfo["url"] = server.URL
|
||||
}
|
||||
|
||||
for name, server := range mcpConfig.MCPServers {
|
||||
serverInfo := map[string]any{
|
||||
"type": server.Type,
|
||||
"status": "failed", // Default to failed
|
||||
}
|
||||
|
||||
// Mark as loaded if it's in the loaded servers list
|
||||
if loadedServerSet[name] {
|
||||
serverInfo["status"] = "loaded"
|
||||
}
|
||||
|
||||
if len(server.Command) > 0 {
|
||||
serverInfo["command"] = server.Command
|
||||
}
|
||||
if len(server.Environment) > 0 {
|
||||
// Mask sensitive environment variables
|
||||
maskedEnv := make(map[string]string)
|
||||
for k, v := range server.Environment {
|
||||
if strings.Contains(strings.ToLower(k), "token") ||
|
||||
strings.Contains(strings.ToLower(k), "key") ||
|
||||
strings.Contains(strings.ToLower(k), "secret") {
|
||||
maskedEnv[k] = "[MASKED]"
|
||||
} else {
|
||||
maskedEnv[k] = v
|
||||
}
|
||||
}
|
||||
serverInfo["environment"] = maskedEnv
|
||||
}
|
||||
if server.URL != "" {
|
||||
serverInfo["url"] = server.URL
|
||||
}
|
||||
if server.Name != "" {
|
||||
serverInfo["name"] = server.Name
|
||||
}
|
||||
mcpServers[name] = serverInfo
|
||||
if server.Name != "" {
|
||||
serverInfo["name"] = server.Name
|
||||
}
|
||||
debugConfig["mcpServers"] = mcpServers
|
||||
mcpServers[name] = serverInfo
|
||||
}
|
||||
cli.DisplayDebugConfig(debugConfig)
|
||||
debugConfig["mcpServers"] = mcpServers
|
||||
}
|
||||
cli.DisplayDebugConfig(debugConfig)
|
||||
}
|
||||
|
||||
// Prepare data for slash commands
|
||||
@@ -513,6 +485,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
serverNames = append(serverNames, name)
|
||||
}
|
||||
|
||||
tools := mcpAgent.GetTools()
|
||||
var toolNames []string
|
||||
for _, tool := range tools {
|
||||
if info, err := tool.Info(ctx); err == nil {
|
||||
|
||||
+52
-58
@@ -203,27 +203,21 @@ func runScriptCommand(ctx context.Context, scriptFile string, variables map[stri
|
||||
// Get MCP config - use script servers if available, otherwise use global viper config
|
||||
var mcpConfig *config.Config
|
||||
if len(scriptConfig.MCPServers) > 0 {
|
||||
// Use MCP servers from script, but get other config values from viper
|
||||
// First, unmarshal all config from viper
|
||||
mcpConfig = &config.Config{}
|
||||
if err := viper.Unmarshal(mcpConfig); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal config: %v", err)
|
||||
// Load base config and merge with script config
|
||||
baseConfig, err := config.LoadAndValidateConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load base config: %v", err)
|
||||
}
|
||||
// Then completely override MCPServers with script's servers
|
||||
mcpConfig.MCPServers = scriptConfig.MCPServers
|
||||
mcpConfig = config.MergeConfigs(baseConfig, scriptConfig)
|
||||
} else {
|
||||
// Get MCP config from the global viper instance (already loaded by initConfig)
|
||||
mcpConfig = &config.Config{}
|
||||
if err := viper.Unmarshal(mcpConfig); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal MCP config: %v", err)
|
||||
// Use the new config loader
|
||||
var err error
|
||||
mcpConfig, err = config.LoadAndValidateConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load MCP config: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate the config
|
||||
if err := mcpConfig.Validate(); err != nil {
|
||||
return fmt.Errorf("invalid MCP config: %v", err)
|
||||
}
|
||||
|
||||
// Get final prompt - prioritize command line flag, then script content
|
||||
finalPrompt := viper.GetString("prompt")
|
||||
if finalPrompt == "" && scriptConfig.Prompt != "" {
|
||||
@@ -550,17 +544,17 @@ func runScriptMode(ctx context.Context, mcpConfig *config.Config, prompt string,
|
||||
StopSequences: finalStopSequences,
|
||||
}
|
||||
|
||||
// Create agent configuration
|
||||
agentConfig := &agent.AgentConfig{
|
||||
// Create the agent using the factory (scripts don't need spinners)
|
||||
mcpAgent, err := agent.CreateAgent(ctx, &agent.AgentCreationOptions{
|
||||
ModelConfig: modelConfig,
|
||||
MCPConfig: mcpConfig,
|
||||
SystemPrompt: systemPrompt,
|
||||
MaxSteps: finalMaxSteps,
|
||||
StreamingEnabled: viper.GetBool("stream"),
|
||||
}
|
||||
|
||||
// Create the agent
|
||||
mcpAgent, err := agent.NewAgent(ctx, agentConfig)
|
||||
ShowSpinner: false, // Scripts don't need spinners
|
||||
Quiet: quietFlag,
|
||||
SpinnerFunc: nil, // No spinner function needed
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create agent: %v", err)
|
||||
}
|
||||
@@ -573,47 +567,47 @@ func runScriptMode(ctx context.Context, mcpConfig *config.Config, prompt string,
|
||||
modelName = parts[1]
|
||||
}
|
||||
|
||||
// Create CLI interface (skip if quiet mode)
|
||||
var cli *ui.CLI
|
||||
if !quietFlag {
|
||||
cli, err = ui.NewCLI(finalDebug, finalCompact)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create CLI: %v", err)
|
||||
// Create an adapter for the agent to match the UI interface
|
||||
agentAdapter := &agentUIAdapter{agent: mcpAgent}
|
||||
|
||||
// Create CLI interface using the factory
|
||||
cli, err := ui.SetupCLI(&ui.CLISetupOptions{
|
||||
Agent: agentAdapter,
|
||||
ModelString: finalModel,
|
||||
Debug: finalDebug,
|
||||
Compact: finalCompact,
|
||||
Quiet: quietFlag,
|
||||
ShowDebug: false, // Will be handled separately below
|
||||
ProviderAPIKey: finalProviderAPIKey,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup CLI: %v", err)
|
||||
}
|
||||
|
||||
// Display debug configuration if debug mode is enabled
|
||||
if !quietFlag && cli != nil && finalDebug {
|
||||
debugConfig := map[string]any{
|
||||
"model": finalModel,
|
||||
"max-steps": finalMaxSteps,
|
||||
"max-tokens": finalMaxTokens,
|
||||
"temperature": finalTemperature,
|
||||
"top-p": finalTopP,
|
||||
"top-k": finalTopK,
|
||||
"provider-url": finalProviderURL,
|
||||
"system-prompt": finalSystemPrompt,
|
||||
}
|
||||
|
||||
// Log successful initialization
|
||||
if len(parts) == 2 {
|
||||
cli.DisplayInfo(fmt.Sprintf("Model loaded: %s (%s)", parts[0], parts[1]))
|
||||
// Only include non-empty stop sequences
|
||||
if len(finalStopSequences) > 0 {
|
||||
debugConfig["stop-sequences"] = finalStopSequences
|
||||
}
|
||||
|
||||
tools := mcpAgent.GetTools()
|
||||
cli.DisplayInfo(fmt.Sprintf("Loaded %d tools from MCP servers", len(tools)))
|
||||
|
||||
// Display debug configuration if debug mode is enabled
|
||||
if finalDebug {
|
||||
debugConfig := map[string]any{
|
||||
"model": finalModel,
|
||||
"max-steps": finalMaxSteps,
|
||||
"max-tokens": finalMaxTokens,
|
||||
"temperature": finalTemperature,
|
||||
"top-p": finalTopP,
|
||||
"top-k": finalTopK,
|
||||
"provider-url": finalProviderURL,
|
||||
"system-prompt": finalSystemPrompt,
|
||||
}
|
||||
|
||||
// Only include non-empty stop sequences
|
||||
if len(finalStopSequences) > 0 {
|
||||
debugConfig["stop-sequences"] = finalStopSequences
|
||||
}
|
||||
|
||||
// Only include API keys if they're set (but don't show the actual values for security)
|
||||
if finalProviderAPIKey != "" {
|
||||
debugConfig["provider-api-key"] = "[SET]"
|
||||
}
|
||||
|
||||
cli.DisplayDebugConfig(debugConfig)
|
||||
// Only include API keys if they're set (but don't show the actual values for security)
|
||||
if finalProviderAPIKey != "" {
|
||||
debugConfig["provider-api-key"] = "[SET]"
|
||||
}
|
||||
|
||||
cli.DisplayDebugConfig(debugConfig)
|
||||
}
|
||||
|
||||
// Prepare data for slash commands
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mark3labs/mcphost/internal/config"
|
||||
"github.com/mark3labs/mcphost/internal/models"
|
||||
)
|
||||
|
||||
// SpinnerFunc is a function type for showing spinners during agent creation
|
||||
type SpinnerFunc func(message string, fn func() error) error
|
||||
|
||||
// AgentCreationOptions contains options for creating an agent
|
||||
type AgentCreationOptions struct {
|
||||
ModelConfig *models.ProviderConfig
|
||||
MCPConfig *config.Config
|
||||
SystemPrompt string
|
||||
MaxSteps int
|
||||
StreamingEnabled bool
|
||||
ShowSpinner bool // For Ollama models
|
||||
Quiet bool // Skip spinner if quiet
|
||||
SpinnerFunc SpinnerFunc // Function to show spinner (provided by caller)
|
||||
}
|
||||
|
||||
// CreateAgent creates an agent with optional spinner for Ollama models
|
||||
func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error) {
|
||||
agentConfig := &AgentConfig{
|
||||
ModelConfig: opts.ModelConfig,
|
||||
MCPConfig: opts.MCPConfig,
|
||||
SystemPrompt: opts.SystemPrompt,
|
||||
MaxSteps: opts.MaxSteps,
|
||||
StreamingEnabled: opts.StreamingEnabled,
|
||||
}
|
||||
|
||||
var agent *Agent
|
||||
var err error
|
||||
|
||||
// Show spinner for Ollama models if requested and not quiet
|
||||
if opts.ShowSpinner && strings.HasPrefix(opts.ModelConfig.ModelString, "ollama:") && !opts.Quiet && opts.SpinnerFunc != nil {
|
||||
err = opts.SpinnerFunc("Loading Ollama model...", func() error {
|
||||
agent, err = NewAgent(ctx, agentConfig)
|
||||
return err
|
||||
})
|
||||
} else {
|
||||
agent, err = NewAgent(ctx, agentConfig)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create agent: %v", err)
|
||||
}
|
||||
|
||||
return agent, nil
|
||||
}
|
||||
|
||||
// ParseModelName extracts provider and model name from model string
|
||||
func ParseModelName(modelString string) (provider, model string) {
|
||||
parts := strings.SplitN(modelString, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
return parts[0], parts[1]
|
||||
}
|
||||
return "unknown", "unknown"
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// MergeConfigs merges script frontmatter config with base config
|
||||
func MergeConfigs(baseConfig *Config, scriptConfig *Config) *Config {
|
||||
merged := *baseConfig // Copy base config
|
||||
|
||||
// Override MCP servers if script provides them
|
||||
if len(scriptConfig.MCPServers) > 0 {
|
||||
merged.MCPServers = scriptConfig.MCPServers
|
||||
}
|
||||
|
||||
// Add other merge logic as needed for future config fields
|
||||
return &merged
|
||||
}
|
||||
|
||||
// LoadAndValidateConfig loads config from viper and validates it
|
||||
func LoadAndValidateConfig() (*Config, error) {
|
||||
config := &Config{
|
||||
MCPServers: make(map[string]MCPServerConfig),
|
||||
}
|
||||
if err := viper.Unmarshal(config); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal config: %v", err)
|
||||
}
|
||||
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid config: %v", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mark3labs/mcphost/internal/auth"
|
||||
"github.com/mark3labs/mcphost/internal/models"
|
||||
)
|
||||
|
||||
// AgentInterface defines the interface we need from agent to avoid import cycles
|
||||
type AgentInterface interface {
|
||||
GetLoadingMessage() string
|
||||
GetTools() []any // Using any to avoid importing tool types
|
||||
GetLoadedServerNames() []string // Add this method for debug config
|
||||
}
|
||||
|
||||
// CLISetupOptions contains options for setting up CLI
|
||||
type CLISetupOptions struct {
|
||||
Agent AgentInterface
|
||||
ModelString string
|
||||
Debug bool
|
||||
Compact bool
|
||||
Quiet bool
|
||||
ShowDebug bool // Whether to show debug config
|
||||
ProviderAPIKey string // For OAuth detection
|
||||
}
|
||||
|
||||
// parseModelName extracts provider and model name from model string
|
||||
func parseModelName(modelString string) (provider, model string) {
|
||||
parts := strings.SplitN(modelString, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
return parts[0], parts[1]
|
||||
}
|
||||
return "unknown", "unknown"
|
||||
}
|
||||
|
||||
// SetupCLI creates and configures CLI with standard info display
|
||||
func SetupCLI(opts *CLISetupOptions) (*CLI, error) {
|
||||
if opts.Quiet {
|
||||
return nil, nil // No CLI in quiet mode
|
||||
}
|
||||
|
||||
cli, err := NewCLI(opts.Debug, opts.Compact)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create CLI: %v", err)
|
||||
}
|
||||
|
||||
// Parse model string for display and usage tracking
|
||||
provider, model := parseModelName(opts.ModelString)
|
||||
|
||||
// Set the model name for consistent display
|
||||
if model != "unknown" {
|
||||
cli.SetModelName(model)
|
||||
}
|
||||
|
||||
// Set up usage tracking for supported providers
|
||||
if provider != "unknown" && model != "unknown" {
|
||||
// Skip usage tracking for ollama as it's not in models.dev
|
||||
if provider != "ollama" {
|
||||
registry := models.GetGlobalRegistry()
|
||||
if modelInfo, err := registry.ValidateModel(provider, model); err == nil {
|
||||
// Check if OAuth credentials are being used for Anthropic models
|
||||
isOAuth := false
|
||||
if provider == "anthropic" {
|
||||
_, source, err := auth.GetAnthropicAPIKey(opts.ProviderAPIKey)
|
||||
if err == nil && strings.HasPrefix(source, "stored OAuth") {
|
||||
isOAuth = true
|
||||
}
|
||||
}
|
||||
|
||||
usageTracker := NewUsageTracker(modelInfo, provider, 80, isOAuth) // Will be updated with actual width
|
||||
cli.SetUsageTracker(usageTracker)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Display model info
|
||||
if provider != "unknown" && model != "unknown" {
|
||||
cli.DisplayInfo(fmt.Sprintf("Model loaded: %s (%s)", provider, model))
|
||||
}
|
||||
|
||||
// Display loading message if available (e.g., GPU fallback info)
|
||||
if loadingMessage := opts.Agent.GetLoadingMessage(); loadingMessage != "" {
|
||||
cli.DisplayInfo(loadingMessage)
|
||||
}
|
||||
|
||||
// Display tool count
|
||||
tools := opts.Agent.GetTools()
|
||||
cli.DisplayInfo(fmt.Sprintf("Loaded %d tools from MCP servers", len(tools)))
|
||||
|
||||
return cli, nil
|
||||
}
|
||||
Reference in New Issue
Block a user