mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
cleanup
This commit is contained in:
@@ -5,6 +5,7 @@ go 1.23.4
|
||||
toolchain go1.23.9
|
||||
|
||||
require (
|
||||
github.com/bytedance/sonic v1.13.3
|
||||
github.com/charmbracelet/huh v0.3.0
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834
|
||||
github.com/cloudwego/eino v0.3.41
|
||||
@@ -12,6 +13,7 @@ require (
|
||||
github.com/cloudwego/eino-ext/components/model/gemini v0.0.0-20250609074000-b7f307dffa18
|
||||
github.com/cloudwego/eino-ext/components/model/ollama v0.0.0-20250609074000-b7f307dffa18
|
||||
github.com/cloudwego/eino-ext/components/model/openai v0.0.0-20250609074000-b7f307dffa18
|
||||
github.com/getkin/kin-openapi v0.118.0
|
||||
github.com/google/generative-ai-go v0.19.0
|
||||
github.com/mark3labs/mcp-go v0.31.0
|
||||
github.com/spf13/cobra v1.8.1
|
||||
@@ -46,7 +48,6 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.33.9 // indirect
|
||||
github.com/aws/smithy-go v1.22.1 // indirect
|
||||
github.com/aymerick/douceur v0.2.0 // indirect
|
||||
github.com/bytedance/sonic v1.13.2 // indirect
|
||||
github.com/bytedance/sonic/loader v0.2.4 // indirect
|
||||
github.com/catppuccin/go v0.2.0 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
|
||||
@@ -59,11 +60,10 @@ require (
|
||||
github.com/evanphx/json-patch v0.5.2 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/fsnotify/fsnotify v1.8.0 // indirect
|
||||
github.com/getkin/kin-openapi v0.118.0 // indirect
|
||||
github.com/go-logr/logr v1.4.2 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-openapi/jsonpointer v0.19.5 // indirect
|
||||
github.com/go-openapi/swag v0.19.5 // indirect
|
||||
github.com/go-openapi/jsonpointer v0.21.0 // indirect
|
||||
github.com/go-openapi/swag v0.23.0 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
@@ -85,7 +85,7 @@ require (
|
||||
github.com/nikolalohinski/gonja v1.5.3 // indirect
|
||||
github.com/ollama/ollama v0.5.12 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
|
||||
github.com/perimeterx/marshmallow v1.1.4 // indirect
|
||||
github.com/perimeterx/marshmallow v1.1.5 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/sagikazarmark/locafero v0.7.0 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
@@ -121,7 +121,6 @@ require (
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250313205543-e70fdf4c4cb4 // indirect
|
||||
google.golang.org/grpc v1.71.0 // indirect
|
||||
google.golang.org/protobuf v1.36.6 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
|
||||
@@ -63,8 +63,8 @@ github.com/bugsnag/bugsnag-go v1.4.0/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqR
|
||||
github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE=
|
||||
github.com/bytedance/mockey v1.2.13 h1:jokWZAm/pUEbD939Rhznz615MKUCZNuvCFQlJ2+ntoo=
|
||||
github.com/bytedance/mockey v1.2.13/go.mod h1:1BPHF9sol5R1ud/+0VEHGQq/+i2lN+GTsr3O2Q9IENY=
|
||||
github.com/bytedance/sonic v1.13.2 h1:8/H1FempDZqC4VqjptGo14QQlJx8VdZJegxs6wwfqpQ=
|
||||
github.com/bytedance/sonic v1.13.2/go.mod h1:o68xyaF9u2gvVBuGHPlUVCy+ZfmNNO5ETf1+KgkJhz4=
|
||||
github.com/bytedance/sonic v1.13.3 h1:MS8gmaH16Gtirygw7jV91pDCN33NyMrPbN7qiYhEsF0=
|
||||
github.com/bytedance/sonic v1.13.3/go.mod h1:o68xyaF9u2gvVBuGHPlUVCy+ZfmNNO5ETf1+KgkJhz4=
|
||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
github.com/bytedance/sonic/loader v0.2.4 h1:ZWCw4stuXUsn1/+zQDqeE7JKP+QO47tz7QCNan80NzY=
|
||||
github.com/bytedance/sonic/loader v0.2.4/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI=
|
||||
@@ -137,10 +137,12 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
||||
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY=
|
||||
github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
|
||||
github.com/go-openapi/swag v0.19.5 h1:lTz6Ys4CmqqCQmZPBlbQENR1/GucA2bzYTE12Pw4tFY=
|
||||
github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ=
|
||||
github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY=
|
||||
github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk=
|
||||
github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE=
|
||||
github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ=
|
||||
github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM=
|
||||
github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss=
|
||||
@@ -243,8 +245,9 @@ github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+W
|
||||
github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
|
||||
github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
|
||||
github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc=
|
||||
github.com/perimeterx/marshmallow v1.1.4 h1:pZLDH9RjlLGGorbXhcaQLhfuV0pFMNfPO55FuFkxqLw=
|
||||
github.com/perimeterx/marshmallow v1.1.4/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw=
|
||||
github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s=
|
||||
github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw=
|
||||
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
@@ -395,7 +398,6 @@ gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMy
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
+8
-309
@@ -3,32 +3,15 @@ package agent
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/flow/agent"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/mark3labs/mcphost/internal/config"
|
||||
"github.com/mark3labs/mcphost/internal/models"
|
||||
"github.com/mark3labs/mcphost/internal/tools"
|
||||
)
|
||||
|
||||
type state struct {
|
||||
Messages []*schema.Message
|
||||
ReturnDirectlyToolCallID string
|
||||
}
|
||||
|
||||
const (
|
||||
nodeKeyTools = "tools"
|
||||
nodeKeyModel = "chat"
|
||||
)
|
||||
|
||||
// MessageModifier modify the input messages before the model is called.
|
||||
type MessageModifier func(ctx context.Context, input []*schema.Message) []*schema.Message
|
||||
|
||||
// AgentConfig is the config for agent.
|
||||
type AgentConfig struct {
|
||||
ModelConfig *models.ProviderConfig
|
||||
@@ -36,17 +19,6 @@ type AgentConfig struct {
|
||||
SystemPrompt string
|
||||
MaxSteps int
|
||||
MessageWindow int
|
||||
|
||||
// MessageModifier.
|
||||
// modify the input messages before the model is called, it's useful when you want to add some system prompt or other messages.
|
||||
MessageModifier MessageModifier
|
||||
|
||||
// Tools that will make agent return directly when the tool is called.
|
||||
// When multiple tools are called and more than one tool is in the return directly list, only the first one will be returned.
|
||||
ToolReturnDirectly map[string]struct{}
|
||||
|
||||
// StreamOutputHandler is a function to determine whether the model's streaming output contains tool calls.
|
||||
StreamToolCallChecker func(ctx context.Context, modelOutput *schema.StreamReader[*schema.Message]) (bool, error)
|
||||
}
|
||||
|
||||
// ToolCallHandler is a function type for handling tool calls as they happen
|
||||
@@ -64,49 +36,14 @@ type ResponseHandler func(content string)
|
||||
// ToolCallContentHandler is a function type for handling content that accompanies tool calls
|
||||
type ToolCallContentHandler func(content string)
|
||||
|
||||
func firstChunkStreamToolCallChecker(_ context.Context, sr *schema.StreamReader[*schema.Message]) (bool, error) {
|
||||
defer sr.Close()
|
||||
|
||||
for {
|
||||
msg, err := sr.Recv()
|
||||
if err == io.EOF {
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if len(msg.Content) == 0 { // skip empty chunks at the front
|
||||
continue
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
GraphName = "Agent"
|
||||
ModelNodeName = "ChatModel"
|
||||
ToolsNodeName = "Tools"
|
||||
)
|
||||
|
||||
// Agent is the agent with real-time tool call display.
|
||||
type Agent struct {
|
||||
runnable compose.Runnable[[]*schema.Message, *schema.Message]
|
||||
graph *compose.Graph[[]*schema.Message, *schema.Message]
|
||||
graphAddNodeOpts []compose.GraphAddNodeOpt
|
||||
toolManager *tools.MCPToolManager
|
||||
model model.ToolCallingChatModel
|
||||
maxSteps int
|
||||
systemPrompt string
|
||||
toolManager *tools.MCPToolManager
|
||||
model model.ToolCallingChatModel
|
||||
maxSteps int
|
||||
systemPrompt string
|
||||
}
|
||||
|
||||
var registerStateOnce sync.Once
|
||||
|
||||
// NewAgent creates an agent with MCP tool integration and real-time tool call display
|
||||
func NewAgent(ctx context.Context, config *AgentConfig) (*Agent, error) {
|
||||
// Create the LLM provider
|
||||
@@ -121,252 +58,19 @@ func NewAgent(ctx context.Context, config *AgentConfig) (*Agent, error) {
|
||||
return nil, fmt.Errorf("failed to load MCP tools: %v", err)
|
||||
}
|
||||
|
||||
var (
|
||||
toolsNode *compose.ToolsNode
|
||||
toolInfos []*schema.ToolInfo
|
||||
toolCallChecker = config.StreamToolCallChecker
|
||||
messageModifier = config.MessageModifier
|
||||
)
|
||||
|
||||
registerStateOnce.Do(func() {
|
||||
err = compose.RegisterSerializableType[state]("_eino_agent_state")
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if toolCallChecker == nil {
|
||||
toolCallChecker = firstChunkStreamToolCallChecker
|
||||
}
|
||||
|
||||
// Create tools config
|
||||
toolsConfig := compose.ToolsNodeConfig{
|
||||
Tools: toolManager.GetTools(),
|
||||
}
|
||||
|
||||
// Only set up tools if we have any
|
||||
hasTools := len(toolsConfig.Tools) > 0
|
||||
|
||||
if hasTools {
|
||||
if toolInfos, err = genToolInfos(ctx, toolsConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if toolsNode, err = compose.NewToolNode(ctx, &toolsConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
chatModel, err := agent.ChatModelWithTools(nil, model, toolInfos)
|
||||
if err != nil {
|
||||
// If binding tools fails and we have no tools, just use the model directly
|
||||
if !hasTools {
|
||||
chatModel = model
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
maxSteps := config.MaxSteps
|
||||
if maxSteps == 0 {
|
||||
maxSteps = 20
|
||||
}
|
||||
|
||||
graph := compose.NewGraph[[]*schema.Message, *schema.Message](compose.WithGenLocalState(func(ctx context.Context) *state {
|
||||
return &state{Messages: make([]*schema.Message, 0, maxSteps+1)}
|
||||
}))
|
||||
|
||||
modelPreHandle := func(ctx context.Context, input []*schema.Message, state *state) ([]*schema.Message, error) {
|
||||
state.Messages = append(state.Messages, input...)
|
||||
|
||||
// Add system prompt if provided and not already present
|
||||
if config.SystemPrompt != "" {
|
||||
hasSystemMessage := false
|
||||
if len(state.Messages) > 0 && state.Messages[0].Role == schema.System {
|
||||
hasSystemMessage = true
|
||||
}
|
||||
|
||||
if !hasSystemMessage {
|
||||
systemMsg := schema.SystemMessage(config.SystemPrompt)
|
||||
state.Messages = append([]*schema.Message{systemMsg}, state.Messages...)
|
||||
}
|
||||
}
|
||||
|
||||
if messageModifier == nil {
|
||||
return state.Messages, nil
|
||||
}
|
||||
|
||||
modifiedInput := make([]*schema.Message, len(state.Messages))
|
||||
copy(modifiedInput, state.Messages)
|
||||
return messageModifier(ctx, modifiedInput), nil
|
||||
}
|
||||
|
||||
if err = graph.AddChatModelNode(nodeKeyModel, chatModel, compose.WithStatePreHandler(modelPreHandle), compose.WithNodeName(ModelNodeName)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = graph.AddEdge(compose.START, nodeKeyModel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Only add tools node and related logic if we have tools
|
||||
if hasTools {
|
||||
toolsNodePreHandle := func(ctx context.Context, input *schema.Message, state *state) (*schema.Message, error) {
|
||||
if input == nil {
|
||||
return state.Messages[len(state.Messages)-1], nil // used for rerun interrupt resume
|
||||
}
|
||||
state.Messages = append(state.Messages, input)
|
||||
state.ReturnDirectlyToolCallID = getReturnDirectlyToolCallID(input, config.ToolReturnDirectly)
|
||||
return input, nil
|
||||
}
|
||||
if err = graph.AddToolsNode(nodeKeyTools, toolsNode, compose.WithStatePreHandler(toolsNodePreHandle), compose.WithNodeName(ToolsNodeName)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelPostBranchCondition := func(_ context.Context, sr *schema.StreamReader[*schema.Message]) (endNode string, err error) {
|
||||
if isToolCall, err := toolCallChecker(ctx, sr); err != nil {
|
||||
return "", err
|
||||
} else if isToolCall {
|
||||
return nodeKeyTools, nil
|
||||
}
|
||||
return compose.END, nil
|
||||
}
|
||||
|
||||
if err = graph.AddBranch(nodeKeyModel, compose.NewStreamGraphBranch(modelPostBranchCondition, map[string]bool{nodeKeyTools: true, compose.END: true})); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(config.ToolReturnDirectly) > 0 {
|
||||
if err = buildReturnDirectly(graph); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if err = graph.AddEdge(nodeKeyTools, nodeKeyModel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// No tools, so model goes directly to END
|
||||
if err = graph.AddEdge(nodeKeyModel, compose.END); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
compileOpts := []compose.GraphCompileOption{compose.WithMaxRunSteps(maxSteps), compose.WithNodeTriggerMode(compose.AnyPredecessor), compose.WithGraphName(GraphName)}
|
||||
runnable, err := graph.Compile(ctx, compileOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Agent{
|
||||
runnable: runnable,
|
||||
graph: graph,
|
||||
graphAddNodeOpts: []compose.GraphAddNodeOpt{compose.WithGraphCompileOptions(compileOpts...)},
|
||||
toolManager: toolManager,
|
||||
model: model,
|
||||
maxSteps: maxSteps,
|
||||
systemPrompt: config.SystemPrompt,
|
||||
toolManager: toolManager,
|
||||
model: model,
|
||||
maxSteps: maxSteps,
|
||||
systemPrompt: config.SystemPrompt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildReturnDirectly(graph *compose.Graph[[]*schema.Message, *schema.Message]) (err error) {
|
||||
directReturn := func(ctx context.Context, msgs *schema.StreamReader[[]*schema.Message]) (*schema.StreamReader[*schema.Message], error) {
|
||||
return schema.StreamReaderWithConvert(msgs, func(msgs []*schema.Message) (*schema.Message, error) {
|
||||
var msg *schema.Message
|
||||
err = compose.ProcessState[*state](ctx, func(_ context.Context, state *state) error {
|
||||
for i := range msgs {
|
||||
if msgs[i] != nil && msgs[i].ToolCallID == state.ReturnDirectlyToolCallID {
|
||||
msg = msgs[i]
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if msg == nil {
|
||||
return nil, schema.ErrNoValue
|
||||
}
|
||||
return msg, nil
|
||||
}), nil
|
||||
}
|
||||
|
||||
nodeKeyDirectReturn := "direct_return"
|
||||
if err = graph.AddLambdaNode(nodeKeyDirectReturn, compose.TransformableLambda(directReturn)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// this branch checks if the tool called should return directly. It either leads to END or back to ChatModel
|
||||
err = graph.AddBranch(nodeKeyTools, compose.NewStreamGraphBranch(func(ctx context.Context, msgsStream *schema.StreamReader[[]*schema.Message]) (endNode string, err error) {
|
||||
msgsStream.Close()
|
||||
|
||||
err = compose.ProcessState[*state](ctx, func(_ context.Context, state *state) error {
|
||||
if len(state.ReturnDirectlyToolCallID) > 0 {
|
||||
endNode = nodeKeyDirectReturn
|
||||
} else {
|
||||
endNode = nodeKeyModel
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return endNode, nil
|
||||
}, map[string]bool{nodeKeyModel: true, nodeKeyDirectReturn: true}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return graph.AddEdge(nodeKeyDirectReturn, compose.END)
|
||||
}
|
||||
|
||||
func genToolInfos(ctx context.Context, config compose.ToolsNodeConfig) ([]*schema.ToolInfo, error) {
|
||||
toolInfos := make([]*schema.ToolInfo, 0, len(config.Tools))
|
||||
for _, t := range config.Tools {
|
||||
tl, err := t.Info(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toolInfos = append(toolInfos, tl)
|
||||
}
|
||||
|
||||
return toolInfos, nil
|
||||
}
|
||||
|
||||
func getReturnDirectlyToolCallID(input *schema.Message, toolReturnDirectly map[string]struct{}) string {
|
||||
if len(toolReturnDirectly) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
for _, toolCall := range input.ToolCalls {
|
||||
if _, ok := toolReturnDirectly[toolCall.Function.Name]; ok {
|
||||
return toolCall.ID
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// Generate generates a response from the agent.
|
||||
func (a *Agent) Generate(ctx context.Context, input []*schema.Message, opts ...compose.Option) (*schema.Message, error) {
|
||||
// Convert compose options to agent options
|
||||
agentOpts := []agent.AgentOption{}
|
||||
if len(opts) > 0 {
|
||||
agentOpts = append(agentOpts, agent.WithComposeOptions(opts...))
|
||||
}
|
||||
return a.runnable.Invoke(ctx, input, agent.GetComposeOptions(agentOpts...)...)
|
||||
}
|
||||
|
||||
// Stream calls the agent and returns a stream response.
|
||||
func (a *Agent) Stream(ctx context.Context, input []*schema.Message, opts ...compose.Option) (output *schema.StreamReader[*schema.Message], err error) {
|
||||
// Convert compose options to agent options
|
||||
agentOpts := []agent.AgentOption{}
|
||||
if len(opts) > 0 {
|
||||
agentOpts = append(agentOpts, agent.WithComposeOptions(opts...))
|
||||
}
|
||||
return a.runnable.Stream(ctx, input, agent.GetComposeOptions(agentOpts...)...)
|
||||
}
|
||||
|
||||
// GenerateWithLoop processes messages with a custom loop that displays tool calls in real-time
|
||||
func (a *Agent) GenerateWithLoop(ctx context.Context, messages []*schema.Message,
|
||||
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler) (*schema.Message, error) {
|
||||
@@ -489,8 +193,3 @@ func (a *Agent) GetTools() []tool.BaseTool {
|
||||
func (a *Agent) Close() error {
|
||||
return a.toolManager.Close()
|
||||
}
|
||||
|
||||
// ExportGraph exports the underlying graph from Agent, along with the []compose.GraphAddNodeOpt to be used when adding this graph to another graph.
|
||||
func (a *Agent) ExportGraph() (compose.AnyGraph, []compose.GraphAddNodeOpt) {
|
||||
return a.graph, a.graphAddNodeOpts
|
||||
}
|
||||
|
||||
+108
-92
@@ -5,84 +5,33 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcphost/internal/config"
|
||||
)
|
||||
|
||||
// MCPTool wraps an MCP tool to implement eino's InvokableTool interface
|
||||
type MCPTool struct {
|
||||
client client.MCPClient
|
||||
toolInfo *mcp.Tool
|
||||
name string
|
||||
}
|
||||
|
||||
// Info returns the tool information for eino
|
||||
func (t *MCPTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
|
||||
// Convert MCP tool schema to eino schema
|
||||
properties := make(map[string]*schema.ParameterInfo)
|
||||
|
||||
// Handle the input schema
|
||||
if t.toolInfo.InputSchema.Properties != nil {
|
||||
for name, prop := range t.toolInfo.InputSchema.Properties {
|
||||
if propMap, ok := prop.(map[string]interface{}); ok {
|
||||
paramInfo := &schema.ParameterInfo{
|
||||
Type: schema.String, // Default type
|
||||
}
|
||||
if typeVal, ok := propMap["type"].(string); ok {
|
||||
paramInfo.Type = schema.DataType(typeVal)
|
||||
}
|
||||
if desc, ok := propMap["description"].(string); ok {
|
||||
paramInfo.Desc = desc
|
||||
}
|
||||
properties[name] = paramInfo
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &schema.ToolInfo{
|
||||
Name: t.name,
|
||||
Desc: t.toolInfo.Description,
|
||||
ParamsOneOf: schema.NewParamsOneOfByParams(properties),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// InvokableRun implements the InvokableTool interface
|
||||
func (t *MCPTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
|
||||
var args map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(argumentsInJSON), &args); err != nil {
|
||||
return "", fmt.Errorf("failed to parse arguments: %v", err)
|
||||
}
|
||||
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Name = t.toolInfo.Name
|
||||
req.Params.Arguments = args
|
||||
|
||||
result, err := t.client.CallTool(ctx, req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to call tool: %v", err)
|
||||
}
|
||||
|
||||
// Convert result to string
|
||||
if result.Content != nil {
|
||||
var resultText string
|
||||
for _, item := range result.Content {
|
||||
if textContent, ok := item.(mcp.TextContent); ok {
|
||||
resultText += textContent.Text + " "
|
||||
}
|
||||
}
|
||||
return resultText, nil
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// MCPToolManager manages MCP tools and clients
|
||||
type MCPToolManager struct {
|
||||
clients map[string]client.MCPClient
|
||||
tools []tool.BaseTool
|
||||
clients map[string]client.MCPClient
|
||||
tools []tool.BaseTool
|
||||
toolMap map[string]*toolMapping // maps prefixed tool names to their server and original name
|
||||
}
|
||||
|
||||
// toolMapping stores the mapping between prefixed tool names and their original details
|
||||
type toolMapping struct {
|
||||
serverName string
|
||||
originalName string
|
||||
client client.MCPClient
|
||||
}
|
||||
|
||||
// mcpToolImpl implements the eino tool interface with server prefixing
|
||||
type mcpToolImpl struct {
|
||||
info *schema.ToolInfo
|
||||
mapping *toolMapping
|
||||
}
|
||||
|
||||
// NewMCPToolManager creates a new MCP tool manager
|
||||
@@ -90,6 +39,7 @@ func NewMCPToolManager() *MCPToolManager {
|
||||
return &MCPToolManager{
|
||||
clients: make(map[string]client.MCPClient),
|
||||
tools: make([]tool.BaseTool, 0),
|
||||
toolMap: make(map[string]*toolMapping),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,23 +59,66 @@ func (m *MCPToolManager) LoadTools(ctx context.Context, config *config.Config) e
|
||||
}
|
||||
|
||||
// Get tools from this server
|
||||
toolsResult, err := client.ListTools(ctx, mcp.ListToolsRequest{})
|
||||
listResults, err := client.ListTools(ctx, mcp.ListToolsRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list tools from server %s: %v", serverName, err)
|
||||
}
|
||||
|
||||
// Convert MCP tools to eino tools
|
||||
for _, mcpTool := range toolsResult.Tools {
|
||||
// Create name set for allowed tools
|
||||
var nameSet map[string]struct{}
|
||||
if len(serverConfig.AllowedTools) > 0 {
|
||||
nameSet = make(map[string]struct{})
|
||||
for _, name := range serverConfig.AllowedTools {
|
||||
nameSet[name] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert MCP tools to eino tools with prefixed names
|
||||
for _, mcpTool := range listResults.Tools {
|
||||
// Filter tools based on allowedTools/excludedTools
|
||||
if !m.shouldIncludeTool(mcpTool.Name, serverConfig) {
|
||||
if len(serverConfig.AllowedTools) > 0 {
|
||||
if _, ok := nameSet[mcpTool.Name]; !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Check if tool should be excluded
|
||||
if m.shouldExcludeTool(mcpTool.Name, serverConfig) {
|
||||
continue
|
||||
}
|
||||
|
||||
einoTool := &MCPTool{
|
||||
client: client,
|
||||
toolInfo: &mcpTool,
|
||||
name: fmt.Sprintf("%s__%s", serverName, mcpTool.Name),
|
||||
// Convert schema
|
||||
marshaledInputSchema, err := sonic.Marshal(mcpTool.InputSchema)
|
||||
if err != nil {
|
||||
return fmt.Errorf("conv mcp tool input schema fail(marshal): %w, tool name: %s", err, mcpTool.Name)
|
||||
}
|
||||
inputSchema := &openapi3.Schema{}
|
||||
err = sonic.Unmarshal(marshaledInputSchema, inputSchema)
|
||||
if err != nil {
|
||||
return fmt.Errorf("conv mcp tool input schema fail(unmarshal): %w, tool name: %s", err, mcpTool.Name)
|
||||
}
|
||||
|
||||
// Create prefixed tool name
|
||||
prefixedName := fmt.Sprintf("%s__%s", serverName, mcpTool.Name)
|
||||
|
||||
// Create tool mapping
|
||||
mapping := &toolMapping{
|
||||
serverName: serverName,
|
||||
originalName: mcpTool.Name,
|
||||
client: client,
|
||||
}
|
||||
m.toolMap[prefixedName] = mapping
|
||||
|
||||
// Create eino tool
|
||||
einoTool := &mcpToolImpl{
|
||||
info: &schema.ToolInfo{
|
||||
Name: prefixedName,
|
||||
Desc: mcpTool.Description,
|
||||
ParamsOneOf: schema.NewParamsOneOfByOpenAPIV3(inputSchema),
|
||||
},
|
||||
mapping: mapping,
|
||||
}
|
||||
|
||||
m.tools = append(m.tools, einoTool)
|
||||
}
|
||||
}
|
||||
@@ -133,6 +126,40 @@ func (m *MCPToolManager) LoadTools(ctx context.Context, config *config.Config) e
|
||||
return nil
|
||||
}
|
||||
|
||||
// Info returns the tool information
|
||||
func (t *mcpToolImpl) Info(ctx context.Context) (*schema.ToolInfo, error) {
|
||||
return t.info, nil
|
||||
}
|
||||
|
||||
// InvokableRun executes the tool by mapping back to the original name and server
|
||||
func (t *mcpToolImpl) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
|
||||
result, err := t.mapping.client.CallTool(ctx, mcp.CallToolRequest{
|
||||
Request: mcp.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: struct {
|
||||
Name string `json:"name"`
|
||||
Arguments any `json:"arguments,omitempty"`
|
||||
Meta *mcp.Meta `json:"_meta,omitempty"`
|
||||
}{
|
||||
Name: t.mapping.originalName, // Use original name, not prefixed
|
||||
Arguments: json.RawMessage(argumentsInJSON),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to call mcp tool: %w", err)
|
||||
}
|
||||
|
||||
marshaledResult, err := sonic.MarshalString(result)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal mcp tool result: %w", err)
|
||||
}
|
||||
if result.IsError {
|
||||
return "", fmt.Errorf("failed to call mcp tool, mcp server return error: %s", marshaledResult)
|
||||
}
|
||||
return marshaledResult, nil
|
||||
}
|
||||
|
||||
// GetTools returns all loaded tools
|
||||
func (m *MCPToolManager) GetTools() []tool.BaseTool {
|
||||
return m.tools
|
||||
@@ -148,29 +175,18 @@ func (m *MCPToolManager) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// shouldIncludeTool determines if a tool should be included based on allowedTools/excludedTools
|
||||
func (m *MCPToolManager) shouldIncludeTool(toolName string, serverConfig config.MCPServerConfig) bool {
|
||||
// If allowedTools is specified, only include tools in the list
|
||||
if len(serverConfig.AllowedTools) > 0 {
|
||||
for _, allowedTool := range serverConfig.AllowedTools {
|
||||
if allowedTool == toolName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// shouldExcludeTool determines if a tool should be excluded based on excludedTools
|
||||
func (m *MCPToolManager) shouldExcludeTool(toolName string, serverConfig config.MCPServerConfig) bool {
|
||||
// If excludedTools is specified, exclude tools in the list
|
||||
if len(serverConfig.ExcludedTools) > 0 {
|
||||
for _, excludedTool := range serverConfig.ExcludedTools {
|
||||
if excludedTool == toolName {
|
||||
return false
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Include by default
|
||||
return true
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *MCPToolManager) createMCPClient(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (client.MCPClient, error) {
|
||||
|
||||
Reference in New Issue
Block a user