From 679709d0789a4abd6a4dd897b0fcb8f9bc39dd25 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Tue, 24 Jun 2025 09:12:09 +0300 Subject: [PATCH] Fix deadlocks --- internal/tools/mcp.go | 175 ++++++++++++++++++++++--------------- internal/tools/mcp_test.go | 90 +++++++++++++++++++ 2 files changed, 196 insertions(+), 69 deletions(-) create mode 100644 internal/tools/mcp_test.go diff --git a/internal/tools/mcp.go b/internal/tools/mcp.go index e4050821..54c13e4a 100644 --- a/internal/tools/mcp.go +++ b/internal/tools/mcp.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "strings" + "time" "github.com/bytedance/sonic" "github.com/cloudwego/eino/components/tool" @@ -47,82 +48,100 @@ func NewMCPToolManager() *MCPToolManager { // LoadTools loads tools from MCP servers based on configuration func (m *MCPToolManager) LoadTools(ctx context.Context, config *config.Config) error { + var loadErrors []string + for serverName, serverConfig := range config.MCPServers { - client, err := m.createMCPClient(ctx, serverName, serverConfig) - if err != nil { - return fmt.Errorf("failed to create MCP client for %s: %v", serverName, err) + if err := m.loadServerTools(ctx, serverName, serverConfig); err != nil { + loadErrors = append(loadErrors, fmt.Sprintf("server %s: %v", serverName, err)) + fmt.Printf("Warning: Failed to load MCP server '%s': %v\n", serverName, err) + continue } + } - m.clients[serverName] = client + // If all servers failed to load, return an error + if len(loadErrors) == len(config.MCPServers) && len(config.MCPServers) > 0 { + return fmt.Errorf("all MCP servers failed to load: %s", strings.Join(loadErrors, "; ")) + } - // Initialize the client - if err := m.initializeClient(ctx, client); err != nil { - return fmt.Errorf("failed to initialize MCP client for %s: %v", serverName, err) + return nil +} + +// loadServerTools loads tools from a single MCP server +func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) error { + client, err := m.createMCPClient(ctx, serverName, serverConfig) + if err != nil { + return fmt.Errorf("failed to create MCP client: %v", err) + } + + m.clients[serverName] = client + + // Initialize the client + if err := m.initializeClient(ctx, client); err != nil { + return fmt.Errorf("failed to initialize MCP client: %v", err) + } + + // Get tools from this server + listResults, err := client.ListTools(ctx, mcp.ListToolsRequest{}) + if err != nil { + return fmt.Errorf("failed to list tools: %v", err) + } + + // Create name set for allowed tools + var nameSet map[string]struct{} + if len(serverConfig.AllowedTools) > 0 { + nameSet = make(map[string]struct{}) + for _, name := range serverConfig.AllowedTools { + nameSet[name] = struct{}{} } + } - // Get tools from this server - listResults, err := client.ListTools(ctx, mcp.ListToolsRequest{}) - if err != nil { - return fmt.Errorf("failed to list tools from server %s: %v", serverName, err) - } - - // Create name set for allowed tools - var nameSet map[string]struct{} + // Convert MCP tools to eino tools with prefixed names + for _, mcpTool := range listResults.Tools { + // Filter tools based on allowedTools/excludedTools if len(serverConfig.AllowedTools) > 0 { - nameSet = make(map[string]struct{}) - for _, name := range serverConfig.AllowedTools { - nameSet[name] = struct{}{} - } - } - - // Convert MCP tools to eino tools with prefixed names - for _, mcpTool := range listResults.Tools { - // Filter tools based on allowedTools/excludedTools - if len(serverConfig.AllowedTools) > 0 { - if _, ok := nameSet[mcpTool.Name]; !ok { - continue - } - } - - // Check if tool should be excluded - if m.shouldExcludeTool(mcpTool.Name, serverConfig) { + if _, ok := nameSet[mcpTool.Name]; !ok { continue } - - // Convert schema - marshaledInputSchema, err := sonic.Marshal(mcpTool.InputSchema) - if err != nil { - return fmt.Errorf("conv mcp tool input schema fail(marshal): %w, tool name: %s", err, mcpTool.Name) - } - inputSchema := &openapi3.Schema{} - err = sonic.Unmarshal(marshaledInputSchema, inputSchema) - if err != nil { - return fmt.Errorf("conv mcp tool input schema fail(unmarshal): %w, tool name: %s", err, mcpTool.Name) - } - - // Create prefixed tool name - prefixedName := fmt.Sprintf("%s__%s", serverName, mcpTool.Name) - - // Create tool mapping - mapping := &toolMapping{ - serverName: serverName, - originalName: mcpTool.Name, - client: client, - } - m.toolMap[prefixedName] = mapping - - // Create eino tool - einoTool := &mcpToolImpl{ - info: &schema.ToolInfo{ - Name: prefixedName, - Desc: mcpTool.Description, - ParamsOneOf: schema.NewParamsOneOfByOpenAPIV3(inputSchema), - }, - mapping: mapping, - } - - m.tools = append(m.tools, einoTool) } + + // Check if tool should be excluded + if m.shouldExcludeTool(mcpTool.Name, serverConfig) { + continue + } + + // Convert schema + marshaledInputSchema, err := sonic.Marshal(mcpTool.InputSchema) + if err != nil { + return fmt.Errorf("conv mcp tool input schema fail(marshal): %w, tool name: %s", err, mcpTool.Name) + } + inputSchema := &openapi3.Schema{} + err = sonic.Unmarshal(marshaledInputSchema, inputSchema) + if err != nil { + return fmt.Errorf("conv mcp tool input schema fail(unmarshal): %w, tool name: %s", err, mcpTool.Name) + } + + // Create prefixed tool name + prefixedName := fmt.Sprintf("%s__%s", serverName, mcpTool.Name) + + // Create tool mapping + mapping := &toolMapping{ + serverName: serverName, + originalName: mcpTool.Name, + client: client, + } + m.toolMap[prefixedName] = mapping + + // Create eino tool + einoTool := &mcpToolImpl{ + info: &schema.ToolInfo{ + Name: prefixedName, + Desc: mcpTool.Description, + ParamsOneOf: schema.NewParamsOneOfByOpenAPIV3(inputSchema), + }, + mapping: mapping, + } + + m.tools = append(m.tools, einoTool) } return nil @@ -204,7 +223,18 @@ func (m *MCPToolManager) createMCPClient(ctx context.Context, serverName string, env = append(env, fmt.Sprintf("%s=%v", k, v)) } - return client.NewStdioMCPClient(serverConfig.Command, env, serverConfig.Args...) + stdioClient, err := client.NewStdioMCPClient(serverConfig.Command, env, serverConfig.Args...) + if err != nil { + return nil, fmt.Errorf("failed to create stdio client: %v", err) + } + + // Add a brief delay to allow the process to start and potentially fail + time.Sleep(100 * time.Millisecond) + + // TODO: Add process health check here if the mcp-go library exposes process info + // For now, we rely on the timeout in initializeClient to catch dead processes + + return stdioClient, nil case "sse": // SSE client @@ -258,6 +288,10 @@ func (m *MCPToolManager) createMCPClient(ctx context.Context, serverName string, } func (m *MCPToolManager) initializeClient(ctx context.Context, client client.MCPClient) error { + // Create a timeout context for initialization to prevent deadlocks + initCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + initRequest := mcp.InitializeRequest{} initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION initRequest.Params.ClientInfo = mcp.Implementation{ @@ -265,6 +299,9 @@ func (m *MCPToolManager) initializeClient(ctx context.Context, client client.MCP Version: "1.0.0", } - _, err := client.Initialize(ctx, initRequest) - return err + _, err := client.Initialize(initCtx, initRequest) + if err != nil { + return fmt.Errorf("initialization timeout or failed: %v", err) + } + return nil } diff --git a/internal/tools/mcp_test.go b/internal/tools/mcp_test.go new file mode 100644 index 00000000..bf2422f4 --- /dev/null +++ b/internal/tools/mcp_test.go @@ -0,0 +1,90 @@ +package tools + +import ( + "context" + "testing" + "time" + + "github.com/mark3labs/mcphost/internal/config" +) + +func TestMCPToolManager_LoadTools_WithTimeout(t *testing.T) { + manager := NewMCPToolManager() + + // Create a config with a non-existent command that should fail + cfg := &config.Config{ + MCPServers: map[string]config.MCPServerConfig{ + "test-server": { + Command: "non-existent-command", + Args: []string{"arg1", "arg2"}, + }, + }, + } + + // Create a context with a reasonable timeout + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + // This should not hang indefinitely and should return an error + start := time.Now() + err := manager.LoadTools(ctx, cfg) + duration := time.Since(start) + + // The operation should complete within our timeout + if duration > 14*time.Second { + t.Errorf("LoadTools took too long: %v, expected to complete within 14 seconds", duration) + } + + // We expect an error since the command doesn't exist, but it shouldn't be a timeout + if err == nil { + t.Error("Expected an error for non-existent command, but got nil") + } + + t.Logf("LoadTools completed in %v with error: %v", duration, err) +} + +func TestMCPToolManager_LoadTools_GracefulFailure(t *testing.T) { + manager := NewMCPToolManager() + + // Create a config with multiple servers, some good and some bad + cfg := &config.Config{ + MCPServers: map[string]config.MCPServerConfig{ + "bad-server-1": { + Command: "non-existent-command-1", + Args: []string{"arg1"}, + }, + "bad-server-2": { + Command: "non-existent-command-2", + Args: []string{"arg2"}, + }, + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // This should fail gracefully and return an error since all servers failed + err := manager.LoadTools(ctx, cfg) + + // We expect an error since all servers failed + if err == nil { + t.Error("Expected an error when all servers fail, but got nil") + } + + // The error should mention that all servers failed + if err != nil && !contains(err.Error(), "all MCP servers failed") { + t.Errorf("Expected error message to mention all servers failed, got: %v", err) + } + + t.Logf("LoadTools failed gracefully with error: %v", err) +} + +// Helper function to check if a string contains a substring +func contains(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +}