From 57efdb5332d33cb4ebff2cf109a82eba612e81df Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Tue, 10 Jun 2025 13:44:17 +0300 Subject: [PATCH] cleanup --- go.mod | 11 +- go.sum | 14 +- internal/agent/agent.go | 317 +--------------------------------------- internal/tools/mcp.go | 200 +++++++++++++------------ 4 files changed, 129 insertions(+), 413 deletions(-) diff --git a/go.mod b/go.mod index cb542f8c..cfe90e9e 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.23.4 toolchain go1.23.9 require ( + github.com/bytedance/sonic v1.13.3 github.com/charmbracelet/huh v0.3.0 github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 github.com/cloudwego/eino v0.3.41 @@ -12,6 +13,7 @@ require ( github.com/cloudwego/eino-ext/components/model/gemini v0.0.0-20250609074000-b7f307dffa18 github.com/cloudwego/eino-ext/components/model/ollama v0.0.0-20250609074000-b7f307dffa18 github.com/cloudwego/eino-ext/components/model/openai v0.0.0-20250609074000-b7f307dffa18 + github.com/getkin/kin-openapi v0.118.0 github.com/google/generative-ai-go v0.19.0 github.com/mark3labs/mcp-go v0.31.0 github.com/spf13/cobra v1.8.1 @@ -46,7 +48,6 @@ require ( github.com/aws/aws-sdk-go-v2/service/sts v1.33.9 // indirect github.com/aws/smithy-go v1.22.1 // indirect github.com/aymerick/douceur v0.2.0 // indirect - github.com/bytedance/sonic v1.13.2 // indirect github.com/bytedance/sonic/loader v0.2.4 // indirect github.com/catppuccin/go v0.2.0 // indirect github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect @@ -59,11 +60,10 @@ require ( github.com/evanphx/json-patch v0.5.2 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.8.0 // indirect - github.com/getkin/kin-openapi v0.118.0 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect - github.com/go-openapi/jsonpointer v0.19.5 // indirect - github.com/go-openapi/swag v0.19.5 // indirect + github.com/go-openapi/jsonpointer v0.21.0 // indirect + github.com/go-openapi/swag v0.23.0 // indirect github.com/go-viper/mapstructure/v2 v2.2.1 // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/google/uuid v1.6.0 // indirect @@ -85,7 +85,7 @@ require ( github.com/nikolalohinski/gonja v1.5.3 // indirect github.com/ollama/ollama v0.5.12 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect - github.com/perimeterx/marshmallow v1.1.4 // indirect + github.com/perimeterx/marshmallow v1.1.5 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/sagikazarmark/locafero v0.7.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect @@ -121,7 +121,6 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20250313205543-e70fdf4c4cb4 // indirect google.golang.org/grpc v1.71.0 // indirect google.golang.org/protobuf v1.36.6 // indirect - gopkg.in/yaml.v2 v2.4.0 // indirect ) require ( diff --git a/go.sum b/go.sum index b00aa7b9..71c2c6b4 100644 --- a/go.sum +++ b/go.sum @@ -63,8 +63,8 @@ github.com/bugsnag/bugsnag-go v1.4.0/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqR github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE= github.com/bytedance/mockey v1.2.13 h1:jokWZAm/pUEbD939Rhznz615MKUCZNuvCFQlJ2+ntoo= github.com/bytedance/mockey v1.2.13/go.mod h1:1BPHF9sol5R1ud/+0VEHGQq/+i2lN+GTsr3O2Q9IENY= -github.com/bytedance/sonic v1.13.2 h1:8/H1FempDZqC4VqjptGo14QQlJx8VdZJegxs6wwfqpQ= -github.com/bytedance/sonic v1.13.2/go.mod h1:o68xyaF9u2gvVBuGHPlUVCy+ZfmNNO5ETf1+KgkJhz4= +github.com/bytedance/sonic v1.13.3 h1:MS8gmaH16Gtirygw7jV91pDCN33NyMrPbN7qiYhEsF0= +github.com/bytedance/sonic v1.13.3/go.mod h1:o68xyaF9u2gvVBuGHPlUVCy+ZfmNNO5ETf1+KgkJhz4= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/bytedance/sonic/loader v0.2.4 h1:ZWCw4stuXUsn1/+zQDqeE7JKP+QO47tz7QCNan80NzY= github.com/bytedance/sonic/loader v0.2.4/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= @@ -137,10 +137,12 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= -github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY= github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= -github.com/go-openapi/swag v0.19.5 h1:lTz6Ys4CmqqCQmZPBlbQENR1/GucA2bzYTE12Pw4tFY= +github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ= +github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY= github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= +github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= +github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss= @@ -243,8 +245,9 @@ github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+W github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= -github.com/perimeterx/marshmallow v1.1.4 h1:pZLDH9RjlLGGorbXhcaQLhfuV0pFMNfPO55FuFkxqLw= github.com/perimeterx/marshmallow v1.1.4/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= +github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s= +github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -395,7 +398,6 @@ gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMy gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 55ef2dcd..9409c647 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -3,32 +3,15 @@ package agent import ( "context" "fmt" - "io" - "sync" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" - "github.com/cloudwego/eino/compose" - "github.com/cloudwego/eino/flow/agent" "github.com/cloudwego/eino/schema" "github.com/mark3labs/mcphost/internal/config" "github.com/mark3labs/mcphost/internal/models" "github.com/mark3labs/mcphost/internal/tools" ) -type state struct { - Messages []*schema.Message - ReturnDirectlyToolCallID string -} - -const ( - nodeKeyTools = "tools" - nodeKeyModel = "chat" -) - -// MessageModifier modify the input messages before the model is called. -type MessageModifier func(ctx context.Context, input []*schema.Message) []*schema.Message - // AgentConfig is the config for agent. type AgentConfig struct { ModelConfig *models.ProviderConfig @@ -36,17 +19,6 @@ type AgentConfig struct { SystemPrompt string MaxSteps int MessageWindow int - - // MessageModifier. - // modify the input messages before the model is called, it's useful when you want to add some system prompt or other messages. - MessageModifier MessageModifier - - // Tools that will make agent return directly when the tool is called. - // When multiple tools are called and more than one tool is in the return directly list, only the first one will be returned. - ToolReturnDirectly map[string]struct{} - - // StreamOutputHandler is a function to determine whether the model's streaming output contains tool calls. - StreamToolCallChecker func(ctx context.Context, modelOutput *schema.StreamReader[*schema.Message]) (bool, error) } // ToolCallHandler is a function type for handling tool calls as they happen @@ -64,49 +36,14 @@ type ResponseHandler func(content string) // ToolCallContentHandler is a function type for handling content that accompanies tool calls type ToolCallContentHandler func(content string) -func firstChunkStreamToolCallChecker(_ context.Context, sr *schema.StreamReader[*schema.Message]) (bool, error) { - defer sr.Close() - - for { - msg, err := sr.Recv() - if err == io.EOF { - return false, nil - } - if err != nil { - return false, err - } - - if len(msg.ToolCalls) > 0 { - return true, nil - } - - if len(msg.Content) == 0 { // skip empty chunks at the front - continue - } - - return false, nil - } -} - -const ( - GraphName = "Agent" - ModelNodeName = "ChatModel" - ToolsNodeName = "Tools" -) - // Agent is the agent with real-time tool call display. type Agent struct { - runnable compose.Runnable[[]*schema.Message, *schema.Message] - graph *compose.Graph[[]*schema.Message, *schema.Message] - graphAddNodeOpts []compose.GraphAddNodeOpt - toolManager *tools.MCPToolManager - model model.ToolCallingChatModel - maxSteps int - systemPrompt string + toolManager *tools.MCPToolManager + model model.ToolCallingChatModel + maxSteps int + systemPrompt string } -var registerStateOnce sync.Once - // NewAgent creates an agent with MCP tool integration and real-time tool call display func NewAgent(ctx context.Context, config *AgentConfig) (*Agent, error) { // Create the LLM provider @@ -121,252 +58,19 @@ func NewAgent(ctx context.Context, config *AgentConfig) (*Agent, error) { return nil, fmt.Errorf("failed to load MCP tools: %v", err) } - var ( - toolsNode *compose.ToolsNode - toolInfos []*schema.ToolInfo - toolCallChecker = config.StreamToolCallChecker - messageModifier = config.MessageModifier - ) - - registerStateOnce.Do(func() { - err = compose.RegisterSerializableType[state]("_eino_agent_state") - }) - if err != nil { - return nil, err - } - - if toolCallChecker == nil { - toolCallChecker = firstChunkStreamToolCallChecker - } - - // Create tools config - toolsConfig := compose.ToolsNodeConfig{ - Tools: toolManager.GetTools(), - } - - // Only set up tools if we have any - hasTools := len(toolsConfig.Tools) > 0 - - if hasTools { - if toolInfos, err = genToolInfos(ctx, toolsConfig); err != nil { - return nil, err - } - - if toolsNode, err = compose.NewToolNode(ctx, &toolsConfig); err != nil { - return nil, err - } - } - - chatModel, err := agent.ChatModelWithTools(nil, model, toolInfos) - if err != nil { - // If binding tools fails and we have no tools, just use the model directly - if !hasTools { - chatModel = model - } else { - return nil, err - } - } - maxSteps := config.MaxSteps if maxSteps == 0 { maxSteps = 20 } - graph := compose.NewGraph[[]*schema.Message, *schema.Message](compose.WithGenLocalState(func(ctx context.Context) *state { - return &state{Messages: make([]*schema.Message, 0, maxSteps+1)} - })) - - modelPreHandle := func(ctx context.Context, input []*schema.Message, state *state) ([]*schema.Message, error) { - state.Messages = append(state.Messages, input...) - - // Add system prompt if provided and not already present - if config.SystemPrompt != "" { - hasSystemMessage := false - if len(state.Messages) > 0 && state.Messages[0].Role == schema.System { - hasSystemMessage = true - } - - if !hasSystemMessage { - systemMsg := schema.SystemMessage(config.SystemPrompt) - state.Messages = append([]*schema.Message{systemMsg}, state.Messages...) - } - } - - if messageModifier == nil { - return state.Messages, nil - } - - modifiedInput := make([]*schema.Message, len(state.Messages)) - copy(modifiedInput, state.Messages) - return messageModifier(ctx, modifiedInput), nil - } - - if err = graph.AddChatModelNode(nodeKeyModel, chatModel, compose.WithStatePreHandler(modelPreHandle), compose.WithNodeName(ModelNodeName)); err != nil { - return nil, err - } - - if err = graph.AddEdge(compose.START, nodeKeyModel); err != nil { - return nil, err - } - - // Only add tools node and related logic if we have tools - if hasTools { - toolsNodePreHandle := func(ctx context.Context, input *schema.Message, state *state) (*schema.Message, error) { - if input == nil { - return state.Messages[len(state.Messages)-1], nil // used for rerun interrupt resume - } - state.Messages = append(state.Messages, input) - state.ReturnDirectlyToolCallID = getReturnDirectlyToolCallID(input, config.ToolReturnDirectly) - return input, nil - } - if err = graph.AddToolsNode(nodeKeyTools, toolsNode, compose.WithStatePreHandler(toolsNodePreHandle), compose.WithNodeName(ToolsNodeName)); err != nil { - return nil, err - } - - modelPostBranchCondition := func(_ context.Context, sr *schema.StreamReader[*schema.Message]) (endNode string, err error) { - if isToolCall, err := toolCallChecker(ctx, sr); err != nil { - return "", err - } else if isToolCall { - return nodeKeyTools, nil - } - return compose.END, nil - } - - if err = graph.AddBranch(nodeKeyModel, compose.NewStreamGraphBranch(modelPostBranchCondition, map[string]bool{nodeKeyTools: true, compose.END: true})); err != nil { - return nil, err - } - - if len(config.ToolReturnDirectly) > 0 { - if err = buildReturnDirectly(graph); err != nil { - return nil, err - } - } else if err = graph.AddEdge(nodeKeyTools, nodeKeyModel); err != nil { - return nil, err - } - } else { - // No tools, so model goes directly to END - if err = graph.AddEdge(nodeKeyModel, compose.END); err != nil { - return nil, err - } - } - - compileOpts := []compose.GraphCompileOption{compose.WithMaxRunSteps(maxSteps), compose.WithNodeTriggerMode(compose.AnyPredecessor), compose.WithGraphName(GraphName)} - runnable, err := graph.Compile(ctx, compileOpts...) - if err != nil { - return nil, err - } - return &Agent{ - runnable: runnable, - graph: graph, - graphAddNodeOpts: []compose.GraphAddNodeOpt{compose.WithGraphCompileOptions(compileOpts...)}, - toolManager: toolManager, - model: model, - maxSteps: maxSteps, - systemPrompt: config.SystemPrompt, + toolManager: toolManager, + model: model, + maxSteps: maxSteps, + systemPrompt: config.SystemPrompt, }, nil } -func buildReturnDirectly(graph *compose.Graph[[]*schema.Message, *schema.Message]) (err error) { - directReturn := func(ctx context.Context, msgs *schema.StreamReader[[]*schema.Message]) (*schema.StreamReader[*schema.Message], error) { - return schema.StreamReaderWithConvert(msgs, func(msgs []*schema.Message) (*schema.Message, error) { - var msg *schema.Message - err = compose.ProcessState[*state](ctx, func(_ context.Context, state *state) error { - for i := range msgs { - if msgs[i] != nil && msgs[i].ToolCallID == state.ReturnDirectlyToolCallID { - msg = msgs[i] - return nil - } - } - return nil - }) - if err != nil { - return nil, err - } - if msg == nil { - return nil, schema.ErrNoValue - } - return msg, nil - }), nil - } - - nodeKeyDirectReturn := "direct_return" - if err = graph.AddLambdaNode(nodeKeyDirectReturn, compose.TransformableLambda(directReturn)); err != nil { - return err - } - - // this branch checks if the tool called should return directly. It either leads to END or back to ChatModel - err = graph.AddBranch(nodeKeyTools, compose.NewStreamGraphBranch(func(ctx context.Context, msgsStream *schema.StreamReader[[]*schema.Message]) (endNode string, err error) { - msgsStream.Close() - - err = compose.ProcessState[*state](ctx, func(_ context.Context, state *state) error { - if len(state.ReturnDirectlyToolCallID) > 0 { - endNode = nodeKeyDirectReturn - } else { - endNode = nodeKeyModel - } - return nil - }) - if err != nil { - return "", err - } - return endNode, nil - }, map[string]bool{nodeKeyModel: true, nodeKeyDirectReturn: true})) - if err != nil { - return err - } - - return graph.AddEdge(nodeKeyDirectReturn, compose.END) -} - -func genToolInfos(ctx context.Context, config compose.ToolsNodeConfig) ([]*schema.ToolInfo, error) { - toolInfos := make([]*schema.ToolInfo, 0, len(config.Tools)) - for _, t := range config.Tools { - tl, err := t.Info(ctx) - if err != nil { - return nil, err - } - - toolInfos = append(toolInfos, tl) - } - - return toolInfos, nil -} - -func getReturnDirectlyToolCallID(input *schema.Message, toolReturnDirectly map[string]struct{}) string { - if len(toolReturnDirectly) == 0 { - return "" - } - - for _, toolCall := range input.ToolCalls { - if _, ok := toolReturnDirectly[toolCall.Function.Name]; ok { - return toolCall.ID - } - } - - return "" -} - -// Generate generates a response from the agent. -func (a *Agent) Generate(ctx context.Context, input []*schema.Message, opts ...compose.Option) (*schema.Message, error) { - // Convert compose options to agent options - agentOpts := []agent.AgentOption{} - if len(opts) > 0 { - agentOpts = append(agentOpts, agent.WithComposeOptions(opts...)) - } - return a.runnable.Invoke(ctx, input, agent.GetComposeOptions(agentOpts...)...) -} - -// Stream calls the agent and returns a stream response. -func (a *Agent) Stream(ctx context.Context, input []*schema.Message, opts ...compose.Option) (output *schema.StreamReader[*schema.Message], err error) { - // Convert compose options to agent options - agentOpts := []agent.AgentOption{} - if len(opts) > 0 { - agentOpts = append(agentOpts, agent.WithComposeOptions(opts...)) - } - return a.runnable.Stream(ctx, input, agent.GetComposeOptions(agentOpts...)...) -} - // GenerateWithLoop processes messages with a custom loop that displays tool calls in real-time func (a *Agent) GenerateWithLoop(ctx context.Context, messages []*schema.Message, onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler) (*schema.Message, error) { @@ -489,8 +193,3 @@ func (a *Agent) GetTools() []tool.BaseTool { func (a *Agent) Close() error { return a.toolManager.Close() } - -// ExportGraph exports the underlying graph from Agent, along with the []compose.GraphAddNodeOpt to be used when adding this graph to another graph. -func (a *Agent) ExportGraph() (compose.AnyGraph, []compose.GraphAddNodeOpt) { - return a.graph, a.graphAddNodeOpts -} diff --git a/internal/tools/mcp.go b/internal/tools/mcp.go index 697d9ecf..3e737b87 100644 --- a/internal/tools/mcp.go +++ b/internal/tools/mcp.go @@ -5,84 +5,33 @@ import ( "encoding/json" "fmt" + "github.com/bytedance/sonic" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/schema" + "github.com/getkin/kin-openapi/openapi3" "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcphost/internal/config" ) -// MCPTool wraps an MCP tool to implement eino's InvokableTool interface -type MCPTool struct { - client client.MCPClient - toolInfo *mcp.Tool - name string -} - -// Info returns the tool information for eino -func (t *MCPTool) Info(ctx context.Context) (*schema.ToolInfo, error) { - // Convert MCP tool schema to eino schema - properties := make(map[string]*schema.ParameterInfo) - - // Handle the input schema - if t.toolInfo.InputSchema.Properties != nil { - for name, prop := range t.toolInfo.InputSchema.Properties { - if propMap, ok := prop.(map[string]interface{}); ok { - paramInfo := &schema.ParameterInfo{ - Type: schema.String, // Default type - } - if typeVal, ok := propMap["type"].(string); ok { - paramInfo.Type = schema.DataType(typeVal) - } - if desc, ok := propMap["description"].(string); ok { - paramInfo.Desc = desc - } - properties[name] = paramInfo - } - } - } - - return &schema.ToolInfo{ - Name: t.name, - Desc: t.toolInfo.Description, - ParamsOneOf: schema.NewParamsOneOfByParams(properties), - }, nil -} - -// InvokableRun implements the InvokableTool interface -func (t *MCPTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { - var args map[string]interface{} - if err := json.Unmarshal([]byte(argumentsInJSON), &args); err != nil { - return "", fmt.Errorf("failed to parse arguments: %v", err) - } - - req := mcp.CallToolRequest{} - req.Params.Name = t.toolInfo.Name - req.Params.Arguments = args - - result, err := t.client.CallTool(ctx, req) - if err != nil { - return "", fmt.Errorf("failed to call tool: %v", err) - } - - // Convert result to string - if result.Content != nil { - var resultText string - for _, item := range result.Content { - if textContent, ok := item.(mcp.TextContent); ok { - resultText += textContent.Text + " " - } - } - return resultText, nil - } - - return "", nil -} - // MCPToolManager manages MCP tools and clients type MCPToolManager struct { - clients map[string]client.MCPClient - tools []tool.BaseTool + clients map[string]client.MCPClient + tools []tool.BaseTool + toolMap map[string]*toolMapping // maps prefixed tool names to their server and original name +} + +// toolMapping stores the mapping between prefixed tool names and their original details +type toolMapping struct { + serverName string + originalName string + client client.MCPClient +} + +// mcpToolImpl implements the eino tool interface with server prefixing +type mcpToolImpl struct { + info *schema.ToolInfo + mapping *toolMapping } // NewMCPToolManager creates a new MCP tool manager @@ -90,6 +39,7 @@ func NewMCPToolManager() *MCPToolManager { return &MCPToolManager{ clients: make(map[string]client.MCPClient), tools: make([]tool.BaseTool, 0), + toolMap: make(map[string]*toolMapping), } } @@ -109,23 +59,66 @@ func (m *MCPToolManager) LoadTools(ctx context.Context, config *config.Config) e } // Get tools from this server - toolsResult, err := client.ListTools(ctx, mcp.ListToolsRequest{}) + listResults, err := client.ListTools(ctx, mcp.ListToolsRequest{}) if err != nil { return fmt.Errorf("failed to list tools from server %s: %v", serverName, err) } - // Convert MCP tools to eino tools - for _, mcpTool := range toolsResult.Tools { + // Create name set for allowed tools + var nameSet map[string]struct{} + if len(serverConfig.AllowedTools) > 0 { + nameSet = make(map[string]struct{}) + for _, name := range serverConfig.AllowedTools { + nameSet[name] = struct{}{} + } + } + + // Convert MCP tools to eino tools with prefixed names + for _, mcpTool := range listResults.Tools { // Filter tools based on allowedTools/excludedTools - if !m.shouldIncludeTool(mcpTool.Name, serverConfig) { + if len(serverConfig.AllowedTools) > 0 { + if _, ok := nameSet[mcpTool.Name]; !ok { + continue + } + } + + // Check if tool should be excluded + if m.shouldExcludeTool(mcpTool.Name, serverConfig) { continue } - einoTool := &MCPTool{ - client: client, - toolInfo: &mcpTool, - name: fmt.Sprintf("%s__%s", serverName, mcpTool.Name), + // Convert schema + marshaledInputSchema, err := sonic.Marshal(mcpTool.InputSchema) + if err != nil { + return fmt.Errorf("conv mcp tool input schema fail(marshal): %w, tool name: %s", err, mcpTool.Name) } + inputSchema := &openapi3.Schema{} + err = sonic.Unmarshal(marshaledInputSchema, inputSchema) + if err != nil { + return fmt.Errorf("conv mcp tool input schema fail(unmarshal): %w, tool name: %s", err, mcpTool.Name) + } + + // Create prefixed tool name + prefixedName := fmt.Sprintf("%s__%s", serverName, mcpTool.Name) + + // Create tool mapping + mapping := &toolMapping{ + serverName: serverName, + originalName: mcpTool.Name, + client: client, + } + m.toolMap[prefixedName] = mapping + + // Create eino tool + einoTool := &mcpToolImpl{ + info: &schema.ToolInfo{ + Name: prefixedName, + Desc: mcpTool.Description, + ParamsOneOf: schema.NewParamsOneOfByOpenAPIV3(inputSchema), + }, + mapping: mapping, + } + m.tools = append(m.tools, einoTool) } } @@ -133,6 +126,40 @@ func (m *MCPToolManager) LoadTools(ctx context.Context, config *config.Config) e return nil } +// Info returns the tool information +func (t *mcpToolImpl) Info(ctx context.Context) (*schema.ToolInfo, error) { + return t.info, nil +} + +// InvokableRun executes the tool by mapping back to the original name and server +func (t *mcpToolImpl) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { + result, err := t.mapping.client.CallTool(ctx, mcp.CallToolRequest{ + Request: mcp.Request{ + Method: "tools/call", + }, + Params: struct { + Name string `json:"name"` + Arguments any `json:"arguments,omitempty"` + Meta *mcp.Meta `json:"_meta,omitempty"` + }{ + Name: t.mapping.originalName, // Use original name, not prefixed + Arguments: json.RawMessage(argumentsInJSON), + }, + }) + if err != nil { + return "", fmt.Errorf("failed to call mcp tool: %w", err) + } + + marshaledResult, err := sonic.MarshalString(result) + if err != nil { + return "", fmt.Errorf("failed to marshal mcp tool result: %w", err) + } + if result.IsError { + return "", fmt.Errorf("failed to call mcp tool, mcp server return error: %s", marshaledResult) + } + return marshaledResult, nil +} + // GetTools returns all loaded tools func (m *MCPToolManager) GetTools() []tool.BaseTool { return m.tools @@ -148,29 +175,18 @@ func (m *MCPToolManager) Close() error { return nil } -// shouldIncludeTool determines if a tool should be included based on allowedTools/excludedTools -func (m *MCPToolManager) shouldIncludeTool(toolName string, serverConfig config.MCPServerConfig) bool { - // If allowedTools is specified, only include tools in the list - if len(serverConfig.AllowedTools) > 0 { - for _, allowedTool := range serverConfig.AllowedTools { - if allowedTool == toolName { - return true - } - } - return false - } - +// shouldExcludeTool determines if a tool should be excluded based on excludedTools +func (m *MCPToolManager) shouldExcludeTool(toolName string, serverConfig config.MCPServerConfig) bool { // If excludedTools is specified, exclude tools in the list if len(serverConfig.ExcludedTools) > 0 { for _, excludedTool := range serverConfig.ExcludedTools { if excludedTool == toolName { - return false + return true } } } - // Include by default - return true + return false } func (m *MCPToolManager) createMCPClient(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (client.MCPClient, error) {