This commit is contained in:
Ed Zynda
2025-06-10 13:44:17 +03:00
parent 13ede07ea5
commit 57efdb5332
4 changed files with 129 additions and 413 deletions
+5 -6
View File
@@ -5,6 +5,7 @@ go 1.23.4
toolchain go1.23.9
require (
github.com/bytedance/sonic v1.13.3
github.com/charmbracelet/huh v0.3.0
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834
github.com/cloudwego/eino v0.3.41
@@ -12,6 +13,7 @@ require (
github.com/cloudwego/eino-ext/components/model/gemini v0.0.0-20250609074000-b7f307dffa18
github.com/cloudwego/eino-ext/components/model/ollama v0.0.0-20250609074000-b7f307dffa18
github.com/cloudwego/eino-ext/components/model/openai v0.0.0-20250609074000-b7f307dffa18
github.com/getkin/kin-openapi v0.118.0
github.com/google/generative-ai-go v0.19.0
github.com/mark3labs/mcp-go v0.31.0
github.com/spf13/cobra v1.8.1
@@ -46,7 +48,6 @@ require (
github.com/aws/aws-sdk-go-v2/service/sts v1.33.9 // indirect
github.com/aws/smithy-go v1.22.1 // indirect
github.com/aymerick/douceur v0.2.0 // indirect
github.com/bytedance/sonic v1.13.2 // indirect
github.com/bytedance/sonic/loader v0.2.4 // indirect
github.com/catppuccin/go v0.2.0 // indirect
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
@@ -59,11 +60,10 @@ require (
github.com/evanphx/json-patch v0.5.2 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fsnotify/fsnotify v1.8.0 // indirect
github.com/getkin/kin-openapi v0.118.0 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-openapi/jsonpointer v0.19.5 // indirect
github.com/go-openapi/swag v0.19.5 // indirect
github.com/go-openapi/jsonpointer v0.21.0 // indirect
github.com/go-openapi/swag v0.23.0 // indirect
github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
github.com/google/s2a-go v0.1.9 // indirect
github.com/google/uuid v1.6.0 // indirect
@@ -85,7 +85,7 @@ require (
github.com/nikolalohinski/gonja v1.5.3 // indirect
github.com/ollama/ollama v0.5.12 // indirect
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
github.com/perimeterx/marshmallow v1.1.4 // indirect
github.com/perimeterx/marshmallow v1.1.5 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/sagikazarmark/locafero v0.7.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
@@ -121,7 +121,6 @@ require (
google.golang.org/genproto/googleapis/rpc v0.0.0-20250313205543-e70fdf4c4cb4 // indirect
google.golang.org/grpc v1.71.0 // indirect
google.golang.org/protobuf v1.36.6 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
)
require (
+8 -6
View File
@@ -63,8 +63,8 @@ github.com/bugsnag/bugsnag-go v1.4.0/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqR
github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE=
github.com/bytedance/mockey v1.2.13 h1:jokWZAm/pUEbD939Rhznz615MKUCZNuvCFQlJ2+ntoo=
github.com/bytedance/mockey v1.2.13/go.mod h1:1BPHF9sol5R1ud/+0VEHGQq/+i2lN+GTsr3O2Q9IENY=
github.com/bytedance/sonic v1.13.2 h1:8/H1FempDZqC4VqjptGo14QQlJx8VdZJegxs6wwfqpQ=
github.com/bytedance/sonic v1.13.2/go.mod h1:o68xyaF9u2gvVBuGHPlUVCy+ZfmNNO5ETf1+KgkJhz4=
github.com/bytedance/sonic v1.13.3 h1:MS8gmaH16Gtirygw7jV91pDCN33NyMrPbN7qiYhEsF0=
github.com/bytedance/sonic v1.13.3/go.mod h1:o68xyaF9u2gvVBuGHPlUVCy+ZfmNNO5ETf1+KgkJhz4=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/bytedance/sonic/loader v0.2.4 h1:ZWCw4stuXUsn1/+zQDqeE7JKP+QO47tz7QCNan80NzY=
github.com/bytedance/sonic/loader v0.2.4/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI=
@@ -137,10 +137,12 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY=
github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
github.com/go-openapi/swag v0.19.5 h1:lTz6Ys4CmqqCQmZPBlbQENR1/GucA2bzYTE12Pw4tFY=
github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ=
github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY=
github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk=
github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE=
github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ=
github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM=
github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss=
@@ -243,8 +245,9 @@ github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+W
github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc=
github.com/perimeterx/marshmallow v1.1.4 h1:pZLDH9RjlLGGorbXhcaQLhfuV0pFMNfPO55FuFkxqLw=
github.com/perimeterx/marshmallow v1.1.4/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw=
github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s=
github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@@ -395,7 +398,6 @@ gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMy
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+8 -309
View File
@@ -3,32 +3,15 @@ package agent
import (
"context"
"fmt"
"io"
"sync"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/flow/agent"
"github.com/cloudwego/eino/schema"
"github.com/mark3labs/mcphost/internal/config"
"github.com/mark3labs/mcphost/internal/models"
"github.com/mark3labs/mcphost/internal/tools"
)
type state struct {
Messages []*schema.Message
ReturnDirectlyToolCallID string
}
const (
nodeKeyTools = "tools"
nodeKeyModel = "chat"
)
// MessageModifier modify the input messages before the model is called.
type MessageModifier func(ctx context.Context, input []*schema.Message) []*schema.Message
// AgentConfig is the config for agent.
type AgentConfig struct {
ModelConfig *models.ProviderConfig
@@ -36,17 +19,6 @@ type AgentConfig struct {
SystemPrompt string
MaxSteps int
MessageWindow int
// MessageModifier.
// modify the input messages before the model is called, it's useful when you want to add some system prompt or other messages.
MessageModifier MessageModifier
// Tools that will make agent return directly when the tool is called.
// When multiple tools are called and more than one tool is in the return directly list, only the first one will be returned.
ToolReturnDirectly map[string]struct{}
// StreamOutputHandler is a function to determine whether the model's streaming output contains tool calls.
StreamToolCallChecker func(ctx context.Context, modelOutput *schema.StreamReader[*schema.Message]) (bool, error)
}
// ToolCallHandler is a function type for handling tool calls as they happen
@@ -64,49 +36,14 @@ type ResponseHandler func(content string)
// ToolCallContentHandler is a function type for handling content that accompanies tool calls
type ToolCallContentHandler func(content string)
func firstChunkStreamToolCallChecker(_ context.Context, sr *schema.StreamReader[*schema.Message]) (bool, error) {
defer sr.Close()
for {
msg, err := sr.Recv()
if err == io.EOF {
return false, nil
}
if err != nil {
return false, err
}
if len(msg.ToolCalls) > 0 {
return true, nil
}
if len(msg.Content) == 0 { // skip empty chunks at the front
continue
}
return false, nil
}
}
const (
GraphName = "Agent"
ModelNodeName = "ChatModel"
ToolsNodeName = "Tools"
)
// Agent is the agent with real-time tool call display.
type Agent struct {
runnable compose.Runnable[[]*schema.Message, *schema.Message]
graph *compose.Graph[[]*schema.Message, *schema.Message]
graphAddNodeOpts []compose.GraphAddNodeOpt
toolManager *tools.MCPToolManager
model model.ToolCallingChatModel
maxSteps int
systemPrompt string
toolManager *tools.MCPToolManager
model model.ToolCallingChatModel
maxSteps int
systemPrompt string
}
var registerStateOnce sync.Once
// NewAgent creates an agent with MCP tool integration and real-time tool call display
func NewAgent(ctx context.Context, config *AgentConfig) (*Agent, error) {
// Create the LLM provider
@@ -121,252 +58,19 @@ func NewAgent(ctx context.Context, config *AgentConfig) (*Agent, error) {
return nil, fmt.Errorf("failed to load MCP tools: %v", err)
}
var (
toolsNode *compose.ToolsNode
toolInfos []*schema.ToolInfo
toolCallChecker = config.StreamToolCallChecker
messageModifier = config.MessageModifier
)
registerStateOnce.Do(func() {
err = compose.RegisterSerializableType[state]("_eino_agent_state")
})
if err != nil {
return nil, err
}
if toolCallChecker == nil {
toolCallChecker = firstChunkStreamToolCallChecker
}
// Create tools config
toolsConfig := compose.ToolsNodeConfig{
Tools: toolManager.GetTools(),
}
// Only set up tools if we have any
hasTools := len(toolsConfig.Tools) > 0
if hasTools {
if toolInfos, err = genToolInfos(ctx, toolsConfig); err != nil {
return nil, err
}
if toolsNode, err = compose.NewToolNode(ctx, &toolsConfig); err != nil {
return nil, err
}
}
chatModel, err := agent.ChatModelWithTools(nil, model, toolInfos)
if err != nil {
// If binding tools fails and we have no tools, just use the model directly
if !hasTools {
chatModel = model
} else {
return nil, err
}
}
maxSteps := config.MaxSteps
if maxSteps == 0 {
maxSteps = 20
}
graph := compose.NewGraph[[]*schema.Message, *schema.Message](compose.WithGenLocalState(func(ctx context.Context) *state {
return &state{Messages: make([]*schema.Message, 0, maxSteps+1)}
}))
modelPreHandle := func(ctx context.Context, input []*schema.Message, state *state) ([]*schema.Message, error) {
state.Messages = append(state.Messages, input...)
// Add system prompt if provided and not already present
if config.SystemPrompt != "" {
hasSystemMessage := false
if len(state.Messages) > 0 && state.Messages[0].Role == schema.System {
hasSystemMessage = true
}
if !hasSystemMessage {
systemMsg := schema.SystemMessage(config.SystemPrompt)
state.Messages = append([]*schema.Message{systemMsg}, state.Messages...)
}
}
if messageModifier == nil {
return state.Messages, nil
}
modifiedInput := make([]*schema.Message, len(state.Messages))
copy(modifiedInput, state.Messages)
return messageModifier(ctx, modifiedInput), nil
}
if err = graph.AddChatModelNode(nodeKeyModel, chatModel, compose.WithStatePreHandler(modelPreHandle), compose.WithNodeName(ModelNodeName)); err != nil {
return nil, err
}
if err = graph.AddEdge(compose.START, nodeKeyModel); err != nil {
return nil, err
}
// Only add tools node and related logic if we have tools
if hasTools {
toolsNodePreHandle := func(ctx context.Context, input *schema.Message, state *state) (*schema.Message, error) {
if input == nil {
return state.Messages[len(state.Messages)-1], nil // used for rerun interrupt resume
}
state.Messages = append(state.Messages, input)
state.ReturnDirectlyToolCallID = getReturnDirectlyToolCallID(input, config.ToolReturnDirectly)
return input, nil
}
if err = graph.AddToolsNode(nodeKeyTools, toolsNode, compose.WithStatePreHandler(toolsNodePreHandle), compose.WithNodeName(ToolsNodeName)); err != nil {
return nil, err
}
modelPostBranchCondition := func(_ context.Context, sr *schema.StreamReader[*schema.Message]) (endNode string, err error) {
if isToolCall, err := toolCallChecker(ctx, sr); err != nil {
return "", err
} else if isToolCall {
return nodeKeyTools, nil
}
return compose.END, nil
}
if err = graph.AddBranch(nodeKeyModel, compose.NewStreamGraphBranch(modelPostBranchCondition, map[string]bool{nodeKeyTools: true, compose.END: true})); err != nil {
return nil, err
}
if len(config.ToolReturnDirectly) > 0 {
if err = buildReturnDirectly(graph); err != nil {
return nil, err
}
} else if err = graph.AddEdge(nodeKeyTools, nodeKeyModel); err != nil {
return nil, err
}
} else {
// No tools, so model goes directly to END
if err = graph.AddEdge(nodeKeyModel, compose.END); err != nil {
return nil, err
}
}
compileOpts := []compose.GraphCompileOption{compose.WithMaxRunSteps(maxSteps), compose.WithNodeTriggerMode(compose.AnyPredecessor), compose.WithGraphName(GraphName)}
runnable, err := graph.Compile(ctx, compileOpts...)
if err != nil {
return nil, err
}
return &Agent{
runnable: runnable,
graph: graph,
graphAddNodeOpts: []compose.GraphAddNodeOpt{compose.WithGraphCompileOptions(compileOpts...)},
toolManager: toolManager,
model: model,
maxSteps: maxSteps,
systemPrompt: config.SystemPrompt,
toolManager: toolManager,
model: model,
maxSteps: maxSteps,
systemPrompt: config.SystemPrompt,
}, nil
}
func buildReturnDirectly(graph *compose.Graph[[]*schema.Message, *schema.Message]) (err error) {
directReturn := func(ctx context.Context, msgs *schema.StreamReader[[]*schema.Message]) (*schema.StreamReader[*schema.Message], error) {
return schema.StreamReaderWithConvert(msgs, func(msgs []*schema.Message) (*schema.Message, error) {
var msg *schema.Message
err = compose.ProcessState[*state](ctx, func(_ context.Context, state *state) error {
for i := range msgs {
if msgs[i] != nil && msgs[i].ToolCallID == state.ReturnDirectlyToolCallID {
msg = msgs[i]
return nil
}
}
return nil
})
if err != nil {
return nil, err
}
if msg == nil {
return nil, schema.ErrNoValue
}
return msg, nil
}), nil
}
nodeKeyDirectReturn := "direct_return"
if err = graph.AddLambdaNode(nodeKeyDirectReturn, compose.TransformableLambda(directReturn)); err != nil {
return err
}
// this branch checks if the tool called should return directly. It either leads to END or back to ChatModel
err = graph.AddBranch(nodeKeyTools, compose.NewStreamGraphBranch(func(ctx context.Context, msgsStream *schema.StreamReader[[]*schema.Message]) (endNode string, err error) {
msgsStream.Close()
err = compose.ProcessState[*state](ctx, func(_ context.Context, state *state) error {
if len(state.ReturnDirectlyToolCallID) > 0 {
endNode = nodeKeyDirectReturn
} else {
endNode = nodeKeyModel
}
return nil
})
if err != nil {
return "", err
}
return endNode, nil
}, map[string]bool{nodeKeyModel: true, nodeKeyDirectReturn: true}))
if err != nil {
return err
}
return graph.AddEdge(nodeKeyDirectReturn, compose.END)
}
func genToolInfos(ctx context.Context, config compose.ToolsNodeConfig) ([]*schema.ToolInfo, error) {
toolInfos := make([]*schema.ToolInfo, 0, len(config.Tools))
for _, t := range config.Tools {
tl, err := t.Info(ctx)
if err != nil {
return nil, err
}
toolInfos = append(toolInfos, tl)
}
return toolInfos, nil
}
func getReturnDirectlyToolCallID(input *schema.Message, toolReturnDirectly map[string]struct{}) string {
if len(toolReturnDirectly) == 0 {
return ""
}
for _, toolCall := range input.ToolCalls {
if _, ok := toolReturnDirectly[toolCall.Function.Name]; ok {
return toolCall.ID
}
}
return ""
}
// Generate generates a response from the agent.
func (a *Agent) Generate(ctx context.Context, input []*schema.Message, opts ...compose.Option) (*schema.Message, error) {
// Convert compose options to agent options
agentOpts := []agent.AgentOption{}
if len(opts) > 0 {
agentOpts = append(agentOpts, agent.WithComposeOptions(opts...))
}
return a.runnable.Invoke(ctx, input, agent.GetComposeOptions(agentOpts...)...)
}
// Stream calls the agent and returns a stream response.
func (a *Agent) Stream(ctx context.Context, input []*schema.Message, opts ...compose.Option) (output *schema.StreamReader[*schema.Message], err error) {
// Convert compose options to agent options
agentOpts := []agent.AgentOption{}
if len(opts) > 0 {
agentOpts = append(agentOpts, agent.WithComposeOptions(opts...))
}
return a.runnable.Stream(ctx, input, agent.GetComposeOptions(agentOpts...)...)
}
// GenerateWithLoop processes messages with a custom loop that displays tool calls in real-time
func (a *Agent) GenerateWithLoop(ctx context.Context, messages []*schema.Message,
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler) (*schema.Message, error) {
@@ -489,8 +193,3 @@ func (a *Agent) GetTools() []tool.BaseTool {
func (a *Agent) Close() error {
return a.toolManager.Close()
}
// ExportGraph exports the underlying graph from Agent, along with the []compose.GraphAddNodeOpt to be used when adding this graph to another graph.
func (a *Agent) ExportGraph() (compose.AnyGraph, []compose.GraphAddNodeOpt) {
return a.graph, a.graphAddNodeOpts
}
+108 -92
View File
@@ -5,84 +5,33 @@ import (
"encoding/json"
"fmt"
"github.com/bytedance/sonic"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/schema"
"github.com/getkin/kin-openapi/openapi3"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcphost/internal/config"
)
// MCPTool wraps an MCP tool to implement eino's InvokableTool interface
type MCPTool struct {
client client.MCPClient
toolInfo *mcp.Tool
name string
}
// Info returns the tool information for eino
func (t *MCPTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
// Convert MCP tool schema to eino schema
properties := make(map[string]*schema.ParameterInfo)
// Handle the input schema
if t.toolInfo.InputSchema.Properties != nil {
for name, prop := range t.toolInfo.InputSchema.Properties {
if propMap, ok := prop.(map[string]interface{}); ok {
paramInfo := &schema.ParameterInfo{
Type: schema.String, // Default type
}
if typeVal, ok := propMap["type"].(string); ok {
paramInfo.Type = schema.DataType(typeVal)
}
if desc, ok := propMap["description"].(string); ok {
paramInfo.Desc = desc
}
properties[name] = paramInfo
}
}
}
return &schema.ToolInfo{
Name: t.name,
Desc: t.toolInfo.Description,
ParamsOneOf: schema.NewParamsOneOfByParams(properties),
}, nil
}
// InvokableRun implements the InvokableTool interface
func (t *MCPTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
var args map[string]interface{}
if err := json.Unmarshal([]byte(argumentsInJSON), &args); err != nil {
return "", fmt.Errorf("failed to parse arguments: %v", err)
}
req := mcp.CallToolRequest{}
req.Params.Name = t.toolInfo.Name
req.Params.Arguments = args
result, err := t.client.CallTool(ctx, req)
if err != nil {
return "", fmt.Errorf("failed to call tool: %v", err)
}
// Convert result to string
if result.Content != nil {
var resultText string
for _, item := range result.Content {
if textContent, ok := item.(mcp.TextContent); ok {
resultText += textContent.Text + " "
}
}
return resultText, nil
}
return "", nil
}
// MCPToolManager manages MCP tools and clients
type MCPToolManager struct {
clients map[string]client.MCPClient
tools []tool.BaseTool
clients map[string]client.MCPClient
tools []tool.BaseTool
toolMap map[string]*toolMapping // maps prefixed tool names to their server and original name
}
// toolMapping stores the mapping between prefixed tool names and their original details
type toolMapping struct {
serverName string
originalName string
client client.MCPClient
}
// mcpToolImpl implements the eino tool interface with server prefixing
type mcpToolImpl struct {
info *schema.ToolInfo
mapping *toolMapping
}
// NewMCPToolManager creates a new MCP tool manager
@@ -90,6 +39,7 @@ func NewMCPToolManager() *MCPToolManager {
return &MCPToolManager{
clients: make(map[string]client.MCPClient),
tools: make([]tool.BaseTool, 0),
toolMap: make(map[string]*toolMapping),
}
}
@@ -109,23 +59,66 @@ func (m *MCPToolManager) LoadTools(ctx context.Context, config *config.Config) e
}
// Get tools from this server
toolsResult, err := client.ListTools(ctx, mcp.ListToolsRequest{})
listResults, err := client.ListTools(ctx, mcp.ListToolsRequest{})
if err != nil {
return fmt.Errorf("failed to list tools from server %s: %v", serverName, err)
}
// Convert MCP tools to eino tools
for _, mcpTool := range toolsResult.Tools {
// Create name set for allowed tools
var nameSet map[string]struct{}
if len(serverConfig.AllowedTools) > 0 {
nameSet = make(map[string]struct{})
for _, name := range serverConfig.AllowedTools {
nameSet[name] = struct{}{}
}
}
// Convert MCP tools to eino tools with prefixed names
for _, mcpTool := range listResults.Tools {
// Filter tools based on allowedTools/excludedTools
if !m.shouldIncludeTool(mcpTool.Name, serverConfig) {
if len(serverConfig.AllowedTools) > 0 {
if _, ok := nameSet[mcpTool.Name]; !ok {
continue
}
}
// Check if tool should be excluded
if m.shouldExcludeTool(mcpTool.Name, serverConfig) {
continue
}
einoTool := &MCPTool{
client: client,
toolInfo: &mcpTool,
name: fmt.Sprintf("%s__%s", serverName, mcpTool.Name),
// Convert schema
marshaledInputSchema, err := sonic.Marshal(mcpTool.InputSchema)
if err != nil {
return fmt.Errorf("conv mcp tool input schema fail(marshal): %w, tool name: %s", err, mcpTool.Name)
}
inputSchema := &openapi3.Schema{}
err = sonic.Unmarshal(marshaledInputSchema, inputSchema)
if err != nil {
return fmt.Errorf("conv mcp tool input schema fail(unmarshal): %w, tool name: %s", err, mcpTool.Name)
}
// Create prefixed tool name
prefixedName := fmt.Sprintf("%s__%s", serverName, mcpTool.Name)
// Create tool mapping
mapping := &toolMapping{
serverName: serverName,
originalName: mcpTool.Name,
client: client,
}
m.toolMap[prefixedName] = mapping
// Create eino tool
einoTool := &mcpToolImpl{
info: &schema.ToolInfo{
Name: prefixedName,
Desc: mcpTool.Description,
ParamsOneOf: schema.NewParamsOneOfByOpenAPIV3(inputSchema),
},
mapping: mapping,
}
m.tools = append(m.tools, einoTool)
}
}
@@ -133,6 +126,40 @@ func (m *MCPToolManager) LoadTools(ctx context.Context, config *config.Config) e
return nil
}
// Info returns the tool information
func (t *mcpToolImpl) Info(ctx context.Context) (*schema.ToolInfo, error) {
return t.info, nil
}
// InvokableRun executes the tool by mapping back to the original name and server
func (t *mcpToolImpl) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
result, err := t.mapping.client.CallTool(ctx, mcp.CallToolRequest{
Request: mcp.Request{
Method: "tools/call",
},
Params: struct {
Name string `json:"name"`
Arguments any `json:"arguments,omitempty"`
Meta *mcp.Meta `json:"_meta,omitempty"`
}{
Name: t.mapping.originalName, // Use original name, not prefixed
Arguments: json.RawMessage(argumentsInJSON),
},
})
if err != nil {
return "", fmt.Errorf("failed to call mcp tool: %w", err)
}
marshaledResult, err := sonic.MarshalString(result)
if err != nil {
return "", fmt.Errorf("failed to marshal mcp tool result: %w", err)
}
if result.IsError {
return "", fmt.Errorf("failed to call mcp tool, mcp server return error: %s", marshaledResult)
}
return marshaledResult, nil
}
// GetTools returns all loaded tools
func (m *MCPToolManager) GetTools() []tool.BaseTool {
return m.tools
@@ -148,29 +175,18 @@ func (m *MCPToolManager) Close() error {
return nil
}
// shouldIncludeTool determines if a tool should be included based on allowedTools/excludedTools
func (m *MCPToolManager) shouldIncludeTool(toolName string, serverConfig config.MCPServerConfig) bool {
// If allowedTools is specified, only include tools in the list
if len(serverConfig.AllowedTools) > 0 {
for _, allowedTool := range serverConfig.AllowedTools {
if allowedTool == toolName {
return true
}
}
return false
}
// shouldExcludeTool determines if a tool should be excluded based on excludedTools
func (m *MCPToolManager) shouldExcludeTool(toolName string, serverConfig config.MCPServerConfig) bool {
// If excludedTools is specified, exclude tools in the list
if len(serverConfig.ExcludedTools) > 0 {
for _, excludedTool := range serverConfig.ExcludedTools {
if excludedTool == toolName {
return false
return true
}
}
}
// Include by default
return true
return false
}
func (m *MCPToolManager) createMCPClient(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (client.MCPClient, error) {