mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-13 19:20:06 +00:00
0703dd1602
Each spinner created a new tea.NewProgram which sent DECRQM queries for synchronized output mode 2026. When the program exited and restored cooked terminal mode, the terminal's DECRPM response leaked as visible ^[[?2026;2$y characters. Replace Bubble Tea spinner with a simple goroutine animation loop writing directly to stderr via lipgloss.
547 lines
17 KiB
Go
547 lines
17 KiB
Go
package tools
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"slices"
|
|
"strings"
|
|
"time"
|
|
|
|
"charm.land/fantasy"
|
|
"github.com/mark3labs/mcp-go/client"
|
|
"github.com/mark3labs/mcp-go/client/transport"
|
|
"github.com/mark3labs/mcp-go/mcp"
|
|
"github.com/mark3labs/mcphost/internal/builtin"
|
|
"github.com/mark3labs/mcphost/internal/config"
|
|
)
|
|
|
|
// MCPToolManager manages MCP (Model Context Protocol) tools and clients across multiple servers.
|
|
// It provides a unified interface for loading, managing, and executing tools from various MCP servers,
|
|
// including stdio, SSE, streamable HTTP, and built-in server types. The manager handles connection
|
|
// pooling, health checks, tool name prefixing to avoid conflicts, and sampling support for LLM interactions.
|
|
// Thread-safe for concurrent tool invocations.
|
|
type MCPToolManager struct {
|
|
connectionPool *MCPConnectionPool
|
|
tools []fantasy.AgentTool
|
|
toolMap map[string]*toolMapping // maps prefixed tool names to their server and original name
|
|
model fantasy.LanguageModel // LLM model for sampling
|
|
config *config.Config
|
|
debug bool
|
|
debugLogger DebugLogger
|
|
}
|
|
|
|
// toolMapping stores the mapping between prefixed tool names and their original details
|
|
type toolMapping struct {
|
|
serverName string
|
|
originalName string
|
|
serverConfig config.MCPServerConfig
|
|
manager *MCPToolManager
|
|
}
|
|
|
|
// NewMCPToolManager creates a new MCP tool manager instance.
|
|
// Returns an initialized manager with empty tool collections ready to load tools from MCP servers.
|
|
// The manager must be configured with SetModel and LoadTools before use.
|
|
func NewMCPToolManager() *MCPToolManager {
|
|
return &MCPToolManager{
|
|
tools: make([]fantasy.AgentTool, 0),
|
|
toolMap: make(map[string]*toolMapping),
|
|
}
|
|
}
|
|
|
|
// SetModel sets the LLM model for sampling support.
|
|
// The model is used when MCP servers request sampling operations, allowing them to
|
|
// leverage the host's LLM capabilities for text generation tasks.
|
|
// This method should be called before LoadTools if any MCP servers require sampling support.
|
|
func (m *MCPToolManager) SetModel(model fantasy.LanguageModel) {
|
|
m.model = model
|
|
}
|
|
|
|
// SetDebugLogger sets the debug logger for the tool manager.
|
|
// The logger will be used to output detailed debugging information about MCP connections,
|
|
// tool loading, and execution. If a connection pool exists, it will also be configured
|
|
// to use the same logger for consistent debugging output.
|
|
func (m *MCPToolManager) SetDebugLogger(logger DebugLogger) {
|
|
m.debugLogger = logger
|
|
if m.connectionPool != nil {
|
|
m.connectionPool.SetDebugLogger(logger)
|
|
}
|
|
}
|
|
|
|
// samplingHandler implements the MCP sampling handler interface using a fantasy LanguageModel
|
|
type samplingHandler struct {
|
|
model fantasy.LanguageModel
|
|
}
|
|
|
|
// CreateMessage handles sampling requests from MCP servers by forwarding them to the configured LLM model.
|
|
// It converts MCP message formats to fantasy message formats, invokes the model for generation,
|
|
// and converts the response back to MCP format. Returns an error if no model is available
|
|
// or if generation fails.
|
|
func (h *samplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
|
|
if h.model == nil {
|
|
return nil, fmt.Errorf("no model available for sampling")
|
|
}
|
|
|
|
// Build fantasy messages from MCP sampling request
|
|
var messages []fantasy.Message
|
|
|
|
// Add system message if provided
|
|
if request.SystemPrompt != "" {
|
|
messages = append(messages, fantasy.NewSystemMessage(request.SystemPrompt))
|
|
}
|
|
|
|
// Convert sampling messages
|
|
for _, msg := range request.Messages {
|
|
var content string
|
|
if textContent, ok := msg.Content.(mcp.TextContent); ok {
|
|
content = textContent.Text
|
|
} else {
|
|
content = fmt.Sprintf("%v", msg.Content)
|
|
}
|
|
|
|
switch msg.Role {
|
|
case mcp.RoleUser:
|
|
messages = append(messages, fantasy.NewUserMessage(content))
|
|
case mcp.RoleAssistant:
|
|
messages = append(messages, fantasy.Message{
|
|
Role: fantasy.MessageRoleAssistant,
|
|
Content: []fantasy.MessagePart{fantasy.TextPart{Text: content}},
|
|
})
|
|
default:
|
|
messages = append(messages, fantasy.NewUserMessage(content))
|
|
}
|
|
}
|
|
|
|
// Generate response using the fantasy model
|
|
call := fantasy.Call{
|
|
Prompt: fantasy.Prompt(messages),
|
|
}
|
|
response, err := h.model.Generate(ctx, call)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("model generation failed: %w", err)
|
|
}
|
|
|
|
// Convert response back to MCP format
|
|
result := &mcp.CreateMessageResult{
|
|
Model: h.model.Model(),
|
|
StopReason: "endTurn",
|
|
}
|
|
result.SamplingMessage = mcp.SamplingMessage{
|
|
Role: mcp.RoleAssistant,
|
|
Content: mcp.TextContent{
|
|
Type: "text",
|
|
Text: response.Content.Text(),
|
|
},
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// LoadTools loads tools from all configured MCP servers based on the provided configuration.
|
|
// It initializes the connection pool, connects to each configured server, and loads their tools.
|
|
// Tools from different servers are prefixed with the server name to avoid naming conflicts.
|
|
// Returns an error only if all configured servers fail to load; partial failures are logged as warnings.
|
|
// This method is thread-safe and idempotent.
|
|
func (m *MCPToolManager) LoadTools(ctx context.Context, config *config.Config) error {
|
|
// Initialize connection pool
|
|
m.config = config
|
|
m.debug = config.Debug
|
|
if m.debugLogger == nil {
|
|
m.debugLogger = NewSimpleDebugLogger(config.Debug)
|
|
}
|
|
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, config.Debug)
|
|
m.connectionPool.SetDebugLogger(m.debugLogger)
|
|
|
|
var loadErrors []string
|
|
|
|
for serverName, serverConfig := range config.MCPServers {
|
|
if err := m.loadServerTools(ctx, serverName, serverConfig); err != nil {
|
|
loadErrors = append(loadErrors, fmt.Sprintf("server %s: %v", serverName, err))
|
|
fmt.Printf("Warning: Failed to load MCP server '%s': %v\n", serverName, err)
|
|
continue
|
|
}
|
|
}
|
|
|
|
// If all servers failed to load, return an error
|
|
if len(loadErrors) == len(config.MCPServers) && len(config.MCPServers) > 0 {
|
|
return fmt.Errorf("all MCP servers failed to load: %s", strings.Join(loadErrors, "; "))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// loadServerTools loads tools from a single MCP server
|
|
func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) error {
|
|
// Add debug logging
|
|
m.debugLogConnectionInfo(serverName, serverConfig)
|
|
|
|
// Get connection from pool
|
|
conn, err := m.connectionPool.GetConnection(ctx, serverName, serverConfig)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get connection from pool: %v", err)
|
|
}
|
|
|
|
// Get tools from this server
|
|
listResults, err := conn.client.ListTools(ctx, mcp.ListToolsRequest{})
|
|
if err != nil {
|
|
// Handle connection error
|
|
m.connectionPool.HandleConnectionError(serverName, err)
|
|
return fmt.Errorf("failed to list tools: %v", err)
|
|
}
|
|
|
|
// Create name set for allowed tools
|
|
var nameSet map[string]struct{}
|
|
if len(serverConfig.AllowedTools) > 0 {
|
|
nameSet = make(map[string]struct{})
|
|
for _, name := range serverConfig.AllowedTools {
|
|
nameSet[name] = struct{}{}
|
|
}
|
|
}
|
|
|
|
// Convert MCP tools to fantasy AgentTools with prefixed names
|
|
for _, mcpTool := range listResults.Tools {
|
|
// Filter tools based on allowedTools/excludedTools
|
|
if len(serverConfig.AllowedTools) > 0 {
|
|
if _, ok := nameSet[mcpTool.Name]; !ok {
|
|
continue
|
|
}
|
|
}
|
|
|
|
// Check if tool should be excluded
|
|
if m.shouldExcludeTool(mcpTool.Name, serverConfig) {
|
|
continue
|
|
}
|
|
|
|
// Convert MCP InputSchema to map[string]any for fantasy ToolInfo
|
|
marshaledSchema, err := json.Marshal(mcpTool.InputSchema)
|
|
if err != nil {
|
|
return fmt.Errorf("conv mcp tool input schema fail(marshal): %w, tool name: %s", err, mcpTool.Name)
|
|
}
|
|
|
|
// Fix for JSON Schema draft-07 vs draft-04 compatibility
|
|
marshaledSchema = convertExclusiveBoundsToBoolean(marshaledSchema)
|
|
|
|
// Parse into map[string]any for fantasy's parameters format
|
|
var schemaMap map[string]any
|
|
if err := json.Unmarshal(marshaledSchema, &schemaMap); err != nil {
|
|
return fmt.Errorf("conv mcp tool input schema fail(unmarshal): %w, tool name: %s", err, mcpTool.Name)
|
|
}
|
|
|
|
// Extract properties and required from the schema
|
|
parameters := make(map[string]any)
|
|
var required []string
|
|
|
|
if props, ok := schemaMap["properties"].(map[string]any); ok {
|
|
parameters = props
|
|
}
|
|
|
|
// Fix for issue #89: Ensure object schemas have a properties field
|
|
if schemaType, ok := schemaMap["type"].(string); ok && schemaType == "object" && len(parameters) == 0 {
|
|
// Keep empty parameters map - fantasy handles this fine
|
|
}
|
|
|
|
if req, ok := schemaMap["required"].([]any); ok {
|
|
for _, r := range req {
|
|
if s, ok := r.(string); ok {
|
|
required = append(required, s)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Create prefixed tool name
|
|
prefixedName := fmt.Sprintf("%s__%s", serverName, mcpTool.Name)
|
|
|
|
// Create tool mapping
|
|
mapping := &toolMapping{
|
|
serverName: serverName,
|
|
originalName: mcpTool.Name,
|
|
serverConfig: serverConfig,
|
|
manager: m,
|
|
}
|
|
m.toolMap[prefixedName] = mapping
|
|
|
|
// Create fantasy AgentTool
|
|
fantasyTool := &mcpFantasyTool{
|
|
toolInfo: fantasy.ToolInfo{
|
|
Name: prefixedName,
|
|
Description: mcpTool.Description,
|
|
Parameters: parameters,
|
|
Required: required,
|
|
},
|
|
mapping: mapping,
|
|
}
|
|
|
|
m.tools = append(m.tools, fantasyTool)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetTools returns all loaded tools as fantasy AgentTools from all configured MCP servers.
|
|
// Tools are returned with their prefixed names (serverName__toolName) to ensure uniqueness.
|
|
func (m *MCPToolManager) GetTools() []fantasy.AgentTool {
|
|
return m.tools
|
|
}
|
|
|
|
// GetLoadedServerNames returns the names of all successfully loaded MCP servers.
|
|
// This includes servers that are currently connected and have had their tools loaded,
|
|
// regardless of their current health status. Useful for debugging and status reporting.
|
|
func (m *MCPToolManager) GetLoadedServerNames() []string {
|
|
var names []string
|
|
for serverName := range m.connectionPool.GetClients() {
|
|
names = append(names, serverName)
|
|
}
|
|
return names
|
|
}
|
|
|
|
// Close closes all MCP client connections and cleans up resources.
|
|
// This method should be called when the tool manager is no longer needed to ensure
|
|
// proper cleanup of stdio processes, network connections, and other resources.
|
|
// It is safe to call Close multiple times.
|
|
func (m *MCPToolManager) Close() error {
|
|
return m.connectionPool.Close()
|
|
}
|
|
|
|
// shouldExcludeTool determines if a tool should be excluded based on excludedTools
|
|
func (m *MCPToolManager) shouldExcludeTool(toolName string, serverConfig config.MCPServerConfig) bool {
|
|
if len(serverConfig.ExcludedTools) > 0 {
|
|
if slices.Contains(serverConfig.ExcludedTools, toolName) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (m *MCPToolManager) createMCPClient(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (client.MCPClient, error) {
|
|
transportType := serverConfig.GetTransportType()
|
|
|
|
switch transportType {
|
|
case "stdio":
|
|
var env []string
|
|
var command string
|
|
var args []string
|
|
|
|
if len(serverConfig.Command) > 0 {
|
|
command = serverConfig.Command[0]
|
|
if len(serverConfig.Command) > 1 {
|
|
args = serverConfig.Command[1:]
|
|
} else if len(serverConfig.Args) > 0 {
|
|
args = serverConfig.Args
|
|
}
|
|
}
|
|
|
|
if serverConfig.Environment != nil {
|
|
for k, v := range serverConfig.Environment {
|
|
env = append(env, fmt.Sprintf("%s=%s", k, v))
|
|
}
|
|
}
|
|
|
|
if serverConfig.Env != nil {
|
|
for k, v := range serverConfig.Env {
|
|
env = append(env, fmt.Sprintf("%s=%v", k, v))
|
|
}
|
|
}
|
|
|
|
stdioTransport := transport.NewStdio(command, env, args...)
|
|
stdioClient := client.NewClient(stdioTransport)
|
|
|
|
if err := stdioTransport.Start(ctx); err != nil {
|
|
return nil, fmt.Errorf("failed to start stdio transport: %v", err)
|
|
}
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
return stdioClient, nil
|
|
|
|
case "sse":
|
|
var options []transport.ClientOption
|
|
|
|
if len(serverConfig.Headers) > 0 {
|
|
headers := make(map[string]string)
|
|
for _, header := range serverConfig.Headers {
|
|
parts := strings.SplitN(header, ":", 2)
|
|
if len(parts) == 2 {
|
|
key := strings.TrimSpace(parts[0])
|
|
value := strings.TrimSpace(parts[1])
|
|
headers[key] = value
|
|
}
|
|
}
|
|
if len(headers) > 0 {
|
|
options = append(options, transport.WithHeaders(headers))
|
|
}
|
|
}
|
|
|
|
sseClient, err := client.NewSSEMCPClient(serverConfig.URL, options...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := sseClient.Start(ctx); err != nil {
|
|
return nil, fmt.Errorf("failed to start SSE client: %v", err)
|
|
}
|
|
|
|
return sseClient, nil
|
|
|
|
case "streamable":
|
|
var options []transport.StreamableHTTPCOption
|
|
|
|
if len(serverConfig.Headers) > 0 {
|
|
headers := make(map[string]string)
|
|
for _, header := range serverConfig.Headers {
|
|
parts := strings.SplitN(header, ":", 2)
|
|
if len(parts) == 2 {
|
|
key := strings.TrimSpace(parts[0])
|
|
value := strings.TrimSpace(parts[1])
|
|
headers[key] = value
|
|
}
|
|
}
|
|
if len(headers) > 0 {
|
|
options = append(options, transport.WithHTTPHeaders(headers))
|
|
}
|
|
}
|
|
|
|
streamableClient, err := client.NewStreamableHttpClient(serverConfig.URL, options...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := streamableClient.Start(ctx); err != nil {
|
|
return nil, fmt.Errorf("failed to start streamable HTTP client: %v", err)
|
|
}
|
|
|
|
return streamableClient, nil
|
|
|
|
case "inprocess":
|
|
return m.createBuiltinClient(ctx, serverName, serverConfig)
|
|
|
|
default:
|
|
return nil, fmt.Errorf("unsupported transport type '%s' for server %s", transportType, serverName)
|
|
}
|
|
}
|
|
|
|
func (m *MCPToolManager) initializeClient(ctx context.Context, client client.MCPClient) error {
|
|
initCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
|
defer cancel()
|
|
|
|
initRequest := mcp.InitializeRequest{}
|
|
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
|
|
initRequest.Params.ClientInfo = mcp.Implementation{
|
|
Name: "mcphost",
|
|
Version: "1.0.0",
|
|
}
|
|
initRequest.Params.Capabilities = mcp.ClientCapabilities{}
|
|
|
|
_, err := client.Initialize(initCtx, initRequest)
|
|
if err != nil {
|
|
return fmt.Errorf("initialization timeout or failed: %v", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// createBuiltinClient creates an in-process MCP client for builtin servers
|
|
func (m *MCPToolManager) createBuiltinClient(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (client.MCPClient, error) {
|
|
registry := builtin.NewRegistry()
|
|
|
|
builtinServer, err := registry.CreateServer(serverConfig.Name, serverConfig.Options, m.model)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create builtin server: %v", err)
|
|
}
|
|
|
|
inProcessClient, err := client.NewInProcessClient(builtinServer.GetServer())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create in-process client: %v", err)
|
|
}
|
|
|
|
return inProcessClient, nil
|
|
}
|
|
|
|
// debugLogConnectionInfo logs detailed connection information for debugging
|
|
func (m *MCPToolManager) debugLogConnectionInfo(serverName string, serverConfig config.MCPServerConfig) {
|
|
if m.debugLogger == nil || !m.debugLogger.IsDebugEnabled() {
|
|
return
|
|
}
|
|
|
|
m.debugLogger.LogDebug(fmt.Sprintf("[DEBUG] Connecting to MCP server: %s", serverName))
|
|
m.debugLogger.LogDebug(fmt.Sprintf("[DEBUG] Transport type: %s", serverConfig.GetTransportType()))
|
|
|
|
switch serverConfig.GetTransportType() {
|
|
case "stdio":
|
|
if len(serverConfig.Command) > 0 {
|
|
m.debugLogger.LogDebug(fmt.Sprintf("[DEBUG] Command: %s %v", serverConfig.Command[0], serverConfig.Command[1:]))
|
|
}
|
|
if len(serverConfig.Environment) > 0 {
|
|
m.debugLogger.LogDebug(fmt.Sprintf("[DEBUG] Environment variables: %d", len(serverConfig.Environment)))
|
|
}
|
|
case "sse", "streamable":
|
|
m.debugLogger.LogDebug(fmt.Sprintf("[DEBUG] URL: %s", serverConfig.URL))
|
|
if len(serverConfig.Headers) > 0 {
|
|
m.debugLogger.LogDebug(fmt.Sprintf("[DEBUG] Headers: %v", serverConfig.Headers))
|
|
}
|
|
}
|
|
}
|
|
|
|
// convertExclusiveBoundsToBoolean converts JSON Schema draft-07 style exclusive bounds
|
|
// (where exclusiveMinimum/exclusiveMaximum are numbers) to draft-04 style
|
|
// (where they are booleans that modify minimum/maximum).
|
|
func convertExclusiveBoundsToBoolean(schemaJSON []byte) []byte {
|
|
var data map[string]any
|
|
if err := json.Unmarshal(schemaJSON, &data); err != nil {
|
|
return schemaJSON
|
|
}
|
|
|
|
convertSchemaRecursive(data)
|
|
|
|
result, err := json.Marshal(data)
|
|
if err != nil {
|
|
return schemaJSON
|
|
}
|
|
return result
|
|
}
|
|
|
|
// convertSchemaRecursive recursively processes a schema map and converts
|
|
// numeric exclusiveMinimum/exclusiveMaximum to boolean format.
|
|
func convertSchemaRecursive(schema map[string]any) {
|
|
if exMin, ok := schema["exclusiveMinimum"]; ok {
|
|
if num, isNum := exMin.(float64); isNum {
|
|
schema["minimum"] = num
|
|
schema["exclusiveMinimum"] = true
|
|
}
|
|
}
|
|
|
|
if exMax, ok := schema["exclusiveMaximum"]; ok {
|
|
if num, isNum := exMax.(float64); isNum {
|
|
schema["maximum"] = num
|
|
schema["exclusiveMaximum"] = true
|
|
}
|
|
}
|
|
|
|
if props, ok := schema["properties"].(map[string]any); ok {
|
|
for _, prop := range props {
|
|
if propSchema, ok := prop.(map[string]any); ok {
|
|
convertSchemaRecursive(propSchema)
|
|
}
|
|
}
|
|
}
|
|
|
|
if items, ok := schema["items"].(map[string]any); ok {
|
|
convertSchemaRecursive(items)
|
|
}
|
|
|
|
if addProps, ok := schema["additionalProperties"].(map[string]any); ok {
|
|
convertSchemaRecursive(addProps)
|
|
}
|
|
|
|
for _, key := range []string{"allOf", "anyOf", "oneOf"} {
|
|
if arr, ok := schema[key].([]any); ok {
|
|
for _, item := range arr {
|
|
if itemSchema, ok := item.(map[string]any); ok {
|
|
convertSchemaRecursive(itemSchema)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if not, ok := schema["not"].(map[string]any); ok {
|
|
convertSchemaRecursive(not)
|
|
}
|
|
}
|