mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-13 19:20:06 +00:00
godoc
This commit is contained in:
+13
@@ -10,6 +10,10 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// authCmd represents the auth command for managing AI provider authentication.
|
||||
// This command provides subcommands for login, logout, and status checking
|
||||
// of authentication credentials for various AI providers, with OAuth support
|
||||
// for providers like Anthropic.
|
||||
var authCmd = &cobra.Command{
|
||||
Use: "auth",
|
||||
Short: "Manage authentication credentials for AI providers",
|
||||
@@ -27,6 +31,9 @@ Examples:
|
||||
mcphost auth status`,
|
||||
}
|
||||
|
||||
// authLoginCmd represents the login subcommand for authenticating with AI providers.
|
||||
// It handles OAuth flow for supported providers, opening a browser for authentication
|
||||
// and securely storing the resulting credentials for future use.
|
||||
var authLoginCmd = &cobra.Command{
|
||||
Use: "login [provider]",
|
||||
Short: "Authenticate with an AI provider using OAuth",
|
||||
@@ -45,6 +52,9 @@ Example:
|
||||
RunE: runAuthLogin,
|
||||
}
|
||||
|
||||
// authLogoutCmd represents the logout subcommand for removing stored authentication credentials.
|
||||
// This command removes stored API keys or OAuth tokens for specified providers,
|
||||
// requiring the user to authenticate again or use environment variables.
|
||||
var authLogoutCmd = &cobra.Command{
|
||||
Use: "logout [provider]",
|
||||
Short: "Remove stored authentication credentials for a provider",
|
||||
@@ -62,6 +72,9 @@ Example:
|
||||
RunE: runAuthLogout,
|
||||
}
|
||||
|
||||
// authStatusCmd represents the status subcommand for checking authentication status.
|
||||
// It displays which providers have stored credentials, their types (OAuth vs API key),
|
||||
// creation dates, and expiration status without revealing the actual credentials.
|
||||
var authStatusCmd = &cobra.Command{
|
||||
Use: "status",
|
||||
Short: "Show authentication status for all providers",
|
||||
|
||||
@@ -10,12 +10,18 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// hooksCmd represents the hooks command for managing MCPHost hook configurations.
|
||||
// Hooks allow users to execute custom scripts or commands at various points
|
||||
// during MCPHost execution, such as before/after tool use or when prompts are submitted.
|
||||
var hooksCmd = &cobra.Command{
|
||||
Use: "hooks",
|
||||
Short: "Manage MCPHost hooks",
|
||||
Long: "Commands for managing and testing MCPHost hooks configuration",
|
||||
}
|
||||
|
||||
// hooksListCmd represents the list subcommand for displaying all configured hooks.
|
||||
// It shows a formatted table of hook events, matchers, commands, and timeouts
|
||||
// to help users understand their current hook configuration.
|
||||
var hooksListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List all configured hooks",
|
||||
@@ -45,6 +51,9 @@ var hooksListCmd = &cobra.Command{
|
||||
},
|
||||
}
|
||||
|
||||
// hooksValidateCmd represents the validate subcommand for checking hook configuration validity.
|
||||
// It loads and validates the hooks configuration file, ensuring proper syntax,
|
||||
// valid event types, and correct matcher patterns before use.
|
||||
var hooksValidateCmd = &cobra.Command{
|
||||
Use: "validate",
|
||||
Short: "Validate hooks configuration",
|
||||
@@ -64,6 +73,9 @@ var hooksValidateCmd = &cobra.Command{
|
||||
},
|
||||
}
|
||||
|
||||
// hooksInitCmd represents the init subcommand for generating an example hooks configuration.
|
||||
// It creates a .mcphost/hooks.yml file with sample hook configurations demonstrating
|
||||
// various hook events and common use cases like logging commands and tool usage.
|
||||
var hooksInitCmd = &cobra.Command{
|
||||
Use: "init",
|
||||
Short: "Generate example hooks configuration",
|
||||
|
||||
+23
-3
@@ -84,6 +84,10 @@ func (a *agentUIAdapter) GetLoadedServerNames() []string {
|
||||
return a.agent.GetLoadedServerNames()
|
||||
}
|
||||
|
||||
// rootCmd represents the base command when called without any subcommands.
|
||||
// This is the main entry point for the MCPHost CLI application, providing
|
||||
// an interface to interact with various AI models through a unified interface
|
||||
// with support for MCP servers and tool integration.
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "mcphost",
|
||||
Short: "Chat with AI models through a unified interface",
|
||||
@@ -120,12 +124,21 @@ Examples:
|
||||
},
|
||||
}
|
||||
|
||||
// GetRootCommand returns the root command with the version set
|
||||
// GetRootCommand returns the root command with the version set.
|
||||
// This function is the main entry point for the MCPHost CLI and should be
|
||||
// called from main.go with the appropriate version string.
|
||||
func GetRootCommand(v string) *cobra.Command {
|
||||
rootCmd.Version = v
|
||||
return rootCmd
|
||||
}
|
||||
|
||||
// InitConfig initializes the configuration for MCPHost by loading config files,
|
||||
// environment variables, and hooks configuration. It follows this priority order:
|
||||
// 1. Command-line specified config file (--config flag)
|
||||
// 2. Current directory config file (.mcphost or .mcp)
|
||||
// 3. Home directory config file (~/.mcphost or ~/.mcp)
|
||||
// 4. Environment variables (MCPHOST_* prefix)
|
||||
// This function is automatically called by cobra before command execution.
|
||||
func InitConfig() {
|
||||
if configFile != "" {
|
||||
// Use config file from the flag
|
||||
@@ -202,7 +215,12 @@ func InitConfig() {
|
||||
|
||||
}
|
||||
|
||||
// LoadConfigWithEnvSubstitution loads a config file with environment variable substitution
|
||||
// LoadConfigWithEnvSubstitution loads a config file with environment variable substitution.
|
||||
// It reads the config file, replaces any ${ENV_VAR} patterns with their corresponding
|
||||
// environment variable values, and then parses the resulting configuration using viper.
|
||||
// The function automatically detects JSON or YAML format based on file extension.
|
||||
// Returns an error if the file cannot be read, environment variable substitution fails,
|
||||
// or the configuration cannot be parsed.
|
||||
func LoadConfigWithEnvSubstitution(configPath string) error {
|
||||
// Read raw config file content
|
||||
rawContent, err := os.ReadFile(configPath)
|
||||
@@ -728,7 +746,9 @@ func runNormalMode(ctx context.Context) error {
|
||||
return runInteractiveMode(ctx, mcpAgent, cli, serverNames, toolNames, modelName, messages, sessionManager, hookExecutor)
|
||||
}
|
||||
|
||||
// AgenticLoopConfig configures the behavior of the unified agentic loop
|
||||
// AgenticLoopConfig configures the behavior of the unified agentic loop.
|
||||
// This struct controls how the main interaction loop operates, whether in
|
||||
// interactive or non-interactive mode, and manages various UI and session options.
|
||||
type AgenticLoopConfig struct {
|
||||
// Mode configuration
|
||||
IsInteractive bool // true for interactive mode, false for non-interactive
|
||||
|
||||
+10
-4
@@ -21,6 +21,10 @@ import (
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// scriptCmd represents the script command for executing MCPHost script files.
|
||||
// Script files can contain YAML frontmatter configuration followed by a prompt,
|
||||
// allowing for reproducible AI interactions with custom configurations and
|
||||
// variable substitution support.
|
||||
var scriptCmd = &cobra.Command{
|
||||
Use: "script <script-file>",
|
||||
Short: "Execute a script file with YAML frontmatter configuration",
|
||||
@@ -413,11 +417,13 @@ func parseScriptContent(content string, variables map[string]string) (*config.Co
|
||||
return &scriptConfig, nil
|
||||
}
|
||||
|
||||
// Variable represents a script variable with optional default value
|
||||
// Variable represents a script variable with optional default value.
|
||||
// Variables can be declared in scripts using ${variable} syntax for required variables
|
||||
// or ${variable:-default} syntax for variables with default values.
|
||||
type Variable struct {
|
||||
Name string
|
||||
DefaultValue string
|
||||
HasDefault bool
|
||||
Name string // The name of the variable as it appears in the script
|
||||
DefaultValue string // The default value if specified using ${variable:-default} syntax
|
||||
HasDefault bool // Whether this variable has a default value
|
||||
}
|
||||
|
||||
// findVariables extracts all unique variable names from ${variable} patterns in content
|
||||
|
||||
+50
-22
@@ -16,35 +16,49 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// AgentConfig is the config for agent.
|
||||
// AgentConfig holds configuration options for creating a new Agent.
|
||||
// It includes model configuration, MCP settings, and various behavioral options.
|
||||
type AgentConfig struct {
|
||||
ModelConfig *models.ProviderConfig
|
||||
MCPConfig *config.Config
|
||||
SystemPrompt string
|
||||
MaxSteps int
|
||||
// ModelConfig specifies the LLM provider and model to use
|
||||
ModelConfig *models.ProviderConfig
|
||||
// MCPConfig contains MCP server configurations
|
||||
MCPConfig *config.Config
|
||||
// SystemPrompt is the initial system message for the agent
|
||||
SystemPrompt string
|
||||
// MaxSteps limits the number of tool calls (0 for unlimited)
|
||||
MaxSteps int
|
||||
// StreamingEnabled controls whether responses are streamed
|
||||
StreamingEnabled bool
|
||||
DebugLogger tools.DebugLogger // Optional debug logger
|
||||
// DebugLogger is an optional logger for debugging MCP communications
|
||||
DebugLogger tools.DebugLogger // Optional debug logger
|
||||
}
|
||||
|
||||
// ToolCallHandler is a function type for handling tool calls as they happen
|
||||
// ToolCallHandler is a function type for handling tool calls as they happen.
|
||||
// It receives the tool name and its arguments when a tool is about to be invoked.
|
||||
type ToolCallHandler func(toolName, toolArgs string)
|
||||
|
||||
// ToolExecutionHandler is a function type for handling tool execution start/end
|
||||
// ToolExecutionHandler is a function type for handling tool execution start/end events.
|
||||
// The isStarting parameter indicates whether the tool is starting (true) or finished (false).
|
||||
type ToolExecutionHandler func(toolName string, isStarting bool)
|
||||
|
||||
// ToolResultHandler is a function type for handling tool results
|
||||
// ToolResultHandler is a function type for handling tool results.
|
||||
// It receives the tool name, arguments, result, and whether the result is an error.
|
||||
type ToolResultHandler func(toolName, toolArgs, result string, isError bool)
|
||||
|
||||
// ResponseHandler is a function type for handling LLM responses
|
||||
// ResponseHandler is a function type for handling LLM responses.
|
||||
// It receives the complete response content from the model.
|
||||
type ResponseHandler func(content string)
|
||||
|
||||
// StreamingResponseHandler is a function type for handling streaming LLM responses
|
||||
// StreamingResponseHandler is a function type for handling streaming LLM responses.
|
||||
// It receives content chunks as they are streamed from the model.
|
||||
type StreamingResponseHandler func(content string)
|
||||
|
||||
// ToolCallContentHandler is a function type for handling content that accompanies tool calls
|
||||
// ToolCallContentHandler is a function type for handling content that accompanies tool calls.
|
||||
// It receives any text content that the model generates alongside tool calls.
|
||||
type ToolCallContentHandler func(content string)
|
||||
|
||||
// Agent is the agent with real-time tool call display.
|
||||
// Agent represents an AI agent with MCP tool integration and real-time tool call display.
|
||||
// It manages the interaction between an LLM and various tools through the MCP protocol.
|
||||
type Agent struct {
|
||||
toolManager *tools.MCPToolManager
|
||||
model model.ToolCallingChatModel
|
||||
@@ -55,7 +69,10 @@ type Agent struct {
|
||||
streamingEnabled bool // Whether streaming is enabled
|
||||
}
|
||||
|
||||
// NewAgent creates an agent with MCP tool integration and real-time tool call display
|
||||
// NewAgent creates a new Agent with MCP tool integration and streaming support.
|
||||
// It initializes the LLM provider, loads MCP tools, and configures the agent
|
||||
// based on the provided configuration. Returns an error if provider creation
|
||||
// or tool loading fails.
|
||||
func NewAgent(ctx context.Context, config *AgentConfig) (*Agent, error) {
|
||||
// Create the LLM provider
|
||||
providerResult, err := models.CreateProvider(ctx, config.ModelConfig)
|
||||
@@ -98,20 +115,27 @@ func NewAgent(ctx context.Context, config *AgentConfig) (*Agent, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GenerateWithLoopResult contains the result and conversation history
|
||||
// GenerateWithLoopResult contains the result and conversation history from an agent interaction.
|
||||
// It includes both the final response and the complete message history with tool interactions.
|
||||
type GenerateWithLoopResult struct {
|
||||
FinalResponse *schema.Message
|
||||
// FinalResponse is the last message generated by the model
|
||||
FinalResponse *schema.Message
|
||||
// ConversationMessages contains all messages in the conversation including tool calls and results
|
||||
ConversationMessages []*schema.Message // All messages in the conversation (including tool calls and results)
|
||||
}
|
||||
|
||||
// GenerateWithLoop processes messages with a custom loop that displays tool calls in real-time
|
||||
// GenerateWithLoop processes messages with a custom loop that displays tool calls in real-time.
|
||||
// It handles the conversation flow, executing tools as needed and invoking callbacks for various events.
|
||||
// This method does not support streaming responses; use GenerateWithLoopAndStreaming for streaming support.
|
||||
func (a *Agent) GenerateWithLoop(ctx context.Context, messages []*schema.Message,
|
||||
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler) (*GenerateWithLoopResult, error) {
|
||||
|
||||
return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult, onResponse, onToolCallContent, nil)
|
||||
}
|
||||
|
||||
// GenerateWithLoopAndStreaming processes messages with a custom loop that displays tool calls in real-time and supports streaming callbacks
|
||||
// GenerateWithLoopAndStreaming processes messages with a custom loop that displays tool calls in real-time and supports streaming callbacks.
|
||||
// It handles the conversation flow, executing tools as needed and invoking callbacks for various events including streaming chunks.
|
||||
// The onStreamingResponse callback is invoked for each content chunk during streaming if streaming is enabled.
|
||||
func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []*schema.Message,
|
||||
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler, onStreamingResponse StreamingResponseHandler) (*GenerateWithLoopResult, error) {
|
||||
|
||||
@@ -256,17 +280,20 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []*sc
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetTools returns the list of available tools
|
||||
// GetTools returns the list of available tools loaded in the agent.
|
||||
// These tools are available for the model to use during interactions.
|
||||
func (a *Agent) GetTools() []tool.BaseTool {
|
||||
return a.toolManager.GetTools()
|
||||
}
|
||||
|
||||
// GetLoadingMessage returns the loading message from provider creation (e.g., GPU fallback info)
|
||||
// GetLoadingMessage returns the loading message from provider creation.
|
||||
// This may contain information about GPU fallback or other provider-specific initialization details.
|
||||
func (a *Agent) GetLoadingMessage() string {
|
||||
return a.loadingMessage
|
||||
}
|
||||
|
||||
// GetLoadedServerNames returns the names of successfully loaded MCP servers
|
||||
// GetLoadedServerNames returns the names of successfully loaded MCP servers.
|
||||
// This includes both builtin servers and external MCP server configurations.
|
||||
func (a *Agent) GetLoadedServerNames() []string {
|
||||
return a.toolManager.GetLoadedServerNames()
|
||||
}
|
||||
@@ -486,7 +513,8 @@ func (a *Agent) listenForESC(stopChan chan bool, readyChan chan bool) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the agent and cleans up resources
|
||||
// Close closes the agent and cleans up resources.
|
||||
// It ensures all MCP connections are properly closed and resources are released.
|
||||
func (a *Agent) Close() error {
|
||||
return a.toolManager.Close()
|
||||
}
|
||||
|
||||
+27
-12
@@ -10,23 +10,36 @@ import (
|
||||
"github.com/mark3labs/mcphost/internal/tools"
|
||||
)
|
||||
|
||||
// SpinnerFunc is a function type for showing spinners during agent creation
|
||||
// SpinnerFunc is a function type for showing spinners during agent creation.
|
||||
// It executes the provided function while displaying a spinner with the given message.
|
||||
type SpinnerFunc func(message string, fn func() error) error
|
||||
|
||||
// AgentCreationOptions contains options for creating an agent
|
||||
// AgentCreationOptions contains options for creating an agent.
|
||||
// It extends AgentConfig with UI-related options for showing progress during creation.
|
||||
type AgentCreationOptions struct {
|
||||
ModelConfig *models.ProviderConfig
|
||||
MCPConfig *config.Config
|
||||
SystemPrompt string
|
||||
MaxSteps int
|
||||
// ModelConfig specifies the LLM provider and model to use
|
||||
ModelConfig *models.ProviderConfig
|
||||
// MCPConfig contains MCP server configurations
|
||||
MCPConfig *config.Config
|
||||
// SystemPrompt is the initial system message for the agent
|
||||
SystemPrompt string
|
||||
// MaxSteps limits the number of tool calls (0 for unlimited)
|
||||
MaxSteps int
|
||||
// StreamingEnabled controls whether responses are streamed
|
||||
StreamingEnabled bool
|
||||
ShowSpinner bool // For Ollama models
|
||||
Quiet bool // Skip spinner if quiet
|
||||
SpinnerFunc SpinnerFunc // Function to show spinner (provided by caller)
|
||||
DebugLogger tools.DebugLogger // Optional debug logger
|
||||
// ShowSpinner indicates whether to show a spinner for Ollama models during loading
|
||||
ShowSpinner bool // For Ollama models
|
||||
// Quiet suppresses the spinner even if ShowSpinner is true
|
||||
Quiet bool // Skip spinner if quiet
|
||||
// SpinnerFunc is the function to show spinner, provided by the caller
|
||||
SpinnerFunc SpinnerFunc // Function to show spinner (provided by caller)
|
||||
// DebugLogger is an optional logger for debugging MCP communications
|
||||
DebugLogger tools.DebugLogger // Optional debug logger
|
||||
}
|
||||
|
||||
// CreateAgent creates an agent with optional spinner for Ollama models
|
||||
// CreateAgent creates an agent with optional spinner for Ollama models.
|
||||
// It shows a loading spinner for Ollama models if ShowSpinner is true and not in quiet mode.
|
||||
// Returns the created agent or an error if creation fails.
|
||||
func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error) {
|
||||
agentConfig := &AgentConfig{
|
||||
ModelConfig: opts.ModelConfig,
|
||||
@@ -57,7 +70,9 @@ func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error
|
||||
return agent, nil
|
||||
}
|
||||
|
||||
// ParseModelName extracts provider and model name from model string
|
||||
// ParseModelName extracts provider and model name from a model string.
|
||||
// Model strings are formatted as "provider:model" (e.g., "anthropic:claude-3-5-sonnet-20241022").
|
||||
// If the string doesn't contain a colon, returns "unknown" for both provider and model.
|
||||
func ParseModelName(modelString string) (provider, model string) {
|
||||
parts := strings.SplitN(modelString, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
|
||||
@@ -8,7 +8,8 @@ import (
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// StreamWithCallback streams content with real-time callbacks and returns complete response
|
||||
// StreamWithCallback streams content with real-time callbacks and returns the complete response.
|
||||
// It accumulates content and tool calls from the stream, invoking the callback for each content chunk.
|
||||
// IMPORTANT: Tool calls are only processed after EOF is reached to ensure we have the complete
|
||||
// and final tool call information. This prevents premature tool execution on partial data.
|
||||
// Handles different provider streaming patterns:
|
||||
|
||||
@@ -6,7 +6,11 @@ import (
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// OpenBrowser opens the default browser to the specified URL
|
||||
// OpenBrowser opens the default web browser to the specified URL.
|
||||
// It automatically detects the operating system and uses the appropriate
|
||||
// command to launch the browser (xdg-open on Linux, rundll32 on Windows,
|
||||
// open on macOS). Returns an error if the platform is unsupported or if
|
||||
// the browser fails to launch.
|
||||
func OpenBrowser(url string) error {
|
||||
var err error
|
||||
|
||||
@@ -24,7 +28,9 @@ func OpenBrowser(url string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// TryOpenBrowser attempts to open the browser but doesn't fail if it can't
|
||||
// TryOpenBrowser attempts to open the default web browser to the specified URL
|
||||
// but silently ignores any errors. This is useful when browser access is optional
|
||||
// and users can manually copy and paste the URL if automatic browser launching fails.
|
||||
func TryOpenBrowser(url string) {
|
||||
// Silently ignore errors - user can still copy/paste the URL
|
||||
_ = OpenBrowser(url)
|
||||
|
||||
@@ -9,12 +9,16 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// CredentialStore holds all stored credentials
|
||||
// CredentialStore holds all stored credentials for various providers.
|
||||
// Currently supports Anthropic credentials with both OAuth and API key authentication methods.
|
||||
type CredentialStore struct {
|
||||
Anthropic *AnthropicCredentials `json:"anthropic,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicCredentials holds Anthropic API credentials
|
||||
// AnthropicCredentials holds Anthropic API credentials supporting both OAuth
|
||||
// and API key authentication methods. The Type field indicates which authentication
|
||||
// method is being used. For OAuth, tokens are stored with expiration timestamps
|
||||
// for automatic refresh. For API keys, only the key itself is stored.
|
||||
type AnthropicCredentials struct {
|
||||
Type string `json:"type"` // "oauth" or "api_key"
|
||||
APIKey string `json:"api_key,omitempty"` // For API key auth
|
||||
@@ -24,7 +28,8 @@ type AnthropicCredentials struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// IsExpired checks if the OAuth token is expired
|
||||
// IsExpired checks if the OAuth token is expired based on the ExpiresAt timestamp.
|
||||
// Returns false for API key authentication or if no expiration is set.
|
||||
func (c *AnthropicCredentials) IsExpired() bool {
|
||||
if c.Type != "oauth" || c.ExpiresAt == 0 {
|
||||
return false
|
||||
@@ -32,7 +37,10 @@ func (c *AnthropicCredentials) IsExpired() bool {
|
||||
return time.Now().Unix() >= c.ExpiresAt
|
||||
}
|
||||
|
||||
// NeedsRefresh checks if the OAuth token needs refresh (5 minutes before expiry)
|
||||
// NeedsRefresh checks if the OAuth token needs refresh, returning true if the token
|
||||
// will expire within the next 5 minutes. This allows for proactive token refresh
|
||||
// to avoid authentication failures during operations. Returns false for API key
|
||||
// authentication or if no expiration is set.
|
||||
func (c *AnthropicCredentials) NeedsRefresh() bool {
|
||||
if c.Type != "oauth" || c.ExpiresAt == 0 {
|
||||
return false
|
||||
@@ -40,12 +48,17 @@ func (c *AnthropicCredentials) NeedsRefresh() bool {
|
||||
return time.Now().Unix() >= (c.ExpiresAt - 300) // 5 minutes buffer
|
||||
}
|
||||
|
||||
// CredentialManager handles credential storage and retrieval
|
||||
// CredentialManager handles secure storage and retrieval of authentication credentials.
|
||||
// It manages a JSON file stored in the user's config directory with appropriate
|
||||
// file permissions for security.
|
||||
type CredentialManager struct {
|
||||
credentialsPath string
|
||||
}
|
||||
|
||||
// NewCredentialManager creates a new credential manager
|
||||
// NewCredentialManager creates a new credential manager instance. It determines
|
||||
// the appropriate credentials path based on XDG_CONFIG_HOME or falls back to
|
||||
// ~/.config/.mcphost/credentials.json. Returns an error if the home directory
|
||||
// cannot be determined.
|
||||
func NewCredentialManager() (*CredentialManager, error) {
|
||||
credentialsPath, err := getCredentialsPath()
|
||||
if err != nil {
|
||||
@@ -73,7 +86,9 @@ func getCredentialsPath() (string, error) {
|
||||
return filepath.Join(homeDir, ".config", ".mcphost", "credentials.json"), nil
|
||||
}
|
||||
|
||||
// LoadCredentials loads credentials from the file
|
||||
// LoadCredentials loads credentials from the JSON file. If the file doesn't exist,
|
||||
// it returns an empty CredentialStore instead of an error, allowing for graceful
|
||||
// initialization. Returns an error if the file exists but cannot be read or parsed.
|
||||
func (cm *CredentialManager) LoadCredentials() (*CredentialStore, error) {
|
||||
// If file doesn't exist, return empty store
|
||||
if _, err := os.Stat(cm.credentialsPath); os.IsNotExist(err) {
|
||||
@@ -93,7 +108,10 @@ func (cm *CredentialManager) LoadCredentials() (*CredentialStore, error) {
|
||||
return &store, nil
|
||||
}
|
||||
|
||||
// SaveCredentials saves credentials to the file
|
||||
// SaveCredentials saves credentials to the JSON file with secure permissions (0600).
|
||||
// It creates the parent directory if it doesn't exist. The file is written atomically
|
||||
// to prevent corruption. Returns an error if the directory cannot be created or the
|
||||
// file cannot be written.
|
||||
func (cm *CredentialManager) SaveCredentials(store *CredentialStore) error {
|
||||
// Ensure directory exists
|
||||
dir := filepath.Dir(cm.credentialsPath)
|
||||
@@ -114,7 +132,10 @@ func (cm *CredentialManager) SaveCredentials(store *CredentialStore) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetAnthropicCredentials stores Anthropic API credentials (for API key auth)
|
||||
// SetAnthropicCredentials stores Anthropic API key credentials. It validates the
|
||||
// API key format before storing. The API key must start with "sk-ant-" and be
|
||||
// at least 20 characters long. Returns an error if the API key is invalid or
|
||||
// if storage fails.
|
||||
func (cm *CredentialManager) SetAnthropicCredentials(apiKey string) error {
|
||||
if err := validateAnthropicAPIKey(apiKey); err != nil {
|
||||
return err
|
||||
@@ -134,7 +155,9 @@ func (cm *CredentialManager) SetAnthropicCredentials(apiKey string) error {
|
||||
return cm.SaveCredentials(store)
|
||||
}
|
||||
|
||||
// GetAnthropicCredentials retrieves Anthropic API credentials
|
||||
// GetAnthropicCredentials retrieves stored Anthropic credentials. Returns nil if
|
||||
// no credentials are stored. The returned credentials may be either OAuth or API
|
||||
// key type, check the Type field to determine which.
|
||||
func (cm *CredentialManager) GetAnthropicCredentials() (*AnthropicCredentials, error) {
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
@@ -144,7 +167,9 @@ func (cm *CredentialManager) GetAnthropicCredentials() (*AnthropicCredentials, e
|
||||
return store.Anthropic, nil
|
||||
}
|
||||
|
||||
// RemoveAnthropicCredentials removes stored Anthropic credentials
|
||||
// RemoveAnthropicCredentials removes stored Anthropic credentials from storage.
|
||||
// If this was the only credential stored, the entire credentials file is removed.
|
||||
// Returns an error if the removal fails.
|
||||
func (cm *CredentialManager) RemoveAnthropicCredentials() error {
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
@@ -164,7 +189,9 @@ func (cm *CredentialManager) RemoveAnthropicCredentials() error {
|
||||
return cm.SaveCredentials(store)
|
||||
}
|
||||
|
||||
// HasAnthropicCredentials checks if Anthropic credentials are stored
|
||||
// HasAnthropicCredentials checks if valid Anthropic credentials are stored.
|
||||
// Returns true if either a non-empty OAuth access token or API key is present,
|
||||
// false otherwise. Returns an error if credentials cannot be loaded.
|
||||
func (cm *CredentialManager) HasAnthropicCredentials() (bool, error) {
|
||||
creds, err := cm.GetAnthropicCredentials()
|
||||
if err != nil {
|
||||
@@ -185,7 +212,8 @@ func (cm *CredentialManager) HasAnthropicCredentials() (bool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetCredentialsPath returns the path to the credentials file
|
||||
// GetCredentialsPath returns the absolute path to the credentials JSON file.
|
||||
// This is useful for debugging or displaying the storage location to users.
|
||||
func (cm *CredentialManager) GetCredentialsPath() string {
|
||||
return cm.credentialsPath
|
||||
}
|
||||
@@ -210,8 +238,12 @@ func validateAnthropicAPIKey(apiKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAnthropicAPIKey is a convenience function that checks stored credentials first,
|
||||
// then falls back to environment variables and flags
|
||||
// GetAnthropicAPIKey retrieves an Anthropic API key from multiple sources in priority order:
|
||||
// 1. Command-line flag value (highest priority)
|
||||
// 2. Stored credentials (OAuth or API key)
|
||||
// 3. ANTHROPIC_API_KEY environment variable (lowest priority)
|
||||
// Returns the API key, a description of its source, and any error encountered.
|
||||
// For OAuth credentials, it automatically refreshes expired tokens.
|
||||
func GetAnthropicAPIKey(flagValue string) (string, string, error) {
|
||||
// 1. Check flag value first (highest priority)
|
||||
if flagValue != "" {
|
||||
|
||||
+34
-9
@@ -13,7 +13,9 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// OAuthClient handles OAuth authentication with Anthropic
|
||||
// OAuthClient handles OAuth 2.0 authentication flow with Anthropic using the
|
||||
// PKCE (Proof Key for Code Exchange) extension for enhanced security in public clients.
|
||||
// It manages the authorization URL generation, code exchange, and token refresh operations.
|
||||
type OAuthClient struct {
|
||||
ClientID string
|
||||
AuthorizeURL string
|
||||
@@ -22,13 +24,18 @@ type OAuthClient struct {
|
||||
Scopes string
|
||||
}
|
||||
|
||||
// AuthData contains authorization URL and PKCE verifier
|
||||
// AuthData contains the authorization URL for user authentication and the PKCE
|
||||
// verifier needed for the subsequent code exchange. The verifier must be stored
|
||||
// securely and used when exchanging the authorization code for tokens.
|
||||
type AuthData struct {
|
||||
URL string
|
||||
Verifier string
|
||||
}
|
||||
|
||||
// NewOAuthClient creates a new OAuth client with Anthropic configuration
|
||||
// NewOAuthClient creates a new OAuth client configured for Anthropic's OAuth service.
|
||||
// The client uses a public client ID (as per OAuth 2.0 public client specification)
|
||||
// with PKCE for security. The configuration includes the authorization endpoint,
|
||||
// token endpoint, redirect URI, and required scopes for API key creation and inference.
|
||||
func NewOAuthClient() *OAuthClient {
|
||||
return &OAuthClient{
|
||||
// OAuth client ID is public by design for CLI applications (OAuth public clients).
|
||||
@@ -42,7 +49,11 @@ func NewOAuthClient() *OAuthClient {
|
||||
}
|
||||
}
|
||||
|
||||
// GeneratePKCE generates PKCE verifier and challenge for OAuth flow
|
||||
// GeneratePKCE generates a cryptographically secure PKCE verifier and challenge pair
|
||||
// for the OAuth 2.0 PKCE flow. The verifier is a random 32-byte string encoded as
|
||||
// base64url, and the challenge is the SHA256 hash of the verifier, also base64url encoded.
|
||||
// Returns the verifier (to be stored securely), challenge (to be sent with auth request),
|
||||
// and any error encountered during generation.
|
||||
func GeneratePKCE() (verifier, challenge string, err error) {
|
||||
// Generate 32 bytes of random data
|
||||
verifierBytes := make([]byte, 32)
|
||||
@@ -60,7 +71,10 @@ func GeneratePKCE() (verifier, challenge string, err error) {
|
||||
return verifier, challenge, nil
|
||||
}
|
||||
|
||||
// GetAuthorizationURL generates the authorization URL with PKCE parameters
|
||||
// GetAuthorizationURL generates a complete authorization URL for the OAuth flow with
|
||||
// PKCE parameters. The URL includes the client ID, redirect URI, requested scopes,
|
||||
// and PKCE challenge. Returns an AuthData structure containing the URL for user
|
||||
// authentication and the PKCE verifier for the subsequent code exchange.
|
||||
func (c *OAuthClient) GetAuthorizationURL() (*AuthData, error) {
|
||||
verifier, challenge, err := GeneratePKCE()
|
||||
if err != nil {
|
||||
@@ -86,7 +100,10 @@ func (c *OAuthClient) GetAuthorizationURL() (*AuthData, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges an authorization code for tokens
|
||||
// ExchangeCode exchanges an authorization code for access and refresh tokens.
|
||||
// The code parameter should be the authorization code received from the OAuth callback.
|
||||
// The verifier parameter must be the same PKCE verifier generated during GetAuthorizationURL.
|
||||
// Returns AnthropicCredentials containing the tokens and expiration information.
|
||||
func (c *OAuthClient) ExchangeCode(code, verifier string) (*AnthropicCredentials, error) {
|
||||
// Parse code and state
|
||||
parsedCode, parsedState := c.parseCodeAndState(code)
|
||||
@@ -109,7 +126,10 @@ func (c *OAuthClient) ExchangeCode(code, verifier string) (*AnthropicCredentials
|
||||
return c.makeTokenRequest(reqBody)
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an access token using a refresh token
|
||||
// RefreshToken refreshes an expired or expiring access token using a refresh token.
|
||||
// Returns new AnthropicCredentials with updated access token, refresh token (may be
|
||||
// rotated), and new expiration timestamp. Returns an error if the refresh fails or
|
||||
// the refresh token is invalid.
|
||||
func (c *OAuthClient) RefreshToken(refreshToken string) (*AnthropicCredentials, error) {
|
||||
reqBody := map[string]interface{}{
|
||||
"grant_type": "refresh_token",
|
||||
@@ -179,7 +199,9 @@ func (c *OAuthClient) parseCodeAndState(code string) (parsedCode, parsedState st
|
||||
return
|
||||
}
|
||||
|
||||
// SetOAuthCredentials stores OAuth credentials
|
||||
// SetOAuthCredentials stores OAuth credentials in the credential manager's secure storage.
|
||||
// The credentials should include access token, refresh token, and expiration information.
|
||||
// Returns an error if the credentials cannot be saved.
|
||||
func (cm *CredentialManager) SetOAuthCredentials(creds *AnthropicCredentials) error {
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
@@ -190,7 +212,10 @@ func (cm *CredentialManager) SetOAuthCredentials(creds *AnthropicCredentials) er
|
||||
return cm.SaveCredentials(store)
|
||||
}
|
||||
|
||||
// GetValidAccessToken returns a valid access token, refreshing if necessary
|
||||
// GetValidAccessToken returns a valid access token for API requests. For OAuth credentials,
|
||||
// it automatically refreshes the token if it's expired or about to expire. For API key
|
||||
// credentials, it simply returns the API key. Returns an error if no credentials are found,
|
||||
// if token refresh fails, or if the credential type is unknown.
|
||||
func (cm *CredentialManager) GetValidAccessToken() (string, error) {
|
||||
creds, err := cm.GetAnthropicCredentials()
|
||||
if err != nil {
|
||||
|
||||
@@ -37,7 +37,10 @@ var bannedCommands = []string{
|
||||
"safari",
|
||||
}
|
||||
|
||||
// NewBashServer creates a new bash MCP server
|
||||
// NewBashServer creates a new MCP server that provides bash command execution capabilities.
|
||||
// The server includes a single tool "run_shell_cmd" that executes shell commands with
|
||||
// security restrictions, timeout controls, and output truncation. Returns an error if
|
||||
// server initialization fails.
|
||||
func NewBashServer() (*server.MCPServer, error) {
|
||||
s := server.NewMCPServer("bash-server", "1.0.0", server.WithToolCapabilities(true))
|
||||
|
||||
|
||||
@@ -21,7 +21,9 @@ const (
|
||||
maxFetchTimeout = 120 * time.Second
|
||||
)
|
||||
|
||||
// NewFetchServer creates a new fetch MCP server
|
||||
// NewFetchServer creates a new MCP server that provides web content fetching capabilities.
|
||||
// The server includes a single tool "fetch" that retrieves content from URLs and converts
|
||||
// it to text, markdown, or HTML format. Returns an error if server initialization fails.
|
||||
func NewFetchServer() (*server.MCPServer, error) {
|
||||
s := server.NewMCPServer("fetch-server", "1.0.0", server.WithToolCapabilities(true))
|
||||
|
||||
|
||||
@@ -28,7 +28,11 @@ const (
|
||||
// httpServerModel holds the model for the HTTP server
|
||||
var httpServerModel model.ToolCallingChatModel
|
||||
|
||||
// NewHTTPServer creates a new HTTP MCP server
|
||||
// NewHTTPServer creates a new MCP server providing advanced HTTP fetching capabilities.
|
||||
// The server includes tools for fetching web content, summarizing pages, extracting
|
||||
// specific information, and filtering JSON responses. If an LLM model is provided,
|
||||
// AI-powered summarization and extraction tools are enabled. Returns an error if
|
||||
// server initialization fails.
|
||||
func NewHTTPServer(llmModel model.ToolCallingChatModel) (*server.MCPServer, error) {
|
||||
// Store the model globally for use in tool handlers
|
||||
httpServerModel = llmModel
|
||||
|
||||
@@ -9,28 +9,36 @@ import (
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
|
||||
// BuiltinServerWrapper wraps an external MCP server for builtin use
|
||||
// BuiltinServerWrapper wraps an external MCP server for builtin use, providing
|
||||
// a consistent interface for all builtin servers regardless of their implementation.
|
||||
type BuiltinServerWrapper struct {
|
||||
server *server.MCPServer
|
||||
}
|
||||
|
||||
// Initialize initializes the wrapped server
|
||||
// Initialize initializes the wrapped server. For builtin servers, this is typically
|
||||
// a no-op as the server is initialized during creation. Returns an error if
|
||||
// initialization fails.
|
||||
func (w *BuiltinServerWrapper) Initialize() error {
|
||||
// The server is already initialized when created
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetServer returns the wrapped MCP server
|
||||
// GetServer returns the wrapped MCP server instance that can be used to handle
|
||||
// tool calls and other MCP protocol operations.
|
||||
func (w *BuiltinServerWrapper) GetServer() *server.MCPServer {
|
||||
return w.server
|
||||
}
|
||||
|
||||
// Registry holds all available builtin servers
|
||||
// Registry holds all available builtin servers and their factory functions.
|
||||
// It provides a centralized registry for creating instances of builtin MCP servers
|
||||
// with their respective configurations.
|
||||
type Registry struct {
|
||||
servers map[string]func(options map[string]any, model model.ToolCallingChatModel) (*BuiltinServerWrapper, error)
|
||||
}
|
||||
|
||||
// NewRegistry creates a new builtin server registry
|
||||
// NewRegistry creates a new builtin server registry with all available builtin
|
||||
// servers registered. The registry includes filesystem (fs), bash, todo, fetch,
|
||||
// and HTTP servers.
|
||||
func NewRegistry() *Registry {
|
||||
r := &Registry{
|
||||
servers: make(map[string]func(options map[string]any, model model.ToolCallingChatModel) (*BuiltinServerWrapper, error)),
|
||||
@@ -46,7 +54,10 @@ func NewRegistry() *Registry {
|
||||
return r
|
||||
}
|
||||
|
||||
// CreateServer creates a new instance of a builtin server
|
||||
// CreateServer creates a new instance of a builtin server by name. The options
|
||||
// parameter provides server-specific configuration, and the model parameter provides
|
||||
// an optional LLM for AI-powered features. Returns an error if the server name
|
||||
// is unknown or if creation fails.
|
||||
func (r *Registry) CreateServer(name string, options map[string]any, model model.ToolCallingChatModel) (*BuiltinServerWrapper, error) {
|
||||
factory, exists := r.servers[name]
|
||||
if !exists {
|
||||
@@ -56,7 +67,8 @@ func (r *Registry) CreateServer(name string, options map[string]any, model model
|
||||
return factory(options, model)
|
||||
}
|
||||
|
||||
// ListServers returns a list of available builtin server names
|
||||
// ListServers returns a list of all available builtin server names that can be
|
||||
// created using CreateServer. The order of names is not guaranteed.
|
||||
func (r *Registry) ListServers() []string {
|
||||
names := make([]string, 0, len(r.servers))
|
||||
for name := range r.servers {
|
||||
|
||||
@@ -11,7 +11,9 @@ import (
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
|
||||
// TodoInfo represents a single todo item
|
||||
// TodoInfo represents a single todo item with content, status, priority, and ID.
|
||||
// Status can be "pending", "in_progress", or "completed". Priority can be "high",
|
||||
// "medium", or "low". Each todo must have a unique ID.
|
||||
type TodoInfo struct {
|
||||
Content string `json:"content"`
|
||||
Status string `json:"status"`
|
||||
@@ -19,13 +21,18 @@ type TodoInfo struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
// TodoServer implements a todo management MCP server with in-memory storage
|
||||
// TodoServer implements a todo management MCP server with in-memory storage.
|
||||
// It provides thread-safe operations for reading and writing todo lists, with
|
||||
// support for task status tracking and priority levels.
|
||||
type TodoServer struct {
|
||||
todos []TodoInfo
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewTodoServer creates a new todo MCP server with in-memory storage
|
||||
// NewTodoServer creates a new MCP server that provides todo list management capabilities.
|
||||
// The server includes two tools: "todowrite" for updating the todo list and "todoread"
|
||||
// for retrieving the current list. Todos are stored in memory and not persisted.
|
||||
// Returns an error if server initialization fails.
|
||||
func NewTodoServer() (*server.MCPServer, error) {
|
||||
todoServer := &TodoServer{
|
||||
todos: make([]TodoInfo, 0),
|
||||
|
||||
@@ -11,7 +11,9 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// MCPServerConfig represents configuration for an MCP server
|
||||
// MCPServerConfig represents configuration for an MCP server, supporting both
|
||||
// local (stdio), remote (StreamableHTTP/SSE), and builtin (in-process) server types.
|
||||
// It maintains backward compatibility with legacy configuration formats.
|
||||
type MCPServerConfig struct {
|
||||
Type string `json:"type"`
|
||||
Command []string `json:"command,omitempty"`
|
||||
@@ -29,7 +31,9 @@ type MCPServerConfig struct {
|
||||
Headers []string `json:"headers,omitempty"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON handles both new and legacy config formats
|
||||
// UnmarshalJSON handles both new and legacy config formats for backward compatibility.
|
||||
// New format uses "type" field with "local", "remote", or "builtin" values.
|
||||
// Legacy format uses "transport", "command", "args", and "env" fields.
|
||||
func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
|
||||
// First try to unmarshal as the new format
|
||||
type newFormat struct {
|
||||
@@ -100,11 +104,15 @@ func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AdaptiveColor represents a color that adapts to light and dark themes.
|
||||
// Either light or dark can be specified, or both for theme-aware coloring.
|
||||
type AdaptiveColor struct {
|
||||
Light string `json:"light,omitempty" yaml:"light,omitempty"`
|
||||
Dark string `json:"dark,omitempty" yaml:"dark,omitempty"`
|
||||
}
|
||||
|
||||
// Theme defines the color scheme for the application UI with adaptive colors
|
||||
// that support both light and dark modes.
|
||||
type Theme struct {
|
||||
Primary AdaptiveColor `json:"primary" yaml:"primary"`
|
||||
Secondary AdaptiveColor `json:"secondary" yaml:"secondary"`
|
||||
@@ -124,6 +132,8 @@ type Theme struct {
|
||||
Highlight AdaptiveColor `json:"highlight" yaml:"highlight"`
|
||||
}
|
||||
|
||||
// MarkdownTheme defines the color scheme for markdown rendering with syntax
|
||||
// highlighting support and adaptive colors for light and dark modes.
|
||||
type MarkdownTheme struct {
|
||||
Text AdaptiveColor `json:"text" yaml:"text"`
|
||||
Muted AdaptiveColor `json:"muted" yaml:"muted"`
|
||||
@@ -139,7 +149,9 @@ type MarkdownTheme struct {
|
||||
Comment AdaptiveColor `json:"comment" yaml:"comment"`
|
||||
}
|
||||
|
||||
// Config represents the application configuration
|
||||
// Config represents the complete application configuration including MCP servers,
|
||||
// model settings, UI preferences, and API credentials. It supports both command-line
|
||||
// flags and configuration file settings.
|
||||
type Config struct {
|
||||
MCPServers map[string]MCPServerConfig `json:"mcpServers" yaml:"mcpServers"`
|
||||
Model string `json:"model,omitempty" yaml:"model,omitempty"`
|
||||
@@ -166,7 +178,9 @@ type Config struct {
|
||||
TLSSkipVerify bool `json:"tls-skip-verify,omitempty" yaml:"tls-skip-verify,omitempty"`
|
||||
}
|
||||
|
||||
// GetTransportType returns the transport type for the server config
|
||||
// GetTransportType returns the transport type for the server config, mapping
|
||||
// simplified type names to actual transport protocols. Supports legacy format
|
||||
// detection and automatic type inference from configuration.
|
||||
func (s *MCPServerConfig) GetTransportType() string {
|
||||
// Legacy format support - check explicit transport first
|
||||
if s.Transport != "" {
|
||||
@@ -197,7 +211,9 @@ func (s *MCPServerConfig) GetTransportType() string {
|
||||
return "stdio" // default
|
||||
}
|
||||
|
||||
// Validate validates the configuration
|
||||
// Validate validates the configuration, ensuring required fields are present
|
||||
// for each server type and that tool filters are used correctly. Returns an
|
||||
// error describing any validation failures.
|
||||
func (c *Config) Validate() error {
|
||||
for serverName, serverConfig := range c.MCPServers {
|
||||
if len(serverConfig.AllowedTools) > 0 && len(serverConfig.ExcludedTools) > 0 {
|
||||
@@ -226,7 +242,9 @@ func (c *Config) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadSystemPrompt loads system prompt from file or returns the string directly
|
||||
// LoadSystemPrompt loads system prompt from file or returns the string directly.
|
||||
// If input is a path to an existing file, its contents are read and returned.
|
||||
// Otherwise, the input string is returned as-is.
|
||||
func LoadSystemPrompt(input string) (string, error) {
|
||||
if input == "" {
|
||||
return "", nil
|
||||
@@ -246,7 +264,9 @@ func LoadSystemPrompt(input string) (string, error) {
|
||||
return input, nil
|
||||
}
|
||||
|
||||
// EnsureConfigExists checks if a config file exists and creates a default one if not
|
||||
// EnsureConfigExists checks if a config file exists and creates a default one if not.
|
||||
// It searches for .mcphost.{yml,yaml,json} or legacy .mcp.{yml,yaml,json} files in
|
||||
// the user's home directory. If none exist, creates a default .mcphost.yml with examples.
|
||||
func EnsureConfigExists() error {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
@@ -378,6 +398,10 @@ mcpServers:
|
||||
return nil
|
||||
}
|
||||
|
||||
// FilepathOr reads a configuration value that can be either a direct value or a
|
||||
// filepath to a JSON/YAML file containing the value. If the value is a string
|
||||
// starting with "~/" or a relative path, it's expanded to an absolute path.
|
||||
// The contents of the file are then unmarshaled into the provided value pointer.
|
||||
func FilepathOr[T any](key string, value *T) error {
|
||||
var field any
|
||||
err := viper.UnmarshalKey(key, &field)
|
||||
@@ -428,6 +452,9 @@ func FilepathOr[T any](key string, value *T) error {
|
||||
|
||||
var configPath string
|
||||
|
||||
// SetConfigPath sets the configuration file path for resolving relative paths
|
||||
// in configuration values. This should be called when the configuration file
|
||||
// location is known.
|
||||
func SetConfigPath(path string) {
|
||||
configPath = path
|
||||
}
|
||||
|
||||
@@ -7,7 +7,9 @@ import (
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// MergeConfigs merges script frontmatter config with base config
|
||||
// MergeConfigs merges script frontmatter config with base config, allowing scripts
|
||||
// to override MCP server configurations. The script config takes precedence over
|
||||
// the base config for any fields that are specified.
|
||||
func MergeConfigs(baseConfig *Config, scriptConfig *Config) *Config {
|
||||
merged := *baseConfig // Copy base config
|
||||
|
||||
@@ -20,7 +22,9 @@ func MergeConfigs(baseConfig *Config, scriptConfig *Config) *Config {
|
||||
return &merged
|
||||
}
|
||||
|
||||
// LoadAndValidateConfig loads config from viper and validates it
|
||||
// LoadAndValidateConfig loads configuration from viper, fixes environment variable
|
||||
// casing issues, and validates the configuration. Returns an error if loading or
|
||||
// validation fails.
|
||||
func LoadAndValidateConfig() (*Config, error) {
|
||||
config := &Config{
|
||||
MCPServers: make(map[string]MCPServerConfig),
|
||||
|
||||
@@ -24,10 +24,13 @@ func parseVariableWithDefault(varPart string) (varName, defaultValue string, has
|
||||
return varPart, "", false
|
||||
}
|
||||
|
||||
// EnvSubstituter handles environment variable substitution
|
||||
// EnvSubstituter handles environment variable substitution in configuration strings,
|
||||
// supporting both ${env://VAR} and ${env://VAR:-default} patterns.
|
||||
type EnvSubstituter struct{}
|
||||
|
||||
// SubstituteEnvVars replaces ${env://VAR} and ${env://VAR:-default} patterns with environment variables
|
||||
// SubstituteEnvVars replaces ${env://VAR} and ${env://VAR:-default} patterns with environment variables.
|
||||
// If a variable is not set and has a default value, the default is used. Returns an error
|
||||
// if required variables (those without defaults) are not set.
|
||||
func (e *EnvSubstituter) SubstituteEnvVars(content string) (string, error) {
|
||||
var errors []string
|
||||
|
||||
@@ -57,17 +60,21 @@ func (e *EnvSubstituter) SubstituteEnvVars(content string) (string, error) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ArgsSubstituter handles script argument substitution
|
||||
// ArgsSubstituter handles script argument substitution in configuration strings,
|
||||
// supporting both ${VAR} and ${VAR:-default} patterns for template variable replacement.
|
||||
type ArgsSubstituter struct {
|
||||
args map[string]string
|
||||
}
|
||||
|
||||
// NewArgsSubstituter creates a new args substituter with the given arguments
|
||||
// NewArgsSubstituter creates a new args substituter with the given arguments map.
|
||||
// The arguments are used to replace template variables in configuration strings.
|
||||
func NewArgsSubstituter(args map[string]string) *ArgsSubstituter {
|
||||
return &ArgsSubstituter{args: args}
|
||||
}
|
||||
|
||||
// SubstituteArgs replaces ${VAR} and ${VAR:-default} patterns with script arguments
|
||||
// SubstituteArgs replaces ${VAR} and ${VAR:-default} patterns with script arguments.
|
||||
// If an argument is not provided and has a default value, the default is used.
|
||||
// Returns an error if required arguments (those without defaults) are not provided.
|
||||
func (a *ArgsSubstituter) SubstituteArgs(content string) (string, error) {
|
||||
var errors []string
|
||||
|
||||
@@ -97,12 +104,14 @@ func (a *ArgsSubstituter) SubstituteArgs(content string) (string, error) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// HasEnvVars checks if content contains environment variable patterns
|
||||
// HasEnvVars checks if content contains environment variable patterns (${env://...}).
|
||||
// This is useful for determining if substitution is needed before processing.
|
||||
func HasEnvVars(content string) bool {
|
||||
return envVarPattern.MatchString(content)
|
||||
}
|
||||
|
||||
// HasScriptArgs checks if content contains script argument patterns
|
||||
// HasScriptArgs checks if content contains script argument patterns (${...}).
|
||||
// This is useful for determining if argument substitution is needed before processing.
|
||||
func HasScriptArgs(content string) bool {
|
||||
return scriptArgsPattern.MatchString(content)
|
||||
}
|
||||
|
||||
@@ -9,26 +9,35 @@ import (
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// HookConfig represents the complete hooks configuration
|
||||
// HookConfig represents the complete hooks configuration containing event-triggered
|
||||
// hooks for tool execution lifecycle events.
|
||||
type HookConfig struct {
|
||||
Hooks map[HookEvent][]HookMatcher `yaml:"hooks" json:"hooks"`
|
||||
}
|
||||
|
||||
// HookMatcher matches specific tools and defines hooks to execute
|
||||
// HookMatcher matches specific tools and defines hooks to execute. The Matcher field
|
||||
// contains a pattern to match tool names, and Hooks contains the commands to run
|
||||
// when a match occurs. The Merge field controls how this matcher combines with others.
|
||||
type HookMatcher struct {
|
||||
Matcher string `yaml:"matcher,omitempty" json:"matcher,omitempty"`
|
||||
Merge string `yaml:"_merge,omitempty" json:"_merge,omitempty"`
|
||||
Hooks []HookEntry `yaml:"hooks" json:"hooks"`
|
||||
}
|
||||
|
||||
// HookEntry defines a single hook command
|
||||
// HookEntry defines a single hook command to execute. Type specifies the command
|
||||
// type (e.g., "bash"), Command contains the actual command to run, and Timeout
|
||||
// optionally specifies the maximum execution time in seconds.
|
||||
type HookEntry struct {
|
||||
Type string `yaml:"type" json:"type"`
|
||||
Command string `yaml:"command" json:"command"`
|
||||
Timeout int `yaml:"timeout,omitempty" json:"timeout,omitempty"`
|
||||
}
|
||||
|
||||
// LoadHooksConfig loads and merges hook configurations from multiple sources
|
||||
// LoadHooksConfig loads and merges hook configurations from multiple sources.
|
||||
// It searches for hooks.{json,yml} files in standard locations (XDG config directory,
|
||||
// local .mcphost directory) and any custom paths provided. Configurations are merged
|
||||
// with later sources taking precedence. Environment variable substitution is applied
|
||||
// to all loaded configurations.
|
||||
func LoadHooksConfig(customPaths ...string) (*HookConfig, error) {
|
||||
// Get config directory following XDG Base Directory specification
|
||||
configDir := getConfigDir()
|
||||
|
||||
@@ -1,23 +1,25 @@
|
||||
package hooks
|
||||
|
||||
// HookEvent represents a point in MCPHost's lifecycle where hooks can be executed
|
||||
// HookEvent represents a point in MCPHost's lifecycle where hooks can be executed.
|
||||
// Events can be tool-related (requiring matchers) or lifecycle-related.
|
||||
type HookEvent string
|
||||
|
||||
const (
|
||||
// PreToolUse fires before any tool execution
|
||||
// PreToolUse fires before any tool execution, allowing pre-processing or validation
|
||||
PreToolUse HookEvent = "PreToolUse"
|
||||
|
||||
// PostToolUse fires after tool execution completes
|
||||
// PostToolUse fires after tool execution completes, allowing post-processing or logging
|
||||
PostToolUse HookEvent = "PostToolUse"
|
||||
|
||||
// UserPromptSubmit fires when user submits a prompt
|
||||
// UserPromptSubmit fires when user submits a prompt, before agent processing
|
||||
UserPromptSubmit HookEvent = "UserPromptSubmit"
|
||||
|
||||
// Stop fires when the main agent finishes responding
|
||||
// Stop fires when the main agent finishes responding to a user prompt
|
||||
Stop HookEvent = "Stop"
|
||||
)
|
||||
|
||||
// IsValid returns true if the event is a valid hook event
|
||||
// IsValid returns true if the event is a valid hook event.
|
||||
// Valid events are PreToolUse, PostToolUse, UserPromptSubmit, and Stop.
|
||||
func (e HookEvent) IsValid() bool {
|
||||
switch e {
|
||||
case PreToolUse, PostToolUse, UserPromptSubmit, Stop:
|
||||
@@ -26,7 +28,9 @@ func (e HookEvent) IsValid() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// RequiresMatcher returns true if the event uses tool matchers
|
||||
// RequiresMatcher returns true if the event uses tool matchers.
|
||||
// PreToolUse and PostToolUse events require matchers to determine which
|
||||
// tools trigger the hooks. Other events apply globally without matchers.
|
||||
func (e HookEvent) RequiresMatcher() bool {
|
||||
return e == PreToolUse || e == PostToolUse
|
||||
}
|
||||
|
||||
@@ -12,7 +12,9 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Executor handles hook execution
|
||||
// Executor handles hook execution for MCPHost lifecycle events. It manages
|
||||
// hook configuration, executes matching hooks in parallel, and processes
|
||||
// their outputs to determine application behavior.
|
||||
type Executor struct {
|
||||
config *HookConfig
|
||||
sessionID string
|
||||
@@ -22,7 +24,9 @@ type Executor struct {
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewExecutor creates a new hook executor
|
||||
// NewExecutor creates a new hook executor with the given configuration,
|
||||
// session ID, and transcript path. The executor manages hook execution
|
||||
// throughout the application lifecycle.
|
||||
func NewExecutor(config *HookConfig, sessionID, transcriptPath string) *Executor {
|
||||
return &Executor{
|
||||
config: config,
|
||||
@@ -31,21 +35,25 @@ func NewExecutor(config *HookConfig, sessionID, transcriptPath string) *Executor
|
||||
}
|
||||
}
|
||||
|
||||
// SetModel sets the model name for hook context
|
||||
// SetModel sets the model name for hook context. This information is passed
|
||||
// to hooks as part of their input data for context-aware processing.
|
||||
func (e *Executor) SetModel(model string) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.model = model
|
||||
}
|
||||
|
||||
// SetInteractive sets whether we're in interactive mode
|
||||
// SetInteractive sets whether the application is running in interactive mode.
|
||||
// This information is passed to hooks for mode-specific behavior.
|
||||
func (e *Executor) SetInteractive(interactive bool) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.interactive = interactive
|
||||
}
|
||||
|
||||
// PopulateCommonFields fills in the common fields for any hook input
|
||||
// PopulateCommonFields fills in the common fields for any hook input, including
|
||||
// session ID, transcript path, working directory, event name, timestamp, model,
|
||||
// and interactive mode. These fields provide context to hooks regardless of event type.
|
||||
func (e *Executor) PopulateCommonFields(event HookEvent) CommonInput {
|
||||
e.mu.RLock()
|
||||
defer e.mu.RUnlock()
|
||||
@@ -62,7 +70,10 @@ func (e *Executor) PopulateCommonFields(event HookEvent) CommonInput {
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteHooks runs all matching hooks for an event
|
||||
// ExecuteHooks runs all matching hooks for an event. For tool-related events,
|
||||
// it matches hooks based on tool name patterns. Hooks are executed in parallel
|
||||
// with configurable timeouts. Returns a combined HookOutput from all executed
|
||||
// hooks, with blocking decisions taking precedence.
|
||||
func (e *Executor) ExecuteHooks(ctx context.Context, event HookEvent, input interface{}) (*HookOutput, error) {
|
||||
matchers, ok := e.config.Hooks[event]
|
||||
if !ok || len(matchers) == 0 {
|
||||
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// CommonInput contains fields common to all hook inputs
|
||||
// CommonInput contains fields common to all hook inputs, providing context
|
||||
// information that is available to every hook regardless of the event type.
|
||||
// These fields help hooks understand the execution environment and session state.
|
||||
type CommonInput struct {
|
||||
SessionID string `json:"session_id"` // Unique session identifier
|
||||
TranscriptPath string `json:"transcript_path"` // Path to transcript file (if enabled)
|
||||
@@ -15,14 +17,18 @@ type CommonInput struct {
|
||||
Interactive bool `json:"interactive"` // Whether in interactive mode
|
||||
}
|
||||
|
||||
// PreToolUseInput is passed to PreToolUse hooks
|
||||
// PreToolUseInput is passed to PreToolUse hooks before a tool is executed.
|
||||
// It contains the tool name and input parameters, allowing hooks to validate,
|
||||
// modify, or block tool execution.
|
||||
type PreToolUseInput struct {
|
||||
CommonInput
|
||||
ToolName string `json:"tool_name"`
|
||||
ToolInput json.RawMessage `json:"tool_input"`
|
||||
}
|
||||
|
||||
// PostToolUseInput is passed to PostToolUse hooks
|
||||
// PostToolUseInput is passed to PostToolUse hooks after a tool has been executed.
|
||||
// It contains the tool name, input parameters, and the tool's response, allowing
|
||||
// hooks to log, analyze, or react to tool execution results.
|
||||
type PostToolUseInput struct {
|
||||
CommonInput
|
||||
ToolName string `json:"tool_name"`
|
||||
@@ -30,13 +36,17 @@ type PostToolUseInput struct {
|
||||
ToolResponse json.RawMessage `json:"tool_response"`
|
||||
}
|
||||
|
||||
// UserPromptSubmitInput is passed to UserPromptSubmit hooks
|
||||
// UserPromptSubmitInput is passed to UserPromptSubmit hooks when a user submits
|
||||
// a prompt. It contains the user's input text, allowing hooks to validate,
|
||||
// modify, or log user interactions before processing.
|
||||
type UserPromptSubmitInput struct {
|
||||
CommonInput
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
// StopInput is passed to Stop hooks
|
||||
// StopInput is passed to Stop hooks when the agent finishes responding to a prompt.
|
||||
// It contains the final response, completion reason, and optional metadata about
|
||||
// the interaction, allowing hooks to perform cleanup or logging operations.
|
||||
type StopInput struct {
|
||||
CommonInput
|
||||
StopHookActive bool `json:"stop_hook_active"`
|
||||
@@ -45,7 +55,10 @@ type StopInput struct {
|
||||
Meta json.RawMessage `json:"meta,omitempty"` // Additional metadata (e.g., token usage, model info)
|
||||
}
|
||||
|
||||
// HookOutput represents the JSON output from a hook
|
||||
// HookOutput represents the JSON output from a hook that controls MCPHost behavior.
|
||||
// Hooks can decide whether to continue execution, provide reasons for stopping,
|
||||
// suppress output, or block tool execution. The Decision field can be "approve",
|
||||
// "block", or empty (default behavior).
|
||||
type HookOutput struct {
|
||||
Continue *bool `json:"continue,omitempty"`
|
||||
StopReason string `json:"stopReason,omitempty"`
|
||||
|
||||
@@ -68,7 +68,10 @@ func containsDangerousPattern(command string) bool {
|
||||
return separatorCount > 2
|
||||
}
|
||||
|
||||
// ValidateHookConfig validates the entire hook configuration
|
||||
// ValidateHookConfig validates the entire hook configuration for correctness
|
||||
// and security. It checks event validity, regex patterns, hook definitions,
|
||||
// and performs security validation on all commands. Returns an error describing
|
||||
// any validation failures.
|
||||
func ValidateHookConfig(config *HookConfig) error {
|
||||
if config == nil {
|
||||
return fmt.Errorf("nil configuration")
|
||||
|
||||
@@ -13,17 +13,42 @@ import (
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// CustomChatModel wraps the eino-ext Claude model with custom tool schema handling
|
||||
// CustomChatModel wraps the eino-ext Claude model with custom tool schema handling.
|
||||
// It provides a compatibility layer that fixes malformed JSON in tool calls and
|
||||
// ensures proper schema validation for Anthropic's API requirements.
|
||||
// This wrapper is necessary to handle edge cases where the underlying library
|
||||
// may generate invalid JSON for empty tool inputs or missing properties.
|
||||
type CustomChatModel struct {
|
||||
// wrapped is the underlying eino-ext Claude model instance
|
||||
wrapped *einoclaude.ChatModel
|
||||
}
|
||||
|
||||
// CustomRoundTripper intercepts HTTP requests to fix Anthropic function schemas
|
||||
// CustomRoundTripper intercepts HTTP requests to fix Anthropic function schemas.
|
||||
// It acts as a middleware that modifies requests before they reach the Anthropic API,
|
||||
// ensuring that tool schemas and function calls are properly formatted.
|
||||
// This is particularly important for handling edge cases like empty tool inputs
|
||||
// or missing schema properties that would otherwise cause API errors.
|
||||
type CustomRoundTripper struct {
|
||||
// wrapped is the underlying HTTP transport to use for actual requests
|
||||
wrapped http.RoundTripper
|
||||
}
|
||||
|
||||
// NewCustomChatModel creates a new custom Anthropic chat model
|
||||
// NewCustomChatModel creates a new custom Anthropic chat model.
|
||||
// It wraps the standard eino-ext Claude model with additional request
|
||||
// preprocessing to ensure compatibility with Anthropic's API requirements.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the operation
|
||||
// - config: Configuration for the Claude model including API key, model name, and parameters
|
||||
//
|
||||
// Returns:
|
||||
// - *CustomChatModel: A wrapped Claude model with enhanced compatibility
|
||||
// - error: Returns an error if model creation fails
|
||||
//
|
||||
// The custom model automatically:
|
||||
// - Fixes malformed JSON in tool calls
|
||||
// - Ensures tool schemas have required properties
|
||||
// - Handles empty or missing input fields in function calls
|
||||
func NewCustomChatModel(ctx context.Context, config *einoclaude.Config) (*CustomChatModel, error) {
|
||||
// Create a custom HTTP client that intercepts requests
|
||||
if config.HTTPClient == nil {
|
||||
@@ -49,7 +74,21 @@ func NewCustomChatModel(ctx context.Context, config *einoclaude.Config) (*Custom
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RoundTrip implements http.RoundTripper to intercept and fix requests
|
||||
// RoundTrip implements http.RoundTripper to intercept and fix requests.
|
||||
// It preprocesses outgoing requests to the Anthropic API to ensure
|
||||
// they meet the API's requirements for tool schemas and function calls.
|
||||
//
|
||||
// Parameters:
|
||||
// - req: The HTTP request to be sent to the Anthropic API
|
||||
//
|
||||
// Returns:
|
||||
// - *http.Response: The response from the Anthropic API
|
||||
// - error: Any error that occurred during the request
|
||||
//
|
||||
// The method performs the following fixes:
|
||||
// - Ensures tool input_schema properties are not null
|
||||
// - Fixes malformed JSON patterns in tool_use content
|
||||
// - Validates and corrects empty or invalid function call inputs
|
||||
func (rt *CustomRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Only process Anthropic API requests
|
||||
if !strings.Contains(req.URL.Host, "anthropic.com") {
|
||||
@@ -191,17 +230,47 @@ func (rt *CustomRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
|
||||
return rt.wrapped.RoundTrip(req)
|
||||
}
|
||||
|
||||
// Generate implements the model.BaseChatModel interface
|
||||
// Generate implements the model.BaseChatModel interface.
|
||||
// It generates a single response from the model based on the input messages.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the operation, supporting cancellation and deadlines
|
||||
// - input: The conversation history as a slice of messages
|
||||
// - opts: Optional configuration options for the generation
|
||||
//
|
||||
// Returns:
|
||||
// - *schema.Message: The generated response message
|
||||
// - error: Any error that occurred during generation
|
||||
func (m *CustomChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) {
|
||||
return m.wrapped.Generate(ctx, input, opts...)
|
||||
}
|
||||
|
||||
// Stream implements the model.BaseChatModel interface
|
||||
// Stream implements the model.BaseChatModel interface.
|
||||
// It generates a streaming response from the model, allowing incremental
|
||||
// processing of the model's output as it's generated.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the operation, supporting cancellation and deadlines
|
||||
// - input: The conversation history as a slice of messages
|
||||
// - opts: Optional configuration options for the generation
|
||||
//
|
||||
// Returns:
|
||||
// - *schema.StreamReader[*schema.Message]: A reader for the streaming response
|
||||
// - error: Any error that occurred during stream setup
|
||||
func (m *CustomChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) {
|
||||
return m.wrapped.Stream(ctx, input, opts...)
|
||||
}
|
||||
|
||||
// WithTools implements the model.ToolCallingChatModel interface
|
||||
// WithTools implements the model.ToolCallingChatModel interface.
|
||||
// It creates a new model instance with the specified tools available for function calling.
|
||||
// The original model instance remains unchanged.
|
||||
//
|
||||
// Parameters:
|
||||
// - tools: A slice of tool definitions that the model can use
|
||||
//
|
||||
// Returns:
|
||||
// - model.ToolCallingChatModel: A new model instance with tools enabled
|
||||
// - error: Returns an error if tool binding fails
|
||||
func (m *CustomChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) {
|
||||
wrappedWithTools, err := m.wrapped.WithTools(tools)
|
||||
if err != nil {
|
||||
|
||||
@@ -17,21 +17,27 @@ import (
|
||||
|
||||
var _ model.ToolCallingChatModel = (*ChatModel)(nil)
|
||||
|
||||
// NewChatModel creates a new Gemini chat model instance
|
||||
// NewChatModel creates a new Gemini chat model instance.
|
||||
// It initializes a Google Gemini model with the specified configuration,
|
||||
// supporting both text generation and tool calling capabilities.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context for the operation
|
||||
// - cfg: Configuration for the Gemini model
|
||||
// - ctx: The context for the operation (currently unused but kept for interface consistency)
|
||||
// - cfg: Configuration for the Gemini model including client, model name, and parameters
|
||||
//
|
||||
// Returns:
|
||||
// - model.ChatModel: A chat model interface implementation
|
||||
// - *ChatModel: A Gemini chat model instance implementing ToolCallingChatModel
|
||||
// - error: Any error that occurred during creation
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// client, _ := genai.NewClient(ctx, &genai.ClientConfig{
|
||||
// APIKey: "your-api-key",
|
||||
// })
|
||||
// model, err := gemini.NewChatModel(ctx, &gemini.Config{
|
||||
// Client: client,
|
||||
// Model: "gemini-pro",
|
||||
// MaxTokens: &maxTokens,
|
||||
// })
|
||||
func NewChatModel(_ context.Context, cfg *Config) (*ChatModel, error) {
|
||||
return &ChatModel{
|
||||
@@ -90,28 +96,58 @@ type Config struct {
|
||||
SafetySettings []*genai.SafetySetting
|
||||
}
|
||||
|
||||
// options contains Gemini-specific options for model configuration
|
||||
// options contains Gemini-specific options for model configuration.
|
||||
// These are options that are specific to the Gemini API and not part
|
||||
// of the common model options interface.
|
||||
type options struct {
|
||||
TopK *int32
|
||||
// TopK limits the number of tokens to sample from
|
||||
TopK *int32
|
||||
// ResponseSchema defines the expected JSON structure for responses
|
||||
ResponseSchema *openapi3.Schema
|
||||
}
|
||||
|
||||
// ChatModel implements the Gemini chat model for the eino framework.
|
||||
// It provides integration with Google's Gemini API, supporting both
|
||||
// text generation and tool calling capabilities.
|
||||
type ChatModel struct {
|
||||
// cli is the Gemini API client instance
|
||||
cli *genai.Client
|
||||
|
||||
model string
|
||||
maxTokens *int
|
||||
topP *float32
|
||||
temperature *float32
|
||||
topK *int32
|
||||
responseSchema *openapi3.Schema
|
||||
tools []*genai.Tool
|
||||
origTools []*schema.ToolInfo
|
||||
toolChoice *schema.ToolChoice
|
||||
// model specifies which Gemini model to use
|
||||
model string
|
||||
// maxTokens limits the response length
|
||||
maxTokens *int
|
||||
// topP controls nucleus sampling
|
||||
topP *float32
|
||||
// temperature controls randomness
|
||||
temperature *float32
|
||||
// topK limits token sampling
|
||||
topK *int32
|
||||
// responseSchema for structured JSON output
|
||||
responseSchema *openapi3.Schema
|
||||
// tools converted to Gemini format
|
||||
tools []*genai.Tool
|
||||
// origTools stores the original tool definitions
|
||||
origTools []*schema.ToolInfo
|
||||
// toolChoice controls how tools are used
|
||||
toolChoice *schema.ToolChoice
|
||||
// enableCodeExecution allows code execution (use with caution)
|
||||
enableCodeExecution bool
|
||||
safetySettings []*genai.SafetySetting
|
||||
// safetySettings for content filtering
|
||||
safetySettings []*genai.SafetySetting
|
||||
}
|
||||
|
||||
// Generate generates a single response from the Gemini model.
|
||||
// It processes the input messages and returns a complete response.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the operation, supporting cancellation and callbacks
|
||||
// - input: The conversation history as a slice of messages
|
||||
// - opts: Optional configuration options for the generation
|
||||
//
|
||||
// Returns:
|
||||
// - *schema.Message: The generated response message with content and metadata
|
||||
// - error: Any error that occurred during generation
|
||||
func (cm *ChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (message *schema.Message, err error) {
|
||||
ctx = callbacks.EnsureRunInfo(ctx, cm.GetType(), components.ComponentOfChatModel)
|
||||
|
||||
@@ -154,6 +190,17 @@ func (cm *ChatModel) Generate(ctx context.Context, input []*schema.Message, opts
|
||||
return message, nil
|
||||
}
|
||||
|
||||
// Stream generates a streaming response from the Gemini model.
|
||||
// It allows incremental processing of the model's output as it's generated.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the operation, supporting cancellation and callbacks
|
||||
// - input: The conversation history as a slice of messages
|
||||
// - opts: Optional configuration options for the generation
|
||||
//
|
||||
// Returns:
|
||||
// - *schema.StreamReader[*schema.Message]: A reader for the streaming response
|
||||
// - error: Any error that occurred during stream setup
|
||||
func (cm *ChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (result *schema.StreamReader[*schema.Message], err error) {
|
||||
ctx = callbacks.EnsureRunInfo(ctx, cm.GetType(), components.ComponentOfChatModel)
|
||||
|
||||
@@ -218,6 +265,16 @@ func (cm *ChatModel) Stream(ctx context.Context, input []*schema.Message, opts .
|
||||
}), nil
|
||||
}
|
||||
|
||||
// WithTools creates a new model instance with the specified tools available.
|
||||
// It returns a new ChatModel with tools configured for function calling.
|
||||
// The original model instance remains unchanged.
|
||||
//
|
||||
// Parameters:
|
||||
// - tools: A slice of tool definitions that the model can use
|
||||
//
|
||||
// Returns:
|
||||
// - model.ToolCallingChatModel: A new model instance with tools enabled
|
||||
// - error: Returns an error if no tools provided or conversion fails
|
||||
func (cm *ChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) {
|
||||
if len(tools) == 0 {
|
||||
return nil, errors.New("no tools to bind")
|
||||
@@ -235,6 +292,15 @@ func (cm *ChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatM
|
||||
return &ncm, nil
|
||||
}
|
||||
|
||||
// BindTools binds tools to the current model instance.
|
||||
// Unlike WithTools, this modifies the current instance rather than
|
||||
// creating a new one. Tools are set to "allowed" mode by default.
|
||||
//
|
||||
// Parameters:
|
||||
// - tools: A slice of tool definitions to bind to the model
|
||||
//
|
||||
// Returns:
|
||||
// - error: Returns an error if no tools provided or conversion fails
|
||||
func (cm *ChatModel) BindTools(tools []*schema.ToolInfo) error {
|
||||
if len(tools) == 0 {
|
||||
return errors.New("no tools to bind")
|
||||
@@ -251,6 +317,15 @@ func (cm *ChatModel) BindTools(tools []*schema.ToolInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// BindForcedTools binds tools to the current model instance in forced mode.
|
||||
// This ensures the model will always use one of the provided tools
|
||||
// rather than generating a text response.
|
||||
//
|
||||
// Parameters:
|
||||
// - tools: A slice of tool definitions to bind to the model
|
||||
//
|
||||
// Returns:
|
||||
// - error: Returns an error if no tools provided or conversion fails
|
||||
func (cm *ChatModel) BindForcedTools(tools []*schema.ToolInfo) error {
|
||||
if len(tools) == 0 {
|
||||
return errors.New("no tools to bind")
|
||||
@@ -679,12 +754,24 @@ func (cm *ChatModel) convertCallbackOutput(message *schema.Message, conf *model.
|
||||
return callbackOutput
|
||||
}
|
||||
|
||||
// IsCallbacksEnabled indicates whether this model supports callbacks.
|
||||
// For the Gemini model, callbacks are always enabled to support
|
||||
// token usage tracking and other monitoring features.
|
||||
//
|
||||
// Returns:
|
||||
// - bool: Always returns true for Gemini models
|
||||
func (cm *ChatModel) IsCallbacksEnabled() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
const typ = "Gemini"
|
||||
|
||||
// GetType returns the type identifier for this model.
|
||||
// This is used for logging and debugging purposes to identify
|
||||
// which model implementation is being used.
|
||||
//
|
||||
// Returns:
|
||||
// - string: Returns "Gemini" as the model type
|
||||
func (cm *ChatModel) GetType() string {
|
||||
return typ
|
||||
}
|
||||
|
||||
@@ -12,37 +12,61 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// ModelInfo represents information about a specific model
|
||||
// ModelInfo represents information about a specific model.
|
||||
// This struct is used during code generation to parse model data
|
||||
// from the models.dev API and generate the static Go code.
|
||||
type ModelInfo struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Attachment bool `json:"attachment"`
|
||||
Reasoning bool `json:"reasoning"`
|
||||
Temperature bool `json:"temperature"`
|
||||
Cost Cost `json:"cost"`
|
||||
Limit Limit `json:"limit"`
|
||||
// ID is the unique identifier for the model
|
||||
ID string `json:"id"`
|
||||
// Name is the human-readable name of the model
|
||||
Name string `json:"name"`
|
||||
// Attachment indicates whether the model supports file attachments
|
||||
Attachment bool `json:"attachment"`
|
||||
// Reasoning indicates whether this is a reasoning/chain-of-thought model
|
||||
Reasoning bool `json:"reasoning"`
|
||||
// Temperature indicates whether the model supports temperature parameter
|
||||
Temperature bool `json:"temperature"`
|
||||
// Cost contains the pricing information for the model
|
||||
Cost Cost `json:"cost"`
|
||||
// Limit contains the context and output token limits
|
||||
Limit Limit `json:"limit"`
|
||||
}
|
||||
|
||||
// Cost represents the pricing information for a model
|
||||
// Cost represents the pricing information for a model.
|
||||
// Used during code generation to parse pricing data from models.dev.
|
||||
type Cost struct {
|
||||
Input float64 `json:"input"`
|
||||
Output float64 `json:"output"`
|
||||
CacheRead *float64 `json:"cache_read,omitempty"`
|
||||
// Input is the cost per million input tokens
|
||||
Input float64 `json:"input"`
|
||||
// Output is the cost per million output tokens
|
||||
Output float64 `json:"output"`
|
||||
// CacheRead is the cost per million cached read tokens (optional)
|
||||
CacheRead *float64 `json:"cache_read,omitempty"`
|
||||
// CacheWrite is the cost per million cached write tokens (optional)
|
||||
CacheWrite *float64 `json:"cache_write,omitempty"`
|
||||
}
|
||||
|
||||
// Limit represents the context and output limits for a model
|
||||
// Limit represents the context and output limits for a model.
|
||||
// Used during code generation to parse token limit data from models.dev.
|
||||
type Limit struct {
|
||||
// Context is the maximum number of input tokens
|
||||
Context int `json:"context"`
|
||||
Output int `json:"output"`
|
||||
// Output is the maximum number of output tokens
|
||||
Output int `json:"output"`
|
||||
}
|
||||
|
||||
// ProviderInfo represents information about a model provider
|
||||
// ProviderInfo represents information about a model provider.
|
||||
// Used during code generation to parse provider data from models.dev
|
||||
// and generate the static provider registry.
|
||||
type ProviderInfo struct {
|
||||
ID string `json:"id"`
|
||||
Env []string `json:"env"`
|
||||
NPM string `json:"npm"`
|
||||
Name string `json:"name"`
|
||||
// ID is the unique identifier for the provider
|
||||
ID string `json:"id"`
|
||||
// Env lists the environment variables for API credentials
|
||||
Env []string `json:"env"`
|
||||
// NPM is the NPM package name (for reference)
|
||||
NPM string `json:"npm"`
|
||||
// Name is the human-readable provider name
|
||||
Name string `json:"name"`
|
||||
// Models maps model IDs to their information
|
||||
Models map[string]ModelInfo `json:"models"`
|
||||
}
|
||||
|
||||
|
||||
@@ -3,41 +3,99 @@
|
||||
|
||||
package models
|
||||
|
||||
// ModelInfo represents information about a specific model
|
||||
// ModelInfo represents information about a specific model.
|
||||
// It contains comprehensive metadata about a model's capabilities,
|
||||
// pricing, and limitations sourced from models.dev.
|
||||
type ModelInfo struct {
|
||||
ID string
|
||||
Name string
|
||||
Attachment bool
|
||||
Reasoning bool
|
||||
// ID is the unique identifier for the model
|
||||
// Example: "claude-3-sonnet-20240620" or "gpt-4"
|
||||
ID string
|
||||
|
||||
// Name is the human-readable name of the model
|
||||
// Example: "Claude 3 Sonnet" or "GPT-4"
|
||||
Name string
|
||||
|
||||
// Attachment indicates whether the model supports file attachments
|
||||
Attachment bool
|
||||
|
||||
// Reasoning indicates whether this is a reasoning/chain-of-thought model
|
||||
// Example: OpenAI's o1 models have this set to true
|
||||
Reasoning bool
|
||||
|
||||
// Temperature indicates whether the model supports temperature parameter
|
||||
Temperature bool
|
||||
Cost Cost
|
||||
Limit Limit
|
||||
|
||||
// Cost contains the pricing information for input/output tokens
|
||||
Cost Cost
|
||||
|
||||
// Limit contains the context window and output token limits
|
||||
Limit Limit
|
||||
}
|
||||
|
||||
// Cost represents the pricing information for a model
|
||||
// Cost represents the pricing information for a model.
|
||||
// Prices are typically in USD per million tokens.
|
||||
type Cost struct {
|
||||
Input float64
|
||||
Output float64
|
||||
CacheRead *float64
|
||||
// Input is the cost per million input tokens
|
||||
Input float64
|
||||
|
||||
// Output is the cost per million output tokens
|
||||
Output float64
|
||||
|
||||
// CacheRead is the cost per million cached read tokens (optional)
|
||||
// Only applicable for models that support prompt caching
|
||||
CacheRead *float64
|
||||
|
||||
// CacheWrite is the cost per million cached write tokens (optional)
|
||||
// Only applicable for models that support prompt caching
|
||||
CacheWrite *float64
|
||||
}
|
||||
|
||||
// Limit represents the context and output limits for a model
|
||||
// Limit represents the context and output limits for a model.
|
||||
// These define the maximum number of tokens the model can process
|
||||
// and generate in a single interaction.
|
||||
type Limit struct {
|
||||
// Context is the maximum number of input tokens (context window size)
|
||||
Context int
|
||||
Output int
|
||||
|
||||
// Output is the maximum number of output tokens that can be generated
|
||||
Output int
|
||||
}
|
||||
|
||||
// ProviderInfo represents information about a model provider
|
||||
// ProviderInfo represents information about a model provider.
|
||||
// It contains metadata about the provider and all models it offers.
|
||||
type ProviderInfo struct {
|
||||
ID string
|
||||
Env []string
|
||||
NPM string
|
||||
Name string
|
||||
// ID is the unique identifier for the provider
|
||||
// Example: "anthropic", "openai", "google"
|
||||
ID string
|
||||
|
||||
// Env lists the environment variables checked for API credentials
|
||||
// Example: ["ANTHROPIC_API_KEY"] or ["OPENAI_API_KEY"]
|
||||
Env []string
|
||||
|
||||
// NPM is the NPM package name used by the provider (for reference)
|
||||
NPM string
|
||||
|
||||
// Name is the human-readable name of the provider
|
||||
// Example: "Anthropic", "OpenAI", "Google"
|
||||
Name string
|
||||
|
||||
// Models maps model IDs to their detailed information
|
||||
Models map[string]ModelInfo
|
||||
}
|
||||
|
||||
// GetModelsData returns the static models data from models.dev
|
||||
// GetModelsData returns the static models data from models.dev.
|
||||
// This data is automatically generated from the models.dev API
|
||||
// and provides comprehensive information about all supported
|
||||
// LLM providers and their models.
|
||||
//
|
||||
// The data includes:
|
||||
// - Provider information (ID, name, environment variables)
|
||||
// - Model capabilities (reasoning, attachments, temperature support)
|
||||
// - Pricing information (input/output costs, cache costs)
|
||||
// - Token limits (context window, max output tokens)
|
||||
//
|
||||
// Returns:
|
||||
// - map[string]ProviderInfo: A map of provider IDs to their information
|
||||
func GetModelsData() map[string]ProviderInfo {
|
||||
return map[string]ProviderInfo{
|
||||
"alibaba": {
|
||||
|
||||
@@ -14,17 +14,42 @@ import (
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// CustomChatModel wraps the eino-ext OpenAI model with custom tool schema handling
|
||||
// CustomChatModel wraps the eino-ext OpenAI model with custom tool schema handling.
|
||||
// It provides a compatibility layer that ensures proper JSON schema formatting
|
||||
// for OpenAI's function calling feature. This wrapper addresses cases where
|
||||
// tool schemas might have missing or empty properties that would cause API errors.
|
||||
type CustomChatModel struct {
|
||||
// wrapped is the underlying eino-ext OpenAI model instance
|
||||
wrapped *einoopenai.ChatModel
|
||||
}
|
||||
|
||||
// CustomRoundTripper intercepts HTTP requests to fix OpenAI function schemas
|
||||
// CustomRoundTripper intercepts HTTP requests to fix OpenAI function schemas.
|
||||
// It acts as middleware that modifies outgoing requests to ensure that
|
||||
// function/tool schemas are properly formatted according to OpenAI's requirements.
|
||||
// This is particularly important for handling edge cases where tool schemas
|
||||
// might have missing or empty properties fields.
|
||||
type CustomRoundTripper struct {
|
||||
// wrapped is the underlying HTTP transport to use for actual requests
|
||||
wrapped http.RoundTripper
|
||||
}
|
||||
|
||||
// NewCustomChatModel creates a new custom OpenAI chat model
|
||||
// NewCustomChatModel creates a new custom OpenAI chat model.
|
||||
// It wraps the standard eino-ext OpenAI model with additional request
|
||||
// preprocessing to ensure compatibility with OpenAI's API requirements,
|
||||
// particularly for function calling and tool schemas.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the operation
|
||||
// - config: Configuration for the OpenAI model including API key, model name, and parameters
|
||||
//
|
||||
// Returns:
|
||||
// - *CustomChatModel: A wrapped OpenAI model with enhanced compatibility
|
||||
// - error: Returns an error if model creation fails
|
||||
//
|
||||
// The custom model automatically:
|
||||
// - Ensures function parameter schemas have properties fields
|
||||
// - Fixes missing or empty properties in tool schemas
|
||||
// - Maintains compatibility with OpenAI's function calling requirements
|
||||
func NewCustomChatModel(ctx context.Context, config *einoopenai.ChatModelConfig) (*CustomChatModel, error) {
|
||||
// Create a custom HTTP client that intercepts requests
|
||||
if config.HTTPClient == nil {
|
||||
@@ -49,7 +74,20 @@ func NewCustomChatModel(ctx context.Context, config *einoopenai.ChatModelConfig)
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RoundTrip implements http.RoundTripper to intercept and fix OpenAI requests
|
||||
// RoundTrip implements http.RoundTripper to intercept and fix OpenAI requests.
|
||||
// It preprocesses outgoing requests to the OpenAI API to ensure tool/function
|
||||
// schemas meet the API's requirements.
|
||||
//
|
||||
// Parameters:
|
||||
// - req: The HTTP request to be sent to the OpenAI API
|
||||
//
|
||||
// Returns:
|
||||
// - *http.Response: The response from the OpenAI API
|
||||
// - error: Any error that occurred during the request
|
||||
//
|
||||
// The method performs the following fixes:
|
||||
// - Ensures function parameter schemas of type "object" have a properties field
|
||||
// - Adds empty properties object if missing to prevent API validation errors
|
||||
func (c *CustomRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Only intercept OpenAI chat completions requests
|
||||
if !strings.Contains(req.URL.Path, "/chat/completions") {
|
||||
@@ -109,17 +147,47 @@ func (c *CustomRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
|
||||
return c.wrapped.RoundTrip(req)
|
||||
}
|
||||
|
||||
// Generate implements model.ChatModel
|
||||
// Generate implements model.ChatModel interface.
|
||||
// It generates a single response from the OpenAI model based on the input messages.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the operation, supporting cancellation and deadlines
|
||||
// - in: The conversation history as a slice of messages
|
||||
// - opts: Optional configuration options for the generation
|
||||
//
|
||||
// Returns:
|
||||
// - *schema.Message: The generated response message
|
||||
// - error: Any error that occurred during generation
|
||||
func (c *CustomChatModel) Generate(ctx context.Context, in []*schema.Message, opts ...model.Option) (*schema.Message, error) {
|
||||
return c.wrapped.Generate(ctx, in, opts...)
|
||||
}
|
||||
|
||||
// Stream implements model.ChatModel
|
||||
// Stream implements model.ChatModel interface.
|
||||
// It generates a streaming response from the OpenAI model, allowing
|
||||
// incremental processing of the model's output as it's generated.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the operation, supporting cancellation and deadlines
|
||||
// - in: The conversation history as a slice of messages
|
||||
// - opts: Optional configuration options for the generation
|
||||
//
|
||||
// Returns:
|
||||
// - *schema.StreamReader[*schema.Message]: A reader for the streaming response
|
||||
// - error: Any error that occurred during stream setup
|
||||
func (c *CustomChatModel) Stream(ctx context.Context, in []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) {
|
||||
return c.wrapped.Stream(ctx, in, opts...)
|
||||
}
|
||||
|
||||
// WithTools implements model.ToolCallingChatModel
|
||||
// WithTools implements model.ToolCallingChatModel interface.
|
||||
// It creates a new model instance with the specified tools available for function calling.
|
||||
// The original model instance remains unchanged.
|
||||
//
|
||||
// Parameters:
|
||||
// - tools: A slice of tool definitions that the model can use
|
||||
//
|
||||
// Returns:
|
||||
// - model.ToolCallingChatModel: A new model instance with tools enabled
|
||||
// - error: Returns an error if tool binding fails
|
||||
func (c *CustomChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) {
|
||||
wrappedWithTools, err := c.wrapped.WithTools(tools)
|
||||
if err != nil {
|
||||
@@ -135,22 +203,47 @@ func (c *CustomChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCalling
|
||||
return &CustomChatModel{wrapped: wrappedChatModel}, nil
|
||||
}
|
||||
|
||||
// BindTools implements model.ToolCallingChatModel
|
||||
// BindTools implements model.ToolCallingChatModel interface.
|
||||
// It binds tools to the current model instance, modifying it in place
|
||||
// rather than creating a new instance.
|
||||
//
|
||||
// Parameters:
|
||||
// - tools: A slice of tool definitions to bind to the model
|
||||
//
|
||||
// Returns:
|
||||
// - error: Returns an error if tool binding fails
|
||||
func (c *CustomChatModel) BindTools(tools []*schema.ToolInfo) error {
|
||||
return c.wrapped.BindTools(tools)
|
||||
}
|
||||
|
||||
// BindForcedTools implements model.ToolCallingChatModel
|
||||
// BindForcedTools implements model.ToolCallingChatModel interface.
|
||||
// It binds tools to the current model instance in forced mode,
|
||||
// ensuring the model will always use one of the provided tools.
|
||||
//
|
||||
// Parameters:
|
||||
// - tools: A slice of tool definitions to bind to the model
|
||||
//
|
||||
// Returns:
|
||||
// - error: Returns an error if tool binding fails
|
||||
func (c *CustomChatModel) BindForcedTools(tools []*schema.ToolInfo) error {
|
||||
return c.wrapped.BindForcedTools(tools)
|
||||
}
|
||||
|
||||
// GetType implements model.ChatModel
|
||||
// GetType implements model.ChatModel interface.
|
||||
// It returns the type identifier for this model implementation.
|
||||
//
|
||||
// Returns:
|
||||
// - string: Returns "CustomOpenAI" as the model type identifier
|
||||
func (c *CustomChatModel) GetType() string {
|
||||
return "CustomOpenAI"
|
||||
}
|
||||
|
||||
// IsCallbacksEnabled implements model.ChatModel
|
||||
// IsCallbacksEnabled implements model.ChatModel interface.
|
||||
// It indicates whether this model supports callbacks for monitoring
|
||||
// and tracking purposes.
|
||||
//
|
||||
// Returns:
|
||||
// - bool: Returns the callback enabled status from the wrapped model
|
||||
func (c *CustomChatModel) IsCallbacksEnabled() bool {
|
||||
return c.wrapped.IsCallbacksEnabled()
|
||||
}
|
||||
|
||||
@@ -27,7 +27,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// ClaudeCodePrompt is the required system prompt for OAuth authentication
|
||||
// ClaudeCodePrompt is the required system prompt for OAuth authentication.
|
||||
// This prompt must be included as the first system message when using OAuth-based
|
||||
// authentication with Anthropic's API to properly identify the application.
|
||||
ClaudeCodePrompt = "You are Claude Code, Anthropic's official CLI for Claude."
|
||||
)
|
||||
|
||||
@@ -62,35 +64,93 @@ func resolveModelAlias(provider, modelName string) string {
|
||||
return modelName
|
||||
}
|
||||
|
||||
// ProviderConfig holds configuration for creating LLM providers
|
||||
// ProviderConfig holds configuration for creating LLM providers.
|
||||
// It contains all necessary settings to initialize and configure
|
||||
// various LLM providers including API keys, model parameters, and
|
||||
// provider-specific settings.
|
||||
type ProviderConfig struct {
|
||||
ModelString string
|
||||
SystemPrompt string
|
||||
ProviderAPIKey string // API key for OpenAI and Anthropic
|
||||
ProviderURL string // Base URL for OpenAI, Anthropic, and Ollama
|
||||
// ModelString specifies the model in the format "provider:model"
|
||||
// Example: "anthropic:claude-3-sonnet-20240620" or "openai:gpt-4"
|
||||
ModelString string
|
||||
|
||||
// SystemPrompt sets the system message/instructions for the model
|
||||
SystemPrompt string
|
||||
|
||||
// ProviderAPIKey is the API key for authentication with the provider
|
||||
// Used for OpenAI, Anthropic, Google, and Azure providers
|
||||
ProviderAPIKey string
|
||||
|
||||
// ProviderURL is the base URL for the provider's API endpoint
|
||||
// Can be used to specify custom endpoints for OpenAI, Anthropic, Ollama, or Azure
|
||||
ProviderURL string
|
||||
|
||||
// Model generation parameters
|
||||
MaxTokens int
|
||||
Temperature *float32
|
||||
TopP *float32
|
||||
TopK *int32
|
||||
|
||||
// MaxTokens limits the maximum number of tokens in the response
|
||||
MaxTokens int
|
||||
|
||||
// Temperature controls randomness in generation (0.0 to 1.0)
|
||||
// Lower values make output more focused and deterministic
|
||||
Temperature *float32
|
||||
|
||||
// TopP implements nucleus sampling, controlling diversity
|
||||
// Value between 0.0 and 1.0, where 1.0 considers all tokens
|
||||
TopP *float32
|
||||
|
||||
// TopK limits the number of tokens to sample from
|
||||
// Used by Anthropic and Google providers
|
||||
TopK *int32
|
||||
|
||||
// StopSequences is a list of sequences that will stop generation
|
||||
StopSequences []string
|
||||
|
||||
// Ollama-specific parameters
|
||||
NumGPU *int32
|
||||
|
||||
// NumGPU specifies the number of GPUs to use for inference
|
||||
NumGPU *int32
|
||||
|
||||
// MainGPU specifies which GPU to use as the primary device
|
||||
MainGPU *int32
|
||||
|
||||
// TLS configuration
|
||||
TLSSkipVerify bool // Skip TLS certificate verification (insecure)
|
||||
|
||||
// TLSSkipVerify skips TLS certificate verification (insecure)
|
||||
// Should only be used for development or with self-signed certificates
|
||||
TLSSkipVerify bool
|
||||
}
|
||||
|
||||
// ProviderResult contains the result of provider creation
|
||||
// ProviderResult contains the result of provider creation.
|
||||
// It includes both the created model instance and any informational
|
||||
// messages that should be displayed to the user.
|
||||
type ProviderResult struct {
|
||||
Model model.ToolCallingChatModel
|
||||
Message string // Optional message for user feedback (e.g., GPU fallback info)
|
||||
// Model is the created LLM provider instance that implements
|
||||
// the ToolCallingChatModel interface for tool-enabled conversations
|
||||
Model model.ToolCallingChatModel
|
||||
|
||||
// Message contains optional feedback for the user
|
||||
// Example: "Insufficient GPU memory, falling back to CPU inference"
|
||||
Message string
|
||||
}
|
||||
|
||||
// CreateProvider creates an eino ToolCallingChatModel based on the provider configuration
|
||||
// CreateProvider creates an eino ToolCallingChatModel based on the provider configuration.
|
||||
// It validates the model, checks required environment variables, and initializes
|
||||
// the appropriate provider (Anthropic, OpenAI, Google, Ollama, or Azure).
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the operation, used for cancellation and deadlines
|
||||
// - config: Provider configuration containing model details and parameters
|
||||
//
|
||||
// Returns:
|
||||
// - *ProviderResult: Contains the created model and any informational messages
|
||||
// - error: Returns an error if provider creation fails, model validation fails,
|
||||
// or required credentials are missing
|
||||
//
|
||||
// Supported providers:
|
||||
// - anthropic: Claude models with API key or OAuth authentication
|
||||
// - openai: GPT models including reasoning models like o1
|
||||
// - google: Gemini models
|
||||
// - ollama: Local models with automatic GPU/CPU fallback
|
||||
// - azure: Azure OpenAI Service deployments
|
||||
func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResult, error) {
|
||||
parts := strings.SplitN(config.ModelString, ":", 2)
|
||||
if len(parts) < 2 {
|
||||
@@ -407,9 +467,17 @@ func createGoogleProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
return gemini.NewChatModel(ctx, geminiConfig)
|
||||
}
|
||||
|
||||
// OllamaLoadingResult contains the result of model loading with actual settings used
|
||||
// OllamaLoadingResult contains the result of model loading with actual settings used.
|
||||
// It provides information about how the model was loaded, including any fallback
|
||||
// behavior that occurred during initialization.
|
||||
type OllamaLoadingResult struct {
|
||||
// Options contains the actual Ollama options used for loading
|
||||
// May differ from requested options if fallback occurred (e.g., CPU instead of GPU)
|
||||
Options *api.Options
|
||||
|
||||
// Message describes the loading result
|
||||
// Example: "Model loaded successfully on GPU" or
|
||||
// "Insufficient GPU memory, falling back to CPU inference"
|
||||
Message string
|
||||
}
|
||||
|
||||
|
||||
@@ -8,19 +8,39 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ModelsRegistry provides validation and information about models
|
||||
// ModelsRegistry provides validation and information about models.
|
||||
// It maintains a registry of all supported LLM providers and their models,
|
||||
// including capabilities, pricing, and configuration requirements.
|
||||
// The registry data is generated from models.dev and provides a single
|
||||
// source of truth for model validation and discovery.
|
||||
type ModelsRegistry struct {
|
||||
// providers maps provider IDs to their information and available models
|
||||
providers map[string]ProviderInfo
|
||||
}
|
||||
|
||||
// NewModelsRegistry creates a new models registry with static data
|
||||
// NewModelsRegistry creates a new models registry with static data.
|
||||
// The registry is populated with model information generated from models.dev,
|
||||
// providing comprehensive metadata about available models across all supported providers.
|
||||
//
|
||||
// Returns:
|
||||
// - *ModelsRegistry: A new registry instance populated with current model data
|
||||
func NewModelsRegistry() *ModelsRegistry {
|
||||
return &ModelsRegistry{
|
||||
providers: GetModelsData(),
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateModel validates if a model exists and returns detailed information
|
||||
// ValidateModel validates if a model exists and returns detailed information.
|
||||
// It checks whether a specific model is available for a given provider and
|
||||
// returns comprehensive information about the model's capabilities and limits.
|
||||
//
|
||||
// Parameters:
|
||||
// - provider: The provider ID (e.g., "anthropic", "openai", "google")
|
||||
// - modelID: The specific model ID (e.g., "claude-3-sonnet-20240620", "gpt-4")
|
||||
//
|
||||
// Returns:
|
||||
// - *ModelInfo: Detailed information about the model including pricing, limits, and capabilities
|
||||
// - error: Returns an error if the provider is unsupported or model is not found
|
||||
func (r *ModelsRegistry) ValidateModel(provider, modelID string) (*ModelInfo, error) {
|
||||
providerInfo, exists := r.providers[provider]
|
||||
if !exists {
|
||||
@@ -35,7 +55,21 @@ func (r *ModelsRegistry) ValidateModel(provider, modelID string) (*ModelInfo, er
|
||||
return &modelInfo, nil
|
||||
}
|
||||
|
||||
// GetRequiredEnvVars returns the required environment variables for a provider
|
||||
// GetRequiredEnvVars returns the required environment variables for a provider.
|
||||
// These are the environment variable names that should contain API keys or
|
||||
// other authentication credentials for the specified provider.
|
||||
//
|
||||
// Parameters:
|
||||
// - provider: The provider ID (e.g., "anthropic", "openai", "google")
|
||||
//
|
||||
// Returns:
|
||||
// - []string: List of environment variable names the provider checks for credentials
|
||||
// - error: Returns an error if the provider is unsupported
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// For "anthropic", returns ["ANTHROPIC_API_KEY"]
|
||||
// For "google", returns ["GOOGLE_API_KEY", "GEMINI_API_KEY", "GOOGLE_GENERATIVE_AI_API_KEY"]
|
||||
func (r *ModelsRegistry) GetRequiredEnvVars(provider string) ([]string, error) {
|
||||
providerInfo, exists := r.providers[provider]
|
||||
if !exists {
|
||||
@@ -45,7 +79,16 @@ func (r *ModelsRegistry) GetRequiredEnvVars(provider string) ([]string, error) {
|
||||
return providerInfo.Env, nil
|
||||
}
|
||||
|
||||
// ValidateEnvironment checks if required environment variables are set
|
||||
// ValidateEnvironment checks if required environment variables are set.
|
||||
// It verifies that at least one of the provider's required environment variables
|
||||
// contains an API key, unless an API key is explicitly provided via configuration.
|
||||
//
|
||||
// Parameters:
|
||||
// - provider: The provider ID to validate environment for
|
||||
// - apiKey: An API key provided via configuration (if empty, checks environment variables)
|
||||
//
|
||||
// Returns:
|
||||
// - error: Returns nil if validation passes, or an error describing missing credentials
|
||||
func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) error {
|
||||
envVars, err := r.GetRequiredEnvVars(provider)
|
||||
if err != nil {
|
||||
@@ -74,7 +117,16 @@ func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) err
|
||||
return nil
|
||||
}
|
||||
|
||||
// SuggestModels returns similar model names when an invalid model is provided
|
||||
// SuggestModels returns similar model names when an invalid model is provided.
|
||||
// It helps users discover the correct model ID by finding models that partially
|
||||
// match the provided input, useful for correcting typos or finding alternatives.
|
||||
//
|
||||
// Parameters:
|
||||
// - provider: The provider ID to search within
|
||||
// - invalidModel: The invalid or misspelled model name to find suggestions for
|
||||
//
|
||||
// Returns:
|
||||
// - []string: A list of up to 5 suggested model IDs that partially match the input
|
||||
func (r *ModelsRegistry) SuggestModels(provider, invalidModel string) []string {
|
||||
providerInfo, exists := r.providers[provider]
|
||||
if !exists {
|
||||
@@ -105,7 +157,12 @@ func (r *ModelsRegistry) SuggestModels(provider, invalidModel string) []string {
|
||||
return suggestions
|
||||
}
|
||||
|
||||
// GetSupportedProviders returns a list of all supported providers
|
||||
// GetSupportedProviders returns a list of all supported providers.
|
||||
// This includes all providers that have models registered in the system,
|
||||
// such as "anthropic", "openai", "google", "alibaba", etc.
|
||||
//
|
||||
// Returns:
|
||||
// - []string: A list of all provider IDs available in the registry
|
||||
func (r *ModelsRegistry) GetSupportedProviders() []string {
|
||||
var providers []string
|
||||
for providerID := range r.providers {
|
||||
@@ -114,7 +171,16 @@ func (r *ModelsRegistry) GetSupportedProviders() []string {
|
||||
return providers
|
||||
}
|
||||
|
||||
// GetModelsForProvider returns all models for a specific provider
|
||||
// GetModelsForProvider returns all models for a specific provider.
|
||||
// This is useful for listing available models when a user wants to see
|
||||
// all options for a particular provider.
|
||||
//
|
||||
// Parameters:
|
||||
// - provider: The provider ID to get models for
|
||||
//
|
||||
// Returns:
|
||||
// - map[string]ModelInfo: A map of model IDs to their detailed information
|
||||
// - error: Returns an error if the provider is unsupported
|
||||
func (r *ModelsRegistry) GetModelsForProvider(provider string) (map[string]ModelInfo, error) {
|
||||
providerInfo, exists := r.providers[provider]
|
||||
if !exists {
|
||||
@@ -127,7 +193,13 @@ func (r *ModelsRegistry) GetModelsForProvider(provider string) (map[string]Model
|
||||
// Global registry instance
|
||||
var globalRegistry = NewModelsRegistry()
|
||||
|
||||
// GetGlobalRegistry returns the global models registry instance
|
||||
// GetGlobalRegistry returns the global models registry instance.
|
||||
// This provides a singleton registry that can be accessed throughout
|
||||
// the application for model validation and information retrieval.
|
||||
// The registry is initialized once with data from models.dev.
|
||||
//
|
||||
// Returns:
|
||||
// - *ModelsRegistry: The global registry instance
|
||||
func GetGlobalRegistry() *ModelsRegistry {
|
||||
return globalRegistry
|
||||
}
|
||||
|
||||
+57
-12
@@ -7,14 +7,21 @@ import (
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// Manager manages session state and auto-saving
|
||||
// Manager manages session state and auto-saving functionality.
|
||||
// It provides thread-safe operations for managing a conversation session,
|
||||
// including automatic persistence to disk after each modification.
|
||||
// The Manager ensures that all session operations are synchronized and
|
||||
// that the session file is kept up-to-date with any changes.
|
||||
type Manager struct {
|
||||
session *Session
|
||||
filePath string
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewManager creates a new session manager
|
||||
// NewManager creates a new session manager with a fresh session.
|
||||
// The filePath parameter specifies where the session will be auto-saved.
|
||||
// If filePath is empty, the session will not be automatically saved to disk.
|
||||
// Returns a Manager instance ready to track conversation messages.
|
||||
func NewManager(filePath string) *Manager {
|
||||
return &Manager{
|
||||
session: NewSession(),
|
||||
@@ -22,7 +29,11 @@ func NewManager(filePath string) *Manager {
|
||||
}
|
||||
}
|
||||
|
||||
// NewManagerWithSession creates a new session manager with an existing session
|
||||
// NewManagerWithSession creates a new session manager with an existing session.
|
||||
// This is useful when loading a session from a file and wanting to continue
|
||||
// managing it with auto-save functionality.
|
||||
// The session parameter is the existing session to manage.
|
||||
// The filePath parameter specifies where the session will be auto-saved.
|
||||
func NewManagerWithSession(session *Session, filePath string) *Manager {
|
||||
return &Manager{
|
||||
session: session,
|
||||
@@ -30,7 +41,12 @@ func NewManagerWithSession(session *Session, filePath string) *Manager {
|
||||
}
|
||||
}
|
||||
|
||||
// AddMessage adds a message to the session and auto-saves
|
||||
// AddMessage adds a message to the session and auto-saves.
|
||||
// The message is converted from schema.Message format to the internal
|
||||
// session Message format before being added. If a filePath was specified
|
||||
// when creating the Manager, the session is automatically saved to disk.
|
||||
// This operation is thread-safe.
|
||||
// Returns an error if auto-saving fails, nil otherwise.
|
||||
func (m *Manager) AddMessage(msg *schema.Message) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
@@ -45,7 +61,11 @@ func (m *Manager) AddMessage(msg *schema.Message) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddMessages adds multiple messages to the session and auto-saves
|
||||
// AddMessages adds multiple messages to the session and auto-saves.
|
||||
// All messages are added in order and then the session is saved once.
|
||||
// This is more efficient than calling AddMessage multiple times when
|
||||
// adding several messages at once. The operation is thread-safe.
|
||||
// Returns an error if auto-saving fails, nil otherwise.
|
||||
func (m *Manager) AddMessages(msgs []*schema.Message) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
@@ -62,7 +82,12 @@ func (m *Manager) AddMessages(msgs []*schema.Message) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReplaceAllMessages replaces all messages in the session with the provided messages
|
||||
// ReplaceAllMessages replaces all messages in the session with the provided messages.
|
||||
// This method completely clears the existing message history and replaces it with
|
||||
// the new set of messages. Useful for resetting a conversation or loading a
|
||||
// different conversation context. The operation is thread-safe and triggers
|
||||
// an auto-save if a filePath is configured.
|
||||
// Returns an error if auto-saving fails, nil otherwise.
|
||||
func (m *Manager) ReplaceAllMessages(msgs []*schema.Message) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
@@ -83,7 +108,11 @@ func (m *Manager) ReplaceAllMessages(msgs []*schema.Message) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetMetadata sets the session metadata
|
||||
// SetMetadata sets the session metadata.
|
||||
// This updates the session's metadata with information about the provider,
|
||||
// model, and MCPHost version. The operation is thread-safe and triggers
|
||||
// an auto-save if a filePath is configured.
|
||||
// Returns an error if auto-saving fails, nil otherwise.
|
||||
func (m *Manager) SetMetadata(metadata Metadata) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
@@ -97,7 +126,11 @@ func (m *Manager) SetMetadata(metadata Metadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMessages returns all messages as schema.Message slice
|
||||
// GetMessages returns all messages as a schema.Message slice.
|
||||
// This method converts all stored session messages to the schema format
|
||||
// used by LLM providers. The returned slice is a new allocation, so
|
||||
// modifications to it won't affect the stored session. This operation
|
||||
// is thread-safe for concurrent reads.
|
||||
func (m *Manager) GetMessages() []*schema.Message {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
@@ -110,7 +143,11 @@ func (m *Manager) GetMessages() []*schema.Message {
|
||||
return messages
|
||||
}
|
||||
|
||||
// GetSession returns a copy of the current session
|
||||
// GetSession returns a copy of the current session.
|
||||
// The returned session is a deep copy, including all messages, so
|
||||
// modifications to it won't affect the managed session. This is useful
|
||||
// for safely inspecting the session state without risk of concurrent
|
||||
// modification. This operation is thread-safe for concurrent reads.
|
||||
func (m *Manager) GetSession() *Session {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
@@ -123,7 +160,11 @@ func (m *Manager) GetSession() *Session {
|
||||
return &sessionCopy
|
||||
}
|
||||
|
||||
// Save manually saves the session to file
|
||||
// Save manually saves the session to file.
|
||||
// This forces a save operation even if no changes have been made.
|
||||
// Useful for ensuring the session is persisted at specific points.
|
||||
// Returns an error if no filePath was specified when creating the
|
||||
// Manager, or if the save operation fails.
|
||||
func (m *Manager) Save() error {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
@@ -135,12 +176,16 @@ func (m *Manager) Save() error {
|
||||
return m.session.SaveToFile(m.filePath)
|
||||
}
|
||||
|
||||
// GetFilePath returns the file path for this session
|
||||
// GetFilePath returns the file path for this session.
|
||||
// Returns the path where the session is being auto-saved, or an
|
||||
// empty string if no auto-save path was configured.
|
||||
func (m *Manager) GetFilePath() string {
|
||||
return m.filePath
|
||||
}
|
||||
|
||||
// MessageCount returns the number of messages in the session
|
||||
// MessageCount returns the number of messages in the session.
|
||||
// This provides a quick way to check the conversation length without
|
||||
// retrieving all messages. This operation is thread-safe for concurrent reads.
|
||||
func (m *Manager) MessageCount() int {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
+82
-25
@@ -11,40 +11,71 @@ import (
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// Session represents a complete conversation session with metadata
|
||||
// Session represents a complete conversation session with metadata.
|
||||
// It stores all messages exchanged during a conversation along with
|
||||
// contextual information about the session such as the provider, model,
|
||||
// and timestamps. Sessions can be saved to and loaded from JSON files
|
||||
// for persistence across program runs.
|
||||
type Session struct {
|
||||
Version string `json:"version"`
|
||||
// Version indicates the session format version for compatibility
|
||||
Version string `json:"version"`
|
||||
// CreatedAt is the timestamp when the session was first created
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
// UpdatedAt is the timestamp when the session was last modified
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Metadata Metadata `json:"metadata"`
|
||||
Messages []Message `json:"messages"`
|
||||
// Metadata contains contextual information about the session
|
||||
Metadata Metadata `json:"metadata"`
|
||||
// Messages is the ordered list of all messages in this session
|
||||
Messages []Message `json:"messages"`
|
||||
}
|
||||
|
||||
// Metadata contains session metadata
|
||||
// Metadata contains session metadata that provides context about the
|
||||
// environment and configuration used during the conversation. This helps
|
||||
// with debugging and understanding the session's context when reviewing
|
||||
// conversation history.
|
||||
type Metadata struct {
|
||||
// MCPHostVersion is the version of MCPHost used for this session
|
||||
MCPHostVersion string `json:"mcphost_version"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
// Provider is the LLM provider used (e.g., "anthropic", "openai", "gemini")
|
||||
Provider string `json:"provider"`
|
||||
// Model is the specific model identifier used for the conversation
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
// Message represents a single message in the session
|
||||
// Message represents a single message in the conversation session.
|
||||
// Messages can be from different roles (user, assistant, tool) and may
|
||||
// include tool calls for assistant messages or tool results for tool messages.
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"` // For tool result messages
|
||||
// ID is a unique identifier for this message, auto-generated if not provided
|
||||
ID string `json:"id"`
|
||||
// Role indicates who sent the message ("user", "assistant", "tool", or "system")
|
||||
Role string `json:"role"`
|
||||
// Content is the text content of the message
|
||||
Content string `json:"content"`
|
||||
// Timestamp is when the message was created
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
// ToolCalls contains any tool invocations made by the assistant in this message
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
// ToolCallID links a tool result message to its corresponding tool call
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
// ToolCall represents a tool call within a message
|
||||
// ToolCall represents a tool invocation within an assistant message.
|
||||
// When the assistant decides to use a tool, it creates a ToolCall with
|
||||
// the necessary information to execute that tool.
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Arguments any `json:"arguments"`
|
||||
// ID is a unique identifier for this tool call, used to link results
|
||||
ID string `json:"id"`
|
||||
// Name is the name of the tool being invoked
|
||||
Name string `json:"name"`
|
||||
// Arguments contains the parameters passed to the tool, typically as JSON
|
||||
Arguments any `json:"arguments"`
|
||||
}
|
||||
|
||||
// NewSession creates a new session with default values
|
||||
// NewSession creates a new session with default values.
|
||||
// It initializes a session with version 1.0, current timestamps,
|
||||
// empty message list, and empty metadata. The returned session
|
||||
// is ready to receive messages and can be saved to a file.
|
||||
func NewSession() *Session {
|
||||
return &Session{
|
||||
Version: "1.0",
|
||||
@@ -55,7 +86,10 @@ func NewSession() *Session {
|
||||
}
|
||||
}
|
||||
|
||||
// AddMessage adds a message to the session
|
||||
// AddMessage adds a message to the session.
|
||||
// If the message doesn't have an ID, one will be auto-generated.
|
||||
// If the message doesn't have a timestamp, the current time will be used.
|
||||
// The session's UpdatedAt timestamp is automatically updated.
|
||||
func (s *Session) AddMessage(msg Message) {
|
||||
if msg.ID == "" {
|
||||
msg.ID = generateMessageID()
|
||||
@@ -68,13 +102,21 @@ func (s *Session) AddMessage(msg Message) {
|
||||
s.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
// SetMetadata sets the session metadata
|
||||
// SetMetadata sets the session metadata.
|
||||
// This replaces the existing metadata with the provided metadata
|
||||
// and updates the session's UpdatedAt timestamp. Use this to record
|
||||
// information about the provider, model, and MCPHost version.
|
||||
func (s *Session) SetMetadata(metadata Metadata) {
|
||||
s.Metadata = metadata
|
||||
s.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
// SaveToFile saves the session to a JSON file
|
||||
// SaveToFile saves the session to a JSON file.
|
||||
// The session is serialized as indented JSON for readability.
|
||||
// The UpdatedAt timestamp is automatically updated before saving.
|
||||
// The file is created with 0644 permissions if it doesn't exist,
|
||||
// or overwritten if it does exist.
|
||||
// Returns an error if marshaling fails or file writing fails.
|
||||
func (s *Session) SaveToFile(filePath string) error {
|
||||
s.UpdatedAt = time.Now()
|
||||
|
||||
@@ -86,7 +128,12 @@ func (s *Session) SaveToFile(filePath string) error {
|
||||
return os.WriteFile(filePath, data, 0644)
|
||||
}
|
||||
|
||||
// LoadFromFile loads a session from a JSON file
|
||||
// LoadFromFile loads a session from a JSON file.
|
||||
// It reads the file at the specified path and deserializes it into
|
||||
// a Session struct. This is useful for resuming previous conversations
|
||||
// or reviewing session history.
|
||||
// Returns the loaded session on success, or an error if the file
|
||||
// cannot be read or the JSON is invalid.
|
||||
func LoadFromFile(filePath string) (*Session, error) {
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
@@ -101,7 +148,12 @@ func LoadFromFile(filePath string) (*Session, error) {
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// ConvertFromSchemaMessage converts a schema.Message to a session Message
|
||||
// ConvertFromSchemaMessage converts a schema.Message to a session Message.
|
||||
// This function bridges between the eino schema message format and the
|
||||
// session's internal message format. It preserves role, content, and
|
||||
// tool-related information while adding a timestamp.
|
||||
// Tool calls from assistant messages and tool call IDs from tool messages
|
||||
// are properly converted and preserved.
|
||||
func ConvertFromSchemaMessage(msg *schema.Message) Message {
|
||||
sessionMsg := Message{
|
||||
Role: string(msg.Role),
|
||||
@@ -129,7 +181,12 @@ func ConvertFromSchemaMessage(msg *schema.Message) Message {
|
||||
return sessionMsg
|
||||
}
|
||||
|
||||
// ConvertToSchemaMessage converts a session Message to a schema.Message
|
||||
// ConvertToSchemaMessage converts a session Message to a schema.Message.
|
||||
// This method bridges between the session's internal message format and
|
||||
// the eino schema message format used by the LLM providers.
|
||||
// It properly handles tool calls for assistant messages and tool call IDs
|
||||
// for tool result messages. Arguments are converted to string format as
|
||||
// required by the schema.
|
||||
func (m *Message) ConvertToSchemaMessage() *schema.Message {
|
||||
msg := &schema.Message{
|
||||
Role: schema.RoleType(m.Role),
|
||||
|
||||
@@ -1 +1,16 @@
|
||||
// Package tokens provides token counting and estimation functionality for
|
||||
// various language model providers. It includes utilities for estimating
|
||||
// token counts in text, as well as provider-specific implementations for
|
||||
// more accurate token counting.
|
||||
//
|
||||
// The package supports multiple approaches to token counting:
|
||||
// - Quick estimation using character-based heuristics
|
||||
// - Provider-specific tokenizers for accurate counts
|
||||
// - Initialization functions for setting up token counters
|
||||
//
|
||||
// Token counting is essential for:
|
||||
// - Managing API rate limits
|
||||
// - Calculating costs for API usage
|
||||
// - Ensuring prompts fit within model context windows
|
||||
// - Optimizing prompt engineering and response handling
|
||||
package tokens
|
||||
|
||||
@@ -1,6 +1,23 @@
|
||||
package tokens
|
||||
|
||||
// EstimateTokens provides a rough estimate of tokens in text
|
||||
// EstimateTokens estimates the number of tokens in the given text string.
|
||||
// It uses a rough approximation of 4 characters per token, which is a common
|
||||
// heuristic for most language models. This function provides a quick estimation
|
||||
// without requiring model-specific tokenizers.
|
||||
//
|
||||
// The estimation may not be accurate for all models or text types, particularly
|
||||
// for texts with many special characters, non-English languages, or code snippets.
|
||||
// For more accurate token counting, use model-specific tokenizers when available.
|
||||
//
|
||||
// Parameters:
|
||||
// - text: The input text string to estimate tokens for
|
||||
//
|
||||
// Returns:
|
||||
// - int: The estimated number of tokens in the text
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// count := EstimateTokens("Hello, world!") // Returns approximately 3
|
||||
func EstimateTokens(text string) int {
|
||||
// Rough approximation: ~4 characters per token for most models
|
||||
return len(text) / 4
|
||||
|
||||
+43
-2
@@ -1,11 +1,52 @@
|
||||
package tokens
|
||||
|
||||
// InitializeTokenCounters registers all available token counters
|
||||
// InitializeTokenCounters registers all available token counters for various
|
||||
// language model providers. This function should be called during application
|
||||
// startup to ensure that token counting functionality is available for all
|
||||
// supported models.
|
||||
//
|
||||
// Currently, this function is a placeholder for future provider-specific
|
||||
// token counter implementations. As new providers are added (OpenAI, Anthropic,
|
||||
// Google, etc.), their respective token counters will be registered here.
|
||||
//
|
||||
// This function does not require any API keys and will only initialize
|
||||
// counters that can work without authentication.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// func main() {
|
||||
// tokens.InitializeTokenCounters()
|
||||
// // Token counting is now available
|
||||
// }
|
||||
func InitializeTokenCounters() {
|
||||
// Future provider-specific counters can be registered here
|
||||
}
|
||||
|
||||
// InitializeTokenCountersWithKeys registers token counters with provided API keys
|
||||
// InitializeTokenCountersWithKeys registers token counters for various language
|
||||
// model providers using the provided API keys. This function enables more
|
||||
// accurate token counting by allowing access to provider-specific tokenization
|
||||
// endpoints or libraries that require authentication.
|
||||
//
|
||||
// This function should be called during application startup after API keys
|
||||
// have been loaded from configuration or environment variables. It will
|
||||
// initialize token counters for providers where API keys are available,
|
||||
// enabling precise token counting that matches the provider's actual
|
||||
// tokenization logic.
|
||||
//
|
||||
// The function will silently skip providers for which no API keys are
|
||||
// configured, allowing the application to continue with partial token
|
||||
// counting capabilities.
|
||||
//
|
||||
// Future implementations will accept provider-specific API keys through
|
||||
// parameters or read them from a configuration context.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// func main() {
|
||||
// // Load API keys from environment or config
|
||||
// tokens.InitializeTokenCountersWithKeys()
|
||||
// // Provider-specific token counting is now available
|
||||
// }
|
||||
func InitializeTokenCountersWithKeys() {
|
||||
// Future provider-specific counters can be registered here
|
||||
}
|
||||
|
||||
@@ -4,14 +4,19 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// BufferedDebugLogger stores debug messages until they can be displayed
|
||||
// BufferedDebugLogger implements DebugLogger by storing debug messages in memory
|
||||
// until they can be retrieved and displayed. This is useful when debug output
|
||||
// needs to be deferred or batch-processed rather than immediately displayed.
|
||||
// All methods are thread-safe for concurrent use.
|
||||
type BufferedDebugLogger struct {
|
||||
enabled bool
|
||||
messages []string
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewBufferedDebugLogger creates a new buffered debug logger
|
||||
// NewBufferedDebugLogger creates a new buffered debug logger instance.
|
||||
// The enabled parameter determines whether debug messages will be stored.
|
||||
// If enabled is false, all LogDebug calls become no-ops for performance.
|
||||
func NewBufferedDebugLogger(enabled bool) *BufferedDebugLogger {
|
||||
return &BufferedDebugLogger{
|
||||
enabled: enabled,
|
||||
@@ -19,7 +24,10 @@ func NewBufferedDebugLogger(enabled bool) *BufferedDebugLogger {
|
||||
}
|
||||
}
|
||||
|
||||
// LogDebug stores a debug message
|
||||
// LogDebug stores a debug message in the internal buffer if debug logging is enabled.
|
||||
// Messages are appended to the buffer and retained until GetMessages is called.
|
||||
// If debug logging is disabled, this method is a no-op.
|
||||
// Thread-safe for concurrent calls.
|
||||
func (l *BufferedDebugLogger) LogDebug(message string) {
|
||||
if !l.enabled {
|
||||
return
|
||||
@@ -29,12 +37,17 @@ func (l *BufferedDebugLogger) LogDebug(message string) {
|
||||
l.messages = append(l.messages, message)
|
||||
}
|
||||
|
||||
// IsDebugEnabled returns whether debug logging is enabled
|
||||
// IsDebugEnabled returns whether debug logging is enabled for this logger.
|
||||
// This can be used to conditionally execute expensive debug operations
|
||||
// only when debugging is actually enabled.
|
||||
func (l *BufferedDebugLogger) IsDebugEnabled() bool {
|
||||
return l.enabled
|
||||
}
|
||||
|
||||
// GetMessages returns all buffered messages and clears the buffer
|
||||
// GetMessages returns all buffered debug messages and clears the internal buffer.
|
||||
// The returned slice contains all messages logged since the last call to GetMessages.
|
||||
// After this call, the internal buffer is empty and ready to accumulate new messages.
|
||||
// Thread-safe for concurrent calls.
|
||||
func (l *BufferedDebugLogger) GetMessages() []string {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
@@ -15,16 +15,20 @@ import (
|
||||
"github.com/mark3labs/mcphost/internal/config"
|
||||
)
|
||||
|
||||
// ConnectionPoolConfig configuration for connection pool
|
||||
// ConnectionPoolConfig defines configuration parameters for the MCP connection pool.
|
||||
// It controls connection lifecycle, health checking, and error handling behaviors.
|
||||
type ConnectionPoolConfig struct {
|
||||
MaxIdleTime time.Duration
|
||||
MaxRetries int
|
||||
HealthCheckInterval time.Duration
|
||||
MaxErrorCount int
|
||||
ReconnectDelay time.Duration
|
||||
MaxIdleTime time.Duration // Maximum time a connection can remain idle before being marked unhealthy
|
||||
MaxRetries int // Maximum number of retry attempts for failed operations
|
||||
HealthCheckInterval time.Duration // Interval between background health checks of all connections
|
||||
MaxErrorCount int // Maximum consecutive errors before marking a connection unhealthy
|
||||
ReconnectDelay time.Duration // Delay before attempting to reconnect after connection failure
|
||||
}
|
||||
|
||||
// DefaultConnectionPoolConfig returns default configuration
|
||||
// DefaultConnectionPoolConfig returns a connection pool configuration with sensible defaults.
|
||||
// Default values: 5 minute max idle time, 3 retries, 30 second health check interval,
|
||||
// 3 max errors before marking unhealthy, and 2 second reconnect delay.
|
||||
// These defaults are suitable for most MCP server deployments.
|
||||
func DefaultConnectionPoolConfig() *ConnectionPoolConfig {
|
||||
return &ConnectionPoolConfig{
|
||||
MaxIdleTime: 5 * time.Minute,
|
||||
@@ -35,7 +39,10 @@ func DefaultConnectionPoolConfig() *ConnectionPoolConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// MCPConnection represents an MCP connection
|
||||
// MCPConnection represents a single MCP client connection with health tracking and metadata.
|
||||
// It wraps an MCP client and maintains state about connection health, usage patterns,
|
||||
// and error history. Access to connection state is protected by a read-write mutex
|
||||
// for thread-safe concurrent access.
|
||||
type MCPConnection struct {
|
||||
client client.MCPClient
|
||||
serverName string
|
||||
@@ -47,7 +54,11 @@ type MCPConnection struct {
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// MCPConnectionPool manages MCP connections
|
||||
// MCPConnectionPool manages a pool of MCP client connections with automatic health checking,
|
||||
// connection reuse, and failure recovery. It provides thread-safe connection management
|
||||
// across multiple MCP servers, automatically handling connection lifecycle including
|
||||
// creation, health monitoring, and cleanup. The pool runs background health checks
|
||||
// to proactively identify and remove unhealthy connections.
|
||||
type MCPConnectionPool struct {
|
||||
connections map[string]*MCPConnection
|
||||
config *ConnectionPoolConfig
|
||||
@@ -59,7 +70,11 @@ type MCPConnectionPool struct {
|
||||
debugLogger DebugLogger
|
||||
}
|
||||
|
||||
// NewMCPConnectionPool creates a new connection pool
|
||||
// NewMCPConnectionPool creates a new MCP connection pool with the specified configuration.
|
||||
// If config is nil, default configuration values will be used. The pool starts a background
|
||||
// goroutine for periodic health checks that runs until Close is called.
|
||||
// The model parameter is used for MCP servers that require sampling support.
|
||||
// Thread-safe for concurrent use immediately after creation.
|
||||
func NewMCPConnectionPool(config *ConnectionPoolConfig, model model.ToolCallingChatModel, debug bool) *MCPConnectionPool {
|
||||
if config == nil {
|
||||
config = DefaultConnectionPoolConfig()
|
||||
@@ -79,14 +94,20 @@ func NewMCPConnectionPool(config *ConnectionPoolConfig, model model.ToolCallingC
|
||||
return pool
|
||||
}
|
||||
|
||||
// SetDebugLogger sets the debug logger for the connection pool
|
||||
// SetDebugLogger sets the debug logger for the connection pool.
|
||||
// The logger will be used to output detailed information about connection lifecycle,
|
||||
// health checks, and error conditions. Thread-safe and can be called at any time.
|
||||
func (p *MCPConnectionPool) SetDebugLogger(logger DebugLogger) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.debugLogger = logger
|
||||
}
|
||||
|
||||
// GetConnection gets a connection from the pool
|
||||
// GetConnection retrieves or creates a connection for the specified MCP server.
|
||||
// If a healthy, non-idle connection exists in the pool, it will be reused.
|
||||
// Otherwise, a new connection is created and added to the pool.
|
||||
// Returns an error if connection creation or initialization fails.
|
||||
// Thread-safe for concurrent calls.
|
||||
func (p *MCPConnectionPool) GetConnection(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (*MCPConnection, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
@@ -127,7 +148,12 @@ func (p *MCPConnectionPool) GetConnection(ctx context.Context, serverName string
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// GetConnectionWithHealthCheck gets a connection from the pool with proactive health check
|
||||
// GetConnectionWithHealthCheck retrieves a connection with an additional proactive health check.
|
||||
// Unlike GetConnection, this method performs a health check on existing connections before
|
||||
// returning them, ensuring the connection is truly healthy. This is useful for critical
|
||||
// operations where connection reliability is paramount. Creates a new connection if the
|
||||
// existing one fails the health check or doesn't exist.
|
||||
// Thread-safe for concurrent calls.
|
||||
func (p *MCPConnectionPool) GetConnectionWithHealthCheck(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (*MCPConnection, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
@@ -430,7 +456,11 @@ func (p *MCPConnectionPool) checkConnectionsHealth() {
|
||||
}
|
||||
}
|
||||
|
||||
// HandleConnectionError handles connection errors
|
||||
// HandleConnectionError records and handles errors for a specific connection.
|
||||
// It increments the error count and may mark the connection as unhealthy based on
|
||||
// the error type and configured thresholds. Connection errors (network, transport, 404)
|
||||
// immediately mark the connection as unhealthy for removal on next access.
|
||||
// Thread-safe for concurrent error reporting.
|
||||
func (p *MCPConnectionPool) HandleConnectionError(serverName string, err error) {
|
||||
p.mu.RLock()
|
||||
conn, exists := p.connections[serverName]
|
||||
@@ -460,7 +490,10 @@ func (p *MCPConnectionPool) HandleConnectionError(serverName string, err error)
|
||||
}
|
||||
}
|
||||
|
||||
// GetConnectionStats returns connection statistics
|
||||
// GetConnectionStats returns detailed statistics for all connections in the pool.
|
||||
// The returned map includes health status, last usage time, error counts, and
|
||||
// last error for each connection. Useful for monitoring and debugging connection
|
||||
// pool behavior. The returned data is a snapshot and safe for concurrent access.
|
||||
func (p *MCPConnectionPool) GetConnectionStats() map[string]interface{} {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
@@ -480,12 +513,17 @@ func (p *MCPConnectionPool) GetConnectionStats() map[string]interface{} {
|
||||
return stats
|
||||
}
|
||||
|
||||
// ServerName returns the server name for this connection
|
||||
// ServerName returns the server name associated with this MCP connection.
|
||||
// This is the configured name from the MCPHost configuration, not necessarily
|
||||
// the actual server implementation name.
|
||||
func (c *MCPConnection) ServerName() string {
|
||||
return c.serverName
|
||||
}
|
||||
|
||||
// GetClients returns all client names in the pool
|
||||
// GetClients returns a map of all MCP clients currently in the pool.
|
||||
// The map keys are server names and values are the corresponding MCP client instances.
|
||||
// The returned map is a copy and modifications won't affect the pool.
|
||||
// Note that clients may be unhealthy; use GetConnectionStats to check health status.
|
||||
func (p *MCPConnectionPool) GetClients() map[string]client.MCPClient {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
@@ -497,7 +535,11 @@ func (p *MCPConnectionPool) GetClients() map[string]client.MCPClient {
|
||||
return clients
|
||||
}
|
||||
|
||||
// Close closes the connection pool
|
||||
// Close gracefully shuts down the connection pool, closing all client connections
|
||||
// and stopping the background health check goroutine. It attempts to close all
|
||||
// connections even if some fail, logging any errors encountered.
|
||||
// Safe to call multiple times; subsequent calls are no-ops.
|
||||
// Always call Close when done with the pool to prevent resource leaks.
|
||||
func (p *MCPConnectionPool) Close() error {
|
||||
p.cancel()
|
||||
|
||||
|
||||
@@ -1,28 +1,45 @@
|
||||
package tools
|
||||
|
||||
// DebugLogger interface for debug logging
|
||||
// DebugLogger defines the interface for debug logging in the MCP tools package.
|
||||
// Implementations can provide different strategies for handling debug output,
|
||||
// such as immediate console output, buffering, or file logging.
|
||||
// All implementations must be thread-safe for concurrent use.
|
||||
type DebugLogger interface {
|
||||
// LogDebug logs a debug message. Implementations determine how the message is handled.
|
||||
LogDebug(message string)
|
||||
// IsDebugEnabled returns true if debug logging is enabled, allowing callers
|
||||
// to skip expensive debug operations when debugging is disabled.
|
||||
IsDebugEnabled() bool
|
||||
}
|
||||
|
||||
// SimpleDebugLogger is a simple implementation that prints to stdout
|
||||
// SimpleDebugLogger provides a minimal implementation of the DebugLogger interface.
|
||||
// It is intentionally silent by default to prevent duplicate or unstyled debug output
|
||||
// during initialization. Debug messages are only displayed when using the CLI debug logger
|
||||
// which provides proper formatting and styling.
|
||||
type SimpleDebugLogger struct {
|
||||
enabled bool
|
||||
}
|
||||
|
||||
// NewSimpleDebugLogger creates a new simple debug logger
|
||||
// NewSimpleDebugLogger creates a new simple debug logger instance.
|
||||
// The enabled parameter determines whether IsDebugEnabled will return true.
|
||||
// Note that LogDebug is intentionally a no-op to avoid unstyled output;
|
||||
// actual debug output is handled by the CLI's debug logger.
|
||||
func NewSimpleDebugLogger(enabled bool) *SimpleDebugLogger {
|
||||
return &SimpleDebugLogger{enabled: enabled}
|
||||
}
|
||||
|
||||
// LogDebug logs a debug message
|
||||
// LogDebug is intentionally a no-op in SimpleDebugLogger.
|
||||
// Debug messages are only displayed when using the CLI debug logger which provides
|
||||
// proper formatting and styling. This prevents duplicate or unstyled debug output
|
||||
// during initialization and ensures consistent debug output presentation.
|
||||
func (l *SimpleDebugLogger) LogDebug(message string) {
|
||||
// Silent by default - messages will only appear when using CLI debug logger
|
||||
// This prevents duplicate or unstyled debug output during initialization
|
||||
}
|
||||
|
||||
// IsDebugEnabled returns whether debug logging is enabled
|
||||
// IsDebugEnabled returns whether debug logging is enabled for this logger.
|
||||
// This allows code to conditionally execute expensive debug operations
|
||||
// only when debugging is active, improving performance in production.
|
||||
func (l *SimpleDebugLogger) IsDebugEnabled() bool {
|
||||
return l.enabled
|
||||
}
|
||||
|
||||
+43
-11
@@ -19,7 +19,11 @@ import (
|
||||
"github.com/mark3labs/mcphost/internal/config"
|
||||
)
|
||||
|
||||
// MCPToolManager manages MCP tools and clients
|
||||
// MCPToolManager manages MCP (Model Context Protocol) tools and clients across multiple servers.
|
||||
// It provides a unified interface for loading, managing, and executing tools from various MCP servers,
|
||||
// including stdio, SSE, streamable HTTP, and built-in server types. The manager handles connection
|
||||
// pooling, health checks, tool name prefixing to avoid conflicts, and sampling support for LLM interactions.
|
||||
// Thread-safe for concurrent tool invocations.
|
||||
type MCPToolManager struct {
|
||||
connectionPool *MCPConnectionPool
|
||||
tools []tool.BaseTool
|
||||
@@ -44,7 +48,9 @@ type mcpToolImpl struct {
|
||||
mapping *toolMapping
|
||||
}
|
||||
|
||||
// NewMCPToolManager creates a new MCP tool manager
|
||||
// NewMCPToolManager creates a new MCP tool manager instance.
|
||||
// Returns an initialized manager with empty tool collections ready to load tools from MCP servers.
|
||||
// The manager must be configured with SetModel and LoadTools before use.
|
||||
func NewMCPToolManager() *MCPToolManager {
|
||||
return &MCPToolManager{
|
||||
tools: make([]tool.BaseTool, 0),
|
||||
@@ -52,12 +58,18 @@ func NewMCPToolManager() *MCPToolManager {
|
||||
}
|
||||
}
|
||||
|
||||
// SetModel sets the LLM model for sampling support
|
||||
// SetModel sets the LLM model for sampling support.
|
||||
// The model is used when MCP servers request sampling operations, allowing them to
|
||||
// leverage the host's LLM capabilities for text generation tasks.
|
||||
// This method should be called before LoadTools if any MCP servers require sampling support.
|
||||
func (m *MCPToolManager) SetModel(model model.ToolCallingChatModel) {
|
||||
m.model = model
|
||||
}
|
||||
|
||||
// SetDebugLogger sets the debug logger
|
||||
// SetDebugLogger sets the debug logger for the tool manager.
|
||||
// The logger will be used to output detailed debugging information about MCP connections,
|
||||
// tool loading, and execution. If a connection pool exists, it will also be configured
|
||||
// to use the same logger for consistent debugging output.
|
||||
func (m *MCPToolManager) SetDebugLogger(logger DebugLogger) {
|
||||
m.debugLogger = logger
|
||||
if m.connectionPool != nil {
|
||||
@@ -70,7 +82,10 @@ type samplingHandler struct {
|
||||
model model.ToolCallingChatModel
|
||||
}
|
||||
|
||||
// CreateMessage handles sampling requests from MCP servers
|
||||
// CreateMessage handles sampling requests from MCP servers by forwarding them to the configured LLM model.
|
||||
// It converts MCP message formats to eino message formats, invokes the model for generation,
|
||||
// and converts the response back to MCP format. Returns an error if no model is available
|
||||
// or if generation fails.
|
||||
func (h *samplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
|
||||
if h.model == nil {
|
||||
return nil, fmt.Errorf("no model available for sampling")
|
||||
@@ -126,7 +141,11 @@ func (h *samplingHandler) CreateMessage(ctx context.Context, request mcp.CreateM
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// LoadTools loads tools from MCP servers based on configuration
|
||||
// LoadTools loads tools from all configured MCP servers based on the provided configuration.
|
||||
// It initializes the connection pool, connects to each configured server, and loads their tools.
|
||||
// Tools from different servers are prefixed with the server name to avoid naming conflicts.
|
||||
// Returns an error only if all configured servers fail to load; partial failures are logged as warnings.
|
||||
// This method is thread-safe and idempotent.
|
||||
func (m *MCPToolManager) LoadTools(ctx context.Context, config *config.Config) error {
|
||||
// Initialize connection pool
|
||||
m.config = config
|
||||
@@ -243,12 +262,18 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
|
||||
return nil
|
||||
}
|
||||
|
||||
// Info returns the tool information
|
||||
// Info returns the tool information including name, description, and parameter schema.
|
||||
// This method implements the eino tool.BaseTool interface.
|
||||
// The returned ToolInfo contains the prefixed tool name to ensure uniqueness across servers.
|
||||
func (t *mcpToolImpl) Info(ctx context.Context) (*schema.ToolInfo, error) {
|
||||
return t.info, nil
|
||||
}
|
||||
|
||||
// InvokableRun executes the tool by mapping back to the original name and server
|
||||
// InvokableRun executes the tool by mapping the prefixed name back to the original tool name and server.
|
||||
// It retrieves a healthy connection from the pool, invokes the tool on the appropriate MCP server,
|
||||
// and returns the result as a JSON string. The method handles connection errors by marking
|
||||
// connections as unhealthy in the pool for automatic recovery on subsequent requests.
|
||||
// Thread-safe for concurrent invocations.
|
||||
func (t *mcpToolImpl) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
|
||||
// Handle empty or invalid JSON arguments
|
||||
var arguments any
|
||||
@@ -300,12 +325,16 @@ func (t *mcpToolImpl) InvokableRun(ctx context.Context, argumentsInJSON string,
|
||||
return marshaledResult, nil
|
||||
}
|
||||
|
||||
// GetTools returns all loaded tools
|
||||
// GetTools returns all loaded tools from all configured MCP servers.
|
||||
// Tools are returned with their prefixed names (serverName__toolName) to ensure uniqueness.
|
||||
// The returned slice is a copy and can be safely modified by the caller.
|
||||
func (m *MCPToolManager) GetTools() []tool.BaseTool {
|
||||
return m.tools
|
||||
}
|
||||
|
||||
// GetLoadedServerNames returns the names of successfully loaded MCP servers
|
||||
// GetLoadedServerNames returns the names of all successfully loaded MCP servers.
|
||||
// This includes servers that are currently connected and have had their tools loaded,
|
||||
// regardless of their current health status. Useful for debugging and status reporting.
|
||||
func (m *MCPToolManager) GetLoadedServerNames() []string {
|
||||
var names []string
|
||||
for serverName := range m.connectionPool.GetClients() {
|
||||
@@ -314,7 +343,10 @@ func (m *MCPToolManager) GetLoadedServerNames() []string {
|
||||
return names
|
||||
}
|
||||
|
||||
// Close closes all MCP clients
|
||||
// Close closes all MCP client connections and cleans up resources.
|
||||
// This method should be called when the tool manager is no longer needed to ensure
|
||||
// proper cleanup of stdio processes, network connections, and other resources.
|
||||
// It is safe to call Close multiple times.
|
||||
func (m *MCPToolManager) Close() error {
|
||||
return m.connectionPool.Close()
|
||||
}
|
||||
|
||||
@@ -21,70 +21,90 @@ type blockRenderer struct {
|
||||
// renderingOption configures block rendering
|
||||
type renderingOption func(*blockRenderer)
|
||||
|
||||
// WithFullWidth makes the block take full available width
|
||||
// WithFullWidth returns a renderingOption that configures the block renderer
|
||||
// to expand to the full available width of its container. When enabled, the
|
||||
// block will fill the entire horizontal space rather than sizing to its content.
|
||||
func WithFullWidth() renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.fullWidth = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithAlign sets the horizontal alignment of the block
|
||||
// WithAlign returns a renderingOption that sets the horizontal alignment
|
||||
// of the block content within its container. The align parameter accepts
|
||||
// lipgloss.Left, lipgloss.Center, or lipgloss.Right positions.
|
||||
func WithAlign(align lipgloss.Position) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.align = &align
|
||||
}
|
||||
}
|
||||
|
||||
// WithBorderColor sets the border color
|
||||
// WithBorderColor returns a renderingOption that sets the border color
|
||||
// for the block. The color parameter uses lipgloss.AdaptiveColor to support
|
||||
// both light and dark terminal themes automatically.
|
||||
func WithBorderColor(color lipgloss.AdaptiveColor) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.borderColor = &color
|
||||
}
|
||||
}
|
||||
|
||||
// WithMarginTop sets the top margin
|
||||
// WithMarginTop returns a renderingOption that sets the top margin
|
||||
// for the block. The margin is specified in number of lines and adds
|
||||
// vertical space above the block.
|
||||
func WithMarginTop(margin int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.marginTop = margin
|
||||
}
|
||||
}
|
||||
|
||||
// WithMarginBottom sets the bottom margin
|
||||
// WithMarginBottom returns a renderingOption that sets the bottom margin
|
||||
// for the block. The margin is specified in number of lines and adds
|
||||
// vertical space below the block.
|
||||
func WithMarginBottom(margin int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.marginBottom = margin
|
||||
}
|
||||
}
|
||||
|
||||
// WithPaddingLeft sets the left padding
|
||||
// WithPaddingLeft returns a renderingOption that sets the left padding
|
||||
// for the block content. The padding is specified in number of characters
|
||||
// and adds horizontal space between the left border and the content.
|
||||
func WithPaddingLeft(padding int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.paddingLeft = padding
|
||||
}
|
||||
}
|
||||
|
||||
// WithPaddingRight sets the right padding
|
||||
// WithPaddingRight returns a renderingOption that sets the right padding
|
||||
// for the block content. The padding is specified in number of characters
|
||||
// and adds horizontal space between the content and the right border.
|
||||
func WithPaddingRight(padding int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.paddingRight = padding
|
||||
}
|
||||
}
|
||||
|
||||
// WithPaddingTop sets the top padding
|
||||
// WithPaddingTop returns a renderingOption that sets the top padding
|
||||
// for the block content. The padding is specified in number of lines
|
||||
// and adds vertical space between the top border and the content.
|
||||
func WithPaddingTop(padding int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.paddingTop = padding
|
||||
}
|
||||
}
|
||||
|
||||
// WithPaddingBottom sets the bottom padding
|
||||
// WithPaddingBottom returns a renderingOption that sets the bottom padding
|
||||
// for the block content. The padding is specified in number of lines
|
||||
// and adds vertical space between the content and the bottom border.
|
||||
func WithPaddingBottom(padding int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.paddingBottom = padding
|
||||
}
|
||||
}
|
||||
|
||||
// WithWidth sets a specific width for the block
|
||||
// WithWidth returns a renderingOption that sets a specific width for the block
|
||||
// in characters. This overrides the default container width and allows precise
|
||||
// control over the block's horizontal dimensions.
|
||||
func WithWidth(width int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.width = width
|
||||
|
||||
@@ -9,7 +9,11 @@ import (
|
||||
utilCallbacks "github.com/cloudwego/eino/utils/callbacks"
|
||||
)
|
||||
|
||||
// CreateCallbackHandler creates a callback handler using HandlerHelper
|
||||
// CreateCallbackHandler creates and returns a callbacks.Handler that manages
|
||||
// tool execution callbacks for the CLI. The handler displays tool calls,
|
||||
// handles errors, and manages streaming output for interactive tool operations.
|
||||
// It integrates with the eino callback system to provide real-time UI feedback
|
||||
// during tool execution.
|
||||
func (c *CLI) CreateCallbackHandler() callbacks.Handler {
|
||||
toolHandler := &utilCallbacks.ToolCallbackHandler{
|
||||
OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *tool.CallbackInput) context.Context {
|
||||
|
||||
+95
-31
@@ -17,7 +17,10 @@ var (
|
||||
promptStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("12"))
|
||||
)
|
||||
|
||||
// CLI handles the command line interface with improved message rendering
|
||||
// CLI manages the command-line interface for MCPHost, providing message rendering,
|
||||
// user input handling, and display management. It supports both standard and compact
|
||||
// display modes, handles streaming responses, tracks token usage, and manages the
|
||||
// overall conversation flow between the user and AI assistants.
|
||||
type CLI struct {
|
||||
messageRenderer *MessageRenderer
|
||||
compactRenderer *CompactRenderer // Add compact renderer
|
||||
@@ -32,7 +35,10 @@ type CLI struct {
|
||||
usageDisplayed bool // track if usage info was displayed after last assistant message
|
||||
}
|
||||
|
||||
// NewCLI creates a new CLI instance with message container
|
||||
// NewCLI creates and initializes a new CLI instance with the specified display modes.
|
||||
// The debug parameter enables debug message rendering, while compact enables a more
|
||||
// condensed display format. Returns an initialized CLI ready for interaction or an
|
||||
// error if initialization fails.
|
||||
func NewCLI(debug bool, compact bool) (*CLI, error) {
|
||||
cli := &CLI{
|
||||
compactMode: compact,
|
||||
@@ -46,7 +52,9 @@ func NewCLI(debug bool, compact bool) (*CLI, error) {
|
||||
return cli, nil
|
||||
}
|
||||
|
||||
// SetUsageTracker sets the usage tracker for the CLI
|
||||
// SetUsageTracker attaches a usage tracker to the CLI for monitoring token
|
||||
// consumption and costs. The tracker will be automatically updated with the
|
||||
// current display width for proper rendering.
|
||||
func (c *CLI) SetUsageTracker(tracker *UsageTracker) {
|
||||
c.usageTracker = tracker
|
||||
if c.usageTracker != nil {
|
||||
@@ -54,12 +62,14 @@ func (c *CLI) SetUsageTracker(tracker *UsageTracker) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetDebugLogger returns a debug logger that uses the CLI for rendering
|
||||
// GetDebugLogger returns a CLIDebugLogger instance that routes debug output
|
||||
// through the CLI's rendering system for consistent message formatting and display.
|
||||
func (c *CLI) GetDebugLogger() *CLIDebugLogger {
|
||||
return NewCLIDebugLogger(c)
|
||||
}
|
||||
|
||||
// SetModelName sets the current model name for the CLI
|
||||
// SetModelName updates the current AI model name being used in the conversation.
|
||||
// This name is displayed in message headers to indicate which model is responding.
|
||||
func (c *CLI) SetModelName(modelName string) {
|
||||
c.modelName = modelName
|
||||
if c.messageContainer != nil {
|
||||
@@ -67,7 +77,10 @@ func (c *CLI) SetModelName(modelName string) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetPrompt gets user input using the huh library with divider and padding
|
||||
// GetPrompt displays an interactive prompt and waits for user input. It provides
|
||||
// slash command support, multi-line editing, and cancellation handling. Returns
|
||||
// the user's input as a string, or an error if the operation was cancelled or
|
||||
// failed. Returns io.EOF for clean exit signals.
|
||||
func (c *CLI) GetPrompt() (string, error) {
|
||||
// Usage info is now displayed immediately after responses via DisplayUsageAfterResponse()
|
||||
// No need to display it here to avoid duplication
|
||||
@@ -107,7 +120,9 @@ func (c *CLI) GetPrompt() (string, error) {
|
||||
return "", fmt.Errorf("unexpected model type")
|
||||
}
|
||||
|
||||
// ShowSpinner displays a spinner with the given message and executes the action
|
||||
// ShowSpinner displays an animated spinner with the specified message while
|
||||
// executing the provided action function. The spinner automatically stops when
|
||||
// the action completes. Returns any error returned by the action function.
|
||||
func (c *CLI) ShowSpinner(message string, action func() error) error {
|
||||
spinner := NewSpinner(message)
|
||||
spinner.Start()
|
||||
@@ -119,7 +134,9 @@ func (c *CLI) ShowSpinner(message string, action func() error) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// DisplayUserMessage displays the user's message using the appropriate renderer
|
||||
// DisplayUserMessage renders and displays a user's message with appropriate
|
||||
// formatting based on the current display mode (standard or compact). The message
|
||||
// is timestamped and styled according to the active theme.
|
||||
func (c *CLI) DisplayUserMessage(message string) {
|
||||
var msg UIMessage
|
||||
if c.compactMode {
|
||||
@@ -131,12 +148,16 @@ func (c *CLI) DisplayUserMessage(message string) {
|
||||
c.displayContainer()
|
||||
}
|
||||
|
||||
// DisplayAssistantMessage displays the assistant's message using the new renderer
|
||||
// DisplayAssistantMessage renders and displays an AI assistant's response message
|
||||
// with appropriate formatting. This method delegates to DisplayAssistantMessageWithModel
|
||||
// with an empty model name for backward compatibility.
|
||||
func (c *CLI) DisplayAssistantMessage(message string) error {
|
||||
return c.DisplayAssistantMessageWithModel(message, "")
|
||||
}
|
||||
|
||||
// DisplayAssistantMessageWithModel displays the assistant's message with model info
|
||||
// DisplayAssistantMessageWithModel renders and displays an AI assistant's response
|
||||
// with the specified model name shown in the message header. The message is
|
||||
// formatted according to the current display mode and includes timestamp information.
|
||||
func (c *CLI) DisplayAssistantMessageWithModel(message, modelName string) error {
|
||||
var msg UIMessage
|
||||
if c.compactMode {
|
||||
@@ -149,7 +170,9 @@ func (c *CLI) DisplayAssistantMessageWithModel(message, modelName string) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// DisplayToolCallMessage displays a tool call in progress
|
||||
// DisplayToolCallMessage renders and displays a message indicating that a tool
|
||||
// is being executed. Shows the tool name and its arguments formatted appropriately
|
||||
// for the current display mode. This is typically shown while a tool is running.
|
||||
func (c *CLI) DisplayToolCallMessage(toolName, toolArgs string) {
|
||||
|
||||
c.messageContainer.messages = nil // clear previous messages (they should have been printed already)
|
||||
@@ -167,7 +190,9 @@ func (c *CLI) DisplayToolCallMessage(toolName, toolArgs string) {
|
||||
c.displayContainer()
|
||||
}
|
||||
|
||||
// DisplayToolMessage displays a tool call message
|
||||
// DisplayToolMessage renders and displays the complete result of a tool execution,
|
||||
// including the tool name, arguments, and result. The isError parameter determines
|
||||
// whether the result should be displayed as an error or success message.
|
||||
func (c *CLI) DisplayToolMessage(toolName, toolArgs, toolResult string, isError bool) {
|
||||
var msg UIMessage
|
||||
if c.compactMode {
|
||||
@@ -181,7 +206,9 @@ func (c *CLI) DisplayToolMessage(toolName, toolArgs, toolResult string, isError
|
||||
c.displayContainer()
|
||||
}
|
||||
|
||||
// StartStreamingMessage starts a streaming assistant message
|
||||
// StartStreamingMessage initializes a new streaming message display for real-time
|
||||
// AI responses. The message will be progressively updated as content arrives.
|
||||
// The modelName parameter indicates which AI model is generating the response.
|
||||
func (c *CLI) StartStreamingMessage(modelName string) {
|
||||
// Add an empty assistant message that we'll update during streaming
|
||||
var msg UIMessage
|
||||
@@ -196,14 +223,18 @@ func (c *CLI) StartStreamingMessage(modelName string) {
|
||||
c.displayContainer()
|
||||
}
|
||||
|
||||
// UpdateStreamingMessage updates the streaming message with new content
|
||||
// UpdateStreamingMessage updates the currently streaming message with new content.
|
||||
// This method should be called after StartStreamingMessage to progressively display
|
||||
// AI responses as they are generated in real-time.
|
||||
func (c *CLI) UpdateStreamingMessage(content string) {
|
||||
// Update the last message (which should be the streaming assistant message)
|
||||
c.messageContainer.UpdateLastMessage(content)
|
||||
c.displayContainer()
|
||||
}
|
||||
|
||||
// DisplayError displays an error message using the appropriate renderer
|
||||
// DisplayError renders and displays an error message with distinctive formatting
|
||||
// to ensure visibility. The error is timestamped and styled according to the
|
||||
// current display mode's error theme.
|
||||
func (c *CLI) DisplayError(err error) {
|
||||
var msg UIMessage
|
||||
if c.compactMode {
|
||||
@@ -215,7 +246,9 @@ func (c *CLI) DisplayError(err error) {
|
||||
c.displayContainer()
|
||||
}
|
||||
|
||||
// DisplayInfo displays an informational message using the appropriate renderer
|
||||
// DisplayInfo renders and displays an informational system message. These messages
|
||||
// are typically used for status updates, notifications, or other non-error system
|
||||
// communications to the user.
|
||||
func (c *CLI) DisplayInfo(message string) {
|
||||
var msg UIMessage
|
||||
if c.compactMode {
|
||||
@@ -227,7 +260,8 @@ func (c *CLI) DisplayInfo(message string) {
|
||||
c.displayContainer()
|
||||
}
|
||||
|
||||
// DisplayCancellation displays a cancellation message
|
||||
// DisplayCancellation displays a system message indicating that the current
|
||||
// AI generation has been cancelled by the user (typically via ESC key).
|
||||
func (c *CLI) DisplayCancellation() {
|
||||
var msg UIMessage
|
||||
if c.compactMode {
|
||||
@@ -239,7 +273,9 @@ func (c *CLI) DisplayCancellation() {
|
||||
c.displayContainer()
|
||||
}
|
||||
|
||||
// DisplayDebugMessage displays debug messages using the appropriate renderer
|
||||
// DisplayDebugMessage renders and displays a debug message if debug mode is enabled.
|
||||
// Debug messages are formatted distinctively and only shown when the CLI is
|
||||
// initialized with debug=true.
|
||||
func (c *CLI) DisplayDebugMessage(message string) {
|
||||
if !c.debug {
|
||||
return
|
||||
@@ -254,7 +290,9 @@ func (c *CLI) DisplayDebugMessage(message string) {
|
||||
c.displayContainer()
|
||||
}
|
||||
|
||||
// DisplayDebugConfig displays configuration settings using the appropriate renderer
|
||||
// DisplayDebugConfig renders and displays configuration settings in a formatted
|
||||
// debug message. The config parameter should contain key-value pairs representing
|
||||
// configuration options that will be displayed for debugging purposes.
|
||||
func (c *CLI) DisplayDebugConfig(config map[string]any) {
|
||||
var msg UIMessage
|
||||
if c.compactMode {
|
||||
@@ -266,7 +304,9 @@ func (c *CLI) DisplayDebugConfig(config map[string]any) {
|
||||
c.displayContainer()
|
||||
}
|
||||
|
||||
// DisplayHelp displays help information in a message block
|
||||
// DisplayHelp renders and displays comprehensive help information showing all
|
||||
// available slash commands, keyboard shortcuts, and usage instructions in a
|
||||
// formatted system message block.
|
||||
func (c *CLI) DisplayHelp() {
|
||||
help := `## Available Commands
|
||||
|
||||
@@ -288,7 +328,9 @@ You can also just type your message to chat with the AI assistant.`
|
||||
c.displayContainer()
|
||||
}
|
||||
|
||||
// DisplayTools displays available tools in a message block
|
||||
// DisplayTools renders and displays a formatted list of all available tools
|
||||
// that can be used by the AI assistant. Each tool is numbered and shown in
|
||||
// a system message block for easy reference.
|
||||
func (c *CLI) DisplayTools(tools []string) {
|
||||
var content strings.Builder
|
||||
content.WriteString("## Available Tools\n\n")
|
||||
@@ -307,7 +349,9 @@ func (c *CLI) DisplayTools(tools []string) {
|
||||
c.displayContainer()
|
||||
}
|
||||
|
||||
// DisplayServers displays configured MCP servers in a message block
|
||||
// DisplayServers renders and displays a formatted list of all configured MCP
|
||||
// (Model Context Protocol) servers. Each server is numbered and shown in a
|
||||
// system message block for easy reference.
|
||||
func (c *CLI) DisplayServers(servers []string) {
|
||||
var content strings.Builder
|
||||
content.WriteString("## Configured MCP Servers\n\n")
|
||||
@@ -326,18 +370,25 @@ func (c *CLI) DisplayServers(servers []string) {
|
||||
c.displayContainer()
|
||||
}
|
||||
|
||||
// IsSlashCommand checks if the input is a slash command
|
||||
// IsSlashCommand determines whether the provided input string is a slash command
|
||||
// by checking if it starts with a forward slash (/). Returns true for commands
|
||||
// like "/help", "/tools", etc.
|
||||
func (c *CLI) IsSlashCommand(input string) bool {
|
||||
return strings.HasPrefix(input, "/")
|
||||
}
|
||||
|
||||
// SlashCommandResult represents the result of handling a slash command
|
||||
// SlashCommandResult encapsulates the outcome of processing a slash command,
|
||||
// indicating whether the command was recognized and handled, and whether the
|
||||
// conversation history should be cleared as a result of the command.
|
||||
type SlashCommandResult struct {
|
||||
Handled bool
|
||||
ClearHistory bool
|
||||
}
|
||||
|
||||
// HandleSlashCommand handles slash commands and returns the result
|
||||
// HandleSlashCommand processes and executes slash commands, returning a result
|
||||
// that indicates whether the command was handled and any side effects. The servers
|
||||
// and tools parameters provide context for commands that display available resources.
|
||||
// Supported commands include /help, /tools, /servers, /clear, /usage, /reset-usage, and /quit.
|
||||
func (c *CLI) HandleSlashCommand(input string, servers []string, tools []string) SlashCommandResult {
|
||||
switch input {
|
||||
case "/help":
|
||||
@@ -369,7 +420,9 @@ func (c *CLI) HandleSlashCommand(input string, servers []string, tools []string)
|
||||
}
|
||||
}
|
||||
|
||||
// ClearMessages clears all messages from the container
|
||||
// ClearMessages removes all messages from the display container and refreshes
|
||||
// the screen. This is typically used when starting a new conversation or
|
||||
// clearing the chat history.
|
||||
func (c *CLI) ClearMessages() {
|
||||
c.messageContainer.Clear()
|
||||
c.displayContainer()
|
||||
@@ -422,14 +475,19 @@ func (c *CLI) displayContainer() {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateUsage updates the usage tracker with token counts and costs
|
||||
// UpdateUsage estimates and records token usage based on input and output text.
|
||||
// This method uses text-based estimation when actual token counts are not available
|
||||
// from the AI provider's response metadata.
|
||||
func (c *CLI) UpdateUsage(inputText, outputText string) {
|
||||
if c.usageTracker != nil {
|
||||
c.usageTracker.EstimateAndUpdateUsage(inputText, outputText)
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateUsageFromResponse updates the usage tracker using token usage from response metadata
|
||||
// UpdateUsageFromResponse records token usage using metadata from the AI provider's
|
||||
// response when available. Falls back to text-based estimation if the metadata is
|
||||
// missing or appears unreliable. This provides more accurate usage tracking when
|
||||
// providers supply token count information.
|
||||
func (c *CLI) UpdateUsageFromResponse(response *schema.Message, inputText string) {
|
||||
if c.usageTracker == nil {
|
||||
return
|
||||
@@ -461,7 +519,9 @@ func (c *CLI) UpdateUsageFromResponse(response *schema.Message, inputText string
|
||||
}
|
||||
}
|
||||
|
||||
// DisplayUsageStats displays current usage statistics
|
||||
// DisplayUsageStats renders and displays comprehensive token usage statistics
|
||||
// including the last request's token counts and costs, as well as session totals.
|
||||
// Shows a message if usage tracking is not available for the current model.
|
||||
func (c *CLI) DisplayUsageStats() {
|
||||
if c.usageTracker == nil {
|
||||
c.DisplayInfo("Usage tracking is not available for this model.")
|
||||
@@ -492,7 +552,9 @@ func (c *CLI) DisplayUsageStats() {
|
||||
c.displayContainer()
|
||||
}
|
||||
|
||||
// ResetUsageStats resets the usage tracking statistics
|
||||
// ResetUsageStats clears all accumulated usage statistics, resetting token counts
|
||||
// and costs to zero. Displays a confirmation message after resetting or an info
|
||||
// message if usage tracking is not available.
|
||||
func (c *CLI) ResetUsageStats() {
|
||||
if c.usageTracker == nil {
|
||||
c.DisplayInfo("Usage tracking is not available for this model.")
|
||||
@@ -503,7 +565,9 @@ func (c *CLI) ResetUsageStats() {
|
||||
c.DisplayInfo("Usage statistics have been reset.")
|
||||
}
|
||||
|
||||
// DisplayUsageAfterResponse displays usage information immediately after a response
|
||||
// DisplayUsageAfterResponse renders and displays token usage information immediately
|
||||
// following an AI response. This provides real-time feedback about the cost and
|
||||
// token consumption of each interaction.
|
||||
func (c *CLI) DisplayUsageAfterResponse() {
|
||||
if c.usageTracker == nil {
|
||||
return
|
||||
|
||||
+12
-4
@@ -1,6 +1,8 @@
|
||||
package ui
|
||||
|
||||
// SlashCommand represents a slash command with its metadata
|
||||
// SlashCommand represents a user-invokable slash command with its metadata.
|
||||
// Commands can have multiple aliases and are organized by category for better
|
||||
// discoverability and help display.
|
||||
type SlashCommand struct {
|
||||
Name string
|
||||
Description string
|
||||
@@ -8,7 +10,9 @@ type SlashCommand struct {
|
||||
Category string // e.g., "Navigation", "System", "Info"
|
||||
}
|
||||
|
||||
// SlashCommands is the registry of all available slash commands
|
||||
// SlashCommands provides the global registry of all available slash commands
|
||||
// in the application. Commands are organized by category (Info, System) and
|
||||
// include their primary names, descriptions, and alternative aliases.
|
||||
var SlashCommands = []SlashCommand{
|
||||
{
|
||||
Name: "/help",
|
||||
@@ -55,7 +59,9 @@ var SlashCommands = []SlashCommand{
|
||||
},
|
||||
}
|
||||
|
||||
// GetCommandByName returns a command by its name or alias
|
||||
// GetCommandByName looks up a slash command by its primary name or any of its
|
||||
// aliases. Returns a pointer to the matching SlashCommand, or nil if no command
|
||||
// matches the provided name.
|
||||
func GetCommandByName(name string) *SlashCommand {
|
||||
for i := range SlashCommands {
|
||||
cmd := &SlashCommands[i]
|
||||
@@ -71,7 +77,9 @@ func GetCommandByName(name string) *SlashCommand {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllCommandNames returns all command names and aliases
|
||||
// GetAllCommandNames returns a complete list of all command names and their aliases.
|
||||
// This is useful for command completion, validation, and help display. The returned
|
||||
// slice contains both primary command names and all alternative aliases.
|
||||
func GetAllCommandNames() []string {
|
||||
var names []string
|
||||
for _, cmd := range SlashCommands {
|
||||
|
||||
@@ -8,13 +8,17 @@ import (
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// CompactRenderer handles rendering messages in compact format
|
||||
// CompactRenderer handles rendering messages in a space-efficient compact format,
|
||||
// optimized for terminals with limited vertical space. It displays messages with
|
||||
// minimal decorations while maintaining readability and essential information.
|
||||
type CompactRenderer struct {
|
||||
width int
|
||||
debug bool
|
||||
}
|
||||
|
||||
// NewCompactRenderer creates a new compact message renderer
|
||||
// NewCompactRenderer creates and initializes a new CompactRenderer with the specified
|
||||
// terminal width and debug mode setting. The width parameter determines line wrapping,
|
||||
// while debug enables additional diagnostic output in rendered messages.
|
||||
func NewCompactRenderer(width int, debug bool) *CompactRenderer {
|
||||
return &CompactRenderer{
|
||||
width: width,
|
||||
@@ -22,12 +26,16 @@ func NewCompactRenderer(width int, debug bool) *CompactRenderer {
|
||||
}
|
||||
}
|
||||
|
||||
// SetWidth updates the renderer width
|
||||
// SetWidth updates the terminal width for the renderer, affecting how content
|
||||
// is wrapped and formatted in subsequent render operations.
|
||||
func (r *CompactRenderer) SetWidth(width int) {
|
||||
r.width = width
|
||||
}
|
||||
|
||||
// RenderUserMessage renders a user message in compact format
|
||||
// RenderUserMessage renders a user's input message in compact format with a
|
||||
// distinctive symbol (>) and label. The content is formatted to preserve structure
|
||||
// while minimizing vertical space usage. Returns a UIMessage with formatted content
|
||||
// and metadata.
|
||||
func (r *CompactRenderer) RenderUserMessage(content string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Secondary).Render(">")
|
||||
@@ -58,7 +66,9 @@ func (r *CompactRenderer) RenderUserMessage(content string, timestamp time.Time)
|
||||
}
|
||||
}
|
||||
|
||||
// RenderAssistantMessage renders an assistant message in compact format
|
||||
// RenderAssistantMessage renders an AI assistant's response in compact format with
|
||||
// a distinctive symbol (<) and the model name as label. Empty content is displayed
|
||||
// as "(no output)". Returns a UIMessage with formatted content and metadata.
|
||||
func (r *CompactRenderer) RenderAssistantMessage(content string, timestamp time.Time, modelName string) UIMessage {
|
||||
theme := getTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Primary).Render("<")
|
||||
@@ -97,7 +107,9 @@ func (r *CompactRenderer) RenderAssistantMessage(content string, timestamp time.
|
||||
}
|
||||
}
|
||||
|
||||
// RenderToolCallMessage renders a tool call in progress in compact format
|
||||
// RenderToolCallMessage renders a tool call notification in compact format, showing
|
||||
// the tool being executed with its arguments in a single line. The tool name is
|
||||
// highlighted and arguments are displayed in a muted color for visual distinction.
|
||||
func (r *CompactRenderer) RenderToolCallMessage(toolName, toolArgs string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Tool).Render("[")
|
||||
@@ -119,7 +131,9 @@ func (r *CompactRenderer) RenderToolCallMessage(toolName, toolArgs string, times
|
||||
}
|
||||
}
|
||||
|
||||
// RenderToolMessage renders a tool result in compact format
|
||||
// RenderToolMessage renders the result of a tool execution in compact format,
|
||||
// displaying the outcome with appropriate styling based on success or error status.
|
||||
// Results are limited to 5 lines to maintain compact display while preserving key information.
|
||||
func (r *CompactRenderer) RenderToolMessage(toolName, toolArgs, toolResult string, isError bool) UIMessage {
|
||||
theme := getTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Muted).Render("]")
|
||||
@@ -165,7 +179,9 @@ func (r *CompactRenderer) RenderToolMessage(toolName, toolArgs, toolResult strin
|
||||
}
|
||||
}
|
||||
|
||||
// RenderSystemMessage renders a system message in compact format
|
||||
// RenderSystemMessage renders a system notification or informational message in
|
||||
// compact format with a distinctive symbol (*) and "System" label. Content is
|
||||
// formatted to fit on a single line for minimal space usage.
|
||||
func (r *CompactRenderer) RenderSystemMessage(content string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.System).Render("*")
|
||||
@@ -183,7 +199,9 @@ func (r *CompactRenderer) RenderSystemMessage(content string, timestamp time.Tim
|
||||
}
|
||||
}
|
||||
|
||||
// RenderErrorMessage renders an error message in compact format
|
||||
// RenderErrorMessage renders an error notification in compact format with a
|
||||
// distinctive error symbol (!) and styling to ensure visibility. The error
|
||||
// content is displayed in a single line with appropriate color highlighting.
|
||||
func (r *CompactRenderer) RenderErrorMessage(errorMsg string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Error).Render("!")
|
||||
@@ -201,7 +219,9 @@ func (r *CompactRenderer) RenderErrorMessage(errorMsg string, timestamp time.Tim
|
||||
}
|
||||
}
|
||||
|
||||
// RenderDebugMessage renders debug messages in compact format
|
||||
// RenderDebugMessage renders diagnostic information in compact format when debug
|
||||
// mode is enabled. Messages are truncated if they exceed the available width to
|
||||
// maintain single-line display.
|
||||
func (r *CompactRenderer) RenderDebugMessage(message string, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Tool).Render("*")
|
||||
@@ -223,7 +243,9 @@ func (r *CompactRenderer) RenderDebugMessage(message string, timestamp time.Time
|
||||
}
|
||||
}
|
||||
|
||||
// RenderDebugConfigMessage renders debug config in compact format
|
||||
// RenderDebugConfigMessage renders configuration settings in compact format for
|
||||
// debugging purposes. Config entries are displayed as key=value pairs separated
|
||||
// by commas, truncated if necessary to fit on a single line.
|
||||
func (r *CompactRenderer) RenderDebugConfigMessage(config map[string]any, timestamp time.Time) UIMessage {
|
||||
theme := getTheme()
|
||||
symbol := lipgloss.NewStyle().Foreground(theme.Tool).Render("*")
|
||||
|
||||
@@ -6,17 +6,25 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// CLIDebugLogger implements the tools.DebugLogger interface using CLI rendering
|
||||
// CLIDebugLogger implements the tools.DebugLogger interface using CLI rendering.
|
||||
// It provides debug logging functionality that integrates with the CLI's display
|
||||
// system, ensuring debug messages are properly formatted and displayed alongside
|
||||
// other conversation content.
|
||||
type CLIDebugLogger struct {
|
||||
cli *CLI
|
||||
}
|
||||
|
||||
// NewCLIDebugLogger creates a new CLI debug logger
|
||||
// NewCLIDebugLogger creates and returns a new CLIDebugLogger instance that routes
|
||||
// debug output through the provided CLI instance. The logger will respect the CLI's
|
||||
// debug mode setting and display format preferences.
|
||||
func NewCLIDebugLogger(cli *CLI) *CLIDebugLogger {
|
||||
return &CLIDebugLogger{cli: cli}
|
||||
}
|
||||
|
||||
// LogDebug logs a debug message using the CLI's debug message renderer
|
||||
// LogDebug processes and displays a debug message through the CLI's rendering system.
|
||||
// Messages are formatted with appropriate emojis and tags based on their content type
|
||||
// (DEBUG, POOL, etc.) and only displayed when debug mode is enabled. The method handles
|
||||
// multi-line debug output and connection pool status messages with context-aware formatting.
|
||||
func (l *CLIDebugLogger) LogDebug(message string) {
|
||||
if l.cli == nil || !l.cli.debug {
|
||||
return
|
||||
@@ -70,7 +78,9 @@ func (l *CLIDebugLogger) LogDebug(message string) {
|
||||
l.cli.displayContainer()
|
||||
}
|
||||
|
||||
// IsDebugEnabled returns whether debug logging is enabled
|
||||
// IsDebugEnabled checks whether debug logging is currently active. Returns true
|
||||
// if the CLI instance exists and has debug mode enabled, allowing callers to
|
||||
// conditionally perform expensive debug operations only when necessary.
|
||||
func (l *CLIDebugLogger) IsDebugEnabled() bool {
|
||||
return l.cli != nil && l.cli.debug
|
||||
}
|
||||
|
||||
@@ -11,17 +11,21 @@ import (
|
||||
// Global theme instance
|
||||
var currentTheme = DefaultTheme()
|
||||
|
||||
// GetTheme returns the current theme
|
||||
// GetTheme returns the currently active UI theme. The theme controls all color
|
||||
// and styling decisions throughout the application's interface.
|
||||
func GetTheme() Theme {
|
||||
return currentTheme
|
||||
}
|
||||
|
||||
// SetTheme sets the current theme
|
||||
// SetTheme updates the global UI theme, affecting all subsequent rendering
|
||||
// operations. This allows runtime theme switching for different visual preferences.
|
||||
func SetTheme(theme Theme) {
|
||||
currentTheme = theme
|
||||
}
|
||||
|
||||
// Theme represents a complete UI theme
|
||||
// Theme defines a comprehensive color scheme for the application's UI, supporting
|
||||
// both light and dark terminal modes through adaptive colors. It includes semantic
|
||||
// colors for different message types and UI elements, based on the Catppuccin color palette.
|
||||
type Theme struct {
|
||||
Primary lipgloss.AdaptiveColor
|
||||
Secondary lipgloss.AdaptiveColor
|
||||
@@ -41,7 +45,9 @@ type Theme struct {
|
||||
Highlight lipgloss.AdaptiveColor
|
||||
}
|
||||
|
||||
// DefaultTheme returns the default MCPHost theme (Catppuccin Mocha)
|
||||
// DefaultTheme creates and returns the default MCPHost theme based on the Catppuccin
|
||||
// Mocha (dark) and Latte (light) color palettes. This theme provides a cohesive,
|
||||
// pleasant visual experience with carefully selected colors for different UI elements.
|
||||
func DefaultTheme() Theme {
|
||||
return Theme{
|
||||
Primary: lipgloss.AdaptiveColor{
|
||||
@@ -111,7 +117,9 @@ func DefaultTheme() Theme {
|
||||
}
|
||||
}
|
||||
|
||||
// StyleCard creates a styled card container
|
||||
// StyleCard creates a lipgloss style for card-like containers with rounded borders,
|
||||
// padding, and appropriate width. Used for grouping related content in a visually
|
||||
// distinct box.
|
||||
func StyleCard(width int, theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Width(width).
|
||||
@@ -121,56 +129,64 @@ func StyleCard(width int, theme Theme) lipgloss.Style {
|
||||
MarginBottom(1)
|
||||
}
|
||||
|
||||
// StyleHeader creates a styled header
|
||||
// StyleHeader creates a lipgloss style for primary headers using the theme's
|
||||
// primary color with bold text for emphasis and hierarchy.
|
||||
func StyleHeader(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Primary).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleSubheader creates a styled subheader
|
||||
// StyleSubheader creates a lipgloss style for secondary headers using the theme's
|
||||
// secondary color with bold text, providing visual hierarchy below primary headers.
|
||||
func StyleSubheader(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Secondary).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleMuted creates muted text styling
|
||||
// StyleMuted creates a lipgloss style for de-emphasized text using muted colors
|
||||
// and italic formatting, suitable for supplementary or less important information.
|
||||
func StyleMuted(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Italic(true)
|
||||
}
|
||||
|
||||
// StyleSuccess creates success text styling
|
||||
// StyleSuccess creates a lipgloss style for success messages using green colors
|
||||
// with bold text to indicate successful operations or positive outcomes.
|
||||
func StyleSuccess(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Success).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleError creates error text styling
|
||||
// StyleError creates a lipgloss style for error messages using red colors
|
||||
// with bold text to ensure visibility of problems or failures.
|
||||
func StyleError(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Error).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleWarning creates warning text styling
|
||||
// StyleWarning creates a lipgloss style for warning messages using yellow/amber
|
||||
// colors with bold text to draw attention to potential issues or cautions.
|
||||
func StyleWarning(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Warning).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleInfo creates info text styling
|
||||
// StyleInfo creates a lipgloss style for informational messages using blue colors
|
||||
// with bold text for general notifications and status updates.
|
||||
func StyleInfo(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Info).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// CreateSeparator creates a styled separator line
|
||||
// CreateSeparator generates a horizontal separator line with the specified width,
|
||||
// character, and color. Useful for visually dividing sections of content in the UI.
|
||||
func CreateSeparator(width int, char string, color lipgloss.AdaptiveColor) string {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(color).
|
||||
@@ -178,7 +194,9 @@ func CreateSeparator(width int, char string, color lipgloss.AdaptiveColor) strin
|
||||
Render(lipgloss.PlaceHorizontal(width, lipgloss.Center, char))
|
||||
}
|
||||
|
||||
// CreateProgressBar creates a simple progress bar
|
||||
// CreateProgressBar generates a visual progress bar with filled and empty segments
|
||||
// based on the percentage complete. The bar uses Unicode block characters for smooth
|
||||
// appearance and theme colors to indicate progress.
|
||||
func CreateProgressBar(width int, percentage float64, theme Theme) string {
|
||||
filled := int(float64(width) * percentage / 100)
|
||||
empty := width - filled
|
||||
@@ -194,7 +212,8 @@ func CreateProgressBar(width int, percentage float64, theme Theme) string {
|
||||
return filledBar + emptyBar
|
||||
}
|
||||
|
||||
// CreateBadge creates a styled badge
|
||||
// CreateBadge generates a styled badge or label with inverted colors (text on
|
||||
// colored background) for highlighting important tags, statuses, or categories.
|
||||
func CreateBadge(text string, color lipgloss.AdaptiveColor) string {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "#FFFFFF", Dark: "#000000"}).
|
||||
@@ -204,7 +223,9 @@ func CreateBadge(text string, color lipgloss.AdaptiveColor) string {
|
||||
Render(text)
|
||||
}
|
||||
|
||||
// CreateGradientText creates text with gradient-like effect using different shades
|
||||
// CreateGradientText creates styled text with a gradient-like effect. Currently
|
||||
// implements a simplified version using the start color only, as true gradients
|
||||
// require more complex terminal capabilities.
|
||||
func CreateGradientText(text string, startColor, endColor lipgloss.AdaptiveColor) string {
|
||||
// For now, just use the start color - true gradients would require more complex implementation
|
||||
return lipgloss.NewStyle().
|
||||
@@ -215,14 +236,16 @@ func CreateGradientText(text string, startColor, endColor lipgloss.AdaptiveColor
|
||||
|
||||
// Compact styling utilities
|
||||
|
||||
// StyleCompactSymbol creates a styled symbol for compact mode
|
||||
// StyleCompactSymbol creates a lipgloss style for message type indicators in
|
||||
// compact mode, using bold colored text to distinguish different message categories.
|
||||
func StyleCompactSymbol(symbol string, color lipgloss.AdaptiveColor) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(color).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleCompactLabel creates a styled label for compact mode
|
||||
// StyleCompactLabel creates a lipgloss style for message labels in compact mode
|
||||
// with fixed width for alignment and bold colored text for readability.
|
||||
func StyleCompactLabel(color lipgloss.AdaptiveColor) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(color).
|
||||
@@ -230,13 +253,16 @@ func StyleCompactLabel(color lipgloss.AdaptiveColor) lipgloss.Style {
|
||||
Width(8)
|
||||
}
|
||||
|
||||
// StyleCompactContent creates basic content styling for compact mode
|
||||
// StyleCompactContent creates a simple lipgloss style for message content in
|
||||
// compact mode, applying only color without additional formatting.
|
||||
func StyleCompactContent(color lipgloss.AdaptiveColor) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(color)
|
||||
}
|
||||
|
||||
// FormatCompactLine formats a complete compact line with consistent spacing
|
||||
// FormatCompactLine assembles a complete compact mode message line with consistent
|
||||
// spacing and styling. Combines a symbol, fixed-width label, and content with their
|
||||
// respective colors to create a uniform appearance across all message types.
|
||||
func FormatCompactLine(symbol, label, content string, symbolColor, labelColor, contentColor lipgloss.AdaptiveColor) string {
|
||||
styledSymbol := StyleCompactSymbol(symbol, symbolColor).Render(symbol)
|
||||
styledLabel := StyleCompactLabel(labelColor).Render(label)
|
||||
|
||||
@@ -8,14 +8,17 @@ import (
|
||||
"github.com/mark3labs/mcphost/internal/models"
|
||||
)
|
||||
|
||||
// AgentInterface defines the interface we need from agent to avoid import cycles
|
||||
// AgentInterface defines the minimal interface required from the agent package
|
||||
// to avoid circular dependencies while still accessing necessary agent functionality.
|
||||
type AgentInterface interface {
|
||||
GetLoadingMessage() string
|
||||
GetTools() []any // Using any to avoid importing tool types
|
||||
GetLoadedServerNames() []string // Add this method for debug config
|
||||
}
|
||||
|
||||
// CLISetupOptions contains options for setting up CLI
|
||||
// CLISetupOptions encapsulates all configuration parameters needed to initialize
|
||||
// and set up a CLI instance, including display preferences, model information,
|
||||
// and debugging settings.
|
||||
type CLISetupOptions struct {
|
||||
Agent AgentInterface
|
||||
ModelString string
|
||||
@@ -35,7 +38,10 @@ func parseModelName(modelString string) (provider, model string) {
|
||||
return "unknown", "unknown"
|
||||
}
|
||||
|
||||
// SetupCLI creates and configures CLI with standard info display
|
||||
// SetupCLI creates, configures, and initializes a CLI instance with the provided
|
||||
// options. It sets up model display, usage tracking for supported providers, and
|
||||
// shows initial loading information. Returns nil in quiet mode or an initialized
|
||||
// CLI instance ready for user interaction.
|
||||
func SetupCLI(opts *CLISetupOptions) (*CLI, error) {
|
||||
if opts.Quiet {
|
||||
return nil, nil // No CLI in quiet mode
|
||||
|
||||
@@ -4,13 +4,17 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// FuzzyMatch represents a match result with score
|
||||
// FuzzyMatch represents the result of a fuzzy string matching operation,
|
||||
// containing the matched command and its relevance score. Higher scores
|
||||
// indicate better matches.
|
||||
type FuzzyMatch struct {
|
||||
Command *SlashCommand
|
||||
Score int
|
||||
}
|
||||
|
||||
// FuzzyMatchCommands performs fuzzy matching on slash commands
|
||||
// FuzzyMatchCommands performs fuzzy string matching on the provided slash commands
|
||||
// based on the query string. Returns a slice of matches sorted by relevance score
|
||||
// in descending order. An empty query returns all commands with zero scores.
|
||||
func FuzzyMatchCommands(query string, commands []SlashCommand) []FuzzyMatch {
|
||||
if query == "" || query == "/" {
|
||||
// Return all commands when query is empty or just "/"
|
||||
|
||||
+59
-21
@@ -10,7 +10,8 @@ import (
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// MessageType represents the type of message
|
||||
// MessageType represents different categories of messages displayed in the UI,
|
||||
// each with distinct visual styling and formatting rules.
|
||||
type MessageType int
|
||||
|
||||
const (
|
||||
@@ -22,7 +23,9 @@ const (
|
||||
ErrorMessage // New type for error messages
|
||||
)
|
||||
|
||||
// UIMessage represents a rendered message for display
|
||||
// UIMessage encapsulates a fully rendered message ready for display in the UI,
|
||||
// including its formatted content, display metrics, and metadata. Messages can
|
||||
// be static or streaming (progressively updated).
|
||||
type UIMessage struct {
|
||||
ID string
|
||||
Type MessageType
|
||||
@@ -38,7 +41,9 @@ func getTheme() Theme {
|
||||
return GetTheme()
|
||||
}
|
||||
|
||||
// MessageRenderer handles rendering of messages with proper styling
|
||||
// MessageRenderer handles the formatting and rendering of different message types
|
||||
// with consistent styling, markdown support, and appropriate visual hierarchies
|
||||
// for the standard (non-compact) display mode.
|
||||
type MessageRenderer struct {
|
||||
width int
|
||||
debug bool
|
||||
@@ -59,7 +64,9 @@ func getSystemUsername() string {
|
||||
return "User"
|
||||
}
|
||||
|
||||
// NewMessageRenderer creates a new message renderer
|
||||
// NewMessageRenderer creates and initializes a new MessageRenderer with the specified
|
||||
// terminal width and debug mode setting. The width parameter determines line wrapping
|
||||
// and layout calculations.
|
||||
func NewMessageRenderer(width int, debug bool) *MessageRenderer {
|
||||
return &MessageRenderer{
|
||||
width: width,
|
||||
@@ -67,12 +74,15 @@ func NewMessageRenderer(width int, debug bool) *MessageRenderer {
|
||||
}
|
||||
}
|
||||
|
||||
// SetWidth updates the renderer width
|
||||
// SetWidth updates the terminal width for the renderer, affecting how content
|
||||
// is wrapped and formatted in subsequent render operations.
|
||||
func (r *MessageRenderer) SetWidth(width int) {
|
||||
r.width = width
|
||||
}
|
||||
|
||||
// RenderUserMessage renders a user message with right border and background header
|
||||
// RenderUserMessage renders a user's input message with distinctive right-aligned
|
||||
// formatting, including the system username, timestamp, and markdown-rendered content.
|
||||
// The message is displayed with a colored right border for visual distinction.
|
||||
func (r *MessageRenderer) RenderUserMessage(content string, timestamp time.Time) UIMessage {
|
||||
// Format timestamp and username
|
||||
timeStr := timestamp.Local().Format("15:04")
|
||||
@@ -106,7 +116,10 @@ func (r *MessageRenderer) RenderUserMessage(content string, timestamp time.Time)
|
||||
}
|
||||
}
|
||||
|
||||
// RenderAssistantMessage renders an assistant message with left border and background header
|
||||
// RenderAssistantMessage renders an AI assistant's response with left-aligned formatting,
|
||||
// including the model name, timestamp, and markdown-rendered content. Empty responses
|
||||
// are displayed with a special "Finished without output" message. The message features
|
||||
// a colored left border for visual distinction.
|
||||
func (r *MessageRenderer) RenderAssistantMessage(content string, timestamp time.Time, modelName string) UIMessage {
|
||||
// Format timestamp and model info with better defaults
|
||||
timeStr := timestamp.Local().Format("15:04")
|
||||
@@ -151,7 +164,9 @@ func (r *MessageRenderer) RenderAssistantMessage(content string, timestamp time.
|
||||
}
|
||||
}
|
||||
|
||||
// RenderSystemMessage renders a system message with left border and background header
|
||||
// RenderSystemMessage renders MCPHost system messages such as help text, command outputs,
|
||||
// and informational notifications. These messages are displayed with a distinctive system
|
||||
// color border and "MCPHost System" label to differentiate them from user and AI content.
|
||||
func (r *MessageRenderer) RenderSystemMessage(content string, timestamp time.Time) UIMessage {
|
||||
// Format timestamp
|
||||
timeStr := timestamp.Local().Format("15:04")
|
||||
@@ -193,7 +208,9 @@ func (r *MessageRenderer) RenderSystemMessage(content string, timestamp time.Tim
|
||||
}
|
||||
}
|
||||
|
||||
// RenderDebugMessage renders debug messages with tool response block styling
|
||||
// RenderDebugMessage renders diagnostic and debugging information with special formatting
|
||||
// including a debug icon, colored border, and structured layout. Debug messages are only
|
||||
// displayed when debug mode is enabled and help developers troubleshoot issues.
|
||||
func (r *MessageRenderer) RenderDebugMessage(message string, timestamp time.Time) UIMessage {
|
||||
baseStyle := lipgloss.NewStyle()
|
||||
|
||||
@@ -251,7 +268,9 @@ func (r *MessageRenderer) RenderDebugMessage(message string, timestamp time.Time
|
||||
}
|
||||
}
|
||||
|
||||
// RenderDebugConfigMessage renders debug configuration settings with tool response block styling
|
||||
// RenderDebugConfigMessage renders configuration settings in a formatted debug display
|
||||
// with key-value pairs shown in a structured layout. Used to display runtime configuration
|
||||
// for debugging purposes with a distinctive icon and border styling.
|
||||
func (r *MessageRenderer) RenderDebugConfigMessage(config map[string]any, timestamp time.Time) UIMessage {
|
||||
baseStyle := lipgloss.NewStyle()
|
||||
|
||||
@@ -311,7 +330,9 @@ func (r *MessageRenderer) RenderDebugConfigMessage(config map[string]any, timest
|
||||
}
|
||||
}
|
||||
|
||||
// RenderErrorMessage renders an error message with left border and background header
|
||||
// RenderErrorMessage renders error notifications with distinctive red coloring and
|
||||
// bold text to ensure visibility. Error messages include timestamp information and
|
||||
// are displayed with an error-colored border for immediate recognition.
|
||||
func (r *MessageRenderer) RenderErrorMessage(errorMsg string, timestamp time.Time) UIMessage {
|
||||
// Format timestamp
|
||||
timeStr := timestamp.Local().Format("15:04")
|
||||
@@ -347,7 +368,9 @@ func (r *MessageRenderer) RenderErrorMessage(errorMsg string, timestamp time.Tim
|
||||
}
|
||||
}
|
||||
|
||||
// RenderToolCallMessage renders a tool call in progress with left border and background header
|
||||
// RenderToolCallMessage renders a notification that a tool is being executed, showing
|
||||
// the tool name, formatted arguments (if any), and execution timestamp. The message
|
||||
// uses tool-specific coloring to distinguish it from regular conversation messages.
|
||||
func (r *MessageRenderer) RenderToolCallMessage(toolName, toolArgs string, timestamp time.Time) UIMessage {
|
||||
// Format timestamp
|
||||
timeStr := timestamp.Local().Format("15:04")
|
||||
@@ -391,7 +414,10 @@ func (r *MessageRenderer) RenderToolCallMessage(toolName, toolArgs string, times
|
||||
}
|
||||
}
|
||||
|
||||
// RenderToolMessage renders a tool call message with proper styling
|
||||
// RenderToolMessage renders the result of a tool execution, formatting the output
|
||||
// based on the tool type and whether it succeeded or failed. Error results are
|
||||
// displayed in red, while successful results are formatted according to the tool's
|
||||
// output type (bash, file content, etc.).
|
||||
func (r *MessageRenderer) RenderToolMessage(toolName, toolArgs, toolResult string, isError bool) UIMessage {
|
||||
theme := getTheme()
|
||||
|
||||
@@ -595,7 +621,9 @@ func (r *MessageRenderer) renderMarkdown(content string, width int) string {
|
||||
return strings.TrimSuffix(rendered, "\n")
|
||||
}
|
||||
|
||||
// MessageContainer wraps multiple messages in a container
|
||||
// MessageContainer manages a collection of UI messages, handling their display,
|
||||
// updates, and layout within the terminal. It supports both standard and compact
|
||||
// display modes and maintains state for streaming message updates.
|
||||
type MessageContainer struct {
|
||||
messages []UIMessage
|
||||
width int
|
||||
@@ -605,7 +633,9 @@ type MessageContainer struct {
|
||||
wasCleared bool // Track if container was explicitly cleared
|
||||
}
|
||||
|
||||
// NewMessageContainer creates a new message container
|
||||
// NewMessageContainer creates and initializes a new MessageContainer with the
|
||||
// specified dimensions and display mode. The container starts empty and will
|
||||
// display a welcome message until the first message is added.
|
||||
func NewMessageContainer(width, height int, compact bool) *MessageContainer {
|
||||
return &MessageContainer{
|
||||
messages: make([]UIMessage, 0),
|
||||
@@ -615,18 +645,22 @@ func NewMessageContainer(width, height int, compact bool) *MessageContainer {
|
||||
}
|
||||
}
|
||||
|
||||
// AddMessage adds a message to the container
|
||||
// AddMessage appends a new UIMessage to the container's collection and resets
|
||||
// the cleared state flag. Messages are displayed in the order they were added.
|
||||
func (c *MessageContainer) AddMessage(msg UIMessage) {
|
||||
c.messages = append(c.messages, msg)
|
||||
c.wasCleared = false // Reset the cleared flag when adding messages
|
||||
}
|
||||
|
||||
// SetModelName sets the current model name for the container
|
||||
// SetModelName updates the AI model name used for rendering assistant messages.
|
||||
// This name is displayed in message headers to indicate which model is responding.
|
||||
func (c *MessageContainer) SetModelName(modelName string) {
|
||||
c.modelName = modelName
|
||||
}
|
||||
|
||||
// UpdateLastMessage updates the content of the last message efficiently
|
||||
// UpdateLastMessage efficiently updates the content of the most recent message
|
||||
// in the container. This is primarily used for streaming responses where the
|
||||
// assistant's message is progressively built. Only works for assistant messages.
|
||||
func (c *MessageContainer) UpdateLastMessage(content string) {
|
||||
if len(c.messages) == 0 {
|
||||
return
|
||||
@@ -651,19 +685,23 @@ func (c *MessageContainer) UpdateLastMessage(content string) {
|
||||
}
|
||||
}
|
||||
|
||||
// Clear clears all messages from the container
|
||||
// Clear removes all messages from the container and sets a flag to prevent
|
||||
// showing the welcome screen. Used when starting a fresh conversation.
|
||||
func (c *MessageContainer) Clear() {
|
||||
c.messages = make([]UIMessage, 0)
|
||||
c.wasCleared = true
|
||||
}
|
||||
|
||||
// SetSize updates the container size
|
||||
// SetSize updates the container's dimensions, typically called when the terminal
|
||||
// is resized. This affects how messages are wrapped and displayed.
|
||||
func (c *MessageContainer) SetSize(width, height int) {
|
||||
c.width = width
|
||||
c.height = height
|
||||
}
|
||||
|
||||
// Render renders all messages in the container
|
||||
// Render generates the complete visual representation of all messages in the
|
||||
// container. Returns an empty state display if no messages exist, or formats
|
||||
// all messages according to the current display mode (standard or compact).
|
||||
func (c *MessageContainer) Render() string {
|
||||
if len(c.messages) == 0 {
|
||||
// Don't show welcome box if explicitly cleared
|
||||
|
||||
@@ -21,7 +21,9 @@ const (
|
||||
maxWidth = 80
|
||||
)
|
||||
|
||||
// OllamaPullProgress represents the progress information from Ollama pull API
|
||||
// OllamaPullProgress represents the progress information received from Ollama's
|
||||
// pull API when downloading model files. It includes status messages, digest
|
||||
// information, and download progress counters.
|
||||
type OllamaPullProgress struct {
|
||||
Status string `json:"status"`
|
||||
Digest string `json:"digest,omitempty"`
|
||||
@@ -41,7 +43,9 @@ type progressErrMsg struct{ err error }
|
||||
// progressCompleteMsg indicates completion
|
||||
type progressCompleteMsg struct{}
|
||||
|
||||
// ProgressModel represents the progress bar model
|
||||
// ProgressModel implements a tea.Model for displaying download progress with
|
||||
// a visual progress bar and status messages. It handles progress updates,
|
||||
// errors, and completion states for Ollama model downloads.
|
||||
type ProgressModel struct {
|
||||
progress progress.Model
|
||||
status string
|
||||
@@ -49,7 +53,8 @@ type ProgressModel struct {
|
||||
complete bool
|
||||
}
|
||||
|
||||
// NewProgressModel creates a new progress model
|
||||
// NewProgressModel creates and initializes a new ProgressModel with a gradient
|
||||
// progress bar and initial "Initializing..." status message.
|
||||
func NewProgressModel() ProgressModel {
|
||||
return ProgressModel{
|
||||
progress: progress.New(progress.WithDefaultGradient()),
|
||||
@@ -57,12 +62,15 @@ func NewProgressModel() ProgressModel {
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes the progress model
|
||||
// Init implements the tea.Model interface, returning nil as no initial commands
|
||||
// are needed for the progress display.
|
||||
func (m ProgressModel) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update handles progress updates
|
||||
// Update implements the tea.Model interface, handling keyboard input, window
|
||||
// resize events, and progress updates. It manages the progress bar state and
|
||||
// triggers program exit on completion or cancellation.
|
||||
func (m ProgressModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
@@ -108,7 +116,9 @@ func (m ProgressModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
}
|
||||
|
||||
// View renders the progress bar
|
||||
// View implements the tea.Model interface, rendering the progress bar with
|
||||
// status information and help text. Displays error messages if present or
|
||||
// a completion message when the download finishes.
|
||||
func (m ProgressModel) View() string {
|
||||
if m.err != nil {
|
||||
return fmt.Sprintf("Error: %s\n", m.err.Error())
|
||||
@@ -128,7 +138,9 @@ func (m ProgressModel) View() string {
|
||||
pad+helpStyle("Press 'q' or Ctrl+C to cancel"))
|
||||
}
|
||||
|
||||
// ProgressReader wraps an io.Reader to provide progress updates for Ollama pull operations
|
||||
// ProgressReader wraps an io.Reader to intercept and parse Ollama pull operation
|
||||
// responses, extracting progress information and updating a visual progress bar.
|
||||
// It manages a tea.Program for the UI and handles graceful shutdown.
|
||||
type ProgressReader struct {
|
||||
reader io.Reader
|
||||
program *tea.Program
|
||||
@@ -138,7 +150,9 @@ type ProgressReader struct {
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewProgressReader creates a new progress reader for Ollama pull operations
|
||||
// NewProgressReader creates and initializes a ProgressReader that wraps the provided
|
||||
// io.Reader. It starts a tea.Program in a separate goroutine to display the progress
|
||||
// bar UI while reading and parsing Ollama's streaming JSON responses.
|
||||
func NewProgressReader(reader io.Reader) *ProgressReader {
|
||||
model := NewProgressModel()
|
||||
// Create program with standard settings
|
||||
@@ -164,7 +178,9 @@ func NewProgressReader(reader io.Reader) *ProgressReader {
|
||||
return pr
|
||||
}
|
||||
|
||||
// Read implements io.Reader and parses Ollama streaming responses
|
||||
// Read implements the io.Reader interface, passing through data from the wrapped
|
||||
// reader while parsing JSON lines to extract progress information. Each complete
|
||||
// JSON line is processed to update the progress bar display.
|
||||
func (pr *ProgressReader) Read(p []byte) (n int, err error) {
|
||||
n, err = pr.reader.Read(p)
|
||||
if n > 0 {
|
||||
@@ -239,7 +255,9 @@ func (pr *ProgressReader) parseProgressLine(line string) {
|
||||
})
|
||||
}
|
||||
|
||||
// Close stops the progress display and waits for cleanup
|
||||
// Close gracefully shuts down the progress display, sending a completion message
|
||||
// and waiting for the tea.Program to exit. If the program doesn't exit within
|
||||
// 2 seconds, it is forcefully terminated to prevent hanging.
|
||||
func (pr *ProgressReader) Close() error {
|
||||
// Send completion message to trigger quit
|
||||
pr.program.Send(progressCompleteMsg{})
|
||||
|
||||
@@ -9,7 +9,9 @@ import (
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// SlashCommandInput is a custom input field with slash command autocomplete
|
||||
// SlashCommandInput provides an interactive text input field with intelligent
|
||||
// slash command autocomplete functionality. It displays a popup menu of matching
|
||||
// commands as the user types, supporting fuzzy matching and keyboard navigation.
|
||||
type SlashCommandInput struct {
|
||||
textarea textarea.Model
|
||||
commands []SlashCommand
|
||||
@@ -26,7 +28,9 @@ type SlashCommandInput struct {
|
||||
renderedLines int // Track how many lines were rendered
|
||||
}
|
||||
|
||||
// NewSlashCommandInput creates a new slash command input field
|
||||
// NewSlashCommandInput creates and initializes a new slash command input field with
|
||||
// the specified width and title. The input supports multi-line text entry, command
|
||||
// autocomplete, and is styled to match the application's theme.
|
||||
func NewSlashCommandInput(width int, title string) *SlashCommandInput {
|
||||
ta := textarea.New()
|
||||
ta.Placeholder = "Type your message..."
|
||||
@@ -54,12 +58,15 @@ func NewSlashCommandInput(width int, title string) *SlashCommandInput {
|
||||
}
|
||||
}
|
||||
|
||||
// Init implements tea.Model
|
||||
// Init implements the tea.Model interface, returning the initial command to start
|
||||
// the cursor blinking animation for the text input field.
|
||||
func (s *SlashCommandInput) Init() tea.Cmd {
|
||||
return textarea.Blink
|
||||
}
|
||||
|
||||
// Update implements tea.Model
|
||||
// Update implements the tea.Model interface, handling keyboard input for text entry,
|
||||
// command selection, and navigation. Manages the autocomplete popup display and
|
||||
// processes submission or cancellation actions.
|
||||
func (s *SlashCommandInput) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
|
||||
@@ -168,7 +175,9 @@ func (s *SlashCommandInput) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
}
|
||||
|
||||
// View implements tea.Model
|
||||
// View implements the tea.Model interface, rendering the complete input field
|
||||
// including the title, text area, autocomplete popup (when active), and help text.
|
||||
// The view adapts based on whether single or multi-line input is detected.
|
||||
func (s *SlashCommandInput) View() string {
|
||||
// Add left padding to entire component (2 spaces like other UI elements)
|
||||
containerStyle := lipgloss.NewStyle().PaddingLeft(2)
|
||||
@@ -324,17 +333,21 @@ func (s *SlashCommandInput) renderPopup() string {
|
||||
return popupStyle.Render(popupContent)
|
||||
}
|
||||
|
||||
// Value returns the final value
|
||||
// Value returns the final text value entered by the user after submission.
|
||||
// This will be empty if the input was cancelled.
|
||||
func (s *SlashCommandInput) Value() string {
|
||||
return s.value
|
||||
}
|
||||
|
||||
// Cancelled returns true if the user cancelled
|
||||
// Cancelled returns true if the user cancelled the input operation (e.g., by
|
||||
// pressing ESC or Ctrl+C) without submitting any text.
|
||||
func (s *SlashCommandInput) Cancelled() bool {
|
||||
return s.quitting && s.value == ""
|
||||
}
|
||||
|
||||
// RenderedLines returns how many lines were rendered
|
||||
// RenderedLines returns the total number of terminal lines used by the last
|
||||
// rendered view, including the title, input area, popup, and help text. This
|
||||
// is used for proper screen clearing when the input is dismissed.
|
||||
func (s *SlashCommandInput) RenderedLines() int {
|
||||
return s.renderedLines
|
||||
}
|
||||
|
||||
+14
-5
@@ -10,7 +10,9 @@ import (
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// Spinner wraps the bubbles spinner for both interactive and non-interactive mode
|
||||
// Spinner provides an animated loading indicator that displays while long-running
|
||||
// operations are in progress. It wraps the bubbles spinner component and manages
|
||||
// its lifecycle through a tea.Program for proper terminal handling.
|
||||
type Spinner struct {
|
||||
model spinner.Model
|
||||
done chan struct{}
|
||||
@@ -72,7 +74,9 @@ func (m spinnerModel) View() string {
|
||||
// quitMsg is sent when we want to quit the spinner
|
||||
type quitMsg struct{}
|
||||
|
||||
// NewSpinner creates a new spinner with enhanced styling
|
||||
// NewSpinner creates a new animated spinner with the specified message. The spinner
|
||||
// uses the theme's primary color and a modern animation style. It runs in a separate
|
||||
// tea.Program to avoid interfering with other terminal operations.
|
||||
func NewSpinner(message string) *Spinner {
|
||||
s := spinner.New()
|
||||
s.Spinner = spinner.Points // More modern spinner style
|
||||
@@ -97,7 +101,9 @@ func NewSpinner(message string) *Spinner {
|
||||
}
|
||||
}
|
||||
|
||||
// NewThemedSpinner creates a new spinner with the given message and color
|
||||
// NewThemedSpinner creates a new animated spinner with custom color styling.
|
||||
// This allows for different spinner colors based on the operation type or status.
|
||||
// The spinner runs independently in its own tea.Program.
|
||||
func NewThemedSpinner(message string, color lipgloss.AdaptiveColor) *Spinner {
|
||||
s := spinner.New()
|
||||
s.Spinner = spinner.Dot
|
||||
@@ -121,7 +127,9 @@ func NewThemedSpinner(message string, color lipgloss.AdaptiveColor) *Spinner {
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the spinner animation
|
||||
// Start begins the spinner animation in a separate goroutine. The spinner will
|
||||
// continue animating until Stop is called. The animation runs in a separate
|
||||
// tea.Program to maintain smooth animation independent of other operations.
|
||||
func (s *Spinner) Start() {
|
||||
go func() {
|
||||
defer close(s.done)
|
||||
@@ -136,7 +144,8 @@ func (s *Spinner) Start() {
|
||||
}()
|
||||
}
|
||||
|
||||
// Stop ends the spinner animation
|
||||
// Stop halts the spinner animation and cleans up resources. This method blocks
|
||||
// until the spinner has fully stopped and the terminal state is restored.
|
||||
func (s *Spinner) Stop() {
|
||||
s.cancel()
|
||||
<-s.done
|
||||
|
||||
@@ -15,12 +15,16 @@ func boolPtr(b bool) *bool { return &b }
|
||||
func stringPtr(s string) *string { return &s }
|
||||
func uintPtr(u uint) *uint { return &u }
|
||||
|
||||
// BaseStyle returns a basic lipgloss style
|
||||
// BaseStyle returns a new, empty lipgloss style that can be customized with
|
||||
// additional styling methods. This serves as the foundation for building more
|
||||
// complex styled components.
|
||||
func BaseStyle() lipgloss.Style {
|
||||
return lipgloss.NewStyle()
|
||||
}
|
||||
|
||||
// GetMarkdownRenderer returns a glamour TermRenderer configured for our use
|
||||
// GetMarkdownRenderer creates and returns a configured glamour.TermRenderer for
|
||||
// rendering markdown content with syntax highlighting and proper formatting. The
|
||||
// renderer is customized with our theme colors and adapted to the specified width.
|
||||
func GetMarkdownRenderer(width int) *glamour.TermRenderer {
|
||||
r, _ := glamour.NewTermRenderer(
|
||||
glamour.WithStyles(generateMarkdownStyleConfig()),
|
||||
|
||||
@@ -9,7 +9,9 @@ import (
|
||||
"github.com/mark3labs/mcphost/internal/tokens"
|
||||
)
|
||||
|
||||
// UsageStats represents token and cost information for a single request/response
|
||||
// UsageStats encapsulates detailed token usage and cost breakdown for a single
|
||||
// LLM request/response cycle, including input, output, and cache token counts
|
||||
// along with their associated costs.
|
||||
type UsageStats struct {
|
||||
InputTokens int
|
||||
OutputTokens int
|
||||
@@ -22,7 +24,9 @@ type UsageStats struct {
|
||||
TotalCost float64
|
||||
}
|
||||
|
||||
// SessionStats represents cumulative stats for the entire session
|
||||
// SessionStats aggregates token usage and cost information across all requests
|
||||
// in a session, providing totals and request counts for usage analysis and
|
||||
// cost tracking.
|
||||
type SessionStats struct {
|
||||
TotalInputTokens int
|
||||
TotalOutputTokens int
|
||||
@@ -32,7 +36,9 @@ type SessionStats struct {
|
||||
RequestCount int
|
||||
}
|
||||
|
||||
// UsageTracker tracks token usage and costs for LLM interactions
|
||||
// UsageTracker monitors and accumulates token usage statistics and associated costs
|
||||
// for LLM interactions throughout a session. It provides real-time usage information
|
||||
// and supports both estimated and actual token counts. OAuth users see $0 costs.
|
||||
type UsageTracker struct {
|
||||
mu sync.RWMutex
|
||||
modelInfo *models.ModelInfo
|
||||
@@ -43,7 +49,10 @@ type UsageTracker struct {
|
||||
isOAuth bool // Whether OAuth credentials are being used (costs should be $0)
|
||||
}
|
||||
|
||||
// NewUsageTracker creates a new usage tracker for the given model
|
||||
// NewUsageTracker creates and initializes a new UsageTracker for the specified model.
|
||||
// The tracker uses model-specific pricing information to calculate costs, unless OAuth
|
||||
// credentials are being used (in which case costs are shown as $0). Width determines
|
||||
// the display formatting.
|
||||
func NewUsageTracker(modelInfo *models.ModelInfo, provider string, width int, isOAuth bool) *UsageTracker {
|
||||
return &UsageTracker{
|
||||
modelInfo: modelInfo,
|
||||
@@ -53,15 +62,19 @@ func NewUsageTracker(modelInfo *models.ModelInfo, provider string, width int, is
|
||||
}
|
||||
}
|
||||
|
||||
// EstimateTokens provides a rough estimate of tokens in text
|
||||
// This is a simple approximation - real token counting would require the actual tokenizer
|
||||
// EstimateTokens provides a rough estimate of the number of tokens in the given text.
|
||||
// This uses a simple heuristic of approximately 4 characters per token, which is a
|
||||
// reasonable approximation for most models but not precise. Actual token counts may vary
|
||||
// significantly based on the specific tokenizer used by each model.
|
||||
func EstimateTokens(text string) int {
|
||||
// Rough approximation: ~4 characters per token for most models
|
||||
// This is not accurate but gives a reasonable estimate
|
||||
return len(text) / 4
|
||||
}
|
||||
|
||||
// UpdateUsage updates the tracker with new usage information
|
||||
// UpdateUsage records new token usage data and calculates associated costs based on
|
||||
// the model's pricing. Updates both the last request statistics and cumulative session
|
||||
// totals. For OAuth users, costs are recorded as $0 while still tracking token counts.
|
||||
func (ut *UsageTracker) UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens int) {
|
||||
ut.mu.Lock()
|
||||
defer ut.mu.Unlock()
|
||||
@@ -107,21 +120,27 @@ func (ut *UsageTracker) UpdateUsage(inputTokens, outputTokens, cacheReadTokens,
|
||||
ut.sessionStats.RequestCount++
|
||||
}
|
||||
|
||||
// EstimateAndUpdateUsage estimates tokens from text and updates usage
|
||||
// EstimateAndUpdateUsage estimates token counts from raw text strings and updates
|
||||
// the usage statistics. This method is used when actual token counts are not available
|
||||
// from the API response.
|
||||
func (ut *UsageTracker) EstimateAndUpdateUsage(inputText, outputText string) {
|
||||
inputTokens := tokens.EstimateTokens(inputText)
|
||||
outputTokens := tokens.EstimateTokens(outputText)
|
||||
ut.UpdateUsage(inputTokens, outputTokens, 0, 0)
|
||||
}
|
||||
|
||||
// EstimateAndUpdateUsageFromText estimates tokens from text and updates usage
|
||||
// EstimateAndUpdateUsageFromText is an alias for EstimateAndUpdateUsage, providing
|
||||
// backward compatibility. It estimates token counts from text and updates usage statistics.
|
||||
func (ut *UsageTracker) EstimateAndUpdateUsageFromText(inputText, outputText string) {
|
||||
inputTokens := tokens.EstimateTokens(inputText)
|
||||
outputTokens := tokens.EstimateTokens(outputText)
|
||||
ut.UpdateUsage(inputTokens, outputTokens, 0, 0)
|
||||
}
|
||||
|
||||
// RenderUsageInfo renders enhanced usage information with better styling
|
||||
// RenderUsageInfo generates a formatted string displaying current usage statistics
|
||||
// including token counts, context utilization percentage, and costs. The display
|
||||
// adapts colors based on usage levels and formats large numbers with K/M suffixes
|
||||
// for readability.
|
||||
func (ut *UsageTracker) RenderUsageInfo() string {
|
||||
ut.mu.RLock()
|
||||
defer ut.mu.RUnlock()
|
||||
@@ -199,14 +218,18 @@ func (ut *UsageTracker) RenderUsageInfo() string {
|
||||
tokensLabel, tokensValue, percentageStr, costLabel, costStr)
|
||||
}
|
||||
|
||||
// GetSessionStats returns a copy of the current session statistics
|
||||
// GetSessionStats returns a copy of the cumulative session statistics including
|
||||
// total token counts, costs, and request count. The returned copy is safe to use
|
||||
// without additional synchronization.
|
||||
func (ut *UsageTracker) GetSessionStats() SessionStats {
|
||||
ut.mu.RLock()
|
||||
defer ut.mu.RUnlock()
|
||||
return ut.sessionStats
|
||||
}
|
||||
|
||||
// GetLastRequestStats returns a copy of the last request statistics
|
||||
// GetLastRequestStats returns a copy of the usage statistics from the most recent
|
||||
// request, or nil if no requests have been made. The returned copy is safe to use
|
||||
// without additional synchronization.
|
||||
func (ut *UsageTracker) GetLastRequestStats() *UsageStats {
|
||||
ut.mu.RLock()
|
||||
defer ut.mu.RUnlock()
|
||||
@@ -217,7 +240,9 @@ func (ut *UsageTracker) GetLastRequestStats() *UsageStats {
|
||||
return &stats
|
||||
}
|
||||
|
||||
// Reset clears all usage statistics
|
||||
// Reset clears all accumulated usage statistics, resetting both session totals
|
||||
// and last request information to their initial empty state. This is typically
|
||||
// used when starting a new conversation or clearing usage history.
|
||||
func (ut *UsageTracker) Reset() {
|
||||
ut.mu.Lock()
|
||||
defer ut.mu.Unlock()
|
||||
@@ -225,7 +250,9 @@ func (ut *UsageTracker) Reset() {
|
||||
ut.lastRequest = nil
|
||||
}
|
||||
|
||||
// SetWidth updates the display width for rendering
|
||||
// SetWidth updates the terminal width used for formatting usage information display.
|
||||
// This should be called when the terminal is resized to ensure proper text wrapping
|
||||
// and alignment.
|
||||
func (ut *UsageTracker) SetWidth(width int) {
|
||||
ut.mu.Lock()
|
||||
defer ut.mu.Unlock()
|
||||
|
||||
+28
-11
@@ -13,14 +13,18 @@ import (
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// MCPHost provides programmatic access to mcphost
|
||||
// MCPHost provides programmatic access to mcphost functionality, allowing
|
||||
// integration of MCP tools and LLM interactions into Go applications. It manages
|
||||
// agents, sessions, and model configurations.
|
||||
type MCPHost struct {
|
||||
agent *agent.Agent
|
||||
sessionMgr *session.Manager
|
||||
modelString string
|
||||
}
|
||||
|
||||
// Options for creating MCPHost (all optional - will use CLI defaults)
|
||||
// Options configures MCPHost creation with optional overrides for model,
|
||||
// prompts, configuration, and behavior settings. All fields are optional
|
||||
// and will use CLI defaults if not specified.
|
||||
type Options struct {
|
||||
Model string // Override model (e.g., "anthropic:claude-3-sonnet")
|
||||
SystemPrompt string // Override system prompt
|
||||
@@ -30,7 +34,9 @@ type Options struct {
|
||||
Quiet bool // Suppress debug output
|
||||
}
|
||||
|
||||
// New creates MCPHost instance using the same initialization as CLI
|
||||
// New creates an MCPHost instance using the same initialization as the CLI.
|
||||
// It loads configuration, initializes MCP servers, creates the LLM model, and
|
||||
// sets up the agent for interaction. Returns an error if initialization fails.
|
||||
func New(ctx context.Context, opts *Options) (*MCPHost, error) {
|
||||
if opts == nil {
|
||||
opts = &Options{}
|
||||
@@ -118,7 +124,9 @@ func New(ctx context.Context, opts *Options) (*MCPHost, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Prompt sends a message and returns the response
|
||||
// Prompt sends a message to the agent and returns the response. The agent may
|
||||
// use tools as needed to generate the response. The conversation history is
|
||||
// automatically maintained in the session. Returns an error if generation fails.
|
||||
func (m *MCPHost) Prompt(ctx context.Context, message string) (string, error) {
|
||||
// Get messages from session
|
||||
messages := m.sessionMgr.GetMessages()
|
||||
@@ -148,7 +156,9 @@ func (m *MCPHost) Prompt(ctx context.Context, message string) (string, error) {
|
||||
return result.FinalResponse.Content, nil
|
||||
}
|
||||
|
||||
// PromptWithCallbacks sends a message with callbacks for tool execution
|
||||
// PromptWithCallbacks sends a message with callbacks for monitoring tool execution
|
||||
// and streaming responses. The callbacks allow real-time observation of tool calls,
|
||||
// results, and response generation. Returns the final response or an error.
|
||||
func (m *MCPHost) PromptWithCallbacks(
|
||||
ctx context.Context,
|
||||
message string,
|
||||
@@ -184,12 +194,14 @@ func (m *MCPHost) PromptWithCallbacks(
|
||||
return result.FinalResponse.Content, nil
|
||||
}
|
||||
|
||||
// GetSessionManager returns the current session manager
|
||||
// GetSessionManager returns the current session manager for direct access
|
||||
// to conversation history and session manipulation.
|
||||
func (m *MCPHost) GetSessionManager() *session.Manager {
|
||||
return m.sessionMgr
|
||||
}
|
||||
|
||||
// LoadSession loads a session from file
|
||||
// LoadSession loads a previously saved session from a file, restoring the
|
||||
// conversation history. Returns an error if the file cannot be loaded or parsed.
|
||||
func (m *MCPHost) LoadSession(path string) error {
|
||||
s, err := session.LoadFromFile(path)
|
||||
if err != nil {
|
||||
@@ -199,22 +211,27 @@ func (m *MCPHost) LoadSession(path string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveSession saves the current session to file
|
||||
// SaveSession saves the current session to a file for later restoration.
|
||||
// Returns an error if the session cannot be written to the specified path.
|
||||
func (m *MCPHost) SaveSession(path string) error {
|
||||
return m.sessionMgr.GetSession().SaveToFile(path)
|
||||
}
|
||||
|
||||
// ClearSession clears the current session history
|
||||
// ClearSession clears the current session history, starting a new conversation
|
||||
// with an empty message history.
|
||||
func (m *MCPHost) ClearSession() {
|
||||
m.sessionMgr = session.NewManager("")
|
||||
}
|
||||
|
||||
// GetModelString returns the current model string
|
||||
// GetModelString returns the current model string identifier (e.g.,
|
||||
// "anthropic:claude-3-sonnet" or "openai:gpt-4") being used by the agent.
|
||||
func (m *MCPHost) GetModelString() string {
|
||||
return m.modelString
|
||||
}
|
||||
|
||||
// Close cleans up resources
|
||||
// Close cleans up resources including MCP server connections and model resources.
|
||||
// Should be called when the MCPHost instance is no longer needed. Returns an
|
||||
// error if cleanup fails.
|
||||
func (m *MCPHost) Close() error {
|
||||
return m.agent.Close()
|
||||
}
|
||||
|
||||
+8
-4
@@ -5,18 +5,22 @@ import (
|
||||
"github.com/mark3labs/mcphost/internal/session"
|
||||
)
|
||||
|
||||
// Message is an alias for session.Message for SDK users
|
||||
// Message is an alias for session.Message providing SDK users with access
|
||||
// to message structures for conversation history and tool interactions.
|
||||
type Message = session.Message
|
||||
|
||||
// ToolCall is an alias for session.ToolCall
|
||||
// ToolCall is an alias for session.ToolCall representing a tool invocation
|
||||
// with its name, arguments, and result within a conversation.
|
||||
type ToolCall = session.ToolCall
|
||||
|
||||
// ConvertToSchemaMessage converts SDK message to schema message
|
||||
// ConvertToSchemaMessage converts an SDK message to the underlying schema message
|
||||
// format used by the agent for LLM interactions.
|
||||
func ConvertToSchemaMessage(msg *Message) *schema.Message {
|
||||
return msg.ConvertToSchemaMessage()
|
||||
}
|
||||
|
||||
// ConvertFromSchemaMessage converts schema message to SDK message
|
||||
// ConvertFromSchemaMessage converts a schema message from the agent to an SDK
|
||||
// message format for use in the SDK API.
|
||||
func ConvertFromSchemaMessage(msg *schema.Message) Message {
|
||||
return session.ConvertFromSchemaMessage(msg)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user