diff --git a/cmd/auth.go b/cmd/auth.go index 5a769725..1b5800bb 100644 --- a/cmd/auth.go +++ b/cmd/auth.go @@ -10,6 +10,10 @@ import ( "github.com/spf13/cobra" ) +// authCmd represents the auth command for managing AI provider authentication. +// This command provides subcommands for login, logout, and status checking +// of authentication credentials for various AI providers, with OAuth support +// for providers like Anthropic. var authCmd = &cobra.Command{ Use: "auth", Short: "Manage authentication credentials for AI providers", @@ -27,6 +31,9 @@ Examples: mcphost auth status`, } +// authLoginCmd represents the login subcommand for authenticating with AI providers. +// It handles OAuth flow for supported providers, opening a browser for authentication +// and securely storing the resulting credentials for future use. var authLoginCmd = &cobra.Command{ Use: "login [provider]", Short: "Authenticate with an AI provider using OAuth", @@ -45,6 +52,9 @@ Example: RunE: runAuthLogin, } +// authLogoutCmd represents the logout subcommand for removing stored authentication credentials. +// This command removes stored API keys or OAuth tokens for specified providers, +// requiring the user to authenticate again or use environment variables. var authLogoutCmd = &cobra.Command{ Use: "logout [provider]", Short: "Remove stored authentication credentials for a provider", @@ -62,6 +72,9 @@ Example: RunE: runAuthLogout, } +// authStatusCmd represents the status subcommand for checking authentication status. +// It displays which providers have stored credentials, their types (OAuth vs API key), +// creation dates, and expiration status without revealing the actual credentials. var authStatusCmd = &cobra.Command{ Use: "status", Short: "Show authentication status for all providers", diff --git a/cmd/hooks.go b/cmd/hooks.go index 3e5dad19..863cd462 100644 --- a/cmd/hooks.go +++ b/cmd/hooks.go @@ -10,12 +10,18 @@ import ( "gopkg.in/yaml.v3" ) +// hooksCmd represents the hooks command for managing MCPHost hook configurations. +// Hooks allow users to execute custom scripts or commands at various points +// during MCPHost execution, such as before/after tool use or when prompts are submitted. var hooksCmd = &cobra.Command{ Use: "hooks", Short: "Manage MCPHost hooks", Long: "Commands for managing and testing MCPHost hooks configuration", } +// hooksListCmd represents the list subcommand for displaying all configured hooks. +// It shows a formatted table of hook events, matchers, commands, and timeouts +// to help users understand their current hook configuration. var hooksListCmd = &cobra.Command{ Use: "list", Short: "List all configured hooks", @@ -45,6 +51,9 @@ var hooksListCmd = &cobra.Command{ }, } +// hooksValidateCmd represents the validate subcommand for checking hook configuration validity. +// It loads and validates the hooks configuration file, ensuring proper syntax, +// valid event types, and correct matcher patterns before use. var hooksValidateCmd = &cobra.Command{ Use: "validate", Short: "Validate hooks configuration", @@ -64,6 +73,9 @@ var hooksValidateCmd = &cobra.Command{ }, } +// hooksInitCmd represents the init subcommand for generating an example hooks configuration. +// It creates a .mcphost/hooks.yml file with sample hook configurations demonstrating +// various hook events and common use cases like logging commands and tool usage. var hooksInitCmd = &cobra.Command{ Use: "init", Short: "Generate example hooks configuration", diff --git a/cmd/root.go b/cmd/root.go index 19008e4a..a3860f90 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -84,6 +84,10 @@ func (a *agentUIAdapter) GetLoadedServerNames() []string { return a.agent.GetLoadedServerNames() } +// rootCmd represents the base command when called without any subcommands. +// This is the main entry point for the MCPHost CLI application, providing +// an interface to interact with various AI models through a unified interface +// with support for MCP servers and tool integration. var rootCmd = &cobra.Command{ Use: "mcphost", Short: "Chat with AI models through a unified interface", @@ -120,12 +124,21 @@ Examples: }, } -// GetRootCommand returns the root command with the version set +// GetRootCommand returns the root command with the version set. +// This function is the main entry point for the MCPHost CLI and should be +// called from main.go with the appropriate version string. func GetRootCommand(v string) *cobra.Command { rootCmd.Version = v return rootCmd } +// InitConfig initializes the configuration for MCPHost by loading config files, +// environment variables, and hooks configuration. It follows this priority order: +// 1. Command-line specified config file (--config flag) +// 2. Current directory config file (.mcphost or .mcp) +// 3. Home directory config file (~/.mcphost or ~/.mcp) +// 4. Environment variables (MCPHOST_* prefix) +// This function is automatically called by cobra before command execution. func InitConfig() { if configFile != "" { // Use config file from the flag @@ -202,7 +215,12 @@ func InitConfig() { } -// LoadConfigWithEnvSubstitution loads a config file with environment variable substitution +// LoadConfigWithEnvSubstitution loads a config file with environment variable substitution. +// It reads the config file, replaces any ${ENV_VAR} patterns with their corresponding +// environment variable values, and then parses the resulting configuration using viper. +// The function automatically detects JSON or YAML format based on file extension. +// Returns an error if the file cannot be read, environment variable substitution fails, +// or the configuration cannot be parsed. func LoadConfigWithEnvSubstitution(configPath string) error { // Read raw config file content rawContent, err := os.ReadFile(configPath) @@ -728,7 +746,9 @@ func runNormalMode(ctx context.Context) error { return runInteractiveMode(ctx, mcpAgent, cli, serverNames, toolNames, modelName, messages, sessionManager, hookExecutor) } -// AgenticLoopConfig configures the behavior of the unified agentic loop +// AgenticLoopConfig configures the behavior of the unified agentic loop. +// This struct controls how the main interaction loop operates, whether in +// interactive or non-interactive mode, and manages various UI and session options. type AgenticLoopConfig struct { // Mode configuration IsInteractive bool // true for interactive mode, false for non-interactive diff --git a/cmd/script.go b/cmd/script.go index 2b3a7b20..67e34626 100644 --- a/cmd/script.go +++ b/cmd/script.go @@ -21,6 +21,10 @@ import ( "github.com/spf13/viper" ) +// scriptCmd represents the script command for executing MCPHost script files. +// Script files can contain YAML frontmatter configuration followed by a prompt, +// allowing for reproducible AI interactions with custom configurations and +// variable substitution support. var scriptCmd = &cobra.Command{ Use: "script ", Short: "Execute a script file with YAML frontmatter configuration", @@ -413,11 +417,13 @@ func parseScriptContent(content string, variables map[string]string) (*config.Co return &scriptConfig, nil } -// Variable represents a script variable with optional default value +// Variable represents a script variable with optional default value. +// Variables can be declared in scripts using ${variable} syntax for required variables +// or ${variable:-default} syntax for variables with default values. type Variable struct { - Name string - DefaultValue string - HasDefault bool + Name string // The name of the variable as it appears in the script + DefaultValue string // The default value if specified using ${variable:-default} syntax + HasDefault bool // Whether this variable has a default value } // findVariables extracts all unique variable names from ${variable} patterns in content diff --git a/internal/agent/agent.go b/internal/agent/agent.go index e0d76726..8a0e3cfd 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -16,35 +16,49 @@ import ( "time" ) -// AgentConfig is the config for agent. +// AgentConfig holds configuration options for creating a new Agent. +// It includes model configuration, MCP settings, and various behavioral options. type AgentConfig struct { - ModelConfig *models.ProviderConfig - MCPConfig *config.Config - SystemPrompt string - MaxSteps int + // ModelConfig specifies the LLM provider and model to use + ModelConfig *models.ProviderConfig + // MCPConfig contains MCP server configurations + MCPConfig *config.Config + // SystemPrompt is the initial system message for the agent + SystemPrompt string + // MaxSteps limits the number of tool calls (0 for unlimited) + MaxSteps int + // StreamingEnabled controls whether responses are streamed StreamingEnabled bool - DebugLogger tools.DebugLogger // Optional debug logger + // DebugLogger is an optional logger for debugging MCP communications + DebugLogger tools.DebugLogger // Optional debug logger } -// ToolCallHandler is a function type for handling tool calls as they happen +// ToolCallHandler is a function type for handling tool calls as they happen. +// It receives the tool name and its arguments when a tool is about to be invoked. type ToolCallHandler func(toolName, toolArgs string) -// ToolExecutionHandler is a function type for handling tool execution start/end +// ToolExecutionHandler is a function type for handling tool execution start/end events. +// The isStarting parameter indicates whether the tool is starting (true) or finished (false). type ToolExecutionHandler func(toolName string, isStarting bool) -// ToolResultHandler is a function type for handling tool results +// ToolResultHandler is a function type for handling tool results. +// It receives the tool name, arguments, result, and whether the result is an error. type ToolResultHandler func(toolName, toolArgs, result string, isError bool) -// ResponseHandler is a function type for handling LLM responses +// ResponseHandler is a function type for handling LLM responses. +// It receives the complete response content from the model. type ResponseHandler func(content string) -// StreamingResponseHandler is a function type for handling streaming LLM responses +// StreamingResponseHandler is a function type for handling streaming LLM responses. +// It receives content chunks as they are streamed from the model. type StreamingResponseHandler func(content string) -// ToolCallContentHandler is a function type for handling content that accompanies tool calls +// ToolCallContentHandler is a function type for handling content that accompanies tool calls. +// It receives any text content that the model generates alongside tool calls. type ToolCallContentHandler func(content string) -// Agent is the agent with real-time tool call display. +// Agent represents an AI agent with MCP tool integration and real-time tool call display. +// It manages the interaction between an LLM and various tools through the MCP protocol. type Agent struct { toolManager *tools.MCPToolManager model model.ToolCallingChatModel @@ -55,7 +69,10 @@ type Agent struct { streamingEnabled bool // Whether streaming is enabled } -// NewAgent creates an agent with MCP tool integration and real-time tool call display +// NewAgent creates a new Agent with MCP tool integration and streaming support. +// It initializes the LLM provider, loads MCP tools, and configures the agent +// based on the provided configuration. Returns an error if provider creation +// or tool loading fails. func NewAgent(ctx context.Context, config *AgentConfig) (*Agent, error) { // Create the LLM provider providerResult, err := models.CreateProvider(ctx, config.ModelConfig) @@ -98,20 +115,27 @@ func NewAgent(ctx context.Context, config *AgentConfig) (*Agent, error) { }, nil } -// GenerateWithLoopResult contains the result and conversation history +// GenerateWithLoopResult contains the result and conversation history from an agent interaction. +// It includes both the final response and the complete message history with tool interactions. type GenerateWithLoopResult struct { - FinalResponse *schema.Message + // FinalResponse is the last message generated by the model + FinalResponse *schema.Message + // ConversationMessages contains all messages in the conversation including tool calls and results ConversationMessages []*schema.Message // All messages in the conversation (including tool calls and results) } -// GenerateWithLoop processes messages with a custom loop that displays tool calls in real-time +// GenerateWithLoop processes messages with a custom loop that displays tool calls in real-time. +// It handles the conversation flow, executing tools as needed and invoking callbacks for various events. +// This method does not support streaming responses; use GenerateWithLoopAndStreaming for streaming support. func (a *Agent) GenerateWithLoop(ctx context.Context, messages []*schema.Message, onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler) (*GenerateWithLoopResult, error) { return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult, onResponse, onToolCallContent, nil) } -// GenerateWithLoopAndStreaming processes messages with a custom loop that displays tool calls in real-time and supports streaming callbacks +// GenerateWithLoopAndStreaming processes messages with a custom loop that displays tool calls in real-time and supports streaming callbacks. +// It handles the conversation flow, executing tools as needed and invoking callbacks for various events including streaming chunks. +// The onStreamingResponse callback is invoked for each content chunk during streaming if streaming is enabled. func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []*schema.Message, onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler, onStreamingResponse StreamingResponseHandler) (*GenerateWithLoopResult, error) { @@ -256,17 +280,20 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []*sc }, nil } -// GetTools returns the list of available tools +// GetTools returns the list of available tools loaded in the agent. +// These tools are available for the model to use during interactions. func (a *Agent) GetTools() []tool.BaseTool { return a.toolManager.GetTools() } -// GetLoadingMessage returns the loading message from provider creation (e.g., GPU fallback info) +// GetLoadingMessage returns the loading message from provider creation. +// This may contain information about GPU fallback or other provider-specific initialization details. func (a *Agent) GetLoadingMessage() string { return a.loadingMessage } -// GetLoadedServerNames returns the names of successfully loaded MCP servers +// GetLoadedServerNames returns the names of successfully loaded MCP servers. +// This includes both builtin servers and external MCP server configurations. func (a *Agent) GetLoadedServerNames() []string { return a.toolManager.GetLoadedServerNames() } @@ -486,7 +513,8 @@ func (a *Agent) listenForESC(stopChan chan bool, readyChan chan bool) bool { } } -// Close closes the agent and cleans up resources +// Close closes the agent and cleans up resources. +// It ensures all MCP connections are properly closed and resources are released. func (a *Agent) Close() error { return a.toolManager.Close() } diff --git a/internal/agent/factory.go b/internal/agent/factory.go index 42cc4226..dfbe6657 100644 --- a/internal/agent/factory.go +++ b/internal/agent/factory.go @@ -10,23 +10,36 @@ import ( "github.com/mark3labs/mcphost/internal/tools" ) -// SpinnerFunc is a function type for showing spinners during agent creation +// SpinnerFunc is a function type for showing spinners during agent creation. +// It executes the provided function while displaying a spinner with the given message. type SpinnerFunc func(message string, fn func() error) error -// AgentCreationOptions contains options for creating an agent +// AgentCreationOptions contains options for creating an agent. +// It extends AgentConfig with UI-related options for showing progress during creation. type AgentCreationOptions struct { - ModelConfig *models.ProviderConfig - MCPConfig *config.Config - SystemPrompt string - MaxSteps int + // ModelConfig specifies the LLM provider and model to use + ModelConfig *models.ProviderConfig + // MCPConfig contains MCP server configurations + MCPConfig *config.Config + // SystemPrompt is the initial system message for the agent + SystemPrompt string + // MaxSteps limits the number of tool calls (0 for unlimited) + MaxSteps int + // StreamingEnabled controls whether responses are streamed StreamingEnabled bool - ShowSpinner bool // For Ollama models - Quiet bool // Skip spinner if quiet - SpinnerFunc SpinnerFunc // Function to show spinner (provided by caller) - DebugLogger tools.DebugLogger // Optional debug logger + // ShowSpinner indicates whether to show a spinner for Ollama models during loading + ShowSpinner bool // For Ollama models + // Quiet suppresses the spinner even if ShowSpinner is true + Quiet bool // Skip spinner if quiet + // SpinnerFunc is the function to show spinner, provided by the caller + SpinnerFunc SpinnerFunc // Function to show spinner (provided by caller) + // DebugLogger is an optional logger for debugging MCP communications + DebugLogger tools.DebugLogger // Optional debug logger } -// CreateAgent creates an agent with optional spinner for Ollama models +// CreateAgent creates an agent with optional spinner for Ollama models. +// It shows a loading spinner for Ollama models if ShowSpinner is true and not in quiet mode. +// Returns the created agent or an error if creation fails. func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error) { agentConfig := &AgentConfig{ ModelConfig: opts.ModelConfig, @@ -57,7 +70,9 @@ func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error return agent, nil } -// ParseModelName extracts provider and model name from model string +// ParseModelName extracts provider and model name from a model string. +// Model strings are formatted as "provider:model" (e.g., "anthropic:claude-3-5-sonnet-20241022"). +// If the string doesn't contain a colon, returns "unknown" for both provider and model. func ParseModelName(modelString string) (provider, model string) { parts := strings.SplitN(modelString, ":", 2) if len(parts) == 2 { diff --git a/internal/agent/streaming.go b/internal/agent/streaming.go index b877a3c1..ce6bb1de 100644 --- a/internal/agent/streaming.go +++ b/internal/agent/streaming.go @@ -8,7 +8,8 @@ import ( "github.com/cloudwego/eino/schema" ) -// StreamWithCallback streams content with real-time callbacks and returns complete response +// StreamWithCallback streams content with real-time callbacks and returns the complete response. +// It accumulates content and tool calls from the stream, invoking the callback for each content chunk. // IMPORTANT: Tool calls are only processed after EOF is reached to ensure we have the complete // and final tool call information. This prevents premature tool execution on partial data. // Handles different provider streaming patterns: diff --git a/internal/auth/browser.go b/internal/auth/browser.go index 4be76aad..35e6906f 100644 --- a/internal/auth/browser.go +++ b/internal/auth/browser.go @@ -6,7 +6,11 @@ import ( "runtime" ) -// OpenBrowser opens the default browser to the specified URL +// OpenBrowser opens the default web browser to the specified URL. +// It automatically detects the operating system and uses the appropriate +// command to launch the browser (xdg-open on Linux, rundll32 on Windows, +// open on macOS). Returns an error if the platform is unsupported or if +// the browser fails to launch. func OpenBrowser(url string) error { var err error @@ -24,7 +28,9 @@ func OpenBrowser(url string) error { return err } -// TryOpenBrowser attempts to open the browser but doesn't fail if it can't +// TryOpenBrowser attempts to open the default web browser to the specified URL +// but silently ignores any errors. This is useful when browser access is optional +// and users can manually copy and paste the URL if automatic browser launching fails. func TryOpenBrowser(url string) { // Silently ignore errors - user can still copy/paste the URL _ = OpenBrowser(url) diff --git a/internal/auth/credentials.go b/internal/auth/credentials.go index 94ec7abe..dd91f408 100644 --- a/internal/auth/credentials.go +++ b/internal/auth/credentials.go @@ -9,12 +9,16 @@ import ( "time" ) -// CredentialStore holds all stored credentials +// CredentialStore holds all stored credentials for various providers. +// Currently supports Anthropic credentials with both OAuth and API key authentication methods. type CredentialStore struct { Anthropic *AnthropicCredentials `json:"anthropic,omitempty"` } -// AnthropicCredentials holds Anthropic API credentials +// AnthropicCredentials holds Anthropic API credentials supporting both OAuth +// and API key authentication methods. The Type field indicates which authentication +// method is being used. For OAuth, tokens are stored with expiration timestamps +// for automatic refresh. For API keys, only the key itself is stored. type AnthropicCredentials struct { Type string `json:"type"` // "oauth" or "api_key" APIKey string `json:"api_key,omitempty"` // For API key auth @@ -24,7 +28,8 @@ type AnthropicCredentials struct { CreatedAt time.Time `json:"created_at"` } -// IsExpired checks if the OAuth token is expired +// IsExpired checks if the OAuth token is expired based on the ExpiresAt timestamp. +// Returns false for API key authentication or if no expiration is set. func (c *AnthropicCredentials) IsExpired() bool { if c.Type != "oauth" || c.ExpiresAt == 0 { return false @@ -32,7 +37,10 @@ func (c *AnthropicCredentials) IsExpired() bool { return time.Now().Unix() >= c.ExpiresAt } -// NeedsRefresh checks if the OAuth token needs refresh (5 minutes before expiry) +// NeedsRefresh checks if the OAuth token needs refresh, returning true if the token +// will expire within the next 5 minutes. This allows for proactive token refresh +// to avoid authentication failures during operations. Returns false for API key +// authentication or if no expiration is set. func (c *AnthropicCredentials) NeedsRefresh() bool { if c.Type != "oauth" || c.ExpiresAt == 0 { return false @@ -40,12 +48,17 @@ func (c *AnthropicCredentials) NeedsRefresh() bool { return time.Now().Unix() >= (c.ExpiresAt - 300) // 5 minutes buffer } -// CredentialManager handles credential storage and retrieval +// CredentialManager handles secure storage and retrieval of authentication credentials. +// It manages a JSON file stored in the user's config directory with appropriate +// file permissions for security. type CredentialManager struct { credentialsPath string } -// NewCredentialManager creates a new credential manager +// NewCredentialManager creates a new credential manager instance. It determines +// the appropriate credentials path based on XDG_CONFIG_HOME or falls back to +// ~/.config/.mcphost/credentials.json. Returns an error if the home directory +// cannot be determined. func NewCredentialManager() (*CredentialManager, error) { credentialsPath, err := getCredentialsPath() if err != nil { @@ -73,7 +86,9 @@ func getCredentialsPath() (string, error) { return filepath.Join(homeDir, ".config", ".mcphost", "credentials.json"), nil } -// LoadCredentials loads credentials from the file +// LoadCredentials loads credentials from the JSON file. If the file doesn't exist, +// it returns an empty CredentialStore instead of an error, allowing for graceful +// initialization. Returns an error if the file exists but cannot be read or parsed. func (cm *CredentialManager) LoadCredentials() (*CredentialStore, error) { // If file doesn't exist, return empty store if _, err := os.Stat(cm.credentialsPath); os.IsNotExist(err) { @@ -93,7 +108,10 @@ func (cm *CredentialManager) LoadCredentials() (*CredentialStore, error) { return &store, nil } -// SaveCredentials saves credentials to the file +// SaveCredentials saves credentials to the JSON file with secure permissions (0600). +// It creates the parent directory if it doesn't exist. The file is written atomically +// to prevent corruption. Returns an error if the directory cannot be created or the +// file cannot be written. func (cm *CredentialManager) SaveCredentials(store *CredentialStore) error { // Ensure directory exists dir := filepath.Dir(cm.credentialsPath) @@ -114,7 +132,10 @@ func (cm *CredentialManager) SaveCredentials(store *CredentialStore) error { return nil } -// SetAnthropicCredentials stores Anthropic API credentials (for API key auth) +// SetAnthropicCredentials stores Anthropic API key credentials. It validates the +// API key format before storing. The API key must start with "sk-ant-" and be +// at least 20 characters long. Returns an error if the API key is invalid or +// if storage fails. func (cm *CredentialManager) SetAnthropicCredentials(apiKey string) error { if err := validateAnthropicAPIKey(apiKey); err != nil { return err @@ -134,7 +155,9 @@ func (cm *CredentialManager) SetAnthropicCredentials(apiKey string) error { return cm.SaveCredentials(store) } -// GetAnthropicCredentials retrieves Anthropic API credentials +// GetAnthropicCredentials retrieves stored Anthropic credentials. Returns nil if +// no credentials are stored. The returned credentials may be either OAuth or API +// key type, check the Type field to determine which. func (cm *CredentialManager) GetAnthropicCredentials() (*AnthropicCredentials, error) { store, err := cm.LoadCredentials() if err != nil { @@ -144,7 +167,9 @@ func (cm *CredentialManager) GetAnthropicCredentials() (*AnthropicCredentials, e return store.Anthropic, nil } -// RemoveAnthropicCredentials removes stored Anthropic credentials +// RemoveAnthropicCredentials removes stored Anthropic credentials from storage. +// If this was the only credential stored, the entire credentials file is removed. +// Returns an error if the removal fails. func (cm *CredentialManager) RemoveAnthropicCredentials() error { store, err := cm.LoadCredentials() if err != nil { @@ -164,7 +189,9 @@ func (cm *CredentialManager) RemoveAnthropicCredentials() error { return cm.SaveCredentials(store) } -// HasAnthropicCredentials checks if Anthropic credentials are stored +// HasAnthropicCredentials checks if valid Anthropic credentials are stored. +// Returns true if either a non-empty OAuth access token or API key is present, +// false otherwise. Returns an error if credentials cannot be loaded. func (cm *CredentialManager) HasAnthropicCredentials() (bool, error) { creds, err := cm.GetAnthropicCredentials() if err != nil { @@ -185,7 +212,8 @@ func (cm *CredentialManager) HasAnthropicCredentials() (bool, error) { } } -// GetCredentialsPath returns the path to the credentials file +// GetCredentialsPath returns the absolute path to the credentials JSON file. +// This is useful for debugging or displaying the storage location to users. func (cm *CredentialManager) GetCredentialsPath() string { return cm.credentialsPath } @@ -210,8 +238,12 @@ func validateAnthropicAPIKey(apiKey string) error { return nil } -// GetAnthropicAPIKey is a convenience function that checks stored credentials first, -// then falls back to environment variables and flags +// GetAnthropicAPIKey retrieves an Anthropic API key from multiple sources in priority order: +// 1. Command-line flag value (highest priority) +// 2. Stored credentials (OAuth or API key) +// 3. ANTHROPIC_API_KEY environment variable (lowest priority) +// Returns the API key, a description of its source, and any error encountered. +// For OAuth credentials, it automatically refreshes expired tokens. func GetAnthropicAPIKey(flagValue string) (string, string, error) { // 1. Check flag value first (highest priority) if flagValue != "" { diff --git a/internal/auth/oauth.go b/internal/auth/oauth.go index df60eb0b..7d528992 100644 --- a/internal/auth/oauth.go +++ b/internal/auth/oauth.go @@ -13,7 +13,9 @@ import ( "time" ) -// OAuthClient handles OAuth authentication with Anthropic +// OAuthClient handles OAuth 2.0 authentication flow with Anthropic using the +// PKCE (Proof Key for Code Exchange) extension for enhanced security in public clients. +// It manages the authorization URL generation, code exchange, and token refresh operations. type OAuthClient struct { ClientID string AuthorizeURL string @@ -22,13 +24,18 @@ type OAuthClient struct { Scopes string } -// AuthData contains authorization URL and PKCE verifier +// AuthData contains the authorization URL for user authentication and the PKCE +// verifier needed for the subsequent code exchange. The verifier must be stored +// securely and used when exchanging the authorization code for tokens. type AuthData struct { URL string Verifier string } -// NewOAuthClient creates a new OAuth client with Anthropic configuration +// NewOAuthClient creates a new OAuth client configured for Anthropic's OAuth service. +// The client uses a public client ID (as per OAuth 2.0 public client specification) +// with PKCE for security. The configuration includes the authorization endpoint, +// token endpoint, redirect URI, and required scopes for API key creation and inference. func NewOAuthClient() *OAuthClient { return &OAuthClient{ // OAuth client ID is public by design for CLI applications (OAuth public clients). @@ -42,7 +49,11 @@ func NewOAuthClient() *OAuthClient { } } -// GeneratePKCE generates PKCE verifier and challenge for OAuth flow +// GeneratePKCE generates a cryptographically secure PKCE verifier and challenge pair +// for the OAuth 2.0 PKCE flow. The verifier is a random 32-byte string encoded as +// base64url, and the challenge is the SHA256 hash of the verifier, also base64url encoded. +// Returns the verifier (to be stored securely), challenge (to be sent with auth request), +// and any error encountered during generation. func GeneratePKCE() (verifier, challenge string, err error) { // Generate 32 bytes of random data verifierBytes := make([]byte, 32) @@ -60,7 +71,10 @@ func GeneratePKCE() (verifier, challenge string, err error) { return verifier, challenge, nil } -// GetAuthorizationURL generates the authorization URL with PKCE parameters +// GetAuthorizationURL generates a complete authorization URL for the OAuth flow with +// PKCE parameters. The URL includes the client ID, redirect URI, requested scopes, +// and PKCE challenge. Returns an AuthData structure containing the URL for user +// authentication and the PKCE verifier for the subsequent code exchange. func (c *OAuthClient) GetAuthorizationURL() (*AuthData, error) { verifier, challenge, err := GeneratePKCE() if err != nil { @@ -86,7 +100,10 @@ func (c *OAuthClient) GetAuthorizationURL() (*AuthData, error) { }, nil } -// ExchangeCode exchanges an authorization code for tokens +// ExchangeCode exchanges an authorization code for access and refresh tokens. +// The code parameter should be the authorization code received from the OAuth callback. +// The verifier parameter must be the same PKCE verifier generated during GetAuthorizationURL. +// Returns AnthropicCredentials containing the tokens and expiration information. func (c *OAuthClient) ExchangeCode(code, verifier string) (*AnthropicCredentials, error) { // Parse code and state parsedCode, parsedState := c.parseCodeAndState(code) @@ -109,7 +126,10 @@ func (c *OAuthClient) ExchangeCode(code, verifier string) (*AnthropicCredentials return c.makeTokenRequest(reqBody) } -// RefreshToken refreshes an access token using a refresh token +// RefreshToken refreshes an expired or expiring access token using a refresh token. +// Returns new AnthropicCredentials with updated access token, refresh token (may be +// rotated), and new expiration timestamp. Returns an error if the refresh fails or +// the refresh token is invalid. func (c *OAuthClient) RefreshToken(refreshToken string) (*AnthropicCredentials, error) { reqBody := map[string]interface{}{ "grant_type": "refresh_token", @@ -179,7 +199,9 @@ func (c *OAuthClient) parseCodeAndState(code string) (parsedCode, parsedState st return } -// SetOAuthCredentials stores OAuth credentials +// SetOAuthCredentials stores OAuth credentials in the credential manager's secure storage. +// The credentials should include access token, refresh token, and expiration information. +// Returns an error if the credentials cannot be saved. func (cm *CredentialManager) SetOAuthCredentials(creds *AnthropicCredentials) error { store, err := cm.LoadCredentials() if err != nil { @@ -190,7 +212,10 @@ func (cm *CredentialManager) SetOAuthCredentials(creds *AnthropicCredentials) er return cm.SaveCredentials(store) } -// GetValidAccessToken returns a valid access token, refreshing if necessary +// GetValidAccessToken returns a valid access token for API requests. For OAuth credentials, +// it automatically refreshes the token if it's expired or about to expire. For API key +// credentials, it simply returns the API key. Returns an error if no credentials are found, +// if token refresh fails, or if the credential type is unknown. func (cm *CredentialManager) GetValidAccessToken() (string, error) { creds, err := cm.GetAnthropicCredentials() if err != nil { diff --git a/internal/builtin/bash.go b/internal/builtin/bash.go index 1afb43f5..ff9ffa12 100644 --- a/internal/builtin/bash.go +++ b/internal/builtin/bash.go @@ -37,7 +37,10 @@ var bannedCommands = []string{ "safari", } -// NewBashServer creates a new bash MCP server +// NewBashServer creates a new MCP server that provides bash command execution capabilities. +// The server includes a single tool "run_shell_cmd" that executes shell commands with +// security restrictions, timeout controls, and output truncation. Returns an error if +// server initialization fails. func NewBashServer() (*server.MCPServer, error) { s := server.NewMCPServer("bash-server", "1.0.0", server.WithToolCapabilities(true)) diff --git a/internal/builtin/fetch.go b/internal/builtin/fetch.go index 1276369e..1af3b8f3 100644 --- a/internal/builtin/fetch.go +++ b/internal/builtin/fetch.go @@ -21,7 +21,9 @@ const ( maxFetchTimeout = 120 * time.Second ) -// NewFetchServer creates a new fetch MCP server +// NewFetchServer creates a new MCP server that provides web content fetching capabilities. +// The server includes a single tool "fetch" that retrieves content from URLs and converts +// it to text, markdown, or HTML format. Returns an error if server initialization fails. func NewFetchServer() (*server.MCPServer, error) { s := server.NewMCPServer("fetch-server", "1.0.0", server.WithToolCapabilities(true)) diff --git a/internal/builtin/http.go b/internal/builtin/http.go index c6cce1d0..49ec01f2 100644 --- a/internal/builtin/http.go +++ b/internal/builtin/http.go @@ -28,7 +28,11 @@ const ( // httpServerModel holds the model for the HTTP server var httpServerModel model.ToolCallingChatModel -// NewHTTPServer creates a new HTTP MCP server +// NewHTTPServer creates a new MCP server providing advanced HTTP fetching capabilities. +// The server includes tools for fetching web content, summarizing pages, extracting +// specific information, and filtering JSON responses. If an LLM model is provided, +// AI-powered summarization and extraction tools are enabled. Returns an error if +// server initialization fails. func NewHTTPServer(llmModel model.ToolCallingChatModel) (*server.MCPServer, error) { // Store the model globally for use in tool handlers httpServerModel = llmModel diff --git a/internal/builtin/registry.go b/internal/builtin/registry.go index ea599dc0..1bb24a82 100644 --- a/internal/builtin/registry.go +++ b/internal/builtin/registry.go @@ -9,28 +9,36 @@ import ( "github.com/mark3labs/mcp-go/server" ) -// BuiltinServerWrapper wraps an external MCP server for builtin use +// BuiltinServerWrapper wraps an external MCP server for builtin use, providing +// a consistent interface for all builtin servers regardless of their implementation. type BuiltinServerWrapper struct { server *server.MCPServer } -// Initialize initializes the wrapped server +// Initialize initializes the wrapped server. For builtin servers, this is typically +// a no-op as the server is initialized during creation. Returns an error if +// initialization fails. func (w *BuiltinServerWrapper) Initialize() error { // The server is already initialized when created return nil } -// GetServer returns the wrapped MCP server +// GetServer returns the wrapped MCP server instance that can be used to handle +// tool calls and other MCP protocol operations. func (w *BuiltinServerWrapper) GetServer() *server.MCPServer { return w.server } -// Registry holds all available builtin servers +// Registry holds all available builtin servers and their factory functions. +// It provides a centralized registry for creating instances of builtin MCP servers +// with their respective configurations. type Registry struct { servers map[string]func(options map[string]any, model model.ToolCallingChatModel) (*BuiltinServerWrapper, error) } -// NewRegistry creates a new builtin server registry +// NewRegistry creates a new builtin server registry with all available builtin +// servers registered. The registry includes filesystem (fs), bash, todo, fetch, +// and HTTP servers. func NewRegistry() *Registry { r := &Registry{ servers: make(map[string]func(options map[string]any, model model.ToolCallingChatModel) (*BuiltinServerWrapper, error)), @@ -46,7 +54,10 @@ func NewRegistry() *Registry { return r } -// CreateServer creates a new instance of a builtin server +// CreateServer creates a new instance of a builtin server by name. The options +// parameter provides server-specific configuration, and the model parameter provides +// an optional LLM for AI-powered features. Returns an error if the server name +// is unknown or if creation fails. func (r *Registry) CreateServer(name string, options map[string]any, model model.ToolCallingChatModel) (*BuiltinServerWrapper, error) { factory, exists := r.servers[name] if !exists { @@ -56,7 +67,8 @@ func (r *Registry) CreateServer(name string, options map[string]any, model model return factory(options, model) } -// ListServers returns a list of available builtin server names +// ListServers returns a list of all available builtin server names that can be +// created using CreateServer. The order of names is not guaranteed. func (r *Registry) ListServers() []string { names := make([]string, 0, len(r.servers)) for name := range r.servers { diff --git a/internal/builtin/todo.go b/internal/builtin/todo.go index c9ded972..4e4ad5be 100644 --- a/internal/builtin/todo.go +++ b/internal/builtin/todo.go @@ -11,7 +11,9 @@ import ( "github.com/mark3labs/mcp-go/server" ) -// TodoInfo represents a single todo item +// TodoInfo represents a single todo item with content, status, priority, and ID. +// Status can be "pending", "in_progress", or "completed". Priority can be "high", +// "medium", or "low". Each todo must have a unique ID. type TodoInfo struct { Content string `json:"content"` Status string `json:"status"` @@ -19,13 +21,18 @@ type TodoInfo struct { ID string `json:"id"` } -// TodoServer implements a todo management MCP server with in-memory storage +// TodoServer implements a todo management MCP server with in-memory storage. +// It provides thread-safe operations for reading and writing todo lists, with +// support for task status tracking and priority levels. type TodoServer struct { todos []TodoInfo mutex sync.RWMutex } -// NewTodoServer creates a new todo MCP server with in-memory storage +// NewTodoServer creates a new MCP server that provides todo list management capabilities. +// The server includes two tools: "todowrite" for updating the todo list and "todoread" +// for retrieving the current list. Todos are stored in memory and not persisted. +// Returns an error if server initialization fails. func NewTodoServer() (*server.MCPServer, error) { todoServer := &TodoServer{ todos: make([]TodoInfo, 0), diff --git a/internal/config/config.go b/internal/config/config.go index b5f5062f..55a32a76 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -11,7 +11,9 @@ import ( "gopkg.in/yaml.v3" ) -// MCPServerConfig represents configuration for an MCP server +// MCPServerConfig represents configuration for an MCP server, supporting both +// local (stdio), remote (StreamableHTTP/SSE), and builtin (in-process) server types. +// It maintains backward compatibility with legacy configuration formats. type MCPServerConfig struct { Type string `json:"type"` Command []string `json:"command,omitempty"` @@ -29,7 +31,9 @@ type MCPServerConfig struct { Headers []string `json:"headers,omitempty"` } -// UnmarshalJSON handles both new and legacy config formats +// UnmarshalJSON handles both new and legacy config formats for backward compatibility. +// New format uses "type" field with "local", "remote", or "builtin" values. +// Legacy format uses "transport", "command", "args", and "env" fields. func (s *MCPServerConfig) UnmarshalJSON(data []byte) error { // First try to unmarshal as the new format type newFormat struct { @@ -100,11 +104,15 @@ func (s *MCPServerConfig) UnmarshalJSON(data []byte) error { return nil } +// AdaptiveColor represents a color that adapts to light and dark themes. +// Either light or dark can be specified, or both for theme-aware coloring. type AdaptiveColor struct { Light string `json:"light,omitempty" yaml:"light,omitempty"` Dark string `json:"dark,omitempty" yaml:"dark,omitempty"` } +// Theme defines the color scheme for the application UI with adaptive colors +// that support both light and dark modes. type Theme struct { Primary AdaptiveColor `json:"primary" yaml:"primary"` Secondary AdaptiveColor `json:"secondary" yaml:"secondary"` @@ -124,6 +132,8 @@ type Theme struct { Highlight AdaptiveColor `json:"highlight" yaml:"highlight"` } +// MarkdownTheme defines the color scheme for markdown rendering with syntax +// highlighting support and adaptive colors for light and dark modes. type MarkdownTheme struct { Text AdaptiveColor `json:"text" yaml:"text"` Muted AdaptiveColor `json:"muted" yaml:"muted"` @@ -139,7 +149,9 @@ type MarkdownTheme struct { Comment AdaptiveColor `json:"comment" yaml:"comment"` } -// Config represents the application configuration +// Config represents the complete application configuration including MCP servers, +// model settings, UI preferences, and API credentials. It supports both command-line +// flags and configuration file settings. type Config struct { MCPServers map[string]MCPServerConfig `json:"mcpServers" yaml:"mcpServers"` Model string `json:"model,omitempty" yaml:"model,omitempty"` @@ -166,7 +178,9 @@ type Config struct { TLSSkipVerify bool `json:"tls-skip-verify,omitempty" yaml:"tls-skip-verify,omitempty"` } -// GetTransportType returns the transport type for the server config +// GetTransportType returns the transport type for the server config, mapping +// simplified type names to actual transport protocols. Supports legacy format +// detection and automatic type inference from configuration. func (s *MCPServerConfig) GetTransportType() string { // Legacy format support - check explicit transport first if s.Transport != "" { @@ -197,7 +211,9 @@ func (s *MCPServerConfig) GetTransportType() string { return "stdio" // default } -// Validate validates the configuration +// Validate validates the configuration, ensuring required fields are present +// for each server type and that tool filters are used correctly. Returns an +// error describing any validation failures. func (c *Config) Validate() error { for serverName, serverConfig := range c.MCPServers { if len(serverConfig.AllowedTools) > 0 && len(serverConfig.ExcludedTools) > 0 { @@ -226,7 +242,9 @@ func (c *Config) Validate() error { return nil } -// LoadSystemPrompt loads system prompt from file or returns the string directly +// LoadSystemPrompt loads system prompt from file or returns the string directly. +// If input is a path to an existing file, its contents are read and returned. +// Otherwise, the input string is returned as-is. func LoadSystemPrompt(input string) (string, error) { if input == "" { return "", nil @@ -246,7 +264,9 @@ func LoadSystemPrompt(input string) (string, error) { return input, nil } -// EnsureConfigExists checks if a config file exists and creates a default one if not +// EnsureConfigExists checks if a config file exists and creates a default one if not. +// It searches for .mcphost.{yml,yaml,json} or legacy .mcp.{yml,yaml,json} files in +// the user's home directory. If none exist, creates a default .mcphost.yml with examples. func EnsureConfigExists() error { homeDir, err := os.UserHomeDir() if err != nil { @@ -378,6 +398,10 @@ mcpServers: return nil } +// FilepathOr reads a configuration value that can be either a direct value or a +// filepath to a JSON/YAML file containing the value. If the value is a string +// starting with "~/" or a relative path, it's expanded to an absolute path. +// The contents of the file are then unmarshaled into the provided value pointer. func FilepathOr[T any](key string, value *T) error { var field any err := viper.UnmarshalKey(key, &field) @@ -428,6 +452,9 @@ func FilepathOr[T any](key string, value *T) error { var configPath string +// SetConfigPath sets the configuration file path for resolving relative paths +// in configuration values. This should be called when the configuration file +// location is known. func SetConfigPath(path string) { configPath = path } diff --git a/internal/config/merger.go b/internal/config/merger.go index 95d41804..1ada1959 100644 --- a/internal/config/merger.go +++ b/internal/config/merger.go @@ -7,7 +7,9 @@ import ( "github.com/spf13/viper" ) -// MergeConfigs merges script frontmatter config with base config +// MergeConfigs merges script frontmatter config with base config, allowing scripts +// to override MCP server configurations. The script config takes precedence over +// the base config for any fields that are specified. func MergeConfigs(baseConfig *Config, scriptConfig *Config) *Config { merged := *baseConfig // Copy base config @@ -20,7 +22,9 @@ func MergeConfigs(baseConfig *Config, scriptConfig *Config) *Config { return &merged } -// LoadAndValidateConfig loads config from viper and validates it +// LoadAndValidateConfig loads configuration from viper, fixes environment variable +// casing issues, and validates the configuration. Returns an error if loading or +// validation fails. func LoadAndValidateConfig() (*Config, error) { config := &Config{ MCPServers: make(map[string]MCPServerConfig), diff --git a/internal/config/substitution.go b/internal/config/substitution.go index 567a826f..9dbc7a8f 100644 --- a/internal/config/substitution.go +++ b/internal/config/substitution.go @@ -24,10 +24,13 @@ func parseVariableWithDefault(varPart string) (varName, defaultValue string, has return varPart, "", false } -// EnvSubstituter handles environment variable substitution +// EnvSubstituter handles environment variable substitution in configuration strings, +// supporting both ${env://VAR} and ${env://VAR:-default} patterns. type EnvSubstituter struct{} -// SubstituteEnvVars replaces ${env://VAR} and ${env://VAR:-default} patterns with environment variables +// SubstituteEnvVars replaces ${env://VAR} and ${env://VAR:-default} patterns with environment variables. +// If a variable is not set and has a default value, the default is used. Returns an error +// if required variables (those without defaults) are not set. func (e *EnvSubstituter) SubstituteEnvVars(content string) (string, error) { var errors []string @@ -57,17 +60,21 @@ func (e *EnvSubstituter) SubstituteEnvVars(content string) (string, error) { return result, nil } -// ArgsSubstituter handles script argument substitution +// ArgsSubstituter handles script argument substitution in configuration strings, +// supporting both ${VAR} and ${VAR:-default} patterns for template variable replacement. type ArgsSubstituter struct { args map[string]string } -// NewArgsSubstituter creates a new args substituter with the given arguments +// NewArgsSubstituter creates a new args substituter with the given arguments map. +// The arguments are used to replace template variables in configuration strings. func NewArgsSubstituter(args map[string]string) *ArgsSubstituter { return &ArgsSubstituter{args: args} } -// SubstituteArgs replaces ${VAR} and ${VAR:-default} patterns with script arguments +// SubstituteArgs replaces ${VAR} and ${VAR:-default} patterns with script arguments. +// If an argument is not provided and has a default value, the default is used. +// Returns an error if required arguments (those without defaults) are not provided. func (a *ArgsSubstituter) SubstituteArgs(content string) (string, error) { var errors []string @@ -97,12 +104,14 @@ func (a *ArgsSubstituter) SubstituteArgs(content string) (string, error) { return result, nil } -// HasEnvVars checks if content contains environment variable patterns +// HasEnvVars checks if content contains environment variable patterns (${env://...}). +// This is useful for determining if substitution is needed before processing. func HasEnvVars(content string) bool { return envVarPattern.MatchString(content) } -// HasScriptArgs checks if content contains script argument patterns +// HasScriptArgs checks if content contains script argument patterns (${...}). +// This is useful for determining if argument substitution is needed before processing. func HasScriptArgs(content string) bool { return scriptArgsPattern.MatchString(content) } diff --git a/internal/hooks/config.go b/internal/hooks/config.go index dc59c15e..dc4b673d 100644 --- a/internal/hooks/config.go +++ b/internal/hooks/config.go @@ -9,26 +9,35 @@ import ( "path/filepath" ) -// HookConfig represents the complete hooks configuration +// HookConfig represents the complete hooks configuration containing event-triggered +// hooks for tool execution lifecycle events. type HookConfig struct { Hooks map[HookEvent][]HookMatcher `yaml:"hooks" json:"hooks"` } -// HookMatcher matches specific tools and defines hooks to execute +// HookMatcher matches specific tools and defines hooks to execute. The Matcher field +// contains a pattern to match tool names, and Hooks contains the commands to run +// when a match occurs. The Merge field controls how this matcher combines with others. type HookMatcher struct { Matcher string `yaml:"matcher,omitempty" json:"matcher,omitempty"` Merge string `yaml:"_merge,omitempty" json:"_merge,omitempty"` Hooks []HookEntry `yaml:"hooks" json:"hooks"` } -// HookEntry defines a single hook command +// HookEntry defines a single hook command to execute. Type specifies the command +// type (e.g., "bash"), Command contains the actual command to run, and Timeout +// optionally specifies the maximum execution time in seconds. type HookEntry struct { Type string `yaml:"type" json:"type"` Command string `yaml:"command" json:"command"` Timeout int `yaml:"timeout,omitempty" json:"timeout,omitempty"` } -// LoadHooksConfig loads and merges hook configurations from multiple sources +// LoadHooksConfig loads and merges hook configurations from multiple sources. +// It searches for hooks.{json,yml} files in standard locations (XDG config directory, +// local .mcphost directory) and any custom paths provided. Configurations are merged +// with later sources taking precedence. Environment variable substitution is applied +// to all loaded configurations. func LoadHooksConfig(customPaths ...string) (*HookConfig, error) { // Get config directory following XDG Base Directory specification configDir := getConfigDir() diff --git a/internal/hooks/events.go b/internal/hooks/events.go index 7b61397a..4dfe1ba7 100644 --- a/internal/hooks/events.go +++ b/internal/hooks/events.go @@ -1,23 +1,25 @@ package hooks -// HookEvent represents a point in MCPHost's lifecycle where hooks can be executed +// HookEvent represents a point in MCPHost's lifecycle where hooks can be executed. +// Events can be tool-related (requiring matchers) or lifecycle-related. type HookEvent string const ( - // PreToolUse fires before any tool execution + // PreToolUse fires before any tool execution, allowing pre-processing or validation PreToolUse HookEvent = "PreToolUse" - // PostToolUse fires after tool execution completes + // PostToolUse fires after tool execution completes, allowing post-processing or logging PostToolUse HookEvent = "PostToolUse" - // UserPromptSubmit fires when user submits a prompt + // UserPromptSubmit fires when user submits a prompt, before agent processing UserPromptSubmit HookEvent = "UserPromptSubmit" - // Stop fires when the main agent finishes responding + // Stop fires when the main agent finishes responding to a user prompt Stop HookEvent = "Stop" ) -// IsValid returns true if the event is a valid hook event +// IsValid returns true if the event is a valid hook event. +// Valid events are PreToolUse, PostToolUse, UserPromptSubmit, and Stop. func (e HookEvent) IsValid() bool { switch e { case PreToolUse, PostToolUse, UserPromptSubmit, Stop: @@ -26,7 +28,9 @@ func (e HookEvent) IsValid() bool { return false } -// RequiresMatcher returns true if the event uses tool matchers +// RequiresMatcher returns true if the event uses tool matchers. +// PreToolUse and PostToolUse events require matchers to determine which +// tools trigger the hooks. Other events apply globally without matchers. func (e HookEvent) RequiresMatcher() bool { return e == PreToolUse || e == PostToolUse } diff --git a/internal/hooks/executor.go b/internal/hooks/executor.go index 0c050f10..ab5b5047 100644 --- a/internal/hooks/executor.go +++ b/internal/hooks/executor.go @@ -12,7 +12,9 @@ import ( "time" ) -// Executor handles hook execution +// Executor handles hook execution for MCPHost lifecycle events. It manages +// hook configuration, executes matching hooks in parallel, and processes +// their outputs to determine application behavior. type Executor struct { config *HookConfig sessionID string @@ -22,7 +24,9 @@ type Executor struct { mu sync.RWMutex } -// NewExecutor creates a new hook executor +// NewExecutor creates a new hook executor with the given configuration, +// session ID, and transcript path. The executor manages hook execution +// throughout the application lifecycle. func NewExecutor(config *HookConfig, sessionID, transcriptPath string) *Executor { return &Executor{ config: config, @@ -31,21 +35,25 @@ func NewExecutor(config *HookConfig, sessionID, transcriptPath string) *Executor } } -// SetModel sets the model name for hook context +// SetModel sets the model name for hook context. This information is passed +// to hooks as part of their input data for context-aware processing. func (e *Executor) SetModel(model string) { e.mu.Lock() defer e.mu.Unlock() e.model = model } -// SetInteractive sets whether we're in interactive mode +// SetInteractive sets whether the application is running in interactive mode. +// This information is passed to hooks for mode-specific behavior. func (e *Executor) SetInteractive(interactive bool) { e.mu.Lock() defer e.mu.Unlock() e.interactive = interactive } -// PopulateCommonFields fills in the common fields for any hook input +// PopulateCommonFields fills in the common fields for any hook input, including +// session ID, transcript path, working directory, event name, timestamp, model, +// and interactive mode. These fields provide context to hooks regardless of event type. func (e *Executor) PopulateCommonFields(event HookEvent) CommonInput { e.mu.RLock() defer e.mu.RUnlock() @@ -62,7 +70,10 @@ func (e *Executor) PopulateCommonFields(event HookEvent) CommonInput { } } -// ExecuteHooks runs all matching hooks for an event +// ExecuteHooks runs all matching hooks for an event. For tool-related events, +// it matches hooks based on tool name patterns. Hooks are executed in parallel +// with configurable timeouts. Returns a combined HookOutput from all executed +// hooks, with blocking decisions taking precedence. func (e *Executor) ExecuteHooks(ctx context.Context, event HookEvent, input interface{}) (*HookOutput, error) { matchers, ok := e.config.Hooks[event] if !ok || len(matchers) == 0 { diff --git a/internal/hooks/schemas.go b/internal/hooks/schemas.go index 9d79ca92..01091534 100644 --- a/internal/hooks/schemas.go +++ b/internal/hooks/schemas.go @@ -4,7 +4,9 @@ import ( "encoding/json" ) -// CommonInput contains fields common to all hook inputs +// CommonInput contains fields common to all hook inputs, providing context +// information that is available to every hook regardless of the event type. +// These fields help hooks understand the execution environment and session state. type CommonInput struct { SessionID string `json:"session_id"` // Unique session identifier TranscriptPath string `json:"transcript_path"` // Path to transcript file (if enabled) @@ -15,14 +17,18 @@ type CommonInput struct { Interactive bool `json:"interactive"` // Whether in interactive mode } -// PreToolUseInput is passed to PreToolUse hooks +// PreToolUseInput is passed to PreToolUse hooks before a tool is executed. +// It contains the tool name and input parameters, allowing hooks to validate, +// modify, or block tool execution. type PreToolUseInput struct { CommonInput ToolName string `json:"tool_name"` ToolInput json.RawMessage `json:"tool_input"` } -// PostToolUseInput is passed to PostToolUse hooks +// PostToolUseInput is passed to PostToolUse hooks after a tool has been executed. +// It contains the tool name, input parameters, and the tool's response, allowing +// hooks to log, analyze, or react to tool execution results. type PostToolUseInput struct { CommonInput ToolName string `json:"tool_name"` @@ -30,13 +36,17 @@ type PostToolUseInput struct { ToolResponse json.RawMessage `json:"tool_response"` } -// UserPromptSubmitInput is passed to UserPromptSubmit hooks +// UserPromptSubmitInput is passed to UserPromptSubmit hooks when a user submits +// a prompt. It contains the user's input text, allowing hooks to validate, +// modify, or log user interactions before processing. type UserPromptSubmitInput struct { CommonInput Prompt string `json:"prompt"` } -// StopInput is passed to Stop hooks +// StopInput is passed to Stop hooks when the agent finishes responding to a prompt. +// It contains the final response, completion reason, and optional metadata about +// the interaction, allowing hooks to perform cleanup or logging operations. type StopInput struct { CommonInput StopHookActive bool `json:"stop_hook_active"` @@ -45,7 +55,10 @@ type StopInput struct { Meta json.RawMessage `json:"meta,omitempty"` // Additional metadata (e.g., token usage, model info) } -// HookOutput represents the JSON output from a hook +// HookOutput represents the JSON output from a hook that controls MCPHost behavior. +// Hooks can decide whether to continue execution, provide reasons for stopping, +// suppress output, or block tool execution. The Decision field can be "approve", +// "block", or empty (default behavior). type HookOutput struct { Continue *bool `json:"continue,omitempty"` StopReason string `json:"stopReason,omitempty"` diff --git a/internal/hooks/validator.go b/internal/hooks/validator.go index 504b5e68..bf0d6c44 100644 --- a/internal/hooks/validator.go +++ b/internal/hooks/validator.go @@ -68,7 +68,10 @@ func containsDangerousPattern(command string) bool { return separatorCount > 2 } -// ValidateHookConfig validates the entire hook configuration +// ValidateHookConfig validates the entire hook configuration for correctness +// and security. It checks event validity, regex patterns, hook definitions, +// and performs security validation on all commands. Returns an error describing +// any validation failures. func ValidateHookConfig(config *HookConfig) error { if config == nil { return fmt.Errorf("nil configuration") diff --git a/internal/models/anthropic/anthropic.go b/internal/models/anthropic/anthropic.go index 3522d068..7f456913 100644 --- a/internal/models/anthropic/anthropic.go +++ b/internal/models/anthropic/anthropic.go @@ -13,17 +13,42 @@ import ( "github.com/cloudwego/eino/schema" ) -// CustomChatModel wraps the eino-ext Claude model with custom tool schema handling +// CustomChatModel wraps the eino-ext Claude model with custom tool schema handling. +// It provides a compatibility layer that fixes malformed JSON in tool calls and +// ensures proper schema validation for Anthropic's API requirements. +// This wrapper is necessary to handle edge cases where the underlying library +// may generate invalid JSON for empty tool inputs or missing properties. type CustomChatModel struct { + // wrapped is the underlying eino-ext Claude model instance wrapped *einoclaude.ChatModel } -// CustomRoundTripper intercepts HTTP requests to fix Anthropic function schemas +// CustomRoundTripper intercepts HTTP requests to fix Anthropic function schemas. +// It acts as a middleware that modifies requests before they reach the Anthropic API, +// ensuring that tool schemas and function calls are properly formatted. +// This is particularly important for handling edge cases like empty tool inputs +// or missing schema properties that would otherwise cause API errors. type CustomRoundTripper struct { + // wrapped is the underlying HTTP transport to use for actual requests wrapped http.RoundTripper } -// NewCustomChatModel creates a new custom Anthropic chat model +// NewCustomChatModel creates a new custom Anthropic chat model. +// It wraps the standard eino-ext Claude model with additional request +// preprocessing to ensure compatibility with Anthropic's API requirements. +// +// Parameters: +// - ctx: Context for the operation +// - config: Configuration for the Claude model including API key, model name, and parameters +// +// Returns: +// - *CustomChatModel: A wrapped Claude model with enhanced compatibility +// - error: Returns an error if model creation fails +// +// The custom model automatically: +// - Fixes malformed JSON in tool calls +// - Ensures tool schemas have required properties +// - Handles empty or missing input fields in function calls func NewCustomChatModel(ctx context.Context, config *einoclaude.Config) (*CustomChatModel, error) { // Create a custom HTTP client that intercepts requests if config.HTTPClient == nil { @@ -49,7 +74,21 @@ func NewCustomChatModel(ctx context.Context, config *einoclaude.Config) (*Custom }, nil } -// RoundTrip implements http.RoundTripper to intercept and fix requests +// RoundTrip implements http.RoundTripper to intercept and fix requests. +// It preprocesses outgoing requests to the Anthropic API to ensure +// they meet the API's requirements for tool schemas and function calls. +// +// Parameters: +// - req: The HTTP request to be sent to the Anthropic API +// +// Returns: +// - *http.Response: The response from the Anthropic API +// - error: Any error that occurred during the request +// +// The method performs the following fixes: +// - Ensures tool input_schema properties are not null +// - Fixes malformed JSON patterns in tool_use content +// - Validates and corrects empty or invalid function call inputs func (rt *CustomRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { // Only process Anthropic API requests if !strings.Contains(req.URL.Host, "anthropic.com") { @@ -191,17 +230,47 @@ func (rt *CustomRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro return rt.wrapped.RoundTrip(req) } -// Generate implements the model.BaseChatModel interface +// Generate implements the model.BaseChatModel interface. +// It generates a single response from the model based on the input messages. +// +// Parameters: +// - ctx: Context for the operation, supporting cancellation and deadlines +// - input: The conversation history as a slice of messages +// - opts: Optional configuration options for the generation +// +// Returns: +// - *schema.Message: The generated response message +// - error: Any error that occurred during generation func (m *CustomChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { return m.wrapped.Generate(ctx, input, opts...) } -// Stream implements the model.BaseChatModel interface +// Stream implements the model.BaseChatModel interface. +// It generates a streaming response from the model, allowing incremental +// processing of the model's output as it's generated. +// +// Parameters: +// - ctx: Context for the operation, supporting cancellation and deadlines +// - input: The conversation history as a slice of messages +// - opts: Optional configuration options for the generation +// +// Returns: +// - *schema.StreamReader[*schema.Message]: A reader for the streaming response +// - error: Any error that occurred during stream setup func (m *CustomChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { return m.wrapped.Stream(ctx, input, opts...) } -// WithTools implements the model.ToolCallingChatModel interface +// WithTools implements the model.ToolCallingChatModel interface. +// It creates a new model instance with the specified tools available for function calling. +// The original model instance remains unchanged. +// +// Parameters: +// - tools: A slice of tool definitions that the model can use +// +// Returns: +// - model.ToolCallingChatModel: A new model instance with tools enabled +// - error: Returns an error if tool binding fails func (m *CustomChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) { wrappedWithTools, err := m.wrapped.WithTools(tools) if err != nil { diff --git a/internal/models/gemini/gemini.go b/internal/models/gemini/gemini.go index a856fddf..0b1da68a 100644 --- a/internal/models/gemini/gemini.go +++ b/internal/models/gemini/gemini.go @@ -17,21 +17,27 @@ import ( var _ model.ToolCallingChatModel = (*ChatModel)(nil) -// NewChatModel creates a new Gemini chat model instance +// NewChatModel creates a new Gemini chat model instance. +// It initializes a Google Gemini model with the specified configuration, +// supporting both text generation and tool calling capabilities. // // Parameters: -// - ctx: The context for the operation -// - cfg: Configuration for the Gemini model +// - ctx: The context for the operation (currently unused but kept for interface consistency) +// - cfg: Configuration for the Gemini model including client, model name, and parameters // // Returns: -// - model.ChatModel: A chat model interface implementation +// - *ChatModel: A Gemini chat model instance implementing ToolCallingChatModel // - error: Any error that occurred during creation // // Example: // +// client, _ := genai.NewClient(ctx, &genai.ClientConfig{ +// APIKey: "your-api-key", +// }) // model, err := gemini.NewChatModel(ctx, &gemini.Config{ // Client: client, // Model: "gemini-pro", +// MaxTokens: &maxTokens, // }) func NewChatModel(_ context.Context, cfg *Config) (*ChatModel, error) { return &ChatModel{ @@ -90,28 +96,58 @@ type Config struct { SafetySettings []*genai.SafetySetting } -// options contains Gemini-specific options for model configuration +// options contains Gemini-specific options for model configuration. +// These are options that are specific to the Gemini API and not part +// of the common model options interface. type options struct { - TopK *int32 + // TopK limits the number of tokens to sample from + TopK *int32 + // ResponseSchema defines the expected JSON structure for responses ResponseSchema *openapi3.Schema } +// ChatModel implements the Gemini chat model for the eino framework. +// It provides integration with Google's Gemini API, supporting both +// text generation and tool calling capabilities. type ChatModel struct { + // cli is the Gemini API client instance cli *genai.Client - model string - maxTokens *int - topP *float32 - temperature *float32 - topK *int32 - responseSchema *openapi3.Schema - tools []*genai.Tool - origTools []*schema.ToolInfo - toolChoice *schema.ToolChoice + // model specifies which Gemini model to use + model string + // maxTokens limits the response length + maxTokens *int + // topP controls nucleus sampling + topP *float32 + // temperature controls randomness + temperature *float32 + // topK limits token sampling + topK *int32 + // responseSchema for structured JSON output + responseSchema *openapi3.Schema + // tools converted to Gemini format + tools []*genai.Tool + // origTools stores the original tool definitions + origTools []*schema.ToolInfo + // toolChoice controls how tools are used + toolChoice *schema.ToolChoice + // enableCodeExecution allows code execution (use with caution) enableCodeExecution bool - safetySettings []*genai.SafetySetting + // safetySettings for content filtering + safetySettings []*genai.SafetySetting } +// Generate generates a single response from the Gemini model. +// It processes the input messages and returns a complete response. +// +// Parameters: +// - ctx: Context for the operation, supporting cancellation and callbacks +// - input: The conversation history as a slice of messages +// - opts: Optional configuration options for the generation +// +// Returns: +// - *schema.Message: The generated response message with content and metadata +// - error: Any error that occurred during generation func (cm *ChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (message *schema.Message, err error) { ctx = callbacks.EnsureRunInfo(ctx, cm.GetType(), components.ComponentOfChatModel) @@ -154,6 +190,17 @@ func (cm *ChatModel) Generate(ctx context.Context, input []*schema.Message, opts return message, nil } +// Stream generates a streaming response from the Gemini model. +// It allows incremental processing of the model's output as it's generated. +// +// Parameters: +// - ctx: Context for the operation, supporting cancellation and callbacks +// - input: The conversation history as a slice of messages +// - opts: Optional configuration options for the generation +// +// Returns: +// - *schema.StreamReader[*schema.Message]: A reader for the streaming response +// - error: Any error that occurred during stream setup func (cm *ChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (result *schema.StreamReader[*schema.Message], err error) { ctx = callbacks.EnsureRunInfo(ctx, cm.GetType(), components.ComponentOfChatModel) @@ -218,6 +265,16 @@ func (cm *ChatModel) Stream(ctx context.Context, input []*schema.Message, opts . }), nil } +// WithTools creates a new model instance with the specified tools available. +// It returns a new ChatModel with tools configured for function calling. +// The original model instance remains unchanged. +// +// Parameters: +// - tools: A slice of tool definitions that the model can use +// +// Returns: +// - model.ToolCallingChatModel: A new model instance with tools enabled +// - error: Returns an error if no tools provided or conversion fails func (cm *ChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) { if len(tools) == 0 { return nil, errors.New("no tools to bind") @@ -235,6 +292,15 @@ func (cm *ChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatM return &ncm, nil } +// BindTools binds tools to the current model instance. +// Unlike WithTools, this modifies the current instance rather than +// creating a new one. Tools are set to "allowed" mode by default. +// +// Parameters: +// - tools: A slice of tool definitions to bind to the model +// +// Returns: +// - error: Returns an error if no tools provided or conversion fails func (cm *ChatModel) BindTools(tools []*schema.ToolInfo) error { if len(tools) == 0 { return errors.New("no tools to bind") @@ -251,6 +317,15 @@ func (cm *ChatModel) BindTools(tools []*schema.ToolInfo) error { return nil } +// BindForcedTools binds tools to the current model instance in forced mode. +// This ensures the model will always use one of the provided tools +// rather than generating a text response. +// +// Parameters: +// - tools: A slice of tool definitions to bind to the model +// +// Returns: +// - error: Returns an error if no tools provided or conversion fails func (cm *ChatModel) BindForcedTools(tools []*schema.ToolInfo) error { if len(tools) == 0 { return errors.New("no tools to bind") @@ -679,12 +754,24 @@ func (cm *ChatModel) convertCallbackOutput(message *schema.Message, conf *model. return callbackOutput } +// IsCallbacksEnabled indicates whether this model supports callbacks. +// For the Gemini model, callbacks are always enabled to support +// token usage tracking and other monitoring features. +// +// Returns: +// - bool: Always returns true for Gemini models func (cm *ChatModel) IsCallbacksEnabled() bool { return true } const typ = "Gemini" +// GetType returns the type identifier for this model. +// This is used for logging and debugging purposes to identify +// which model implementation is being used. +// +// Returns: +// - string: Returns "Gemini" as the model type func (cm *ChatModel) GetType() string { return typ } diff --git a/internal/models/generate_models.go b/internal/models/generate_models.go index 0459bd5d..4b669a8e 100644 --- a/internal/models/generate_models.go +++ b/internal/models/generate_models.go @@ -12,37 +12,61 @@ import ( "time" ) -// ModelInfo represents information about a specific model +// ModelInfo represents information about a specific model. +// This struct is used during code generation to parse model data +// from the models.dev API and generate the static Go code. type ModelInfo struct { - ID string `json:"id"` - Name string `json:"name"` - Attachment bool `json:"attachment"` - Reasoning bool `json:"reasoning"` - Temperature bool `json:"temperature"` - Cost Cost `json:"cost"` - Limit Limit `json:"limit"` + // ID is the unique identifier for the model + ID string `json:"id"` + // Name is the human-readable name of the model + Name string `json:"name"` + // Attachment indicates whether the model supports file attachments + Attachment bool `json:"attachment"` + // Reasoning indicates whether this is a reasoning/chain-of-thought model + Reasoning bool `json:"reasoning"` + // Temperature indicates whether the model supports temperature parameter + Temperature bool `json:"temperature"` + // Cost contains the pricing information for the model + Cost Cost `json:"cost"` + // Limit contains the context and output token limits + Limit Limit `json:"limit"` } -// Cost represents the pricing information for a model +// Cost represents the pricing information for a model. +// Used during code generation to parse pricing data from models.dev. type Cost struct { - Input float64 `json:"input"` - Output float64 `json:"output"` - CacheRead *float64 `json:"cache_read,omitempty"` + // Input is the cost per million input tokens + Input float64 `json:"input"` + // Output is the cost per million output tokens + Output float64 `json:"output"` + // CacheRead is the cost per million cached read tokens (optional) + CacheRead *float64 `json:"cache_read,omitempty"` + // CacheWrite is the cost per million cached write tokens (optional) CacheWrite *float64 `json:"cache_write,omitempty"` } -// Limit represents the context and output limits for a model +// Limit represents the context and output limits for a model. +// Used during code generation to parse token limit data from models.dev. type Limit struct { + // Context is the maximum number of input tokens Context int `json:"context"` - Output int `json:"output"` + // Output is the maximum number of output tokens + Output int `json:"output"` } -// ProviderInfo represents information about a model provider +// ProviderInfo represents information about a model provider. +// Used during code generation to parse provider data from models.dev +// and generate the static provider registry. type ProviderInfo struct { - ID string `json:"id"` - Env []string `json:"env"` - NPM string `json:"npm"` - Name string `json:"name"` + // ID is the unique identifier for the provider + ID string `json:"id"` + // Env lists the environment variables for API credentials + Env []string `json:"env"` + // NPM is the NPM package name (for reference) + NPM string `json:"npm"` + // Name is the human-readable provider name + Name string `json:"name"` + // Models maps model IDs to their information Models map[string]ModelInfo `json:"models"` } diff --git a/internal/models/models_data.go b/internal/models/models_data.go index 928d4bc4..a8eb379e 100644 --- a/internal/models/models_data.go +++ b/internal/models/models_data.go @@ -3,41 +3,99 @@ package models -// ModelInfo represents information about a specific model +// ModelInfo represents information about a specific model. +// It contains comprehensive metadata about a model's capabilities, +// pricing, and limitations sourced from models.dev. type ModelInfo struct { - ID string - Name string - Attachment bool - Reasoning bool + // ID is the unique identifier for the model + // Example: "claude-3-sonnet-20240620" or "gpt-4" + ID string + + // Name is the human-readable name of the model + // Example: "Claude 3 Sonnet" or "GPT-4" + Name string + + // Attachment indicates whether the model supports file attachments + Attachment bool + + // Reasoning indicates whether this is a reasoning/chain-of-thought model + // Example: OpenAI's o1 models have this set to true + Reasoning bool + + // Temperature indicates whether the model supports temperature parameter Temperature bool - Cost Cost - Limit Limit + + // Cost contains the pricing information for input/output tokens + Cost Cost + + // Limit contains the context window and output token limits + Limit Limit } -// Cost represents the pricing information for a model +// Cost represents the pricing information for a model. +// Prices are typically in USD per million tokens. type Cost struct { - Input float64 - Output float64 - CacheRead *float64 + // Input is the cost per million input tokens + Input float64 + + // Output is the cost per million output tokens + Output float64 + + // CacheRead is the cost per million cached read tokens (optional) + // Only applicable for models that support prompt caching + CacheRead *float64 + + // CacheWrite is the cost per million cached write tokens (optional) + // Only applicable for models that support prompt caching CacheWrite *float64 } -// Limit represents the context and output limits for a model +// Limit represents the context and output limits for a model. +// These define the maximum number of tokens the model can process +// and generate in a single interaction. type Limit struct { + // Context is the maximum number of input tokens (context window size) Context int - Output int + + // Output is the maximum number of output tokens that can be generated + Output int } -// ProviderInfo represents information about a model provider +// ProviderInfo represents information about a model provider. +// It contains metadata about the provider and all models it offers. type ProviderInfo struct { - ID string - Env []string - NPM string - Name string + // ID is the unique identifier for the provider + // Example: "anthropic", "openai", "google" + ID string + + // Env lists the environment variables checked for API credentials + // Example: ["ANTHROPIC_API_KEY"] or ["OPENAI_API_KEY"] + Env []string + + // NPM is the NPM package name used by the provider (for reference) + NPM string + + // Name is the human-readable name of the provider + // Example: "Anthropic", "OpenAI", "Google" + Name string + + // Models maps model IDs to their detailed information Models map[string]ModelInfo } -// GetModelsData returns the static models data from models.dev +// GetModelsData returns the static models data from models.dev. +// This data is automatically generated from the models.dev API +// and provides comprehensive information about all supported +// LLM providers and their models. +// +// The data includes: +// - Provider information (ID, name, environment variables) +// - Model capabilities (reasoning, attachments, temperature support) +// - Pricing information (input/output costs, cache costs) +// - Token limits (context window, max output tokens) +// +// Returns: +// - map[string]ProviderInfo: A map of provider IDs to their information func GetModelsData() map[string]ProviderInfo { return map[string]ProviderInfo{ "alibaba": { diff --git a/internal/models/openai/openai.go b/internal/models/openai/openai.go index cbf9eb05..3f51490d 100644 --- a/internal/models/openai/openai.go +++ b/internal/models/openai/openai.go @@ -14,17 +14,42 @@ import ( "github.com/cloudwego/eino/schema" ) -// CustomChatModel wraps the eino-ext OpenAI model with custom tool schema handling +// CustomChatModel wraps the eino-ext OpenAI model with custom tool schema handling. +// It provides a compatibility layer that ensures proper JSON schema formatting +// for OpenAI's function calling feature. This wrapper addresses cases where +// tool schemas might have missing or empty properties that would cause API errors. type CustomChatModel struct { + // wrapped is the underlying eino-ext OpenAI model instance wrapped *einoopenai.ChatModel } -// CustomRoundTripper intercepts HTTP requests to fix OpenAI function schemas +// CustomRoundTripper intercepts HTTP requests to fix OpenAI function schemas. +// It acts as middleware that modifies outgoing requests to ensure that +// function/tool schemas are properly formatted according to OpenAI's requirements. +// This is particularly important for handling edge cases where tool schemas +// might have missing or empty properties fields. type CustomRoundTripper struct { + // wrapped is the underlying HTTP transport to use for actual requests wrapped http.RoundTripper } -// NewCustomChatModel creates a new custom OpenAI chat model +// NewCustomChatModel creates a new custom OpenAI chat model. +// It wraps the standard eino-ext OpenAI model with additional request +// preprocessing to ensure compatibility with OpenAI's API requirements, +// particularly for function calling and tool schemas. +// +// Parameters: +// - ctx: Context for the operation +// - config: Configuration for the OpenAI model including API key, model name, and parameters +// +// Returns: +// - *CustomChatModel: A wrapped OpenAI model with enhanced compatibility +// - error: Returns an error if model creation fails +// +// The custom model automatically: +// - Ensures function parameter schemas have properties fields +// - Fixes missing or empty properties in tool schemas +// - Maintains compatibility with OpenAI's function calling requirements func NewCustomChatModel(ctx context.Context, config *einoopenai.ChatModelConfig) (*CustomChatModel, error) { // Create a custom HTTP client that intercepts requests if config.HTTPClient == nil { @@ -49,7 +74,20 @@ func NewCustomChatModel(ctx context.Context, config *einoopenai.ChatModelConfig) }, nil } -// RoundTrip implements http.RoundTripper to intercept and fix OpenAI requests +// RoundTrip implements http.RoundTripper to intercept and fix OpenAI requests. +// It preprocesses outgoing requests to the OpenAI API to ensure tool/function +// schemas meet the API's requirements. +// +// Parameters: +// - req: The HTTP request to be sent to the OpenAI API +// +// Returns: +// - *http.Response: The response from the OpenAI API +// - error: Any error that occurred during the request +// +// The method performs the following fixes: +// - Ensures function parameter schemas of type "object" have a properties field +// - Adds empty properties object if missing to prevent API validation errors func (c *CustomRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { // Only intercept OpenAI chat completions requests if !strings.Contains(req.URL.Path, "/chat/completions") { @@ -109,17 +147,47 @@ func (c *CustomRoundTripper) RoundTrip(req *http.Request) (*http.Response, error return c.wrapped.RoundTrip(req) } -// Generate implements model.ChatModel +// Generate implements model.ChatModel interface. +// It generates a single response from the OpenAI model based on the input messages. +// +// Parameters: +// - ctx: Context for the operation, supporting cancellation and deadlines +// - in: The conversation history as a slice of messages +// - opts: Optional configuration options for the generation +// +// Returns: +// - *schema.Message: The generated response message +// - error: Any error that occurred during generation func (c *CustomChatModel) Generate(ctx context.Context, in []*schema.Message, opts ...model.Option) (*schema.Message, error) { return c.wrapped.Generate(ctx, in, opts...) } -// Stream implements model.ChatModel +// Stream implements model.ChatModel interface. +// It generates a streaming response from the OpenAI model, allowing +// incremental processing of the model's output as it's generated. +// +// Parameters: +// - ctx: Context for the operation, supporting cancellation and deadlines +// - in: The conversation history as a slice of messages +// - opts: Optional configuration options for the generation +// +// Returns: +// - *schema.StreamReader[*schema.Message]: A reader for the streaming response +// - error: Any error that occurred during stream setup func (c *CustomChatModel) Stream(ctx context.Context, in []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { return c.wrapped.Stream(ctx, in, opts...) } -// WithTools implements model.ToolCallingChatModel +// WithTools implements model.ToolCallingChatModel interface. +// It creates a new model instance with the specified tools available for function calling. +// The original model instance remains unchanged. +// +// Parameters: +// - tools: A slice of tool definitions that the model can use +// +// Returns: +// - model.ToolCallingChatModel: A new model instance with tools enabled +// - error: Returns an error if tool binding fails func (c *CustomChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) { wrappedWithTools, err := c.wrapped.WithTools(tools) if err != nil { @@ -135,22 +203,47 @@ func (c *CustomChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCalling return &CustomChatModel{wrapped: wrappedChatModel}, nil } -// BindTools implements model.ToolCallingChatModel +// BindTools implements model.ToolCallingChatModel interface. +// It binds tools to the current model instance, modifying it in place +// rather than creating a new instance. +// +// Parameters: +// - tools: A slice of tool definitions to bind to the model +// +// Returns: +// - error: Returns an error if tool binding fails func (c *CustomChatModel) BindTools(tools []*schema.ToolInfo) error { return c.wrapped.BindTools(tools) } -// BindForcedTools implements model.ToolCallingChatModel +// BindForcedTools implements model.ToolCallingChatModel interface. +// It binds tools to the current model instance in forced mode, +// ensuring the model will always use one of the provided tools. +// +// Parameters: +// - tools: A slice of tool definitions to bind to the model +// +// Returns: +// - error: Returns an error if tool binding fails func (c *CustomChatModel) BindForcedTools(tools []*schema.ToolInfo) error { return c.wrapped.BindForcedTools(tools) } -// GetType implements model.ChatModel +// GetType implements model.ChatModel interface. +// It returns the type identifier for this model implementation. +// +// Returns: +// - string: Returns "CustomOpenAI" as the model type identifier func (c *CustomChatModel) GetType() string { return "CustomOpenAI" } -// IsCallbacksEnabled implements model.ChatModel +// IsCallbacksEnabled implements model.ChatModel interface. +// It indicates whether this model supports callbacks for monitoring +// and tracking purposes. +// +// Returns: +// - bool: Returns the callback enabled status from the wrapped model func (c *CustomChatModel) IsCallbacksEnabled() bool { return c.wrapped.IsCallbacksEnabled() } diff --git a/internal/models/providers.go b/internal/models/providers.go index 7065d1a6..4f2caedc 100644 --- a/internal/models/providers.go +++ b/internal/models/providers.go @@ -27,7 +27,9 @@ import ( ) const ( - // ClaudeCodePrompt is the required system prompt for OAuth authentication + // ClaudeCodePrompt is the required system prompt for OAuth authentication. + // This prompt must be included as the first system message when using OAuth-based + // authentication with Anthropic's API to properly identify the application. ClaudeCodePrompt = "You are Claude Code, Anthropic's official CLI for Claude." ) @@ -62,35 +64,93 @@ func resolveModelAlias(provider, modelName string) string { return modelName } -// ProviderConfig holds configuration for creating LLM providers +// ProviderConfig holds configuration for creating LLM providers. +// It contains all necessary settings to initialize and configure +// various LLM providers including API keys, model parameters, and +// provider-specific settings. type ProviderConfig struct { - ModelString string - SystemPrompt string - ProviderAPIKey string // API key for OpenAI and Anthropic - ProviderURL string // Base URL for OpenAI, Anthropic, and Ollama + // ModelString specifies the model in the format "provider:model" + // Example: "anthropic:claude-3-sonnet-20240620" or "openai:gpt-4" + ModelString string + + // SystemPrompt sets the system message/instructions for the model + SystemPrompt string + + // ProviderAPIKey is the API key for authentication with the provider + // Used for OpenAI, Anthropic, Google, and Azure providers + ProviderAPIKey string + + // ProviderURL is the base URL for the provider's API endpoint + // Can be used to specify custom endpoints for OpenAI, Anthropic, Ollama, or Azure + ProviderURL string // Model generation parameters - MaxTokens int - Temperature *float32 - TopP *float32 - TopK *int32 + + // MaxTokens limits the maximum number of tokens in the response + MaxTokens int + + // Temperature controls randomness in generation (0.0 to 1.0) + // Lower values make output more focused and deterministic + Temperature *float32 + + // TopP implements nucleus sampling, controlling diversity + // Value between 0.0 and 1.0, where 1.0 considers all tokens + TopP *float32 + + // TopK limits the number of tokens to sample from + // Used by Anthropic and Google providers + TopK *int32 + + // StopSequences is a list of sequences that will stop generation StopSequences []string // Ollama-specific parameters - NumGPU *int32 + + // NumGPU specifies the number of GPUs to use for inference + NumGPU *int32 + + // MainGPU specifies which GPU to use as the primary device MainGPU *int32 // TLS configuration - TLSSkipVerify bool // Skip TLS certificate verification (insecure) + + // TLSSkipVerify skips TLS certificate verification (insecure) + // Should only be used for development or with self-signed certificates + TLSSkipVerify bool } -// ProviderResult contains the result of provider creation +// ProviderResult contains the result of provider creation. +// It includes both the created model instance and any informational +// messages that should be displayed to the user. type ProviderResult struct { - Model model.ToolCallingChatModel - Message string // Optional message for user feedback (e.g., GPU fallback info) + // Model is the created LLM provider instance that implements + // the ToolCallingChatModel interface for tool-enabled conversations + Model model.ToolCallingChatModel + + // Message contains optional feedback for the user + // Example: "Insufficient GPU memory, falling back to CPU inference" + Message string } -// CreateProvider creates an eino ToolCallingChatModel based on the provider configuration +// CreateProvider creates an eino ToolCallingChatModel based on the provider configuration. +// It validates the model, checks required environment variables, and initializes +// the appropriate provider (Anthropic, OpenAI, Google, Ollama, or Azure). +// +// Parameters: +// - ctx: Context for the operation, used for cancellation and deadlines +// - config: Provider configuration containing model details and parameters +// +// Returns: +// - *ProviderResult: Contains the created model and any informational messages +// - error: Returns an error if provider creation fails, model validation fails, +// or required credentials are missing +// +// Supported providers: +// - anthropic: Claude models with API key or OAuth authentication +// - openai: GPT models including reasoning models like o1 +// - google: Gemini models +// - ollama: Local models with automatic GPU/CPU fallback +// - azure: Azure OpenAI Service deployments func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResult, error) { parts := strings.SplitN(config.ModelString, ":", 2) if len(parts) < 2 { @@ -407,9 +467,17 @@ func createGoogleProvider(ctx context.Context, config *ProviderConfig, modelName return gemini.NewChatModel(ctx, geminiConfig) } -// OllamaLoadingResult contains the result of model loading with actual settings used +// OllamaLoadingResult contains the result of model loading with actual settings used. +// It provides information about how the model was loaded, including any fallback +// behavior that occurred during initialization. type OllamaLoadingResult struct { + // Options contains the actual Ollama options used for loading + // May differ from requested options if fallback occurred (e.g., CPU instead of GPU) Options *api.Options + + // Message describes the loading result + // Example: "Model loaded successfully on GPU" or + // "Insufficient GPU memory, falling back to CPU inference" Message string } diff --git a/internal/models/registry.go b/internal/models/registry.go index 057a8bf4..3ac93df3 100644 --- a/internal/models/registry.go +++ b/internal/models/registry.go @@ -8,19 +8,39 @@ import ( "strings" ) -// ModelsRegistry provides validation and information about models +// ModelsRegistry provides validation and information about models. +// It maintains a registry of all supported LLM providers and their models, +// including capabilities, pricing, and configuration requirements. +// The registry data is generated from models.dev and provides a single +// source of truth for model validation and discovery. type ModelsRegistry struct { + // providers maps provider IDs to their information and available models providers map[string]ProviderInfo } -// NewModelsRegistry creates a new models registry with static data +// NewModelsRegistry creates a new models registry with static data. +// The registry is populated with model information generated from models.dev, +// providing comprehensive metadata about available models across all supported providers. +// +// Returns: +// - *ModelsRegistry: A new registry instance populated with current model data func NewModelsRegistry() *ModelsRegistry { return &ModelsRegistry{ providers: GetModelsData(), } } -// ValidateModel validates if a model exists and returns detailed information +// ValidateModel validates if a model exists and returns detailed information. +// It checks whether a specific model is available for a given provider and +// returns comprehensive information about the model's capabilities and limits. +// +// Parameters: +// - provider: The provider ID (e.g., "anthropic", "openai", "google") +// - modelID: The specific model ID (e.g., "claude-3-sonnet-20240620", "gpt-4") +// +// Returns: +// - *ModelInfo: Detailed information about the model including pricing, limits, and capabilities +// - error: Returns an error if the provider is unsupported or model is not found func (r *ModelsRegistry) ValidateModel(provider, modelID string) (*ModelInfo, error) { providerInfo, exists := r.providers[provider] if !exists { @@ -35,7 +55,21 @@ func (r *ModelsRegistry) ValidateModel(provider, modelID string) (*ModelInfo, er return &modelInfo, nil } -// GetRequiredEnvVars returns the required environment variables for a provider +// GetRequiredEnvVars returns the required environment variables for a provider. +// These are the environment variable names that should contain API keys or +// other authentication credentials for the specified provider. +// +// Parameters: +// - provider: The provider ID (e.g., "anthropic", "openai", "google") +// +// Returns: +// - []string: List of environment variable names the provider checks for credentials +// - error: Returns an error if the provider is unsupported +// +// Example: +// +// For "anthropic", returns ["ANTHROPIC_API_KEY"] +// For "google", returns ["GOOGLE_API_KEY", "GEMINI_API_KEY", "GOOGLE_GENERATIVE_AI_API_KEY"] func (r *ModelsRegistry) GetRequiredEnvVars(provider string) ([]string, error) { providerInfo, exists := r.providers[provider] if !exists { @@ -45,7 +79,16 @@ func (r *ModelsRegistry) GetRequiredEnvVars(provider string) ([]string, error) { return providerInfo.Env, nil } -// ValidateEnvironment checks if required environment variables are set +// ValidateEnvironment checks if required environment variables are set. +// It verifies that at least one of the provider's required environment variables +// contains an API key, unless an API key is explicitly provided via configuration. +// +// Parameters: +// - provider: The provider ID to validate environment for +// - apiKey: An API key provided via configuration (if empty, checks environment variables) +// +// Returns: +// - error: Returns nil if validation passes, or an error describing missing credentials func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) error { envVars, err := r.GetRequiredEnvVars(provider) if err != nil { @@ -74,7 +117,16 @@ func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) err return nil } -// SuggestModels returns similar model names when an invalid model is provided +// SuggestModels returns similar model names when an invalid model is provided. +// It helps users discover the correct model ID by finding models that partially +// match the provided input, useful for correcting typos or finding alternatives. +// +// Parameters: +// - provider: The provider ID to search within +// - invalidModel: The invalid or misspelled model name to find suggestions for +// +// Returns: +// - []string: A list of up to 5 suggested model IDs that partially match the input func (r *ModelsRegistry) SuggestModels(provider, invalidModel string) []string { providerInfo, exists := r.providers[provider] if !exists { @@ -105,7 +157,12 @@ func (r *ModelsRegistry) SuggestModels(provider, invalidModel string) []string { return suggestions } -// GetSupportedProviders returns a list of all supported providers +// GetSupportedProviders returns a list of all supported providers. +// This includes all providers that have models registered in the system, +// such as "anthropic", "openai", "google", "alibaba", etc. +// +// Returns: +// - []string: A list of all provider IDs available in the registry func (r *ModelsRegistry) GetSupportedProviders() []string { var providers []string for providerID := range r.providers { @@ -114,7 +171,16 @@ func (r *ModelsRegistry) GetSupportedProviders() []string { return providers } -// GetModelsForProvider returns all models for a specific provider +// GetModelsForProvider returns all models for a specific provider. +// This is useful for listing available models when a user wants to see +// all options for a particular provider. +// +// Parameters: +// - provider: The provider ID to get models for +// +// Returns: +// - map[string]ModelInfo: A map of model IDs to their detailed information +// - error: Returns an error if the provider is unsupported func (r *ModelsRegistry) GetModelsForProvider(provider string) (map[string]ModelInfo, error) { providerInfo, exists := r.providers[provider] if !exists { @@ -127,7 +193,13 @@ func (r *ModelsRegistry) GetModelsForProvider(provider string) (map[string]Model // Global registry instance var globalRegistry = NewModelsRegistry() -// GetGlobalRegistry returns the global models registry instance +// GetGlobalRegistry returns the global models registry instance. +// This provides a singleton registry that can be accessed throughout +// the application for model validation and information retrieval. +// The registry is initialized once with data from models.dev. +// +// Returns: +// - *ModelsRegistry: The global registry instance func GetGlobalRegistry() *ModelsRegistry { return globalRegistry } diff --git a/internal/session/manager.go b/internal/session/manager.go index abe7f0cb..82773327 100644 --- a/internal/session/manager.go +++ b/internal/session/manager.go @@ -7,14 +7,21 @@ import ( "github.com/cloudwego/eino/schema" ) -// Manager manages session state and auto-saving +// Manager manages session state and auto-saving functionality. +// It provides thread-safe operations for managing a conversation session, +// including automatic persistence to disk after each modification. +// The Manager ensures that all session operations are synchronized and +// that the session file is kept up-to-date with any changes. type Manager struct { session *Session filePath string mutex sync.RWMutex } -// NewManager creates a new session manager +// NewManager creates a new session manager with a fresh session. +// The filePath parameter specifies where the session will be auto-saved. +// If filePath is empty, the session will not be automatically saved to disk. +// Returns a Manager instance ready to track conversation messages. func NewManager(filePath string) *Manager { return &Manager{ session: NewSession(), @@ -22,7 +29,11 @@ func NewManager(filePath string) *Manager { } } -// NewManagerWithSession creates a new session manager with an existing session +// NewManagerWithSession creates a new session manager with an existing session. +// This is useful when loading a session from a file and wanting to continue +// managing it with auto-save functionality. +// The session parameter is the existing session to manage. +// The filePath parameter specifies where the session will be auto-saved. func NewManagerWithSession(session *Session, filePath string) *Manager { return &Manager{ session: session, @@ -30,7 +41,12 @@ func NewManagerWithSession(session *Session, filePath string) *Manager { } } -// AddMessage adds a message to the session and auto-saves +// AddMessage adds a message to the session and auto-saves. +// The message is converted from schema.Message format to the internal +// session Message format before being added. If a filePath was specified +// when creating the Manager, the session is automatically saved to disk. +// This operation is thread-safe. +// Returns an error if auto-saving fails, nil otherwise. func (m *Manager) AddMessage(msg *schema.Message) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -45,7 +61,11 @@ func (m *Manager) AddMessage(msg *schema.Message) error { return nil } -// AddMessages adds multiple messages to the session and auto-saves +// AddMessages adds multiple messages to the session and auto-saves. +// All messages are added in order and then the session is saved once. +// This is more efficient than calling AddMessage multiple times when +// adding several messages at once. The operation is thread-safe. +// Returns an error if auto-saving fails, nil otherwise. func (m *Manager) AddMessages(msgs []*schema.Message) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -62,7 +82,12 @@ func (m *Manager) AddMessages(msgs []*schema.Message) error { return nil } -// ReplaceAllMessages replaces all messages in the session with the provided messages +// ReplaceAllMessages replaces all messages in the session with the provided messages. +// This method completely clears the existing message history and replaces it with +// the new set of messages. Useful for resetting a conversation or loading a +// different conversation context. The operation is thread-safe and triggers +// an auto-save if a filePath is configured. +// Returns an error if auto-saving fails, nil otherwise. func (m *Manager) ReplaceAllMessages(msgs []*schema.Message) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -83,7 +108,11 @@ func (m *Manager) ReplaceAllMessages(msgs []*schema.Message) error { return nil } -// SetMetadata sets the session metadata +// SetMetadata sets the session metadata. +// This updates the session's metadata with information about the provider, +// model, and MCPHost version. The operation is thread-safe and triggers +// an auto-save if a filePath is configured. +// Returns an error if auto-saving fails, nil otherwise. func (m *Manager) SetMetadata(metadata Metadata) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -97,7 +126,11 @@ func (m *Manager) SetMetadata(metadata Metadata) error { return nil } -// GetMessages returns all messages as schema.Message slice +// GetMessages returns all messages as a schema.Message slice. +// This method converts all stored session messages to the schema format +// used by LLM providers. The returned slice is a new allocation, so +// modifications to it won't affect the stored session. This operation +// is thread-safe for concurrent reads. func (m *Manager) GetMessages() []*schema.Message { m.mutex.RLock() defer m.mutex.RUnlock() @@ -110,7 +143,11 @@ func (m *Manager) GetMessages() []*schema.Message { return messages } -// GetSession returns a copy of the current session +// GetSession returns a copy of the current session. +// The returned session is a deep copy, including all messages, so +// modifications to it won't affect the managed session. This is useful +// for safely inspecting the session state without risk of concurrent +// modification. This operation is thread-safe for concurrent reads. func (m *Manager) GetSession() *Session { m.mutex.RLock() defer m.mutex.RUnlock() @@ -123,7 +160,11 @@ func (m *Manager) GetSession() *Session { return &sessionCopy } -// Save manually saves the session to file +// Save manually saves the session to file. +// This forces a save operation even if no changes have been made. +// Useful for ensuring the session is persisted at specific points. +// Returns an error if no filePath was specified when creating the +// Manager, or if the save operation fails. func (m *Manager) Save() error { m.mutex.RLock() defer m.mutex.RUnlock() @@ -135,12 +176,16 @@ func (m *Manager) Save() error { return m.session.SaveToFile(m.filePath) } -// GetFilePath returns the file path for this session +// GetFilePath returns the file path for this session. +// Returns the path where the session is being auto-saved, or an +// empty string if no auto-save path was configured. func (m *Manager) GetFilePath() string { return m.filePath } -// MessageCount returns the number of messages in the session +// MessageCount returns the number of messages in the session. +// This provides a quick way to check the conversation length without +// retrieving all messages. This operation is thread-safe for concurrent reads. func (m *Manager) MessageCount() int { m.mutex.RLock() defer m.mutex.RUnlock() diff --git a/internal/session/session.go b/internal/session/session.go index b3baf6f5..580160ef 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -11,40 +11,71 @@ import ( "github.com/cloudwego/eino/schema" ) -// Session represents a complete conversation session with metadata +// Session represents a complete conversation session with metadata. +// It stores all messages exchanged during a conversation along with +// contextual information about the session such as the provider, model, +// and timestamps. Sessions can be saved to and loaded from JSON files +// for persistence across program runs. type Session struct { - Version string `json:"version"` + // Version indicates the session format version for compatibility + Version string `json:"version"` + // CreatedAt is the timestamp when the session was first created CreatedAt time.Time `json:"created_at"` + // UpdatedAt is the timestamp when the session was last modified UpdatedAt time.Time `json:"updated_at"` - Metadata Metadata `json:"metadata"` - Messages []Message `json:"messages"` + // Metadata contains contextual information about the session + Metadata Metadata `json:"metadata"` + // Messages is the ordered list of all messages in this session + Messages []Message `json:"messages"` } -// Metadata contains session metadata +// Metadata contains session metadata that provides context about the +// environment and configuration used during the conversation. This helps +// with debugging and understanding the session's context when reviewing +// conversation history. type Metadata struct { + // MCPHostVersion is the version of MCPHost used for this session MCPHostVersion string `json:"mcphost_version"` - Provider string `json:"provider"` - Model string `json:"model"` + // Provider is the LLM provider used (e.g., "anthropic", "openai", "gemini") + Provider string `json:"provider"` + // Model is the specific model identifier used for the conversation + Model string `json:"model"` } -// Message represents a single message in the session +// Message represents a single message in the conversation session. +// Messages can be from different roles (user, assistant, tool) and may +// include tool calls for assistant messages or tool results for tool messages. type Message struct { - ID string `json:"id"` - Role string `json:"role"` - Content string `json:"content"` - Timestamp time.Time `json:"timestamp"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` // For tool result messages + // ID is a unique identifier for this message, auto-generated if not provided + ID string `json:"id"` + // Role indicates who sent the message ("user", "assistant", "tool", or "system") + Role string `json:"role"` + // Content is the text content of the message + Content string `json:"content"` + // Timestamp is when the message was created + Timestamp time.Time `json:"timestamp"` + // ToolCalls contains any tool invocations made by the assistant in this message + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + // ToolCallID links a tool result message to its corresponding tool call + ToolCallID string `json:"tool_call_id,omitempty"` } -// ToolCall represents a tool call within a message +// ToolCall represents a tool invocation within an assistant message. +// When the assistant decides to use a tool, it creates a ToolCall with +// the necessary information to execute that tool. type ToolCall struct { - ID string `json:"id"` - Name string `json:"name"` - Arguments any `json:"arguments"` + // ID is a unique identifier for this tool call, used to link results + ID string `json:"id"` + // Name is the name of the tool being invoked + Name string `json:"name"` + // Arguments contains the parameters passed to the tool, typically as JSON + Arguments any `json:"arguments"` } -// NewSession creates a new session with default values +// NewSession creates a new session with default values. +// It initializes a session with version 1.0, current timestamps, +// empty message list, and empty metadata. The returned session +// is ready to receive messages and can be saved to a file. func NewSession() *Session { return &Session{ Version: "1.0", @@ -55,7 +86,10 @@ func NewSession() *Session { } } -// AddMessage adds a message to the session +// AddMessage adds a message to the session. +// If the message doesn't have an ID, one will be auto-generated. +// If the message doesn't have a timestamp, the current time will be used. +// The session's UpdatedAt timestamp is automatically updated. func (s *Session) AddMessage(msg Message) { if msg.ID == "" { msg.ID = generateMessageID() @@ -68,13 +102,21 @@ func (s *Session) AddMessage(msg Message) { s.UpdatedAt = time.Now() } -// SetMetadata sets the session metadata +// SetMetadata sets the session metadata. +// This replaces the existing metadata with the provided metadata +// and updates the session's UpdatedAt timestamp. Use this to record +// information about the provider, model, and MCPHost version. func (s *Session) SetMetadata(metadata Metadata) { s.Metadata = metadata s.UpdatedAt = time.Now() } -// SaveToFile saves the session to a JSON file +// SaveToFile saves the session to a JSON file. +// The session is serialized as indented JSON for readability. +// The UpdatedAt timestamp is automatically updated before saving. +// The file is created with 0644 permissions if it doesn't exist, +// or overwritten if it does exist. +// Returns an error if marshaling fails or file writing fails. func (s *Session) SaveToFile(filePath string) error { s.UpdatedAt = time.Now() @@ -86,7 +128,12 @@ func (s *Session) SaveToFile(filePath string) error { return os.WriteFile(filePath, data, 0644) } -// LoadFromFile loads a session from a JSON file +// LoadFromFile loads a session from a JSON file. +// It reads the file at the specified path and deserializes it into +// a Session struct. This is useful for resuming previous conversations +// or reviewing session history. +// Returns the loaded session on success, or an error if the file +// cannot be read or the JSON is invalid. func LoadFromFile(filePath string) (*Session, error) { data, err := os.ReadFile(filePath) if err != nil { @@ -101,7 +148,12 @@ func LoadFromFile(filePath string) (*Session, error) { return &session, nil } -// ConvertFromSchemaMessage converts a schema.Message to a session Message +// ConvertFromSchemaMessage converts a schema.Message to a session Message. +// This function bridges between the eino schema message format and the +// session's internal message format. It preserves role, content, and +// tool-related information while adding a timestamp. +// Tool calls from assistant messages and tool call IDs from tool messages +// are properly converted and preserved. func ConvertFromSchemaMessage(msg *schema.Message) Message { sessionMsg := Message{ Role: string(msg.Role), @@ -129,7 +181,12 @@ func ConvertFromSchemaMessage(msg *schema.Message) Message { return sessionMsg } -// ConvertToSchemaMessage converts a session Message to a schema.Message +// ConvertToSchemaMessage converts a session Message to a schema.Message. +// This method bridges between the session's internal message format and +// the eino schema message format used by the LLM providers. +// It properly handles tool calls for assistant messages and tool call IDs +// for tool result messages. Arguments are converted to string format as +// required by the schema. func (m *Message) ConvertToSchemaMessage() *schema.Message { msg := &schema.Message{ Role: schema.RoleType(m.Role), diff --git a/internal/tokens/anthropic.go b/internal/tokens/anthropic.go index a7c90664..a493669b 100644 --- a/internal/tokens/anthropic.go +++ b/internal/tokens/anthropic.go @@ -1 +1,16 @@ +// Package tokens provides token counting and estimation functionality for +// various language model providers. It includes utilities for estimating +// token counts in text, as well as provider-specific implementations for +// more accurate token counting. +// +// The package supports multiple approaches to token counting: +// - Quick estimation using character-based heuristics +// - Provider-specific tokenizers for accurate counts +// - Initialization functions for setting up token counters +// +// Token counting is essential for: +// - Managing API rate limits +// - Calculating costs for API usage +// - Ensuring prompts fit within model context windows +// - Optimizing prompt engineering and response handling package tokens diff --git a/internal/tokens/counter.go b/internal/tokens/counter.go index d88a125b..0cc3a96c 100644 --- a/internal/tokens/counter.go +++ b/internal/tokens/counter.go @@ -1,6 +1,23 @@ package tokens -// EstimateTokens provides a rough estimate of tokens in text +// EstimateTokens estimates the number of tokens in the given text string. +// It uses a rough approximation of 4 characters per token, which is a common +// heuristic for most language models. This function provides a quick estimation +// without requiring model-specific tokenizers. +// +// The estimation may not be accurate for all models or text types, particularly +// for texts with many special characters, non-English languages, or code snippets. +// For more accurate token counting, use model-specific tokenizers when available. +// +// Parameters: +// - text: The input text string to estimate tokens for +// +// Returns: +// - int: The estimated number of tokens in the text +// +// Example: +// +// count := EstimateTokens("Hello, world!") // Returns approximately 3 func EstimateTokens(text string) int { // Rough approximation: ~4 characters per token for most models return len(text) / 4 diff --git a/internal/tokens/init.go b/internal/tokens/init.go index 2ad98305..a4e13063 100644 --- a/internal/tokens/init.go +++ b/internal/tokens/init.go @@ -1,11 +1,52 @@ package tokens -// InitializeTokenCounters registers all available token counters +// InitializeTokenCounters registers all available token counters for various +// language model providers. This function should be called during application +// startup to ensure that token counting functionality is available for all +// supported models. +// +// Currently, this function is a placeholder for future provider-specific +// token counter implementations. As new providers are added (OpenAI, Anthropic, +// Google, etc.), their respective token counters will be registered here. +// +// This function does not require any API keys and will only initialize +// counters that can work without authentication. +// +// Example: +// +// func main() { +// tokens.InitializeTokenCounters() +// // Token counting is now available +// } func InitializeTokenCounters() { // Future provider-specific counters can be registered here } -// InitializeTokenCountersWithKeys registers token counters with provided API keys +// InitializeTokenCountersWithKeys registers token counters for various language +// model providers using the provided API keys. This function enables more +// accurate token counting by allowing access to provider-specific tokenization +// endpoints or libraries that require authentication. +// +// This function should be called during application startup after API keys +// have been loaded from configuration or environment variables. It will +// initialize token counters for providers where API keys are available, +// enabling precise token counting that matches the provider's actual +// tokenization logic. +// +// The function will silently skip providers for which no API keys are +// configured, allowing the application to continue with partial token +// counting capabilities. +// +// Future implementations will accept provider-specific API keys through +// parameters or read them from a configuration context. +// +// Example: +// +// func main() { +// // Load API keys from environment or config +// tokens.InitializeTokenCountersWithKeys() +// // Provider-specific token counting is now available +// } func InitializeTokenCountersWithKeys() { // Future provider-specific counters can be registered here } diff --git a/internal/tools/buffered_logger.go b/internal/tools/buffered_logger.go index 7ab75ddd..d582c5ee 100644 --- a/internal/tools/buffered_logger.go +++ b/internal/tools/buffered_logger.go @@ -4,14 +4,19 @@ import ( "sync" ) -// BufferedDebugLogger stores debug messages until they can be displayed +// BufferedDebugLogger implements DebugLogger by storing debug messages in memory +// until they can be retrieved and displayed. This is useful when debug output +// needs to be deferred or batch-processed rather than immediately displayed. +// All methods are thread-safe for concurrent use. type BufferedDebugLogger struct { enabled bool messages []string mu sync.Mutex } -// NewBufferedDebugLogger creates a new buffered debug logger +// NewBufferedDebugLogger creates a new buffered debug logger instance. +// The enabled parameter determines whether debug messages will be stored. +// If enabled is false, all LogDebug calls become no-ops for performance. func NewBufferedDebugLogger(enabled bool) *BufferedDebugLogger { return &BufferedDebugLogger{ enabled: enabled, @@ -19,7 +24,10 @@ func NewBufferedDebugLogger(enabled bool) *BufferedDebugLogger { } } -// LogDebug stores a debug message +// LogDebug stores a debug message in the internal buffer if debug logging is enabled. +// Messages are appended to the buffer and retained until GetMessages is called. +// If debug logging is disabled, this method is a no-op. +// Thread-safe for concurrent calls. func (l *BufferedDebugLogger) LogDebug(message string) { if !l.enabled { return @@ -29,12 +37,17 @@ func (l *BufferedDebugLogger) LogDebug(message string) { l.messages = append(l.messages, message) } -// IsDebugEnabled returns whether debug logging is enabled +// IsDebugEnabled returns whether debug logging is enabled for this logger. +// This can be used to conditionally execute expensive debug operations +// only when debugging is actually enabled. func (l *BufferedDebugLogger) IsDebugEnabled() bool { return l.enabled } -// GetMessages returns all buffered messages and clears the buffer +// GetMessages returns all buffered debug messages and clears the internal buffer. +// The returned slice contains all messages logged since the last call to GetMessages. +// After this call, the internal buffer is empty and ready to accumulate new messages. +// Thread-safe for concurrent calls. func (l *BufferedDebugLogger) GetMessages() []string { l.mu.Lock() defer l.mu.Unlock() diff --git a/internal/tools/connection_pool.go b/internal/tools/connection_pool.go index edffa588..caab7f7d 100644 --- a/internal/tools/connection_pool.go +++ b/internal/tools/connection_pool.go @@ -15,16 +15,20 @@ import ( "github.com/mark3labs/mcphost/internal/config" ) -// ConnectionPoolConfig configuration for connection pool +// ConnectionPoolConfig defines configuration parameters for the MCP connection pool. +// It controls connection lifecycle, health checking, and error handling behaviors. type ConnectionPoolConfig struct { - MaxIdleTime time.Duration - MaxRetries int - HealthCheckInterval time.Duration - MaxErrorCount int - ReconnectDelay time.Duration + MaxIdleTime time.Duration // Maximum time a connection can remain idle before being marked unhealthy + MaxRetries int // Maximum number of retry attempts for failed operations + HealthCheckInterval time.Duration // Interval between background health checks of all connections + MaxErrorCount int // Maximum consecutive errors before marking a connection unhealthy + ReconnectDelay time.Duration // Delay before attempting to reconnect after connection failure } -// DefaultConnectionPoolConfig returns default configuration +// DefaultConnectionPoolConfig returns a connection pool configuration with sensible defaults. +// Default values: 5 minute max idle time, 3 retries, 30 second health check interval, +// 3 max errors before marking unhealthy, and 2 second reconnect delay. +// These defaults are suitable for most MCP server deployments. func DefaultConnectionPoolConfig() *ConnectionPoolConfig { return &ConnectionPoolConfig{ MaxIdleTime: 5 * time.Minute, @@ -35,7 +39,10 @@ func DefaultConnectionPoolConfig() *ConnectionPoolConfig { } } -// MCPConnection represents an MCP connection +// MCPConnection represents a single MCP client connection with health tracking and metadata. +// It wraps an MCP client and maintains state about connection health, usage patterns, +// and error history. Access to connection state is protected by a read-write mutex +// for thread-safe concurrent access. type MCPConnection struct { client client.MCPClient serverName string @@ -47,7 +54,11 @@ type MCPConnection struct { mu sync.RWMutex } -// MCPConnectionPool manages MCP connections +// MCPConnectionPool manages a pool of MCP client connections with automatic health checking, +// connection reuse, and failure recovery. It provides thread-safe connection management +// across multiple MCP servers, automatically handling connection lifecycle including +// creation, health monitoring, and cleanup. The pool runs background health checks +// to proactively identify and remove unhealthy connections. type MCPConnectionPool struct { connections map[string]*MCPConnection config *ConnectionPoolConfig @@ -59,7 +70,11 @@ type MCPConnectionPool struct { debugLogger DebugLogger } -// NewMCPConnectionPool creates a new connection pool +// NewMCPConnectionPool creates a new MCP connection pool with the specified configuration. +// If config is nil, default configuration values will be used. The pool starts a background +// goroutine for periodic health checks that runs until Close is called. +// The model parameter is used for MCP servers that require sampling support. +// Thread-safe for concurrent use immediately after creation. func NewMCPConnectionPool(config *ConnectionPoolConfig, model model.ToolCallingChatModel, debug bool) *MCPConnectionPool { if config == nil { config = DefaultConnectionPoolConfig() @@ -79,14 +94,20 @@ func NewMCPConnectionPool(config *ConnectionPoolConfig, model model.ToolCallingC return pool } -// SetDebugLogger sets the debug logger for the connection pool +// SetDebugLogger sets the debug logger for the connection pool. +// The logger will be used to output detailed information about connection lifecycle, +// health checks, and error conditions. Thread-safe and can be called at any time. func (p *MCPConnectionPool) SetDebugLogger(logger DebugLogger) { p.mu.Lock() defer p.mu.Unlock() p.debugLogger = logger } -// GetConnection gets a connection from the pool +// GetConnection retrieves or creates a connection for the specified MCP server. +// If a healthy, non-idle connection exists in the pool, it will be reused. +// Otherwise, a new connection is created and added to the pool. +// Returns an error if connection creation or initialization fails. +// Thread-safe for concurrent calls. func (p *MCPConnectionPool) GetConnection(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (*MCPConnection, error) { p.mu.Lock() defer p.mu.Unlock() @@ -127,7 +148,12 @@ func (p *MCPConnectionPool) GetConnection(ctx context.Context, serverName string return conn, nil } -// GetConnectionWithHealthCheck gets a connection from the pool with proactive health check +// GetConnectionWithHealthCheck retrieves a connection with an additional proactive health check. +// Unlike GetConnection, this method performs a health check on existing connections before +// returning them, ensuring the connection is truly healthy. This is useful for critical +// operations where connection reliability is paramount. Creates a new connection if the +// existing one fails the health check or doesn't exist. +// Thread-safe for concurrent calls. func (p *MCPConnectionPool) GetConnectionWithHealthCheck(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (*MCPConnection, error) { p.mu.Lock() defer p.mu.Unlock() @@ -430,7 +456,11 @@ func (p *MCPConnectionPool) checkConnectionsHealth() { } } -// HandleConnectionError handles connection errors +// HandleConnectionError records and handles errors for a specific connection. +// It increments the error count and may mark the connection as unhealthy based on +// the error type and configured thresholds. Connection errors (network, transport, 404) +// immediately mark the connection as unhealthy for removal on next access. +// Thread-safe for concurrent error reporting. func (p *MCPConnectionPool) HandleConnectionError(serverName string, err error) { p.mu.RLock() conn, exists := p.connections[serverName] @@ -460,7 +490,10 @@ func (p *MCPConnectionPool) HandleConnectionError(serverName string, err error) } } -// GetConnectionStats returns connection statistics +// GetConnectionStats returns detailed statistics for all connections in the pool. +// The returned map includes health status, last usage time, error counts, and +// last error for each connection. Useful for monitoring and debugging connection +// pool behavior. The returned data is a snapshot and safe for concurrent access. func (p *MCPConnectionPool) GetConnectionStats() map[string]interface{} { p.mu.RLock() defer p.mu.RUnlock() @@ -480,12 +513,17 @@ func (p *MCPConnectionPool) GetConnectionStats() map[string]interface{} { return stats } -// ServerName returns the server name for this connection +// ServerName returns the server name associated with this MCP connection. +// This is the configured name from the MCPHost configuration, not necessarily +// the actual server implementation name. func (c *MCPConnection) ServerName() string { return c.serverName } -// GetClients returns all client names in the pool +// GetClients returns a map of all MCP clients currently in the pool. +// The map keys are server names and values are the corresponding MCP client instances. +// The returned map is a copy and modifications won't affect the pool. +// Note that clients may be unhealthy; use GetConnectionStats to check health status. func (p *MCPConnectionPool) GetClients() map[string]client.MCPClient { p.mu.RLock() defer p.mu.RUnlock() @@ -497,7 +535,11 @@ func (p *MCPConnectionPool) GetClients() map[string]client.MCPClient { return clients } -// Close closes the connection pool +// Close gracefully shuts down the connection pool, closing all client connections +// and stopping the background health check goroutine. It attempts to close all +// connections even if some fail, logging any errors encountered. +// Safe to call multiple times; subsequent calls are no-ops. +// Always call Close when done with the pool to prevent resource leaks. func (p *MCPConnectionPool) Close() error { p.cancel() diff --git a/internal/tools/debug_logger.go b/internal/tools/debug_logger.go index 9df5767a..acd697b2 100644 --- a/internal/tools/debug_logger.go +++ b/internal/tools/debug_logger.go @@ -1,28 +1,45 @@ package tools -// DebugLogger interface for debug logging +// DebugLogger defines the interface for debug logging in the MCP tools package. +// Implementations can provide different strategies for handling debug output, +// such as immediate console output, buffering, or file logging. +// All implementations must be thread-safe for concurrent use. type DebugLogger interface { + // LogDebug logs a debug message. Implementations determine how the message is handled. LogDebug(message string) + // IsDebugEnabled returns true if debug logging is enabled, allowing callers + // to skip expensive debug operations when debugging is disabled. IsDebugEnabled() bool } -// SimpleDebugLogger is a simple implementation that prints to stdout +// SimpleDebugLogger provides a minimal implementation of the DebugLogger interface. +// It is intentionally silent by default to prevent duplicate or unstyled debug output +// during initialization. Debug messages are only displayed when using the CLI debug logger +// which provides proper formatting and styling. type SimpleDebugLogger struct { enabled bool } -// NewSimpleDebugLogger creates a new simple debug logger +// NewSimpleDebugLogger creates a new simple debug logger instance. +// The enabled parameter determines whether IsDebugEnabled will return true. +// Note that LogDebug is intentionally a no-op to avoid unstyled output; +// actual debug output is handled by the CLI's debug logger. func NewSimpleDebugLogger(enabled bool) *SimpleDebugLogger { return &SimpleDebugLogger{enabled: enabled} } -// LogDebug logs a debug message +// LogDebug is intentionally a no-op in SimpleDebugLogger. +// Debug messages are only displayed when using the CLI debug logger which provides +// proper formatting and styling. This prevents duplicate or unstyled debug output +// during initialization and ensures consistent debug output presentation. func (l *SimpleDebugLogger) LogDebug(message string) { // Silent by default - messages will only appear when using CLI debug logger // This prevents duplicate or unstyled debug output during initialization } -// IsDebugEnabled returns whether debug logging is enabled +// IsDebugEnabled returns whether debug logging is enabled for this logger. +// This allows code to conditionally execute expensive debug operations +// only when debugging is active, improving performance in production. func (l *SimpleDebugLogger) IsDebugEnabled() bool { return l.enabled } diff --git a/internal/tools/mcp.go b/internal/tools/mcp.go index 39d84be6..2da192d5 100644 --- a/internal/tools/mcp.go +++ b/internal/tools/mcp.go @@ -19,7 +19,11 @@ import ( "github.com/mark3labs/mcphost/internal/config" ) -// MCPToolManager manages MCP tools and clients +// MCPToolManager manages MCP (Model Context Protocol) tools and clients across multiple servers. +// It provides a unified interface for loading, managing, and executing tools from various MCP servers, +// including stdio, SSE, streamable HTTP, and built-in server types. The manager handles connection +// pooling, health checks, tool name prefixing to avoid conflicts, and sampling support for LLM interactions. +// Thread-safe for concurrent tool invocations. type MCPToolManager struct { connectionPool *MCPConnectionPool tools []tool.BaseTool @@ -44,7 +48,9 @@ type mcpToolImpl struct { mapping *toolMapping } -// NewMCPToolManager creates a new MCP tool manager +// NewMCPToolManager creates a new MCP tool manager instance. +// Returns an initialized manager with empty tool collections ready to load tools from MCP servers. +// The manager must be configured with SetModel and LoadTools before use. func NewMCPToolManager() *MCPToolManager { return &MCPToolManager{ tools: make([]tool.BaseTool, 0), @@ -52,12 +58,18 @@ func NewMCPToolManager() *MCPToolManager { } } -// SetModel sets the LLM model for sampling support +// SetModel sets the LLM model for sampling support. +// The model is used when MCP servers request sampling operations, allowing them to +// leverage the host's LLM capabilities for text generation tasks. +// This method should be called before LoadTools if any MCP servers require sampling support. func (m *MCPToolManager) SetModel(model model.ToolCallingChatModel) { m.model = model } -// SetDebugLogger sets the debug logger +// SetDebugLogger sets the debug logger for the tool manager. +// The logger will be used to output detailed debugging information about MCP connections, +// tool loading, and execution. If a connection pool exists, it will also be configured +// to use the same logger for consistent debugging output. func (m *MCPToolManager) SetDebugLogger(logger DebugLogger) { m.debugLogger = logger if m.connectionPool != nil { @@ -70,7 +82,10 @@ type samplingHandler struct { model model.ToolCallingChatModel } -// CreateMessage handles sampling requests from MCP servers +// CreateMessage handles sampling requests from MCP servers by forwarding them to the configured LLM model. +// It converts MCP message formats to eino message formats, invokes the model for generation, +// and converts the response back to MCP format. Returns an error if no model is available +// or if generation fails. func (h *samplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { if h.model == nil { return nil, fmt.Errorf("no model available for sampling") @@ -126,7 +141,11 @@ func (h *samplingHandler) CreateMessage(ctx context.Context, request mcp.CreateM return result, nil } -// LoadTools loads tools from MCP servers based on configuration +// LoadTools loads tools from all configured MCP servers based on the provided configuration. +// It initializes the connection pool, connects to each configured server, and loads their tools. +// Tools from different servers are prefixed with the server name to avoid naming conflicts. +// Returns an error only if all configured servers fail to load; partial failures are logged as warnings. +// This method is thread-safe and idempotent. func (m *MCPToolManager) LoadTools(ctx context.Context, config *config.Config) error { // Initialize connection pool m.config = config @@ -243,12 +262,18 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string, return nil } -// Info returns the tool information +// Info returns the tool information including name, description, and parameter schema. +// This method implements the eino tool.BaseTool interface. +// The returned ToolInfo contains the prefixed tool name to ensure uniqueness across servers. func (t *mcpToolImpl) Info(ctx context.Context) (*schema.ToolInfo, error) { return t.info, nil } -// InvokableRun executes the tool by mapping back to the original name and server +// InvokableRun executes the tool by mapping the prefixed name back to the original tool name and server. +// It retrieves a healthy connection from the pool, invokes the tool on the appropriate MCP server, +// and returns the result as a JSON string. The method handles connection errors by marking +// connections as unhealthy in the pool for automatic recovery on subsequent requests. +// Thread-safe for concurrent invocations. func (t *mcpToolImpl) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { // Handle empty or invalid JSON arguments var arguments any @@ -300,12 +325,16 @@ func (t *mcpToolImpl) InvokableRun(ctx context.Context, argumentsInJSON string, return marshaledResult, nil } -// GetTools returns all loaded tools +// GetTools returns all loaded tools from all configured MCP servers. +// Tools are returned with their prefixed names (serverName__toolName) to ensure uniqueness. +// The returned slice is a copy and can be safely modified by the caller. func (m *MCPToolManager) GetTools() []tool.BaseTool { return m.tools } -// GetLoadedServerNames returns the names of successfully loaded MCP servers +// GetLoadedServerNames returns the names of all successfully loaded MCP servers. +// This includes servers that are currently connected and have had their tools loaded, +// regardless of their current health status. Useful for debugging and status reporting. func (m *MCPToolManager) GetLoadedServerNames() []string { var names []string for serverName := range m.connectionPool.GetClients() { @@ -314,7 +343,10 @@ func (m *MCPToolManager) GetLoadedServerNames() []string { return names } -// Close closes all MCP clients +// Close closes all MCP client connections and cleans up resources. +// This method should be called when the tool manager is no longer needed to ensure +// proper cleanup of stdio processes, network connections, and other resources. +// It is safe to call Close multiple times. func (m *MCPToolManager) Close() error { return m.connectionPool.Close() } diff --git a/internal/ui/block_renderer.go b/internal/ui/block_renderer.go index 50a39d03..013ad476 100644 --- a/internal/ui/block_renderer.go +++ b/internal/ui/block_renderer.go @@ -21,70 +21,90 @@ type blockRenderer struct { // renderingOption configures block rendering type renderingOption func(*blockRenderer) -// WithFullWidth makes the block take full available width +// WithFullWidth returns a renderingOption that configures the block renderer +// to expand to the full available width of its container. When enabled, the +// block will fill the entire horizontal space rather than sizing to its content. func WithFullWidth() renderingOption { return func(c *blockRenderer) { c.fullWidth = true } } -// WithAlign sets the horizontal alignment of the block +// WithAlign returns a renderingOption that sets the horizontal alignment +// of the block content within its container. The align parameter accepts +// lipgloss.Left, lipgloss.Center, or lipgloss.Right positions. func WithAlign(align lipgloss.Position) renderingOption { return func(c *blockRenderer) { c.align = &align } } -// WithBorderColor sets the border color +// WithBorderColor returns a renderingOption that sets the border color +// for the block. The color parameter uses lipgloss.AdaptiveColor to support +// both light and dark terminal themes automatically. func WithBorderColor(color lipgloss.AdaptiveColor) renderingOption { return func(c *blockRenderer) { c.borderColor = &color } } -// WithMarginTop sets the top margin +// WithMarginTop returns a renderingOption that sets the top margin +// for the block. The margin is specified in number of lines and adds +// vertical space above the block. func WithMarginTop(margin int) renderingOption { return func(c *blockRenderer) { c.marginTop = margin } } -// WithMarginBottom sets the bottom margin +// WithMarginBottom returns a renderingOption that sets the bottom margin +// for the block. The margin is specified in number of lines and adds +// vertical space below the block. func WithMarginBottom(margin int) renderingOption { return func(c *blockRenderer) { c.marginBottom = margin } } -// WithPaddingLeft sets the left padding +// WithPaddingLeft returns a renderingOption that sets the left padding +// for the block content. The padding is specified in number of characters +// and adds horizontal space between the left border and the content. func WithPaddingLeft(padding int) renderingOption { return func(c *blockRenderer) { c.paddingLeft = padding } } -// WithPaddingRight sets the right padding +// WithPaddingRight returns a renderingOption that sets the right padding +// for the block content. The padding is specified in number of characters +// and adds horizontal space between the content and the right border. func WithPaddingRight(padding int) renderingOption { return func(c *blockRenderer) { c.paddingRight = padding } } -// WithPaddingTop sets the top padding +// WithPaddingTop returns a renderingOption that sets the top padding +// for the block content. The padding is specified in number of lines +// and adds vertical space between the top border and the content. func WithPaddingTop(padding int) renderingOption { return func(c *blockRenderer) { c.paddingTop = padding } } -// WithPaddingBottom sets the bottom padding +// WithPaddingBottom returns a renderingOption that sets the bottom padding +// for the block content. The padding is specified in number of lines +// and adds vertical space between the content and the bottom border. func WithPaddingBottom(padding int) renderingOption { return func(c *blockRenderer) { c.paddingBottom = padding } } -// WithWidth sets a specific width for the block +// WithWidth returns a renderingOption that sets a specific width for the block +// in characters. This overrides the default container width and allows precise +// control over the block's horizontal dimensions. func WithWidth(width int) renderingOption { return func(c *blockRenderer) { c.width = width diff --git a/internal/ui/callbacks.go b/internal/ui/callbacks.go index 92e0b7c5..a9af937a 100644 --- a/internal/ui/callbacks.go +++ b/internal/ui/callbacks.go @@ -9,7 +9,11 @@ import ( utilCallbacks "github.com/cloudwego/eino/utils/callbacks" ) -// CreateCallbackHandler creates a callback handler using HandlerHelper +// CreateCallbackHandler creates and returns a callbacks.Handler that manages +// tool execution callbacks for the CLI. The handler displays tool calls, +// handles errors, and manages streaming output for interactive tool operations. +// It integrates with the eino callback system to provide real-time UI feedback +// during tool execution. func (c *CLI) CreateCallbackHandler() callbacks.Handler { toolHandler := &utilCallbacks.ToolCallbackHandler{ OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *tool.CallbackInput) context.Context { diff --git a/internal/ui/cli.go b/internal/ui/cli.go index 4eacd6b3..2bfb451b 100644 --- a/internal/ui/cli.go +++ b/internal/ui/cli.go @@ -17,7 +17,10 @@ var ( promptStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("12")) ) -// CLI handles the command line interface with improved message rendering +// CLI manages the command-line interface for MCPHost, providing message rendering, +// user input handling, and display management. It supports both standard and compact +// display modes, handles streaming responses, tracks token usage, and manages the +// overall conversation flow between the user and AI assistants. type CLI struct { messageRenderer *MessageRenderer compactRenderer *CompactRenderer // Add compact renderer @@ -32,7 +35,10 @@ type CLI struct { usageDisplayed bool // track if usage info was displayed after last assistant message } -// NewCLI creates a new CLI instance with message container +// NewCLI creates and initializes a new CLI instance with the specified display modes. +// The debug parameter enables debug message rendering, while compact enables a more +// condensed display format. Returns an initialized CLI ready for interaction or an +// error if initialization fails. func NewCLI(debug bool, compact bool) (*CLI, error) { cli := &CLI{ compactMode: compact, @@ -46,7 +52,9 @@ func NewCLI(debug bool, compact bool) (*CLI, error) { return cli, nil } -// SetUsageTracker sets the usage tracker for the CLI +// SetUsageTracker attaches a usage tracker to the CLI for monitoring token +// consumption and costs. The tracker will be automatically updated with the +// current display width for proper rendering. func (c *CLI) SetUsageTracker(tracker *UsageTracker) { c.usageTracker = tracker if c.usageTracker != nil { @@ -54,12 +62,14 @@ func (c *CLI) SetUsageTracker(tracker *UsageTracker) { } } -// GetDebugLogger returns a debug logger that uses the CLI for rendering +// GetDebugLogger returns a CLIDebugLogger instance that routes debug output +// through the CLI's rendering system for consistent message formatting and display. func (c *CLI) GetDebugLogger() *CLIDebugLogger { return NewCLIDebugLogger(c) } -// SetModelName sets the current model name for the CLI +// SetModelName updates the current AI model name being used in the conversation. +// This name is displayed in message headers to indicate which model is responding. func (c *CLI) SetModelName(modelName string) { c.modelName = modelName if c.messageContainer != nil { @@ -67,7 +77,10 @@ func (c *CLI) SetModelName(modelName string) { } } -// GetPrompt gets user input using the huh library with divider and padding +// GetPrompt displays an interactive prompt and waits for user input. It provides +// slash command support, multi-line editing, and cancellation handling. Returns +// the user's input as a string, or an error if the operation was cancelled or +// failed. Returns io.EOF for clean exit signals. func (c *CLI) GetPrompt() (string, error) { // Usage info is now displayed immediately after responses via DisplayUsageAfterResponse() // No need to display it here to avoid duplication @@ -107,7 +120,9 @@ func (c *CLI) GetPrompt() (string, error) { return "", fmt.Errorf("unexpected model type") } -// ShowSpinner displays a spinner with the given message and executes the action +// ShowSpinner displays an animated spinner with the specified message while +// executing the provided action function. The spinner automatically stops when +// the action completes. Returns any error returned by the action function. func (c *CLI) ShowSpinner(message string, action func() error) error { spinner := NewSpinner(message) spinner.Start() @@ -119,7 +134,9 @@ func (c *CLI) ShowSpinner(message string, action func() error) error { return err } -// DisplayUserMessage displays the user's message using the appropriate renderer +// DisplayUserMessage renders and displays a user's message with appropriate +// formatting based on the current display mode (standard or compact). The message +// is timestamped and styled according to the active theme. func (c *CLI) DisplayUserMessage(message string) { var msg UIMessage if c.compactMode { @@ -131,12 +148,16 @@ func (c *CLI) DisplayUserMessage(message string) { c.displayContainer() } -// DisplayAssistantMessage displays the assistant's message using the new renderer +// DisplayAssistantMessage renders and displays an AI assistant's response message +// with appropriate formatting. This method delegates to DisplayAssistantMessageWithModel +// with an empty model name for backward compatibility. func (c *CLI) DisplayAssistantMessage(message string) error { return c.DisplayAssistantMessageWithModel(message, "") } -// DisplayAssistantMessageWithModel displays the assistant's message with model info +// DisplayAssistantMessageWithModel renders and displays an AI assistant's response +// with the specified model name shown in the message header. The message is +// formatted according to the current display mode and includes timestamp information. func (c *CLI) DisplayAssistantMessageWithModel(message, modelName string) error { var msg UIMessage if c.compactMode { @@ -149,7 +170,9 @@ func (c *CLI) DisplayAssistantMessageWithModel(message, modelName string) error return nil } -// DisplayToolCallMessage displays a tool call in progress +// DisplayToolCallMessage renders and displays a message indicating that a tool +// is being executed. Shows the tool name and its arguments formatted appropriately +// for the current display mode. This is typically shown while a tool is running. func (c *CLI) DisplayToolCallMessage(toolName, toolArgs string) { c.messageContainer.messages = nil // clear previous messages (they should have been printed already) @@ -167,7 +190,9 @@ func (c *CLI) DisplayToolCallMessage(toolName, toolArgs string) { c.displayContainer() } -// DisplayToolMessage displays a tool call message +// DisplayToolMessage renders and displays the complete result of a tool execution, +// including the tool name, arguments, and result. The isError parameter determines +// whether the result should be displayed as an error or success message. func (c *CLI) DisplayToolMessage(toolName, toolArgs, toolResult string, isError bool) { var msg UIMessage if c.compactMode { @@ -181,7 +206,9 @@ func (c *CLI) DisplayToolMessage(toolName, toolArgs, toolResult string, isError c.displayContainer() } -// StartStreamingMessage starts a streaming assistant message +// StartStreamingMessage initializes a new streaming message display for real-time +// AI responses. The message will be progressively updated as content arrives. +// The modelName parameter indicates which AI model is generating the response. func (c *CLI) StartStreamingMessage(modelName string) { // Add an empty assistant message that we'll update during streaming var msg UIMessage @@ -196,14 +223,18 @@ func (c *CLI) StartStreamingMessage(modelName string) { c.displayContainer() } -// UpdateStreamingMessage updates the streaming message with new content +// UpdateStreamingMessage updates the currently streaming message with new content. +// This method should be called after StartStreamingMessage to progressively display +// AI responses as they are generated in real-time. func (c *CLI) UpdateStreamingMessage(content string) { // Update the last message (which should be the streaming assistant message) c.messageContainer.UpdateLastMessage(content) c.displayContainer() } -// DisplayError displays an error message using the appropriate renderer +// DisplayError renders and displays an error message with distinctive formatting +// to ensure visibility. The error is timestamped and styled according to the +// current display mode's error theme. func (c *CLI) DisplayError(err error) { var msg UIMessage if c.compactMode { @@ -215,7 +246,9 @@ func (c *CLI) DisplayError(err error) { c.displayContainer() } -// DisplayInfo displays an informational message using the appropriate renderer +// DisplayInfo renders and displays an informational system message. These messages +// are typically used for status updates, notifications, or other non-error system +// communications to the user. func (c *CLI) DisplayInfo(message string) { var msg UIMessage if c.compactMode { @@ -227,7 +260,8 @@ func (c *CLI) DisplayInfo(message string) { c.displayContainer() } -// DisplayCancellation displays a cancellation message +// DisplayCancellation displays a system message indicating that the current +// AI generation has been cancelled by the user (typically via ESC key). func (c *CLI) DisplayCancellation() { var msg UIMessage if c.compactMode { @@ -239,7 +273,9 @@ func (c *CLI) DisplayCancellation() { c.displayContainer() } -// DisplayDebugMessage displays debug messages using the appropriate renderer +// DisplayDebugMessage renders and displays a debug message if debug mode is enabled. +// Debug messages are formatted distinctively and only shown when the CLI is +// initialized with debug=true. func (c *CLI) DisplayDebugMessage(message string) { if !c.debug { return @@ -254,7 +290,9 @@ func (c *CLI) DisplayDebugMessage(message string) { c.displayContainer() } -// DisplayDebugConfig displays configuration settings using the appropriate renderer +// DisplayDebugConfig renders and displays configuration settings in a formatted +// debug message. The config parameter should contain key-value pairs representing +// configuration options that will be displayed for debugging purposes. func (c *CLI) DisplayDebugConfig(config map[string]any) { var msg UIMessage if c.compactMode { @@ -266,7 +304,9 @@ func (c *CLI) DisplayDebugConfig(config map[string]any) { c.displayContainer() } -// DisplayHelp displays help information in a message block +// DisplayHelp renders and displays comprehensive help information showing all +// available slash commands, keyboard shortcuts, and usage instructions in a +// formatted system message block. func (c *CLI) DisplayHelp() { help := `## Available Commands @@ -288,7 +328,9 @@ You can also just type your message to chat with the AI assistant.` c.displayContainer() } -// DisplayTools displays available tools in a message block +// DisplayTools renders and displays a formatted list of all available tools +// that can be used by the AI assistant. Each tool is numbered and shown in +// a system message block for easy reference. func (c *CLI) DisplayTools(tools []string) { var content strings.Builder content.WriteString("## Available Tools\n\n") @@ -307,7 +349,9 @@ func (c *CLI) DisplayTools(tools []string) { c.displayContainer() } -// DisplayServers displays configured MCP servers in a message block +// DisplayServers renders and displays a formatted list of all configured MCP +// (Model Context Protocol) servers. Each server is numbered and shown in a +// system message block for easy reference. func (c *CLI) DisplayServers(servers []string) { var content strings.Builder content.WriteString("## Configured MCP Servers\n\n") @@ -326,18 +370,25 @@ func (c *CLI) DisplayServers(servers []string) { c.displayContainer() } -// IsSlashCommand checks if the input is a slash command +// IsSlashCommand determines whether the provided input string is a slash command +// by checking if it starts with a forward slash (/). Returns true for commands +// like "/help", "/tools", etc. func (c *CLI) IsSlashCommand(input string) bool { return strings.HasPrefix(input, "/") } -// SlashCommandResult represents the result of handling a slash command +// SlashCommandResult encapsulates the outcome of processing a slash command, +// indicating whether the command was recognized and handled, and whether the +// conversation history should be cleared as a result of the command. type SlashCommandResult struct { Handled bool ClearHistory bool } -// HandleSlashCommand handles slash commands and returns the result +// HandleSlashCommand processes and executes slash commands, returning a result +// that indicates whether the command was handled and any side effects. The servers +// and tools parameters provide context for commands that display available resources. +// Supported commands include /help, /tools, /servers, /clear, /usage, /reset-usage, and /quit. func (c *CLI) HandleSlashCommand(input string, servers []string, tools []string) SlashCommandResult { switch input { case "/help": @@ -369,7 +420,9 @@ func (c *CLI) HandleSlashCommand(input string, servers []string, tools []string) } } -// ClearMessages clears all messages from the container +// ClearMessages removes all messages from the display container and refreshes +// the screen. This is typically used when starting a new conversation or +// clearing the chat history. func (c *CLI) ClearMessages() { c.messageContainer.Clear() c.displayContainer() @@ -422,14 +475,19 @@ func (c *CLI) displayContainer() { } } -// UpdateUsage updates the usage tracker with token counts and costs +// UpdateUsage estimates and records token usage based on input and output text. +// This method uses text-based estimation when actual token counts are not available +// from the AI provider's response metadata. func (c *CLI) UpdateUsage(inputText, outputText string) { if c.usageTracker != nil { c.usageTracker.EstimateAndUpdateUsage(inputText, outputText) } } -// UpdateUsageFromResponse updates the usage tracker using token usage from response metadata +// UpdateUsageFromResponse records token usage using metadata from the AI provider's +// response when available. Falls back to text-based estimation if the metadata is +// missing or appears unreliable. This provides more accurate usage tracking when +// providers supply token count information. func (c *CLI) UpdateUsageFromResponse(response *schema.Message, inputText string) { if c.usageTracker == nil { return @@ -461,7 +519,9 @@ func (c *CLI) UpdateUsageFromResponse(response *schema.Message, inputText string } } -// DisplayUsageStats displays current usage statistics +// DisplayUsageStats renders and displays comprehensive token usage statistics +// including the last request's token counts and costs, as well as session totals. +// Shows a message if usage tracking is not available for the current model. func (c *CLI) DisplayUsageStats() { if c.usageTracker == nil { c.DisplayInfo("Usage tracking is not available for this model.") @@ -492,7 +552,9 @@ func (c *CLI) DisplayUsageStats() { c.displayContainer() } -// ResetUsageStats resets the usage tracking statistics +// ResetUsageStats clears all accumulated usage statistics, resetting token counts +// and costs to zero. Displays a confirmation message after resetting or an info +// message if usage tracking is not available. func (c *CLI) ResetUsageStats() { if c.usageTracker == nil { c.DisplayInfo("Usage tracking is not available for this model.") @@ -503,7 +565,9 @@ func (c *CLI) ResetUsageStats() { c.DisplayInfo("Usage statistics have been reset.") } -// DisplayUsageAfterResponse displays usage information immediately after a response +// DisplayUsageAfterResponse renders and displays token usage information immediately +// following an AI response. This provides real-time feedback about the cost and +// token consumption of each interaction. func (c *CLI) DisplayUsageAfterResponse() { if c.usageTracker == nil { return diff --git a/internal/ui/commands.go b/internal/ui/commands.go index 02200ccb..fbd52ac5 100644 --- a/internal/ui/commands.go +++ b/internal/ui/commands.go @@ -1,6 +1,8 @@ package ui -// SlashCommand represents a slash command with its metadata +// SlashCommand represents a user-invokable slash command with its metadata. +// Commands can have multiple aliases and are organized by category for better +// discoverability and help display. type SlashCommand struct { Name string Description string @@ -8,7 +10,9 @@ type SlashCommand struct { Category string // e.g., "Navigation", "System", "Info" } -// SlashCommands is the registry of all available slash commands +// SlashCommands provides the global registry of all available slash commands +// in the application. Commands are organized by category (Info, System) and +// include their primary names, descriptions, and alternative aliases. var SlashCommands = []SlashCommand{ { Name: "/help", @@ -55,7 +59,9 @@ var SlashCommands = []SlashCommand{ }, } -// GetCommandByName returns a command by its name or alias +// GetCommandByName looks up a slash command by its primary name or any of its +// aliases. Returns a pointer to the matching SlashCommand, or nil if no command +// matches the provided name. func GetCommandByName(name string) *SlashCommand { for i := range SlashCommands { cmd := &SlashCommands[i] @@ -71,7 +77,9 @@ func GetCommandByName(name string) *SlashCommand { return nil } -// GetAllCommandNames returns all command names and aliases +// GetAllCommandNames returns a complete list of all command names and their aliases. +// This is useful for command completion, validation, and help display. The returned +// slice contains both primary command names and all alternative aliases. func GetAllCommandNames() []string { var names []string for _, cmd := range SlashCommands { diff --git a/internal/ui/compact_renderer.go b/internal/ui/compact_renderer.go index b394a232..1b614a1b 100644 --- a/internal/ui/compact_renderer.go +++ b/internal/ui/compact_renderer.go @@ -8,13 +8,17 @@ import ( "github.com/charmbracelet/lipgloss" ) -// CompactRenderer handles rendering messages in compact format +// CompactRenderer handles rendering messages in a space-efficient compact format, +// optimized for terminals with limited vertical space. It displays messages with +// minimal decorations while maintaining readability and essential information. type CompactRenderer struct { width int debug bool } -// NewCompactRenderer creates a new compact message renderer +// NewCompactRenderer creates and initializes a new CompactRenderer with the specified +// terminal width and debug mode setting. The width parameter determines line wrapping, +// while debug enables additional diagnostic output in rendered messages. func NewCompactRenderer(width int, debug bool) *CompactRenderer { return &CompactRenderer{ width: width, @@ -22,12 +26,16 @@ func NewCompactRenderer(width int, debug bool) *CompactRenderer { } } -// SetWidth updates the renderer width +// SetWidth updates the terminal width for the renderer, affecting how content +// is wrapped and formatted in subsequent render operations. func (r *CompactRenderer) SetWidth(width int) { r.width = width } -// RenderUserMessage renders a user message in compact format +// RenderUserMessage renders a user's input message in compact format with a +// distinctive symbol (>) and label. The content is formatted to preserve structure +// while minimizing vertical space usage. Returns a UIMessage with formatted content +// and metadata. func (r *CompactRenderer) RenderUserMessage(content string, timestamp time.Time) UIMessage { theme := getTheme() symbol := lipgloss.NewStyle().Foreground(theme.Secondary).Render(">") @@ -58,7 +66,9 @@ func (r *CompactRenderer) RenderUserMessage(content string, timestamp time.Time) } } -// RenderAssistantMessage renders an assistant message in compact format +// RenderAssistantMessage renders an AI assistant's response in compact format with +// a distinctive symbol (<) and the model name as label. Empty content is displayed +// as "(no output)". Returns a UIMessage with formatted content and metadata. func (r *CompactRenderer) RenderAssistantMessage(content string, timestamp time.Time, modelName string) UIMessage { theme := getTheme() symbol := lipgloss.NewStyle().Foreground(theme.Primary).Render("<") @@ -97,7 +107,9 @@ func (r *CompactRenderer) RenderAssistantMessage(content string, timestamp time. } } -// RenderToolCallMessage renders a tool call in progress in compact format +// RenderToolCallMessage renders a tool call notification in compact format, showing +// the tool being executed with its arguments in a single line. The tool name is +// highlighted and arguments are displayed in a muted color for visual distinction. func (r *CompactRenderer) RenderToolCallMessage(toolName, toolArgs string, timestamp time.Time) UIMessage { theme := getTheme() symbol := lipgloss.NewStyle().Foreground(theme.Tool).Render("[") @@ -119,7 +131,9 @@ func (r *CompactRenderer) RenderToolCallMessage(toolName, toolArgs string, times } } -// RenderToolMessage renders a tool result in compact format +// RenderToolMessage renders the result of a tool execution in compact format, +// displaying the outcome with appropriate styling based on success or error status. +// Results are limited to 5 lines to maintain compact display while preserving key information. func (r *CompactRenderer) RenderToolMessage(toolName, toolArgs, toolResult string, isError bool) UIMessage { theme := getTheme() symbol := lipgloss.NewStyle().Foreground(theme.Muted).Render("]") @@ -165,7 +179,9 @@ func (r *CompactRenderer) RenderToolMessage(toolName, toolArgs, toolResult strin } } -// RenderSystemMessage renders a system message in compact format +// RenderSystemMessage renders a system notification or informational message in +// compact format with a distinctive symbol (*) and "System" label. Content is +// formatted to fit on a single line for minimal space usage. func (r *CompactRenderer) RenderSystemMessage(content string, timestamp time.Time) UIMessage { theme := getTheme() symbol := lipgloss.NewStyle().Foreground(theme.System).Render("*") @@ -183,7 +199,9 @@ func (r *CompactRenderer) RenderSystemMessage(content string, timestamp time.Tim } } -// RenderErrorMessage renders an error message in compact format +// RenderErrorMessage renders an error notification in compact format with a +// distinctive error symbol (!) and styling to ensure visibility. The error +// content is displayed in a single line with appropriate color highlighting. func (r *CompactRenderer) RenderErrorMessage(errorMsg string, timestamp time.Time) UIMessage { theme := getTheme() symbol := lipgloss.NewStyle().Foreground(theme.Error).Render("!") @@ -201,7 +219,9 @@ func (r *CompactRenderer) RenderErrorMessage(errorMsg string, timestamp time.Tim } } -// RenderDebugMessage renders debug messages in compact format +// RenderDebugMessage renders diagnostic information in compact format when debug +// mode is enabled. Messages are truncated if they exceed the available width to +// maintain single-line display. func (r *CompactRenderer) RenderDebugMessage(message string, timestamp time.Time) UIMessage { theme := getTheme() symbol := lipgloss.NewStyle().Foreground(theme.Tool).Render("*") @@ -223,7 +243,9 @@ func (r *CompactRenderer) RenderDebugMessage(message string, timestamp time.Time } } -// RenderDebugConfigMessage renders debug config in compact format +// RenderDebugConfigMessage renders configuration settings in compact format for +// debugging purposes. Config entries are displayed as key=value pairs separated +// by commas, truncated if necessary to fit on a single line. func (r *CompactRenderer) RenderDebugConfigMessage(config map[string]any, timestamp time.Time) UIMessage { theme := getTheme() symbol := lipgloss.NewStyle().Foreground(theme.Tool).Render("*") diff --git a/internal/ui/debug_logger.go b/internal/ui/debug_logger.go index f9fc52b2..30fc7a5e 100644 --- a/internal/ui/debug_logger.go +++ b/internal/ui/debug_logger.go @@ -6,17 +6,25 @@ import ( "time" ) -// CLIDebugLogger implements the tools.DebugLogger interface using CLI rendering +// CLIDebugLogger implements the tools.DebugLogger interface using CLI rendering. +// It provides debug logging functionality that integrates with the CLI's display +// system, ensuring debug messages are properly formatted and displayed alongside +// other conversation content. type CLIDebugLogger struct { cli *CLI } -// NewCLIDebugLogger creates a new CLI debug logger +// NewCLIDebugLogger creates and returns a new CLIDebugLogger instance that routes +// debug output through the provided CLI instance. The logger will respect the CLI's +// debug mode setting and display format preferences. func NewCLIDebugLogger(cli *CLI) *CLIDebugLogger { return &CLIDebugLogger{cli: cli} } -// LogDebug logs a debug message using the CLI's debug message renderer +// LogDebug processes and displays a debug message through the CLI's rendering system. +// Messages are formatted with appropriate emojis and tags based on their content type +// (DEBUG, POOL, etc.) and only displayed when debug mode is enabled. The method handles +// multi-line debug output and connection pool status messages with context-aware formatting. func (l *CLIDebugLogger) LogDebug(message string) { if l.cli == nil || !l.cli.debug { return @@ -70,7 +78,9 @@ func (l *CLIDebugLogger) LogDebug(message string) { l.cli.displayContainer() } -// IsDebugEnabled returns whether debug logging is enabled +// IsDebugEnabled checks whether debug logging is currently active. Returns true +// if the CLI instance exists and has debug mode enabled, allowing callers to +// conditionally perform expensive debug operations only when necessary. func (l *CLIDebugLogger) IsDebugEnabled() bool { return l.cli != nil && l.cli.debug } diff --git a/internal/ui/enhanced_styles.go b/internal/ui/enhanced_styles.go index 3f29db16..d54711f3 100644 --- a/internal/ui/enhanced_styles.go +++ b/internal/ui/enhanced_styles.go @@ -11,17 +11,21 @@ import ( // Global theme instance var currentTheme = DefaultTheme() -// GetTheme returns the current theme +// GetTheme returns the currently active UI theme. The theme controls all color +// and styling decisions throughout the application's interface. func GetTheme() Theme { return currentTheme } -// SetTheme sets the current theme +// SetTheme updates the global UI theme, affecting all subsequent rendering +// operations. This allows runtime theme switching for different visual preferences. func SetTheme(theme Theme) { currentTheme = theme } -// Theme represents a complete UI theme +// Theme defines a comprehensive color scheme for the application's UI, supporting +// both light and dark terminal modes through adaptive colors. It includes semantic +// colors for different message types and UI elements, based on the Catppuccin color palette. type Theme struct { Primary lipgloss.AdaptiveColor Secondary lipgloss.AdaptiveColor @@ -41,7 +45,9 @@ type Theme struct { Highlight lipgloss.AdaptiveColor } -// DefaultTheme returns the default MCPHost theme (Catppuccin Mocha) +// DefaultTheme creates and returns the default MCPHost theme based on the Catppuccin +// Mocha (dark) and Latte (light) color palettes. This theme provides a cohesive, +// pleasant visual experience with carefully selected colors for different UI elements. func DefaultTheme() Theme { return Theme{ Primary: lipgloss.AdaptiveColor{ @@ -111,7 +117,9 @@ func DefaultTheme() Theme { } } -// StyleCard creates a styled card container +// StyleCard creates a lipgloss style for card-like containers with rounded borders, +// padding, and appropriate width. Used for grouping related content in a visually +// distinct box. func StyleCard(width int, theme Theme) lipgloss.Style { return lipgloss.NewStyle(). Width(width). @@ -121,56 +129,64 @@ func StyleCard(width int, theme Theme) lipgloss.Style { MarginBottom(1) } -// StyleHeader creates a styled header +// StyleHeader creates a lipgloss style for primary headers using the theme's +// primary color with bold text for emphasis and hierarchy. func StyleHeader(theme Theme) lipgloss.Style { return lipgloss.NewStyle(). Foreground(theme.Primary). Bold(true) } -// StyleSubheader creates a styled subheader +// StyleSubheader creates a lipgloss style for secondary headers using the theme's +// secondary color with bold text, providing visual hierarchy below primary headers. func StyleSubheader(theme Theme) lipgloss.Style { return lipgloss.NewStyle(). Foreground(theme.Secondary). Bold(true) } -// StyleMuted creates muted text styling +// StyleMuted creates a lipgloss style for de-emphasized text using muted colors +// and italic formatting, suitable for supplementary or less important information. func StyleMuted(theme Theme) lipgloss.Style { return lipgloss.NewStyle(). Foreground(theme.Muted). Italic(true) } -// StyleSuccess creates success text styling +// StyleSuccess creates a lipgloss style for success messages using green colors +// with bold text to indicate successful operations or positive outcomes. func StyleSuccess(theme Theme) lipgloss.Style { return lipgloss.NewStyle(). Foreground(theme.Success). Bold(true) } -// StyleError creates error text styling +// StyleError creates a lipgloss style for error messages using red colors +// with bold text to ensure visibility of problems or failures. func StyleError(theme Theme) lipgloss.Style { return lipgloss.NewStyle(). Foreground(theme.Error). Bold(true) } -// StyleWarning creates warning text styling +// StyleWarning creates a lipgloss style for warning messages using yellow/amber +// colors with bold text to draw attention to potential issues or cautions. func StyleWarning(theme Theme) lipgloss.Style { return lipgloss.NewStyle(). Foreground(theme.Warning). Bold(true) } -// StyleInfo creates info text styling +// StyleInfo creates a lipgloss style for informational messages using blue colors +// with bold text for general notifications and status updates. func StyleInfo(theme Theme) lipgloss.Style { return lipgloss.NewStyle(). Foreground(theme.Info). Bold(true) } -// CreateSeparator creates a styled separator line +// CreateSeparator generates a horizontal separator line with the specified width, +// character, and color. Useful for visually dividing sections of content in the UI. func CreateSeparator(width int, char string, color lipgloss.AdaptiveColor) string { return lipgloss.NewStyle(). Foreground(color). @@ -178,7 +194,9 @@ func CreateSeparator(width int, char string, color lipgloss.AdaptiveColor) strin Render(lipgloss.PlaceHorizontal(width, lipgloss.Center, char)) } -// CreateProgressBar creates a simple progress bar +// CreateProgressBar generates a visual progress bar with filled and empty segments +// based on the percentage complete. The bar uses Unicode block characters for smooth +// appearance and theme colors to indicate progress. func CreateProgressBar(width int, percentage float64, theme Theme) string { filled := int(float64(width) * percentage / 100) empty := width - filled @@ -194,7 +212,8 @@ func CreateProgressBar(width int, percentage float64, theme Theme) string { return filledBar + emptyBar } -// CreateBadge creates a styled badge +// CreateBadge generates a styled badge or label with inverted colors (text on +// colored background) for highlighting important tags, statuses, or categories. func CreateBadge(text string, color lipgloss.AdaptiveColor) string { return lipgloss.NewStyle(). Foreground(lipgloss.AdaptiveColor{Light: "#FFFFFF", Dark: "#000000"}). @@ -204,7 +223,9 @@ func CreateBadge(text string, color lipgloss.AdaptiveColor) string { Render(text) } -// CreateGradientText creates text with gradient-like effect using different shades +// CreateGradientText creates styled text with a gradient-like effect. Currently +// implements a simplified version using the start color only, as true gradients +// require more complex terminal capabilities. func CreateGradientText(text string, startColor, endColor lipgloss.AdaptiveColor) string { // For now, just use the start color - true gradients would require more complex implementation return lipgloss.NewStyle(). @@ -215,14 +236,16 @@ func CreateGradientText(text string, startColor, endColor lipgloss.AdaptiveColor // Compact styling utilities -// StyleCompactSymbol creates a styled symbol for compact mode +// StyleCompactSymbol creates a lipgloss style for message type indicators in +// compact mode, using bold colored text to distinguish different message categories. func StyleCompactSymbol(symbol string, color lipgloss.AdaptiveColor) lipgloss.Style { return lipgloss.NewStyle(). Foreground(color). Bold(true) } -// StyleCompactLabel creates a styled label for compact mode +// StyleCompactLabel creates a lipgloss style for message labels in compact mode +// with fixed width for alignment and bold colored text for readability. func StyleCompactLabel(color lipgloss.AdaptiveColor) lipgloss.Style { return lipgloss.NewStyle(). Foreground(color). @@ -230,13 +253,16 @@ func StyleCompactLabel(color lipgloss.AdaptiveColor) lipgloss.Style { Width(8) } -// StyleCompactContent creates basic content styling for compact mode +// StyleCompactContent creates a simple lipgloss style for message content in +// compact mode, applying only color without additional formatting. func StyleCompactContent(color lipgloss.AdaptiveColor) lipgloss.Style { return lipgloss.NewStyle(). Foreground(color) } -// FormatCompactLine formats a complete compact line with consistent spacing +// FormatCompactLine assembles a complete compact mode message line with consistent +// spacing and styling. Combines a symbol, fixed-width label, and content with their +// respective colors to create a uniform appearance across all message types. func FormatCompactLine(symbol, label, content string, symbolColor, labelColor, contentColor lipgloss.AdaptiveColor) string { styledSymbol := StyleCompactSymbol(symbol, symbolColor).Render(symbol) styledLabel := StyleCompactLabel(labelColor).Render(label) diff --git a/internal/ui/factory.go b/internal/ui/factory.go index c5ffcac1..4a57d3dc 100644 --- a/internal/ui/factory.go +++ b/internal/ui/factory.go @@ -8,14 +8,17 @@ import ( "github.com/mark3labs/mcphost/internal/models" ) -// AgentInterface defines the interface we need from agent to avoid import cycles +// AgentInterface defines the minimal interface required from the agent package +// to avoid circular dependencies while still accessing necessary agent functionality. type AgentInterface interface { GetLoadingMessage() string GetTools() []any // Using any to avoid importing tool types GetLoadedServerNames() []string // Add this method for debug config } -// CLISetupOptions contains options for setting up CLI +// CLISetupOptions encapsulates all configuration parameters needed to initialize +// and set up a CLI instance, including display preferences, model information, +// and debugging settings. type CLISetupOptions struct { Agent AgentInterface ModelString string @@ -35,7 +38,10 @@ func parseModelName(modelString string) (provider, model string) { return "unknown", "unknown" } -// SetupCLI creates and configures CLI with standard info display +// SetupCLI creates, configures, and initializes a CLI instance with the provided +// options. It sets up model display, usage tracking for supported providers, and +// shows initial loading information. Returns nil in quiet mode or an initialized +// CLI instance ready for user interaction. func SetupCLI(opts *CLISetupOptions) (*CLI, error) { if opts.Quiet { return nil, nil // No CLI in quiet mode diff --git a/internal/ui/fuzzy.go b/internal/ui/fuzzy.go index 1c7bddd3..cbc6a231 100644 --- a/internal/ui/fuzzy.go +++ b/internal/ui/fuzzy.go @@ -4,13 +4,17 @@ import ( "strings" ) -// FuzzyMatch represents a match result with score +// FuzzyMatch represents the result of a fuzzy string matching operation, +// containing the matched command and its relevance score. Higher scores +// indicate better matches. type FuzzyMatch struct { Command *SlashCommand Score int } -// FuzzyMatchCommands performs fuzzy matching on slash commands +// FuzzyMatchCommands performs fuzzy string matching on the provided slash commands +// based on the query string. Returns a slice of matches sorted by relevance score +// in descending order. An empty query returns all commands with zero scores. func FuzzyMatchCommands(query string, commands []SlashCommand) []FuzzyMatch { if query == "" || query == "/" { // Return all commands when query is empty or just "/" diff --git a/internal/ui/messages.go b/internal/ui/messages.go index ed4abdfd..eb2dfe13 100644 --- a/internal/ui/messages.go +++ b/internal/ui/messages.go @@ -10,7 +10,8 @@ import ( "github.com/charmbracelet/lipgloss" ) -// MessageType represents the type of message +// MessageType represents different categories of messages displayed in the UI, +// each with distinct visual styling and formatting rules. type MessageType int const ( @@ -22,7 +23,9 @@ const ( ErrorMessage // New type for error messages ) -// UIMessage represents a rendered message for display +// UIMessage encapsulates a fully rendered message ready for display in the UI, +// including its formatted content, display metrics, and metadata. Messages can +// be static or streaming (progressively updated). type UIMessage struct { ID string Type MessageType @@ -38,7 +41,9 @@ func getTheme() Theme { return GetTheme() } -// MessageRenderer handles rendering of messages with proper styling +// MessageRenderer handles the formatting and rendering of different message types +// with consistent styling, markdown support, and appropriate visual hierarchies +// for the standard (non-compact) display mode. type MessageRenderer struct { width int debug bool @@ -59,7 +64,9 @@ func getSystemUsername() string { return "User" } -// NewMessageRenderer creates a new message renderer +// NewMessageRenderer creates and initializes a new MessageRenderer with the specified +// terminal width and debug mode setting. The width parameter determines line wrapping +// and layout calculations. func NewMessageRenderer(width int, debug bool) *MessageRenderer { return &MessageRenderer{ width: width, @@ -67,12 +74,15 @@ func NewMessageRenderer(width int, debug bool) *MessageRenderer { } } -// SetWidth updates the renderer width +// SetWidth updates the terminal width for the renderer, affecting how content +// is wrapped and formatted in subsequent render operations. func (r *MessageRenderer) SetWidth(width int) { r.width = width } -// RenderUserMessage renders a user message with right border and background header +// RenderUserMessage renders a user's input message with distinctive right-aligned +// formatting, including the system username, timestamp, and markdown-rendered content. +// The message is displayed with a colored right border for visual distinction. func (r *MessageRenderer) RenderUserMessage(content string, timestamp time.Time) UIMessage { // Format timestamp and username timeStr := timestamp.Local().Format("15:04") @@ -106,7 +116,10 @@ func (r *MessageRenderer) RenderUserMessage(content string, timestamp time.Time) } } -// RenderAssistantMessage renders an assistant message with left border and background header +// RenderAssistantMessage renders an AI assistant's response with left-aligned formatting, +// including the model name, timestamp, and markdown-rendered content. Empty responses +// are displayed with a special "Finished without output" message. The message features +// a colored left border for visual distinction. func (r *MessageRenderer) RenderAssistantMessage(content string, timestamp time.Time, modelName string) UIMessage { // Format timestamp and model info with better defaults timeStr := timestamp.Local().Format("15:04") @@ -151,7 +164,9 @@ func (r *MessageRenderer) RenderAssistantMessage(content string, timestamp time. } } -// RenderSystemMessage renders a system message with left border and background header +// RenderSystemMessage renders MCPHost system messages such as help text, command outputs, +// and informational notifications. These messages are displayed with a distinctive system +// color border and "MCPHost System" label to differentiate them from user and AI content. func (r *MessageRenderer) RenderSystemMessage(content string, timestamp time.Time) UIMessage { // Format timestamp timeStr := timestamp.Local().Format("15:04") @@ -193,7 +208,9 @@ func (r *MessageRenderer) RenderSystemMessage(content string, timestamp time.Tim } } -// RenderDebugMessage renders debug messages with tool response block styling +// RenderDebugMessage renders diagnostic and debugging information with special formatting +// including a debug icon, colored border, and structured layout. Debug messages are only +// displayed when debug mode is enabled and help developers troubleshoot issues. func (r *MessageRenderer) RenderDebugMessage(message string, timestamp time.Time) UIMessage { baseStyle := lipgloss.NewStyle() @@ -251,7 +268,9 @@ func (r *MessageRenderer) RenderDebugMessage(message string, timestamp time.Time } } -// RenderDebugConfigMessage renders debug configuration settings with tool response block styling +// RenderDebugConfigMessage renders configuration settings in a formatted debug display +// with key-value pairs shown in a structured layout. Used to display runtime configuration +// for debugging purposes with a distinctive icon and border styling. func (r *MessageRenderer) RenderDebugConfigMessage(config map[string]any, timestamp time.Time) UIMessage { baseStyle := lipgloss.NewStyle() @@ -311,7 +330,9 @@ func (r *MessageRenderer) RenderDebugConfigMessage(config map[string]any, timest } } -// RenderErrorMessage renders an error message with left border and background header +// RenderErrorMessage renders error notifications with distinctive red coloring and +// bold text to ensure visibility. Error messages include timestamp information and +// are displayed with an error-colored border for immediate recognition. func (r *MessageRenderer) RenderErrorMessage(errorMsg string, timestamp time.Time) UIMessage { // Format timestamp timeStr := timestamp.Local().Format("15:04") @@ -347,7 +368,9 @@ func (r *MessageRenderer) RenderErrorMessage(errorMsg string, timestamp time.Tim } } -// RenderToolCallMessage renders a tool call in progress with left border and background header +// RenderToolCallMessage renders a notification that a tool is being executed, showing +// the tool name, formatted arguments (if any), and execution timestamp. The message +// uses tool-specific coloring to distinguish it from regular conversation messages. func (r *MessageRenderer) RenderToolCallMessage(toolName, toolArgs string, timestamp time.Time) UIMessage { // Format timestamp timeStr := timestamp.Local().Format("15:04") @@ -391,7 +414,10 @@ func (r *MessageRenderer) RenderToolCallMessage(toolName, toolArgs string, times } } -// RenderToolMessage renders a tool call message with proper styling +// RenderToolMessage renders the result of a tool execution, formatting the output +// based on the tool type and whether it succeeded or failed. Error results are +// displayed in red, while successful results are formatted according to the tool's +// output type (bash, file content, etc.). func (r *MessageRenderer) RenderToolMessage(toolName, toolArgs, toolResult string, isError bool) UIMessage { theme := getTheme() @@ -595,7 +621,9 @@ func (r *MessageRenderer) renderMarkdown(content string, width int) string { return strings.TrimSuffix(rendered, "\n") } -// MessageContainer wraps multiple messages in a container +// MessageContainer manages a collection of UI messages, handling their display, +// updates, and layout within the terminal. It supports both standard and compact +// display modes and maintains state for streaming message updates. type MessageContainer struct { messages []UIMessage width int @@ -605,7 +633,9 @@ type MessageContainer struct { wasCleared bool // Track if container was explicitly cleared } -// NewMessageContainer creates a new message container +// NewMessageContainer creates and initializes a new MessageContainer with the +// specified dimensions and display mode. The container starts empty and will +// display a welcome message until the first message is added. func NewMessageContainer(width, height int, compact bool) *MessageContainer { return &MessageContainer{ messages: make([]UIMessage, 0), @@ -615,18 +645,22 @@ func NewMessageContainer(width, height int, compact bool) *MessageContainer { } } -// AddMessage adds a message to the container +// AddMessage appends a new UIMessage to the container's collection and resets +// the cleared state flag. Messages are displayed in the order they were added. func (c *MessageContainer) AddMessage(msg UIMessage) { c.messages = append(c.messages, msg) c.wasCleared = false // Reset the cleared flag when adding messages } -// SetModelName sets the current model name for the container +// SetModelName updates the AI model name used for rendering assistant messages. +// This name is displayed in message headers to indicate which model is responding. func (c *MessageContainer) SetModelName(modelName string) { c.modelName = modelName } -// UpdateLastMessage updates the content of the last message efficiently +// UpdateLastMessage efficiently updates the content of the most recent message +// in the container. This is primarily used for streaming responses where the +// assistant's message is progressively built. Only works for assistant messages. func (c *MessageContainer) UpdateLastMessage(content string) { if len(c.messages) == 0 { return @@ -651,19 +685,23 @@ func (c *MessageContainer) UpdateLastMessage(content string) { } } -// Clear clears all messages from the container +// Clear removes all messages from the container and sets a flag to prevent +// showing the welcome screen. Used when starting a fresh conversation. func (c *MessageContainer) Clear() { c.messages = make([]UIMessage, 0) c.wasCleared = true } -// SetSize updates the container size +// SetSize updates the container's dimensions, typically called when the terminal +// is resized. This affects how messages are wrapped and displayed. func (c *MessageContainer) SetSize(width, height int) { c.width = width c.height = height } -// Render renders all messages in the container +// Render generates the complete visual representation of all messages in the +// container. Returns an empty state display if no messages exist, or formats +// all messages according to the current display mode (standard or compact). func (c *MessageContainer) Render() string { if len(c.messages) == 0 { // Don't show welcome box if explicitly cleared diff --git a/internal/ui/progress/ollama.go b/internal/ui/progress/ollama.go index 7df4ca66..258dc5b9 100644 --- a/internal/ui/progress/ollama.go +++ b/internal/ui/progress/ollama.go @@ -21,7 +21,9 @@ const ( maxWidth = 80 ) -// OllamaPullProgress represents the progress information from Ollama pull API +// OllamaPullProgress represents the progress information received from Ollama's +// pull API when downloading model files. It includes status messages, digest +// information, and download progress counters. type OllamaPullProgress struct { Status string `json:"status"` Digest string `json:"digest,omitempty"` @@ -41,7 +43,9 @@ type progressErrMsg struct{ err error } // progressCompleteMsg indicates completion type progressCompleteMsg struct{} -// ProgressModel represents the progress bar model +// ProgressModel implements a tea.Model for displaying download progress with +// a visual progress bar and status messages. It handles progress updates, +// errors, and completion states for Ollama model downloads. type ProgressModel struct { progress progress.Model status string @@ -49,7 +53,8 @@ type ProgressModel struct { complete bool } -// NewProgressModel creates a new progress model +// NewProgressModel creates and initializes a new ProgressModel with a gradient +// progress bar and initial "Initializing..." status message. func NewProgressModel() ProgressModel { return ProgressModel{ progress: progress.New(progress.WithDefaultGradient()), @@ -57,12 +62,15 @@ func NewProgressModel() ProgressModel { } } -// Init initializes the progress model +// Init implements the tea.Model interface, returning nil as no initial commands +// are needed for the progress display. func (m ProgressModel) Init() tea.Cmd { return nil } -// Update handles progress updates +// Update implements the tea.Model interface, handling keyboard input, window +// resize events, and progress updates. It manages the progress bar state and +// triggers program exit on completion or cancellation. func (m ProgressModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { case tea.KeyMsg: @@ -108,7 +116,9 @@ func (m ProgressModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } } -// View renders the progress bar +// View implements the tea.Model interface, rendering the progress bar with +// status information and help text. Displays error messages if present or +// a completion message when the download finishes. func (m ProgressModel) View() string { if m.err != nil { return fmt.Sprintf("Error: %s\n", m.err.Error()) @@ -128,7 +138,9 @@ func (m ProgressModel) View() string { pad+helpStyle("Press 'q' or Ctrl+C to cancel")) } -// ProgressReader wraps an io.Reader to provide progress updates for Ollama pull operations +// ProgressReader wraps an io.Reader to intercept and parse Ollama pull operation +// responses, extracting progress information and updating a visual progress bar. +// It manages a tea.Program for the UI and handles graceful shutdown. type ProgressReader struct { reader io.Reader program *tea.Program @@ -138,7 +150,9 @@ type ProgressReader struct { wg sync.WaitGroup } -// NewProgressReader creates a new progress reader for Ollama pull operations +// NewProgressReader creates and initializes a ProgressReader that wraps the provided +// io.Reader. It starts a tea.Program in a separate goroutine to display the progress +// bar UI while reading and parsing Ollama's streaming JSON responses. func NewProgressReader(reader io.Reader) *ProgressReader { model := NewProgressModel() // Create program with standard settings @@ -164,7 +178,9 @@ func NewProgressReader(reader io.Reader) *ProgressReader { return pr } -// Read implements io.Reader and parses Ollama streaming responses +// Read implements the io.Reader interface, passing through data from the wrapped +// reader while parsing JSON lines to extract progress information. Each complete +// JSON line is processed to update the progress bar display. func (pr *ProgressReader) Read(p []byte) (n int, err error) { n, err = pr.reader.Read(p) if n > 0 { @@ -239,7 +255,9 @@ func (pr *ProgressReader) parseProgressLine(line string) { }) } -// Close stops the progress display and waits for cleanup +// Close gracefully shuts down the progress display, sending a completion message +// and waiting for the tea.Program to exit. If the program doesn't exit within +// 2 seconds, it is forcefully terminated to prevent hanging. func (pr *ProgressReader) Close() error { // Send completion message to trigger quit pr.program.Send(progressCompleteMsg{}) diff --git a/internal/ui/slash_command_input.go b/internal/ui/slash_command_input.go index a2cac6e9..01d81edf 100644 --- a/internal/ui/slash_command_input.go +++ b/internal/ui/slash_command_input.go @@ -9,7 +9,9 @@ import ( "github.com/charmbracelet/lipgloss" ) -// SlashCommandInput is a custom input field with slash command autocomplete +// SlashCommandInput provides an interactive text input field with intelligent +// slash command autocomplete functionality. It displays a popup menu of matching +// commands as the user types, supporting fuzzy matching and keyboard navigation. type SlashCommandInput struct { textarea textarea.Model commands []SlashCommand @@ -26,7 +28,9 @@ type SlashCommandInput struct { renderedLines int // Track how many lines were rendered } -// NewSlashCommandInput creates a new slash command input field +// NewSlashCommandInput creates and initializes a new slash command input field with +// the specified width and title. The input supports multi-line text entry, command +// autocomplete, and is styled to match the application's theme. func NewSlashCommandInput(width int, title string) *SlashCommandInput { ta := textarea.New() ta.Placeholder = "Type your message..." @@ -54,12 +58,15 @@ func NewSlashCommandInput(width int, title string) *SlashCommandInput { } } -// Init implements tea.Model +// Init implements the tea.Model interface, returning the initial command to start +// the cursor blinking animation for the text input field. func (s *SlashCommandInput) Init() tea.Cmd { return textarea.Blink } -// Update implements tea.Model +// Update implements the tea.Model interface, handling keyboard input for text entry, +// command selection, and navigation. Manages the autocomplete popup display and +// processes submission or cancellation actions. func (s *SlashCommandInput) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmd tea.Cmd @@ -168,7 +175,9 @@ func (s *SlashCommandInput) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } } -// View implements tea.Model +// View implements the tea.Model interface, rendering the complete input field +// including the title, text area, autocomplete popup (when active), and help text. +// The view adapts based on whether single or multi-line input is detected. func (s *SlashCommandInput) View() string { // Add left padding to entire component (2 spaces like other UI elements) containerStyle := lipgloss.NewStyle().PaddingLeft(2) @@ -324,17 +333,21 @@ func (s *SlashCommandInput) renderPopup() string { return popupStyle.Render(popupContent) } -// Value returns the final value +// Value returns the final text value entered by the user after submission. +// This will be empty if the input was cancelled. func (s *SlashCommandInput) Value() string { return s.value } -// Cancelled returns true if the user cancelled +// Cancelled returns true if the user cancelled the input operation (e.g., by +// pressing ESC or Ctrl+C) without submitting any text. func (s *SlashCommandInput) Cancelled() bool { return s.quitting && s.value == "" } -// RenderedLines returns how many lines were rendered +// RenderedLines returns the total number of terminal lines used by the last +// rendered view, including the title, input area, popup, and help text. This +// is used for proper screen clearing when the input is dismissed. func (s *SlashCommandInput) RenderedLines() int { return s.renderedLines } diff --git a/internal/ui/spinner.go b/internal/ui/spinner.go index 367a537b..994e3a49 100644 --- a/internal/ui/spinner.go +++ b/internal/ui/spinner.go @@ -10,7 +10,9 @@ import ( "github.com/charmbracelet/lipgloss" ) -// Spinner wraps the bubbles spinner for both interactive and non-interactive mode +// Spinner provides an animated loading indicator that displays while long-running +// operations are in progress. It wraps the bubbles spinner component and manages +// its lifecycle through a tea.Program for proper terminal handling. type Spinner struct { model spinner.Model done chan struct{} @@ -72,7 +74,9 @@ func (m spinnerModel) View() string { // quitMsg is sent when we want to quit the spinner type quitMsg struct{} -// NewSpinner creates a new spinner with enhanced styling +// NewSpinner creates a new animated spinner with the specified message. The spinner +// uses the theme's primary color and a modern animation style. It runs in a separate +// tea.Program to avoid interfering with other terminal operations. func NewSpinner(message string) *Spinner { s := spinner.New() s.Spinner = spinner.Points // More modern spinner style @@ -97,7 +101,9 @@ func NewSpinner(message string) *Spinner { } } -// NewThemedSpinner creates a new spinner with the given message and color +// NewThemedSpinner creates a new animated spinner with custom color styling. +// This allows for different spinner colors based on the operation type or status. +// The spinner runs independently in its own tea.Program. func NewThemedSpinner(message string, color lipgloss.AdaptiveColor) *Spinner { s := spinner.New() s.Spinner = spinner.Dot @@ -121,7 +127,9 @@ func NewThemedSpinner(message string, color lipgloss.AdaptiveColor) *Spinner { } } -// Start begins the spinner animation +// Start begins the spinner animation in a separate goroutine. The spinner will +// continue animating until Stop is called. The animation runs in a separate +// tea.Program to maintain smooth animation independent of other operations. func (s *Spinner) Start() { go func() { defer close(s.done) @@ -136,7 +144,8 @@ func (s *Spinner) Start() { }() } -// Stop ends the spinner animation +// Stop halts the spinner animation and cleans up resources. This method blocks +// until the spinner has fully stopped and the terminal state is restored. func (s *Spinner) Stop() { s.cancel() <-s.done diff --git a/internal/ui/styles.go b/internal/ui/styles.go index 17b0c622..f4d092e9 100644 --- a/internal/ui/styles.go +++ b/internal/ui/styles.go @@ -15,12 +15,16 @@ func boolPtr(b bool) *bool { return &b } func stringPtr(s string) *string { return &s } func uintPtr(u uint) *uint { return &u } -// BaseStyle returns a basic lipgloss style +// BaseStyle returns a new, empty lipgloss style that can be customized with +// additional styling methods. This serves as the foundation for building more +// complex styled components. func BaseStyle() lipgloss.Style { return lipgloss.NewStyle() } -// GetMarkdownRenderer returns a glamour TermRenderer configured for our use +// GetMarkdownRenderer creates and returns a configured glamour.TermRenderer for +// rendering markdown content with syntax highlighting and proper formatting. The +// renderer is customized with our theme colors and adapted to the specified width. func GetMarkdownRenderer(width int) *glamour.TermRenderer { r, _ := glamour.NewTermRenderer( glamour.WithStyles(generateMarkdownStyleConfig()), diff --git a/internal/ui/usage_tracker.go b/internal/ui/usage_tracker.go index f79e1dc0..f9656914 100644 --- a/internal/ui/usage_tracker.go +++ b/internal/ui/usage_tracker.go @@ -9,7 +9,9 @@ import ( "github.com/mark3labs/mcphost/internal/tokens" ) -// UsageStats represents token and cost information for a single request/response +// UsageStats encapsulates detailed token usage and cost breakdown for a single +// LLM request/response cycle, including input, output, and cache token counts +// along with their associated costs. type UsageStats struct { InputTokens int OutputTokens int @@ -22,7 +24,9 @@ type UsageStats struct { TotalCost float64 } -// SessionStats represents cumulative stats for the entire session +// SessionStats aggregates token usage and cost information across all requests +// in a session, providing totals and request counts for usage analysis and +// cost tracking. type SessionStats struct { TotalInputTokens int TotalOutputTokens int @@ -32,7 +36,9 @@ type SessionStats struct { RequestCount int } -// UsageTracker tracks token usage and costs for LLM interactions +// UsageTracker monitors and accumulates token usage statistics and associated costs +// for LLM interactions throughout a session. It provides real-time usage information +// and supports both estimated and actual token counts. OAuth users see $0 costs. type UsageTracker struct { mu sync.RWMutex modelInfo *models.ModelInfo @@ -43,7 +49,10 @@ type UsageTracker struct { isOAuth bool // Whether OAuth credentials are being used (costs should be $0) } -// NewUsageTracker creates a new usage tracker for the given model +// NewUsageTracker creates and initializes a new UsageTracker for the specified model. +// The tracker uses model-specific pricing information to calculate costs, unless OAuth +// credentials are being used (in which case costs are shown as $0). Width determines +// the display formatting. func NewUsageTracker(modelInfo *models.ModelInfo, provider string, width int, isOAuth bool) *UsageTracker { return &UsageTracker{ modelInfo: modelInfo, @@ -53,15 +62,19 @@ func NewUsageTracker(modelInfo *models.ModelInfo, provider string, width int, is } } -// EstimateTokens provides a rough estimate of tokens in text -// This is a simple approximation - real token counting would require the actual tokenizer +// EstimateTokens provides a rough estimate of the number of tokens in the given text. +// This uses a simple heuristic of approximately 4 characters per token, which is a +// reasonable approximation for most models but not precise. Actual token counts may vary +// significantly based on the specific tokenizer used by each model. func EstimateTokens(text string) int { // Rough approximation: ~4 characters per token for most models // This is not accurate but gives a reasonable estimate return len(text) / 4 } -// UpdateUsage updates the tracker with new usage information +// UpdateUsage records new token usage data and calculates associated costs based on +// the model's pricing. Updates both the last request statistics and cumulative session +// totals. For OAuth users, costs are recorded as $0 while still tracking token counts. func (ut *UsageTracker) UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens int) { ut.mu.Lock() defer ut.mu.Unlock() @@ -107,21 +120,27 @@ func (ut *UsageTracker) UpdateUsage(inputTokens, outputTokens, cacheReadTokens, ut.sessionStats.RequestCount++ } -// EstimateAndUpdateUsage estimates tokens from text and updates usage +// EstimateAndUpdateUsage estimates token counts from raw text strings and updates +// the usage statistics. This method is used when actual token counts are not available +// from the API response. func (ut *UsageTracker) EstimateAndUpdateUsage(inputText, outputText string) { inputTokens := tokens.EstimateTokens(inputText) outputTokens := tokens.EstimateTokens(outputText) ut.UpdateUsage(inputTokens, outputTokens, 0, 0) } -// EstimateAndUpdateUsageFromText estimates tokens from text and updates usage +// EstimateAndUpdateUsageFromText is an alias for EstimateAndUpdateUsage, providing +// backward compatibility. It estimates token counts from text and updates usage statistics. func (ut *UsageTracker) EstimateAndUpdateUsageFromText(inputText, outputText string) { inputTokens := tokens.EstimateTokens(inputText) outputTokens := tokens.EstimateTokens(outputText) ut.UpdateUsage(inputTokens, outputTokens, 0, 0) } -// RenderUsageInfo renders enhanced usage information with better styling +// RenderUsageInfo generates a formatted string displaying current usage statistics +// including token counts, context utilization percentage, and costs. The display +// adapts colors based on usage levels and formats large numbers with K/M suffixes +// for readability. func (ut *UsageTracker) RenderUsageInfo() string { ut.mu.RLock() defer ut.mu.RUnlock() @@ -199,14 +218,18 @@ func (ut *UsageTracker) RenderUsageInfo() string { tokensLabel, tokensValue, percentageStr, costLabel, costStr) } -// GetSessionStats returns a copy of the current session statistics +// GetSessionStats returns a copy of the cumulative session statistics including +// total token counts, costs, and request count. The returned copy is safe to use +// without additional synchronization. func (ut *UsageTracker) GetSessionStats() SessionStats { ut.mu.RLock() defer ut.mu.RUnlock() return ut.sessionStats } -// GetLastRequestStats returns a copy of the last request statistics +// GetLastRequestStats returns a copy of the usage statistics from the most recent +// request, or nil if no requests have been made. The returned copy is safe to use +// without additional synchronization. func (ut *UsageTracker) GetLastRequestStats() *UsageStats { ut.mu.RLock() defer ut.mu.RUnlock() @@ -217,7 +240,9 @@ func (ut *UsageTracker) GetLastRequestStats() *UsageStats { return &stats } -// Reset clears all usage statistics +// Reset clears all accumulated usage statistics, resetting both session totals +// and last request information to their initial empty state. This is typically +// used when starting a new conversation or clearing usage history. func (ut *UsageTracker) Reset() { ut.mu.Lock() defer ut.mu.Unlock() @@ -225,7 +250,9 @@ func (ut *UsageTracker) Reset() { ut.lastRequest = nil } -// SetWidth updates the display width for rendering +// SetWidth updates the terminal width used for formatting usage information display. +// This should be called when the terminal is resized to ensure proper text wrapping +// and alignment. func (ut *UsageTracker) SetWidth(width int) { ut.mu.Lock() defer ut.mu.Unlock() diff --git a/sdk/mcphost.go b/sdk/mcphost.go index a984555f..95c01db2 100644 --- a/sdk/mcphost.go +++ b/sdk/mcphost.go @@ -13,14 +13,18 @@ import ( "github.com/spf13/viper" ) -// MCPHost provides programmatic access to mcphost +// MCPHost provides programmatic access to mcphost functionality, allowing +// integration of MCP tools and LLM interactions into Go applications. It manages +// agents, sessions, and model configurations. type MCPHost struct { agent *agent.Agent sessionMgr *session.Manager modelString string } -// Options for creating MCPHost (all optional - will use CLI defaults) +// Options configures MCPHost creation with optional overrides for model, +// prompts, configuration, and behavior settings. All fields are optional +// and will use CLI defaults if not specified. type Options struct { Model string // Override model (e.g., "anthropic:claude-3-sonnet") SystemPrompt string // Override system prompt @@ -30,7 +34,9 @@ type Options struct { Quiet bool // Suppress debug output } -// New creates MCPHost instance using the same initialization as CLI +// New creates an MCPHost instance using the same initialization as the CLI. +// It loads configuration, initializes MCP servers, creates the LLM model, and +// sets up the agent for interaction. Returns an error if initialization fails. func New(ctx context.Context, opts *Options) (*MCPHost, error) { if opts == nil { opts = &Options{} @@ -118,7 +124,9 @@ func New(ctx context.Context, opts *Options) (*MCPHost, error) { }, nil } -// Prompt sends a message and returns the response +// Prompt sends a message to the agent and returns the response. The agent may +// use tools as needed to generate the response. The conversation history is +// automatically maintained in the session. Returns an error if generation fails. func (m *MCPHost) Prompt(ctx context.Context, message string) (string, error) { // Get messages from session messages := m.sessionMgr.GetMessages() @@ -148,7 +156,9 @@ func (m *MCPHost) Prompt(ctx context.Context, message string) (string, error) { return result.FinalResponse.Content, nil } -// PromptWithCallbacks sends a message with callbacks for tool execution +// PromptWithCallbacks sends a message with callbacks for monitoring tool execution +// and streaming responses. The callbacks allow real-time observation of tool calls, +// results, and response generation. Returns the final response or an error. func (m *MCPHost) PromptWithCallbacks( ctx context.Context, message string, @@ -184,12 +194,14 @@ func (m *MCPHost) PromptWithCallbacks( return result.FinalResponse.Content, nil } -// GetSessionManager returns the current session manager +// GetSessionManager returns the current session manager for direct access +// to conversation history and session manipulation. func (m *MCPHost) GetSessionManager() *session.Manager { return m.sessionMgr } -// LoadSession loads a session from file +// LoadSession loads a previously saved session from a file, restoring the +// conversation history. Returns an error if the file cannot be loaded or parsed. func (m *MCPHost) LoadSession(path string) error { s, err := session.LoadFromFile(path) if err != nil { @@ -199,22 +211,27 @@ func (m *MCPHost) LoadSession(path string) error { return nil } -// SaveSession saves the current session to file +// SaveSession saves the current session to a file for later restoration. +// Returns an error if the session cannot be written to the specified path. func (m *MCPHost) SaveSession(path string) error { return m.sessionMgr.GetSession().SaveToFile(path) } -// ClearSession clears the current session history +// ClearSession clears the current session history, starting a new conversation +// with an empty message history. func (m *MCPHost) ClearSession() { m.sessionMgr = session.NewManager("") } -// GetModelString returns the current model string +// GetModelString returns the current model string identifier (e.g., +// "anthropic:claude-3-sonnet" or "openai:gpt-4") being used by the agent. func (m *MCPHost) GetModelString() string { return m.modelString } -// Close cleans up resources +// Close cleans up resources including MCP server connections and model resources. +// Should be called when the MCPHost instance is no longer needed. Returns an +// error if cleanup fails. func (m *MCPHost) Close() error { return m.agent.Close() } diff --git a/sdk/types.go b/sdk/types.go index f1a1a2a0..f77313c8 100644 --- a/sdk/types.go +++ b/sdk/types.go @@ -5,18 +5,22 @@ import ( "github.com/mark3labs/mcphost/internal/session" ) -// Message is an alias for session.Message for SDK users +// Message is an alias for session.Message providing SDK users with access +// to message structures for conversation history and tool interactions. type Message = session.Message -// ToolCall is an alias for session.ToolCall +// ToolCall is an alias for session.ToolCall representing a tool invocation +// with its name, arguments, and result within a conversation. type ToolCall = session.ToolCall -// ConvertToSchemaMessage converts SDK message to schema message +// ConvertToSchemaMessage converts an SDK message to the underlying schema message +// format used by the agent for LLM interactions. func ConvertToSchemaMessage(msg *Message) *schema.Message { return msg.ConvertToSchemaMessage() } -// ConvertFromSchemaMessage converts schema message to SDK message +// ConvertFromSchemaMessage converts a schema message from the agent to an SDK +// message format for use in the SDK API. func ConvertFromSchemaMessage(msg *schema.Message) Message { return session.ConvertFromSchemaMessage(msg) }