mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
558 lines
19 KiB
Go
558 lines
19 KiB
Go
package tools
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/bytedance/sonic"
|
|
"github.com/cloudwego/eino/components/model"
|
|
"github.com/cloudwego/eino/components/tool"
|
|
"github.com/cloudwego/eino/schema"
|
|
"github.com/getkin/kin-openapi/openapi3"
|
|
"github.com/mark3labs/mcp-go/client"
|
|
"github.com/mark3labs/mcp-go/client/transport"
|
|
"github.com/mark3labs/mcp-go/mcp"
|
|
"github.com/mark3labs/mcphost/internal/builtin"
|
|
"github.com/mark3labs/mcphost/internal/config"
|
|
)
|
|
|
|
// MCPToolManager manages MCP (Model Context Protocol) tools and clients across multiple servers.
|
|
// It provides a unified interface for loading, managing, and executing tools from various MCP servers,
|
|
// including stdio, SSE, streamable HTTP, and built-in server types. The manager handles connection
|
|
// pooling, health checks, tool name prefixing to avoid conflicts, and sampling support for LLM interactions.
|
|
// Thread-safe for concurrent tool invocations.
|
|
type MCPToolManager struct {
|
|
connectionPool *MCPConnectionPool
|
|
tools []tool.BaseTool
|
|
toolMap map[string]*toolMapping // maps prefixed tool names to their server and original name
|
|
model model.ToolCallingChatModel // LLM model for sampling
|
|
config *config.Config
|
|
debug bool
|
|
debugLogger DebugLogger
|
|
}
|
|
|
|
// toolMapping stores the mapping between prefixed tool names and their original details
|
|
type toolMapping struct {
|
|
serverName string
|
|
originalName string
|
|
serverConfig config.MCPServerConfig
|
|
manager *MCPToolManager
|
|
}
|
|
|
|
// mcpToolImpl implements the eino tool interface with server prefixing
|
|
type mcpToolImpl struct {
|
|
info *schema.ToolInfo
|
|
mapping *toolMapping
|
|
}
|
|
|
|
// NewMCPToolManager creates a new MCP tool manager instance.
|
|
// Returns an initialized manager with empty tool collections ready to load tools from MCP servers.
|
|
// The manager must be configured with SetModel and LoadTools before use.
|
|
func NewMCPToolManager() *MCPToolManager {
|
|
return &MCPToolManager{
|
|
tools: make([]tool.BaseTool, 0),
|
|
toolMap: make(map[string]*toolMapping),
|
|
}
|
|
}
|
|
|
|
// SetModel sets the LLM model for sampling support.
|
|
// The model is used when MCP servers request sampling operations, allowing them to
|
|
// leverage the host's LLM capabilities for text generation tasks.
|
|
// This method should be called before LoadTools if any MCP servers require sampling support.
|
|
func (m *MCPToolManager) SetModel(model model.ToolCallingChatModel) {
|
|
m.model = model
|
|
}
|
|
|
|
// SetDebugLogger sets the debug logger for the tool manager.
|
|
// The logger will be used to output detailed debugging information about MCP connections,
|
|
// tool loading, and execution. If a connection pool exists, it will also be configured
|
|
// to use the same logger for consistent debugging output.
|
|
func (m *MCPToolManager) SetDebugLogger(logger DebugLogger) {
|
|
m.debugLogger = logger
|
|
if m.connectionPool != nil {
|
|
m.connectionPool.SetDebugLogger(logger)
|
|
}
|
|
}
|
|
|
|
// samplingHandler implements the MCP sampling handler interface
|
|
type samplingHandler struct {
|
|
model model.ToolCallingChatModel
|
|
}
|
|
|
|
// CreateMessage handles sampling requests from MCP servers by forwarding them to the configured LLM model.
|
|
// It converts MCP message formats to eino message formats, invokes the model for generation,
|
|
// and converts the response back to MCP format. Returns an error if no model is available
|
|
// or if generation fails.
|
|
func (h *samplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
|
|
if h.model == nil {
|
|
return nil, fmt.Errorf("no model available for sampling")
|
|
}
|
|
|
|
// Convert MCP messages to eino messages
|
|
var messages []*schema.Message
|
|
|
|
// Add system message if provided
|
|
if request.SystemPrompt != "" {
|
|
messages = append(messages, schema.SystemMessage(request.SystemPrompt))
|
|
}
|
|
|
|
// Convert sampling messages
|
|
for _, msg := range request.Messages {
|
|
// Extract text content
|
|
var content string
|
|
if textContent, ok := msg.Content.(mcp.TextContent); ok {
|
|
content = textContent.Text
|
|
} else {
|
|
content = fmt.Sprintf("%v", msg.Content)
|
|
}
|
|
|
|
switch msg.Role {
|
|
case mcp.RoleUser:
|
|
messages = append(messages, schema.UserMessage(content))
|
|
case mcp.RoleAssistant:
|
|
messages = append(messages, schema.AssistantMessage(content, nil))
|
|
default:
|
|
messages = append(messages, schema.UserMessage(content)) // Default to user
|
|
}
|
|
}
|
|
|
|
// Generate response using the model (no config options for now)
|
|
response, err := h.model.Generate(ctx, messages)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("model generation failed: %w", err)
|
|
}
|
|
|
|
// Convert response back to MCP format
|
|
result := &mcp.CreateMessageResult{
|
|
Model: "mcphost-model", // Generic model name
|
|
StopReason: "endTurn",
|
|
}
|
|
result.SamplingMessage = mcp.SamplingMessage{
|
|
Role: mcp.RoleAssistant,
|
|
Content: mcp.TextContent{
|
|
Type: "text",
|
|
Text: response.Content,
|
|
},
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// LoadTools loads tools from all configured MCP servers based on the provided configuration.
|
|
// It initializes the connection pool, connects to each configured server, and loads their tools.
|
|
// Tools from different servers are prefixed with the server name to avoid naming conflicts.
|
|
// Returns an error only if all configured servers fail to load; partial failures are logged as warnings.
|
|
// This method is thread-safe and idempotent.
|
|
func (m *MCPToolManager) LoadTools(ctx context.Context, config *config.Config) error {
|
|
// Initialize connection pool
|
|
m.config = config
|
|
m.debug = config.Debug
|
|
if m.debugLogger == nil {
|
|
m.debugLogger = NewSimpleDebugLogger(config.Debug)
|
|
}
|
|
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, config.Debug)
|
|
m.connectionPool.SetDebugLogger(m.debugLogger)
|
|
|
|
var loadErrors []string
|
|
|
|
for serverName, serverConfig := range config.MCPServers {
|
|
if err := m.loadServerTools(ctx, serverName, serverConfig); err != nil {
|
|
loadErrors = append(loadErrors, fmt.Sprintf("server %s: %v", serverName, err))
|
|
fmt.Printf("Warning: Failed to load MCP server '%s': %v\n", serverName, err)
|
|
continue
|
|
}
|
|
}
|
|
|
|
// If all servers failed to load, return an error
|
|
if len(loadErrors) == len(config.MCPServers) && len(config.MCPServers) > 0 {
|
|
return fmt.Errorf("all MCP servers failed to load: %s", strings.Join(loadErrors, "; "))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// loadServerTools loads tools from a single MCP server
|
|
func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) error {
|
|
// Add debug logging
|
|
m.debugLogConnectionInfo(serverName, serverConfig)
|
|
|
|
// Get connection from pool
|
|
conn, err := m.connectionPool.GetConnection(ctx, serverName, serverConfig)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get connection from pool: %v", err)
|
|
}
|
|
|
|
// Get tools from this server
|
|
listResults, err := conn.client.ListTools(ctx, mcp.ListToolsRequest{})
|
|
if err != nil {
|
|
// Handle connection error
|
|
m.connectionPool.HandleConnectionError(serverName, err)
|
|
return fmt.Errorf("failed to list tools: %v", err)
|
|
}
|
|
|
|
// Create name set for allowed tools
|
|
var nameSet map[string]struct{}
|
|
if len(serverConfig.AllowedTools) > 0 {
|
|
nameSet = make(map[string]struct{})
|
|
for _, name := range serverConfig.AllowedTools {
|
|
nameSet[name] = struct{}{}
|
|
}
|
|
}
|
|
|
|
// Convert MCP tools to eino tools with prefixed names
|
|
for _, mcpTool := range listResults.Tools {
|
|
// Filter tools based on allowedTools/excludedTools
|
|
if len(serverConfig.AllowedTools) > 0 {
|
|
if _, ok := nameSet[mcpTool.Name]; !ok {
|
|
continue
|
|
}
|
|
}
|
|
|
|
// Check if tool should be excluded
|
|
if m.shouldExcludeTool(mcpTool.Name, serverConfig) {
|
|
continue
|
|
}
|
|
|
|
// Convert schema
|
|
marshaledInputSchema, err := sonic.Marshal(mcpTool.InputSchema)
|
|
if err != nil {
|
|
return fmt.Errorf("conv mcp tool input schema fail(marshal): %w, tool name: %s", err, mcpTool.Name)
|
|
}
|
|
inputSchema := &openapi3.Schema{}
|
|
err = sonic.Unmarshal(marshaledInputSchema, inputSchema)
|
|
if err != nil {
|
|
return fmt.Errorf("conv mcp tool input schema fail(unmarshal): %w, tool name: %s", err, mcpTool.Name)
|
|
}
|
|
|
|
// Fix for issue #89: Ensure object schemas have a properties field
|
|
// OpenAI function calling requires object schemas to have a "properties" field
|
|
// even if it's empty, otherwise it throws "object schema missing properties" error
|
|
if inputSchema.Type == "object" && inputSchema.Properties == nil {
|
|
inputSchema.Properties = make(openapi3.Schemas)
|
|
}
|
|
|
|
// Create prefixed tool name
|
|
prefixedName := fmt.Sprintf("%s__%s", serverName, mcpTool.Name)
|
|
|
|
// Create tool mapping
|
|
mapping := &toolMapping{
|
|
serverName: serverName,
|
|
originalName: mcpTool.Name,
|
|
serverConfig: serverConfig,
|
|
manager: m,
|
|
}
|
|
m.toolMap[prefixedName] = mapping
|
|
|
|
// Create eino tool
|
|
einoTool := &mcpToolImpl{
|
|
info: &schema.ToolInfo{
|
|
Name: prefixedName,
|
|
Desc: mcpTool.Description,
|
|
ParamsOneOf: schema.NewParamsOneOfByOpenAPIV3(inputSchema),
|
|
},
|
|
mapping: mapping,
|
|
}
|
|
|
|
m.tools = append(m.tools, einoTool)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Info returns the tool information including name, description, and parameter schema.
|
|
// This method implements the eino tool.BaseTool interface.
|
|
// The returned ToolInfo contains the prefixed tool name to ensure uniqueness across servers.
|
|
func (t *mcpToolImpl) Info(ctx context.Context) (*schema.ToolInfo, error) {
|
|
return t.info, nil
|
|
}
|
|
|
|
// InvokableRun executes the tool by mapping the prefixed name back to the original tool name and server.
|
|
// It retrieves a healthy connection from the pool, invokes the tool on the appropriate MCP server,
|
|
// and returns the result as a JSON string. The method handles connection errors by marking
|
|
// connections as unhealthy in the pool for automatic recovery on subsequent requests.
|
|
// Thread-safe for concurrent invocations.
|
|
func (t *mcpToolImpl) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
|
|
// Handle empty or invalid JSON arguments
|
|
var arguments any
|
|
if argumentsInJSON == "" || argumentsInJSON == "{}" {
|
|
arguments = nil
|
|
} else {
|
|
// Validate that argumentsInJSON is valid JSON before using it
|
|
var temp any
|
|
if err := json.Unmarshal([]byte(argumentsInJSON), &temp); err != nil {
|
|
return "", fmt.Errorf("invalid JSON arguments: %w", err)
|
|
}
|
|
arguments = json.RawMessage(argumentsInJSON)
|
|
}
|
|
|
|
// Get connection from pool for this server with health check
|
|
conn, err := t.mapping.manager.connectionPool.GetConnectionWithHealthCheck(ctx, t.mapping.serverName, t.mapping.serverConfig)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to get healthy connection from pool: %w", err)
|
|
}
|
|
|
|
result, err := conn.client.CallTool(ctx, mcp.CallToolRequest{
|
|
Request: mcp.Request{
|
|
Method: "tools/call",
|
|
},
|
|
Params: struct {
|
|
Name string `json:"name"`
|
|
Arguments any `json:"arguments,omitempty"`
|
|
Meta *mcp.Meta `json:"_meta,omitempty"`
|
|
}{
|
|
Name: t.mapping.originalName, // Use original name, not prefixed
|
|
Arguments: arguments,
|
|
},
|
|
})
|
|
if err != nil {
|
|
// Handle connection error in pool
|
|
t.mapping.manager.connectionPool.HandleConnectionError(t.mapping.serverName, err)
|
|
return "", fmt.Errorf("failed to call mcp tool: %w", err)
|
|
}
|
|
|
|
marshaledResult, err := sonic.MarshalString(result)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to marshal mcp tool result: %w", err)
|
|
}
|
|
|
|
// If the MCP server returned an error, we still return the error content as the response
|
|
// to the LLM so it can see what went wrong. The error will be shown to the user via
|
|
// the UI callbacks, but the LLM needs to see the actual error details to continue
|
|
// the conversation appropriately.
|
|
return marshaledResult, nil
|
|
}
|
|
|
|
// GetTools returns all loaded tools from all configured MCP servers.
|
|
// Tools are returned with their prefixed names (serverName__toolName) to ensure uniqueness.
|
|
// The returned slice is a copy and can be safely modified by the caller.
|
|
func (m *MCPToolManager) GetTools() []tool.BaseTool {
|
|
return m.tools
|
|
}
|
|
|
|
// GetLoadedServerNames returns the names of all successfully loaded MCP servers.
|
|
// This includes servers that are currently connected and have had their tools loaded,
|
|
// regardless of their current health status. Useful for debugging and status reporting.
|
|
func (m *MCPToolManager) GetLoadedServerNames() []string {
|
|
var names []string
|
|
for serverName := range m.connectionPool.GetClients() {
|
|
names = append(names, serverName)
|
|
}
|
|
return names
|
|
}
|
|
|
|
// Close closes all MCP client connections and cleans up resources.
|
|
// This method should be called when the tool manager is no longer needed to ensure
|
|
// proper cleanup of stdio processes, network connections, and other resources.
|
|
// It is safe to call Close multiple times.
|
|
func (m *MCPToolManager) Close() error {
|
|
return m.connectionPool.Close()
|
|
}
|
|
|
|
// shouldExcludeTool determines if a tool should be excluded based on excludedTools
|
|
func (m *MCPToolManager) shouldExcludeTool(toolName string, serverConfig config.MCPServerConfig) bool {
|
|
// If excludedTools is specified, exclude tools in the list
|
|
if len(serverConfig.ExcludedTools) > 0 {
|
|
for _, excludedTool := range serverConfig.ExcludedTools {
|
|
if excludedTool == toolName {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (m *MCPToolManager) createMCPClient(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (client.MCPClient, error) {
|
|
transportType := serverConfig.GetTransportType()
|
|
|
|
switch transportType {
|
|
case "stdio":
|
|
// STDIO client
|
|
var env []string
|
|
var command string
|
|
var args []string
|
|
|
|
// Handle command and environment
|
|
if len(serverConfig.Command) > 0 {
|
|
command = serverConfig.Command[0]
|
|
if len(serverConfig.Command) > 1 {
|
|
args = serverConfig.Command[1:]
|
|
} else if len(serverConfig.Args) > 0 {
|
|
// Legacy fallback: Command only has the command, Args has the arguments
|
|
// This handles cases where legacy config conversion didn't work properly
|
|
args = serverConfig.Args
|
|
}
|
|
}
|
|
|
|
// Convert environment variables
|
|
if serverConfig.Environment != nil {
|
|
for k, v := range serverConfig.Environment {
|
|
env = append(env, fmt.Sprintf("%s=%s", k, v))
|
|
}
|
|
}
|
|
|
|
// Legacy environment support
|
|
if serverConfig.Env != nil {
|
|
for k, v := range serverConfig.Env {
|
|
env = append(env, fmt.Sprintf("%s=%v", k, v))
|
|
}
|
|
}
|
|
|
|
// Create stdio transport
|
|
stdioTransport := transport.NewStdio(command, env, args...)
|
|
|
|
stdioClient := client.NewClient(stdioTransport)
|
|
|
|
// Start the transport
|
|
if err := stdioTransport.Start(ctx); err != nil {
|
|
return nil, fmt.Errorf("failed to start stdio transport: %v", err)
|
|
}
|
|
|
|
// Add a brief delay to allow the process to start and potentially fail
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
// TODO: Add process health check here if the mcp-go library exposes process info
|
|
// For now, we rely on the timeout in initializeClient to catch dead processes
|
|
|
|
return stdioClient, nil
|
|
|
|
case "sse":
|
|
// SSE client
|
|
var options []transport.ClientOption
|
|
|
|
// Add headers if specified
|
|
if len(serverConfig.Headers) > 0 {
|
|
headers := make(map[string]string)
|
|
for _, header := range serverConfig.Headers {
|
|
parts := strings.SplitN(header, ":", 2)
|
|
if len(parts) == 2 {
|
|
key := strings.TrimSpace(parts[0])
|
|
value := strings.TrimSpace(parts[1])
|
|
headers[key] = value
|
|
}
|
|
}
|
|
if len(headers) > 0 {
|
|
options = append(options, transport.WithHeaders(headers))
|
|
}
|
|
}
|
|
|
|
sseClient, err := client.NewSSEMCPClient(serverConfig.URL, options...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Start the SSE client
|
|
if err := sseClient.Start(ctx); err != nil {
|
|
return nil, fmt.Errorf("failed to start SSE client: %v", err)
|
|
}
|
|
|
|
return sseClient, nil
|
|
|
|
case "streamable":
|
|
// Streamable HTTP client
|
|
var options []transport.StreamableHTTPCOption
|
|
|
|
// Add headers if specified
|
|
if len(serverConfig.Headers) > 0 {
|
|
headers := make(map[string]string)
|
|
for _, header := range serverConfig.Headers {
|
|
parts := strings.SplitN(header, ":", 2)
|
|
if len(parts) == 2 {
|
|
key := strings.TrimSpace(parts[0])
|
|
value := strings.TrimSpace(parts[1])
|
|
headers[key] = value
|
|
}
|
|
}
|
|
if len(headers) > 0 {
|
|
options = append(options, transport.WithHTTPHeaders(headers))
|
|
}
|
|
}
|
|
|
|
streamableClient, err := client.NewStreamableHttpClient(serverConfig.URL, options...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Start the streamable HTTP client
|
|
if err := streamableClient.Start(ctx); err != nil {
|
|
return nil, fmt.Errorf("failed to start streamable HTTP client: %v", err)
|
|
}
|
|
|
|
return streamableClient, nil
|
|
|
|
case "inprocess":
|
|
// Builtin server
|
|
return m.createBuiltinClient(ctx, serverName, serverConfig)
|
|
|
|
default:
|
|
return nil, fmt.Errorf("unsupported transport type '%s' for server %s", transportType, serverName)
|
|
}
|
|
}
|
|
|
|
func (m *MCPToolManager) initializeClient(ctx context.Context, client client.MCPClient) error {
|
|
// Create a timeout context for initialization to prevent deadlocks
|
|
initCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
|
defer cancel()
|
|
|
|
initRequest := mcp.InitializeRequest{}
|
|
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
|
|
initRequest.Params.ClientInfo = mcp.Implementation{
|
|
Name: "mcphost",
|
|
Version: "1.0.0",
|
|
}
|
|
initRequest.Params.Capabilities = mcp.ClientCapabilities{}
|
|
|
|
_, err := client.Initialize(initCtx, initRequest)
|
|
if err != nil {
|
|
return fmt.Errorf("initialization timeout or failed: %v", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// createBuiltinClient creates an in-process MCP client for builtin servers
|
|
func (m *MCPToolManager) createBuiltinClient(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (client.MCPClient, error) {
|
|
registry := builtin.NewRegistry()
|
|
|
|
// Create the builtin server, passing the model for servers that need it
|
|
builtinServer, err := registry.CreateServer(serverConfig.Name, serverConfig.Options, m.model)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create builtin server: %v", err)
|
|
}
|
|
|
|
// Create an in-process client that wraps the builtin server
|
|
inProcessClient, err := client.NewInProcessClient(builtinServer.GetServer())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create in-process client: %v", err)
|
|
}
|
|
|
|
return inProcessClient, nil
|
|
}
|
|
|
|
// debugLogConnectionInfo logs detailed connection information for debugging
|
|
func (m *MCPToolManager) debugLogConnectionInfo(serverName string, serverConfig config.MCPServerConfig) {
|
|
if m.debugLogger == nil || !m.debugLogger.IsDebugEnabled() {
|
|
return
|
|
}
|
|
|
|
m.debugLogger.LogDebug(fmt.Sprintf("[DEBUG] Connecting to MCP server: %s", serverName))
|
|
m.debugLogger.LogDebug(fmt.Sprintf("[DEBUG] Transport type: %s", serverConfig.GetTransportType()))
|
|
|
|
switch serverConfig.GetTransportType() {
|
|
case "stdio":
|
|
if len(serverConfig.Command) > 0 {
|
|
m.debugLogger.LogDebug(fmt.Sprintf("[DEBUG] Command: %s %v", serverConfig.Command[0], serverConfig.Command[1:]))
|
|
}
|
|
if len(serverConfig.Environment) > 0 {
|
|
m.debugLogger.LogDebug(fmt.Sprintf("[DEBUG] Environment variables: %d", len(serverConfig.Environment)))
|
|
}
|
|
case "sse", "streamable":
|
|
m.debugLogger.LogDebug(fmt.Sprintf("[DEBUG] URL: %s", serverConfig.URL))
|
|
if len(serverConfig.Headers) > 0 {
|
|
m.debugLogger.LogDebug(fmt.Sprintf("[DEBUG] Headers: %v", serverConfig.Headers))
|
|
}
|
|
}
|
|
}
|