mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
Fix deadlocks
This commit is contained in:
+106
-69
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
@@ -47,82 +48,100 @@ func NewMCPToolManager() *MCPToolManager {
|
||||
|
||||
// LoadTools loads tools from MCP servers based on configuration
|
||||
func (m *MCPToolManager) LoadTools(ctx context.Context, config *config.Config) error {
|
||||
var loadErrors []string
|
||||
|
||||
for serverName, serverConfig := range config.MCPServers {
|
||||
client, err := m.createMCPClient(ctx, serverName, serverConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create MCP client for %s: %v", serverName, err)
|
||||
if err := m.loadServerTools(ctx, serverName, serverConfig); err != nil {
|
||||
loadErrors = append(loadErrors, fmt.Sprintf("server %s: %v", serverName, err))
|
||||
fmt.Printf("Warning: Failed to load MCP server '%s': %v\n", serverName, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
m.clients[serverName] = client
|
||||
// If all servers failed to load, return an error
|
||||
if len(loadErrors) == len(config.MCPServers) && len(config.MCPServers) > 0 {
|
||||
return fmt.Errorf("all MCP servers failed to load: %s", strings.Join(loadErrors, "; "))
|
||||
}
|
||||
|
||||
// Initialize the client
|
||||
if err := m.initializeClient(ctx, client); err != nil {
|
||||
return fmt.Errorf("failed to initialize MCP client for %s: %v", serverName, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadServerTools loads tools from a single MCP server
|
||||
func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) error {
|
||||
client, err := m.createMCPClient(ctx, serverName, serverConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create MCP client: %v", err)
|
||||
}
|
||||
|
||||
m.clients[serverName] = client
|
||||
|
||||
// Initialize the client
|
||||
if err := m.initializeClient(ctx, client); err != nil {
|
||||
return fmt.Errorf("failed to initialize MCP client: %v", err)
|
||||
}
|
||||
|
||||
// Get tools from this server
|
||||
listResults, err := client.ListTools(ctx, mcp.ListToolsRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list tools: %v", err)
|
||||
}
|
||||
|
||||
// Create name set for allowed tools
|
||||
var nameSet map[string]struct{}
|
||||
if len(serverConfig.AllowedTools) > 0 {
|
||||
nameSet = make(map[string]struct{})
|
||||
for _, name := range serverConfig.AllowedTools {
|
||||
nameSet[name] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Get tools from this server
|
||||
listResults, err := client.ListTools(ctx, mcp.ListToolsRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list tools from server %s: %v", serverName, err)
|
||||
}
|
||||
|
||||
// Create name set for allowed tools
|
||||
var nameSet map[string]struct{}
|
||||
// Convert MCP tools to eino tools with prefixed names
|
||||
for _, mcpTool := range listResults.Tools {
|
||||
// Filter tools based on allowedTools/excludedTools
|
||||
if len(serverConfig.AllowedTools) > 0 {
|
||||
nameSet = make(map[string]struct{})
|
||||
for _, name := range serverConfig.AllowedTools {
|
||||
nameSet[name] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert MCP tools to eino tools with prefixed names
|
||||
for _, mcpTool := range listResults.Tools {
|
||||
// Filter tools based on allowedTools/excludedTools
|
||||
if len(serverConfig.AllowedTools) > 0 {
|
||||
if _, ok := nameSet[mcpTool.Name]; !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Check if tool should be excluded
|
||||
if m.shouldExcludeTool(mcpTool.Name, serverConfig) {
|
||||
if _, ok := nameSet[mcpTool.Name]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert schema
|
||||
marshaledInputSchema, err := sonic.Marshal(mcpTool.InputSchema)
|
||||
if err != nil {
|
||||
return fmt.Errorf("conv mcp tool input schema fail(marshal): %w, tool name: %s", err, mcpTool.Name)
|
||||
}
|
||||
inputSchema := &openapi3.Schema{}
|
||||
err = sonic.Unmarshal(marshaledInputSchema, inputSchema)
|
||||
if err != nil {
|
||||
return fmt.Errorf("conv mcp tool input schema fail(unmarshal): %w, tool name: %s", err, mcpTool.Name)
|
||||
}
|
||||
|
||||
// Create prefixed tool name
|
||||
prefixedName := fmt.Sprintf("%s__%s", serverName, mcpTool.Name)
|
||||
|
||||
// Create tool mapping
|
||||
mapping := &toolMapping{
|
||||
serverName: serverName,
|
||||
originalName: mcpTool.Name,
|
||||
client: client,
|
||||
}
|
||||
m.toolMap[prefixedName] = mapping
|
||||
|
||||
// Create eino tool
|
||||
einoTool := &mcpToolImpl{
|
||||
info: &schema.ToolInfo{
|
||||
Name: prefixedName,
|
||||
Desc: mcpTool.Description,
|
||||
ParamsOneOf: schema.NewParamsOneOfByOpenAPIV3(inputSchema),
|
||||
},
|
||||
mapping: mapping,
|
||||
}
|
||||
|
||||
m.tools = append(m.tools, einoTool)
|
||||
}
|
||||
|
||||
// Check if tool should be excluded
|
||||
if m.shouldExcludeTool(mcpTool.Name, serverConfig) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert schema
|
||||
marshaledInputSchema, err := sonic.Marshal(mcpTool.InputSchema)
|
||||
if err != nil {
|
||||
return fmt.Errorf("conv mcp tool input schema fail(marshal): %w, tool name: %s", err, mcpTool.Name)
|
||||
}
|
||||
inputSchema := &openapi3.Schema{}
|
||||
err = sonic.Unmarshal(marshaledInputSchema, inputSchema)
|
||||
if err != nil {
|
||||
return fmt.Errorf("conv mcp tool input schema fail(unmarshal): %w, tool name: %s", err, mcpTool.Name)
|
||||
}
|
||||
|
||||
// Create prefixed tool name
|
||||
prefixedName := fmt.Sprintf("%s__%s", serverName, mcpTool.Name)
|
||||
|
||||
// Create tool mapping
|
||||
mapping := &toolMapping{
|
||||
serverName: serverName,
|
||||
originalName: mcpTool.Name,
|
||||
client: client,
|
||||
}
|
||||
m.toolMap[prefixedName] = mapping
|
||||
|
||||
// Create eino tool
|
||||
einoTool := &mcpToolImpl{
|
||||
info: &schema.ToolInfo{
|
||||
Name: prefixedName,
|
||||
Desc: mcpTool.Description,
|
||||
ParamsOneOf: schema.NewParamsOneOfByOpenAPIV3(inputSchema),
|
||||
},
|
||||
mapping: mapping,
|
||||
}
|
||||
|
||||
m.tools = append(m.tools, einoTool)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -204,7 +223,18 @@ func (m *MCPToolManager) createMCPClient(ctx context.Context, serverName string,
|
||||
env = append(env, fmt.Sprintf("%s=%v", k, v))
|
||||
}
|
||||
|
||||
return client.NewStdioMCPClient(serverConfig.Command, env, serverConfig.Args...)
|
||||
stdioClient, err := client.NewStdioMCPClient(serverConfig.Command, env, serverConfig.Args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create stdio client: %v", err)
|
||||
}
|
||||
|
||||
// Add a brief delay to allow the process to start and potentially fail
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// TODO: Add process health check here if the mcp-go library exposes process info
|
||||
// For now, we rely on the timeout in initializeClient to catch dead processes
|
||||
|
||||
return stdioClient, nil
|
||||
|
||||
case "sse":
|
||||
// SSE client
|
||||
@@ -258,6 +288,10 @@ func (m *MCPToolManager) createMCPClient(ctx context.Context, serverName string,
|
||||
}
|
||||
|
||||
func (m *MCPToolManager) initializeClient(ctx context.Context, client client.MCPClient) error {
|
||||
// Create a timeout context for initialization to prevent deadlocks
|
||||
initCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
initRequest := mcp.InitializeRequest{}
|
||||
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
|
||||
initRequest.Params.ClientInfo = mcp.Implementation{
|
||||
@@ -265,6 +299,9 @@ func (m *MCPToolManager) initializeClient(ctx context.Context, client client.MCP
|
||||
Version: "1.0.0",
|
||||
}
|
||||
|
||||
_, err := client.Initialize(ctx, initRequest)
|
||||
return err
|
||||
_, err := client.Initialize(initCtx, initRequest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialization timeout or failed: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/mcphost/internal/config"
|
||||
)
|
||||
|
||||
func TestMCPToolManager_LoadTools_WithTimeout(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
|
||||
// Create a config with a non-existent command that should fail
|
||||
cfg := &config.Config{
|
||||
MCPServers: map[string]config.MCPServerConfig{
|
||||
"test-server": {
|
||||
Command: "non-existent-command",
|
||||
Args: []string{"arg1", "arg2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create a context with a reasonable timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// This should not hang indefinitely and should return an error
|
||||
start := time.Now()
|
||||
err := manager.LoadTools(ctx, cfg)
|
||||
duration := time.Since(start)
|
||||
|
||||
// The operation should complete within our timeout
|
||||
if duration > 14*time.Second {
|
||||
t.Errorf("LoadTools took too long: %v, expected to complete within 14 seconds", duration)
|
||||
}
|
||||
|
||||
// We expect an error since the command doesn't exist, but it shouldn't be a timeout
|
||||
if err == nil {
|
||||
t.Error("Expected an error for non-existent command, but got nil")
|
||||
}
|
||||
|
||||
t.Logf("LoadTools completed in %v with error: %v", duration, err)
|
||||
}
|
||||
|
||||
func TestMCPToolManager_LoadTools_GracefulFailure(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
|
||||
// Create a config with multiple servers, some good and some bad
|
||||
cfg := &config.Config{
|
||||
MCPServers: map[string]config.MCPServerConfig{
|
||||
"bad-server-1": {
|
||||
Command: "non-existent-command-1",
|
||||
Args: []string{"arg1"},
|
||||
},
|
||||
"bad-server-2": {
|
||||
Command: "non-existent-command-2",
|
||||
Args: []string{"arg2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// This should fail gracefully and return an error since all servers failed
|
||||
err := manager.LoadTools(ctx, cfg)
|
||||
|
||||
// We expect an error since all servers failed
|
||||
if err == nil {
|
||||
t.Error("Expected an error when all servers fail, but got nil")
|
||||
}
|
||||
|
||||
// The error should mention that all servers failed
|
||||
if err != nil && !contains(err.Error(), "all MCP servers failed") {
|
||||
t.Errorf("Expected error message to mention all servers failed, got: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("LoadTools failed gracefully with error: %v", err)
|
||||
}
|
||||
|
||||
// Helper function to check if a string contains a substring
|
||||
func contains(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
Reference in New Issue
Block a user