diff --git a/internal/agent/agent.go b/internal/agent/agent.go index c059d5b1..3955d6e0 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -59,6 +59,11 @@ type AgentConfig struct { // loading (successfully or with error). The callback receives the server // name, tool count, and any error. Called from the background goroutine. OnMCPServerLoaded func(serverName string, toolCount int, err error) + + // MCPTaskConfig configures task-augmented tools/call execution. The + // zero value preserves historical synchronous-only behaviour for any + // server that didn't advertise task support during initialize. + MCPTaskConfig tools.MCPTaskConfig } // ToolCallHandler is a function type for handling tool calls as they happen. @@ -231,6 +236,10 @@ type Agent struct { authHandler tools.MCPAuthHandler tokenStoreFactory tools.TokenStoreFactory + // mcpTaskConfig is stored from AgentConfig so AddMCPServer() can + // propagate it to a lazily-created MCPToolManager. + mcpTaskConfig tools.MCPTaskConfig + // mcpReady is closed when background MCP tool loading completes (success // or failure). nil when no MCP servers are configured. mcpReady chan struct{} @@ -329,6 +338,7 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) { modelConfig: agentConfig.ModelConfig, authHandler: agentConfig.AuthHandler, tokenStoreFactory: agentConfig.TokenStoreFactory, + mcpTaskConfig: agentConfig.MCPTaskConfig, } // Start MCP tool loading in the background if servers are configured. @@ -348,6 +358,8 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) { if agentConfig.OnMCPServerLoaded != nil { toolManager.SetOnServerLoaded(agentConfig.OnMCPServerLoaded) } + // Apply task-augmented tool execution config (zero value = no-op). + toolManager.SetTaskConfig(agentConfig.MCPTaskConfig) a.toolManager = toolManager a.mcpReady = make(chan struct{}) @@ -1134,6 +1146,7 @@ func (a *Agent) AddMCPServer(ctx context.Context, name string, cfg config.MCPSer if a.tokenStoreFactory != nil { a.toolManager.SetTokenStoreFactory(a.tokenStoreFactory) } + a.toolManager.SetTaskConfig(a.mcpTaskConfig) a.toolManager.SetOnToolsChanged(func() { a.rebuildFantasyAgent() }) diff --git a/internal/agent/factory.go b/internal/agent/factory.go index 6689a4b0..cb3d692f 100644 --- a/internal/agent/factory.go +++ b/internal/agent/factory.go @@ -56,6 +56,8 @@ type AgentCreationOptions struct { // OnMCPServerLoaded, if non-nil, is called when each MCP server finishes // loading (successfully or with error). Called from the background goroutine. OnMCPServerLoaded func(serverName string, toolCount int, err error) + // MCPTaskConfig configures task-augmented tools/call execution. + MCPTaskConfig tools.MCPTaskConfig } // CreateAgent creates an agent with optional spinner for Ollama models. @@ -76,6 +78,7 @@ func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error ToolWrapper: opts.ToolWrapper, ExtraTools: opts.ExtraTools, OnMCPServerLoaded: opts.OnMCPServerLoaded, + MCPTaskConfig: opts.MCPTaskConfig, } var agent *Agent diff --git a/internal/config/config.go b/internal/config/config.go index 39ae98f8..020a4707 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -38,6 +38,23 @@ type MCPServerConfig struct { // servers that don't support it. NoOAuth bool `json:"noOAuth,omitempty" yaml:"noOAuth,omitempty"` + // TasksMode controls when this server's tools/call requests are augmented + // with MCP task metadata (turning a synchronous call into an asynchronous, + // pollable job — see https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks). + // + // Valid values: + // - "" or "auto": (default) augment requests with task metadata only + // when the server advertises tasks/toolCalls capability during initialize. + // - "never": never augment — every tool call is synchronous, regardless + // of server capability. + // - "always": always augment, even when the server didn't advertise + // task support. The server may still respond synchronously; this just + // opts in unconditionally on the client side. + // + // In all modes, when the server returns a CreateTaskResult the client polls + // tasks/get / tasks/result until the task reaches a terminal state. + TasksMode string `json:"tasksMode,omitempty" yaml:"tasksMode,omitempty"` + // InProcessServer holds a live *server.MCPServer for in-process transport. // When set (and Type is "inprocess"), the connection pool creates an // in-process client instead of spawning a subprocess or making HTTP calls. @@ -68,6 +85,7 @@ func (s *MCPServerConfig) UnmarshalJSON(data []byte) error { OAuthClientSecret string `json:"oauthClientSecret,omitempty" yaml:"oauthClientSecret,omitempty"` OAuthScopes []string `json:"oauthScopes,omitempty" yaml:"oauthScopes,omitempty"` NoOAuth bool `json:"noOAuth,omitempty" yaml:"noOAuth,omitempty"` + TasksMode string `json:"tasksMode,omitempty" yaml:"tasksMode,omitempty"` } // Also try legacy format @@ -80,6 +98,7 @@ func (s *MCPServerConfig) UnmarshalJSON(data []byte) error { Headers []string `json:"headers,omitempty"` AllowedTools []string `json:"allowedTools,omitempty" yaml:"allowedTools,omitempty"` ExcludedTools []string `json:"excludedTools,omitempty" yaml:"excludedTools,omitempty"` + TasksMode string `json:"tasksMode,omitempty" yaml:"tasksMode,omitempty"` } // Try new format first @@ -96,6 +115,7 @@ func (s *MCPServerConfig) UnmarshalJSON(data []byte) error { s.OAuthClientSecret = newConfig.OAuthClientSecret s.OAuthScopes = newConfig.OAuthScopes s.NoOAuth = newConfig.NoOAuth + s.TasksMode = newConfig.TasksMode return nil } @@ -116,6 +136,7 @@ func (s *MCPServerConfig) UnmarshalJSON(data []byte) error { s.Headers = legacyConfig.Headers s.AllowedTools = legacyConfig.AllowedTools s.ExcludedTools = legacyConfig.ExcludedTools + s.TasksMode = legacyConfig.TasksMode // Infer type from legacy format for better compatibility // Only set Type when it doesn't change existing transport behavior diff --git a/internal/config/config_test.go b/internal/config/config_test.go index ab920691..d5e91c78 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -627,3 +627,48 @@ func TestMCPServerConfig_OAuthFields_Omitted(t *testing.T) { t.Errorf("Expected empty OAuthScopes, got %v", cfg.OAuthScopes) } } + +func TestMCPServerConfig_TasksMode_NewFormat(t *testing.T) { + jsonData := `{ + "type": "remote", + "url": "https://my-mcp-server.com", + "tasksMode": "always" + }` + var cfg MCPServerConfig + if err := json.Unmarshal([]byte(jsonData), &cfg); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + if cfg.TasksMode != "always" { + t.Errorf("expected TasksMode 'always', got %q", cfg.TasksMode) + } +} + +func TestMCPServerConfig_TasksMode_LegacyFormat(t *testing.T) { + // tasksMode also recognised in the legacy unmarshal path so users on + // the older command/args shape can opt in without migrating. + jsonData := `{ + "command": "npx", + "args": ["@modelcontextprotocol/server-filesystem", "/path"], + "tasksMode": "never" + }` + var cfg MCPServerConfig + if err := json.Unmarshal([]byte(jsonData), &cfg); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + if cfg.TasksMode != "never" { + t.Errorf("expected TasksMode 'never', got %q", cfg.TasksMode) + } +} + +func TestMCPServerConfig_TasksMode_DefaultEmpty(t *testing.T) { + // When tasksMode is not set the field stays empty, which downstream + // resolves to "auto" via tools.ParseTaskMode. + jsonData := `{"type":"remote","url":"https://x.example"}` + var cfg MCPServerConfig + if err := json.Unmarshal([]byte(jsonData), &cfg); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + if cfg.TasksMode != "" { + t.Errorf("expected default TasksMode to be empty, got %q", cfg.TasksMode) + } +} diff --git a/internal/kitsetup/setup.go b/internal/kitsetup/setup.go index 0c8ff8a6..2d276104 100644 --- a/internal/kitsetup/setup.go +++ b/internal/kitsetup/setup.go @@ -72,6 +72,9 @@ type AgentSetupOptions struct { // OnMCPServerLoaded, if non-nil, is called when each MCP server finishes // loading (successfully or with error). Called from the background goroutine. OnMCPServerLoaded func(serverName string, toolCount int, err error) + // MCPTaskConfig configures task-augmented tools/call execution. The + // zero value preserves historical synchronous-only behaviour. + MCPTaskConfig tools.MCPTaskConfig } // AgentSetupResult bundles the created agent and any debug logger so the caller @@ -229,6 +232,7 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult, ToolWrapper: toolWrapper, ExtraTools: extraTools, OnMCPServerLoaded: opts.OnMCPServerLoaded, + MCPTaskConfig: opts.MCPTaskConfig, }) if err != nil { return nil, fmt.Errorf("failed to create agent: %w", err) diff --git a/internal/tools/connection_pool.go b/internal/tools/connection_pool.go index b8407e47..8517ea92 100644 --- a/internal/tools/connection_pool.go +++ b/internal/tools/connection_pool.go @@ -47,6 +47,7 @@ type MCPConnection struct { client client.MCPClient serverName string serverConfig config.MCPServerConfig + initResult *mcp.InitializeResult // captured at handshake; nil before initialize lastUsed time.Time isHealthy bool errorCount int @@ -262,7 +263,9 @@ func (p *MCPConnectionPool) createConnection(ctx context.Context, serverName str } } - if err := p.initializeClient(ctx, mcpClient); err != nil { + conn := &MCPConnection{} + + if err := p.initializeClient(ctx, mcpClient, conn); err != nil { // Streamable HTTP transport returns OAuth error during Initialize() if oauthEnabled && IsOAuthError(err) { if flowErr := p.oauthFlow.RunAuthFlow(ctx, serverName, err); flowErr != nil { @@ -270,7 +273,7 @@ func (p *MCPConnectionPool) createConnection(ctx context.Context, serverName str return nil, fmt.Errorf("OAuth authorization failed: %w", flowErr) } // Retry initialization after successful auth - if err := p.initializeClient(ctx, mcpClient); err != nil { + if err := p.initializeClient(ctx, mcpClient, conn); err != nil { _ = mcpClient.Close() return nil, err } @@ -280,15 +283,11 @@ func (p *MCPConnectionPool) createConnection(ctx context.Context, serverName str } } - conn := &MCPConnection{ - client: mcpClient, - serverName: serverName, - serverConfig: serverConfig, - lastUsed: time.Now(), - isHealthy: true, - errorCount: 0, - lastError: nil, - } + conn.client = mcpClient + conn.serverName = serverName + conn.serverConfig = serverConfig + conn.lastUsed = time.Now() + conn.isHealthy = true if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() { p.debugLogger.LogDebug(fmt.Sprintf("[POOL] Created connection for %s", serverName)) @@ -484,8 +483,10 @@ func (p *MCPConnectionPool) createTokenStore(serverURL string) (transport.TokenS return NewFileTokenStore(serverURL) } -// initializeClient initializes the client -func (p *MCPConnectionPool) initializeClient(ctx context.Context, client client.MCPClient) error { +// initializeClient initializes the client and captures the server's +// initialize result on the supplied connection so callers can later +// inspect advertised capabilities (e.g. task support). +func (p *MCPConnectionPool) initializeClient(ctx context.Context, c client.MCPClient, conn *MCPConnection) error { initCtx, cancel := context.WithTimeout(ctx, 5*time.Minute) defer cancel() @@ -495,12 +496,21 @@ func (p *MCPConnectionPool) initializeClient(ctx context.Context, client client. Name: "kit", Version: "1.0.0", } - initRequest.Params.Capabilities = mcp.ClientCapabilities{} + // Advertise task support so servers may return CreateTaskResult for + // long-running tools/call requests instead of blocking the connection + // until completion. The client is responsible for polling tasks/get and + // tasks/result until the task reaches a terminal state. + initRequest.Params.Capabilities = mcp.ClientCapabilities{ + Tasks: mcp.NewTasksCapability(), + } - _, err := client.Initialize(initCtx, initRequest) + initResult, err := c.Initialize(initCtx, initRequest) if err != nil { return fmt.Errorf("initialization timeout or failed: %w", err) } + if conn != nil { + conn.initResult = initResult + } if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() { p.debugLogger.LogDebug("[POOL] Initialized MCP client") @@ -615,6 +625,54 @@ func (c *MCPConnection) ServerName() string { return c.serverName } +// InitializeResult returns the result captured from the server's initialize +// response, or nil if the connection was created before initialize completed. +// Callers can inspect ServerCapabilities.Tasks to discover task-related +// capability advertisements. +func (c *MCPConnection) InitializeResult() *mcp.InitializeResult { + c.mu.RLock() + defer c.mu.RUnlock() + return c.initResult +} + +// SupportsToolTasks reports whether the server advertised support for +// task-augmented tools/call requests. Returns false when the connection has +// not yet completed initialization or when the server omitted task +// capabilities. +func (c *MCPConnection) SupportsToolTasks() bool { + c.mu.RLock() + defer c.mu.RUnlock() + return supportsToolTasksFromInit(c.initResult) +} + +// supportsToolTasksFromInit reports whether the supplied InitializeResult +// advertises task-augmented tools/call support. Extracted to a free function +// for unit testing without standing up a connection. +func supportsToolTasksFromInit(init *mcp.InitializeResult) bool { + if init == nil || init.Capabilities.Tasks == nil { + return false + } + req := init.Capabilities.Tasks.Requests + if req == nil || req.Tools == nil { + return false + } + return req.Tools.Call != nil +} + +// ServerSupportsToolTasks reports whether the named server's connection +// advertises task-augmented tools/call support. Returns false when no +// connection exists for the server or when the server didn't advertise the +// capability. +func (p *MCPConnectionPool) ServerSupportsToolTasks(serverName string) bool { + p.mu.RLock() + conn, ok := p.connections[serverName] + p.mu.RUnlock() + if !ok { + return false + } + return conn.SupportsToolTasks() +} + // GetClients returns a map of all MCP clients currently in the pool. // The map keys are server names and values are the corresponding MCP client instances. // The returned map is a copy and modifications won't affect the pool. diff --git a/internal/tools/mcp.go b/internal/tools/mcp.go index da7f5086..f4a4a66b 100644 --- a/internal/tools/mcp.go +++ b/internal/tools/mcp.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" "encoding/json" + "errors" "fmt" "maps" "slices" @@ -13,6 +14,7 @@ import ( log "github.com/charmbracelet/log" "github.com/mark3labs/kit/internal/config" + "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/mcp" ) @@ -141,6 +143,11 @@ type MCPToolManager struct { debug bool debugLogger DebugLogger + // taskCfg controls task-augmented tools/call execution. The zero value + // means: auto-detect server capability, no progress callback, default + // poll/timeout. + taskCfg MCPTaskConfig + // onServerLoaded, if non-nil, is called when each server finishes loading. // Called with server name, tool count, and error (nil on success). onServerLoaded func(serverName string, toolCount int, err error) @@ -220,6 +227,21 @@ func (m *MCPToolManager) SetOnToolsChanged(cb func()) { m.onToolsChanged = cb } +// SetTaskConfig sets the task-augmented tools/call configuration. Call +// this before LoadTools / AddServer if you want the per-server mode +// override and progress handler to take effect for the very first call. +// Subsequent calls replace the previous configuration wholesale. +func (m *MCPToolManager) SetTaskConfig(cfg MCPTaskConfig) { + m.taskCfg = cfg +} + +// TaskConfig returns the manager's current task-augmented tools/call +// configuration. The zero value means: defer to per-server config and +// auto-detected capability, with no progress callback and default polling. +func (m *MCPToolManager) TaskConfig() MCPTaskConfig { + return m.taskCfg +} + // AddServer connects to a new MCP server at runtime and loads its tools. // The server's tools are immediately available to the agent after this call. // Returns the number of tools loaded from the server. @@ -551,6 +573,14 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string, // checks, OAuth re-authorization, and connection error tracking. // The inputJSON parameter is the raw JSON arguments from the LLM. // Returns the result content, error flag, and any execution error. +// +// When the per-server TasksMode resolves to "always", or to "auto" and the +// server advertised tasks/toolCalls capability during initialize, the call +// is augmented with TaskParams. If the server elects to respond with a +// CreateTaskResult the manager polls tasks/get / tasks/result until the +// task reaches a terminal state, transparently presenting the final +// CallToolResult-equivalent content to the agent layer. Context +// cancellation triggers a best-effort tasks/cancel. func (m *MCPToolManager) ExecuteTool(ctx context.Context, prefixedName, inputJSON string) (*MCPToolResult, error) { m.mu.Lock() mapping, ok := m.toolMap[prefixedName] @@ -582,49 +612,221 @@ func (m *MCPToolManager) ExecuteTool(ctx context.Context, prefixedName, inputJSO return nil, fmt.Errorf("failed to get healthy connection from pool: %w", err) } - callRequest := mcp.CallToolRequest{ - Request: mcp.Request{ - Method: "tools/call", - }, - Params: mcp.CallToolParams{ - Name: mapping.originalName, - Arguments: arguments, - }, + callParams := mcp.CallToolParams{ + Name: mapping.originalName, + Arguments: arguments, } - // Call the MCP tool using the original (unprefixed) name - result, err := conn.client.CallTool(ctx, callRequest) - if err != nil { - // Handle OAuth re-authorization: token may have expired mid-session. - if m.connectionPool.oauthFlow != nil && IsOAuthError(err) { - if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, mapping.serverName, err); flowErr != nil { + // Decide whether to augment the request with TaskParams. Modes: + // never — never augment (synchronous-only). + // always — always augment, even without server capability. + // auto — augment only when the server advertised tasks/toolCalls. + mode := m.resolveTaskMode(mapping.serverName, mapping.serverConfig) + useTask := mode == MCPTaskModeAlways || + (mode == MCPTaskModeAuto && conn.SupportsToolTasks()) + if useTask { + var ttl *int64 + if m.taskCfg.DefaultTTL > 0 { + ms := m.taskCfg.DefaultTTL.Milliseconds() + ttl = &ms + } + callParams.Task = &mcp.TaskParams{TTL: ttl} + } + + // Synchronous fast path: no task augmentation. Use the upstream client + // helper which keeps content-block typing identical to historical + // behaviour. + if !useTask { + callRequest := mcp.CallToolRequest{ + Request: mcp.Request{Method: "tools/call"}, + Params: callParams, + } + result, callErr := conn.client.CallTool(ctx, callRequest) + if callErr != nil { + if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) { + if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, mapping.serverName, callErr); flowErr != nil { + return nil, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", mapping.originalName, flowErr) + } + result, callErr = conn.client.CallTool(ctx, callRequest) + if callErr != nil { + m.connectionPool.HandleConnectionError(mapping.serverName, callErr) + return nil, fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr) + } + } else { + m.connectionPool.HandleConnectionError(mapping.serverName, callErr) + return nil, fmt.Errorf("failed to call mcp tool: %w", callErr) + } + } + marshaledResult, mErr := json.Marshal(result) + if mErr != nil { + return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr) + } + return &MCPToolResult{ + Content: string(marshaledResult), + IsError: result.IsError, + }, nil + } + + // Task-augmented path. Bypass the upstream CallTool helper because its + // ParseCallToolResult requires a "content" field that is absent from a + // CreateTaskResult. + rawClient, ok := conn.client.(*client.Client) + if !ok { + // Older client implementations — fall back to the synchronous shape. + callParams.Task = nil + callRequest := mcp.CallToolRequest{ + Request: mcp.Request{Method: "tools/call"}, + Params: callParams, + } + result, callErr := conn.client.CallTool(ctx, callRequest) + if callErr != nil { + m.connectionPool.HandleConnectionError(mapping.serverName, callErr) + return nil, fmt.Errorf("failed to call mcp tool: %w", callErr) + } + marshaledResult, mErr := json.Marshal(result) + if mErr != nil { + return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr) + } + return &MCPToolResult{Content: string(marshaledResult), IsError: result.IsError}, nil + } + + callResult, taskResult, callErr := callToolWithTask(ctx, rawClient, callParams) + if callErr != nil { + if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) { + if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, mapping.serverName, callErr); flowErr != nil { return nil, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", mapping.originalName, flowErr) } - // Retry the tool call after successful re-auth. - result, err = conn.client.CallTool(ctx, callRequest) - if err != nil { - m.connectionPool.HandleConnectionError(mapping.serverName, err) - return nil, fmt.Errorf("failed to call mcp tool after re-auth: %w", err) + callResult, taskResult, callErr = callToolWithTask(ctx, rawClient, callParams) + if callErr != nil { + m.connectionPool.HandleConnectionError(mapping.serverName, callErr) + return nil, fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr) } } else { - // Mark connection as unhealthy for automatic recovery - m.connectionPool.HandleConnectionError(mapping.serverName, err) - return nil, fmt.Errorf("failed to call mcp tool: %w", err) + m.connectionPool.HandleConnectionError(mapping.serverName, callErr) + return nil, fmt.Errorf("failed to call mcp tool: %w", callErr) } } - // Marshal the MCP result to JSON string - marshaledResult, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal mcp tool result: %w", err) + // Server chose to answer synchronously — same shape as the no-task path. + if callResult != nil { + marshaledResult, mErr := json.Marshal(callResult) + if mErr != nil { + return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr) + } + return &MCPToolResult{ + Content: string(marshaledResult), + IsError: callResult.IsError, + }, nil } + // Asynchronous task path: poll until terminal, then return the result. + if taskResult == nil { + return nil, errors.New("mcp tools/call returned neither result nor task") + } + final, pollErr := pollTaskUntilTerminal( + ctx, rawClient, mapping.serverName, taskResult.Task, + m.taskCfg, m.taskCfg.Progress, + ) + if pollErr != nil { + return nil, fmt.Errorf("task execution failed: %w", pollErr) + } + + // Adapt TaskResultResult → CallToolResult for downstream JSON shape parity. + adapted := &mcp.CallToolResult{ + Content: final.Content, + StructuredContent: final.StructuredContent, + IsError: final.IsError, + } + marshaledResult, mErr := json.Marshal(adapted) + if mErr != nil { + return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr) + } return &MCPToolResult{ Content: string(marshaledResult), - IsError: result.IsError, + IsError: final.IsError, }, nil } +// resolveTaskMode resolves the effective task mode for a given server. +// Programmatic overrides via SetTaskConfig take precedence over the +// per-server TasksMode in MCPServerConfig. Empty / unknown values map to +// MCPTaskModeAuto. +func (m *MCPToolManager) resolveTaskMode(name string, cfg config.MCPServerConfig) MCPTaskMode { + if m.taskCfg.PerServerMode != nil { + if v, ok := m.taskCfg.PerServerMode[name]; ok { + return v + } + } + return ParseTaskMode(cfg.TasksMode) +} + +// ListServerTasks queries tasks/list on the named server and returns the +// active and recent tasks the server is willing to disclose. Errors are +// returned untouched (callers commonly ignore METHOD_NOT_FOUND when the +// server didn't advertise tasks/list capability). +func (m *MCPToolManager) ListServerTasks(ctx context.Context, serverName string) ([]MCPTaskInfo, error) { + c, err := m.taskClient(serverName) + if err != nil { + return nil, err + } + res, err := c.ListTasks(ctx, mcp.ListTasksRequest{}) + if err != nil { + return nil, fmt.Errorf("tasks/list on %s: %w", serverName, err) + } + out := make([]MCPTaskInfo, 0, len(res.Tasks)) + for _, t := range res.Tasks { + out = append(out, taskFromMCP(serverName, t)) + } + return out, nil +} + +// GetServerTask queries tasks/get for a single task on the named server. +func (m *MCPToolManager) GetServerTask(ctx context.Context, serverName, taskID string) (MCPTaskInfo, error) { + c, err := m.taskClient(serverName) + if err != nil { + return MCPTaskInfo{}, err + } + res, err := c.GetTask(ctx, mcp.GetTaskRequest{Params: mcp.GetTaskParams{TaskId: taskID}}) + if err != nil { + return MCPTaskInfo{}, fmt.Errorf("tasks/get on %s: %w", serverName, err) + } + return taskFromMCP(serverName, res.Task), nil +} + +// CancelServerTask issues tasks/cancel for a task on the named server. +// Returns the post-cancel task state when the server responded with one. +func (m *MCPToolManager) CancelServerTask(ctx context.Context, serverName, taskID string) (MCPTaskInfo, error) { + c, err := m.taskClient(serverName) + if err != nil { + return MCPTaskInfo{}, err + } + res, err := c.CancelTask(ctx, mcp.CancelTaskRequest{Params: mcp.CancelTaskParams{TaskId: taskID}}) + if err != nil { + return MCPTaskInfo{}, fmt.Errorf("tasks/cancel on %s: %w", serverName, err) + } + return taskFromMCP(serverName, res.Task), nil +} + +// taskClient returns the *client.Client for a server. Tasks endpoints are +// not part of the upstream MCPClient interface so callers must work with +// the concrete client. Returns an error when the connection is missing +// or backed by a non-standard client type. +func (m *MCPToolManager) taskClient(serverName string) (*client.Client, error) { + if m.connectionPool == nil { + return nil, fmt.Errorf("no connection pool available") + } + clients := m.connectionPool.GetClients() + raw, ok := clients[serverName] + if !ok { + return nil, fmt.Errorf("MCP server %q not loaded", serverName) + } + c, ok := raw.(*client.Client) + if !ok { + return nil, fmt.Errorf("MCP server %q does not support task RPCs", serverName) + } + return c, nil +} + // GetTools returns all loaded MCP tools from all configured MCP servers. // Tools are returned with their prefixed names (serverName__toolName) to ensure uniqueness. func (m *MCPToolManager) GetTools() []MCPTool { diff --git a/internal/tools/mcp_tasks.go b/internal/tools/mcp_tasks.go new file mode 100644 index 00000000..3a22c2fe --- /dev/null +++ b/internal/tools/mcp_tasks.go @@ -0,0 +1,404 @@ +package tools + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "sync/atomic" + "time" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +// MCPTaskMode controls when the connection pool augments tools/call requests +// with MCP task metadata. See https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks. +type MCPTaskMode string + +const ( + // MCPTaskModeAuto augments tools/call with task metadata only when the + // server advertises tasks/toolCalls capability during initialize. + MCPTaskModeAuto MCPTaskMode = "auto" + // MCPTaskModeNever forces every tools/call to be issued synchronously + // (no Task field in the request), regardless of server capability. + MCPTaskModeNever MCPTaskMode = "never" + // MCPTaskModeAlways always sets a Task field on the tools/call request, + // even when the server didn't advertise task support. The server may + // still respond synchronously; this just opts in unconditionally on + // the client side. + MCPTaskModeAlways MCPTaskMode = "always" +) + +// ParseTaskMode normalises a per-server tasks-mode string from +// configuration. Empty input maps to MCPTaskModeAuto. Unknown values are +// also treated as MCPTaskModeAuto so a stray config typo never breaks +// existing flows. +func ParseTaskMode(s string) MCPTaskMode { + switch strings.ToLower(strings.TrimSpace(s)) { + case "", "auto": + return MCPTaskModeAuto + case "never", "off", "disabled": + return MCPTaskModeNever + case "always", "force": + return MCPTaskModeAlways + default: + return MCPTaskModeAuto + } +} + +// MCPTaskInfo is the connection-layer view of an MCP Task. It mirrors the +// upstream mcp.Task but exposes Go-native types and includes the originating +// server name. SDK-level wrappers re-export this under public-facing names. +type MCPTaskInfo struct { + // Server is the configured MCP server name this task lives on. + Server string + // TaskID is the server-assigned identifier for the task. + TaskID string + // Status is the current task lifecycle state. + Status mcp.TaskStatus + // StatusMessage is an optional human-readable description. + StatusMessage string + // CreatedAt is the wall-clock time the task was created (best-effort + // parsed from the server's ISO-8601 timestamp; zero on parse failure). + CreatedAt time.Time + // UpdatedAt is the wall-clock time the task was last updated (best- + // effort parsed; zero on parse failure). + UpdatedAt time.Time + // TTL is the time-to-live the server intends to retain the task after + // creation. Zero means the server did not advertise a TTL. + TTL time.Duration + // PollInterval is the suggested polling interval. Zero means use the + // client's default. + PollInterval time.Duration +} + +// MCPTaskProgress is emitted while the connection pool is waiting on a +// task-augmented tool call. It provides minimal feedback for SDK consumers +// that want to render progress widgets without subscribing to the full +// notifications/tasks/status channel (Phase 2). +type MCPTaskProgress struct { + Server string + TaskID string + Status mcp.TaskStatus + Message string +} + +// MCPTaskProgressHandler is invoked once after a task is accepted and on +// every status transition observed by the polling loop. The final +// invocation always carries a terminal status. Implementations must not +// block; long work should be queued on a goroutine. +type MCPTaskProgressHandler func(MCPTaskProgress) + +// MCPTaskConfig configures task-aware tool execution on the manager. +// All fields are optional; the zero value disables progress callbacks and +// applies sensible defaults. +type MCPTaskConfig struct { + // PerServerMode overrides the per-server TasksMode resolved from + // MCPServerConfig. Keys are server names. Missing entries fall back + // to the value from config. Used by SDK consumers that want to set + // modes programmatically. + PerServerMode map[string]MCPTaskMode + + // DefaultTTL is the TTL hint sent in TaskParams when augmenting a + // tools/call. Zero means omit the TTL — let the server pick its own. + DefaultTTL time.Duration + + // PollInterval is the fallback interval between tasks/get requests + // when the server does not suggest one. Zero defaults to 1 second. + PollInterval time.Duration + + // MaxPollInterval caps the polling interval. Zero defaults to 5 seconds. + MaxPollInterval time.Duration + + // Timeout is the maximum wall-clock duration to wait for a task to + // reach a terminal state. Zero defaults to 15 minutes. Independent + // of the per-call context deadline; whichever fires first wins. + Timeout time.Duration + + // Progress, if non-nil, receives every status transition observed by + // the polling loop. + Progress MCPTaskProgressHandler +} + +func (c MCPTaskConfig) resolved() MCPTaskConfig { + if c.PollInterval <= 0 { + c.PollInterval = 1 * time.Second + } + if c.MaxPollInterval <= 0 { + c.MaxPollInterval = 5 * time.Second + } + if c.Timeout <= 0 { + c.Timeout = 15 * time.Minute + } + return c +} + +// requestIDCounter generates monotonically increasing JSON-RPC request IDs +// for low-level tools/call invocations that bypass the upstream client's +// ParseCallToolResult helper (necessary because that helper rejects task +// responses for lacking a "content" field). +// +// The counter is process-wide rather than per-manager so multiple managers +// or repeated calls within the same connection produce unique IDs. +var requestIDCounter atomic.Int64 + +func nextRequestID() mcp.RequestId { + return mcp.NewRequestId(requestIDCounter.Add(1)) +} + +// callToolWithTask issues tools/call directly on the transport so we can +// observe both response shapes: +// +// - {"content": [...], ...} — synchronous CallToolResult. +// - {"task": {...}, ...} — asynchronous CreateTaskResult. +// +// On success exactly one of (callResult, taskResult) is non-nil. The +// upstream client.CallTool helper parses the response with +// mcp.ParseCallToolResult which requires a "content" field, so it cannot +// be used for task-augmented calls. +func callToolWithTask( + ctx context.Context, + c *client.Client, + params mcp.CallToolParams, +) (callResult *mcp.CallToolResult, taskResult *mcp.CreateTaskResult, err error) { + tr := c.GetTransport() + if tr == nil { + return nil, nil, errors.New("mcp client has no transport") + } + + req := transport.JSONRPCRequest{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: nextRequestID(), + Method: string(mcp.MethodToolsCall), + Params: params, + } + + resp, sendErr := tr.SendRequest(ctx, req) + if sendErr != nil { + return nil, nil, sendErr + } + if resp.Error != nil { + return nil, nil, resp.Error.AsError() + } + + // Peek at the raw result to decide which shape we got. + var probe struct { + Task json.RawMessage `json:"task"` + Content json.RawMessage `json:"content"` + } + raw := resp.Result + if len(raw) == 0 { + return nil, nil, errors.New("empty tools/call result") + } + if uErr := json.Unmarshal(raw, &probe); uErr != nil { + return nil, nil, fmt.Errorf("decode tools/call result: %w", uErr) + } + + if len(probe.Task) > 0 && string(probe.Task) != "null" { + // Task-augmented response. + var ct mcp.CreateTaskResult + if uErr := json.Unmarshal(raw, &ct); uErr != nil { + return nil, nil, fmt.Errorf("decode CreateTaskResult: %w", uErr) + } + return nil, &ct, nil + } + + // Synchronous response — defer to the upstream parser so content blocks + // are typed correctly (TextContent, ImageContent, ResourceLink, etc.). + cr, pErr := mcp.ParseCallToolResult(&raw) + if pErr != nil { + return nil, nil, fmt.Errorf("parse CallToolResult: %w", pErr) + } + return cr, nil, nil +} + +// pollTaskUntilTerminal blocks until the task reaches a terminal status, +// the context is cancelled, or the configured timeout elapses. On +// cancellation it best-effort issues tasks/cancel before returning. +func pollTaskUntilTerminal( + ctx context.Context, + c *client.Client, + serverName string, + task mcp.Task, + cfg MCPTaskConfig, + progress MCPTaskProgressHandler, +) (*mcp.TaskResultResult, error) { + cfg = cfg.resolved() + deadline := time.Now().Add(cfg.Timeout) + + emit := func(status mcp.TaskStatus, msg string) { + if progress != nil { + progress(MCPTaskProgress{Server: serverName, TaskID: task.TaskId, Status: status, Message: msg}) + } + } + + emit(task.Status, task.StatusMessage) + + current := task + interval := cfg.PollInterval + if current.PollInterval != nil && *current.PollInterval > 0 { + interval = time.Duration(*current.PollInterval) * time.Millisecond + } + if interval > cfg.MaxPollInterval { + interval = cfg.MaxPollInterval + } + + for !current.Status.IsTerminal() { + if time.Now().After(deadline) { + cancelTaskBestEffort(c, current.TaskId) + return nil, fmt.Errorf("task %s timed out after %s", current.TaskId, cfg.Timeout) + } + + // Wait between polls or abort early on context cancellation. + select { + case <-ctx.Done(): + cancelTaskBestEffort(c, current.TaskId) + return nil, ctx.Err() + case <-time.After(interval): + } + + got, err := c.GetTask(ctx, mcp.GetTaskRequest{ + Params: mcp.GetTaskParams{TaskId: current.TaskId}, + }) + if err != nil { + // Transient transport hiccup — propagate immediately. The + // upstream agent layer treats this like any other tool error. + return nil, fmt.Errorf("tasks/get failed: %w", err) + } + current = got.Task + if current.Status != task.Status || current.StatusMessage != task.StatusMessage { + emit(current.Status, current.StatusMessage) + task = current + } + + // Honour any updated suggested poll interval, capped at the limit. + if current.PollInterval != nil && *current.PollInterval > 0 { + interval = min(time.Duration(*current.PollInterval)*time.Millisecond, cfg.MaxPollInterval) + } + } + + // Terminal state reached. Emit one last progress event and fetch the + // definitive tool result. + emit(current.Status, current.StatusMessage) + + if current.Status == mcp.TaskStatusCancelled { + return nil, fmt.Errorf("task %s was cancelled", current.TaskId) + } + + res, err := fetchTaskResult(ctx, c, current.TaskId) + if err != nil { + return nil, fmt.Errorf("tasks/result failed: %w", err) + } + if current.Status == mcp.TaskStatusFailed && res != nil && !res.IsError { + // The server flagged the task as failed but didn't decorate the + // result. Surface the status message so the caller still sees a + // useful tool-error. + return nil, fmt.Errorf("task %s failed: %s", current.TaskId, current.StatusMessage) + } + return res, nil +} + +// fetchTaskResult issues tasks/result on the transport and parses the raw +// response. The upstream client.TaskResult helper delegates to +// mcp.ParseTaskResultResult which (as of mcp-go v0.51.0) looks for the +// content array under a nested "result" key that never exists in the +// wire format — leading to systematically empty Content. Doing the +// parse here keeps the polling path working until that is fixed upstream. +func fetchTaskResult(ctx context.Context, c *client.Client, taskID string) (*mcp.TaskResultResult, error) { + tr := c.GetTransport() + if tr == nil { + return nil, errors.New("mcp client has no transport") + } + req := transport.JSONRPCRequest{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: nextRequestID(), + Method: string(mcp.MethodTasksResult), + Params: mcp.TaskResultParams{TaskId: taskID}, + } + resp, err := tr.SendRequest(ctx, req) + if err != nil { + return nil, err + } + if resp.Error != nil { + return nil, resp.Error.AsError() + } + + // Manually decode the wire shape: {"_meta": {...}, "content": [...], + // "structuredContent": ..., "isError": bool}. + var shape struct { + Meta json.RawMessage `json:"_meta"` + Content []json.RawMessage `json:"content"` + StructuredContent any `json:"structuredContent"` + IsError bool `json:"isError"` + } + if err := json.Unmarshal(resp.Result, &shape); err != nil { + return nil, fmt.Errorf("decode tasks/result: %w", err) + } + + out := &mcp.TaskResultResult{ + StructuredContent: shape.StructuredContent, + IsError: shape.IsError, + } + if len(shape.Meta) > 0 && string(shape.Meta) != "null" { + var metaMap map[string]any + if err := json.Unmarshal(shape.Meta, &metaMap); err == nil { + out.Meta = mcp.NewMetaFromMap(metaMap) + } + } + for _, raw := range shape.Content { + var contentMap map[string]any + if err := json.Unmarshal(raw, &contentMap); err != nil { + return nil, fmt.Errorf("decode content block: %w", err) + } + parsed, err := mcp.ParseContent(contentMap) + if err != nil { + return nil, fmt.Errorf("parse content block: %w", err) + } + out.Content = append(out.Content, parsed) + } + return out, nil +} + +// cancelTaskBestEffort issues tasks/cancel and ignores any error. Used on +// context cancellation paths where the connection is already going away. +func cancelTaskBestEffort(c *client.Client, taskID string) { + if c == nil || taskID == "" { + return + } + cancelCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, _ = c.CancelTask(cancelCtx, mcp.CancelTaskRequest{ + Params: mcp.CancelTaskParams{TaskId: taskID}, + }) +} + +// taskFromMCP converts a wire-format mcp.Task to our richer connection- +// layer view. Unparseable timestamps surface as the zero time. +func taskFromMCP(serverName string, t mcp.Task) MCPTaskInfo { + out := MCPTaskInfo{ + Server: serverName, + TaskID: t.TaskId, + Status: t.Status, + StatusMessage: t.StatusMessage, + } + if t.CreatedAt != "" { + if v, err := time.Parse(time.RFC3339, t.CreatedAt); err == nil { + out.CreatedAt = v + } + } + if t.LastUpdatedAt != "" { + if v, err := time.Parse(time.RFC3339, t.LastUpdatedAt); err == nil { + out.UpdatedAt = v + } + } + if t.TTL != nil { + out.TTL = time.Duration(*t.TTL) * time.Millisecond + } + if t.PollInterval != nil { + out.PollInterval = time.Duration(*t.PollInterval) * time.Millisecond + } + return out +} diff --git a/internal/tools/mcp_tasks_test.go b/internal/tools/mcp_tasks_test.go new file mode 100644 index 00000000..1ab9c651 --- /dev/null +++ b/internal/tools/mcp_tasks_test.go @@ -0,0 +1,294 @@ +package tools + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/mark3labs/kit/internal/config" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// newTaskTestInProcessServer builds an in-process MCP server with a +// task-augmented tool. The handler simulates work by sleeping briefly +// before completing. +// +// Important: the upstream mcp-go server cancels the request context as +// soon as the synchronous part of the tools/call returns (see +// request_handler.go:85, `defer cancel()`). Task goroutines spawned by +// AddTaskTool inherit that context and therefore see context.Canceled +// the instant they start. Real-world transports (stdio, SSE, streamable +// HTTP) don't trip this because they keep the connection — and a +// background context — alive across the async work, but the in-process +// transport runs entirely on the request goroutine. To test the polling +// path realistically we detach from the request context here. +func newTaskTestInProcessServer(t *testing.T, workDuration time.Duration) *server.MCPServer { + t.Helper() + srv := server.NewMCPServer("task-test", "1.0.0", + server.WithToolCapabilities(true), + // list=true, cancel=true, toolCallTasks=true so capability detection, + // cancellation, and tool augmentation all flow through. + server.WithTaskCapabilities(true, true, true), + ) + srv.AddTaskTool( + mcp.Tool{ + Name: "long_running", + Description: "Sleep, then echo the input string.", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "msg": map[string]any{"type": "string"}, + }, + }, + Execution: &mcp.ToolExecution{ + TaskSupport: mcp.TaskSupportRequired, + }, + }, + func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CreateTaskResult, error) { + msg, _ := req.GetArguments()["msg"].(string) + // Detach from the request context so the task handler can + // outlive the synchronous request — see comment above. + time.Sleep(workDuration) + _ = ctx + return &mcp.CreateTaskResult{ + Content: []mcp.Content{ + mcp.TextContent{Type: "text", Text: "echo:" + msg}, + }, + }, nil + }, + ) + return srv +} + +// newSyncOnlyServer is a server that does NOT advertise task capability. +// Used to verify the auto-detect path keeps the sync semantics. +func newSyncOnlyServer() *server.MCPServer { + srv := server.NewMCPServer("sync-only", "1.0.0", + server.WithToolCapabilities(true), + ) + srv.AddTool( + mcp.NewTool("greet", + mcp.WithDescription("Say hello"), + mcp.WithString("name", mcp.Required()), + ), + func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + name, _ := req.GetArguments()["name"].(string) + return mcp.NewToolResultText("hi " + name), nil + }, + ) + return srv +} + +func TestConnectionPoolAdvertisesTaskCapability(t *testing.T) { + pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil) + defer func() { _ = pool.Close() }() + + srv := newTaskTestInProcessServer(t, 0) + cfg := config.MCPServerConfig{Type: "inprocess", InProcessServer: srv} + + conn, err := pool.GetConnection(context.Background(), "tasks", cfg) + if err != nil { + t.Fatalf("GetConnection: %v", err) + } + + init := conn.InitializeResult() + if init == nil { + t.Fatal("InitializeResult is nil after GetConnection") + } + if init.Capabilities.Tasks == nil { + t.Fatal("server did not advertise Tasks capability — initialize handshake regressed") + } + if !conn.SupportsToolTasks() { + t.Error("SupportsToolTasks should be true for a server with toolCallTasks=true") + } + if !pool.ServerSupportsToolTasks("tasks") { + t.Error("ServerSupportsToolTasks should mirror the connection's value") + } +} + +func TestConnectionPoolDetectsAbsentTaskCapability(t *testing.T) { + pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil) + defer func() { _ = pool.Close() }() + + cfg := config.MCPServerConfig{Type: "inprocess", InProcessServer: newSyncOnlyServer()} + conn, err := pool.GetConnection(context.Background(), "sync", cfg) + if err != nil { + t.Fatalf("GetConnection: %v", err) + } + if conn.SupportsToolTasks() { + t.Error("SupportsToolTasks should be false for a server that didn't advertise the capability") + } +} + +func TestSupportsToolTasksFromInit(t *testing.T) { + cases := []struct { + name string + in *mcp.InitializeResult + want bool + }{ + {"nil", nil, false}, + {"no tasks", &mcp.InitializeResult{}, false}, + {"tasks no requests", &mcp.InitializeResult{ + Capabilities: mcp.ServerCapabilities{Tasks: &mcp.TasksCapability{}}, + }, false}, + {"tasks with toolCalls", &mcp.InitializeResult{ + Capabilities: mcp.ServerCapabilities{Tasks: mcp.NewTasksCapability()}, + }, true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := supportsToolTasksFromInit(tc.in); got != tc.want { + t.Errorf("supportsToolTasksFromInit() = %v, want %v", got, tc.want) + } + }) + } +} + +func TestParseTaskMode(t *testing.T) { + cases := []struct { + in string + want MCPTaskMode + }{ + {"", MCPTaskModeAuto}, + {"auto", MCPTaskModeAuto}, + {"AUTO", MCPTaskModeAuto}, + {"never", MCPTaskModeNever}, + {"off", MCPTaskModeNever}, + {"always", MCPTaskModeAlways}, + {"force", MCPTaskModeAlways}, + {"bogus", MCPTaskModeAuto}, + } + for _, tc := range cases { + if got := ParseTaskMode(tc.in); got != tc.want { + t.Errorf("ParseTaskMode(%q) = %q, want %q", tc.in, got, tc.want) + } + } +} + +func TestExecuteToolPollsTaskToCompletion(t *testing.T) { + mgr := NewMCPToolManager() + mgr.SetTaskConfig(MCPTaskConfig{ + PollInterval: 20 * time.Millisecond, + MaxPollInterval: 50 * time.Millisecond, + Timeout: 10 * time.Second, + }) + + cfg := config.MCPServerConfig{ + Type: "inprocess", + InProcessServer: newTaskTestInProcessServer(t, 50*time.Millisecond), + } + + if _, err := mgr.AddServer(context.Background(), "tasks", cfg); err != nil { + t.Fatalf("AddServer: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + res, err := mgr.ExecuteTool(ctx, "tasks__long_running", `{"msg":"hello"}`) + if err != nil { + t.Fatalf("ExecuteTool: %v", err) + } + if res.IsError { + t.Fatalf("expected non-error result, got %s", res.Content) + } + if !strings.Contains(res.Content, "echo:hello") { + t.Errorf("expected result to contain 'echo:hello', got %s", res.Content) + } +} + +func TestExecuteToolHonorsNeverMode(t *testing.T) { + // Even though the server advertises tasks/toolCalls, "never" should + // keep the call synchronous. Since the tool is TaskSupportRequired, + // the server returns an error rather than running it sync — we just + // verify the error surfaces (not a poll-loop hang). + mgr := NewMCPToolManager() + mgr.SetTaskConfig(MCPTaskConfig{ + PerServerMode: map[string]MCPTaskMode{"tasks": MCPTaskModeNever}, + Timeout: 2 * time.Second, + }) + + cfg := config.MCPServerConfig{ + Type: "inprocess", + InProcessServer: newTaskTestInProcessServer(t, 0), + } + + if _, err := mgr.AddServer(context.Background(), "tasks", cfg); err != nil { + t.Fatalf("AddServer: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + // We don't care which way the server fails the sync call; we just want + // to confirm we didn't hang in the polling loop and didn't panic. + _, err := mgr.ExecuteTool(ctx, "tasks__long_running", `{"msg":"x"}`) + if err == nil { + t.Fatal("expected an error when forcing sync execution of a task-required tool") + } +} + +func TestExecuteToolEmitsProgress(t *testing.T) { + var statuses []mcp.TaskStatus + mgr := NewMCPToolManager() + mgr.SetTaskConfig(MCPTaskConfig{ + PollInterval: 10 * time.Millisecond, + MaxPollInterval: 25 * time.Millisecond, + Timeout: 5 * time.Second, + Progress: func(p MCPTaskProgress) { + statuses = append(statuses, p.Status) + }, + }) + + cfg := config.MCPServerConfig{ + Type: "inprocess", + InProcessServer: newTaskTestInProcessServer(t, 30*time.Millisecond), + } + if _, err := mgr.AddServer(context.Background(), "tasks", cfg); err != nil { + t.Fatalf("AddServer: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if _, err := mgr.ExecuteTool(ctx, "tasks__long_running", `{"msg":"hi"}`); err != nil { + t.Fatalf("ExecuteTool: %v", err) + } + if len(statuses) == 0 { + t.Fatal("expected at least one progress event") + } + last := statuses[len(statuses)-1] + if !last.IsTerminal() { + t.Errorf("last progress event should be terminal, got %q", last) + } +} + +func TestListGetCancelMCPTasksOnLoadedServer(t *testing.T) { + mgr := NewMCPToolManager() + cfg := config.MCPServerConfig{ + Type: "inprocess", + InProcessServer: newTaskTestInProcessServer(t, 0), + } + if _, err := mgr.AddServer(context.Background(), "tasks", cfg); err != nil { + t.Fatalf("AddServer: %v", err) + } + + ctx := context.Background() + + // tasks/list — no in-flight tasks yet, so we just verify the call + // succeeds and returns an empty slice (or any slice; the exact length + // depends on server retention policy). + if _, err := mgr.ListServerTasks(ctx, "tasks"); err != nil { + t.Errorf("ListServerTasks: %v", err) + } + + // Unknown server should error cleanly without panicking. + if _, err := mgr.GetServerTask(ctx, "unknown", "abc"); err == nil { + t.Error("GetServerTask on unknown server should error") + } + if _, err := mgr.CancelServerTask(ctx, "unknown", "abc"); err == nil { + t.Error("CancelServerTask on unknown server should error") + } +} diff --git a/pkg/kit/kit.go b/pkg/kit/kit.go index b60186f5..3b7e754e 100644 --- a/pkg/kit/kit.go +++ b/pkg/kit/kit.go @@ -1035,6 +1035,41 @@ type Options struct { // real-time progress in the TUI. OnMCPServerLoaded func(serverName string, toolCount int, err error) + // MCPTaskMode overrides the per-server [MCPTaskMode] for task-augmented + // tools/call execution. Keys are MCP server names. Servers not present + // in the map fall back to the TasksMode field of MCPServerConfig (or + // MCPTaskModeAuto when that is empty). See the MCP Tasks spec for the + // underlying semantics: + // https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks + MCPTaskMode map[string]MCPTaskMode + + // MCPTaskTimeout is the maximum wall-clock duration to wait for a + // task-augmented tool call to reach a terminal state. Independent of + // any per-call context deadline; whichever fires first wins. Zero + // means use the default (15 minutes). + MCPTaskTimeout time.Duration + + // MCPTaskTTL is the TTL hint sent in TaskParams for every + // task-augmented tools/call. Zero omits the TTL and lets the server + // pick its own retention policy. + MCPTaskTTL time.Duration + + // MCPTaskPollInterval is the fallback interval between tasks/get + // requests when the server does not suggest one. Zero means use the + // default (1 second). + MCPTaskPollInterval time.Duration + + // MCPTaskMaxPollInterval caps the polling interval (a server-supplied + // pollInterval can otherwise grow without bound). Zero means use the + // default (5 seconds). + MCPTaskMaxPollInterval time.Duration + + // MCPTaskProgress, if non-nil, is invoked once when a task is accepted + // and on every status transition observed by the polling loop. The + // final invocation always carries a terminal status. Implementations + // must not block; long work should run on a goroutine. + MCPTaskProgress MCPTaskProgressHandler + // CLI is optional CLI-specific configuration. SDK users leave this nil. CLI *CLIOptions @@ -1387,6 +1422,14 @@ func New(ctx context.Context, opts *Options) (*Kit, error) { MaxSteps: maxSteps, StreamingEnabled: streaming, OnMCPServerLoaded: opts.OnMCPServerLoaded, + MCPTaskConfig: mcpTaskOptions{ + perServer: opts.MCPTaskMode, + defaultTTL: opts.MCPTaskTTL, + pollInterval: opts.MCPTaskPollInterval, + maxPollInterval: opts.MCPTaskMaxPollInterval, + timeout: opts.MCPTaskTimeout, + progress: opts.MCPTaskProgress, + }.toToolsConfig(), } // Set up OAuth handler for remote MCP servers. The SDK does not create diff --git a/pkg/kit/mcp_tasks.go b/pkg/kit/mcp_tasks.go new file mode 100644 index 00000000..e57f76b6 --- /dev/null +++ b/pkg/kit/mcp_tasks.go @@ -0,0 +1,220 @@ +package kit + +import ( + "context" + "fmt" + "time" + + "github.com/mark3labs/kit/internal/tools" + "github.com/mark3labs/mcp-go/mcp" +) + +// MCPTaskStatus represents the lifecycle state of a task-augmented MCP +// tool call. See https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks +// for the underlying spec. +type MCPTaskStatus string + +const ( + // MCPTaskStatusWorking indicates the task is currently being processed. + MCPTaskStatusWorking MCPTaskStatus = MCPTaskStatus(mcp.TaskStatusWorking) + // MCPTaskStatusInputRequired indicates the server is waiting for client + // input before it can proceed (rare; typically surfaced via elicitation). + MCPTaskStatusInputRequired MCPTaskStatus = MCPTaskStatus(mcp.TaskStatusInputRequired) + // MCPTaskStatusCompleted indicates the task finished successfully. + MCPTaskStatusCompleted MCPTaskStatus = MCPTaskStatus(mcp.TaskStatusCompleted) + // MCPTaskStatusFailed indicates the task ended in error. + MCPTaskStatusFailed MCPTaskStatus = MCPTaskStatus(mcp.TaskStatusFailed) + // MCPTaskStatusCancelled indicates the task was cancelled before completion. + MCPTaskStatusCancelled MCPTaskStatus = MCPTaskStatus(mcp.TaskStatusCancelled) +) + +// IsTerminal reports whether the status represents a final state — that is, +// the task will not change again. Terminal states are completed, failed, +// and cancelled. +func (s MCPTaskStatus) IsTerminal() bool { + return mcp.TaskStatus(s).IsTerminal() +} + +// MCPTaskMode controls when Kit augments tools/call requests with MCP task +// metadata for a specific server. +type MCPTaskMode string + +const ( + // MCPTaskModeAuto augments tools/call with task metadata only when the + // server advertises tasks/toolCalls capability during initialize. + // This is the default and is safe to leave unconfigured for any + // existing MCP server. + MCPTaskModeAuto MCPTaskMode = MCPTaskMode(tools.MCPTaskModeAuto) + // MCPTaskModeNever forces every tools/call to be issued synchronously + // (no Task field), regardless of server capability. + MCPTaskModeNever MCPTaskMode = MCPTaskMode(tools.MCPTaskModeNever) + // MCPTaskModeAlways always opts into task augmentation, even when the + // server didn't advertise the capability. The server may still respond + // synchronously; this just expresses client intent unconditionally. + MCPTaskModeAlways MCPTaskMode = MCPTaskMode(tools.MCPTaskModeAlways) +) + +// MCPTask is the SDK-level view of an MCP Task. Timestamps are best-effort +// parsed from the server's ISO-8601 strings; they may be the zero time when +// the server omitted them or used a non-RFC3339 format. +type MCPTask struct { + // Server is the configured MCP server name this task lives on. + Server string + // TaskID is the server-assigned identifier for the task. + TaskID string + // Status is the current task lifecycle state. + Status MCPTaskStatus + // StatusMessage is an optional human-readable description provided by + // the server. + StatusMessage string + // CreatedAt is when the task was created on the server. + CreatedAt time.Time + // UpdatedAt is when the task was last updated on the server. + UpdatedAt time.Time + // TTL is how long the server intends to retain this task after creation. + // Zero means the server did not advertise a TTL. + TTL time.Duration + // PollInterval is the suggested time between status checks. Zero means + // the client should use its own default. + PollInterval time.Duration +} + +// MCPTaskProgress is a single status update emitted while Kit is waiting +// on a task-augmented tool call. +type MCPTaskProgress struct { + // Server is the configured MCP server name. + Server string + // TaskID is the server-assigned identifier for the in-flight task. + TaskID string + // Status is the most recent task status observed. + Status MCPTaskStatus + // Message is the optional human-readable status message from the server. + Message string +} + +// MCPTaskProgressHandler is called once when a task is accepted and again +// on every observed status transition. The final invocation always carries +// a terminal status. Implementations must not block; long work should be +// dispatched on a goroutine. +type MCPTaskProgressHandler func(MCPTaskProgress) + +// mcpTaskOptions carries SDK consumer configuration into the agent setup. +// Stored on Options as a single value so the public surface stays compact; +// individual fields are exposed via WithMCP* builder functions. +type mcpTaskOptions struct { + perServer map[string]MCPTaskMode + defaultTTL time.Duration + pollInterval time.Duration + maxPollInterval time.Duration + timeout time.Duration + progress MCPTaskProgressHandler +} + +// toToolsConfig converts the SDK-level config to the internal tools-package +// representation. Keeps the dependency arrow internal-only. +func (o mcpTaskOptions) toToolsConfig() tools.MCPTaskConfig { + cfg := tools.MCPTaskConfig{ + DefaultTTL: o.defaultTTL, + PollInterval: o.pollInterval, + MaxPollInterval: o.maxPollInterval, + Timeout: o.timeout, + } + if len(o.perServer) > 0 { + cfg.PerServerMode = make(map[string]tools.MCPTaskMode, len(o.perServer)) + for k, v := range o.perServer { + cfg.PerServerMode[k] = tools.MCPTaskMode(v) + } + } + if o.progress != nil { + h := o.progress + cfg.Progress = func(p tools.MCPTaskProgress) { + h(MCPTaskProgress{ + Server: p.Server, + TaskID: p.TaskID, + Status: MCPTaskStatus(p.Status), + Message: p.Message, + }) + } + } + return cfg +} + +// ListMCPTasks queries tasks/list on the named MCP server and returns the +// active and recent tasks the server is willing to disclose. Returns an +// error when the server isn't loaded, doesn't expose tasks/list, or the +// underlying transport fails. +func (m *Kit) ListMCPTasks(ctx context.Context, serverName string) ([]MCPTask, error) { + mgr, err := m.mcpToolManager() + if err != nil { + return nil, err + } + infos, err := mgr.ListServerTasks(ctx, serverName) + if err != nil { + return nil, err + } + out := make([]MCPTask, len(infos)) + for i, t := range infos { + out[i] = mcpTaskFromInternal(t) + } + return out, nil +} + +// GetMCPTask queries tasks/get for a single in-flight task on the named +// server. The returned MCPTask reflects the server's current view of the +// task. +func (m *Kit) GetMCPTask(ctx context.Context, serverName, taskID string) (MCPTask, error) { + mgr, err := m.mcpToolManager() + if err != nil { + return MCPTask{}, err + } + info, err := mgr.GetServerTask(ctx, serverName, taskID) + if err != nil { + return MCPTask{}, err + } + return mcpTaskFromInternal(info), nil +} + +// CancelMCPTask issues tasks/cancel for an in-flight task on the named +// server. Returns the post-cancel task state when the server responded +// with one. Cancelling an already-terminal task is a no-op on most +// servers. +func (m *Kit) CancelMCPTask(ctx context.Context, serverName, taskID string) (MCPTask, error) { + mgr, err := m.mcpToolManager() + if err != nil { + return MCPTask{}, err + } + info, err := mgr.CancelServerTask(ctx, serverName, taskID) + if err != nil { + return MCPTask{}, err + } + return mcpTaskFromInternal(info), nil +} + +// mcpToolManager returns the underlying MCP tool manager or an error when +// no MCP servers are configured. +func (m *Kit) mcpToolManager() (*tools.MCPToolManager, error) { + if m == nil || m.agent == nil { + return nil, fmt.Errorf("kit instance has no agent") + } + mgr := m.agent.GetMCPToolManager() + if mgr == nil { + return nil, fmt.Errorf("no MCP servers configured") + } + return mgr, nil +} + +// mcpTaskFromInternal adapts the internal tools.MCPTaskInfo to the +// SDK-level MCPTask type. Keeps the public surface independent of +// internal package types. +func mcpTaskFromInternal(t tools.MCPTaskInfo) MCPTask { + return MCPTask{ + Server: t.Server, + TaskID: t.TaskID, + Status: MCPTaskStatus(t.Status), + StatusMessage: t.StatusMessage, + CreatedAt: t.CreatedAt, + UpdatedAt: t.UpdatedAt, + TTL: t.TTL, + PollInterval: t.PollInterval, + } +} diff --git a/pkg/kit/mcp_tasks_test.go b/pkg/kit/mcp_tasks_test.go new file mode 100644 index 00000000..50589188 --- /dev/null +++ b/pkg/kit/mcp_tasks_test.go @@ -0,0 +1,120 @@ +package kit + +import ( + "testing" + "time" + + "github.com/mark3labs/kit/internal/tools" +) + +func TestMCPTaskStatusIsTerminal(t *testing.T) { + cases := []struct { + s MCPTaskStatus + want bool + }{ + {MCPTaskStatusWorking, false}, + {MCPTaskStatusInputRequired, false}, + {MCPTaskStatusCompleted, true}, + {MCPTaskStatusFailed, true}, + {MCPTaskStatusCancelled, true}, + {MCPTaskStatus("unknown"), false}, + } + for _, tc := range cases { + if got := tc.s.IsTerminal(); got != tc.want { + t.Errorf("MCPTaskStatus(%q).IsTerminal() = %v, want %v", tc.s, got, tc.want) + } + } +} + +func TestMCPTaskOptionsToToolsConfig(t *testing.T) { + called := 0 + o := mcpTaskOptions{ + perServer: map[string]MCPTaskMode{ + "alpha": MCPTaskModeAlways, + "beta": MCPTaskModeNever, + }, + defaultTTL: 30 * time.Second, + pollInterval: 250 * time.Millisecond, + maxPollInterval: 2 * time.Second, + timeout: 5 * time.Minute, + progress: func(p MCPTaskProgress) { called++ }, + } + cfg := o.toToolsConfig() + + if cfg.DefaultTTL != 30*time.Second { + t.Errorf("DefaultTTL = %v, want 30s", cfg.DefaultTTL) + } + if cfg.PollInterval != 250*time.Millisecond { + t.Errorf("PollInterval = %v, want 250ms", cfg.PollInterval) + } + if cfg.MaxPollInterval != 2*time.Second { + t.Errorf("MaxPollInterval = %v, want 2s", cfg.MaxPollInterval) + } + if cfg.Timeout != 5*time.Minute { + t.Errorf("Timeout = %v, want 5m", cfg.Timeout) + } + if cfg.PerServerMode["alpha"] != tools.MCPTaskModeAlways { + t.Errorf("PerServerMode[alpha] = %q, want always", cfg.PerServerMode["alpha"]) + } + if cfg.PerServerMode["beta"] != tools.MCPTaskModeNever { + t.Errorf("PerServerMode[beta] = %q, want never", cfg.PerServerMode["beta"]) + } + + // Progress conversion: invoking the internal handler must call our + // SDK-level callback with the converted struct. + if cfg.Progress == nil { + t.Fatal("Progress callback was lost in conversion") + } + cfg.Progress(tools.MCPTaskProgress{ + Server: "alpha", + TaskID: "t1", + Status: "working", + }) + if called != 1 { + t.Errorf("expected SDK progress handler to be invoked once, got %d", called) + } +} + +func TestMCPTaskFromInternal(t *testing.T) { + in := tools.MCPTaskInfo{ + Server: "srv", + TaskID: "t-1", + Status: "working", + StatusMessage: "phase 1", + CreatedAt: time.Date(2026, 5, 4, 12, 0, 0, 0, time.UTC), + UpdatedAt: time.Date(2026, 5, 4, 12, 0, 1, 0, time.UTC), + TTL: 5 * time.Minute, + PollInterval: 500 * time.Millisecond, + } + out := mcpTaskFromInternal(in) + + if out.Server != "srv" || out.TaskID != "t-1" { + t.Errorf("identity fields not copied: %+v", out) + } + if out.Status != MCPTaskStatusWorking { + t.Errorf("Status = %q, want working", out.Status) + } + if out.StatusMessage != "phase 1" { + t.Errorf("StatusMessage = %q, want phase 1", out.StatusMessage) + } + if out.TTL != 5*time.Minute || out.PollInterval != 500*time.Millisecond { + t.Errorf("durations not copied: %+v", out) + } +} + +func TestKitMCPTasksWithoutAgentReturnsError(t *testing.T) { + // A nil/zero Kit must not panic — task RPCs should surface a clear + // error instead. Useful for SDK consumers that try task ops on a Kit + // constructed without MCP servers. + var k *Kit + ctx := t.Context() + if _, err := k.ListMCPTasks(ctx, "any"); err == nil { + t.Error("ListMCPTasks on nil Kit should error") + } + if _, err := k.GetMCPTask(ctx, "any", "id"); err == nil { + t.Error("GetMCPTask on nil Kit should error") + } + if _, err := k.CancelMCPTask(ctx, "any", "id"); err == nil { + t.Error("CancelMCPTask on nil Kit should error") + } +}