mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
e6084b7bd0
Implement Phase 1 of the MCP Tasks spec so long-running tools/call
requests can run asynchronously, survive proxy timeouts, and be
cancelled mid-flight.
- connection pool now advertises mcp.NewTasksCapability() during
initialize and captures the InitializeResult so callers can detect
per-server task support
- new MCPServerConfig.TasksMode (auto|never|always, default auto)
parsed from both new and legacy mcp.json shapes
- ExecuteTool augments tools/call with TaskParams when policy and
capability allow, polls tasks/get / tasks/result until terminal,
and best-effort tasks/cancel on context cancellation
- new MCPToolManager methods: SetTaskConfig, ListServerTasks,
GetServerTask, CancelServerTask
- public SDK surface in pkg/kit: MCPTask, MCPTaskStatus, MCPTaskMode,
MCPTaskProgress, MCPTaskProgressHandler, plus Options fields
(MCPTaskMode, MCPTaskTimeout, MCPTaskTTL, MCPTaskPollInterval,
MCPTaskMaxPollInterval, MCPTaskProgress) and Kit.{List,Get,Cancel}
MCPTask methods
- works around two upstream mcp-go v0.51.0 parser bugs
(ParseCallToolResult rejects task responses; ParseTaskResultResult
looks for content under a non-existent nested key) by decoding the
wire shape directly via the transport
- defaults to MCPTaskModeAuto so servers that don't advertise task
support behave exactly as before
Fixes #21
1387 lines
45 KiB
Go
1387 lines
45 KiB
Go
package tools
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"maps"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
|
|
log "github.com/charmbracelet/log"
|
|
|
|
"github.com/mark3labs/kit/internal/config"
|
|
"github.com/mark3labs/mcp-go/client"
|
|
"github.com/mark3labs/mcp-go/mcp"
|
|
)
|
|
|
|
// MCPTool represents a tool discovered from an MCP server. It contains all
|
|
// the metadata needed to present the tool to an LLM (name, description, JSON
|
|
// schema) plus the server origin information needed to execute it.
|
|
type MCPTool struct {
|
|
// Name is the prefixed tool name: "serverName__toolName".
|
|
Name string
|
|
// Description is the human-readable tool description.
|
|
Description string
|
|
// Parameters is the JSON Schema properties for the tool's input.
|
|
Parameters map[string]any
|
|
// Required lists the required parameter names.
|
|
Required []string
|
|
// ServerName is the MCP server this tool belongs to.
|
|
ServerName string
|
|
// OriginalName is the unprefixed tool name on the MCP server.
|
|
OriginalName string
|
|
}
|
|
|
|
// MCPToolResult is the result of executing an MCP tool via ExecuteTool.
|
|
type MCPToolResult struct {
|
|
// Content is the JSON-encoded result from the MCP server.
|
|
Content string
|
|
// IsError indicates the MCP server reported a tool-level error.
|
|
IsError bool
|
|
}
|
|
|
|
// MCPPrompt represents a prompt discovered from an MCP server.
|
|
type MCPPrompt struct {
|
|
// Name is the prompt name on the MCP server.
|
|
Name string
|
|
// Description is the human-readable prompt description.
|
|
Description string
|
|
// Arguments lists the prompt's expected arguments.
|
|
Arguments []MCPPromptArgument
|
|
// ServerName is the MCP server this prompt belongs to.
|
|
ServerName string
|
|
}
|
|
|
|
// MCPPromptArgument describes an argument that a prompt template can accept.
|
|
type MCPPromptArgument struct {
|
|
// Name is the argument name.
|
|
Name string
|
|
// Description is a human-readable description.
|
|
Description string
|
|
// Required indicates whether this argument must be provided.
|
|
Required bool
|
|
}
|
|
|
|
// MCPPromptMessage is a single message returned by a prompt expansion.
|
|
type MCPPromptMessage struct {
|
|
// Role is "user" or "assistant".
|
|
Role string
|
|
// Content is the text content of the message.
|
|
Content string
|
|
// FileParts contains binary attachments extracted from embedded resources,
|
|
// images, or audio content blocks. Empty for text-only messages.
|
|
FileParts []MCPFilePart
|
|
}
|
|
|
|
// MCPFilePart represents a binary file attachment extracted from an MCP prompt
|
|
// content block (ImageContent, AudioContent, or EmbeddedResource with blob data).
|
|
type MCPFilePart struct {
|
|
// Filename is a best-effort name derived from the resource URI or content type.
|
|
Filename string
|
|
// Data is the raw binary content (already base64-decoded).
|
|
Data []byte
|
|
// MediaType is the MIME type (e.g. "image/png", "audio/wav").
|
|
MediaType string
|
|
}
|
|
|
|
// MCPPromptResult is the result of expanding an MCP prompt via GetPrompt.
|
|
type MCPPromptResult struct {
|
|
// Description is an optional description returned by the server.
|
|
Description string
|
|
// Messages contains the expanded prompt messages.
|
|
Messages []MCPPromptMessage
|
|
}
|
|
|
|
// MCPResource represents a resource discovered from an MCP server.
|
|
type MCPResource struct {
|
|
// URI is the unique resource identifier (e.g. "file:///path" or custom scheme).
|
|
URI string
|
|
// Name is a human-readable name for the resource.
|
|
Name string
|
|
// Description is an optional description of the resource.
|
|
Description string
|
|
// MIMEType is the MIME type of the resource, if known.
|
|
MIMEType string
|
|
// ServerName is the MCP server this resource belongs to.
|
|
ServerName string
|
|
}
|
|
|
|
// MCPResourceContent is the result of reading an MCP resource via ReadResource.
|
|
type MCPResourceContent struct {
|
|
// URI is the resource URI that was read.
|
|
URI string
|
|
// MIMEType is the MIME type of the content.
|
|
MIMEType string
|
|
// Text is the text content (non-empty for text resources).
|
|
Text string
|
|
// BlobData is the decoded binary content (non-empty for blob resources).
|
|
BlobData []byte
|
|
// IsBlob is true when the content is binary (BlobData is set).
|
|
IsBlob bool
|
|
}
|
|
|
|
// MCPToolManager manages MCP (Model Context Protocol) tools and clients across multiple servers.
|
|
// It provides a unified interface for loading, managing, and executing tools from various MCP servers,
|
|
// including stdio, SSE, streamable HTTP, and built-in server types. The manager handles connection
|
|
// pooling, health checks, tool name prefixing to avoid conflicts, and OAuth re-authorization.
|
|
// Thread-safe for concurrent tool invocations.
|
|
type MCPToolManager struct {
|
|
connectionPool *MCPConnectionPool
|
|
tools []MCPTool
|
|
toolMap map[string]*toolMapping // maps prefixed tool names to their server and original name
|
|
prompts []MCPPrompt // prompts discovered from all servers
|
|
resources []MCPResource // resources discovered from all servers
|
|
subscriptions map[string]string // resource URI → server name for active subscriptions
|
|
mu sync.Mutex // protects tools, toolMap, prompts, resources during parallel loading
|
|
authHandler MCPAuthHandler // OAuth handler for remote servers (nil = no OAuth)
|
|
tokenStoreFactory TokenStoreFactory // factory for creating per-server token stores (nil = default FileTokenStore)
|
|
config *config.Config
|
|
debug bool
|
|
debugLogger DebugLogger
|
|
|
|
// taskCfg controls task-augmented tools/call execution. The zero value
|
|
// means: auto-detect server capability, no progress callback, default
|
|
// poll/timeout.
|
|
taskCfg MCPTaskConfig
|
|
|
|
// onServerLoaded, if non-nil, is called when each server finishes loading.
|
|
// Called with server name, tool count, and error (nil on success).
|
|
onServerLoaded func(serverName string, toolCount int, err error)
|
|
|
|
// onToolsChanged, if non-nil, is called after AddServer or RemoveServer
|
|
// mutates the tool list. The agent layer uses this to trigger a rebuild
|
|
// so the LLM sees the updated tools.
|
|
onToolsChanged func()
|
|
|
|
// onResourcesChanged, if non-nil, is called when a subscribed resource
|
|
// is updated by the server.
|
|
onResourcesChanged func()
|
|
}
|
|
|
|
// toolMapping stores the mapping between prefixed tool names and their original details
|
|
type toolMapping struct {
|
|
serverName string
|
|
originalName string
|
|
serverConfig config.MCPServerConfig
|
|
}
|
|
|
|
// NewMCPToolManager creates a new MCP tool manager instance.
|
|
// Returns an initialized manager with empty tool collections ready to load tools from MCP servers.
|
|
// The manager must be configured with LoadTools before use.
|
|
func NewMCPToolManager() *MCPToolManager {
|
|
return &MCPToolManager{
|
|
tools: make([]MCPTool, 0),
|
|
toolMap: make(map[string]*toolMapping),
|
|
subscriptions: make(map[string]string),
|
|
}
|
|
}
|
|
|
|
// SetAuthHandler sets the OAuth handler for remote MCP server authentication.
|
|
// When set, remote transports (streamable HTTP, SSE) are configured with OAuth
|
|
// support, enabling automatic authorization flows when servers require authentication.
|
|
// This method should be called before LoadTools.
|
|
func (m *MCPToolManager) SetAuthHandler(handler MCPAuthHandler) {
|
|
m.authHandler = handler
|
|
}
|
|
|
|
// GetAuthHandler returns the OAuth handler for remote MCP server authentication.
|
|
// Returns nil if no handler is configured.
|
|
func (m *MCPToolManager) GetAuthHandler() MCPAuthHandler {
|
|
return m.authHandler
|
|
}
|
|
|
|
// SetTokenStoreFactory sets a custom factory for creating per-server OAuth token
|
|
// stores. When set, the factory is called for each remote MCP server instead of
|
|
// using the default file-based token store. This method should be called before
|
|
// LoadTools.
|
|
func (m *MCPToolManager) SetTokenStoreFactory(factory TokenStoreFactory) {
|
|
m.tokenStoreFactory = factory
|
|
}
|
|
|
|
// SetDebugLogger sets the debug logger for the tool manager.
|
|
// The logger will be used to output detailed debugging information about MCP connections,
|
|
// tool loading, and execution. If a connection pool exists, it will also be configured
|
|
// to use the same logger for consistent debugging output.
|
|
func (m *MCPToolManager) SetDebugLogger(logger DebugLogger) {
|
|
m.debugLogger = logger
|
|
if m.connectionPool != nil {
|
|
m.connectionPool.SetDebugLogger(logger)
|
|
}
|
|
}
|
|
|
|
// SetOnServerLoaded sets the callback that's invoked when each MCP server finishes
|
|
// loading. The callback receives the server name, tool count, and any error.
|
|
// Call this before LoadTools to receive per-server notifications.
|
|
func (m *MCPToolManager) SetOnServerLoaded(cb func(serverName string, toolCount int, err error)) {
|
|
m.onServerLoaded = cb
|
|
}
|
|
|
|
// SetOnToolsChanged sets the callback that's invoked after AddServer or
|
|
// RemoveServer mutates the tool list. The agent layer uses this to trigger
|
|
// a rebuild so the LLM sees the updated tool set.
|
|
func (m *MCPToolManager) SetOnToolsChanged(cb func()) {
|
|
m.onToolsChanged = cb
|
|
}
|
|
|
|
// SetTaskConfig sets the task-augmented tools/call configuration. Call
|
|
// this before LoadTools / AddServer if you want the per-server mode
|
|
// override and progress handler to take effect for the very first call.
|
|
// Subsequent calls replace the previous configuration wholesale.
|
|
func (m *MCPToolManager) SetTaskConfig(cfg MCPTaskConfig) {
|
|
m.taskCfg = cfg
|
|
}
|
|
|
|
// TaskConfig returns the manager's current task-augmented tools/call
|
|
// configuration. The zero value means: defer to per-server config and
|
|
// auto-detected capability, with no progress callback and default polling.
|
|
func (m *MCPToolManager) TaskConfig() MCPTaskConfig {
|
|
return m.taskCfg
|
|
}
|
|
|
|
// AddServer connects to a new MCP server at runtime and loads its tools.
|
|
// The server's tools are immediately available to the agent after this call.
|
|
// Returns the number of tools loaded from the server.
|
|
//
|
|
// If the connection pool has not been initialised yet (i.e. LoadTools was never
|
|
// called), AddServer creates one automatically using the manager's current
|
|
// configuration.
|
|
//
|
|
// Returns an error if a server with the same name is already loaded, or if
|
|
// the connection or tool loading fails.
|
|
func (m *MCPToolManager) AddServer(ctx context.Context, name string, cfg config.MCPServerConfig) (int, error) {
|
|
m.mu.Lock()
|
|
// Check for duplicate.
|
|
if _, exists := m.toolMap[name+"__"]; exists {
|
|
m.mu.Unlock()
|
|
return 0, fmt.Errorf("MCP server %q is already loaded", name)
|
|
}
|
|
// More thorough duplicate check: scan toolMap for any key with the server prefix.
|
|
prefix := name + "__"
|
|
for k := range m.toolMap {
|
|
if len(k) >= len(prefix) && k[:len(prefix)] == prefix {
|
|
m.mu.Unlock()
|
|
return 0, fmt.Errorf("MCP server %q is already loaded", name)
|
|
}
|
|
}
|
|
m.mu.Unlock()
|
|
|
|
// Lazily create the connection pool if LoadTools was never called.
|
|
m.ensureConnectionPool()
|
|
|
|
count, err := m.loadServerTools(ctx, name, cfg)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("failed to add MCP server %q: %w", name, err)
|
|
}
|
|
|
|
// Notify listeners.
|
|
if m.onServerLoaded != nil {
|
|
m.onServerLoaded(name, count, nil)
|
|
}
|
|
if m.onToolsChanged != nil {
|
|
m.onToolsChanged()
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
// RemoveServer disconnects an MCP server and removes all its tools and prompts.
|
|
// After this call the agent will no longer see or be able to call tools from
|
|
// the named server. Returns an error if the server is not loaded.
|
|
func (m *MCPToolManager) RemoveServer(name string) error {
|
|
prefix := name + "__"
|
|
|
|
m.mu.Lock()
|
|
|
|
// Check the server actually has tools or prompts loaded.
|
|
found := false
|
|
for k := range m.toolMap {
|
|
if len(k) >= len(prefix) && k[:len(prefix)] == prefix {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
// Also check prompts — a server might expose only prompts.
|
|
for _, p := range m.prompts {
|
|
if p.ServerName == name {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
if !found {
|
|
m.mu.Unlock()
|
|
return fmt.Errorf("MCP server %q is not loaded", name)
|
|
}
|
|
|
|
// Remove tools belonging to this server.
|
|
newTools := make([]MCPTool, 0, len(m.tools))
|
|
for _, t := range m.tools {
|
|
if len(t.Name) < len(prefix) || t.Name[:len(prefix)] != prefix {
|
|
newTools = append(newTools, t)
|
|
}
|
|
}
|
|
m.tools = newTools
|
|
|
|
// Remove tool mappings.
|
|
for k := range m.toolMap {
|
|
if len(k) >= len(prefix) && k[:len(prefix)] == prefix {
|
|
delete(m.toolMap, k)
|
|
}
|
|
}
|
|
|
|
// Remove prompts belonging to this server.
|
|
newPrompts := make([]MCPPrompt, 0, len(m.prompts))
|
|
for _, p := range m.prompts {
|
|
if p.ServerName != name {
|
|
newPrompts = append(newPrompts, p)
|
|
}
|
|
}
|
|
m.prompts = newPrompts
|
|
|
|
// Remove resources belonging to this server.
|
|
newResources := make([]MCPResource, 0, len(m.resources))
|
|
for _, r := range m.resources {
|
|
if r.ServerName != name {
|
|
newResources = append(newResources, r)
|
|
} else {
|
|
// Clean up any active subscription for this resource.
|
|
delete(m.subscriptions, r.URI)
|
|
}
|
|
}
|
|
m.resources = newResources
|
|
|
|
m.mu.Unlock()
|
|
|
|
// Close the connection in the pool (best-effort).
|
|
if m.connectionPool != nil {
|
|
_ = m.connectionPool.RemoveConnection(name)
|
|
}
|
|
|
|
if m.onToolsChanged != nil {
|
|
m.onToolsChanged()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ensureConnectionPool lazily creates a connection pool if one does not exist.
|
|
// This allows AddServer to work even if LoadTools was never called.
|
|
func (m *MCPToolManager) ensureConnectionPool() {
|
|
if m.connectionPool != nil {
|
|
return
|
|
}
|
|
debug := false
|
|
if m.config != nil {
|
|
debug = m.config.Debug
|
|
}
|
|
if m.debugLogger == nil {
|
|
m.debugLogger = NewSimpleDebugLogger(debug)
|
|
}
|
|
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), debug, m.authHandler, m.tokenStoreFactory)
|
|
m.connectionPool.SetDebugLogger(m.debugLogger)
|
|
}
|
|
|
|
// LoadTools loads tools from all configured MCP servers based on the provided configuration.
|
|
// It initializes the connection pool, connects to each configured server, and loads their tools.
|
|
// Tools from different servers are prefixed with the server name to avoid naming conflicts.
|
|
// Returns an error only if all configured servers fail to load; partial failures are logged as warnings.
|
|
// This method is thread-safe and idempotent.
|
|
func (m *MCPToolManager) LoadTools(ctx context.Context, cfg *config.Config) error {
|
|
// Initialize connection pool
|
|
m.config = cfg
|
|
m.debug = cfg.Debug
|
|
if m.debugLogger == nil {
|
|
m.debugLogger = NewSimpleDebugLogger(cfg.Debug)
|
|
}
|
|
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), cfg.Debug, m.authHandler, m.tokenStoreFactory)
|
|
m.connectionPool.SetDebugLogger(m.debugLogger)
|
|
|
|
// Load all servers in parallel. Each server connection (subprocess
|
|
// spawn, MCP initialize handshake, ListTools) is independent and
|
|
// typically dominated by process startup latency. Running them
|
|
// concurrently reduces total wall-clock time from O(n * avg) to
|
|
// O(max).
|
|
type serverResult struct {
|
|
name string
|
|
err error
|
|
}
|
|
|
|
results := make(chan serverResult, len(cfg.MCPServers))
|
|
var wg sync.WaitGroup
|
|
|
|
for serverName, serverConfig := range cfg.MCPServers {
|
|
wg.Add(1)
|
|
go func(name string, sc config.MCPServerConfig) {
|
|
defer wg.Done()
|
|
count, err := m.loadServerTools(ctx, name, sc)
|
|
results <- serverResult{name: name, err: err}
|
|
// Notify callback if set (for real-time UI updates).
|
|
if m.onServerLoaded != nil {
|
|
m.onServerLoaded(name, count, err)
|
|
}
|
|
}(serverName, serverConfig)
|
|
}
|
|
|
|
// Close results channel once all goroutines finish.
|
|
go func() {
|
|
wg.Wait()
|
|
close(results)
|
|
}()
|
|
|
|
var loadErrors []string
|
|
for r := range results {
|
|
if r.err != nil {
|
|
loadErrors = append(loadErrors, fmt.Sprintf("server %s: %v", r.name, r.err))
|
|
fmt.Printf("Warning: Failed to load MCP server '%s': %v\n", r.name, r.err)
|
|
}
|
|
}
|
|
|
|
// If all servers failed to load, return an error
|
|
if len(loadErrors) == len(cfg.MCPServers) && len(cfg.MCPServers) > 0 {
|
|
return fmt.Errorf("all MCP servers failed to load: %s", strings.Join(loadErrors, "; "))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// loadServerTools loads tools from a single MCP server.
|
|
// Thread-safe: may be called concurrently for different servers.
|
|
// Returns the number of tools loaded from this server, or -1 on error.
|
|
func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (int, error) {
|
|
// Add debug logging
|
|
m.debugLogConnectionInfo(serverName, serverConfig)
|
|
|
|
// Get connection from pool
|
|
conn, err := m.connectionPool.GetConnection(ctx, serverName, serverConfig)
|
|
if err != nil {
|
|
return -1, fmt.Errorf("failed to get connection from pool: %v", err)
|
|
}
|
|
|
|
// Get tools from this server
|
|
listResults, err := conn.client.ListTools(ctx, mcp.ListToolsRequest{})
|
|
if err != nil {
|
|
// Handle connection error
|
|
m.connectionPool.HandleConnectionError(serverName, err)
|
|
return -1, fmt.Errorf("failed to list tools: %v", err)
|
|
}
|
|
|
|
// Create name set for allowed tools
|
|
var nameSet map[string]struct{}
|
|
if len(serverConfig.AllowedTools) > 0 {
|
|
nameSet = make(map[string]struct{})
|
|
for _, name := range serverConfig.AllowedTools {
|
|
nameSet[name] = struct{}{}
|
|
}
|
|
}
|
|
|
|
// Build tools locally before acquiring the lock.
|
|
var localTools []MCPTool
|
|
localMap := make(map[string]*toolMapping)
|
|
|
|
// Convert MCP tools to MCPTool structs with prefixed names
|
|
for _, mcpTool := range listResults.Tools {
|
|
// Filter tools based on allowedTools/excludedTools
|
|
if len(serverConfig.AllowedTools) > 0 {
|
|
if _, ok := nameSet[mcpTool.Name]; !ok {
|
|
continue
|
|
}
|
|
}
|
|
|
|
// Check if tool should be excluded
|
|
if m.shouldExcludeTool(mcpTool.Name, serverConfig) {
|
|
continue
|
|
}
|
|
|
|
// Convert MCP InputSchema to map[string]any
|
|
marshaledSchema, err := json.Marshal(mcpTool.InputSchema)
|
|
if err != nil {
|
|
return -1, fmt.Errorf("conv mcp tool input schema fail(marshal): %w, tool name: %s", err, mcpTool.Name)
|
|
}
|
|
|
|
// Fix for JSON Schema draft-07 vs draft-04 compatibility
|
|
marshaledSchema = convertExclusiveBoundsToBoolean(marshaledSchema)
|
|
|
|
// Parse into map[string]any
|
|
var schemaMap map[string]any
|
|
if err := json.Unmarshal(marshaledSchema, &schemaMap); err != nil {
|
|
return -1, fmt.Errorf("conv mcp tool input schema fail(unmarshal): %w, tool name: %s", err, mcpTool.Name)
|
|
}
|
|
|
|
// Extract properties and required from the schema
|
|
parameters := make(map[string]any)
|
|
required := []string{}
|
|
|
|
if props, ok := schemaMap["properties"].(map[string]any); ok {
|
|
parameters = props
|
|
}
|
|
|
|
// Fix for issue #89: Ensure object schemas have a properties field.
|
|
// When schema type is "object" with no properties, we keep the
|
|
// empty parameters map.
|
|
|
|
if req, ok := schemaMap["required"].([]any); ok {
|
|
for _, r := range req {
|
|
if s, ok := r.(string); ok {
|
|
required = append(required, s)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Create prefixed tool name
|
|
prefixedName := fmt.Sprintf("%s__%s", serverName, mcpTool.Name)
|
|
|
|
// Create tool mapping
|
|
mapping := &toolMapping{
|
|
serverName: serverName,
|
|
originalName: mcpTool.Name,
|
|
serverConfig: serverConfig,
|
|
}
|
|
localMap[prefixedName] = mapping
|
|
|
|
// Create MCPTool
|
|
localTools = append(localTools, MCPTool{
|
|
Name: prefixedName,
|
|
Description: mcpTool.Description,
|
|
Parameters: parameters,
|
|
Required: required,
|
|
ServerName: serverName,
|
|
OriginalName: mcpTool.Name,
|
|
})
|
|
}
|
|
|
|
// Merge into the manager under the lock.
|
|
m.mu.Lock()
|
|
maps.Copy(m.toolMap, localMap)
|
|
m.tools = append(m.tools, localTools...)
|
|
m.mu.Unlock()
|
|
|
|
// Also load prompts from this server (best-effort, non-blocking).
|
|
m.loadServerPrompts(ctx, serverName, conn)
|
|
|
|
// Also load resources from this server (best-effort, non-blocking).
|
|
m.loadServerResources(ctx, serverName, conn)
|
|
|
|
return len(localTools), nil
|
|
}
|
|
|
|
// ExecuteTool calls an MCP tool through the connection pool, handling health
|
|
// checks, OAuth re-authorization, and connection error tracking.
|
|
// The inputJSON parameter is the raw JSON arguments from the LLM.
|
|
// Returns the result content, error flag, and any execution error.
|
|
//
|
|
// When the per-server TasksMode resolves to "always", or to "auto" and the
|
|
// server advertised tasks/toolCalls capability during initialize, the call
|
|
// is augmented with TaskParams. If the server elects to respond with a
|
|
// CreateTaskResult the manager polls tasks/get / tasks/result until the
|
|
// task reaches a terminal state, transparently presenting the final
|
|
// CallToolResult-equivalent content to the agent layer. Context
|
|
// cancellation triggers a best-effort tasks/cancel.
|
|
func (m *MCPToolManager) ExecuteTool(ctx context.Context, prefixedName, inputJSON string) (*MCPToolResult, error) {
|
|
m.mu.Lock()
|
|
mapping, ok := m.toolMap[prefixedName]
|
|
m.mu.Unlock()
|
|
if !ok {
|
|
return nil, fmt.Errorf("tool %q not found", prefixedName)
|
|
}
|
|
|
|
// Parse and validate JSON arguments
|
|
var arguments any
|
|
if inputJSON == "" || inputJSON == "{}" {
|
|
arguments = nil
|
|
} else {
|
|
var temp any
|
|
if err := json.Unmarshal([]byte(inputJSON), &temp); err != nil {
|
|
return &MCPToolResult{
|
|
Content: fmt.Sprintf("invalid JSON arguments: %v", err),
|
|
IsError: true,
|
|
}, nil
|
|
}
|
|
arguments = json.RawMessage(inputJSON)
|
|
}
|
|
|
|
// Get connection from pool with health check
|
|
conn, err := m.connectionPool.GetConnectionWithHealthCheck(
|
|
ctx, mapping.serverName, mapping.serverConfig,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get healthy connection from pool: %w", err)
|
|
}
|
|
|
|
callParams := mcp.CallToolParams{
|
|
Name: mapping.originalName,
|
|
Arguments: arguments,
|
|
}
|
|
|
|
// Decide whether to augment the request with TaskParams. Modes:
|
|
// never — never augment (synchronous-only).
|
|
// always — always augment, even without server capability.
|
|
// auto — augment only when the server advertised tasks/toolCalls.
|
|
mode := m.resolveTaskMode(mapping.serverName, mapping.serverConfig)
|
|
useTask := mode == MCPTaskModeAlways ||
|
|
(mode == MCPTaskModeAuto && conn.SupportsToolTasks())
|
|
if useTask {
|
|
var ttl *int64
|
|
if m.taskCfg.DefaultTTL > 0 {
|
|
ms := m.taskCfg.DefaultTTL.Milliseconds()
|
|
ttl = &ms
|
|
}
|
|
callParams.Task = &mcp.TaskParams{TTL: ttl}
|
|
}
|
|
|
|
// Synchronous fast path: no task augmentation. Use the upstream client
|
|
// helper which keeps content-block typing identical to historical
|
|
// behaviour.
|
|
if !useTask {
|
|
callRequest := mcp.CallToolRequest{
|
|
Request: mcp.Request{Method: "tools/call"},
|
|
Params: callParams,
|
|
}
|
|
result, callErr := conn.client.CallTool(ctx, callRequest)
|
|
if callErr != nil {
|
|
if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) {
|
|
if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, mapping.serverName, callErr); flowErr != nil {
|
|
return nil, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", mapping.originalName, flowErr)
|
|
}
|
|
result, callErr = conn.client.CallTool(ctx, callRequest)
|
|
if callErr != nil {
|
|
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
|
return nil, fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr)
|
|
}
|
|
} else {
|
|
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
|
return nil, fmt.Errorf("failed to call mcp tool: %w", callErr)
|
|
}
|
|
}
|
|
marshaledResult, mErr := json.Marshal(result)
|
|
if mErr != nil {
|
|
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
|
|
}
|
|
return &MCPToolResult{
|
|
Content: string(marshaledResult),
|
|
IsError: result.IsError,
|
|
}, nil
|
|
}
|
|
|
|
// Task-augmented path. Bypass the upstream CallTool helper because its
|
|
// ParseCallToolResult requires a "content" field that is absent from a
|
|
// CreateTaskResult.
|
|
rawClient, ok := conn.client.(*client.Client)
|
|
if !ok {
|
|
// Older client implementations — fall back to the synchronous shape.
|
|
callParams.Task = nil
|
|
callRequest := mcp.CallToolRequest{
|
|
Request: mcp.Request{Method: "tools/call"},
|
|
Params: callParams,
|
|
}
|
|
result, callErr := conn.client.CallTool(ctx, callRequest)
|
|
if callErr != nil {
|
|
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
|
return nil, fmt.Errorf("failed to call mcp tool: %w", callErr)
|
|
}
|
|
marshaledResult, mErr := json.Marshal(result)
|
|
if mErr != nil {
|
|
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
|
|
}
|
|
return &MCPToolResult{Content: string(marshaledResult), IsError: result.IsError}, nil
|
|
}
|
|
|
|
callResult, taskResult, callErr := callToolWithTask(ctx, rawClient, callParams)
|
|
if callErr != nil {
|
|
if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) {
|
|
if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, mapping.serverName, callErr); flowErr != nil {
|
|
return nil, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", mapping.originalName, flowErr)
|
|
}
|
|
callResult, taskResult, callErr = callToolWithTask(ctx, rawClient, callParams)
|
|
if callErr != nil {
|
|
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
|
return nil, fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr)
|
|
}
|
|
} else {
|
|
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
|
return nil, fmt.Errorf("failed to call mcp tool: %w", callErr)
|
|
}
|
|
}
|
|
|
|
// Server chose to answer synchronously — same shape as the no-task path.
|
|
if callResult != nil {
|
|
marshaledResult, mErr := json.Marshal(callResult)
|
|
if mErr != nil {
|
|
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
|
|
}
|
|
return &MCPToolResult{
|
|
Content: string(marshaledResult),
|
|
IsError: callResult.IsError,
|
|
}, nil
|
|
}
|
|
|
|
// Asynchronous task path: poll until terminal, then return the result.
|
|
if taskResult == nil {
|
|
return nil, errors.New("mcp tools/call returned neither result nor task")
|
|
}
|
|
final, pollErr := pollTaskUntilTerminal(
|
|
ctx, rawClient, mapping.serverName, taskResult.Task,
|
|
m.taskCfg, m.taskCfg.Progress,
|
|
)
|
|
if pollErr != nil {
|
|
return nil, fmt.Errorf("task execution failed: %w", pollErr)
|
|
}
|
|
|
|
// Adapt TaskResultResult → CallToolResult for downstream JSON shape parity.
|
|
adapted := &mcp.CallToolResult{
|
|
Content: final.Content,
|
|
StructuredContent: final.StructuredContent,
|
|
IsError: final.IsError,
|
|
}
|
|
marshaledResult, mErr := json.Marshal(adapted)
|
|
if mErr != nil {
|
|
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
|
|
}
|
|
return &MCPToolResult{
|
|
Content: string(marshaledResult),
|
|
IsError: final.IsError,
|
|
}, nil
|
|
}
|
|
|
|
// resolveTaskMode resolves the effective task mode for a given server.
|
|
// Programmatic overrides via SetTaskConfig take precedence over the
|
|
// per-server TasksMode in MCPServerConfig. Empty / unknown values map to
|
|
// MCPTaskModeAuto.
|
|
func (m *MCPToolManager) resolveTaskMode(name string, cfg config.MCPServerConfig) MCPTaskMode {
|
|
if m.taskCfg.PerServerMode != nil {
|
|
if v, ok := m.taskCfg.PerServerMode[name]; ok {
|
|
return v
|
|
}
|
|
}
|
|
return ParseTaskMode(cfg.TasksMode)
|
|
}
|
|
|
|
// ListServerTasks queries tasks/list on the named server and returns the
|
|
// active and recent tasks the server is willing to disclose. Errors are
|
|
// returned untouched (callers commonly ignore METHOD_NOT_FOUND when the
|
|
// server didn't advertise tasks/list capability).
|
|
func (m *MCPToolManager) ListServerTasks(ctx context.Context, serverName string) ([]MCPTaskInfo, error) {
|
|
c, err := m.taskClient(serverName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
res, err := c.ListTasks(ctx, mcp.ListTasksRequest{})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("tasks/list on %s: %w", serverName, err)
|
|
}
|
|
out := make([]MCPTaskInfo, 0, len(res.Tasks))
|
|
for _, t := range res.Tasks {
|
|
out = append(out, taskFromMCP(serverName, t))
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
// GetServerTask queries tasks/get for a single task on the named server.
|
|
func (m *MCPToolManager) GetServerTask(ctx context.Context, serverName, taskID string) (MCPTaskInfo, error) {
|
|
c, err := m.taskClient(serverName)
|
|
if err != nil {
|
|
return MCPTaskInfo{}, err
|
|
}
|
|
res, err := c.GetTask(ctx, mcp.GetTaskRequest{Params: mcp.GetTaskParams{TaskId: taskID}})
|
|
if err != nil {
|
|
return MCPTaskInfo{}, fmt.Errorf("tasks/get on %s: %w", serverName, err)
|
|
}
|
|
return taskFromMCP(serverName, res.Task), nil
|
|
}
|
|
|
|
// CancelServerTask issues tasks/cancel for a task on the named server.
|
|
// Returns the post-cancel task state when the server responded with one.
|
|
func (m *MCPToolManager) CancelServerTask(ctx context.Context, serverName, taskID string) (MCPTaskInfo, error) {
|
|
c, err := m.taskClient(serverName)
|
|
if err != nil {
|
|
return MCPTaskInfo{}, err
|
|
}
|
|
res, err := c.CancelTask(ctx, mcp.CancelTaskRequest{Params: mcp.CancelTaskParams{TaskId: taskID}})
|
|
if err != nil {
|
|
return MCPTaskInfo{}, fmt.Errorf("tasks/cancel on %s: %w", serverName, err)
|
|
}
|
|
return taskFromMCP(serverName, res.Task), nil
|
|
}
|
|
|
|
// taskClient returns the *client.Client for a server. Tasks endpoints are
|
|
// not part of the upstream MCPClient interface so callers must work with
|
|
// the concrete client. Returns an error when the connection is missing
|
|
// or backed by a non-standard client type.
|
|
func (m *MCPToolManager) taskClient(serverName string) (*client.Client, error) {
|
|
if m.connectionPool == nil {
|
|
return nil, fmt.Errorf("no connection pool available")
|
|
}
|
|
clients := m.connectionPool.GetClients()
|
|
raw, ok := clients[serverName]
|
|
if !ok {
|
|
return nil, fmt.Errorf("MCP server %q not loaded", serverName)
|
|
}
|
|
c, ok := raw.(*client.Client)
|
|
if !ok {
|
|
return nil, fmt.Errorf("MCP server %q does not support task RPCs", serverName)
|
|
}
|
|
return c, nil
|
|
}
|
|
|
|
// GetTools returns all loaded MCP tools from all configured MCP servers.
|
|
// Tools are returned with their prefixed names (serverName__toolName) to ensure uniqueness.
|
|
func (m *MCPToolManager) GetTools() []MCPTool {
|
|
return m.tools
|
|
}
|
|
|
|
// GetPrompts returns all prompts discovered from connected MCP servers.
|
|
func (m *MCPToolManager) GetPrompts() []MCPPrompt {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
result := make([]MCPPrompt, len(m.prompts))
|
|
copy(result, m.prompts)
|
|
return result
|
|
}
|
|
|
|
// GetPrompt retrieves and expands a specific prompt from an MCP server.
|
|
// The serverName identifies which server to query, promptName is the prompt's
|
|
// name on that server, and args are the template arguments to substitute.
|
|
// This call is lazy — it contacts the MCP server on each invocation.
|
|
func (m *MCPToolManager) GetPrompt(ctx context.Context, serverName, promptName string, args map[string]string) (*MCPPromptResult, error) {
|
|
if m.connectionPool == nil {
|
|
return nil, fmt.Errorf("no connection pool available")
|
|
}
|
|
|
|
clients := m.connectionPool.GetClients()
|
|
mcpClient, ok := clients[serverName]
|
|
if !ok {
|
|
return nil, fmt.Errorf("MCP server %q not found", serverName)
|
|
}
|
|
|
|
req := mcp.GetPromptRequest{}
|
|
req.Params.Name = promptName
|
|
if len(args) > 0 {
|
|
req.Params.Arguments = args
|
|
}
|
|
|
|
result, err := mcpClient.GetPrompt(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get prompt %q from server %q: %w", promptName, serverName, err)
|
|
}
|
|
|
|
// Convert MCP messages to our types, extracting all content types.
|
|
var messages []MCPPromptMessage
|
|
for _, msg := range result.Messages {
|
|
text, fileParts := extractPromptContent(msg.Content)
|
|
if text != "" || len(fileParts) > 0 {
|
|
messages = append(messages, MCPPromptMessage{
|
|
Role: string(msg.Role),
|
|
Content: text,
|
|
FileParts: fileParts,
|
|
})
|
|
}
|
|
}
|
|
|
|
return &MCPPromptResult{
|
|
Description: result.Description,
|
|
Messages: messages,
|
|
}, nil
|
|
}
|
|
|
|
// extractPromptContent extracts text and binary attachments from an MCP Content value.
|
|
// Handles all MCP content types: TextContent, ImageContent, AudioContent,
|
|
// EmbeddedResource (text and blob), and ResourceLink.
|
|
func extractPromptContent(content mcp.Content) (string, []MCPFilePart) {
|
|
switch c := content.(type) {
|
|
case mcp.TextContent:
|
|
return c.Text, nil
|
|
case *mcp.TextContent:
|
|
if c != nil {
|
|
return c.Text, nil
|
|
}
|
|
return "", nil
|
|
|
|
case mcp.ImageContent:
|
|
return "", decodeBase64FilePart(c.Data, c.MIMEType, "image/png", "image.png")
|
|
case *mcp.ImageContent:
|
|
if c != nil {
|
|
return "", decodeBase64FilePart(c.Data, c.MIMEType, "image/png", "image.png")
|
|
}
|
|
return "", nil
|
|
|
|
case mcp.AudioContent:
|
|
return "", decodeBase64FilePart(c.Data, c.MIMEType, "audio/wav", "audio.wav")
|
|
case *mcp.AudioContent:
|
|
if c != nil {
|
|
return "", decodeBase64FilePart(c.Data, c.MIMEType, "audio/wav", "audio.wav")
|
|
}
|
|
return "", nil
|
|
|
|
case mcp.EmbeddedResource:
|
|
return extractEmbeddedResourceContent(c.Resource)
|
|
case *mcp.EmbeddedResource:
|
|
if c != nil {
|
|
return extractEmbeddedResourceContent(c.Resource)
|
|
}
|
|
return "", nil
|
|
|
|
case mcp.ResourceLink:
|
|
// ResourceLink is a reference without inline content — include as a
|
|
// text annotation so the LLM knows about it.
|
|
return fmt.Sprintf("[Referenced resource: %s (%s)]", c.URI, c.Name), nil
|
|
case *mcp.ResourceLink:
|
|
if c != nil {
|
|
return fmt.Sprintf("[Referenced resource: %s (%s)]", c.URI, c.Name), nil
|
|
}
|
|
return "", nil
|
|
|
|
default:
|
|
return "", nil
|
|
}
|
|
}
|
|
|
|
// extractEmbeddedResourceContent handles the two variants of embedded resource
|
|
// content: text resources are inlined as fenced code blocks, blob resources
|
|
// are base64-decoded into MCPFilePart attachments.
|
|
func extractEmbeddedResourceContent(res mcp.ResourceContents) (string, []MCPFilePart) {
|
|
switch r := res.(type) {
|
|
case mcp.TextResourceContents:
|
|
return fmt.Sprintf("[File: %s]\n```\n%s\n```", r.URI, r.Text), nil
|
|
case *mcp.TextResourceContents:
|
|
if r != nil {
|
|
return fmt.Sprintf("[File: %s]\n```\n%s\n```", r.URI, r.Text), nil
|
|
}
|
|
return "", nil
|
|
case mcp.BlobResourceContents:
|
|
return "", decodeBase64FilePart(r.Blob, r.MIMEType, "application/octet-stream", filenameFromURI(r.URI))
|
|
case *mcp.BlobResourceContents:
|
|
if r != nil {
|
|
return "", decodeBase64FilePart(r.Blob, r.MIMEType, "application/octet-stream", filenameFromURI(r.URI))
|
|
}
|
|
return "", nil
|
|
default:
|
|
return "", nil
|
|
}
|
|
}
|
|
|
|
// decodeBase64FilePart decodes base64-encoded data into an MCPFilePart.
|
|
// Returns nil on decode failure (logged as a warning).
|
|
func decodeBase64FilePart(data, mimeType, defaultMIME, filename string) []MCPFilePart {
|
|
decoded, err := base64.StdEncoding.DecodeString(data)
|
|
if err != nil {
|
|
log.Warn("mcp prompt: failed to decode base64 content", "filename", filename, "error", err)
|
|
return nil
|
|
}
|
|
if mimeType == "" {
|
|
mimeType = defaultMIME
|
|
}
|
|
return []MCPFilePart{{
|
|
Filename: filename,
|
|
Data: decoded,
|
|
MediaType: mimeType,
|
|
}}
|
|
}
|
|
|
|
// filenameFromURI extracts a filename from a URI (e.g. "file:///path/to/img.png" → "img.png").
|
|
func filenameFromURI(uri string) string {
|
|
uri = strings.TrimPrefix(uri, "file://")
|
|
if idx := strings.LastIndex(uri, "/"); idx >= 0 {
|
|
return uri[idx+1:]
|
|
}
|
|
if uri == "" {
|
|
return "resource"
|
|
}
|
|
return uri
|
|
}
|
|
|
|
// loadServerPrompts loads prompts from a single MCP server connection.
|
|
// Called inside loadServerTools after a successful connection is established.
|
|
// Thread-safe: acquires m.mu to merge results.
|
|
func (m *MCPToolManager) loadServerPrompts(ctx context.Context, serverName string, conn *MCPConnection) {
|
|
listResult, err := conn.client.ListPrompts(ctx, mcp.ListPromptsRequest{})
|
|
if err != nil {
|
|
// Prompts are optional — servers may not support them.
|
|
// Silently skip.
|
|
return
|
|
}
|
|
|
|
if len(listResult.Prompts) == 0 {
|
|
return
|
|
}
|
|
|
|
var localPrompts []MCPPrompt
|
|
for _, p := range listResult.Prompts {
|
|
var args []MCPPromptArgument
|
|
for _, a := range p.Arguments {
|
|
args = append(args, MCPPromptArgument{
|
|
Name: a.Name,
|
|
Description: a.Description,
|
|
Required: a.Required,
|
|
})
|
|
}
|
|
localPrompts = append(localPrompts, MCPPrompt{
|
|
Name: p.Name,
|
|
Description: p.Description,
|
|
Arguments: args,
|
|
ServerName: serverName,
|
|
})
|
|
}
|
|
|
|
m.mu.Lock()
|
|
m.prompts = append(m.prompts, localPrompts...)
|
|
m.mu.Unlock()
|
|
}
|
|
|
|
// loadServerResources loads resources from a single MCP server connection.
|
|
// Called inside loadServerTools after a successful connection is established.
|
|
// Thread-safe: acquires m.mu to merge results.
|
|
func (m *MCPToolManager) loadServerResources(ctx context.Context, serverName string, conn *MCPConnection) {
|
|
listResult, err := conn.client.ListResources(ctx, mcp.ListResourcesRequest{})
|
|
if err != nil {
|
|
// Resources are optional — servers may not support them.
|
|
return
|
|
}
|
|
|
|
if len(listResult.Resources) == 0 {
|
|
return
|
|
}
|
|
|
|
var localResources []MCPResource
|
|
for _, r := range listResult.Resources {
|
|
localResources = append(localResources, MCPResource{
|
|
URI: r.URI,
|
|
Name: r.Name,
|
|
Description: r.Description,
|
|
MIMEType: r.MIMEType,
|
|
ServerName: serverName,
|
|
})
|
|
}
|
|
|
|
m.mu.Lock()
|
|
m.resources = append(m.resources, localResources...)
|
|
m.mu.Unlock()
|
|
}
|
|
|
|
// GetResources returns all resources discovered from connected MCP servers.
|
|
func (m *MCPToolManager) GetResources() []MCPResource {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
result := make([]MCPResource, len(m.resources))
|
|
copy(result, m.resources)
|
|
return result
|
|
}
|
|
|
|
// SetOnResourcesChanged sets the callback invoked when a subscribed resource
|
|
// changes. Used by the UI layer to refresh autocomplete or re-read content.
|
|
func (m *MCPToolManager) SetOnResourcesChanged(cb func()) {
|
|
m.onResourcesChanged = cb
|
|
}
|
|
|
|
// ReadResource reads a specific resource from an MCP server by URI.
|
|
// Returns the resource content (text or binary blob).
|
|
func (m *MCPToolManager) ReadResource(ctx context.Context, serverName, uri string) (*MCPResourceContent, error) {
|
|
if m.connectionPool == nil {
|
|
return nil, fmt.Errorf("no connection pool available")
|
|
}
|
|
|
|
clients := m.connectionPool.GetClients()
|
|
mcpClient, ok := clients[serverName]
|
|
if !ok {
|
|
return nil, fmt.Errorf("MCP server %q not found", serverName)
|
|
}
|
|
|
|
req := mcp.ReadResourceRequest{}
|
|
req.Params.URI = uri
|
|
|
|
result, err := mcpClient.ReadResource(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read resource %q from server %q: %w", uri, serverName, err)
|
|
}
|
|
|
|
if len(result.Contents) == 0 {
|
|
return nil, fmt.Errorf("resource %q returned no content", uri)
|
|
}
|
|
|
|
// Process the first content item (most resources return exactly one).
|
|
content := result.Contents[0]
|
|
switch c := content.(type) {
|
|
case mcp.TextResourceContents:
|
|
return &MCPResourceContent{
|
|
URI: c.URI,
|
|
MIMEType: c.MIMEType,
|
|
Text: c.Text,
|
|
IsBlob: false,
|
|
}, nil
|
|
case *mcp.TextResourceContents:
|
|
if c == nil {
|
|
return nil, fmt.Errorf("resource %q returned nil text content", uri)
|
|
}
|
|
return &MCPResourceContent{
|
|
URI: c.URI,
|
|
MIMEType: c.MIMEType,
|
|
Text: c.Text,
|
|
IsBlob: false,
|
|
}, nil
|
|
case mcp.BlobResourceContents:
|
|
decoded, err := base64.StdEncoding.DecodeString(c.Blob)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decode blob resource %q: %w", uri, err)
|
|
}
|
|
return &MCPResourceContent{
|
|
URI: c.URI,
|
|
MIMEType: c.MIMEType,
|
|
BlobData: decoded,
|
|
IsBlob: true,
|
|
}, nil
|
|
case *mcp.BlobResourceContents:
|
|
if c == nil {
|
|
return nil, fmt.Errorf("resource %q returned nil blob content", uri)
|
|
}
|
|
decoded, err := base64.StdEncoding.DecodeString(c.Blob)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decode blob resource %q: %w", uri, err)
|
|
}
|
|
return &MCPResourceContent{
|
|
URI: c.URI,
|
|
MIMEType: c.MIMEType,
|
|
BlobData: decoded,
|
|
IsBlob: true,
|
|
}, nil
|
|
default:
|
|
return nil, fmt.Errorf("resource %q returned unknown content type %T", uri, content)
|
|
}
|
|
}
|
|
|
|
// SubscribeResource subscribes to change notifications for a resource.
|
|
// When the resource changes on the server, onResourcesChanged is called
|
|
// and the resource list is refreshed automatically.
|
|
func (m *MCPToolManager) SubscribeResource(ctx context.Context, serverName, uri string) error {
|
|
if m.connectionPool == nil {
|
|
return fmt.Errorf("no connection pool available")
|
|
}
|
|
|
|
clients := m.connectionPool.GetClients()
|
|
mcpClient, ok := clients[serverName]
|
|
if !ok {
|
|
return fmt.Errorf("MCP server %q not found", serverName)
|
|
}
|
|
|
|
req := mcp.SubscribeRequest{}
|
|
req.Params.URI = uri
|
|
|
|
if err := mcpClient.Subscribe(ctx, req); err != nil {
|
|
return fmt.Errorf("failed to subscribe to resource %q on server %q: %w", uri, serverName, err)
|
|
}
|
|
|
|
m.mu.Lock()
|
|
m.subscriptions[uri] = serverName
|
|
m.mu.Unlock()
|
|
|
|
return nil
|
|
}
|
|
|
|
// UnsubscribeResource cancels change notifications for a resource.
|
|
func (m *MCPToolManager) UnsubscribeResource(ctx context.Context, serverName, uri string) error {
|
|
if m.connectionPool == nil {
|
|
return fmt.Errorf("no connection pool available")
|
|
}
|
|
|
|
clients := m.connectionPool.GetClients()
|
|
mcpClient, ok := clients[serverName]
|
|
if !ok {
|
|
return fmt.Errorf("MCP server %q not found", serverName)
|
|
}
|
|
|
|
req := mcp.UnsubscribeRequest{}
|
|
req.Params.URI = uri
|
|
|
|
if err := mcpClient.Unsubscribe(ctx, req); err != nil {
|
|
return fmt.Errorf("failed to unsubscribe from resource %q on server %q: %w", uri, serverName, err)
|
|
}
|
|
|
|
m.mu.Lock()
|
|
delete(m.subscriptions, uri)
|
|
m.mu.Unlock()
|
|
|
|
return nil
|
|
}
|
|
|
|
// RefreshServerResources re-fetches resources from a specific server.
|
|
// Called when a resource change notification is received.
|
|
func (m *MCPToolManager) RefreshServerResources(ctx context.Context, serverName string) {
|
|
if m.connectionPool == nil {
|
|
return
|
|
}
|
|
|
|
clients := m.connectionPool.GetClients()
|
|
mcpClient, ok := clients[serverName]
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
listResult, err := mcpClient.ListResources(ctx, mcp.ListResourcesRequest{})
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
var newResources []MCPResource
|
|
for _, r := range listResult.Resources {
|
|
newResources = append(newResources, MCPResource{
|
|
URI: r.URI,
|
|
Name: r.Name,
|
|
Description: r.Description,
|
|
MIMEType: r.MIMEType,
|
|
ServerName: serverName,
|
|
})
|
|
}
|
|
|
|
m.mu.Lock()
|
|
// Remove old resources from this server, add new ones.
|
|
filtered := make([]MCPResource, 0, len(m.resources))
|
|
for _, r := range m.resources {
|
|
if r.ServerName != serverName {
|
|
filtered = append(filtered, r)
|
|
}
|
|
}
|
|
m.resources = append(filtered, newResources...)
|
|
m.mu.Unlock()
|
|
|
|
if m.onResourcesChanged != nil {
|
|
m.onResourcesChanged()
|
|
}
|
|
}
|
|
|
|
// GetLoadedServerNames returns the names of all successfully loaded MCP servers.
|
|
// This includes servers that are currently connected and have had their tools loaded,
|
|
// regardless of their current health status. Useful for debugging and status reporting.
|
|
func (m *MCPToolManager) GetLoadedServerNames() []string {
|
|
var names []string
|
|
for serverName := range m.connectionPool.GetClients() {
|
|
names = append(names, serverName)
|
|
}
|
|
return names
|
|
}
|
|
|
|
// Close closes all MCP client connections and cleans up resources.
|
|
// This method should be called when the tool manager is no longer needed to ensure
|
|
// proper cleanup of stdio processes, network connections, and other resources.
|
|
// It is safe to call Close multiple times.
|
|
func (m *MCPToolManager) Close() error {
|
|
if m.connectionPool == nil {
|
|
return nil
|
|
}
|
|
return m.connectionPool.Close()
|
|
}
|
|
|
|
// shouldExcludeTool determines if a tool should be excluded based on excludedTools
|
|
func (m *MCPToolManager) shouldExcludeTool(toolName string, serverConfig config.MCPServerConfig) bool {
|
|
if len(serverConfig.ExcludedTools) > 0 {
|
|
if slices.Contains(serverConfig.ExcludedTools, toolName) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// debugLogConnectionInfo logs detailed connection information for debugging
|
|
func (m *MCPToolManager) debugLogConnectionInfo(serverName string, serverConfig config.MCPServerConfig) {
|
|
if m.debugLogger == nil || !m.debugLogger.IsDebugEnabled() {
|
|
return
|
|
}
|
|
|
|
m.debugLogger.LogDebug(fmt.Sprintf("[DEBUG] Connecting to MCP server: %s", serverName))
|
|
m.debugLogger.LogDebug(fmt.Sprintf("[DEBUG] Transport type: %s", serverConfig.GetTransportType()))
|
|
|
|
switch serverConfig.GetTransportType() {
|
|
case "stdio":
|
|
if len(serverConfig.Command) > 0 {
|
|
m.debugLogger.LogDebug(fmt.Sprintf("[DEBUG] Command: %s %v", serverConfig.Command[0], serverConfig.Command[1:]))
|
|
}
|
|
if len(serverConfig.Environment) > 0 {
|
|
m.debugLogger.LogDebug(fmt.Sprintf("[DEBUG] Environment variables: %d", len(serverConfig.Environment)))
|
|
}
|
|
case "sse", "streamable":
|
|
m.debugLogger.LogDebug(fmt.Sprintf("[DEBUG] URL: %s", serverConfig.URL))
|
|
if len(serverConfig.Headers) > 0 {
|
|
m.debugLogger.LogDebug(fmt.Sprintf("[DEBUG] Headers: %v", serverConfig.Headers))
|
|
}
|
|
}
|
|
}
|
|
|
|
// convertExclusiveBoundsToBoolean converts JSON Schema draft-07 style exclusive bounds
|
|
// (where exclusiveMinimum/exclusiveMaximum are numbers) to draft-04 style
|
|
// (where they are booleans that modify minimum/maximum).
|
|
func convertExclusiveBoundsToBoolean(schemaJSON []byte) []byte {
|
|
var data map[string]any
|
|
if err := json.Unmarshal(schemaJSON, &data); err != nil {
|
|
return schemaJSON
|
|
}
|
|
|
|
convertSchemaRecursive(data)
|
|
|
|
result, err := json.Marshal(data)
|
|
if err != nil {
|
|
return schemaJSON
|
|
}
|
|
return result
|
|
}
|
|
|
|
// convertSchemaRecursive recursively processes a schema map to:
|
|
// - Convert numeric exclusiveMinimum/exclusiveMaximum to boolean format (draft-07 → draft-04)
|
|
// - Remove null "required" fields that cause OpenAI API validation errors
|
|
func convertSchemaRecursive(schema map[string]any) {
|
|
if exMin, ok := schema["exclusiveMinimum"]; ok {
|
|
if num, isNum := exMin.(float64); isNum {
|
|
schema["minimum"] = num
|
|
schema["exclusiveMinimum"] = true
|
|
}
|
|
}
|
|
|
|
if exMax, ok := schema["exclusiveMaximum"]; ok {
|
|
if num, isNum := exMax.(float64); isNum {
|
|
schema["maximum"] = num
|
|
schema["exclusiveMaximum"] = true
|
|
}
|
|
}
|
|
|
|
// Fix null "required" fields — OpenAI rejects "required": null,
|
|
// it must be an array or absent entirely.
|
|
if req, exists := schema["required"]; exists {
|
|
if req == nil {
|
|
delete(schema, "required")
|
|
} else if _, isArr := req.([]any); !isArr {
|
|
// Not an array — remove invalid value
|
|
delete(schema, "required")
|
|
}
|
|
}
|
|
|
|
if props, ok := schema["properties"].(map[string]any); ok {
|
|
for _, prop := range props {
|
|
if propSchema, ok := prop.(map[string]any); ok {
|
|
convertSchemaRecursive(propSchema)
|
|
}
|
|
}
|
|
}
|
|
|
|
if items, ok := schema["items"].(map[string]any); ok {
|
|
convertSchemaRecursive(items)
|
|
}
|
|
|
|
if addProps, ok := schema["additionalProperties"].(map[string]any); ok {
|
|
convertSchemaRecursive(addProps)
|
|
}
|
|
|
|
for _, key := range []string{"allOf", "anyOf", "oneOf"} {
|
|
if arr, ok := schema[key].([]any); ok {
|
|
for _, item := range arr {
|
|
if itemSchema, ok := item.(map[string]any); ok {
|
|
convertSchemaRecursive(itemSchema)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if not, ok := schema["not"].(map[string]any); ok {
|
|
convertSchemaRecursive(not)
|
|
}
|
|
}
|