Files
kit/internal/tools/mcp.go
T
Ed Zynda 0703dd1602 fix: eliminate escape sequence leak from spinner tea.Program instances
Each spinner created a new tea.NewProgram which sent DECRQM queries for
synchronized output mode 2026. When the program exited and restored
cooked terminal mode, the terminal's DECRPM response leaked as visible
^[[?2026;2$y characters. Replace Bubble Tea spinner with a simple
goroutine animation loop writing directly to stderr via lipgloss.
2026-02-25 18:17:25 +03:00

547 lines
17 KiB
Go

package tools
import (
"context"
"encoding/json"
"fmt"
"slices"
"strings"
"time"
"charm.land/fantasy"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcphost/internal/builtin"
"github.com/mark3labs/mcphost/internal/config"
)
// MCPToolManager manages MCP (Model Context Protocol) tools and clients across multiple servers.
// It provides a unified interface for loading, managing, and executing tools from various MCP servers,
// including stdio, SSE, streamable HTTP, and built-in server types. The manager handles connection
// pooling, health checks, tool name prefixing to avoid conflicts, and sampling support for LLM interactions.
// Thread-safe for concurrent tool invocations.
type MCPToolManager struct {
connectionPool *MCPConnectionPool
tools []fantasy.AgentTool
toolMap map[string]*toolMapping // maps prefixed tool names to their server and original name
model fantasy.LanguageModel // LLM model for sampling
config *config.Config
debug bool
debugLogger DebugLogger
}
// toolMapping stores the mapping between prefixed tool names and their original details
type toolMapping struct {
serverName string
originalName string
serverConfig config.MCPServerConfig
manager *MCPToolManager
}
// NewMCPToolManager creates a new MCP tool manager instance.
// Returns an initialized manager with empty tool collections ready to load tools from MCP servers.
// The manager must be configured with SetModel and LoadTools before use.
func NewMCPToolManager() *MCPToolManager {
return &MCPToolManager{
tools: make([]fantasy.AgentTool, 0),
toolMap: make(map[string]*toolMapping),
}
}
// SetModel sets the LLM model for sampling support.
// The model is used when MCP servers request sampling operations, allowing them to
// leverage the host's LLM capabilities for text generation tasks.
// This method should be called before LoadTools if any MCP servers require sampling support.
func (m *MCPToolManager) SetModel(model fantasy.LanguageModel) {
m.model = model
}
// SetDebugLogger sets the debug logger for the tool manager.
// The logger will be used to output detailed debugging information about MCP connections,
// tool loading, and execution. If a connection pool exists, it will also be configured
// to use the same logger for consistent debugging output.
func (m *MCPToolManager) SetDebugLogger(logger DebugLogger) {
m.debugLogger = logger
if m.connectionPool != nil {
m.connectionPool.SetDebugLogger(logger)
}
}
// samplingHandler implements the MCP sampling handler interface using a fantasy LanguageModel
type samplingHandler struct {
model fantasy.LanguageModel
}
// CreateMessage handles sampling requests from MCP servers by forwarding them to the configured LLM model.
// It converts MCP message formats to fantasy message formats, invokes the model for generation,
// and converts the response back to MCP format. Returns an error if no model is available
// or if generation fails.
func (h *samplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
if h.model == nil {
return nil, fmt.Errorf("no model available for sampling")
}
// Build fantasy messages from MCP sampling request
var messages []fantasy.Message
// Add system message if provided
if request.SystemPrompt != "" {
messages = append(messages, fantasy.NewSystemMessage(request.SystemPrompt))
}
// Convert sampling messages
for _, msg := range request.Messages {
var content string
if textContent, ok := msg.Content.(mcp.TextContent); ok {
content = textContent.Text
} else {
content = fmt.Sprintf("%v", msg.Content)
}
switch msg.Role {
case mcp.RoleUser:
messages = append(messages, fantasy.NewUserMessage(content))
case mcp.RoleAssistant:
messages = append(messages, fantasy.Message{
Role: fantasy.MessageRoleAssistant,
Content: []fantasy.MessagePart{fantasy.TextPart{Text: content}},
})
default:
messages = append(messages, fantasy.NewUserMessage(content))
}
}
// Generate response using the fantasy model
call := fantasy.Call{
Prompt: fantasy.Prompt(messages),
}
response, err := h.model.Generate(ctx, call)
if err != nil {
return nil, fmt.Errorf("model generation failed: %w", err)
}
// Convert response back to MCP format
result := &mcp.CreateMessageResult{
Model: h.model.Model(),
StopReason: "endTurn",
}
result.SamplingMessage = mcp.SamplingMessage{
Role: mcp.RoleAssistant,
Content: mcp.TextContent{
Type: "text",
Text: response.Content.Text(),
},
}
return result, nil
}
// LoadTools loads tools from all configured MCP servers based on the provided configuration.
// It initializes the connection pool, connects to each configured server, and loads their tools.
// Tools from different servers are prefixed with the server name to avoid naming conflicts.
// Returns an error only if all configured servers fail to load; partial failures are logged as warnings.
// This method is thread-safe and idempotent.
func (m *MCPToolManager) LoadTools(ctx context.Context, config *config.Config) error {
// Initialize connection pool
m.config = config
m.debug = config.Debug
if m.debugLogger == nil {
m.debugLogger = NewSimpleDebugLogger(config.Debug)
}
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, config.Debug)
m.connectionPool.SetDebugLogger(m.debugLogger)
var loadErrors []string
for serverName, serverConfig := range config.MCPServers {
if err := m.loadServerTools(ctx, serverName, serverConfig); err != nil {
loadErrors = append(loadErrors, fmt.Sprintf("server %s: %v", serverName, err))
fmt.Printf("Warning: Failed to load MCP server '%s': %v\n", serverName, err)
continue
}
}
// If all servers failed to load, return an error
if len(loadErrors) == len(config.MCPServers) && len(config.MCPServers) > 0 {
return fmt.Errorf("all MCP servers failed to load: %s", strings.Join(loadErrors, "; "))
}
return nil
}
// loadServerTools loads tools from a single MCP server
func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) error {
// Add debug logging
m.debugLogConnectionInfo(serverName, serverConfig)
// Get connection from pool
conn, err := m.connectionPool.GetConnection(ctx, serverName, serverConfig)
if err != nil {
return fmt.Errorf("failed to get connection from pool: %v", err)
}
// Get tools from this server
listResults, err := conn.client.ListTools(ctx, mcp.ListToolsRequest{})
if err != nil {
// Handle connection error
m.connectionPool.HandleConnectionError(serverName, err)
return fmt.Errorf("failed to list tools: %v", err)
}
// Create name set for allowed tools
var nameSet map[string]struct{}
if len(serverConfig.AllowedTools) > 0 {
nameSet = make(map[string]struct{})
for _, name := range serverConfig.AllowedTools {
nameSet[name] = struct{}{}
}
}
// Convert MCP tools to fantasy AgentTools with prefixed names
for _, mcpTool := range listResults.Tools {
// Filter tools based on allowedTools/excludedTools
if len(serverConfig.AllowedTools) > 0 {
if _, ok := nameSet[mcpTool.Name]; !ok {
continue
}
}
// Check if tool should be excluded
if m.shouldExcludeTool(mcpTool.Name, serverConfig) {
continue
}
// Convert MCP InputSchema to map[string]any for fantasy ToolInfo
marshaledSchema, err := json.Marshal(mcpTool.InputSchema)
if err != nil {
return fmt.Errorf("conv mcp tool input schema fail(marshal): %w, tool name: %s", err, mcpTool.Name)
}
// Fix for JSON Schema draft-07 vs draft-04 compatibility
marshaledSchema = convertExclusiveBoundsToBoolean(marshaledSchema)
// Parse into map[string]any for fantasy's parameters format
var schemaMap map[string]any
if err := json.Unmarshal(marshaledSchema, &schemaMap); err != nil {
return fmt.Errorf("conv mcp tool input schema fail(unmarshal): %w, tool name: %s", err, mcpTool.Name)
}
// Extract properties and required from the schema
parameters := make(map[string]any)
var required []string
if props, ok := schemaMap["properties"].(map[string]any); ok {
parameters = props
}
// Fix for issue #89: Ensure object schemas have a properties field
if schemaType, ok := schemaMap["type"].(string); ok && schemaType == "object" && len(parameters) == 0 {
// Keep empty parameters map - fantasy handles this fine
}
if req, ok := schemaMap["required"].([]any); ok {
for _, r := range req {
if s, ok := r.(string); ok {
required = append(required, s)
}
}
}
// Create prefixed tool name
prefixedName := fmt.Sprintf("%s__%s", serverName, mcpTool.Name)
// Create tool mapping
mapping := &toolMapping{
serverName: serverName,
originalName: mcpTool.Name,
serverConfig: serverConfig,
manager: m,
}
m.toolMap[prefixedName] = mapping
// Create fantasy AgentTool
fantasyTool := &mcpFantasyTool{
toolInfo: fantasy.ToolInfo{
Name: prefixedName,
Description: mcpTool.Description,
Parameters: parameters,
Required: required,
},
mapping: mapping,
}
m.tools = append(m.tools, fantasyTool)
}
return nil
}
// GetTools returns all loaded tools as fantasy AgentTools from all configured MCP servers.
// Tools are returned with their prefixed names (serverName__toolName) to ensure uniqueness.
func (m *MCPToolManager) GetTools() []fantasy.AgentTool {
return m.tools
}
// GetLoadedServerNames returns the names of all successfully loaded MCP servers.
// This includes servers that are currently connected and have had their tools loaded,
// regardless of their current health status. Useful for debugging and status reporting.
func (m *MCPToolManager) GetLoadedServerNames() []string {
var names []string
for serverName := range m.connectionPool.GetClients() {
names = append(names, serverName)
}
return names
}
// Close closes all MCP client connections and cleans up resources.
// This method should be called when the tool manager is no longer needed to ensure
// proper cleanup of stdio processes, network connections, and other resources.
// It is safe to call Close multiple times.
func (m *MCPToolManager) Close() error {
return m.connectionPool.Close()
}
// shouldExcludeTool determines if a tool should be excluded based on excludedTools
func (m *MCPToolManager) shouldExcludeTool(toolName string, serverConfig config.MCPServerConfig) bool {
if len(serverConfig.ExcludedTools) > 0 {
if slices.Contains(serverConfig.ExcludedTools, toolName) {
return true
}
}
return false
}
func (m *MCPToolManager) createMCPClient(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (client.MCPClient, error) {
transportType := serverConfig.GetTransportType()
switch transportType {
case "stdio":
var env []string
var command string
var args []string
if len(serverConfig.Command) > 0 {
command = serverConfig.Command[0]
if len(serverConfig.Command) > 1 {
args = serverConfig.Command[1:]
} else if len(serverConfig.Args) > 0 {
args = serverConfig.Args
}
}
if serverConfig.Environment != nil {
for k, v := range serverConfig.Environment {
env = append(env, fmt.Sprintf("%s=%s", k, v))
}
}
if serverConfig.Env != nil {
for k, v := range serverConfig.Env {
env = append(env, fmt.Sprintf("%s=%v", k, v))
}
}
stdioTransport := transport.NewStdio(command, env, args...)
stdioClient := client.NewClient(stdioTransport)
if err := stdioTransport.Start(ctx); err != nil {
return nil, fmt.Errorf("failed to start stdio transport: %v", err)
}
time.Sleep(100 * time.Millisecond)
return stdioClient, nil
case "sse":
var options []transport.ClientOption
if len(serverConfig.Headers) > 0 {
headers := make(map[string]string)
for _, header := range serverConfig.Headers {
parts := strings.SplitN(header, ":", 2)
if len(parts) == 2 {
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
headers[key] = value
}
}
if len(headers) > 0 {
options = append(options, transport.WithHeaders(headers))
}
}
sseClient, err := client.NewSSEMCPClient(serverConfig.URL, options...)
if err != nil {
return nil, err
}
if err := sseClient.Start(ctx); err != nil {
return nil, fmt.Errorf("failed to start SSE client: %v", err)
}
return sseClient, nil
case "streamable":
var options []transport.StreamableHTTPCOption
if len(serverConfig.Headers) > 0 {
headers := make(map[string]string)
for _, header := range serverConfig.Headers {
parts := strings.SplitN(header, ":", 2)
if len(parts) == 2 {
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
headers[key] = value
}
}
if len(headers) > 0 {
options = append(options, transport.WithHTTPHeaders(headers))
}
}
streamableClient, err := client.NewStreamableHttpClient(serverConfig.URL, options...)
if err != nil {
return nil, err
}
if err := streamableClient.Start(ctx); err != nil {
return nil, fmt.Errorf("failed to start streamable HTTP client: %v", err)
}
return streamableClient, nil
case "inprocess":
return m.createBuiltinClient(ctx, serverName, serverConfig)
default:
return nil, fmt.Errorf("unsupported transport type '%s' for server %s", transportType, serverName)
}
}
func (m *MCPToolManager) initializeClient(ctx context.Context, client client.MCPClient) error {
initCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.ClientInfo = mcp.Implementation{
Name: "mcphost",
Version: "1.0.0",
}
initRequest.Params.Capabilities = mcp.ClientCapabilities{}
_, err := client.Initialize(initCtx, initRequest)
if err != nil {
return fmt.Errorf("initialization timeout or failed: %v", err)
}
return nil
}
// createBuiltinClient creates an in-process MCP client for builtin servers
func (m *MCPToolManager) createBuiltinClient(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (client.MCPClient, error) {
registry := builtin.NewRegistry()
builtinServer, err := registry.CreateServer(serverConfig.Name, serverConfig.Options, m.model)
if err != nil {
return nil, fmt.Errorf("failed to create builtin server: %v", err)
}
inProcessClient, err := client.NewInProcessClient(builtinServer.GetServer())
if err != nil {
return nil, fmt.Errorf("failed to create in-process client: %v", err)
}
return inProcessClient, nil
}
// debugLogConnectionInfo logs detailed connection information for debugging
func (m *MCPToolManager) debugLogConnectionInfo(serverName string, serverConfig config.MCPServerConfig) {
if m.debugLogger == nil || !m.debugLogger.IsDebugEnabled() {
return
}
m.debugLogger.LogDebug(fmt.Sprintf("[DEBUG] Connecting to MCP server: %s", serverName))
m.debugLogger.LogDebug(fmt.Sprintf("[DEBUG] Transport type: %s", serverConfig.GetTransportType()))
switch serverConfig.GetTransportType() {
case "stdio":
if len(serverConfig.Command) > 0 {
m.debugLogger.LogDebug(fmt.Sprintf("[DEBUG] Command: %s %v", serverConfig.Command[0], serverConfig.Command[1:]))
}
if len(serverConfig.Environment) > 0 {
m.debugLogger.LogDebug(fmt.Sprintf("[DEBUG] Environment variables: %d", len(serverConfig.Environment)))
}
case "sse", "streamable":
m.debugLogger.LogDebug(fmt.Sprintf("[DEBUG] URL: %s", serverConfig.URL))
if len(serverConfig.Headers) > 0 {
m.debugLogger.LogDebug(fmt.Sprintf("[DEBUG] Headers: %v", serverConfig.Headers))
}
}
}
// convertExclusiveBoundsToBoolean converts JSON Schema draft-07 style exclusive bounds
// (where exclusiveMinimum/exclusiveMaximum are numbers) to draft-04 style
// (where they are booleans that modify minimum/maximum).
func convertExclusiveBoundsToBoolean(schemaJSON []byte) []byte {
var data map[string]any
if err := json.Unmarshal(schemaJSON, &data); err != nil {
return schemaJSON
}
convertSchemaRecursive(data)
result, err := json.Marshal(data)
if err != nil {
return schemaJSON
}
return result
}
// convertSchemaRecursive recursively processes a schema map and converts
// numeric exclusiveMinimum/exclusiveMaximum to boolean format.
func convertSchemaRecursive(schema map[string]any) {
if exMin, ok := schema["exclusiveMinimum"]; ok {
if num, isNum := exMin.(float64); isNum {
schema["minimum"] = num
schema["exclusiveMinimum"] = true
}
}
if exMax, ok := schema["exclusiveMaximum"]; ok {
if num, isNum := exMax.(float64); isNum {
schema["maximum"] = num
schema["exclusiveMaximum"] = true
}
}
if props, ok := schema["properties"].(map[string]any); ok {
for _, prop := range props {
if propSchema, ok := prop.(map[string]any); ok {
convertSchemaRecursive(propSchema)
}
}
}
if items, ok := schema["items"].(map[string]any); ok {
convertSchemaRecursive(items)
}
if addProps, ok := schema["additionalProperties"].(map[string]any); ok {
convertSchemaRecursive(addProps)
}
for _, key := range []string{"allOf", "anyOf", "oneOf"} {
if arr, ok := schema[key].([]any); ok {
for _, item := range arr {
if itemSchema, ok := item.(map[string]any); ok {
convertSchemaRecursive(itemSchema)
}
}
}
}
if not, ok := schema["not"].(map[string]any); ok {
convertSchemaRecursive(not)
}
}