Fix deadlocks

This commit is contained in:
Ed Zynda
2025-06-24 09:12:09 +03:00
parent ff49415679
commit 679709d078
2 changed files with 196 additions and 69 deletions
+106 -69
View File
@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"strings"
"time"
"github.com/bytedance/sonic"
"github.com/cloudwego/eino/components/tool"
@@ -47,82 +48,100 @@ func NewMCPToolManager() *MCPToolManager {
// LoadTools loads tools from MCP servers based on configuration
func (m *MCPToolManager) LoadTools(ctx context.Context, config *config.Config) error {
var loadErrors []string
for serverName, serverConfig := range config.MCPServers {
client, err := m.createMCPClient(ctx, serverName, serverConfig)
if err != nil {
return fmt.Errorf("failed to create MCP client for %s: %v", serverName, err)
if err := m.loadServerTools(ctx, serverName, serverConfig); err != nil {
loadErrors = append(loadErrors, fmt.Sprintf("server %s: %v", serverName, err))
fmt.Printf("Warning: Failed to load MCP server '%s': %v\n", serverName, err)
continue
}
}
m.clients[serverName] = client
// If all servers failed to load, return an error
if len(loadErrors) == len(config.MCPServers) && len(config.MCPServers) > 0 {
return fmt.Errorf("all MCP servers failed to load: %s", strings.Join(loadErrors, "; "))
}
// Initialize the client
if err := m.initializeClient(ctx, client); err != nil {
return fmt.Errorf("failed to initialize MCP client for %s: %v", serverName, err)
return nil
}
// loadServerTools loads tools from a single MCP server
func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) error {
client, err := m.createMCPClient(ctx, serverName, serverConfig)
if err != nil {
return fmt.Errorf("failed to create MCP client: %v", err)
}
m.clients[serverName] = client
// Initialize the client
if err := m.initializeClient(ctx, client); err != nil {
return fmt.Errorf("failed to initialize MCP client: %v", err)
}
// Get tools from this server
listResults, err := client.ListTools(ctx, mcp.ListToolsRequest{})
if err != nil {
return fmt.Errorf("failed to list tools: %v", err)
}
// Create name set for allowed tools
var nameSet map[string]struct{}
if len(serverConfig.AllowedTools) > 0 {
nameSet = make(map[string]struct{})
for _, name := range serverConfig.AllowedTools {
nameSet[name] = struct{}{}
}
}
// Get tools from this server
listResults, err := client.ListTools(ctx, mcp.ListToolsRequest{})
if err != nil {
return fmt.Errorf("failed to list tools from server %s: %v", serverName, err)
}
// Create name set for allowed tools
var nameSet map[string]struct{}
// Convert MCP tools to eino tools with prefixed names
for _, mcpTool := range listResults.Tools {
// Filter tools based on allowedTools/excludedTools
if len(serverConfig.AllowedTools) > 0 {
nameSet = make(map[string]struct{})
for _, name := range serverConfig.AllowedTools {
nameSet[name] = struct{}{}
}
}
// Convert MCP tools to eino tools with prefixed names
for _, mcpTool := range listResults.Tools {
// Filter tools based on allowedTools/excludedTools
if len(serverConfig.AllowedTools) > 0 {
if _, ok := nameSet[mcpTool.Name]; !ok {
continue
}
}
// Check if tool should be excluded
if m.shouldExcludeTool(mcpTool.Name, serverConfig) {
if _, ok := nameSet[mcpTool.Name]; !ok {
continue
}
// Convert schema
marshaledInputSchema, err := sonic.Marshal(mcpTool.InputSchema)
if err != nil {
return fmt.Errorf("conv mcp tool input schema fail(marshal): %w, tool name: %s", err, mcpTool.Name)
}
inputSchema := &openapi3.Schema{}
err = sonic.Unmarshal(marshaledInputSchema, inputSchema)
if err != nil {
return fmt.Errorf("conv mcp tool input schema fail(unmarshal): %w, tool name: %s", err, mcpTool.Name)
}
// Create prefixed tool name
prefixedName := fmt.Sprintf("%s__%s", serverName, mcpTool.Name)
// Create tool mapping
mapping := &toolMapping{
serverName: serverName,
originalName: mcpTool.Name,
client: client,
}
m.toolMap[prefixedName] = mapping
// Create eino tool
einoTool := &mcpToolImpl{
info: &schema.ToolInfo{
Name: prefixedName,
Desc: mcpTool.Description,
ParamsOneOf: schema.NewParamsOneOfByOpenAPIV3(inputSchema),
},
mapping: mapping,
}
m.tools = append(m.tools, einoTool)
}
// Check if tool should be excluded
if m.shouldExcludeTool(mcpTool.Name, serverConfig) {
continue
}
// Convert schema
marshaledInputSchema, err := sonic.Marshal(mcpTool.InputSchema)
if err != nil {
return fmt.Errorf("conv mcp tool input schema fail(marshal): %w, tool name: %s", err, mcpTool.Name)
}
inputSchema := &openapi3.Schema{}
err = sonic.Unmarshal(marshaledInputSchema, inputSchema)
if err != nil {
return fmt.Errorf("conv mcp tool input schema fail(unmarshal): %w, tool name: %s", err, mcpTool.Name)
}
// Create prefixed tool name
prefixedName := fmt.Sprintf("%s__%s", serverName, mcpTool.Name)
// Create tool mapping
mapping := &toolMapping{
serverName: serverName,
originalName: mcpTool.Name,
client: client,
}
m.toolMap[prefixedName] = mapping
// Create eino tool
einoTool := &mcpToolImpl{
info: &schema.ToolInfo{
Name: prefixedName,
Desc: mcpTool.Description,
ParamsOneOf: schema.NewParamsOneOfByOpenAPIV3(inputSchema),
},
mapping: mapping,
}
m.tools = append(m.tools, einoTool)
}
return nil
@@ -204,7 +223,18 @@ func (m *MCPToolManager) createMCPClient(ctx context.Context, serverName string,
env = append(env, fmt.Sprintf("%s=%v", k, v))
}
return client.NewStdioMCPClient(serverConfig.Command, env, serverConfig.Args...)
stdioClient, err := client.NewStdioMCPClient(serverConfig.Command, env, serverConfig.Args...)
if err != nil {
return nil, fmt.Errorf("failed to create stdio client: %v", err)
}
// Add a brief delay to allow the process to start and potentially fail
time.Sleep(100 * time.Millisecond)
// TODO: Add process health check here if the mcp-go library exposes process info
// For now, we rely on the timeout in initializeClient to catch dead processes
return stdioClient, nil
case "sse":
// SSE client
@@ -258,6 +288,10 @@ func (m *MCPToolManager) createMCPClient(ctx context.Context, serverName string,
}
func (m *MCPToolManager) initializeClient(ctx context.Context, client client.MCPClient) error {
// Create a timeout context for initialization to prevent deadlocks
initCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.ClientInfo = mcp.Implementation{
@@ -265,6 +299,9 @@ func (m *MCPToolManager) initializeClient(ctx context.Context, client client.MCP
Version: "1.0.0",
}
_, err := client.Initialize(ctx, initRequest)
return err
_, err := client.Initialize(initCtx, initRequest)
if err != nil {
return fmt.Errorf("initialization timeout or failed: %v", err)
}
return nil
}
+90
View File
@@ -0,0 +1,90 @@
package tools
import (
"context"
"testing"
"time"
"github.com/mark3labs/mcphost/internal/config"
)
func TestMCPToolManager_LoadTools_WithTimeout(t *testing.T) {
manager := NewMCPToolManager()
// Create a config with a non-existent command that should fail
cfg := &config.Config{
MCPServers: map[string]config.MCPServerConfig{
"test-server": {
Command: "non-existent-command",
Args: []string{"arg1", "arg2"},
},
},
}
// Create a context with a reasonable timeout
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
// This should not hang indefinitely and should return an error
start := time.Now()
err := manager.LoadTools(ctx, cfg)
duration := time.Since(start)
// The operation should complete within our timeout
if duration > 14*time.Second {
t.Errorf("LoadTools took too long: %v, expected to complete within 14 seconds", duration)
}
// We expect an error since the command doesn't exist, but it shouldn't be a timeout
if err == nil {
t.Error("Expected an error for non-existent command, but got nil")
}
t.Logf("LoadTools completed in %v with error: %v", duration, err)
}
func TestMCPToolManager_LoadTools_GracefulFailure(t *testing.T) {
manager := NewMCPToolManager()
// Create a config with multiple servers, some good and some bad
cfg := &config.Config{
MCPServers: map[string]config.MCPServerConfig{
"bad-server-1": {
Command: "non-existent-command-1",
Args: []string{"arg1"},
},
"bad-server-2": {
Command: "non-existent-command-2",
Args: []string{"arg2"},
},
},
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// This should fail gracefully and return an error since all servers failed
err := manager.LoadTools(ctx, cfg)
// We expect an error since all servers failed
if err == nil {
t.Error("Expected an error when all servers fail, but got nil")
}
// The error should mention that all servers failed
if err != nil && !contains(err.Error(), "all MCP servers failed") {
t.Errorf("Expected error message to mention all servers failed, got: %v", err)
}
t.Logf("LoadTools failed gracefully with error: %v", err)
}
// Helper function to check if a string contains a substring
func contains(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}