This commit is contained in:
Ed Zynda
2025-11-12 16:48:46 +03:00
parent d3281a2f01
commit 63704f55b5
56 changed files with 1769 additions and 459 deletions
+13
View File
@@ -10,6 +10,10 @@ import (
"github.com/spf13/cobra"
)
// authCmd represents the auth command for managing AI provider authentication.
// This command provides subcommands for login, logout, and status checking
// of authentication credentials for various AI providers, with OAuth support
// for providers like Anthropic.
var authCmd = &cobra.Command{
Use: "auth",
Short: "Manage authentication credentials for AI providers",
@@ -27,6 +31,9 @@ Examples:
mcphost auth status`,
}
// authLoginCmd represents the login subcommand for authenticating with AI providers.
// It handles OAuth flow for supported providers, opening a browser for authentication
// and securely storing the resulting credentials for future use.
var authLoginCmd = &cobra.Command{
Use: "login [provider]",
Short: "Authenticate with an AI provider using OAuth",
@@ -45,6 +52,9 @@ Example:
RunE: runAuthLogin,
}
// authLogoutCmd represents the logout subcommand for removing stored authentication credentials.
// This command removes stored API keys or OAuth tokens for specified providers,
// requiring the user to authenticate again or use environment variables.
var authLogoutCmd = &cobra.Command{
Use: "logout [provider]",
Short: "Remove stored authentication credentials for a provider",
@@ -62,6 +72,9 @@ Example:
RunE: runAuthLogout,
}
// authStatusCmd represents the status subcommand for checking authentication status.
// It displays which providers have stored credentials, their types (OAuth vs API key),
// creation dates, and expiration status without revealing the actual credentials.
var authStatusCmd = &cobra.Command{
Use: "status",
Short: "Show authentication status for all providers",
+12
View File
@@ -10,12 +10,18 @@ import (
"gopkg.in/yaml.v3"
)
// hooksCmd represents the hooks command for managing MCPHost hook configurations.
// Hooks allow users to execute custom scripts or commands at various points
// during MCPHost execution, such as before/after tool use or when prompts are submitted.
var hooksCmd = &cobra.Command{
Use: "hooks",
Short: "Manage MCPHost hooks",
Long: "Commands for managing and testing MCPHost hooks configuration",
}
// hooksListCmd represents the list subcommand for displaying all configured hooks.
// It shows a formatted table of hook events, matchers, commands, and timeouts
// to help users understand their current hook configuration.
var hooksListCmd = &cobra.Command{
Use: "list",
Short: "List all configured hooks",
@@ -45,6 +51,9 @@ var hooksListCmd = &cobra.Command{
},
}
// hooksValidateCmd represents the validate subcommand for checking hook configuration validity.
// It loads and validates the hooks configuration file, ensuring proper syntax,
// valid event types, and correct matcher patterns before use.
var hooksValidateCmd = &cobra.Command{
Use: "validate",
Short: "Validate hooks configuration",
@@ -64,6 +73,9 @@ var hooksValidateCmd = &cobra.Command{
},
}
// hooksInitCmd represents the init subcommand for generating an example hooks configuration.
// It creates a .mcphost/hooks.yml file with sample hook configurations demonstrating
// various hook events and common use cases like logging commands and tool usage.
var hooksInitCmd = &cobra.Command{
Use: "init",
Short: "Generate example hooks configuration",
+23 -3
View File
@@ -84,6 +84,10 @@ func (a *agentUIAdapter) GetLoadedServerNames() []string {
return a.agent.GetLoadedServerNames()
}
// rootCmd represents the base command when called without any subcommands.
// This is the main entry point for the MCPHost CLI application, providing
// an interface to interact with various AI models through a unified interface
// with support for MCP servers and tool integration.
var rootCmd = &cobra.Command{
Use: "mcphost",
Short: "Chat with AI models through a unified interface",
@@ -120,12 +124,21 @@ Examples:
},
}
// GetRootCommand returns the root command with the version set
// GetRootCommand returns the root command with the version set.
// This function is the main entry point for the MCPHost CLI and should be
// called from main.go with the appropriate version string.
func GetRootCommand(v string) *cobra.Command {
rootCmd.Version = v
return rootCmd
}
// InitConfig initializes the configuration for MCPHost by loading config files,
// environment variables, and hooks configuration. It follows this priority order:
// 1. Command-line specified config file (--config flag)
// 2. Current directory config file (.mcphost or .mcp)
// 3. Home directory config file (~/.mcphost or ~/.mcp)
// 4. Environment variables (MCPHOST_* prefix)
// This function is automatically called by cobra before command execution.
func InitConfig() {
if configFile != "" {
// Use config file from the flag
@@ -202,7 +215,12 @@ func InitConfig() {
}
// LoadConfigWithEnvSubstitution loads a config file with environment variable substitution
// LoadConfigWithEnvSubstitution loads a config file with environment variable substitution.
// It reads the config file, replaces any ${ENV_VAR} patterns with their corresponding
// environment variable values, and then parses the resulting configuration using viper.
// The function automatically detects JSON or YAML format based on file extension.
// Returns an error if the file cannot be read, environment variable substitution fails,
// or the configuration cannot be parsed.
func LoadConfigWithEnvSubstitution(configPath string) error {
// Read raw config file content
rawContent, err := os.ReadFile(configPath)
@@ -728,7 +746,9 @@ func runNormalMode(ctx context.Context) error {
return runInteractiveMode(ctx, mcpAgent, cli, serverNames, toolNames, modelName, messages, sessionManager, hookExecutor)
}
// AgenticLoopConfig configures the behavior of the unified agentic loop
// AgenticLoopConfig configures the behavior of the unified agentic loop.
// This struct controls how the main interaction loop operates, whether in
// interactive or non-interactive mode, and manages various UI and session options.
type AgenticLoopConfig struct {
// Mode configuration
IsInteractive bool // true for interactive mode, false for non-interactive
+10 -4
View File
@@ -21,6 +21,10 @@ import (
"github.com/spf13/viper"
)
// scriptCmd represents the script command for executing MCPHost script files.
// Script files can contain YAML frontmatter configuration followed by a prompt,
// allowing for reproducible AI interactions with custom configurations and
// variable substitution support.
var scriptCmd = &cobra.Command{
Use: "script <script-file>",
Short: "Execute a script file with YAML frontmatter configuration",
@@ -413,11 +417,13 @@ func parseScriptContent(content string, variables map[string]string) (*config.Co
return &scriptConfig, nil
}
// Variable represents a script variable with optional default value
// Variable represents a script variable with optional default value.
// Variables can be declared in scripts using ${variable} syntax for required variables
// or ${variable:-default} syntax for variables with default values.
type Variable struct {
Name string
DefaultValue string
HasDefault bool
Name string // The name of the variable as it appears in the script
DefaultValue string // The default value if specified using ${variable:-default} syntax
HasDefault bool // Whether this variable has a default value
}
// findVariables extracts all unique variable names from ${variable} patterns in content
+50 -22
View File
@@ -16,35 +16,49 @@ import (
"time"
)
// AgentConfig is the config for agent.
// AgentConfig holds configuration options for creating a new Agent.
// It includes model configuration, MCP settings, and various behavioral options.
type AgentConfig struct {
ModelConfig *models.ProviderConfig
MCPConfig *config.Config
SystemPrompt string
MaxSteps int
// ModelConfig specifies the LLM provider and model to use
ModelConfig *models.ProviderConfig
// MCPConfig contains MCP server configurations
MCPConfig *config.Config
// SystemPrompt is the initial system message for the agent
SystemPrompt string
// MaxSteps limits the number of tool calls (0 for unlimited)
MaxSteps int
// StreamingEnabled controls whether responses are streamed
StreamingEnabled bool
DebugLogger tools.DebugLogger // Optional debug logger
// DebugLogger is an optional logger for debugging MCP communications
DebugLogger tools.DebugLogger // Optional debug logger
}
// ToolCallHandler is a function type for handling tool calls as they happen
// ToolCallHandler is a function type for handling tool calls as they happen.
// It receives the tool name and its arguments when a tool is about to be invoked.
type ToolCallHandler func(toolName, toolArgs string)
// ToolExecutionHandler is a function type for handling tool execution start/end
// ToolExecutionHandler is a function type for handling tool execution start/end events.
// The isStarting parameter indicates whether the tool is starting (true) or finished (false).
type ToolExecutionHandler func(toolName string, isStarting bool)
// ToolResultHandler is a function type for handling tool results
// ToolResultHandler is a function type for handling tool results.
// It receives the tool name, arguments, result, and whether the result is an error.
type ToolResultHandler func(toolName, toolArgs, result string, isError bool)
// ResponseHandler is a function type for handling LLM responses
// ResponseHandler is a function type for handling LLM responses.
// It receives the complete response content from the model.
type ResponseHandler func(content string)
// StreamingResponseHandler is a function type for handling streaming LLM responses
// StreamingResponseHandler is a function type for handling streaming LLM responses.
// It receives content chunks as they are streamed from the model.
type StreamingResponseHandler func(content string)
// ToolCallContentHandler is a function type for handling content that accompanies tool calls
// ToolCallContentHandler is a function type for handling content that accompanies tool calls.
// It receives any text content that the model generates alongside tool calls.
type ToolCallContentHandler func(content string)
// Agent is the agent with real-time tool call display.
// Agent represents an AI agent with MCP tool integration and real-time tool call display.
// It manages the interaction between an LLM and various tools through the MCP protocol.
type Agent struct {
toolManager *tools.MCPToolManager
model model.ToolCallingChatModel
@@ -55,7 +69,10 @@ type Agent struct {
streamingEnabled bool // Whether streaming is enabled
}
// NewAgent creates an agent with MCP tool integration and real-time tool call display
// NewAgent creates a new Agent with MCP tool integration and streaming support.
// It initializes the LLM provider, loads MCP tools, and configures the agent
// based on the provided configuration. Returns an error if provider creation
// or tool loading fails.
func NewAgent(ctx context.Context, config *AgentConfig) (*Agent, error) {
// Create the LLM provider
providerResult, err := models.CreateProvider(ctx, config.ModelConfig)
@@ -98,20 +115,27 @@ func NewAgent(ctx context.Context, config *AgentConfig) (*Agent, error) {
}, nil
}
// GenerateWithLoopResult contains the result and conversation history
// GenerateWithLoopResult contains the result and conversation history from an agent interaction.
// It includes both the final response and the complete message history with tool interactions.
type GenerateWithLoopResult struct {
FinalResponse *schema.Message
// FinalResponse is the last message generated by the model
FinalResponse *schema.Message
// ConversationMessages contains all messages in the conversation including tool calls and results
ConversationMessages []*schema.Message // All messages in the conversation (including tool calls and results)
}
// GenerateWithLoop processes messages with a custom loop that displays tool calls in real-time
// GenerateWithLoop processes messages with a custom loop that displays tool calls in real-time.
// It handles the conversation flow, executing tools as needed and invoking callbacks for various events.
// This method does not support streaming responses; use GenerateWithLoopAndStreaming for streaming support.
func (a *Agent) GenerateWithLoop(ctx context.Context, messages []*schema.Message,
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler) (*GenerateWithLoopResult, error) {
return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult, onResponse, onToolCallContent, nil)
}
// GenerateWithLoopAndStreaming processes messages with a custom loop that displays tool calls in real-time and supports streaming callbacks
// GenerateWithLoopAndStreaming processes messages with a custom loop that displays tool calls in real-time and supports streaming callbacks.
// It handles the conversation flow, executing tools as needed and invoking callbacks for various events including streaming chunks.
// The onStreamingResponse callback is invoked for each content chunk during streaming if streaming is enabled.
func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []*schema.Message,
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler, onStreamingResponse StreamingResponseHandler) (*GenerateWithLoopResult, error) {
@@ -256,17 +280,20 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []*sc
}, nil
}
// GetTools returns the list of available tools
// GetTools returns the list of available tools loaded in the agent.
// These tools are available for the model to use during interactions.
func (a *Agent) GetTools() []tool.BaseTool {
return a.toolManager.GetTools()
}
// GetLoadingMessage returns the loading message from provider creation (e.g., GPU fallback info)
// GetLoadingMessage returns the loading message from provider creation.
// This may contain information about GPU fallback or other provider-specific initialization details.
func (a *Agent) GetLoadingMessage() string {
return a.loadingMessage
}
// GetLoadedServerNames returns the names of successfully loaded MCP servers
// GetLoadedServerNames returns the names of successfully loaded MCP servers.
// This includes both builtin servers and external MCP server configurations.
func (a *Agent) GetLoadedServerNames() []string {
return a.toolManager.GetLoadedServerNames()
}
@@ -486,7 +513,8 @@ func (a *Agent) listenForESC(stopChan chan bool, readyChan chan bool) bool {
}
}
// Close closes the agent and cleans up resources
// Close closes the agent and cleans up resources.
// It ensures all MCP connections are properly closed and resources are released.
func (a *Agent) Close() error {
return a.toolManager.Close()
}
+27 -12
View File
@@ -10,23 +10,36 @@ import (
"github.com/mark3labs/mcphost/internal/tools"
)
// SpinnerFunc is a function type for showing spinners during agent creation
// SpinnerFunc is a function type for showing spinners during agent creation.
// It executes the provided function while displaying a spinner with the given message.
type SpinnerFunc func(message string, fn func() error) error
// AgentCreationOptions contains options for creating an agent
// AgentCreationOptions contains options for creating an agent.
// It extends AgentConfig with UI-related options for showing progress during creation.
type AgentCreationOptions struct {
ModelConfig *models.ProviderConfig
MCPConfig *config.Config
SystemPrompt string
MaxSteps int
// ModelConfig specifies the LLM provider and model to use
ModelConfig *models.ProviderConfig
// MCPConfig contains MCP server configurations
MCPConfig *config.Config
// SystemPrompt is the initial system message for the agent
SystemPrompt string
// MaxSteps limits the number of tool calls (0 for unlimited)
MaxSteps int
// StreamingEnabled controls whether responses are streamed
StreamingEnabled bool
ShowSpinner bool // For Ollama models
Quiet bool // Skip spinner if quiet
SpinnerFunc SpinnerFunc // Function to show spinner (provided by caller)
DebugLogger tools.DebugLogger // Optional debug logger
// ShowSpinner indicates whether to show a spinner for Ollama models during loading
ShowSpinner bool // For Ollama models
// Quiet suppresses the spinner even if ShowSpinner is true
Quiet bool // Skip spinner if quiet
// SpinnerFunc is the function to show spinner, provided by the caller
SpinnerFunc SpinnerFunc // Function to show spinner (provided by caller)
// DebugLogger is an optional logger for debugging MCP communications
DebugLogger tools.DebugLogger // Optional debug logger
}
// CreateAgent creates an agent with optional spinner for Ollama models
// CreateAgent creates an agent with optional spinner for Ollama models.
// It shows a loading spinner for Ollama models if ShowSpinner is true and not in quiet mode.
// Returns the created agent or an error if creation fails.
func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error) {
agentConfig := &AgentConfig{
ModelConfig: opts.ModelConfig,
@@ -57,7 +70,9 @@ func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error
return agent, nil
}
// ParseModelName extracts provider and model name from model string
// ParseModelName extracts provider and model name from a model string.
// Model strings are formatted as "provider:model" (e.g., "anthropic:claude-3-5-sonnet-20241022").
// If the string doesn't contain a colon, returns "unknown" for both provider and model.
func ParseModelName(modelString string) (provider, model string) {
parts := strings.SplitN(modelString, ":", 2)
if len(parts) == 2 {
+2 -1
View File
@@ -8,7 +8,8 @@ import (
"github.com/cloudwego/eino/schema"
)
// StreamWithCallback streams content with real-time callbacks and returns complete response
// StreamWithCallback streams content with real-time callbacks and returns the complete response.
// It accumulates content and tool calls from the stream, invoking the callback for each content chunk.
// IMPORTANT: Tool calls are only processed after EOF is reached to ensure we have the complete
// and final tool call information. This prevents premature tool execution on partial data.
// Handles different provider streaming patterns:
+8 -2
View File
@@ -6,7 +6,11 @@ import (
"runtime"
)
// OpenBrowser opens the default browser to the specified URL
// OpenBrowser opens the default web browser to the specified URL.
// It automatically detects the operating system and uses the appropriate
// command to launch the browser (xdg-open on Linux, rundll32 on Windows,
// open on macOS). Returns an error if the platform is unsupported or if
// the browser fails to launch.
func OpenBrowser(url string) error {
var err error
@@ -24,7 +28,9 @@ func OpenBrowser(url string) error {
return err
}
// TryOpenBrowser attempts to open the browser but doesn't fail if it can't
// TryOpenBrowser attempts to open the default web browser to the specified URL
// but silently ignores any errors. This is useful when browser access is optional
// and users can manually copy and paste the URL if automatic browser launching fails.
func TryOpenBrowser(url string) {
// Silently ignore errors - user can still copy/paste the URL
_ = OpenBrowser(url)
+47 -15
View File
@@ -9,12 +9,16 @@ import (
"time"
)
// CredentialStore holds all stored credentials
// CredentialStore holds all stored credentials for various providers.
// Currently supports Anthropic credentials with both OAuth and API key authentication methods.
type CredentialStore struct {
Anthropic *AnthropicCredentials `json:"anthropic,omitempty"`
}
// AnthropicCredentials holds Anthropic API credentials
// AnthropicCredentials holds Anthropic API credentials supporting both OAuth
// and API key authentication methods. The Type field indicates which authentication
// method is being used. For OAuth, tokens are stored with expiration timestamps
// for automatic refresh. For API keys, only the key itself is stored.
type AnthropicCredentials struct {
Type string `json:"type"` // "oauth" or "api_key"
APIKey string `json:"api_key,omitempty"` // For API key auth
@@ -24,7 +28,8 @@ type AnthropicCredentials struct {
CreatedAt time.Time `json:"created_at"`
}
// IsExpired checks if the OAuth token is expired
// IsExpired checks if the OAuth token is expired based on the ExpiresAt timestamp.
// Returns false for API key authentication or if no expiration is set.
func (c *AnthropicCredentials) IsExpired() bool {
if c.Type != "oauth" || c.ExpiresAt == 0 {
return false
@@ -32,7 +37,10 @@ func (c *AnthropicCredentials) IsExpired() bool {
return time.Now().Unix() >= c.ExpiresAt
}
// NeedsRefresh checks if the OAuth token needs refresh (5 minutes before expiry)
// NeedsRefresh checks if the OAuth token needs refresh, returning true if the token
// will expire within the next 5 minutes. This allows for proactive token refresh
// to avoid authentication failures during operations. Returns false for API key
// authentication or if no expiration is set.
func (c *AnthropicCredentials) NeedsRefresh() bool {
if c.Type != "oauth" || c.ExpiresAt == 0 {
return false
@@ -40,12 +48,17 @@ func (c *AnthropicCredentials) NeedsRefresh() bool {
return time.Now().Unix() >= (c.ExpiresAt - 300) // 5 minutes buffer
}
// CredentialManager handles credential storage and retrieval
// CredentialManager handles secure storage and retrieval of authentication credentials.
// It manages a JSON file stored in the user's config directory with appropriate
// file permissions for security.
type CredentialManager struct {
credentialsPath string
}
// NewCredentialManager creates a new credential manager
// NewCredentialManager creates a new credential manager instance. It determines
// the appropriate credentials path based on XDG_CONFIG_HOME or falls back to
// ~/.config/.mcphost/credentials.json. Returns an error if the home directory
// cannot be determined.
func NewCredentialManager() (*CredentialManager, error) {
credentialsPath, err := getCredentialsPath()
if err != nil {
@@ -73,7 +86,9 @@ func getCredentialsPath() (string, error) {
return filepath.Join(homeDir, ".config", ".mcphost", "credentials.json"), nil
}
// LoadCredentials loads credentials from the file
// LoadCredentials loads credentials from the JSON file. If the file doesn't exist,
// it returns an empty CredentialStore instead of an error, allowing for graceful
// initialization. Returns an error if the file exists but cannot be read or parsed.
func (cm *CredentialManager) LoadCredentials() (*CredentialStore, error) {
// If file doesn't exist, return empty store
if _, err := os.Stat(cm.credentialsPath); os.IsNotExist(err) {
@@ -93,7 +108,10 @@ func (cm *CredentialManager) LoadCredentials() (*CredentialStore, error) {
return &store, nil
}
// SaveCredentials saves credentials to the file
// SaveCredentials saves credentials to the JSON file with secure permissions (0600).
// It creates the parent directory if it doesn't exist. The file is written atomically
// to prevent corruption. Returns an error if the directory cannot be created or the
// file cannot be written.
func (cm *CredentialManager) SaveCredentials(store *CredentialStore) error {
// Ensure directory exists
dir := filepath.Dir(cm.credentialsPath)
@@ -114,7 +132,10 @@ func (cm *CredentialManager) SaveCredentials(store *CredentialStore) error {
return nil
}
// SetAnthropicCredentials stores Anthropic API credentials (for API key auth)
// SetAnthropicCredentials stores Anthropic API key credentials. It validates the
// API key format before storing. The API key must start with "sk-ant-" and be
// at least 20 characters long. Returns an error if the API key is invalid or
// if storage fails.
func (cm *CredentialManager) SetAnthropicCredentials(apiKey string) error {
if err := validateAnthropicAPIKey(apiKey); err != nil {
return err
@@ -134,7 +155,9 @@ func (cm *CredentialManager) SetAnthropicCredentials(apiKey string) error {
return cm.SaveCredentials(store)
}
// GetAnthropicCredentials retrieves Anthropic API credentials
// GetAnthropicCredentials retrieves stored Anthropic credentials. Returns nil if
// no credentials are stored. The returned credentials may be either OAuth or API
// key type, check the Type field to determine which.
func (cm *CredentialManager) GetAnthropicCredentials() (*AnthropicCredentials, error) {
store, err := cm.LoadCredentials()
if err != nil {
@@ -144,7 +167,9 @@ func (cm *CredentialManager) GetAnthropicCredentials() (*AnthropicCredentials, e
return store.Anthropic, nil
}
// RemoveAnthropicCredentials removes stored Anthropic credentials
// RemoveAnthropicCredentials removes stored Anthropic credentials from storage.
// If this was the only credential stored, the entire credentials file is removed.
// Returns an error if the removal fails.
func (cm *CredentialManager) RemoveAnthropicCredentials() error {
store, err := cm.LoadCredentials()
if err != nil {
@@ -164,7 +189,9 @@ func (cm *CredentialManager) RemoveAnthropicCredentials() error {
return cm.SaveCredentials(store)
}
// HasAnthropicCredentials checks if Anthropic credentials are stored
// HasAnthropicCredentials checks if valid Anthropic credentials are stored.
// Returns true if either a non-empty OAuth access token or API key is present,
// false otherwise. Returns an error if credentials cannot be loaded.
func (cm *CredentialManager) HasAnthropicCredentials() (bool, error) {
creds, err := cm.GetAnthropicCredentials()
if err != nil {
@@ -185,7 +212,8 @@ func (cm *CredentialManager) HasAnthropicCredentials() (bool, error) {
}
}
// GetCredentialsPath returns the path to the credentials file
// GetCredentialsPath returns the absolute path to the credentials JSON file.
// This is useful for debugging or displaying the storage location to users.
func (cm *CredentialManager) GetCredentialsPath() string {
return cm.credentialsPath
}
@@ -210,8 +238,12 @@ func validateAnthropicAPIKey(apiKey string) error {
return nil
}
// GetAnthropicAPIKey is a convenience function that checks stored credentials first,
// then falls back to environment variables and flags
// GetAnthropicAPIKey retrieves an Anthropic API key from multiple sources in priority order:
// 1. Command-line flag value (highest priority)
// 2. Stored credentials (OAuth or API key)
// 3. ANTHROPIC_API_KEY environment variable (lowest priority)
// Returns the API key, a description of its source, and any error encountered.
// For OAuth credentials, it automatically refreshes expired tokens.
func GetAnthropicAPIKey(flagValue string) (string, string, error) {
// 1. Check flag value first (highest priority)
if flagValue != "" {
+34 -9
View File
@@ -13,7 +13,9 @@ import (
"time"
)
// OAuthClient handles OAuth authentication with Anthropic
// OAuthClient handles OAuth 2.0 authentication flow with Anthropic using the
// PKCE (Proof Key for Code Exchange) extension for enhanced security in public clients.
// It manages the authorization URL generation, code exchange, and token refresh operations.
type OAuthClient struct {
ClientID string
AuthorizeURL string
@@ -22,13 +24,18 @@ type OAuthClient struct {
Scopes string
}
// AuthData contains authorization URL and PKCE verifier
// AuthData contains the authorization URL for user authentication and the PKCE
// verifier needed for the subsequent code exchange. The verifier must be stored
// securely and used when exchanging the authorization code for tokens.
type AuthData struct {
URL string
Verifier string
}
// NewOAuthClient creates a new OAuth client with Anthropic configuration
// NewOAuthClient creates a new OAuth client configured for Anthropic's OAuth service.
// The client uses a public client ID (as per OAuth 2.0 public client specification)
// with PKCE for security. The configuration includes the authorization endpoint,
// token endpoint, redirect URI, and required scopes for API key creation and inference.
func NewOAuthClient() *OAuthClient {
return &OAuthClient{
// OAuth client ID is public by design for CLI applications (OAuth public clients).
@@ -42,7 +49,11 @@ func NewOAuthClient() *OAuthClient {
}
}
// GeneratePKCE generates PKCE verifier and challenge for OAuth flow
// GeneratePKCE generates a cryptographically secure PKCE verifier and challenge pair
// for the OAuth 2.0 PKCE flow. The verifier is a random 32-byte string encoded as
// base64url, and the challenge is the SHA256 hash of the verifier, also base64url encoded.
// Returns the verifier (to be stored securely), challenge (to be sent with auth request),
// and any error encountered during generation.
func GeneratePKCE() (verifier, challenge string, err error) {
// Generate 32 bytes of random data
verifierBytes := make([]byte, 32)
@@ -60,7 +71,10 @@ func GeneratePKCE() (verifier, challenge string, err error) {
return verifier, challenge, nil
}
// GetAuthorizationURL generates the authorization URL with PKCE parameters
// GetAuthorizationURL generates a complete authorization URL for the OAuth flow with
// PKCE parameters. The URL includes the client ID, redirect URI, requested scopes,
// and PKCE challenge. Returns an AuthData structure containing the URL for user
// authentication and the PKCE verifier for the subsequent code exchange.
func (c *OAuthClient) GetAuthorizationURL() (*AuthData, error) {
verifier, challenge, err := GeneratePKCE()
if err != nil {
@@ -86,7 +100,10 @@ func (c *OAuthClient) GetAuthorizationURL() (*AuthData, error) {
}, nil
}
// ExchangeCode exchanges an authorization code for tokens
// ExchangeCode exchanges an authorization code for access and refresh tokens.
// The code parameter should be the authorization code received from the OAuth callback.
// The verifier parameter must be the same PKCE verifier generated during GetAuthorizationURL.
// Returns AnthropicCredentials containing the tokens and expiration information.
func (c *OAuthClient) ExchangeCode(code, verifier string) (*AnthropicCredentials, error) {
// Parse code and state
parsedCode, parsedState := c.parseCodeAndState(code)
@@ -109,7 +126,10 @@ func (c *OAuthClient) ExchangeCode(code, verifier string) (*AnthropicCredentials
return c.makeTokenRequest(reqBody)
}
// RefreshToken refreshes an access token using a refresh token
// RefreshToken refreshes an expired or expiring access token using a refresh token.
// Returns new AnthropicCredentials with updated access token, refresh token (may be
// rotated), and new expiration timestamp. Returns an error if the refresh fails or
// the refresh token is invalid.
func (c *OAuthClient) RefreshToken(refreshToken string) (*AnthropicCredentials, error) {
reqBody := map[string]interface{}{
"grant_type": "refresh_token",
@@ -179,7 +199,9 @@ func (c *OAuthClient) parseCodeAndState(code string) (parsedCode, parsedState st
return
}
// SetOAuthCredentials stores OAuth credentials
// SetOAuthCredentials stores OAuth credentials in the credential manager's secure storage.
// The credentials should include access token, refresh token, and expiration information.
// Returns an error if the credentials cannot be saved.
func (cm *CredentialManager) SetOAuthCredentials(creds *AnthropicCredentials) error {
store, err := cm.LoadCredentials()
if err != nil {
@@ -190,7 +212,10 @@ func (cm *CredentialManager) SetOAuthCredentials(creds *AnthropicCredentials) er
return cm.SaveCredentials(store)
}
// GetValidAccessToken returns a valid access token, refreshing if necessary
// GetValidAccessToken returns a valid access token for API requests. For OAuth credentials,
// it automatically refreshes the token if it's expired or about to expire. For API key
// credentials, it simply returns the API key. Returns an error if no credentials are found,
// if token refresh fails, or if the credential type is unknown.
func (cm *CredentialManager) GetValidAccessToken() (string, error) {
creds, err := cm.GetAnthropicCredentials()
if err != nil {
+4 -1
View File
@@ -37,7 +37,10 @@ var bannedCommands = []string{
"safari",
}
// NewBashServer creates a new bash MCP server
// NewBashServer creates a new MCP server that provides bash command execution capabilities.
// The server includes a single tool "run_shell_cmd" that executes shell commands with
// security restrictions, timeout controls, and output truncation. Returns an error if
// server initialization fails.
func NewBashServer() (*server.MCPServer, error) {
s := server.NewMCPServer("bash-server", "1.0.0", server.WithToolCapabilities(true))
+3 -1
View File
@@ -21,7 +21,9 @@ const (
maxFetchTimeout = 120 * time.Second
)
// NewFetchServer creates a new fetch MCP server
// NewFetchServer creates a new MCP server that provides web content fetching capabilities.
// The server includes a single tool "fetch" that retrieves content from URLs and converts
// it to text, markdown, or HTML format. Returns an error if server initialization fails.
func NewFetchServer() (*server.MCPServer, error) {
s := server.NewMCPServer("fetch-server", "1.0.0", server.WithToolCapabilities(true))
+5 -1
View File
@@ -28,7 +28,11 @@ const (
// httpServerModel holds the model for the HTTP server
var httpServerModel model.ToolCallingChatModel
// NewHTTPServer creates a new HTTP MCP server
// NewHTTPServer creates a new MCP server providing advanced HTTP fetching capabilities.
// The server includes tools for fetching web content, summarizing pages, extracting
// specific information, and filtering JSON responses. If an LLM model is provided,
// AI-powered summarization and extraction tools are enabled. Returns an error if
// server initialization fails.
func NewHTTPServer(llmModel model.ToolCallingChatModel) (*server.MCPServer, error) {
// Store the model globally for use in tool handlers
httpServerModel = llmModel
+19 -7
View File
@@ -9,28 +9,36 @@ import (
"github.com/mark3labs/mcp-go/server"
)
// BuiltinServerWrapper wraps an external MCP server for builtin use
// BuiltinServerWrapper wraps an external MCP server for builtin use, providing
// a consistent interface for all builtin servers regardless of their implementation.
type BuiltinServerWrapper struct {
server *server.MCPServer
}
// Initialize initializes the wrapped server
// Initialize initializes the wrapped server. For builtin servers, this is typically
// a no-op as the server is initialized during creation. Returns an error if
// initialization fails.
func (w *BuiltinServerWrapper) Initialize() error {
// The server is already initialized when created
return nil
}
// GetServer returns the wrapped MCP server
// GetServer returns the wrapped MCP server instance that can be used to handle
// tool calls and other MCP protocol operations.
func (w *BuiltinServerWrapper) GetServer() *server.MCPServer {
return w.server
}
// Registry holds all available builtin servers
// Registry holds all available builtin servers and their factory functions.
// It provides a centralized registry for creating instances of builtin MCP servers
// with their respective configurations.
type Registry struct {
servers map[string]func(options map[string]any, model model.ToolCallingChatModel) (*BuiltinServerWrapper, error)
}
// NewRegistry creates a new builtin server registry
// NewRegistry creates a new builtin server registry with all available builtin
// servers registered. The registry includes filesystem (fs), bash, todo, fetch,
// and HTTP servers.
func NewRegistry() *Registry {
r := &Registry{
servers: make(map[string]func(options map[string]any, model model.ToolCallingChatModel) (*BuiltinServerWrapper, error)),
@@ -46,7 +54,10 @@ func NewRegistry() *Registry {
return r
}
// CreateServer creates a new instance of a builtin server
// CreateServer creates a new instance of a builtin server by name. The options
// parameter provides server-specific configuration, and the model parameter provides
// an optional LLM for AI-powered features. Returns an error if the server name
// is unknown or if creation fails.
func (r *Registry) CreateServer(name string, options map[string]any, model model.ToolCallingChatModel) (*BuiltinServerWrapper, error) {
factory, exists := r.servers[name]
if !exists {
@@ -56,7 +67,8 @@ func (r *Registry) CreateServer(name string, options map[string]any, model model
return factory(options, model)
}
// ListServers returns a list of available builtin server names
// ListServers returns a list of all available builtin server names that can be
// created using CreateServer. The order of names is not guaranteed.
func (r *Registry) ListServers() []string {
names := make([]string, 0, len(r.servers))
for name := range r.servers {
+10 -3
View File
@@ -11,7 +11,9 @@ import (
"github.com/mark3labs/mcp-go/server"
)
// TodoInfo represents a single todo item
// TodoInfo represents a single todo item with content, status, priority, and ID.
// Status can be "pending", "in_progress", or "completed". Priority can be "high",
// "medium", or "low". Each todo must have a unique ID.
type TodoInfo struct {
Content string `json:"content"`
Status string `json:"status"`
@@ -19,13 +21,18 @@ type TodoInfo struct {
ID string `json:"id"`
}
// TodoServer implements a todo management MCP server with in-memory storage
// TodoServer implements a todo management MCP server with in-memory storage.
// It provides thread-safe operations for reading and writing todo lists, with
// support for task status tracking and priority levels.
type TodoServer struct {
todos []TodoInfo
mutex sync.RWMutex
}
// NewTodoServer creates a new todo MCP server with in-memory storage
// NewTodoServer creates a new MCP server that provides todo list management capabilities.
// The server includes two tools: "todowrite" for updating the todo list and "todoread"
// for retrieving the current list. Todos are stored in memory and not persisted.
// Returns an error if server initialization fails.
func NewTodoServer() (*server.MCPServer, error) {
todoServer := &TodoServer{
todos: make([]TodoInfo, 0),
+34 -7
View File
@@ -11,7 +11,9 @@ import (
"gopkg.in/yaml.v3"
)
// MCPServerConfig represents configuration for an MCP server
// MCPServerConfig represents configuration for an MCP server, supporting both
// local (stdio), remote (StreamableHTTP/SSE), and builtin (in-process) server types.
// It maintains backward compatibility with legacy configuration formats.
type MCPServerConfig struct {
Type string `json:"type"`
Command []string `json:"command,omitempty"`
@@ -29,7 +31,9 @@ type MCPServerConfig struct {
Headers []string `json:"headers,omitempty"`
}
// UnmarshalJSON handles both new and legacy config formats
// UnmarshalJSON handles both new and legacy config formats for backward compatibility.
// New format uses "type" field with "local", "remote", or "builtin" values.
// Legacy format uses "transport", "command", "args", and "env" fields.
func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
// First try to unmarshal as the new format
type newFormat struct {
@@ -100,11 +104,15 @@ func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
return nil
}
// AdaptiveColor represents a color that adapts to light and dark themes.
// Either light or dark can be specified, or both for theme-aware coloring.
type AdaptiveColor struct {
Light string `json:"light,omitempty" yaml:"light,omitempty"`
Dark string `json:"dark,omitempty" yaml:"dark,omitempty"`
}
// Theme defines the color scheme for the application UI with adaptive colors
// that support both light and dark modes.
type Theme struct {
Primary AdaptiveColor `json:"primary" yaml:"primary"`
Secondary AdaptiveColor `json:"secondary" yaml:"secondary"`
@@ -124,6 +132,8 @@ type Theme struct {
Highlight AdaptiveColor `json:"highlight" yaml:"highlight"`
}
// MarkdownTheme defines the color scheme for markdown rendering with syntax
// highlighting support and adaptive colors for light and dark modes.
type MarkdownTheme struct {
Text AdaptiveColor `json:"text" yaml:"text"`
Muted AdaptiveColor `json:"muted" yaml:"muted"`
@@ -139,7 +149,9 @@ type MarkdownTheme struct {
Comment AdaptiveColor `json:"comment" yaml:"comment"`
}
// Config represents the application configuration
// Config represents the complete application configuration including MCP servers,
// model settings, UI preferences, and API credentials. It supports both command-line
// flags and configuration file settings.
type Config struct {
MCPServers map[string]MCPServerConfig `json:"mcpServers" yaml:"mcpServers"`
Model string `json:"model,omitempty" yaml:"model,omitempty"`
@@ -166,7 +178,9 @@ type Config struct {
TLSSkipVerify bool `json:"tls-skip-verify,omitempty" yaml:"tls-skip-verify,omitempty"`
}
// GetTransportType returns the transport type for the server config
// GetTransportType returns the transport type for the server config, mapping
// simplified type names to actual transport protocols. Supports legacy format
// detection and automatic type inference from configuration.
func (s *MCPServerConfig) GetTransportType() string {
// Legacy format support - check explicit transport first
if s.Transport != "" {
@@ -197,7 +211,9 @@ func (s *MCPServerConfig) GetTransportType() string {
return "stdio" // default
}
// Validate validates the configuration
// Validate validates the configuration, ensuring required fields are present
// for each server type and that tool filters are used correctly. Returns an
// error describing any validation failures.
func (c *Config) Validate() error {
for serverName, serverConfig := range c.MCPServers {
if len(serverConfig.AllowedTools) > 0 && len(serverConfig.ExcludedTools) > 0 {
@@ -226,7 +242,9 @@ func (c *Config) Validate() error {
return nil
}
// LoadSystemPrompt loads system prompt from file or returns the string directly
// LoadSystemPrompt loads system prompt from file or returns the string directly.
// If input is a path to an existing file, its contents are read and returned.
// Otherwise, the input string is returned as-is.
func LoadSystemPrompt(input string) (string, error) {
if input == "" {
return "", nil
@@ -246,7 +264,9 @@ func LoadSystemPrompt(input string) (string, error) {
return input, nil
}
// EnsureConfigExists checks if a config file exists and creates a default one if not
// EnsureConfigExists checks if a config file exists and creates a default one if not.
// It searches for .mcphost.{yml,yaml,json} or legacy .mcp.{yml,yaml,json} files in
// the user's home directory. If none exist, creates a default .mcphost.yml with examples.
func EnsureConfigExists() error {
homeDir, err := os.UserHomeDir()
if err != nil {
@@ -378,6 +398,10 @@ mcpServers:
return nil
}
// FilepathOr reads a configuration value that can be either a direct value or a
// filepath to a JSON/YAML file containing the value. If the value is a string
// starting with "~/" or a relative path, it's expanded to an absolute path.
// The contents of the file are then unmarshaled into the provided value pointer.
func FilepathOr[T any](key string, value *T) error {
var field any
err := viper.UnmarshalKey(key, &field)
@@ -428,6 +452,9 @@ func FilepathOr[T any](key string, value *T) error {
var configPath string
// SetConfigPath sets the configuration file path for resolving relative paths
// in configuration values. This should be called when the configuration file
// location is known.
func SetConfigPath(path string) {
configPath = path
}
+6 -2
View File
@@ -7,7 +7,9 @@ import (
"github.com/spf13/viper"
)
// MergeConfigs merges script frontmatter config with base config
// MergeConfigs merges script frontmatter config with base config, allowing scripts
// to override MCP server configurations. The script config takes precedence over
// the base config for any fields that are specified.
func MergeConfigs(baseConfig *Config, scriptConfig *Config) *Config {
merged := *baseConfig // Copy base config
@@ -20,7 +22,9 @@ func MergeConfigs(baseConfig *Config, scriptConfig *Config) *Config {
return &merged
}
// LoadAndValidateConfig loads config from viper and validates it
// LoadAndValidateConfig loads configuration from viper, fixes environment variable
// casing issues, and validates the configuration. Returns an error if loading or
// validation fails.
func LoadAndValidateConfig() (*Config, error) {
config := &Config{
MCPServers: make(map[string]MCPServerConfig),
+16 -7
View File
@@ -24,10 +24,13 @@ func parseVariableWithDefault(varPart string) (varName, defaultValue string, has
return varPart, "", false
}
// EnvSubstituter handles environment variable substitution
// EnvSubstituter handles environment variable substitution in configuration strings,
// supporting both ${env://VAR} and ${env://VAR:-default} patterns.
type EnvSubstituter struct{}
// SubstituteEnvVars replaces ${env://VAR} and ${env://VAR:-default} patterns with environment variables
// SubstituteEnvVars replaces ${env://VAR} and ${env://VAR:-default} patterns with environment variables.
// If a variable is not set and has a default value, the default is used. Returns an error
// if required variables (those without defaults) are not set.
func (e *EnvSubstituter) SubstituteEnvVars(content string) (string, error) {
var errors []string
@@ -57,17 +60,21 @@ func (e *EnvSubstituter) SubstituteEnvVars(content string) (string, error) {
return result, nil
}
// ArgsSubstituter handles script argument substitution
// ArgsSubstituter handles script argument substitution in configuration strings,
// supporting both ${VAR} and ${VAR:-default} patterns for template variable replacement.
type ArgsSubstituter struct {
args map[string]string
}
// NewArgsSubstituter creates a new args substituter with the given arguments
// NewArgsSubstituter creates a new args substituter with the given arguments map.
// The arguments are used to replace template variables in configuration strings.
func NewArgsSubstituter(args map[string]string) *ArgsSubstituter {
return &ArgsSubstituter{args: args}
}
// SubstituteArgs replaces ${VAR} and ${VAR:-default} patterns with script arguments
// SubstituteArgs replaces ${VAR} and ${VAR:-default} patterns with script arguments.
// If an argument is not provided and has a default value, the default is used.
// Returns an error if required arguments (those without defaults) are not provided.
func (a *ArgsSubstituter) SubstituteArgs(content string) (string, error) {
var errors []string
@@ -97,12 +104,14 @@ func (a *ArgsSubstituter) SubstituteArgs(content string) (string, error) {
return result, nil
}
// HasEnvVars checks if content contains environment variable patterns
// HasEnvVars checks if content contains environment variable patterns (${env://...}).
// This is useful for determining if substitution is needed before processing.
func HasEnvVars(content string) bool {
return envVarPattern.MatchString(content)
}
// HasScriptArgs checks if content contains script argument patterns
// HasScriptArgs checks if content contains script argument patterns (${...}).
// This is useful for determining if argument substitution is needed before processing.
func HasScriptArgs(content string) bool {
return scriptArgsPattern.MatchString(content)
}
+13 -4
View File
@@ -9,26 +9,35 @@ import (
"path/filepath"
)
// HookConfig represents the complete hooks configuration
// HookConfig represents the complete hooks configuration containing event-triggered
// hooks for tool execution lifecycle events.
type HookConfig struct {
Hooks map[HookEvent][]HookMatcher `yaml:"hooks" json:"hooks"`
}
// HookMatcher matches specific tools and defines hooks to execute
// HookMatcher matches specific tools and defines hooks to execute. The Matcher field
// contains a pattern to match tool names, and Hooks contains the commands to run
// when a match occurs. The Merge field controls how this matcher combines with others.
type HookMatcher struct {
Matcher string `yaml:"matcher,omitempty" json:"matcher,omitempty"`
Merge string `yaml:"_merge,omitempty" json:"_merge,omitempty"`
Hooks []HookEntry `yaml:"hooks" json:"hooks"`
}
// HookEntry defines a single hook command
// HookEntry defines a single hook command to execute. Type specifies the command
// type (e.g., "bash"), Command contains the actual command to run, and Timeout
// optionally specifies the maximum execution time in seconds.
type HookEntry struct {
Type string `yaml:"type" json:"type"`
Command string `yaml:"command" json:"command"`
Timeout int `yaml:"timeout,omitempty" json:"timeout,omitempty"`
}
// LoadHooksConfig loads and merges hook configurations from multiple sources
// LoadHooksConfig loads and merges hook configurations from multiple sources.
// It searches for hooks.{json,yml} files in standard locations (XDG config directory,
// local .mcphost directory) and any custom paths provided. Configurations are merged
// with later sources taking precedence. Environment variable substitution is applied
// to all loaded configurations.
func LoadHooksConfig(customPaths ...string) (*HookConfig, error) {
// Get config directory following XDG Base Directory specification
configDir := getConfigDir()
+11 -7
View File
@@ -1,23 +1,25 @@
package hooks
// HookEvent represents a point in MCPHost's lifecycle where hooks can be executed
// HookEvent represents a point in MCPHost's lifecycle where hooks can be executed.
// Events can be tool-related (requiring matchers) or lifecycle-related.
type HookEvent string
const (
// PreToolUse fires before any tool execution
// PreToolUse fires before any tool execution, allowing pre-processing or validation
PreToolUse HookEvent = "PreToolUse"
// PostToolUse fires after tool execution completes
// PostToolUse fires after tool execution completes, allowing post-processing or logging
PostToolUse HookEvent = "PostToolUse"
// UserPromptSubmit fires when user submits a prompt
// UserPromptSubmit fires when user submits a prompt, before agent processing
UserPromptSubmit HookEvent = "UserPromptSubmit"
// Stop fires when the main agent finishes responding
// Stop fires when the main agent finishes responding to a user prompt
Stop HookEvent = "Stop"
)
// IsValid returns true if the event is a valid hook event
// IsValid returns true if the event is a valid hook event.
// Valid events are PreToolUse, PostToolUse, UserPromptSubmit, and Stop.
func (e HookEvent) IsValid() bool {
switch e {
case PreToolUse, PostToolUse, UserPromptSubmit, Stop:
@@ -26,7 +28,9 @@ func (e HookEvent) IsValid() bool {
return false
}
// RequiresMatcher returns true if the event uses tool matchers
// RequiresMatcher returns true if the event uses tool matchers.
// PreToolUse and PostToolUse events require matchers to determine which
// tools trigger the hooks. Other events apply globally without matchers.
func (e HookEvent) RequiresMatcher() bool {
return e == PreToolUse || e == PostToolUse
}
+17 -6
View File
@@ -12,7 +12,9 @@ import (
"time"
)
// Executor handles hook execution
// Executor handles hook execution for MCPHost lifecycle events. It manages
// hook configuration, executes matching hooks in parallel, and processes
// their outputs to determine application behavior.
type Executor struct {
config *HookConfig
sessionID string
@@ -22,7 +24,9 @@ type Executor struct {
mu sync.RWMutex
}
// NewExecutor creates a new hook executor
// NewExecutor creates a new hook executor with the given configuration,
// session ID, and transcript path. The executor manages hook execution
// throughout the application lifecycle.
func NewExecutor(config *HookConfig, sessionID, transcriptPath string) *Executor {
return &Executor{
config: config,
@@ -31,21 +35,25 @@ func NewExecutor(config *HookConfig, sessionID, transcriptPath string) *Executor
}
}
// SetModel sets the model name for hook context
// SetModel sets the model name for hook context. This information is passed
// to hooks as part of their input data for context-aware processing.
func (e *Executor) SetModel(model string) {
e.mu.Lock()
defer e.mu.Unlock()
e.model = model
}
// SetInteractive sets whether we're in interactive mode
// SetInteractive sets whether the application is running in interactive mode.
// This information is passed to hooks for mode-specific behavior.
func (e *Executor) SetInteractive(interactive bool) {
e.mu.Lock()
defer e.mu.Unlock()
e.interactive = interactive
}
// PopulateCommonFields fills in the common fields for any hook input
// PopulateCommonFields fills in the common fields for any hook input, including
// session ID, transcript path, working directory, event name, timestamp, model,
// and interactive mode. These fields provide context to hooks regardless of event type.
func (e *Executor) PopulateCommonFields(event HookEvent) CommonInput {
e.mu.RLock()
defer e.mu.RUnlock()
@@ -62,7 +70,10 @@ func (e *Executor) PopulateCommonFields(event HookEvent) CommonInput {
}
}
// ExecuteHooks runs all matching hooks for an event
// ExecuteHooks runs all matching hooks for an event. For tool-related events,
// it matches hooks based on tool name patterns. Hooks are executed in parallel
// with configurable timeouts. Returns a combined HookOutput from all executed
// hooks, with blocking decisions taking precedence.
func (e *Executor) ExecuteHooks(ctx context.Context, event HookEvent, input interface{}) (*HookOutput, error) {
matchers, ok := e.config.Hooks[event]
if !ok || len(matchers) == 0 {
+19 -6
View File
@@ -4,7 +4,9 @@ import (
"encoding/json"
)
// CommonInput contains fields common to all hook inputs
// CommonInput contains fields common to all hook inputs, providing context
// information that is available to every hook regardless of the event type.
// These fields help hooks understand the execution environment and session state.
type CommonInput struct {
SessionID string `json:"session_id"` // Unique session identifier
TranscriptPath string `json:"transcript_path"` // Path to transcript file (if enabled)
@@ -15,14 +17,18 @@ type CommonInput struct {
Interactive bool `json:"interactive"` // Whether in interactive mode
}
// PreToolUseInput is passed to PreToolUse hooks
// PreToolUseInput is passed to PreToolUse hooks before a tool is executed.
// It contains the tool name and input parameters, allowing hooks to validate,
// modify, or block tool execution.
type PreToolUseInput struct {
CommonInput
ToolName string `json:"tool_name"`
ToolInput json.RawMessage `json:"tool_input"`
}
// PostToolUseInput is passed to PostToolUse hooks
// PostToolUseInput is passed to PostToolUse hooks after a tool has been executed.
// It contains the tool name, input parameters, and the tool's response, allowing
// hooks to log, analyze, or react to tool execution results.
type PostToolUseInput struct {
CommonInput
ToolName string `json:"tool_name"`
@@ -30,13 +36,17 @@ type PostToolUseInput struct {
ToolResponse json.RawMessage `json:"tool_response"`
}
// UserPromptSubmitInput is passed to UserPromptSubmit hooks
// UserPromptSubmitInput is passed to UserPromptSubmit hooks when a user submits
// a prompt. It contains the user's input text, allowing hooks to validate,
// modify, or log user interactions before processing.
type UserPromptSubmitInput struct {
CommonInput
Prompt string `json:"prompt"`
}
// StopInput is passed to Stop hooks
// StopInput is passed to Stop hooks when the agent finishes responding to a prompt.
// It contains the final response, completion reason, and optional metadata about
// the interaction, allowing hooks to perform cleanup or logging operations.
type StopInput struct {
CommonInput
StopHookActive bool `json:"stop_hook_active"`
@@ -45,7 +55,10 @@ type StopInput struct {
Meta json.RawMessage `json:"meta,omitempty"` // Additional metadata (e.g., token usage, model info)
}
// HookOutput represents the JSON output from a hook
// HookOutput represents the JSON output from a hook that controls MCPHost behavior.
// Hooks can decide whether to continue execution, provide reasons for stopping,
// suppress output, or block tool execution. The Decision field can be "approve",
// "block", or empty (default behavior).
type HookOutput struct {
Continue *bool `json:"continue,omitempty"`
StopReason string `json:"stopReason,omitempty"`
+4 -1
View File
@@ -68,7 +68,10 @@ func containsDangerousPattern(command string) bool {
return separatorCount > 2
}
// ValidateHookConfig validates the entire hook configuration
// ValidateHookConfig validates the entire hook configuration for correctness
// and security. It checks event validity, regex patterns, hook definitions,
// and performs security validation on all commands. Returns an error describing
// any validation failures.
func ValidateHookConfig(config *HookConfig) error {
if config == nil {
return fmt.Errorf("nil configuration")
+76 -7
View File
@@ -13,17 +13,42 @@ import (
"github.com/cloudwego/eino/schema"
)
// CustomChatModel wraps the eino-ext Claude model with custom tool schema handling
// CustomChatModel wraps the eino-ext Claude model with custom tool schema handling.
// It provides a compatibility layer that fixes malformed JSON in tool calls and
// ensures proper schema validation for Anthropic's API requirements.
// This wrapper is necessary to handle edge cases where the underlying library
// may generate invalid JSON for empty tool inputs or missing properties.
type CustomChatModel struct {
// wrapped is the underlying eino-ext Claude model instance
wrapped *einoclaude.ChatModel
}
// CustomRoundTripper intercepts HTTP requests to fix Anthropic function schemas
// CustomRoundTripper intercepts HTTP requests to fix Anthropic function schemas.
// It acts as a middleware that modifies requests before they reach the Anthropic API,
// ensuring that tool schemas and function calls are properly formatted.
// This is particularly important for handling edge cases like empty tool inputs
// or missing schema properties that would otherwise cause API errors.
type CustomRoundTripper struct {
// wrapped is the underlying HTTP transport to use for actual requests
wrapped http.RoundTripper
}
// NewCustomChatModel creates a new custom Anthropic chat model
// NewCustomChatModel creates a new custom Anthropic chat model.
// It wraps the standard eino-ext Claude model with additional request
// preprocessing to ensure compatibility with Anthropic's API requirements.
//
// Parameters:
// - ctx: Context for the operation
// - config: Configuration for the Claude model including API key, model name, and parameters
//
// Returns:
// - *CustomChatModel: A wrapped Claude model with enhanced compatibility
// - error: Returns an error if model creation fails
//
// The custom model automatically:
// - Fixes malformed JSON in tool calls
// - Ensures tool schemas have required properties
// - Handles empty or missing input fields in function calls
func NewCustomChatModel(ctx context.Context, config *einoclaude.Config) (*CustomChatModel, error) {
// Create a custom HTTP client that intercepts requests
if config.HTTPClient == nil {
@@ -49,7 +74,21 @@ func NewCustomChatModel(ctx context.Context, config *einoclaude.Config) (*Custom
}, nil
}
// RoundTrip implements http.RoundTripper to intercept and fix requests
// RoundTrip implements http.RoundTripper to intercept and fix requests.
// It preprocesses outgoing requests to the Anthropic API to ensure
// they meet the API's requirements for tool schemas and function calls.
//
// Parameters:
// - req: The HTTP request to be sent to the Anthropic API
//
// Returns:
// - *http.Response: The response from the Anthropic API
// - error: Any error that occurred during the request
//
// The method performs the following fixes:
// - Ensures tool input_schema properties are not null
// - Fixes malformed JSON patterns in tool_use content
// - Validates and corrects empty or invalid function call inputs
func (rt *CustomRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// Only process Anthropic API requests
if !strings.Contains(req.URL.Host, "anthropic.com") {
@@ -191,17 +230,47 @@ func (rt *CustomRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
return rt.wrapped.RoundTrip(req)
}
// Generate implements the model.BaseChatModel interface
// Generate implements the model.BaseChatModel interface.
// It generates a single response from the model based on the input messages.
//
// Parameters:
// - ctx: Context for the operation, supporting cancellation and deadlines
// - input: The conversation history as a slice of messages
// - opts: Optional configuration options for the generation
//
// Returns:
// - *schema.Message: The generated response message
// - error: Any error that occurred during generation
func (m *CustomChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) {
return m.wrapped.Generate(ctx, input, opts...)
}
// Stream implements the model.BaseChatModel interface
// Stream implements the model.BaseChatModel interface.
// It generates a streaming response from the model, allowing incremental
// processing of the model's output as it's generated.
//
// Parameters:
// - ctx: Context for the operation, supporting cancellation and deadlines
// - input: The conversation history as a slice of messages
// - opts: Optional configuration options for the generation
//
// Returns:
// - *schema.StreamReader[*schema.Message]: A reader for the streaming response
// - error: Any error that occurred during stream setup
func (m *CustomChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) {
return m.wrapped.Stream(ctx, input, opts...)
}
// WithTools implements the model.ToolCallingChatModel interface
// WithTools implements the model.ToolCallingChatModel interface.
// It creates a new model instance with the specified tools available for function calling.
// The original model instance remains unchanged.
//
// Parameters:
// - tools: A slice of tool definitions that the model can use
//
// Returns:
// - model.ToolCallingChatModel: A new model instance with tools enabled
// - error: Returns an error if tool binding fails
func (m *CustomChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) {
wrappedWithTools, err := m.wrapped.WithTools(tools)
if err != nil {
+103 -16
View File
@@ -17,21 +17,27 @@ import (
var _ model.ToolCallingChatModel = (*ChatModel)(nil)
// NewChatModel creates a new Gemini chat model instance
// NewChatModel creates a new Gemini chat model instance.
// It initializes a Google Gemini model with the specified configuration,
// supporting both text generation and tool calling capabilities.
//
// Parameters:
// - ctx: The context for the operation
// - cfg: Configuration for the Gemini model
// - ctx: The context for the operation (currently unused but kept for interface consistency)
// - cfg: Configuration for the Gemini model including client, model name, and parameters
//
// Returns:
// - model.ChatModel: A chat model interface implementation
// - *ChatModel: A Gemini chat model instance implementing ToolCallingChatModel
// - error: Any error that occurred during creation
//
// Example:
//
// client, _ := genai.NewClient(ctx, &genai.ClientConfig{
// APIKey: "your-api-key",
// })
// model, err := gemini.NewChatModel(ctx, &gemini.Config{
// Client: client,
// Model: "gemini-pro",
// MaxTokens: &maxTokens,
// })
func NewChatModel(_ context.Context, cfg *Config) (*ChatModel, error) {
return &ChatModel{
@@ -90,28 +96,58 @@ type Config struct {
SafetySettings []*genai.SafetySetting
}
// options contains Gemini-specific options for model configuration
// options contains Gemini-specific options for model configuration.
// These are options that are specific to the Gemini API and not part
// of the common model options interface.
type options struct {
TopK *int32
// TopK limits the number of tokens to sample from
TopK *int32
// ResponseSchema defines the expected JSON structure for responses
ResponseSchema *openapi3.Schema
}
// ChatModel implements the Gemini chat model for the eino framework.
// It provides integration with Google's Gemini API, supporting both
// text generation and tool calling capabilities.
type ChatModel struct {
// cli is the Gemini API client instance
cli *genai.Client
model string
maxTokens *int
topP *float32
temperature *float32
topK *int32
responseSchema *openapi3.Schema
tools []*genai.Tool
origTools []*schema.ToolInfo
toolChoice *schema.ToolChoice
// model specifies which Gemini model to use
model string
// maxTokens limits the response length
maxTokens *int
// topP controls nucleus sampling
topP *float32
// temperature controls randomness
temperature *float32
// topK limits token sampling
topK *int32
// responseSchema for structured JSON output
responseSchema *openapi3.Schema
// tools converted to Gemini format
tools []*genai.Tool
// origTools stores the original tool definitions
origTools []*schema.ToolInfo
// toolChoice controls how tools are used
toolChoice *schema.ToolChoice
// enableCodeExecution allows code execution (use with caution)
enableCodeExecution bool
safetySettings []*genai.SafetySetting
// safetySettings for content filtering
safetySettings []*genai.SafetySetting
}
// Generate generates a single response from the Gemini model.
// It processes the input messages and returns a complete response.
//
// Parameters:
// - ctx: Context for the operation, supporting cancellation and callbacks
// - input: The conversation history as a slice of messages
// - opts: Optional configuration options for the generation
//
// Returns:
// - *schema.Message: The generated response message with content and metadata
// - error: Any error that occurred during generation
func (cm *ChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (message *schema.Message, err error) {
ctx = callbacks.EnsureRunInfo(ctx, cm.GetType(), components.ComponentOfChatModel)
@@ -154,6 +190,17 @@ func (cm *ChatModel) Generate(ctx context.Context, input []*schema.Message, opts
return message, nil
}
// Stream generates a streaming response from the Gemini model.
// It allows incremental processing of the model's output as it's generated.
//
// Parameters:
// - ctx: Context for the operation, supporting cancellation and callbacks
// - input: The conversation history as a slice of messages
// - opts: Optional configuration options for the generation
//
// Returns:
// - *schema.StreamReader[*schema.Message]: A reader for the streaming response
// - error: Any error that occurred during stream setup
func (cm *ChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (result *schema.StreamReader[*schema.Message], err error) {
ctx = callbacks.EnsureRunInfo(ctx, cm.GetType(), components.ComponentOfChatModel)
@@ -218,6 +265,16 @@ func (cm *ChatModel) Stream(ctx context.Context, input []*schema.Message, opts .
}), nil
}
// WithTools creates a new model instance with the specified tools available.
// It returns a new ChatModel with tools configured for function calling.
// The original model instance remains unchanged.
//
// Parameters:
// - tools: A slice of tool definitions that the model can use
//
// Returns:
// - model.ToolCallingChatModel: A new model instance with tools enabled
// - error: Returns an error if no tools provided or conversion fails
func (cm *ChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) {
if len(tools) == 0 {
return nil, errors.New("no tools to bind")
@@ -235,6 +292,15 @@ func (cm *ChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatM
return &ncm, nil
}
// BindTools binds tools to the current model instance.
// Unlike WithTools, this modifies the current instance rather than
// creating a new one. Tools are set to "allowed" mode by default.
//
// Parameters:
// - tools: A slice of tool definitions to bind to the model
//
// Returns:
// - error: Returns an error if no tools provided or conversion fails
func (cm *ChatModel) BindTools(tools []*schema.ToolInfo) error {
if len(tools) == 0 {
return errors.New("no tools to bind")
@@ -251,6 +317,15 @@ func (cm *ChatModel) BindTools(tools []*schema.ToolInfo) error {
return nil
}
// BindForcedTools binds tools to the current model instance in forced mode.
// This ensures the model will always use one of the provided tools
// rather than generating a text response.
//
// Parameters:
// - tools: A slice of tool definitions to bind to the model
//
// Returns:
// - error: Returns an error if no tools provided or conversion fails
func (cm *ChatModel) BindForcedTools(tools []*schema.ToolInfo) error {
if len(tools) == 0 {
return errors.New("no tools to bind")
@@ -679,12 +754,24 @@ func (cm *ChatModel) convertCallbackOutput(message *schema.Message, conf *model.
return callbackOutput
}
// IsCallbacksEnabled indicates whether this model supports callbacks.
// For the Gemini model, callbacks are always enabled to support
// token usage tracking and other monitoring features.
//
// Returns:
// - bool: Always returns true for Gemini models
func (cm *ChatModel) IsCallbacksEnabled() bool {
return true
}
const typ = "Gemini"
// GetType returns the type identifier for this model.
// This is used for logging and debugging purposes to identify
// which model implementation is being used.
//
// Returns:
// - string: Returns "Gemini" as the model type
func (cm *ChatModel) GetType() string {
return typ
}
+43 -19
View File
@@ -12,37 +12,61 @@ import (
"time"
)
// ModelInfo represents information about a specific model
// ModelInfo represents information about a specific model.
// This struct is used during code generation to parse model data
// from the models.dev API and generate the static Go code.
type ModelInfo struct {
ID string `json:"id"`
Name string `json:"name"`
Attachment bool `json:"attachment"`
Reasoning bool `json:"reasoning"`
Temperature bool `json:"temperature"`
Cost Cost `json:"cost"`
Limit Limit `json:"limit"`
// ID is the unique identifier for the model
ID string `json:"id"`
// Name is the human-readable name of the model
Name string `json:"name"`
// Attachment indicates whether the model supports file attachments
Attachment bool `json:"attachment"`
// Reasoning indicates whether this is a reasoning/chain-of-thought model
Reasoning bool `json:"reasoning"`
// Temperature indicates whether the model supports temperature parameter
Temperature bool `json:"temperature"`
// Cost contains the pricing information for the model
Cost Cost `json:"cost"`
// Limit contains the context and output token limits
Limit Limit `json:"limit"`
}
// Cost represents the pricing information for a model
// Cost represents the pricing information for a model.
// Used during code generation to parse pricing data from models.dev.
type Cost struct {
Input float64 `json:"input"`
Output float64 `json:"output"`
CacheRead *float64 `json:"cache_read,omitempty"`
// Input is the cost per million input tokens
Input float64 `json:"input"`
// Output is the cost per million output tokens
Output float64 `json:"output"`
// CacheRead is the cost per million cached read tokens (optional)
CacheRead *float64 `json:"cache_read,omitempty"`
// CacheWrite is the cost per million cached write tokens (optional)
CacheWrite *float64 `json:"cache_write,omitempty"`
}
// Limit represents the context and output limits for a model
// Limit represents the context and output limits for a model.
// Used during code generation to parse token limit data from models.dev.
type Limit struct {
// Context is the maximum number of input tokens
Context int `json:"context"`
Output int `json:"output"`
// Output is the maximum number of output tokens
Output int `json:"output"`
}
// ProviderInfo represents information about a model provider
// ProviderInfo represents information about a model provider.
// Used during code generation to parse provider data from models.dev
// and generate the static provider registry.
type ProviderInfo struct {
ID string `json:"id"`
Env []string `json:"env"`
NPM string `json:"npm"`
Name string `json:"name"`
// ID is the unique identifier for the provider
ID string `json:"id"`
// Env lists the environment variables for API credentials
Env []string `json:"env"`
// NPM is the NPM package name (for reference)
NPM string `json:"npm"`
// Name is the human-readable provider name
Name string `json:"name"`
// Models maps model IDs to their information
Models map[string]ModelInfo `json:"models"`
}
+77 -19
View File
@@ -3,41 +3,99 @@
package models
// ModelInfo represents information about a specific model
// ModelInfo represents information about a specific model.
// It contains comprehensive metadata about a model's capabilities,
// pricing, and limitations sourced from models.dev.
type ModelInfo struct {
ID string
Name string
Attachment bool
Reasoning bool
// ID is the unique identifier for the model
// Example: "claude-3-sonnet-20240620" or "gpt-4"
ID string
// Name is the human-readable name of the model
// Example: "Claude 3 Sonnet" or "GPT-4"
Name string
// Attachment indicates whether the model supports file attachments
Attachment bool
// Reasoning indicates whether this is a reasoning/chain-of-thought model
// Example: OpenAI's o1 models have this set to true
Reasoning bool
// Temperature indicates whether the model supports temperature parameter
Temperature bool
Cost Cost
Limit Limit
// Cost contains the pricing information for input/output tokens
Cost Cost
// Limit contains the context window and output token limits
Limit Limit
}
// Cost represents the pricing information for a model
// Cost represents the pricing information for a model.
// Prices are typically in USD per million tokens.
type Cost struct {
Input float64
Output float64
CacheRead *float64
// Input is the cost per million input tokens
Input float64
// Output is the cost per million output tokens
Output float64
// CacheRead is the cost per million cached read tokens (optional)
// Only applicable for models that support prompt caching
CacheRead *float64
// CacheWrite is the cost per million cached write tokens (optional)
// Only applicable for models that support prompt caching
CacheWrite *float64
}
// Limit represents the context and output limits for a model
// Limit represents the context and output limits for a model.
// These define the maximum number of tokens the model can process
// and generate in a single interaction.
type Limit struct {
// Context is the maximum number of input tokens (context window size)
Context int
Output int
// Output is the maximum number of output tokens that can be generated
Output int
}
// ProviderInfo represents information about a model provider
// ProviderInfo represents information about a model provider.
// It contains metadata about the provider and all models it offers.
type ProviderInfo struct {
ID string
Env []string
NPM string
Name string
// ID is the unique identifier for the provider
// Example: "anthropic", "openai", "google"
ID string
// Env lists the environment variables checked for API credentials
// Example: ["ANTHROPIC_API_KEY"] or ["OPENAI_API_KEY"]
Env []string
// NPM is the NPM package name used by the provider (for reference)
NPM string
// Name is the human-readable name of the provider
// Example: "Anthropic", "OpenAI", "Google"
Name string
// Models maps model IDs to their detailed information
Models map[string]ModelInfo
}
// GetModelsData returns the static models data from models.dev
// GetModelsData returns the static models data from models.dev.
// This data is automatically generated from the models.dev API
// and provides comprehensive information about all supported
// LLM providers and their models.
//
// The data includes:
// - Provider information (ID, name, environment variables)
// - Model capabilities (reasoning, attachments, temperature support)
// - Pricing information (input/output costs, cache costs)
// - Token limits (context window, max output tokens)
//
// Returns:
// - map[string]ProviderInfo: A map of provider IDs to their information
func GetModelsData() map[string]ProviderInfo {
return map[string]ProviderInfo{
"alibaba": {
+104 -11
View File
@@ -14,17 +14,42 @@ import (
"github.com/cloudwego/eino/schema"
)
// CustomChatModel wraps the eino-ext OpenAI model with custom tool schema handling
// CustomChatModel wraps the eino-ext OpenAI model with custom tool schema handling.
// It provides a compatibility layer that ensures proper JSON schema formatting
// for OpenAI's function calling feature. This wrapper addresses cases where
// tool schemas might have missing or empty properties that would cause API errors.
type CustomChatModel struct {
// wrapped is the underlying eino-ext OpenAI model instance
wrapped *einoopenai.ChatModel
}
// CustomRoundTripper intercepts HTTP requests to fix OpenAI function schemas
// CustomRoundTripper intercepts HTTP requests to fix OpenAI function schemas.
// It acts as middleware that modifies outgoing requests to ensure that
// function/tool schemas are properly formatted according to OpenAI's requirements.
// This is particularly important for handling edge cases where tool schemas
// might have missing or empty properties fields.
type CustomRoundTripper struct {
// wrapped is the underlying HTTP transport to use for actual requests
wrapped http.RoundTripper
}
// NewCustomChatModel creates a new custom OpenAI chat model
// NewCustomChatModel creates a new custom OpenAI chat model.
// It wraps the standard eino-ext OpenAI model with additional request
// preprocessing to ensure compatibility with OpenAI's API requirements,
// particularly for function calling and tool schemas.
//
// Parameters:
// - ctx: Context for the operation
// - config: Configuration for the OpenAI model including API key, model name, and parameters
//
// Returns:
// - *CustomChatModel: A wrapped OpenAI model with enhanced compatibility
// - error: Returns an error if model creation fails
//
// The custom model automatically:
// - Ensures function parameter schemas have properties fields
// - Fixes missing or empty properties in tool schemas
// - Maintains compatibility with OpenAI's function calling requirements
func NewCustomChatModel(ctx context.Context, config *einoopenai.ChatModelConfig) (*CustomChatModel, error) {
// Create a custom HTTP client that intercepts requests
if config.HTTPClient == nil {
@@ -49,7 +74,20 @@ func NewCustomChatModel(ctx context.Context, config *einoopenai.ChatModelConfig)
}, nil
}
// RoundTrip implements http.RoundTripper to intercept and fix OpenAI requests
// RoundTrip implements http.RoundTripper to intercept and fix OpenAI requests.
// It preprocesses outgoing requests to the OpenAI API to ensure tool/function
// schemas meet the API's requirements.
//
// Parameters:
// - req: The HTTP request to be sent to the OpenAI API
//
// Returns:
// - *http.Response: The response from the OpenAI API
// - error: Any error that occurred during the request
//
// The method performs the following fixes:
// - Ensures function parameter schemas of type "object" have a properties field
// - Adds empty properties object if missing to prevent API validation errors
func (c *CustomRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// Only intercept OpenAI chat completions requests
if !strings.Contains(req.URL.Path, "/chat/completions") {
@@ -109,17 +147,47 @@ func (c *CustomRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
return c.wrapped.RoundTrip(req)
}
// Generate implements model.ChatModel
// Generate implements model.ChatModel interface.
// It generates a single response from the OpenAI model based on the input messages.
//
// Parameters:
// - ctx: Context for the operation, supporting cancellation and deadlines
// - in: The conversation history as a slice of messages
// - opts: Optional configuration options for the generation
//
// Returns:
// - *schema.Message: The generated response message
// - error: Any error that occurred during generation
func (c *CustomChatModel) Generate(ctx context.Context, in []*schema.Message, opts ...model.Option) (*schema.Message, error) {
return c.wrapped.Generate(ctx, in, opts...)
}
// Stream implements model.ChatModel
// Stream implements model.ChatModel interface.
// It generates a streaming response from the OpenAI model, allowing
// incremental processing of the model's output as it's generated.
//
// Parameters:
// - ctx: Context for the operation, supporting cancellation and deadlines
// - in: The conversation history as a slice of messages
// - opts: Optional configuration options for the generation
//
// Returns:
// - *schema.StreamReader[*schema.Message]: A reader for the streaming response
// - error: Any error that occurred during stream setup
func (c *CustomChatModel) Stream(ctx context.Context, in []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) {
return c.wrapped.Stream(ctx, in, opts...)
}
// WithTools implements model.ToolCallingChatModel
// WithTools implements model.ToolCallingChatModel interface.
// It creates a new model instance with the specified tools available for function calling.
// The original model instance remains unchanged.
//
// Parameters:
// - tools: A slice of tool definitions that the model can use
//
// Returns:
// - model.ToolCallingChatModel: A new model instance with tools enabled
// - error: Returns an error if tool binding fails
func (c *CustomChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) {
wrappedWithTools, err := c.wrapped.WithTools(tools)
if err != nil {
@@ -135,22 +203,47 @@ func (c *CustomChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCalling
return &CustomChatModel{wrapped: wrappedChatModel}, nil
}
// BindTools implements model.ToolCallingChatModel
// BindTools implements model.ToolCallingChatModel interface.
// It binds tools to the current model instance, modifying it in place
// rather than creating a new instance.
//
// Parameters:
// - tools: A slice of tool definitions to bind to the model
//
// Returns:
// - error: Returns an error if tool binding fails
func (c *CustomChatModel) BindTools(tools []*schema.ToolInfo) error {
return c.wrapped.BindTools(tools)
}
// BindForcedTools implements model.ToolCallingChatModel
// BindForcedTools implements model.ToolCallingChatModel interface.
// It binds tools to the current model instance in forced mode,
// ensuring the model will always use one of the provided tools.
//
// Parameters:
// - tools: A slice of tool definitions to bind to the model
//
// Returns:
// - error: Returns an error if tool binding fails
func (c *CustomChatModel) BindForcedTools(tools []*schema.ToolInfo) error {
return c.wrapped.BindForcedTools(tools)
}
// GetType implements model.ChatModel
// GetType implements model.ChatModel interface.
// It returns the type identifier for this model implementation.
//
// Returns:
// - string: Returns "CustomOpenAI" as the model type identifier
func (c *CustomChatModel) GetType() string {
return "CustomOpenAI"
}
// IsCallbacksEnabled implements model.ChatModel
// IsCallbacksEnabled implements model.ChatModel interface.
// It indicates whether this model supports callbacks for monitoring
// and tracking purposes.
//
// Returns:
// - bool: Returns the callback enabled status from the wrapped model
func (c *CustomChatModel) IsCallbacksEnabled() bool {
return c.wrapped.IsCallbacksEnabled()
}
+85 -17
View File
@@ -27,7 +27,9 @@ import (
)
const (
// ClaudeCodePrompt is the required system prompt for OAuth authentication
// ClaudeCodePrompt is the required system prompt for OAuth authentication.
// This prompt must be included as the first system message when using OAuth-based
// authentication with Anthropic's API to properly identify the application.
ClaudeCodePrompt = "You are Claude Code, Anthropic's official CLI for Claude."
)
@@ -62,35 +64,93 @@ func resolveModelAlias(provider, modelName string) string {
return modelName
}
// ProviderConfig holds configuration for creating LLM providers
// ProviderConfig holds configuration for creating LLM providers.
// It contains all necessary settings to initialize and configure
// various LLM providers including API keys, model parameters, and
// provider-specific settings.
type ProviderConfig struct {
ModelString string
SystemPrompt string
ProviderAPIKey string // API key for OpenAI and Anthropic
ProviderURL string // Base URL for OpenAI, Anthropic, and Ollama
// ModelString specifies the model in the format "provider:model"
// Example: "anthropic:claude-3-sonnet-20240620" or "openai:gpt-4"
ModelString string
// SystemPrompt sets the system message/instructions for the model
SystemPrompt string
// ProviderAPIKey is the API key for authentication with the provider
// Used for OpenAI, Anthropic, Google, and Azure providers
ProviderAPIKey string
// ProviderURL is the base URL for the provider's API endpoint
// Can be used to specify custom endpoints for OpenAI, Anthropic, Ollama, or Azure
ProviderURL string
// Model generation parameters
MaxTokens int
Temperature *float32
TopP *float32
TopK *int32
// MaxTokens limits the maximum number of tokens in the response
MaxTokens int
// Temperature controls randomness in generation (0.0 to 1.0)
// Lower values make output more focused and deterministic
Temperature *float32
// TopP implements nucleus sampling, controlling diversity
// Value between 0.0 and 1.0, where 1.0 considers all tokens
TopP *float32
// TopK limits the number of tokens to sample from
// Used by Anthropic and Google providers
TopK *int32
// StopSequences is a list of sequences that will stop generation
StopSequences []string
// Ollama-specific parameters
NumGPU *int32
// NumGPU specifies the number of GPUs to use for inference
NumGPU *int32
// MainGPU specifies which GPU to use as the primary device
MainGPU *int32
// TLS configuration
TLSSkipVerify bool // Skip TLS certificate verification (insecure)
// TLSSkipVerify skips TLS certificate verification (insecure)
// Should only be used for development or with self-signed certificates
TLSSkipVerify bool
}
// ProviderResult contains the result of provider creation
// ProviderResult contains the result of provider creation.
// It includes both the created model instance and any informational
// messages that should be displayed to the user.
type ProviderResult struct {
Model model.ToolCallingChatModel
Message string // Optional message for user feedback (e.g., GPU fallback info)
// Model is the created LLM provider instance that implements
// the ToolCallingChatModel interface for tool-enabled conversations
Model model.ToolCallingChatModel
// Message contains optional feedback for the user
// Example: "Insufficient GPU memory, falling back to CPU inference"
Message string
}
// CreateProvider creates an eino ToolCallingChatModel based on the provider configuration
// CreateProvider creates an eino ToolCallingChatModel based on the provider configuration.
// It validates the model, checks required environment variables, and initializes
// the appropriate provider (Anthropic, OpenAI, Google, Ollama, or Azure).
//
// Parameters:
// - ctx: Context for the operation, used for cancellation and deadlines
// - config: Provider configuration containing model details and parameters
//
// Returns:
// - *ProviderResult: Contains the created model and any informational messages
// - error: Returns an error if provider creation fails, model validation fails,
// or required credentials are missing
//
// Supported providers:
// - anthropic: Claude models with API key or OAuth authentication
// - openai: GPT models including reasoning models like o1
// - google: Gemini models
// - ollama: Local models with automatic GPU/CPU fallback
// - azure: Azure OpenAI Service deployments
func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResult, error) {
parts := strings.SplitN(config.ModelString, ":", 2)
if len(parts) < 2 {
@@ -407,9 +467,17 @@ func createGoogleProvider(ctx context.Context, config *ProviderConfig, modelName
return gemini.NewChatModel(ctx, geminiConfig)
}
// OllamaLoadingResult contains the result of model loading with actual settings used
// OllamaLoadingResult contains the result of model loading with actual settings used.
// It provides information about how the model was loaded, including any fallback
// behavior that occurred during initialization.
type OllamaLoadingResult struct {
// Options contains the actual Ollama options used for loading
// May differ from requested options if fallback occurred (e.g., CPU instead of GPU)
Options *api.Options
// Message describes the loading result
// Example: "Model loaded successfully on GPU" or
// "Insufficient GPU memory, falling back to CPU inference"
Message string
}
+81 -9
View File
@@ -8,19 +8,39 @@ import (
"strings"
)
// ModelsRegistry provides validation and information about models
// ModelsRegistry provides validation and information about models.
// It maintains a registry of all supported LLM providers and their models,
// including capabilities, pricing, and configuration requirements.
// The registry data is generated from models.dev and provides a single
// source of truth for model validation and discovery.
type ModelsRegistry struct {
// providers maps provider IDs to their information and available models
providers map[string]ProviderInfo
}
// NewModelsRegistry creates a new models registry with static data
// NewModelsRegistry creates a new models registry with static data.
// The registry is populated with model information generated from models.dev,
// providing comprehensive metadata about available models across all supported providers.
//
// Returns:
// - *ModelsRegistry: A new registry instance populated with current model data
func NewModelsRegistry() *ModelsRegistry {
return &ModelsRegistry{
providers: GetModelsData(),
}
}
// ValidateModel validates if a model exists and returns detailed information
// ValidateModel validates if a model exists and returns detailed information.
// It checks whether a specific model is available for a given provider and
// returns comprehensive information about the model's capabilities and limits.
//
// Parameters:
// - provider: The provider ID (e.g., "anthropic", "openai", "google")
// - modelID: The specific model ID (e.g., "claude-3-sonnet-20240620", "gpt-4")
//
// Returns:
// - *ModelInfo: Detailed information about the model including pricing, limits, and capabilities
// - error: Returns an error if the provider is unsupported or model is not found
func (r *ModelsRegistry) ValidateModel(provider, modelID string) (*ModelInfo, error) {
providerInfo, exists := r.providers[provider]
if !exists {
@@ -35,7 +55,21 @@ func (r *ModelsRegistry) ValidateModel(provider, modelID string) (*ModelInfo, er
return &modelInfo, nil
}
// GetRequiredEnvVars returns the required environment variables for a provider
// GetRequiredEnvVars returns the required environment variables for a provider.
// These are the environment variable names that should contain API keys or
// other authentication credentials for the specified provider.
//
// Parameters:
// - provider: The provider ID (e.g., "anthropic", "openai", "google")
//
// Returns:
// - []string: List of environment variable names the provider checks for credentials
// - error: Returns an error if the provider is unsupported
//
// Example:
//
// For "anthropic", returns ["ANTHROPIC_API_KEY"]
// For "google", returns ["GOOGLE_API_KEY", "GEMINI_API_KEY", "GOOGLE_GENERATIVE_AI_API_KEY"]
func (r *ModelsRegistry) GetRequiredEnvVars(provider string) ([]string, error) {
providerInfo, exists := r.providers[provider]
if !exists {
@@ -45,7 +79,16 @@ func (r *ModelsRegistry) GetRequiredEnvVars(provider string) ([]string, error) {
return providerInfo.Env, nil
}
// ValidateEnvironment checks if required environment variables are set
// ValidateEnvironment checks if required environment variables are set.
// It verifies that at least one of the provider's required environment variables
// contains an API key, unless an API key is explicitly provided via configuration.
//
// Parameters:
// - provider: The provider ID to validate environment for
// - apiKey: An API key provided via configuration (if empty, checks environment variables)
//
// Returns:
// - error: Returns nil if validation passes, or an error describing missing credentials
func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) error {
envVars, err := r.GetRequiredEnvVars(provider)
if err != nil {
@@ -74,7 +117,16 @@ func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) err
return nil
}
// SuggestModels returns similar model names when an invalid model is provided
// SuggestModels returns similar model names when an invalid model is provided.
// It helps users discover the correct model ID by finding models that partially
// match the provided input, useful for correcting typos or finding alternatives.
//
// Parameters:
// - provider: The provider ID to search within
// - invalidModel: The invalid or misspelled model name to find suggestions for
//
// Returns:
// - []string: A list of up to 5 suggested model IDs that partially match the input
func (r *ModelsRegistry) SuggestModels(provider, invalidModel string) []string {
providerInfo, exists := r.providers[provider]
if !exists {
@@ -105,7 +157,12 @@ func (r *ModelsRegistry) SuggestModels(provider, invalidModel string) []string {
return suggestions
}
// GetSupportedProviders returns a list of all supported providers
// GetSupportedProviders returns a list of all supported providers.
// This includes all providers that have models registered in the system,
// such as "anthropic", "openai", "google", "alibaba", etc.
//
// Returns:
// - []string: A list of all provider IDs available in the registry
func (r *ModelsRegistry) GetSupportedProviders() []string {
var providers []string
for providerID := range r.providers {
@@ -114,7 +171,16 @@ func (r *ModelsRegistry) GetSupportedProviders() []string {
return providers
}
// GetModelsForProvider returns all models for a specific provider
// GetModelsForProvider returns all models for a specific provider.
// This is useful for listing available models when a user wants to see
// all options for a particular provider.
//
// Parameters:
// - provider: The provider ID to get models for
//
// Returns:
// - map[string]ModelInfo: A map of model IDs to their detailed information
// - error: Returns an error if the provider is unsupported
func (r *ModelsRegistry) GetModelsForProvider(provider string) (map[string]ModelInfo, error) {
providerInfo, exists := r.providers[provider]
if !exists {
@@ -127,7 +193,13 @@ func (r *ModelsRegistry) GetModelsForProvider(provider string) (map[string]Model
// Global registry instance
var globalRegistry = NewModelsRegistry()
// GetGlobalRegistry returns the global models registry instance
// GetGlobalRegistry returns the global models registry instance.
// This provides a singleton registry that can be accessed throughout
// the application for model validation and information retrieval.
// The registry is initialized once with data from models.dev.
//
// Returns:
// - *ModelsRegistry: The global registry instance
func GetGlobalRegistry() *ModelsRegistry {
return globalRegistry
}
+57 -12
View File
@@ -7,14 +7,21 @@ import (
"github.com/cloudwego/eino/schema"
)
// Manager manages session state and auto-saving
// Manager manages session state and auto-saving functionality.
// It provides thread-safe operations for managing a conversation session,
// including automatic persistence to disk after each modification.
// The Manager ensures that all session operations are synchronized and
// that the session file is kept up-to-date with any changes.
type Manager struct {
session *Session
filePath string
mutex sync.RWMutex
}
// NewManager creates a new session manager
// NewManager creates a new session manager with a fresh session.
// The filePath parameter specifies where the session will be auto-saved.
// If filePath is empty, the session will not be automatically saved to disk.
// Returns a Manager instance ready to track conversation messages.
func NewManager(filePath string) *Manager {
return &Manager{
session: NewSession(),
@@ -22,7 +29,11 @@ func NewManager(filePath string) *Manager {
}
}
// NewManagerWithSession creates a new session manager with an existing session
// NewManagerWithSession creates a new session manager with an existing session.
// This is useful when loading a session from a file and wanting to continue
// managing it with auto-save functionality.
// The session parameter is the existing session to manage.
// The filePath parameter specifies where the session will be auto-saved.
func NewManagerWithSession(session *Session, filePath string) *Manager {
return &Manager{
session: session,
@@ -30,7 +41,12 @@ func NewManagerWithSession(session *Session, filePath string) *Manager {
}
}
// AddMessage adds a message to the session and auto-saves
// AddMessage adds a message to the session and auto-saves.
// The message is converted from schema.Message format to the internal
// session Message format before being added. If a filePath was specified
// when creating the Manager, the session is automatically saved to disk.
// This operation is thread-safe.
// Returns an error if auto-saving fails, nil otherwise.
func (m *Manager) AddMessage(msg *schema.Message) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -45,7 +61,11 @@ func (m *Manager) AddMessage(msg *schema.Message) error {
return nil
}
// AddMessages adds multiple messages to the session and auto-saves
// AddMessages adds multiple messages to the session and auto-saves.
// All messages are added in order and then the session is saved once.
// This is more efficient than calling AddMessage multiple times when
// adding several messages at once. The operation is thread-safe.
// Returns an error if auto-saving fails, nil otherwise.
func (m *Manager) AddMessages(msgs []*schema.Message) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -62,7 +82,12 @@ func (m *Manager) AddMessages(msgs []*schema.Message) error {
return nil
}
// ReplaceAllMessages replaces all messages in the session with the provided messages
// ReplaceAllMessages replaces all messages in the session with the provided messages.
// This method completely clears the existing message history and replaces it with
// the new set of messages. Useful for resetting a conversation or loading a
// different conversation context. The operation is thread-safe and triggers
// an auto-save if a filePath is configured.
// Returns an error if auto-saving fails, nil otherwise.
func (m *Manager) ReplaceAllMessages(msgs []*schema.Message) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -83,7 +108,11 @@ func (m *Manager) ReplaceAllMessages(msgs []*schema.Message) error {
return nil
}
// SetMetadata sets the session metadata
// SetMetadata sets the session metadata.
// This updates the session's metadata with information about the provider,
// model, and MCPHost version. The operation is thread-safe and triggers
// an auto-save if a filePath is configured.
// Returns an error if auto-saving fails, nil otherwise.
func (m *Manager) SetMetadata(metadata Metadata) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -97,7 +126,11 @@ func (m *Manager) SetMetadata(metadata Metadata) error {
return nil
}
// GetMessages returns all messages as schema.Message slice
// GetMessages returns all messages as a schema.Message slice.
// This method converts all stored session messages to the schema format
// used by LLM providers. The returned slice is a new allocation, so
// modifications to it won't affect the stored session. This operation
// is thread-safe for concurrent reads.
func (m *Manager) GetMessages() []*schema.Message {
m.mutex.RLock()
defer m.mutex.RUnlock()
@@ -110,7 +143,11 @@ func (m *Manager) GetMessages() []*schema.Message {
return messages
}
// GetSession returns a copy of the current session
// GetSession returns a copy of the current session.
// The returned session is a deep copy, including all messages, so
// modifications to it won't affect the managed session. This is useful
// for safely inspecting the session state without risk of concurrent
// modification. This operation is thread-safe for concurrent reads.
func (m *Manager) GetSession() *Session {
m.mutex.RLock()
defer m.mutex.RUnlock()
@@ -123,7 +160,11 @@ func (m *Manager) GetSession() *Session {
return &sessionCopy
}
// Save manually saves the session to file
// Save manually saves the session to file.
// This forces a save operation even if no changes have been made.
// Useful for ensuring the session is persisted at specific points.
// Returns an error if no filePath was specified when creating the
// Manager, or if the save operation fails.
func (m *Manager) Save() error {
m.mutex.RLock()
defer m.mutex.RUnlock()
@@ -135,12 +176,16 @@ func (m *Manager) Save() error {
return m.session.SaveToFile(m.filePath)
}
// GetFilePath returns the file path for this session
// GetFilePath returns the file path for this session.
// Returns the path where the session is being auto-saved, or an
// empty string if no auto-save path was configured.
func (m *Manager) GetFilePath() string {
return m.filePath
}
// MessageCount returns the number of messages in the session
// MessageCount returns the number of messages in the session.
// This provides a quick way to check the conversation length without
// retrieving all messages. This operation is thread-safe for concurrent reads.
func (m *Manager) MessageCount() int {
m.mutex.RLock()
defer m.mutex.RUnlock()
+82 -25
View File
@@ -11,40 +11,71 @@ import (
"github.com/cloudwego/eino/schema"
)
// Session represents a complete conversation session with metadata
// Session represents a complete conversation session with metadata.
// It stores all messages exchanged during a conversation along with
// contextual information about the session such as the provider, model,
// and timestamps. Sessions can be saved to and loaded from JSON files
// for persistence across program runs.
type Session struct {
Version string `json:"version"`
// Version indicates the session format version for compatibility
Version string `json:"version"`
// CreatedAt is the timestamp when the session was first created
CreatedAt time.Time `json:"created_at"`
// UpdatedAt is the timestamp when the session was last modified
UpdatedAt time.Time `json:"updated_at"`
Metadata Metadata `json:"metadata"`
Messages []Message `json:"messages"`
// Metadata contains contextual information about the session
Metadata Metadata `json:"metadata"`
// Messages is the ordered list of all messages in this session
Messages []Message `json:"messages"`
}
// Metadata contains session metadata
// Metadata contains session metadata that provides context about the
// environment and configuration used during the conversation. This helps
// with debugging and understanding the session's context when reviewing
// conversation history.
type Metadata struct {
// MCPHostVersion is the version of MCPHost used for this session
MCPHostVersion string `json:"mcphost_version"`
Provider string `json:"provider"`
Model string `json:"model"`
// Provider is the LLM provider used (e.g., "anthropic", "openai", "gemini")
Provider string `json:"provider"`
// Model is the specific model identifier used for the conversation
Model string `json:"model"`
}
// Message represents a single message in the session
// Message represents a single message in the conversation session.
// Messages can be from different roles (user, assistant, tool) and may
// include tool calls for assistant messages or tool results for tool messages.
type Message struct {
ID string `json:"id"`
Role string `json:"role"`
Content string `json:"content"`
Timestamp time.Time `json:"timestamp"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"` // For tool result messages
// ID is a unique identifier for this message, auto-generated if not provided
ID string `json:"id"`
// Role indicates who sent the message ("user", "assistant", "tool", or "system")
Role string `json:"role"`
// Content is the text content of the message
Content string `json:"content"`
// Timestamp is when the message was created
Timestamp time.Time `json:"timestamp"`
// ToolCalls contains any tool invocations made by the assistant in this message
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
// ToolCallID links a tool result message to its corresponding tool call
ToolCallID string `json:"tool_call_id,omitempty"`
}
// ToolCall represents a tool call within a message
// ToolCall represents a tool invocation within an assistant message.
// When the assistant decides to use a tool, it creates a ToolCall with
// the necessary information to execute that tool.
type ToolCall struct {
ID string `json:"id"`
Name string `json:"name"`
Arguments any `json:"arguments"`
// ID is a unique identifier for this tool call, used to link results
ID string `json:"id"`
// Name is the name of the tool being invoked
Name string `json:"name"`
// Arguments contains the parameters passed to the tool, typically as JSON
Arguments any `json:"arguments"`
}
// NewSession creates a new session with default values
// NewSession creates a new session with default values.
// It initializes a session with version 1.0, current timestamps,
// empty message list, and empty metadata. The returned session
// is ready to receive messages and can be saved to a file.
func NewSession() *Session {
return &Session{
Version: "1.0",
@@ -55,7 +86,10 @@ func NewSession() *Session {
}
}
// AddMessage adds a message to the session
// AddMessage adds a message to the session.
// If the message doesn't have an ID, one will be auto-generated.
// If the message doesn't have a timestamp, the current time will be used.
// The session's UpdatedAt timestamp is automatically updated.
func (s *Session) AddMessage(msg Message) {
if msg.ID == "" {
msg.ID = generateMessageID()
@@ -68,13 +102,21 @@ func (s *Session) AddMessage(msg Message) {
s.UpdatedAt = time.Now()
}
// SetMetadata sets the session metadata
// SetMetadata sets the session metadata.
// This replaces the existing metadata with the provided metadata
// and updates the session's UpdatedAt timestamp. Use this to record
// information about the provider, model, and MCPHost version.
func (s *Session) SetMetadata(metadata Metadata) {
s.Metadata = metadata
s.UpdatedAt = time.Now()
}
// SaveToFile saves the session to a JSON file
// SaveToFile saves the session to a JSON file.
// The session is serialized as indented JSON for readability.
// The UpdatedAt timestamp is automatically updated before saving.
// The file is created with 0644 permissions if it doesn't exist,
// or overwritten if it does exist.
// Returns an error if marshaling fails or file writing fails.
func (s *Session) SaveToFile(filePath string) error {
s.UpdatedAt = time.Now()
@@ -86,7 +128,12 @@ func (s *Session) SaveToFile(filePath string) error {
return os.WriteFile(filePath, data, 0644)
}
// LoadFromFile loads a session from a JSON file
// LoadFromFile loads a session from a JSON file.
// It reads the file at the specified path and deserializes it into
// a Session struct. This is useful for resuming previous conversations
// or reviewing session history.
// Returns the loaded session on success, or an error if the file
// cannot be read or the JSON is invalid.
func LoadFromFile(filePath string) (*Session, error) {
data, err := os.ReadFile(filePath)
if err != nil {
@@ -101,7 +148,12 @@ func LoadFromFile(filePath string) (*Session, error) {
return &session, nil
}
// ConvertFromSchemaMessage converts a schema.Message to a session Message
// ConvertFromSchemaMessage converts a schema.Message to a session Message.
// This function bridges between the eino schema message format and the
// session's internal message format. It preserves role, content, and
// tool-related information while adding a timestamp.
// Tool calls from assistant messages and tool call IDs from tool messages
// are properly converted and preserved.
func ConvertFromSchemaMessage(msg *schema.Message) Message {
sessionMsg := Message{
Role: string(msg.Role),
@@ -129,7 +181,12 @@ func ConvertFromSchemaMessage(msg *schema.Message) Message {
return sessionMsg
}
// ConvertToSchemaMessage converts a session Message to a schema.Message
// ConvertToSchemaMessage converts a session Message to a schema.Message.
// This method bridges between the session's internal message format and
// the eino schema message format used by the LLM providers.
// It properly handles tool calls for assistant messages and tool call IDs
// for tool result messages. Arguments are converted to string format as
// required by the schema.
func (m *Message) ConvertToSchemaMessage() *schema.Message {
msg := &schema.Message{
Role: schema.RoleType(m.Role),
+15
View File
@@ -1 +1,16 @@
// Package tokens provides token counting and estimation functionality for
// various language model providers. It includes utilities for estimating
// token counts in text, as well as provider-specific implementations for
// more accurate token counting.
//
// The package supports multiple approaches to token counting:
// - Quick estimation using character-based heuristics
// - Provider-specific tokenizers for accurate counts
// - Initialization functions for setting up token counters
//
// Token counting is essential for:
// - Managing API rate limits
// - Calculating costs for API usage
// - Ensuring prompts fit within model context windows
// - Optimizing prompt engineering and response handling
package tokens
+18 -1
View File
@@ -1,6 +1,23 @@
package tokens
// EstimateTokens provides a rough estimate of tokens in text
// EstimateTokens estimates the number of tokens in the given text string.
// It uses a rough approximation of 4 characters per token, which is a common
// heuristic for most language models. This function provides a quick estimation
// without requiring model-specific tokenizers.
//
// The estimation may not be accurate for all models or text types, particularly
// for texts with many special characters, non-English languages, or code snippets.
// For more accurate token counting, use model-specific tokenizers when available.
//
// Parameters:
// - text: The input text string to estimate tokens for
//
// Returns:
// - int: The estimated number of tokens in the text
//
// Example:
//
// count := EstimateTokens("Hello, world!") // Returns approximately 3
func EstimateTokens(text string) int {
// Rough approximation: ~4 characters per token for most models
return len(text) / 4
+43 -2
View File
@@ -1,11 +1,52 @@
package tokens
// InitializeTokenCounters registers all available token counters
// InitializeTokenCounters registers all available token counters for various
// language model providers. This function should be called during application
// startup to ensure that token counting functionality is available for all
// supported models.
//
// Currently, this function is a placeholder for future provider-specific
// token counter implementations. As new providers are added (OpenAI, Anthropic,
// Google, etc.), their respective token counters will be registered here.
//
// This function does not require any API keys and will only initialize
// counters that can work without authentication.
//
// Example:
//
// func main() {
// tokens.InitializeTokenCounters()
// // Token counting is now available
// }
func InitializeTokenCounters() {
// Future provider-specific counters can be registered here
}
// InitializeTokenCountersWithKeys registers token counters with provided API keys
// InitializeTokenCountersWithKeys registers token counters for various language
// model providers using the provided API keys. This function enables more
// accurate token counting by allowing access to provider-specific tokenization
// endpoints or libraries that require authentication.
//
// This function should be called during application startup after API keys
// have been loaded from configuration or environment variables. It will
// initialize token counters for providers where API keys are available,
// enabling precise token counting that matches the provider's actual
// tokenization logic.
//
// The function will silently skip providers for which no API keys are
// configured, allowing the application to continue with partial token
// counting capabilities.
//
// Future implementations will accept provider-specific API keys through
// parameters or read them from a configuration context.
//
// Example:
//
// func main() {
// // Load API keys from environment or config
// tokens.InitializeTokenCountersWithKeys()
// // Provider-specific token counting is now available
// }
func InitializeTokenCountersWithKeys() {
// Future provider-specific counters can be registered here
}
+18 -5
View File
@@ -4,14 +4,19 @@ import (
"sync"
)
// BufferedDebugLogger stores debug messages until they can be displayed
// BufferedDebugLogger implements DebugLogger by storing debug messages in memory
// until they can be retrieved and displayed. This is useful when debug output
// needs to be deferred or batch-processed rather than immediately displayed.
// All methods are thread-safe for concurrent use.
type BufferedDebugLogger struct {
enabled bool
messages []string
mu sync.Mutex
}
// NewBufferedDebugLogger creates a new buffered debug logger
// NewBufferedDebugLogger creates a new buffered debug logger instance.
// The enabled parameter determines whether debug messages will be stored.
// If enabled is false, all LogDebug calls become no-ops for performance.
func NewBufferedDebugLogger(enabled bool) *BufferedDebugLogger {
return &BufferedDebugLogger{
enabled: enabled,
@@ -19,7 +24,10 @@ func NewBufferedDebugLogger(enabled bool) *BufferedDebugLogger {
}
}
// LogDebug stores a debug message
// LogDebug stores a debug message in the internal buffer if debug logging is enabled.
// Messages are appended to the buffer and retained until GetMessages is called.
// If debug logging is disabled, this method is a no-op.
// Thread-safe for concurrent calls.
func (l *BufferedDebugLogger) LogDebug(message string) {
if !l.enabled {
return
@@ -29,12 +37,17 @@ func (l *BufferedDebugLogger) LogDebug(message string) {
l.messages = append(l.messages, message)
}
// IsDebugEnabled returns whether debug logging is enabled
// IsDebugEnabled returns whether debug logging is enabled for this logger.
// This can be used to conditionally execute expensive debug operations
// only when debugging is actually enabled.
func (l *BufferedDebugLogger) IsDebugEnabled() bool {
return l.enabled
}
// GetMessages returns all buffered messages and clears the buffer
// GetMessages returns all buffered debug messages and clears the internal buffer.
// The returned slice contains all messages logged since the last call to GetMessages.
// After this call, the internal buffer is empty and ready to accumulate new messages.
// Thread-safe for concurrent calls.
func (l *BufferedDebugLogger) GetMessages() []string {
l.mu.Lock()
defer l.mu.Unlock()
+60 -18
View File
@@ -15,16 +15,20 @@ import (
"github.com/mark3labs/mcphost/internal/config"
)
// ConnectionPoolConfig configuration for connection pool
// ConnectionPoolConfig defines configuration parameters for the MCP connection pool.
// It controls connection lifecycle, health checking, and error handling behaviors.
type ConnectionPoolConfig struct {
MaxIdleTime time.Duration
MaxRetries int
HealthCheckInterval time.Duration
MaxErrorCount int
ReconnectDelay time.Duration
MaxIdleTime time.Duration // Maximum time a connection can remain idle before being marked unhealthy
MaxRetries int // Maximum number of retry attempts for failed operations
HealthCheckInterval time.Duration // Interval between background health checks of all connections
MaxErrorCount int // Maximum consecutive errors before marking a connection unhealthy
ReconnectDelay time.Duration // Delay before attempting to reconnect after connection failure
}
// DefaultConnectionPoolConfig returns default configuration
// DefaultConnectionPoolConfig returns a connection pool configuration with sensible defaults.
// Default values: 5 minute max idle time, 3 retries, 30 second health check interval,
// 3 max errors before marking unhealthy, and 2 second reconnect delay.
// These defaults are suitable for most MCP server deployments.
func DefaultConnectionPoolConfig() *ConnectionPoolConfig {
return &ConnectionPoolConfig{
MaxIdleTime: 5 * time.Minute,
@@ -35,7 +39,10 @@ func DefaultConnectionPoolConfig() *ConnectionPoolConfig {
}
}
// MCPConnection represents an MCP connection
// MCPConnection represents a single MCP client connection with health tracking and metadata.
// It wraps an MCP client and maintains state about connection health, usage patterns,
// and error history. Access to connection state is protected by a read-write mutex
// for thread-safe concurrent access.
type MCPConnection struct {
client client.MCPClient
serverName string
@@ -47,7 +54,11 @@ type MCPConnection struct {
mu sync.RWMutex
}
// MCPConnectionPool manages MCP connections
// MCPConnectionPool manages a pool of MCP client connections with automatic health checking,
// connection reuse, and failure recovery. It provides thread-safe connection management
// across multiple MCP servers, automatically handling connection lifecycle including
// creation, health monitoring, and cleanup. The pool runs background health checks
// to proactively identify and remove unhealthy connections.
type MCPConnectionPool struct {
connections map[string]*MCPConnection
config *ConnectionPoolConfig
@@ -59,7 +70,11 @@ type MCPConnectionPool struct {
debugLogger DebugLogger
}
// NewMCPConnectionPool creates a new connection pool
// NewMCPConnectionPool creates a new MCP connection pool with the specified configuration.
// If config is nil, default configuration values will be used. The pool starts a background
// goroutine for periodic health checks that runs until Close is called.
// The model parameter is used for MCP servers that require sampling support.
// Thread-safe for concurrent use immediately after creation.
func NewMCPConnectionPool(config *ConnectionPoolConfig, model model.ToolCallingChatModel, debug bool) *MCPConnectionPool {
if config == nil {
config = DefaultConnectionPoolConfig()
@@ -79,14 +94,20 @@ func NewMCPConnectionPool(config *ConnectionPoolConfig, model model.ToolCallingC
return pool
}
// SetDebugLogger sets the debug logger for the connection pool
// SetDebugLogger sets the debug logger for the connection pool.
// The logger will be used to output detailed information about connection lifecycle,
// health checks, and error conditions. Thread-safe and can be called at any time.
func (p *MCPConnectionPool) SetDebugLogger(logger DebugLogger) {
p.mu.Lock()
defer p.mu.Unlock()
p.debugLogger = logger
}
// GetConnection gets a connection from the pool
// GetConnection retrieves or creates a connection for the specified MCP server.
// If a healthy, non-idle connection exists in the pool, it will be reused.
// Otherwise, a new connection is created and added to the pool.
// Returns an error if connection creation or initialization fails.
// Thread-safe for concurrent calls.
func (p *MCPConnectionPool) GetConnection(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (*MCPConnection, error) {
p.mu.Lock()
defer p.mu.Unlock()
@@ -127,7 +148,12 @@ func (p *MCPConnectionPool) GetConnection(ctx context.Context, serverName string
return conn, nil
}
// GetConnectionWithHealthCheck gets a connection from the pool with proactive health check
// GetConnectionWithHealthCheck retrieves a connection with an additional proactive health check.
// Unlike GetConnection, this method performs a health check on existing connections before
// returning them, ensuring the connection is truly healthy. This is useful for critical
// operations where connection reliability is paramount. Creates a new connection if the
// existing one fails the health check or doesn't exist.
// Thread-safe for concurrent calls.
func (p *MCPConnectionPool) GetConnectionWithHealthCheck(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (*MCPConnection, error) {
p.mu.Lock()
defer p.mu.Unlock()
@@ -430,7 +456,11 @@ func (p *MCPConnectionPool) checkConnectionsHealth() {
}
}
// HandleConnectionError handles connection errors
// HandleConnectionError records and handles errors for a specific connection.
// It increments the error count and may mark the connection as unhealthy based on
// the error type and configured thresholds. Connection errors (network, transport, 404)
// immediately mark the connection as unhealthy for removal on next access.
// Thread-safe for concurrent error reporting.
func (p *MCPConnectionPool) HandleConnectionError(serverName string, err error) {
p.mu.RLock()
conn, exists := p.connections[serverName]
@@ -460,7 +490,10 @@ func (p *MCPConnectionPool) HandleConnectionError(serverName string, err error)
}
}
// GetConnectionStats returns connection statistics
// GetConnectionStats returns detailed statistics for all connections in the pool.
// The returned map includes health status, last usage time, error counts, and
// last error for each connection. Useful for monitoring and debugging connection
// pool behavior. The returned data is a snapshot and safe for concurrent access.
func (p *MCPConnectionPool) GetConnectionStats() map[string]interface{} {
p.mu.RLock()
defer p.mu.RUnlock()
@@ -480,12 +513,17 @@ func (p *MCPConnectionPool) GetConnectionStats() map[string]interface{} {
return stats
}
// ServerName returns the server name for this connection
// ServerName returns the server name associated with this MCP connection.
// This is the configured name from the MCPHost configuration, not necessarily
// the actual server implementation name.
func (c *MCPConnection) ServerName() string {
return c.serverName
}
// GetClients returns all client names in the pool
// GetClients returns a map of all MCP clients currently in the pool.
// The map keys are server names and values are the corresponding MCP client instances.
// The returned map is a copy and modifications won't affect the pool.
// Note that clients may be unhealthy; use GetConnectionStats to check health status.
func (p *MCPConnectionPool) GetClients() map[string]client.MCPClient {
p.mu.RLock()
defer p.mu.RUnlock()
@@ -497,7 +535,11 @@ func (p *MCPConnectionPool) GetClients() map[string]client.MCPClient {
return clients
}
// Close closes the connection pool
// Close gracefully shuts down the connection pool, closing all client connections
// and stopping the background health check goroutine. It attempts to close all
// connections even if some fail, logging any errors encountered.
// Safe to call multiple times; subsequent calls are no-ops.
// Always call Close when done with the pool to prevent resource leaks.
func (p *MCPConnectionPool) Close() error {
p.cancel()
+22 -5
View File
@@ -1,28 +1,45 @@
package tools
// DebugLogger interface for debug logging
// DebugLogger defines the interface for debug logging in the MCP tools package.
// Implementations can provide different strategies for handling debug output,
// such as immediate console output, buffering, or file logging.
// All implementations must be thread-safe for concurrent use.
type DebugLogger interface {
// LogDebug logs a debug message. Implementations determine how the message is handled.
LogDebug(message string)
// IsDebugEnabled returns true if debug logging is enabled, allowing callers
// to skip expensive debug operations when debugging is disabled.
IsDebugEnabled() bool
}
// SimpleDebugLogger is a simple implementation that prints to stdout
// SimpleDebugLogger provides a minimal implementation of the DebugLogger interface.
// It is intentionally silent by default to prevent duplicate or unstyled debug output
// during initialization. Debug messages are only displayed when using the CLI debug logger
// which provides proper formatting and styling.
type SimpleDebugLogger struct {
enabled bool
}
// NewSimpleDebugLogger creates a new simple debug logger
// NewSimpleDebugLogger creates a new simple debug logger instance.
// The enabled parameter determines whether IsDebugEnabled will return true.
// Note that LogDebug is intentionally a no-op to avoid unstyled output;
// actual debug output is handled by the CLI's debug logger.
func NewSimpleDebugLogger(enabled bool) *SimpleDebugLogger {
return &SimpleDebugLogger{enabled: enabled}
}
// LogDebug logs a debug message
// LogDebug is intentionally a no-op in SimpleDebugLogger.
// Debug messages are only displayed when using the CLI debug logger which provides
// proper formatting and styling. This prevents duplicate or unstyled debug output
// during initialization and ensures consistent debug output presentation.
func (l *SimpleDebugLogger) LogDebug(message string) {
// Silent by default - messages will only appear when using CLI debug logger
// This prevents duplicate or unstyled debug output during initialization
}
// IsDebugEnabled returns whether debug logging is enabled
// IsDebugEnabled returns whether debug logging is enabled for this logger.
// This allows code to conditionally execute expensive debug operations
// only when debugging is active, improving performance in production.
func (l *SimpleDebugLogger) IsDebugEnabled() bool {
return l.enabled
}
+43 -11
View File
@@ -19,7 +19,11 @@ import (
"github.com/mark3labs/mcphost/internal/config"
)
// MCPToolManager manages MCP tools and clients
// MCPToolManager manages MCP (Model Context Protocol) tools and clients across multiple servers.
// It provides a unified interface for loading, managing, and executing tools from various MCP servers,
// including stdio, SSE, streamable HTTP, and built-in server types. The manager handles connection
// pooling, health checks, tool name prefixing to avoid conflicts, and sampling support for LLM interactions.
// Thread-safe for concurrent tool invocations.
type MCPToolManager struct {
connectionPool *MCPConnectionPool
tools []tool.BaseTool
@@ -44,7 +48,9 @@ type mcpToolImpl struct {
mapping *toolMapping
}
// NewMCPToolManager creates a new MCP tool manager
// NewMCPToolManager creates a new MCP tool manager instance.
// Returns an initialized manager with empty tool collections ready to load tools from MCP servers.
// The manager must be configured with SetModel and LoadTools before use.
func NewMCPToolManager() *MCPToolManager {
return &MCPToolManager{
tools: make([]tool.BaseTool, 0),
@@ -52,12 +58,18 @@ func NewMCPToolManager() *MCPToolManager {
}
}
// SetModel sets the LLM model for sampling support
// SetModel sets the LLM model for sampling support.
// The model is used when MCP servers request sampling operations, allowing them to
// leverage the host's LLM capabilities for text generation tasks.
// This method should be called before LoadTools if any MCP servers require sampling support.
func (m *MCPToolManager) SetModel(model model.ToolCallingChatModel) {
m.model = model
}
// SetDebugLogger sets the debug logger
// SetDebugLogger sets the debug logger for the tool manager.
// The logger will be used to output detailed debugging information about MCP connections,
// tool loading, and execution. If a connection pool exists, it will also be configured
// to use the same logger for consistent debugging output.
func (m *MCPToolManager) SetDebugLogger(logger DebugLogger) {
m.debugLogger = logger
if m.connectionPool != nil {
@@ -70,7 +82,10 @@ type samplingHandler struct {
model model.ToolCallingChatModel
}
// CreateMessage handles sampling requests from MCP servers
// CreateMessage handles sampling requests from MCP servers by forwarding them to the configured LLM model.
// It converts MCP message formats to eino message formats, invokes the model for generation,
// and converts the response back to MCP format. Returns an error if no model is available
// or if generation fails.
func (h *samplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
if h.model == nil {
return nil, fmt.Errorf("no model available for sampling")
@@ -126,7 +141,11 @@ func (h *samplingHandler) CreateMessage(ctx context.Context, request mcp.CreateM
return result, nil
}
// LoadTools loads tools from MCP servers based on configuration
// LoadTools loads tools from all configured MCP servers based on the provided configuration.
// It initializes the connection pool, connects to each configured server, and loads their tools.
// Tools from different servers are prefixed with the server name to avoid naming conflicts.
// Returns an error only if all configured servers fail to load; partial failures are logged as warnings.
// This method is thread-safe and idempotent.
func (m *MCPToolManager) LoadTools(ctx context.Context, config *config.Config) error {
// Initialize connection pool
m.config = config
@@ -243,12 +262,18 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
return nil
}
// Info returns the tool information
// Info returns the tool information including name, description, and parameter schema.
// This method implements the eino tool.BaseTool interface.
// The returned ToolInfo contains the prefixed tool name to ensure uniqueness across servers.
func (t *mcpToolImpl) Info(ctx context.Context) (*schema.ToolInfo, error) {
return t.info, nil
}
// InvokableRun executes the tool by mapping back to the original name and server
// InvokableRun executes the tool by mapping the prefixed name back to the original tool name and server.
// It retrieves a healthy connection from the pool, invokes the tool on the appropriate MCP server,
// and returns the result as a JSON string. The method handles connection errors by marking
// connections as unhealthy in the pool for automatic recovery on subsequent requests.
// Thread-safe for concurrent invocations.
func (t *mcpToolImpl) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
// Handle empty or invalid JSON arguments
var arguments any
@@ -300,12 +325,16 @@ func (t *mcpToolImpl) InvokableRun(ctx context.Context, argumentsInJSON string,
return marshaledResult, nil
}
// GetTools returns all loaded tools
// GetTools returns all loaded tools from all configured MCP servers.
// Tools are returned with their prefixed names (serverName__toolName) to ensure uniqueness.
// The returned slice is a copy and can be safely modified by the caller.
func (m *MCPToolManager) GetTools() []tool.BaseTool {
return m.tools
}
// GetLoadedServerNames returns the names of successfully loaded MCP servers
// GetLoadedServerNames returns the names of all successfully loaded MCP servers.
// This includes servers that are currently connected and have had their tools loaded,
// regardless of their current health status. Useful for debugging and status reporting.
func (m *MCPToolManager) GetLoadedServerNames() []string {
var names []string
for serverName := range m.connectionPool.GetClients() {
@@ -314,7 +343,10 @@ func (m *MCPToolManager) GetLoadedServerNames() []string {
return names
}
// Close closes all MCP clients
// Close closes all MCP client connections and cleans up resources.
// This method should be called when the tool manager is no longer needed to ensure
// proper cleanup of stdio processes, network connections, and other resources.
// It is safe to call Close multiple times.
func (m *MCPToolManager) Close() error {
return m.connectionPool.Close()
}
+30 -10
View File
@@ -21,70 +21,90 @@ type blockRenderer struct {
// renderingOption configures block rendering
type renderingOption func(*blockRenderer)
// WithFullWidth makes the block take full available width
// WithFullWidth returns a renderingOption that configures the block renderer
// to expand to the full available width of its container. When enabled, the
// block will fill the entire horizontal space rather than sizing to its content.
func WithFullWidth() renderingOption {
return func(c *blockRenderer) {
c.fullWidth = true
}
}
// WithAlign sets the horizontal alignment of the block
// WithAlign returns a renderingOption that sets the horizontal alignment
// of the block content within its container. The align parameter accepts
// lipgloss.Left, lipgloss.Center, or lipgloss.Right positions.
func WithAlign(align lipgloss.Position) renderingOption {
return func(c *blockRenderer) {
c.align = &align
}
}
// WithBorderColor sets the border color
// WithBorderColor returns a renderingOption that sets the border color
// for the block. The color parameter uses lipgloss.AdaptiveColor to support
// both light and dark terminal themes automatically.
func WithBorderColor(color lipgloss.AdaptiveColor) renderingOption {
return func(c *blockRenderer) {
c.borderColor = &color
}
}
// WithMarginTop sets the top margin
// WithMarginTop returns a renderingOption that sets the top margin
// for the block. The margin is specified in number of lines and adds
// vertical space above the block.
func WithMarginTop(margin int) renderingOption {
return func(c *blockRenderer) {
c.marginTop = margin
}
}
// WithMarginBottom sets the bottom margin
// WithMarginBottom returns a renderingOption that sets the bottom margin
// for the block. The margin is specified in number of lines and adds
// vertical space below the block.
func WithMarginBottom(margin int) renderingOption {
return func(c *blockRenderer) {
c.marginBottom = margin
}
}
// WithPaddingLeft sets the left padding
// WithPaddingLeft returns a renderingOption that sets the left padding
// for the block content. The padding is specified in number of characters
// and adds horizontal space between the left border and the content.
func WithPaddingLeft(padding int) renderingOption {
return func(c *blockRenderer) {
c.paddingLeft = padding
}
}
// WithPaddingRight sets the right padding
// WithPaddingRight returns a renderingOption that sets the right padding
// for the block content. The padding is specified in number of characters
// and adds horizontal space between the content and the right border.
func WithPaddingRight(padding int) renderingOption {
return func(c *blockRenderer) {
c.paddingRight = padding
}
}
// WithPaddingTop sets the top padding
// WithPaddingTop returns a renderingOption that sets the top padding
// for the block content. The padding is specified in number of lines
// and adds vertical space between the top border and the content.
func WithPaddingTop(padding int) renderingOption {
return func(c *blockRenderer) {
c.paddingTop = padding
}
}
// WithPaddingBottom sets the bottom padding
// WithPaddingBottom returns a renderingOption that sets the bottom padding
// for the block content. The padding is specified in number of lines
// and adds vertical space between the content and the bottom border.
func WithPaddingBottom(padding int) renderingOption {
return func(c *blockRenderer) {
c.paddingBottom = padding
}
}
// WithWidth sets a specific width for the block
// WithWidth returns a renderingOption that sets a specific width for the block
// in characters. This overrides the default container width and allows precise
// control over the block's horizontal dimensions.
func WithWidth(width int) renderingOption {
return func(c *blockRenderer) {
c.width = width
+5 -1
View File
@@ -9,7 +9,11 @@ import (
utilCallbacks "github.com/cloudwego/eino/utils/callbacks"
)
// CreateCallbackHandler creates a callback handler using HandlerHelper
// CreateCallbackHandler creates and returns a callbacks.Handler that manages
// tool execution callbacks for the CLI. The handler displays tool calls,
// handles errors, and manages streaming output for interactive tool operations.
// It integrates with the eino callback system to provide real-time UI feedback
// during tool execution.
func (c *CLI) CreateCallbackHandler() callbacks.Handler {
toolHandler := &utilCallbacks.ToolCallbackHandler{
OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *tool.CallbackInput) context.Context {
+95 -31
View File
@@ -17,7 +17,10 @@ var (
promptStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("12"))
)
// CLI handles the command line interface with improved message rendering
// CLI manages the command-line interface for MCPHost, providing message rendering,
// user input handling, and display management. It supports both standard and compact
// display modes, handles streaming responses, tracks token usage, and manages the
// overall conversation flow between the user and AI assistants.
type CLI struct {
messageRenderer *MessageRenderer
compactRenderer *CompactRenderer // Add compact renderer
@@ -32,7 +35,10 @@ type CLI struct {
usageDisplayed bool // track if usage info was displayed after last assistant message
}
// NewCLI creates a new CLI instance with message container
// NewCLI creates and initializes a new CLI instance with the specified display modes.
// The debug parameter enables debug message rendering, while compact enables a more
// condensed display format. Returns an initialized CLI ready for interaction or an
// error if initialization fails.
func NewCLI(debug bool, compact bool) (*CLI, error) {
cli := &CLI{
compactMode: compact,
@@ -46,7 +52,9 @@ func NewCLI(debug bool, compact bool) (*CLI, error) {
return cli, nil
}
// SetUsageTracker sets the usage tracker for the CLI
// SetUsageTracker attaches a usage tracker to the CLI for monitoring token
// consumption and costs. The tracker will be automatically updated with the
// current display width for proper rendering.
func (c *CLI) SetUsageTracker(tracker *UsageTracker) {
c.usageTracker = tracker
if c.usageTracker != nil {
@@ -54,12 +62,14 @@ func (c *CLI) SetUsageTracker(tracker *UsageTracker) {
}
}
// GetDebugLogger returns a debug logger that uses the CLI for rendering
// GetDebugLogger returns a CLIDebugLogger instance that routes debug output
// through the CLI's rendering system for consistent message formatting and display.
func (c *CLI) GetDebugLogger() *CLIDebugLogger {
return NewCLIDebugLogger(c)
}
// SetModelName sets the current model name for the CLI
// SetModelName updates the current AI model name being used in the conversation.
// This name is displayed in message headers to indicate which model is responding.
func (c *CLI) SetModelName(modelName string) {
c.modelName = modelName
if c.messageContainer != nil {
@@ -67,7 +77,10 @@ func (c *CLI) SetModelName(modelName string) {
}
}
// GetPrompt gets user input using the huh library with divider and padding
// GetPrompt displays an interactive prompt and waits for user input. It provides
// slash command support, multi-line editing, and cancellation handling. Returns
// the user's input as a string, or an error if the operation was cancelled or
// failed. Returns io.EOF for clean exit signals.
func (c *CLI) GetPrompt() (string, error) {
// Usage info is now displayed immediately after responses via DisplayUsageAfterResponse()
// No need to display it here to avoid duplication
@@ -107,7 +120,9 @@ func (c *CLI) GetPrompt() (string, error) {
return "", fmt.Errorf("unexpected model type")
}
// ShowSpinner displays a spinner with the given message and executes the action
// ShowSpinner displays an animated spinner with the specified message while
// executing the provided action function. The spinner automatically stops when
// the action completes. Returns any error returned by the action function.
func (c *CLI) ShowSpinner(message string, action func() error) error {
spinner := NewSpinner(message)
spinner.Start()
@@ -119,7 +134,9 @@ func (c *CLI) ShowSpinner(message string, action func() error) error {
return err
}
// DisplayUserMessage displays the user's message using the appropriate renderer
// DisplayUserMessage renders and displays a user's message with appropriate
// formatting based on the current display mode (standard or compact). The message
// is timestamped and styled according to the active theme.
func (c *CLI) DisplayUserMessage(message string) {
var msg UIMessage
if c.compactMode {
@@ -131,12 +148,16 @@ func (c *CLI) DisplayUserMessage(message string) {
c.displayContainer()
}
// DisplayAssistantMessage displays the assistant's message using the new renderer
// DisplayAssistantMessage renders and displays an AI assistant's response message
// with appropriate formatting. This method delegates to DisplayAssistantMessageWithModel
// with an empty model name for backward compatibility.
func (c *CLI) DisplayAssistantMessage(message string) error {
return c.DisplayAssistantMessageWithModel(message, "")
}
// DisplayAssistantMessageWithModel displays the assistant's message with model info
// DisplayAssistantMessageWithModel renders and displays an AI assistant's response
// with the specified model name shown in the message header. The message is
// formatted according to the current display mode and includes timestamp information.
func (c *CLI) DisplayAssistantMessageWithModel(message, modelName string) error {
var msg UIMessage
if c.compactMode {
@@ -149,7 +170,9 @@ func (c *CLI) DisplayAssistantMessageWithModel(message, modelName string) error
return nil
}
// DisplayToolCallMessage displays a tool call in progress
// DisplayToolCallMessage renders and displays a message indicating that a tool
// is being executed. Shows the tool name and its arguments formatted appropriately
// for the current display mode. This is typically shown while a tool is running.
func (c *CLI) DisplayToolCallMessage(toolName, toolArgs string) {
c.messageContainer.messages = nil // clear previous messages (they should have been printed already)
@@ -167,7 +190,9 @@ func (c *CLI) DisplayToolCallMessage(toolName, toolArgs string) {
c.displayContainer()
}
// DisplayToolMessage displays a tool call message
// DisplayToolMessage renders and displays the complete result of a tool execution,
// including the tool name, arguments, and result. The isError parameter determines
// whether the result should be displayed as an error or success message.
func (c *CLI) DisplayToolMessage(toolName, toolArgs, toolResult string, isError bool) {
var msg UIMessage
if c.compactMode {
@@ -181,7 +206,9 @@ func (c *CLI) DisplayToolMessage(toolName, toolArgs, toolResult string, isError
c.displayContainer()
}
// StartStreamingMessage starts a streaming assistant message
// StartStreamingMessage initializes a new streaming message display for real-time
// AI responses. The message will be progressively updated as content arrives.
// The modelName parameter indicates which AI model is generating the response.
func (c *CLI) StartStreamingMessage(modelName string) {
// Add an empty assistant message that we'll update during streaming
var msg UIMessage
@@ -196,14 +223,18 @@ func (c *CLI) StartStreamingMessage(modelName string) {
c.displayContainer()
}
// UpdateStreamingMessage updates the streaming message with new content
// UpdateStreamingMessage updates the currently streaming message with new content.
// This method should be called after StartStreamingMessage to progressively display
// AI responses as they are generated in real-time.
func (c *CLI) UpdateStreamingMessage(content string) {
// Update the last message (which should be the streaming assistant message)
c.messageContainer.UpdateLastMessage(content)
c.displayContainer()
}
// DisplayError displays an error message using the appropriate renderer
// DisplayError renders and displays an error message with distinctive formatting
// to ensure visibility. The error is timestamped and styled according to the
// current display mode's error theme.
func (c *CLI) DisplayError(err error) {
var msg UIMessage
if c.compactMode {
@@ -215,7 +246,9 @@ func (c *CLI) DisplayError(err error) {
c.displayContainer()
}
// DisplayInfo displays an informational message using the appropriate renderer
// DisplayInfo renders and displays an informational system message. These messages
// are typically used for status updates, notifications, or other non-error system
// communications to the user.
func (c *CLI) DisplayInfo(message string) {
var msg UIMessage
if c.compactMode {
@@ -227,7 +260,8 @@ func (c *CLI) DisplayInfo(message string) {
c.displayContainer()
}
// DisplayCancellation displays a cancellation message
// DisplayCancellation displays a system message indicating that the current
// AI generation has been cancelled by the user (typically via ESC key).
func (c *CLI) DisplayCancellation() {
var msg UIMessage
if c.compactMode {
@@ -239,7 +273,9 @@ func (c *CLI) DisplayCancellation() {
c.displayContainer()
}
// DisplayDebugMessage displays debug messages using the appropriate renderer
// DisplayDebugMessage renders and displays a debug message if debug mode is enabled.
// Debug messages are formatted distinctively and only shown when the CLI is
// initialized with debug=true.
func (c *CLI) DisplayDebugMessage(message string) {
if !c.debug {
return
@@ -254,7 +290,9 @@ func (c *CLI) DisplayDebugMessage(message string) {
c.displayContainer()
}
// DisplayDebugConfig displays configuration settings using the appropriate renderer
// DisplayDebugConfig renders and displays configuration settings in a formatted
// debug message. The config parameter should contain key-value pairs representing
// configuration options that will be displayed for debugging purposes.
func (c *CLI) DisplayDebugConfig(config map[string]any) {
var msg UIMessage
if c.compactMode {
@@ -266,7 +304,9 @@ func (c *CLI) DisplayDebugConfig(config map[string]any) {
c.displayContainer()
}
// DisplayHelp displays help information in a message block
// DisplayHelp renders and displays comprehensive help information showing all
// available slash commands, keyboard shortcuts, and usage instructions in a
// formatted system message block.
func (c *CLI) DisplayHelp() {
help := `## Available Commands
@@ -288,7 +328,9 @@ You can also just type your message to chat with the AI assistant.`
c.displayContainer()
}
// DisplayTools displays available tools in a message block
// DisplayTools renders and displays a formatted list of all available tools
// that can be used by the AI assistant. Each tool is numbered and shown in
// a system message block for easy reference.
func (c *CLI) DisplayTools(tools []string) {
var content strings.Builder
content.WriteString("## Available Tools\n\n")
@@ -307,7 +349,9 @@ func (c *CLI) DisplayTools(tools []string) {
c.displayContainer()
}
// DisplayServers displays configured MCP servers in a message block
// DisplayServers renders and displays a formatted list of all configured MCP
// (Model Context Protocol) servers. Each server is numbered and shown in a
// system message block for easy reference.
func (c *CLI) DisplayServers(servers []string) {
var content strings.Builder
content.WriteString("## Configured MCP Servers\n\n")
@@ -326,18 +370,25 @@ func (c *CLI) DisplayServers(servers []string) {
c.displayContainer()
}
// IsSlashCommand checks if the input is a slash command
// IsSlashCommand determines whether the provided input string is a slash command
// by checking if it starts with a forward slash (/). Returns true for commands
// like "/help", "/tools", etc.
func (c *CLI) IsSlashCommand(input string) bool {
return strings.HasPrefix(input, "/")
}
// SlashCommandResult represents the result of handling a slash command
// SlashCommandResult encapsulates the outcome of processing a slash command,
// indicating whether the command was recognized and handled, and whether the
// conversation history should be cleared as a result of the command.
type SlashCommandResult struct {
Handled bool
ClearHistory bool
}
// HandleSlashCommand handles slash commands and returns the result
// HandleSlashCommand processes and executes slash commands, returning a result
// that indicates whether the command was handled and any side effects. The servers
// and tools parameters provide context for commands that display available resources.
// Supported commands include /help, /tools, /servers, /clear, /usage, /reset-usage, and /quit.
func (c *CLI) HandleSlashCommand(input string, servers []string, tools []string) SlashCommandResult {
switch input {
case "/help":
@@ -369,7 +420,9 @@ func (c *CLI) HandleSlashCommand(input string, servers []string, tools []string)
}
}
// ClearMessages clears all messages from the container
// ClearMessages removes all messages from the display container and refreshes
// the screen. This is typically used when starting a new conversation or
// clearing the chat history.
func (c *CLI) ClearMessages() {
c.messageContainer.Clear()
c.displayContainer()
@@ -422,14 +475,19 @@ func (c *CLI) displayContainer() {
}
}
// UpdateUsage updates the usage tracker with token counts and costs
// UpdateUsage estimates and records token usage based on input and output text.
// This method uses text-based estimation when actual token counts are not available
// from the AI provider's response metadata.
func (c *CLI) UpdateUsage(inputText, outputText string) {
if c.usageTracker != nil {
c.usageTracker.EstimateAndUpdateUsage(inputText, outputText)
}
}
// UpdateUsageFromResponse updates the usage tracker using token usage from response metadata
// UpdateUsageFromResponse records token usage using metadata from the AI provider's
// response when available. Falls back to text-based estimation if the metadata is
// missing or appears unreliable. This provides more accurate usage tracking when
// providers supply token count information.
func (c *CLI) UpdateUsageFromResponse(response *schema.Message, inputText string) {
if c.usageTracker == nil {
return
@@ -461,7 +519,9 @@ func (c *CLI) UpdateUsageFromResponse(response *schema.Message, inputText string
}
}
// DisplayUsageStats displays current usage statistics
// DisplayUsageStats renders and displays comprehensive token usage statistics
// including the last request's token counts and costs, as well as session totals.
// Shows a message if usage tracking is not available for the current model.
func (c *CLI) DisplayUsageStats() {
if c.usageTracker == nil {
c.DisplayInfo("Usage tracking is not available for this model.")
@@ -492,7 +552,9 @@ func (c *CLI) DisplayUsageStats() {
c.displayContainer()
}
// ResetUsageStats resets the usage tracking statistics
// ResetUsageStats clears all accumulated usage statistics, resetting token counts
// and costs to zero. Displays a confirmation message after resetting or an info
// message if usage tracking is not available.
func (c *CLI) ResetUsageStats() {
if c.usageTracker == nil {
c.DisplayInfo("Usage tracking is not available for this model.")
@@ -503,7 +565,9 @@ func (c *CLI) ResetUsageStats() {
c.DisplayInfo("Usage statistics have been reset.")
}
// DisplayUsageAfterResponse displays usage information immediately after a response
// DisplayUsageAfterResponse renders and displays token usage information immediately
// following an AI response. This provides real-time feedback about the cost and
// token consumption of each interaction.
func (c *CLI) DisplayUsageAfterResponse() {
if c.usageTracker == nil {
return
+12 -4
View File
@@ -1,6 +1,8 @@
package ui
// SlashCommand represents a slash command with its metadata
// SlashCommand represents a user-invokable slash command with its metadata.
// Commands can have multiple aliases and are organized by category for better
// discoverability and help display.
type SlashCommand struct {
Name string
Description string
@@ -8,7 +10,9 @@ type SlashCommand struct {
Category string // e.g., "Navigation", "System", "Info"
}
// SlashCommands is the registry of all available slash commands
// SlashCommands provides the global registry of all available slash commands
// in the application. Commands are organized by category (Info, System) and
// include their primary names, descriptions, and alternative aliases.
var SlashCommands = []SlashCommand{
{
Name: "/help",
@@ -55,7 +59,9 @@ var SlashCommands = []SlashCommand{
},
}
// GetCommandByName returns a command by its name or alias
// GetCommandByName looks up a slash command by its primary name or any of its
// aliases. Returns a pointer to the matching SlashCommand, or nil if no command
// matches the provided name.
func GetCommandByName(name string) *SlashCommand {
for i := range SlashCommands {
cmd := &SlashCommands[i]
@@ -71,7 +77,9 @@ func GetCommandByName(name string) *SlashCommand {
return nil
}
// GetAllCommandNames returns all command names and aliases
// GetAllCommandNames returns a complete list of all command names and their aliases.
// This is useful for command completion, validation, and help display. The returned
// slice contains both primary command names and all alternative aliases.
func GetAllCommandNames() []string {
var names []string
for _, cmd := range SlashCommands {
+33 -11
View File
@@ -8,13 +8,17 @@ import (
"github.com/charmbracelet/lipgloss"
)
// CompactRenderer handles rendering messages in compact format
// CompactRenderer handles rendering messages in a space-efficient compact format,
// optimized for terminals with limited vertical space. It displays messages with
// minimal decorations while maintaining readability and essential information.
type CompactRenderer struct {
width int
debug bool
}
// NewCompactRenderer creates a new compact message renderer
// NewCompactRenderer creates and initializes a new CompactRenderer with the specified
// terminal width and debug mode setting. The width parameter determines line wrapping,
// while debug enables additional diagnostic output in rendered messages.
func NewCompactRenderer(width int, debug bool) *CompactRenderer {
return &CompactRenderer{
width: width,
@@ -22,12 +26,16 @@ func NewCompactRenderer(width int, debug bool) *CompactRenderer {
}
}
// SetWidth updates the renderer width
// SetWidth updates the terminal width for the renderer, affecting how content
// is wrapped and formatted in subsequent render operations.
func (r *CompactRenderer) SetWidth(width int) {
r.width = width
}
// RenderUserMessage renders a user message in compact format
// RenderUserMessage renders a user's input message in compact format with a
// distinctive symbol (>) and label. The content is formatted to preserve structure
// while minimizing vertical space usage. Returns a UIMessage with formatted content
// and metadata.
func (r *CompactRenderer) RenderUserMessage(content string, timestamp time.Time) UIMessage {
theme := getTheme()
symbol := lipgloss.NewStyle().Foreground(theme.Secondary).Render(">")
@@ -58,7 +66,9 @@ func (r *CompactRenderer) RenderUserMessage(content string, timestamp time.Time)
}
}
// RenderAssistantMessage renders an assistant message in compact format
// RenderAssistantMessage renders an AI assistant's response in compact format with
// a distinctive symbol (<) and the model name as label. Empty content is displayed
// as "(no output)". Returns a UIMessage with formatted content and metadata.
func (r *CompactRenderer) RenderAssistantMessage(content string, timestamp time.Time, modelName string) UIMessage {
theme := getTheme()
symbol := lipgloss.NewStyle().Foreground(theme.Primary).Render("<")
@@ -97,7 +107,9 @@ func (r *CompactRenderer) RenderAssistantMessage(content string, timestamp time.
}
}
// RenderToolCallMessage renders a tool call in progress in compact format
// RenderToolCallMessage renders a tool call notification in compact format, showing
// the tool being executed with its arguments in a single line. The tool name is
// highlighted and arguments are displayed in a muted color for visual distinction.
func (r *CompactRenderer) RenderToolCallMessage(toolName, toolArgs string, timestamp time.Time) UIMessage {
theme := getTheme()
symbol := lipgloss.NewStyle().Foreground(theme.Tool).Render("[")
@@ -119,7 +131,9 @@ func (r *CompactRenderer) RenderToolCallMessage(toolName, toolArgs string, times
}
}
// RenderToolMessage renders a tool result in compact format
// RenderToolMessage renders the result of a tool execution in compact format,
// displaying the outcome with appropriate styling based on success or error status.
// Results are limited to 5 lines to maintain compact display while preserving key information.
func (r *CompactRenderer) RenderToolMessage(toolName, toolArgs, toolResult string, isError bool) UIMessage {
theme := getTheme()
symbol := lipgloss.NewStyle().Foreground(theme.Muted).Render("]")
@@ -165,7 +179,9 @@ func (r *CompactRenderer) RenderToolMessage(toolName, toolArgs, toolResult strin
}
}
// RenderSystemMessage renders a system message in compact format
// RenderSystemMessage renders a system notification or informational message in
// compact format with a distinctive symbol (*) and "System" label. Content is
// formatted to fit on a single line for minimal space usage.
func (r *CompactRenderer) RenderSystemMessage(content string, timestamp time.Time) UIMessage {
theme := getTheme()
symbol := lipgloss.NewStyle().Foreground(theme.System).Render("*")
@@ -183,7 +199,9 @@ func (r *CompactRenderer) RenderSystemMessage(content string, timestamp time.Tim
}
}
// RenderErrorMessage renders an error message in compact format
// RenderErrorMessage renders an error notification in compact format with a
// distinctive error symbol (!) and styling to ensure visibility. The error
// content is displayed in a single line with appropriate color highlighting.
func (r *CompactRenderer) RenderErrorMessage(errorMsg string, timestamp time.Time) UIMessage {
theme := getTheme()
symbol := lipgloss.NewStyle().Foreground(theme.Error).Render("!")
@@ -201,7 +219,9 @@ func (r *CompactRenderer) RenderErrorMessage(errorMsg string, timestamp time.Tim
}
}
// RenderDebugMessage renders debug messages in compact format
// RenderDebugMessage renders diagnostic information in compact format when debug
// mode is enabled. Messages are truncated if they exceed the available width to
// maintain single-line display.
func (r *CompactRenderer) RenderDebugMessage(message string, timestamp time.Time) UIMessage {
theme := getTheme()
symbol := lipgloss.NewStyle().Foreground(theme.Tool).Render("*")
@@ -223,7 +243,9 @@ func (r *CompactRenderer) RenderDebugMessage(message string, timestamp time.Time
}
}
// RenderDebugConfigMessage renders debug config in compact format
// RenderDebugConfigMessage renders configuration settings in compact format for
// debugging purposes. Config entries are displayed as key=value pairs separated
// by commas, truncated if necessary to fit on a single line.
func (r *CompactRenderer) RenderDebugConfigMessage(config map[string]any, timestamp time.Time) UIMessage {
theme := getTheme()
symbol := lipgloss.NewStyle().Foreground(theme.Tool).Render("*")
+14 -4
View File
@@ -6,17 +6,25 @@ import (
"time"
)
// CLIDebugLogger implements the tools.DebugLogger interface using CLI rendering
// CLIDebugLogger implements the tools.DebugLogger interface using CLI rendering.
// It provides debug logging functionality that integrates with the CLI's display
// system, ensuring debug messages are properly formatted and displayed alongside
// other conversation content.
type CLIDebugLogger struct {
cli *CLI
}
// NewCLIDebugLogger creates a new CLI debug logger
// NewCLIDebugLogger creates and returns a new CLIDebugLogger instance that routes
// debug output through the provided CLI instance. The logger will respect the CLI's
// debug mode setting and display format preferences.
func NewCLIDebugLogger(cli *CLI) *CLIDebugLogger {
return &CLIDebugLogger{cli: cli}
}
// LogDebug logs a debug message using the CLI's debug message renderer
// LogDebug processes and displays a debug message through the CLI's rendering system.
// Messages are formatted with appropriate emojis and tags based on their content type
// (DEBUG, POOL, etc.) and only displayed when debug mode is enabled. The method handles
// multi-line debug output and connection pool status messages with context-aware formatting.
func (l *CLIDebugLogger) LogDebug(message string) {
if l.cli == nil || !l.cli.debug {
return
@@ -70,7 +78,9 @@ func (l *CLIDebugLogger) LogDebug(message string) {
l.cli.displayContainer()
}
// IsDebugEnabled returns whether debug logging is enabled
// IsDebugEnabled checks whether debug logging is currently active. Returns true
// if the CLI instance exists and has debug mode enabled, allowing callers to
// conditionally perform expensive debug operations only when necessary.
func (l *CLIDebugLogger) IsDebugEnabled() bool {
return l.cli != nil && l.cli.debug
}
+46 -20
View File
@@ -11,17 +11,21 @@ import (
// Global theme instance
var currentTheme = DefaultTheme()
// GetTheme returns the current theme
// GetTheme returns the currently active UI theme. The theme controls all color
// and styling decisions throughout the application's interface.
func GetTheme() Theme {
return currentTheme
}
// SetTheme sets the current theme
// SetTheme updates the global UI theme, affecting all subsequent rendering
// operations. This allows runtime theme switching for different visual preferences.
func SetTheme(theme Theme) {
currentTheme = theme
}
// Theme represents a complete UI theme
// Theme defines a comprehensive color scheme for the application's UI, supporting
// both light and dark terminal modes through adaptive colors. It includes semantic
// colors for different message types and UI elements, based on the Catppuccin color palette.
type Theme struct {
Primary lipgloss.AdaptiveColor
Secondary lipgloss.AdaptiveColor
@@ -41,7 +45,9 @@ type Theme struct {
Highlight lipgloss.AdaptiveColor
}
// DefaultTheme returns the default MCPHost theme (Catppuccin Mocha)
// DefaultTheme creates and returns the default MCPHost theme based on the Catppuccin
// Mocha (dark) and Latte (light) color palettes. This theme provides a cohesive,
// pleasant visual experience with carefully selected colors for different UI elements.
func DefaultTheme() Theme {
return Theme{
Primary: lipgloss.AdaptiveColor{
@@ -111,7 +117,9 @@ func DefaultTheme() Theme {
}
}
// StyleCard creates a styled card container
// StyleCard creates a lipgloss style for card-like containers with rounded borders,
// padding, and appropriate width. Used for grouping related content in a visually
// distinct box.
func StyleCard(width int, theme Theme) lipgloss.Style {
return lipgloss.NewStyle().
Width(width).
@@ -121,56 +129,64 @@ func StyleCard(width int, theme Theme) lipgloss.Style {
MarginBottom(1)
}
// StyleHeader creates a styled header
// StyleHeader creates a lipgloss style for primary headers using the theme's
// primary color with bold text for emphasis and hierarchy.
func StyleHeader(theme Theme) lipgloss.Style {
return lipgloss.NewStyle().
Foreground(theme.Primary).
Bold(true)
}
// StyleSubheader creates a styled subheader
// StyleSubheader creates a lipgloss style for secondary headers using the theme's
// secondary color with bold text, providing visual hierarchy below primary headers.
func StyleSubheader(theme Theme) lipgloss.Style {
return lipgloss.NewStyle().
Foreground(theme.Secondary).
Bold(true)
}
// StyleMuted creates muted text styling
// StyleMuted creates a lipgloss style for de-emphasized text using muted colors
// and italic formatting, suitable for supplementary or less important information.
func StyleMuted(theme Theme) lipgloss.Style {
return lipgloss.NewStyle().
Foreground(theme.Muted).
Italic(true)
}
// StyleSuccess creates success text styling
// StyleSuccess creates a lipgloss style for success messages using green colors
// with bold text to indicate successful operations or positive outcomes.
func StyleSuccess(theme Theme) lipgloss.Style {
return lipgloss.NewStyle().
Foreground(theme.Success).
Bold(true)
}
// StyleError creates error text styling
// StyleError creates a lipgloss style for error messages using red colors
// with bold text to ensure visibility of problems or failures.
func StyleError(theme Theme) lipgloss.Style {
return lipgloss.NewStyle().
Foreground(theme.Error).
Bold(true)
}
// StyleWarning creates warning text styling
// StyleWarning creates a lipgloss style for warning messages using yellow/amber
// colors with bold text to draw attention to potential issues or cautions.
func StyleWarning(theme Theme) lipgloss.Style {
return lipgloss.NewStyle().
Foreground(theme.Warning).
Bold(true)
}
// StyleInfo creates info text styling
// StyleInfo creates a lipgloss style for informational messages using blue colors
// with bold text for general notifications and status updates.
func StyleInfo(theme Theme) lipgloss.Style {
return lipgloss.NewStyle().
Foreground(theme.Info).
Bold(true)
}
// CreateSeparator creates a styled separator line
// CreateSeparator generates a horizontal separator line with the specified width,
// character, and color. Useful for visually dividing sections of content in the UI.
func CreateSeparator(width int, char string, color lipgloss.AdaptiveColor) string {
return lipgloss.NewStyle().
Foreground(color).
@@ -178,7 +194,9 @@ func CreateSeparator(width int, char string, color lipgloss.AdaptiveColor) strin
Render(lipgloss.PlaceHorizontal(width, lipgloss.Center, char))
}
// CreateProgressBar creates a simple progress bar
// CreateProgressBar generates a visual progress bar with filled and empty segments
// based on the percentage complete. The bar uses Unicode block characters for smooth
// appearance and theme colors to indicate progress.
func CreateProgressBar(width int, percentage float64, theme Theme) string {
filled := int(float64(width) * percentage / 100)
empty := width - filled
@@ -194,7 +212,8 @@ func CreateProgressBar(width int, percentage float64, theme Theme) string {
return filledBar + emptyBar
}
// CreateBadge creates a styled badge
// CreateBadge generates a styled badge or label with inverted colors (text on
// colored background) for highlighting important tags, statuses, or categories.
func CreateBadge(text string, color lipgloss.AdaptiveColor) string {
return lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "#FFFFFF", Dark: "#000000"}).
@@ -204,7 +223,9 @@ func CreateBadge(text string, color lipgloss.AdaptiveColor) string {
Render(text)
}
// CreateGradientText creates text with gradient-like effect using different shades
// CreateGradientText creates styled text with a gradient-like effect. Currently
// implements a simplified version using the start color only, as true gradients
// require more complex terminal capabilities.
func CreateGradientText(text string, startColor, endColor lipgloss.AdaptiveColor) string {
// For now, just use the start color - true gradients would require more complex implementation
return lipgloss.NewStyle().
@@ -215,14 +236,16 @@ func CreateGradientText(text string, startColor, endColor lipgloss.AdaptiveColor
// Compact styling utilities
// StyleCompactSymbol creates a styled symbol for compact mode
// StyleCompactSymbol creates a lipgloss style for message type indicators in
// compact mode, using bold colored text to distinguish different message categories.
func StyleCompactSymbol(symbol string, color lipgloss.AdaptiveColor) lipgloss.Style {
return lipgloss.NewStyle().
Foreground(color).
Bold(true)
}
// StyleCompactLabel creates a styled label for compact mode
// StyleCompactLabel creates a lipgloss style for message labels in compact mode
// with fixed width for alignment and bold colored text for readability.
func StyleCompactLabel(color lipgloss.AdaptiveColor) lipgloss.Style {
return lipgloss.NewStyle().
Foreground(color).
@@ -230,13 +253,16 @@ func StyleCompactLabel(color lipgloss.AdaptiveColor) lipgloss.Style {
Width(8)
}
// StyleCompactContent creates basic content styling for compact mode
// StyleCompactContent creates a simple lipgloss style for message content in
// compact mode, applying only color without additional formatting.
func StyleCompactContent(color lipgloss.AdaptiveColor) lipgloss.Style {
return lipgloss.NewStyle().
Foreground(color)
}
// FormatCompactLine formats a complete compact line with consistent spacing
// FormatCompactLine assembles a complete compact mode message line with consistent
// spacing and styling. Combines a symbol, fixed-width label, and content with their
// respective colors to create a uniform appearance across all message types.
func FormatCompactLine(symbol, label, content string, symbolColor, labelColor, contentColor lipgloss.AdaptiveColor) string {
styledSymbol := StyleCompactSymbol(symbol, symbolColor).Render(symbol)
styledLabel := StyleCompactLabel(labelColor).Render(label)
+9 -3
View File
@@ -8,14 +8,17 @@ import (
"github.com/mark3labs/mcphost/internal/models"
)
// AgentInterface defines the interface we need from agent to avoid import cycles
// AgentInterface defines the minimal interface required from the agent package
// to avoid circular dependencies while still accessing necessary agent functionality.
type AgentInterface interface {
GetLoadingMessage() string
GetTools() []any // Using any to avoid importing tool types
GetLoadedServerNames() []string // Add this method for debug config
}
// CLISetupOptions contains options for setting up CLI
// CLISetupOptions encapsulates all configuration parameters needed to initialize
// and set up a CLI instance, including display preferences, model information,
// and debugging settings.
type CLISetupOptions struct {
Agent AgentInterface
ModelString string
@@ -35,7 +38,10 @@ func parseModelName(modelString string) (provider, model string) {
return "unknown", "unknown"
}
// SetupCLI creates and configures CLI with standard info display
// SetupCLI creates, configures, and initializes a CLI instance with the provided
// options. It sets up model display, usage tracking for supported providers, and
// shows initial loading information. Returns nil in quiet mode or an initialized
// CLI instance ready for user interaction.
func SetupCLI(opts *CLISetupOptions) (*CLI, error) {
if opts.Quiet {
return nil, nil // No CLI in quiet mode
+6 -2
View File
@@ -4,13 +4,17 @@ import (
"strings"
)
// FuzzyMatch represents a match result with score
// FuzzyMatch represents the result of a fuzzy string matching operation,
// containing the matched command and its relevance score. Higher scores
// indicate better matches.
type FuzzyMatch struct {
Command *SlashCommand
Score int
}
// FuzzyMatchCommands performs fuzzy matching on slash commands
// FuzzyMatchCommands performs fuzzy string matching on the provided slash commands
// based on the query string. Returns a slice of matches sorted by relevance score
// in descending order. An empty query returns all commands with zero scores.
func FuzzyMatchCommands(query string, commands []SlashCommand) []FuzzyMatch {
if query == "" || query == "/" {
// Return all commands when query is empty or just "/"
+59 -21
View File
@@ -10,7 +10,8 @@ import (
"github.com/charmbracelet/lipgloss"
)
// MessageType represents the type of message
// MessageType represents different categories of messages displayed in the UI,
// each with distinct visual styling and formatting rules.
type MessageType int
const (
@@ -22,7 +23,9 @@ const (
ErrorMessage // New type for error messages
)
// UIMessage represents a rendered message for display
// UIMessage encapsulates a fully rendered message ready for display in the UI,
// including its formatted content, display metrics, and metadata. Messages can
// be static or streaming (progressively updated).
type UIMessage struct {
ID string
Type MessageType
@@ -38,7 +41,9 @@ func getTheme() Theme {
return GetTheme()
}
// MessageRenderer handles rendering of messages with proper styling
// MessageRenderer handles the formatting and rendering of different message types
// with consistent styling, markdown support, and appropriate visual hierarchies
// for the standard (non-compact) display mode.
type MessageRenderer struct {
width int
debug bool
@@ -59,7 +64,9 @@ func getSystemUsername() string {
return "User"
}
// NewMessageRenderer creates a new message renderer
// NewMessageRenderer creates and initializes a new MessageRenderer with the specified
// terminal width and debug mode setting. The width parameter determines line wrapping
// and layout calculations.
func NewMessageRenderer(width int, debug bool) *MessageRenderer {
return &MessageRenderer{
width: width,
@@ -67,12 +74,15 @@ func NewMessageRenderer(width int, debug bool) *MessageRenderer {
}
}
// SetWidth updates the renderer width
// SetWidth updates the terminal width for the renderer, affecting how content
// is wrapped and formatted in subsequent render operations.
func (r *MessageRenderer) SetWidth(width int) {
r.width = width
}
// RenderUserMessage renders a user message with right border and background header
// RenderUserMessage renders a user's input message with distinctive right-aligned
// formatting, including the system username, timestamp, and markdown-rendered content.
// The message is displayed with a colored right border for visual distinction.
func (r *MessageRenderer) RenderUserMessage(content string, timestamp time.Time) UIMessage {
// Format timestamp and username
timeStr := timestamp.Local().Format("15:04")
@@ -106,7 +116,10 @@ func (r *MessageRenderer) RenderUserMessage(content string, timestamp time.Time)
}
}
// RenderAssistantMessage renders an assistant message with left border and background header
// RenderAssistantMessage renders an AI assistant's response with left-aligned formatting,
// including the model name, timestamp, and markdown-rendered content. Empty responses
// are displayed with a special "Finished without output" message. The message features
// a colored left border for visual distinction.
func (r *MessageRenderer) RenderAssistantMessage(content string, timestamp time.Time, modelName string) UIMessage {
// Format timestamp and model info with better defaults
timeStr := timestamp.Local().Format("15:04")
@@ -151,7 +164,9 @@ func (r *MessageRenderer) RenderAssistantMessage(content string, timestamp time.
}
}
// RenderSystemMessage renders a system message with left border and background header
// RenderSystemMessage renders MCPHost system messages such as help text, command outputs,
// and informational notifications. These messages are displayed with a distinctive system
// color border and "MCPHost System" label to differentiate them from user and AI content.
func (r *MessageRenderer) RenderSystemMessage(content string, timestamp time.Time) UIMessage {
// Format timestamp
timeStr := timestamp.Local().Format("15:04")
@@ -193,7 +208,9 @@ func (r *MessageRenderer) RenderSystemMessage(content string, timestamp time.Tim
}
}
// RenderDebugMessage renders debug messages with tool response block styling
// RenderDebugMessage renders diagnostic and debugging information with special formatting
// including a debug icon, colored border, and structured layout. Debug messages are only
// displayed when debug mode is enabled and help developers troubleshoot issues.
func (r *MessageRenderer) RenderDebugMessage(message string, timestamp time.Time) UIMessage {
baseStyle := lipgloss.NewStyle()
@@ -251,7 +268,9 @@ func (r *MessageRenderer) RenderDebugMessage(message string, timestamp time.Time
}
}
// RenderDebugConfigMessage renders debug configuration settings with tool response block styling
// RenderDebugConfigMessage renders configuration settings in a formatted debug display
// with key-value pairs shown in a structured layout. Used to display runtime configuration
// for debugging purposes with a distinctive icon and border styling.
func (r *MessageRenderer) RenderDebugConfigMessage(config map[string]any, timestamp time.Time) UIMessage {
baseStyle := lipgloss.NewStyle()
@@ -311,7 +330,9 @@ func (r *MessageRenderer) RenderDebugConfigMessage(config map[string]any, timest
}
}
// RenderErrorMessage renders an error message with left border and background header
// RenderErrorMessage renders error notifications with distinctive red coloring and
// bold text to ensure visibility. Error messages include timestamp information and
// are displayed with an error-colored border for immediate recognition.
func (r *MessageRenderer) RenderErrorMessage(errorMsg string, timestamp time.Time) UIMessage {
// Format timestamp
timeStr := timestamp.Local().Format("15:04")
@@ -347,7 +368,9 @@ func (r *MessageRenderer) RenderErrorMessage(errorMsg string, timestamp time.Tim
}
}
// RenderToolCallMessage renders a tool call in progress with left border and background header
// RenderToolCallMessage renders a notification that a tool is being executed, showing
// the tool name, formatted arguments (if any), and execution timestamp. The message
// uses tool-specific coloring to distinguish it from regular conversation messages.
func (r *MessageRenderer) RenderToolCallMessage(toolName, toolArgs string, timestamp time.Time) UIMessage {
// Format timestamp
timeStr := timestamp.Local().Format("15:04")
@@ -391,7 +414,10 @@ func (r *MessageRenderer) RenderToolCallMessage(toolName, toolArgs string, times
}
}
// RenderToolMessage renders a tool call message with proper styling
// RenderToolMessage renders the result of a tool execution, formatting the output
// based on the tool type and whether it succeeded or failed. Error results are
// displayed in red, while successful results are formatted according to the tool's
// output type (bash, file content, etc.).
func (r *MessageRenderer) RenderToolMessage(toolName, toolArgs, toolResult string, isError bool) UIMessage {
theme := getTheme()
@@ -595,7 +621,9 @@ func (r *MessageRenderer) renderMarkdown(content string, width int) string {
return strings.TrimSuffix(rendered, "\n")
}
// MessageContainer wraps multiple messages in a container
// MessageContainer manages a collection of UI messages, handling their display,
// updates, and layout within the terminal. It supports both standard and compact
// display modes and maintains state for streaming message updates.
type MessageContainer struct {
messages []UIMessage
width int
@@ -605,7 +633,9 @@ type MessageContainer struct {
wasCleared bool // Track if container was explicitly cleared
}
// NewMessageContainer creates a new message container
// NewMessageContainer creates and initializes a new MessageContainer with the
// specified dimensions and display mode. The container starts empty and will
// display a welcome message until the first message is added.
func NewMessageContainer(width, height int, compact bool) *MessageContainer {
return &MessageContainer{
messages: make([]UIMessage, 0),
@@ -615,18 +645,22 @@ func NewMessageContainer(width, height int, compact bool) *MessageContainer {
}
}
// AddMessage adds a message to the container
// AddMessage appends a new UIMessage to the container's collection and resets
// the cleared state flag. Messages are displayed in the order they were added.
func (c *MessageContainer) AddMessage(msg UIMessage) {
c.messages = append(c.messages, msg)
c.wasCleared = false // Reset the cleared flag when adding messages
}
// SetModelName sets the current model name for the container
// SetModelName updates the AI model name used for rendering assistant messages.
// This name is displayed in message headers to indicate which model is responding.
func (c *MessageContainer) SetModelName(modelName string) {
c.modelName = modelName
}
// UpdateLastMessage updates the content of the last message efficiently
// UpdateLastMessage efficiently updates the content of the most recent message
// in the container. This is primarily used for streaming responses where the
// assistant's message is progressively built. Only works for assistant messages.
func (c *MessageContainer) UpdateLastMessage(content string) {
if len(c.messages) == 0 {
return
@@ -651,19 +685,23 @@ func (c *MessageContainer) UpdateLastMessage(content string) {
}
}
// Clear clears all messages from the container
// Clear removes all messages from the container and sets a flag to prevent
// showing the welcome screen. Used when starting a fresh conversation.
func (c *MessageContainer) Clear() {
c.messages = make([]UIMessage, 0)
c.wasCleared = true
}
// SetSize updates the container size
// SetSize updates the container's dimensions, typically called when the terminal
// is resized. This affects how messages are wrapped and displayed.
func (c *MessageContainer) SetSize(width, height int) {
c.width = width
c.height = height
}
// Render renders all messages in the container
// Render generates the complete visual representation of all messages in the
// container. Returns an empty state display if no messages exist, or formats
// all messages according to the current display mode (standard or compact).
func (c *MessageContainer) Render() string {
if len(c.messages) == 0 {
// Don't show welcome box if explicitly cleared
+28 -10
View File
@@ -21,7 +21,9 @@ const (
maxWidth = 80
)
// OllamaPullProgress represents the progress information from Ollama pull API
// OllamaPullProgress represents the progress information received from Ollama's
// pull API when downloading model files. It includes status messages, digest
// information, and download progress counters.
type OllamaPullProgress struct {
Status string `json:"status"`
Digest string `json:"digest,omitempty"`
@@ -41,7 +43,9 @@ type progressErrMsg struct{ err error }
// progressCompleteMsg indicates completion
type progressCompleteMsg struct{}
// ProgressModel represents the progress bar model
// ProgressModel implements a tea.Model for displaying download progress with
// a visual progress bar and status messages. It handles progress updates,
// errors, and completion states for Ollama model downloads.
type ProgressModel struct {
progress progress.Model
status string
@@ -49,7 +53,8 @@ type ProgressModel struct {
complete bool
}
// NewProgressModel creates a new progress model
// NewProgressModel creates and initializes a new ProgressModel with a gradient
// progress bar and initial "Initializing..." status message.
func NewProgressModel() ProgressModel {
return ProgressModel{
progress: progress.New(progress.WithDefaultGradient()),
@@ -57,12 +62,15 @@ func NewProgressModel() ProgressModel {
}
}
// Init initializes the progress model
// Init implements the tea.Model interface, returning nil as no initial commands
// are needed for the progress display.
func (m ProgressModel) Init() tea.Cmd {
return nil
}
// Update handles progress updates
// Update implements the tea.Model interface, handling keyboard input, window
// resize events, and progress updates. It manages the progress bar state and
// triggers program exit on completion or cancellation.
func (m ProgressModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.KeyMsg:
@@ -108,7 +116,9 @@ func (m ProgressModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
}
// View renders the progress bar
// View implements the tea.Model interface, rendering the progress bar with
// status information and help text. Displays error messages if present or
// a completion message when the download finishes.
func (m ProgressModel) View() string {
if m.err != nil {
return fmt.Sprintf("Error: %s\n", m.err.Error())
@@ -128,7 +138,9 @@ func (m ProgressModel) View() string {
pad+helpStyle("Press 'q' or Ctrl+C to cancel"))
}
// ProgressReader wraps an io.Reader to provide progress updates for Ollama pull operations
// ProgressReader wraps an io.Reader to intercept and parse Ollama pull operation
// responses, extracting progress information and updating a visual progress bar.
// It manages a tea.Program for the UI and handles graceful shutdown.
type ProgressReader struct {
reader io.Reader
program *tea.Program
@@ -138,7 +150,9 @@ type ProgressReader struct {
wg sync.WaitGroup
}
// NewProgressReader creates a new progress reader for Ollama pull operations
// NewProgressReader creates and initializes a ProgressReader that wraps the provided
// io.Reader. It starts a tea.Program in a separate goroutine to display the progress
// bar UI while reading and parsing Ollama's streaming JSON responses.
func NewProgressReader(reader io.Reader) *ProgressReader {
model := NewProgressModel()
// Create program with standard settings
@@ -164,7 +178,9 @@ func NewProgressReader(reader io.Reader) *ProgressReader {
return pr
}
// Read implements io.Reader and parses Ollama streaming responses
// Read implements the io.Reader interface, passing through data from the wrapped
// reader while parsing JSON lines to extract progress information. Each complete
// JSON line is processed to update the progress bar display.
func (pr *ProgressReader) Read(p []byte) (n int, err error) {
n, err = pr.reader.Read(p)
if n > 0 {
@@ -239,7 +255,9 @@ func (pr *ProgressReader) parseProgressLine(line string) {
})
}
// Close stops the progress display and waits for cleanup
// Close gracefully shuts down the progress display, sending a completion message
// and waiting for the tea.Program to exit. If the program doesn't exit within
// 2 seconds, it is forcefully terminated to prevent hanging.
func (pr *ProgressReader) Close() error {
// Send completion message to trigger quit
pr.program.Send(progressCompleteMsg{})
+21 -8
View File
@@ -9,7 +9,9 @@ import (
"github.com/charmbracelet/lipgloss"
)
// SlashCommandInput is a custom input field with slash command autocomplete
// SlashCommandInput provides an interactive text input field with intelligent
// slash command autocomplete functionality. It displays a popup menu of matching
// commands as the user types, supporting fuzzy matching and keyboard navigation.
type SlashCommandInput struct {
textarea textarea.Model
commands []SlashCommand
@@ -26,7 +28,9 @@ type SlashCommandInput struct {
renderedLines int // Track how many lines were rendered
}
// NewSlashCommandInput creates a new slash command input field
// NewSlashCommandInput creates and initializes a new slash command input field with
// the specified width and title. The input supports multi-line text entry, command
// autocomplete, and is styled to match the application's theme.
func NewSlashCommandInput(width int, title string) *SlashCommandInput {
ta := textarea.New()
ta.Placeholder = "Type your message..."
@@ -54,12 +58,15 @@ func NewSlashCommandInput(width int, title string) *SlashCommandInput {
}
}
// Init implements tea.Model
// Init implements the tea.Model interface, returning the initial command to start
// the cursor blinking animation for the text input field.
func (s *SlashCommandInput) Init() tea.Cmd {
return textarea.Blink
}
// Update implements tea.Model
// Update implements the tea.Model interface, handling keyboard input for text entry,
// command selection, and navigation. Manages the autocomplete popup display and
// processes submission or cancellation actions.
func (s *SlashCommandInput) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
@@ -168,7 +175,9 @@ func (s *SlashCommandInput) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
}
// View implements tea.Model
// View implements the tea.Model interface, rendering the complete input field
// including the title, text area, autocomplete popup (when active), and help text.
// The view adapts based on whether single or multi-line input is detected.
func (s *SlashCommandInput) View() string {
// Add left padding to entire component (2 spaces like other UI elements)
containerStyle := lipgloss.NewStyle().PaddingLeft(2)
@@ -324,17 +333,21 @@ func (s *SlashCommandInput) renderPopup() string {
return popupStyle.Render(popupContent)
}
// Value returns the final value
// Value returns the final text value entered by the user after submission.
// This will be empty if the input was cancelled.
func (s *SlashCommandInput) Value() string {
return s.value
}
// Cancelled returns true if the user cancelled
// Cancelled returns true if the user cancelled the input operation (e.g., by
// pressing ESC or Ctrl+C) without submitting any text.
func (s *SlashCommandInput) Cancelled() bool {
return s.quitting && s.value == ""
}
// RenderedLines returns how many lines were rendered
// RenderedLines returns the total number of terminal lines used by the last
// rendered view, including the title, input area, popup, and help text. This
// is used for proper screen clearing when the input is dismissed.
func (s *SlashCommandInput) RenderedLines() int {
return s.renderedLines
}
+14 -5
View File
@@ -10,7 +10,9 @@ import (
"github.com/charmbracelet/lipgloss"
)
// Spinner wraps the bubbles spinner for both interactive and non-interactive mode
// Spinner provides an animated loading indicator that displays while long-running
// operations are in progress. It wraps the bubbles spinner component and manages
// its lifecycle through a tea.Program for proper terminal handling.
type Spinner struct {
model spinner.Model
done chan struct{}
@@ -72,7 +74,9 @@ func (m spinnerModel) View() string {
// quitMsg is sent when we want to quit the spinner
type quitMsg struct{}
// NewSpinner creates a new spinner with enhanced styling
// NewSpinner creates a new animated spinner with the specified message. The spinner
// uses the theme's primary color and a modern animation style. It runs in a separate
// tea.Program to avoid interfering with other terminal operations.
func NewSpinner(message string) *Spinner {
s := spinner.New()
s.Spinner = spinner.Points // More modern spinner style
@@ -97,7 +101,9 @@ func NewSpinner(message string) *Spinner {
}
}
// NewThemedSpinner creates a new spinner with the given message and color
// NewThemedSpinner creates a new animated spinner with custom color styling.
// This allows for different spinner colors based on the operation type or status.
// The spinner runs independently in its own tea.Program.
func NewThemedSpinner(message string, color lipgloss.AdaptiveColor) *Spinner {
s := spinner.New()
s.Spinner = spinner.Dot
@@ -121,7 +127,9 @@ func NewThemedSpinner(message string, color lipgloss.AdaptiveColor) *Spinner {
}
}
// Start begins the spinner animation
// Start begins the spinner animation in a separate goroutine. The spinner will
// continue animating until Stop is called. The animation runs in a separate
// tea.Program to maintain smooth animation independent of other operations.
func (s *Spinner) Start() {
go func() {
defer close(s.done)
@@ -136,7 +144,8 @@ func (s *Spinner) Start() {
}()
}
// Stop ends the spinner animation
// Stop halts the spinner animation and cleans up resources. This method blocks
// until the spinner has fully stopped and the terminal state is restored.
func (s *Spinner) Stop() {
s.cancel()
<-s.done
+6 -2
View File
@@ -15,12 +15,16 @@ func boolPtr(b bool) *bool { return &b }
func stringPtr(s string) *string { return &s }
func uintPtr(u uint) *uint { return &u }
// BaseStyle returns a basic lipgloss style
// BaseStyle returns a new, empty lipgloss style that can be customized with
// additional styling methods. This serves as the foundation for building more
// complex styled components.
func BaseStyle() lipgloss.Style {
return lipgloss.NewStyle()
}
// GetMarkdownRenderer returns a glamour TermRenderer configured for our use
// GetMarkdownRenderer creates and returns a configured glamour.TermRenderer for
// rendering markdown content with syntax highlighting and proper formatting. The
// renderer is customized with our theme colors and adapted to the specified width.
func GetMarkdownRenderer(width int) *glamour.TermRenderer {
r, _ := glamour.NewTermRenderer(
glamour.WithStyles(generateMarkdownStyleConfig()),
+41 -14
View File
@@ -9,7 +9,9 @@ import (
"github.com/mark3labs/mcphost/internal/tokens"
)
// UsageStats represents token and cost information for a single request/response
// UsageStats encapsulates detailed token usage and cost breakdown for a single
// LLM request/response cycle, including input, output, and cache token counts
// along with their associated costs.
type UsageStats struct {
InputTokens int
OutputTokens int
@@ -22,7 +24,9 @@ type UsageStats struct {
TotalCost float64
}
// SessionStats represents cumulative stats for the entire session
// SessionStats aggregates token usage and cost information across all requests
// in a session, providing totals and request counts for usage analysis and
// cost tracking.
type SessionStats struct {
TotalInputTokens int
TotalOutputTokens int
@@ -32,7 +36,9 @@ type SessionStats struct {
RequestCount int
}
// UsageTracker tracks token usage and costs for LLM interactions
// UsageTracker monitors and accumulates token usage statistics and associated costs
// for LLM interactions throughout a session. It provides real-time usage information
// and supports both estimated and actual token counts. OAuth users see $0 costs.
type UsageTracker struct {
mu sync.RWMutex
modelInfo *models.ModelInfo
@@ -43,7 +49,10 @@ type UsageTracker struct {
isOAuth bool // Whether OAuth credentials are being used (costs should be $0)
}
// NewUsageTracker creates a new usage tracker for the given model
// NewUsageTracker creates and initializes a new UsageTracker for the specified model.
// The tracker uses model-specific pricing information to calculate costs, unless OAuth
// credentials are being used (in which case costs are shown as $0). Width determines
// the display formatting.
func NewUsageTracker(modelInfo *models.ModelInfo, provider string, width int, isOAuth bool) *UsageTracker {
return &UsageTracker{
modelInfo: modelInfo,
@@ -53,15 +62,19 @@ func NewUsageTracker(modelInfo *models.ModelInfo, provider string, width int, is
}
}
// EstimateTokens provides a rough estimate of tokens in text
// This is a simple approximation - real token counting would require the actual tokenizer
// EstimateTokens provides a rough estimate of the number of tokens in the given text.
// This uses a simple heuristic of approximately 4 characters per token, which is a
// reasonable approximation for most models but not precise. Actual token counts may vary
// significantly based on the specific tokenizer used by each model.
func EstimateTokens(text string) int {
// Rough approximation: ~4 characters per token for most models
// This is not accurate but gives a reasonable estimate
return len(text) / 4
}
// UpdateUsage updates the tracker with new usage information
// UpdateUsage records new token usage data and calculates associated costs based on
// the model's pricing. Updates both the last request statistics and cumulative session
// totals. For OAuth users, costs are recorded as $0 while still tracking token counts.
func (ut *UsageTracker) UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens int) {
ut.mu.Lock()
defer ut.mu.Unlock()
@@ -107,21 +120,27 @@ func (ut *UsageTracker) UpdateUsage(inputTokens, outputTokens, cacheReadTokens,
ut.sessionStats.RequestCount++
}
// EstimateAndUpdateUsage estimates tokens from text and updates usage
// EstimateAndUpdateUsage estimates token counts from raw text strings and updates
// the usage statistics. This method is used when actual token counts are not available
// from the API response.
func (ut *UsageTracker) EstimateAndUpdateUsage(inputText, outputText string) {
inputTokens := tokens.EstimateTokens(inputText)
outputTokens := tokens.EstimateTokens(outputText)
ut.UpdateUsage(inputTokens, outputTokens, 0, 0)
}
// EstimateAndUpdateUsageFromText estimates tokens from text and updates usage
// EstimateAndUpdateUsageFromText is an alias for EstimateAndUpdateUsage, providing
// backward compatibility. It estimates token counts from text and updates usage statistics.
func (ut *UsageTracker) EstimateAndUpdateUsageFromText(inputText, outputText string) {
inputTokens := tokens.EstimateTokens(inputText)
outputTokens := tokens.EstimateTokens(outputText)
ut.UpdateUsage(inputTokens, outputTokens, 0, 0)
}
// RenderUsageInfo renders enhanced usage information with better styling
// RenderUsageInfo generates a formatted string displaying current usage statistics
// including token counts, context utilization percentage, and costs. The display
// adapts colors based on usage levels and formats large numbers with K/M suffixes
// for readability.
func (ut *UsageTracker) RenderUsageInfo() string {
ut.mu.RLock()
defer ut.mu.RUnlock()
@@ -199,14 +218,18 @@ func (ut *UsageTracker) RenderUsageInfo() string {
tokensLabel, tokensValue, percentageStr, costLabel, costStr)
}
// GetSessionStats returns a copy of the current session statistics
// GetSessionStats returns a copy of the cumulative session statistics including
// total token counts, costs, and request count. The returned copy is safe to use
// without additional synchronization.
func (ut *UsageTracker) GetSessionStats() SessionStats {
ut.mu.RLock()
defer ut.mu.RUnlock()
return ut.sessionStats
}
// GetLastRequestStats returns a copy of the last request statistics
// GetLastRequestStats returns a copy of the usage statistics from the most recent
// request, or nil if no requests have been made. The returned copy is safe to use
// without additional synchronization.
func (ut *UsageTracker) GetLastRequestStats() *UsageStats {
ut.mu.RLock()
defer ut.mu.RUnlock()
@@ -217,7 +240,9 @@ func (ut *UsageTracker) GetLastRequestStats() *UsageStats {
return &stats
}
// Reset clears all usage statistics
// Reset clears all accumulated usage statistics, resetting both session totals
// and last request information to their initial empty state. This is typically
// used when starting a new conversation or clearing usage history.
func (ut *UsageTracker) Reset() {
ut.mu.Lock()
defer ut.mu.Unlock()
@@ -225,7 +250,9 @@ func (ut *UsageTracker) Reset() {
ut.lastRequest = nil
}
// SetWidth updates the display width for rendering
// SetWidth updates the terminal width used for formatting usage information display.
// This should be called when the terminal is resized to ensure proper text wrapping
// and alignment.
func (ut *UsageTracker) SetWidth(width int) {
ut.mu.Lock()
defer ut.mu.Unlock()
+28 -11
View File
@@ -13,14 +13,18 @@ import (
"github.com/spf13/viper"
)
// MCPHost provides programmatic access to mcphost
// MCPHost provides programmatic access to mcphost functionality, allowing
// integration of MCP tools and LLM interactions into Go applications. It manages
// agents, sessions, and model configurations.
type MCPHost struct {
agent *agent.Agent
sessionMgr *session.Manager
modelString string
}
// Options for creating MCPHost (all optional - will use CLI defaults)
// Options configures MCPHost creation with optional overrides for model,
// prompts, configuration, and behavior settings. All fields are optional
// and will use CLI defaults if not specified.
type Options struct {
Model string // Override model (e.g., "anthropic:claude-3-sonnet")
SystemPrompt string // Override system prompt
@@ -30,7 +34,9 @@ type Options struct {
Quiet bool // Suppress debug output
}
// New creates MCPHost instance using the same initialization as CLI
// New creates an MCPHost instance using the same initialization as the CLI.
// It loads configuration, initializes MCP servers, creates the LLM model, and
// sets up the agent for interaction. Returns an error if initialization fails.
func New(ctx context.Context, opts *Options) (*MCPHost, error) {
if opts == nil {
opts = &Options{}
@@ -118,7 +124,9 @@ func New(ctx context.Context, opts *Options) (*MCPHost, error) {
}, nil
}
// Prompt sends a message and returns the response
// Prompt sends a message to the agent and returns the response. The agent may
// use tools as needed to generate the response. The conversation history is
// automatically maintained in the session. Returns an error if generation fails.
func (m *MCPHost) Prompt(ctx context.Context, message string) (string, error) {
// Get messages from session
messages := m.sessionMgr.GetMessages()
@@ -148,7 +156,9 @@ func (m *MCPHost) Prompt(ctx context.Context, message string) (string, error) {
return result.FinalResponse.Content, nil
}
// PromptWithCallbacks sends a message with callbacks for tool execution
// PromptWithCallbacks sends a message with callbacks for monitoring tool execution
// and streaming responses. The callbacks allow real-time observation of tool calls,
// results, and response generation. Returns the final response or an error.
func (m *MCPHost) PromptWithCallbacks(
ctx context.Context,
message string,
@@ -184,12 +194,14 @@ func (m *MCPHost) PromptWithCallbacks(
return result.FinalResponse.Content, nil
}
// GetSessionManager returns the current session manager
// GetSessionManager returns the current session manager for direct access
// to conversation history and session manipulation.
func (m *MCPHost) GetSessionManager() *session.Manager {
return m.sessionMgr
}
// LoadSession loads a session from file
// LoadSession loads a previously saved session from a file, restoring the
// conversation history. Returns an error if the file cannot be loaded or parsed.
func (m *MCPHost) LoadSession(path string) error {
s, err := session.LoadFromFile(path)
if err != nil {
@@ -199,22 +211,27 @@ func (m *MCPHost) LoadSession(path string) error {
return nil
}
// SaveSession saves the current session to file
// SaveSession saves the current session to a file for later restoration.
// Returns an error if the session cannot be written to the specified path.
func (m *MCPHost) SaveSession(path string) error {
return m.sessionMgr.GetSession().SaveToFile(path)
}
// ClearSession clears the current session history
// ClearSession clears the current session history, starting a new conversation
// with an empty message history.
func (m *MCPHost) ClearSession() {
m.sessionMgr = session.NewManager("")
}
// GetModelString returns the current model string
// GetModelString returns the current model string identifier (e.g.,
// "anthropic:claude-3-sonnet" or "openai:gpt-4") being used by the agent.
func (m *MCPHost) GetModelString() string {
return m.modelString
}
// Close cleans up resources
// Close cleans up resources including MCP server connections and model resources.
// Should be called when the MCPHost instance is no longer needed. Returns an
// error if cleanup fails.
func (m *MCPHost) Close() error {
return m.agent.Close()
}
+8 -4
View File
@@ -5,18 +5,22 @@ import (
"github.com/mark3labs/mcphost/internal/session"
)
// Message is an alias for session.Message for SDK users
// Message is an alias for session.Message providing SDK users with access
// to message structures for conversation history and tool interactions.
type Message = session.Message
// ToolCall is an alias for session.ToolCall
// ToolCall is an alias for session.ToolCall representing a tool invocation
// with its name, arguments, and result within a conversation.
type ToolCall = session.ToolCall
// ConvertToSchemaMessage converts SDK message to schema message
// ConvertToSchemaMessage converts an SDK message to the underlying schema message
// format used by the agent for LLM interactions.
func ConvertToSchemaMessage(msg *Message) *schema.Message {
return msg.ConvertToSchemaMessage()
}
// ConvertFromSchemaMessage converts schema message to SDK message
// ConvertFromSchemaMessage converts a schema message from the agent to an SDK
// message format for use in the SDK API.
func ConvertFromSchemaMessage(msg *schema.Message) Message {
return session.ConvertFromSchemaMessage(msg)
}