diff --git a/cmd/root.go b/cmd/root.go index 36d820fb..51008ed8 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -717,13 +717,20 @@ func runNormalMode(ctx context.Context) error { // Build Kit options from CLI flags and create the SDK instance. // kit.New() handles: config → skills → agent → session → extension bridge. + authHandler, authErr := kit.NewCLIMCPAuthHandler() + if authErr != nil { + // Non-fatal: OAuth just won't be available for remote MCP servers. + fmt.Fprintf(os.Stderr, "Warning: Failed to create OAuth handler: %v\n", authErr) + } + kitOpts := &kit.Options{ - Quiet: quietFlag, - Debug: debugMode, - NoSession: noSessionFlag, - Continue: continueFlag, - SessionPath: sessionPath, - AutoCompact: autoCompactFlag, + Quiet: quietFlag, + Debug: debugMode, + NoSession: noSessionFlag, + Continue: continueFlag, + SessionPath: sessionPath, + AutoCompact: autoCompactFlag, + MCPAuthHandler: authHandler, CLI: &kit.CLIOptions{ MCPConfig: mcpConfig, ShowSpinner: true, @@ -796,6 +803,13 @@ func runNormalMode(ctx context.Context) error { appInstance := app.New(appOpts, messages) defer appInstance.Close() + // Wire OAuth handler to route messages through the TUI once it's running. + if authHandler != nil { + authHandler.NotifyFunc = func(serverName, message string) { + appInstance.PrintFromExtension("info", message) + } + } + // Buffer for extension messages during startup (printed after startup banner). var startupExtensionMessages []string diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 0e45d28c..8e9502cb 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -25,6 +25,11 @@ type AgentConfig struct { StreamingEnabled bool DebugLogger tools.DebugLogger + // AuthHandler handles OAuth authorization for remote MCP servers. + // When set, remote transports are configured with OAuth support. + // If nil, remote MCP servers that require OAuth will fail to connect. + AuthHandler tools.MCPAuthHandler + // CoreTools overrides the default core tool set. If empty, core.AllTools() // is used. This allows SDK users to provide a custom tool set (e.g. // CodingTools or tools with a custom WorkDir). @@ -139,6 +144,10 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) { toolManager = tools.NewMCPToolManager() toolManager.SetModel(providerResult.Model) + if agentConfig.AuthHandler != nil { + toolManager.SetAuthHandler(agentConfig.AuthHandler) + } + if agentConfig.DebugLogger != nil { toolManager.SetDebugLogger(agentConfig.DebugLogger) } diff --git a/internal/agent/factory.go b/internal/agent/factory.go index 249f2b60..de01de46 100644 --- a/internal/agent/factory.go +++ b/internal/agent/factory.go @@ -36,6 +36,8 @@ type AgentCreationOptions struct { SpinnerFunc SpinnerFunc // Function to show spinner (provided by caller) // DebugLogger is an optional logger for debugging MCP communications DebugLogger tools.DebugLogger // Optional debug logger + // AuthHandler handles OAuth authorization for remote MCP servers + AuthHandler tools.MCPAuthHandler // CoreTools overrides the default core tool set. If empty, core.AllTools() // is used. CoreTools []fantasy.AgentTool @@ -56,6 +58,7 @@ func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error MaxSteps: opts.MaxSteps, StreamingEnabled: opts.StreamingEnabled, DebugLogger: opts.DebugLogger, + AuthHandler: opts.AuthHandler, CoreTools: opts.CoreTools, ToolWrapper: opts.ToolWrapper, ExtraTools: opts.ExtraTools, diff --git a/internal/kitsetup/setup.go b/internal/kitsetup/setup.go index 4d46e2f8..eee65820 100644 --- a/internal/kitsetup/setup.go +++ b/internal/kitsetup/setup.go @@ -58,6 +58,9 @@ type AgentSetupOptions struct { // StreamingEnabled controls streaming. Only meaningful when ProviderConfig // is also set. StreamingEnabled bool + // AuthHandler handles OAuth authorization for remote MCP servers. + // When set, remote transports are configured with OAuth support. + AuthHandler tools.MCPAuthHandler } // AgentSetupResult bundles the created agent and any debug logger so the caller @@ -185,6 +188,7 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult, Quiet: opts.Quiet, SpinnerFunc: opts.SpinnerFunc, DebugLogger: debugLogger, + AuthHandler: opts.AuthHandler, CoreTools: opts.CoreTools, ToolWrapper: toolWrapper, ExtraTools: extraTools, diff --git a/internal/tools/connection_pool.go b/internal/tools/connection_pool.go index b3565955..6379cd97 100644 --- a/internal/tools/connection_pool.go +++ b/internal/tools/connection_pool.go @@ -68,6 +68,7 @@ type MCPConnectionPool struct { cancel context.CancelFunc debug bool debugLogger DebugLogger + oauthFlow *OAuthFlowRunner } // NewMCPConnectionPool creates a new MCP connection pool with the specified configuration. @@ -75,7 +76,7 @@ type MCPConnectionPool struct { // goroutine for periodic health checks that runs until Close is called. // The model parameter is used for MCP servers that require sampling support. // Thread-safe for concurrent use immediately after creation. -func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageModel, debug bool) *MCPConnectionPool { +func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageModel, debug bool, authHandler MCPAuthHandler) *MCPConnectionPool { if config == nil { config = DefaultConnectionPoolConfig() } @@ -90,6 +91,10 @@ func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageMo debug: debug, } + if authHandler != nil { + pool.oauthFlow = NewOAuthFlowRunner(authHandler) + } + go pool.startHealthCheck() return pool } @@ -103,6 +108,15 @@ func (p *MCPConnectionPool) SetDebugLogger(logger DebugLogger) { p.debugLogger = logger } +// SetOAuthFlow sets the OAuth flow runner for the connection pool. +// When set, the pool can trigger OAuth re-authorization when a tool call fails +// with an OAuth error (e.g. expired token). Thread-safe and can be called at any time. +func (p *MCPConnectionPool) SetOAuthFlow(flow *OAuthFlowRunner) { + p.mu.Lock() + defer p.mu.Unlock() + p.oauthFlow = flow +} + // GetConnection retrieves or creates a connection for the specified MCP server. // If a healthy, non-idle connection exists in the pool, it will be reused. // Otherwise, a new connection is created and added to the pool. @@ -230,18 +244,43 @@ func (p *MCPConnectionPool) performHealthCheck(ctx context.Context, conn *MCPCon // createConnection creates a new connection func (p *MCPConnectionPool) createConnection(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (*MCPConnection, error) { - client, err := p.createMCPClient(ctx, serverName, serverConfig) + mcpClient, err := p.createMCPClient(ctx, serverName, serverConfig) if err != nil { - return nil, err + // SSE transport can return OAuth error during Start() + if p.oauthFlow != nil && IsOAuthError(err) { + if flowErr := p.oauthFlow.RunAuthFlow(ctx, serverName, err); flowErr != nil { + return nil, fmt.Errorf("OAuth authorization failed: %w", flowErr) + } + // Retry after successful auth + mcpClient, err = p.createMCPClient(ctx, serverName, serverConfig) + if err != nil { + return nil, err + } + } else { + return nil, err + } } - if err := p.initializeClient(ctx, client); err != nil { - _ = client.Close() - return nil, err + if err := p.initializeClient(ctx, mcpClient); err != nil { + // Streamable HTTP transport returns OAuth error during Initialize() + if p.oauthFlow != nil && IsOAuthError(err) { + if flowErr := p.oauthFlow.RunAuthFlow(ctx, serverName, err); flowErr != nil { + _ = mcpClient.Close() + return nil, fmt.Errorf("OAuth authorization failed: %w", flowErr) + } + // Retry initialization after successful auth + if err := p.initializeClient(ctx, mcpClient); err != nil { + _ = mcpClient.Close() + return nil, err + } + } else { + _ = mcpClient.Close() + return nil, err + } } conn := &MCPConnection{ - client: client, + client: mcpClient, serverName: serverName, serverConfig: serverConfig, lastUsed: time.Now(), @@ -323,13 +362,29 @@ func (p *MCPConnectionPool) createSSEClient(ctx context.Context, serverConfig co } } + // Enable OAuth for remote transports when an auth handler is configured. + // The OAuthConfig uses PKCE and the handler's redirect URI. Client ID and + // scopes are discovered automatically via dynamic client registration and + // server metadata (RFC 9728). + if p.oauthFlow != nil { + tokenStore, tsErr := NewFileTokenStore(serverConfig.URL) + if tsErr != nil { + return nil, fmt.Errorf("failed to create token store: %w", tsErr) + } + options = append(options, transport.WithOAuth(transport.OAuthConfig{ + RedirectURI: p.oauthFlow.handler.RedirectURI(), + PKCEEnabled: true, + TokenStore: tokenStore, + })) + } + sseClient, err := client.NewSSEMCPClient(serverConfig.URL, options...) if err != nil { return nil, err } if err := sseClient.Start(ctx); err != nil { - return nil, fmt.Errorf("failed to start SSE client: %v", err) + return nil, fmt.Errorf("failed to start SSE client: %w", err) } return sseClient, nil @@ -354,13 +409,29 @@ func (p *MCPConnectionPool) createStreamableClient(ctx context.Context, serverCo } } + // Enable OAuth for remote transports when an auth handler is configured. + // The OAuthConfig uses PKCE and the handler's redirect URI. Client ID and + // scopes are discovered automatically via dynamic client registration and + // server metadata (RFC 9728). + if p.oauthFlow != nil { + tokenStore, tsErr := NewFileTokenStore(serverConfig.URL) + if tsErr != nil { + return nil, fmt.Errorf("failed to create token store: %w", tsErr) + } + options = append(options, transport.WithHTTPOAuth(transport.OAuthConfig{ + RedirectURI: p.oauthFlow.handler.RedirectURI(), + PKCEEnabled: true, + TokenStore: tokenStore, + })) + } + streamableClient, err := client.NewStreamableHttpClient(serverConfig.URL, options...) if err != nil { return nil, err } if err := streamableClient.Start(ctx); err != nil { - return nil, fmt.Errorf("failed to start streamable HTTP client: %v", err) + return nil, fmt.Errorf("failed to start streamable HTTP client: %w", err) } return streamableClient, nil @@ -381,7 +452,7 @@ func (p *MCPConnectionPool) initializeClient(ctx context.Context, client client. _, err := client.Initialize(initCtx, initRequest) if err != nil { - return fmt.Errorf("initialization timeout or failed: %v", err) + return fmt.Errorf("initialization timeout or failed: %w", err) } if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() { @@ -539,6 +610,9 @@ func (p *MCPConnectionPool) Close() error { // isConnectionError checks if the error is connection-related func isConnectionError(err error) bool { + if IsOAuthError(err) { + return false // OAuth errors are recoverable, not connection failures + } errStr := err.Error() return strings.Contains(errStr, "Connection not found") || strings.Contains(errStr, "transport error") || diff --git a/internal/tools/fantasy_adapter.go b/internal/tools/fantasy_adapter.go index 2fc3c1f5..e89e6d90 100644 --- a/internal/tools/fantasy_adapter.go +++ b/internal/tools/fantasy_adapter.go @@ -59,9 +59,30 @@ func (t *mcpFantasyTool) Run(ctx context.Context, call fantasy.ToolCall) (fantas }, }) if err != nil { - // Mark connection as unhealthy for automatic recovery - t.mapping.manager.connectionPool.HandleConnectionError(t.mapping.serverName, err) - return fantasy.ToolResponse{}, fmt.Errorf("failed to call mcp tool: %w", err) + // Handle OAuth re-authorization: token may have expired mid-session. + if t.mapping.manager.connectionPool.oauthFlow != nil && IsOAuthError(err) { + if flowErr := t.mapping.manager.connectionPool.oauthFlow.RunAuthFlow(ctx, t.mapping.serverName, err); flowErr != nil { + return fantasy.ToolResponse{}, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", t.mapping.originalName, flowErr) + } + // Retry the tool call after successful re-auth. + result, err = conn.client.CallTool(ctx, mcp.CallToolRequest{ + Request: mcp.Request{ + Method: "tools/call", + }, + Params: mcp.CallToolParams{ + Name: t.mapping.originalName, + Arguments: arguments, + }, + }) + if err != nil { + t.mapping.manager.connectionPool.HandleConnectionError(t.mapping.serverName, err) + return fantasy.ToolResponse{}, fmt.Errorf("failed to call mcp tool after re-auth: %w", err) + } + } else { + // Mark connection as unhealthy for automatic recovery + t.mapping.manager.connectionPool.HandleConnectionError(t.mapping.serverName, err) + return fantasy.ToolResponse{}, fmt.Errorf("failed to call mcp tool: %w", err) + } } // Marshal the MCP result to JSON string diff --git a/internal/tools/mcp.go b/internal/tools/mcp.go index f97136aa..e282395f 100644 --- a/internal/tools/mcp.go +++ b/internal/tools/mcp.go @@ -22,6 +22,7 @@ type MCPToolManager struct { tools []fantasy.AgentTool toolMap map[string]*toolMapping // maps prefixed tool names to their server and original name model fantasy.LanguageModel // LLM model for sampling + authHandler MCPAuthHandler // OAuth handler for remote servers (nil = no OAuth) config *config.Config debug bool debugLogger DebugLogger @@ -53,6 +54,14 @@ func (m *MCPToolManager) SetModel(model fantasy.LanguageModel) { m.model = model } +// SetAuthHandler sets the OAuth handler for remote MCP server authentication. +// When set, remote transports (streamable HTTP, SSE) are configured with OAuth +// support, enabling automatic authorization flows when servers require authentication. +// This method should be called before LoadTools. +func (m *MCPToolManager) SetAuthHandler(handler MCPAuthHandler) { + m.authHandler = handler +} + // SetDebugLogger sets the debug logger for the tool manager. // The logger will be used to output detailed debugging information about MCP connections, // tool loading, and execution. If a connection pool exists, it will also be configured @@ -76,7 +85,7 @@ func (m *MCPToolManager) LoadTools(ctx context.Context, config *config.Config) e if m.debugLogger == nil { m.debugLogger = NewSimpleDebugLogger(config.Debug) } - m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, config.Debug) + m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, config.Debug, m.authHandler) m.connectionPool.SetDebugLogger(m.debugLogger) var loadErrors []string diff --git a/internal/tools/oauth_flow.go b/internal/tools/oauth_flow.go new file mode 100644 index 00000000..5acfaff1 --- /dev/null +++ b/internal/tools/oauth_flow.go @@ -0,0 +1,109 @@ +package tools + +import ( + "context" + "fmt" + "net/url" + + "github.com/mark3labs/mcp-go/client" +) + +// MCPAuthHandler is the internal interface for handling MCP OAuth flows. +// The SDK-level kit.MCPAuthHandler is adapted to this interface in cmd/root.go +// or pkg/kit/kit.go, keeping the tools package decoupled from the SDK. +type MCPAuthHandler interface { + // RedirectURI returns the OAuth redirect URI for transport setup. + RedirectURI() string + // HandleAuth is called when a server requires OAuth authorization. + // It receives the server name and the authorization URL the user must visit. + // It returns the full callback URL (containing code and state query params) + // after the user completes authorization. + HandleAuth(ctx context.Context, serverName string, authURL string) (callbackURL string, err error) +} + +// OAuthFlowRunner handles the OAuth authorization flow when an MCP server +// returns an OAuthAuthorizationRequiredError. It coordinates dynamic client +// registration, PKCE generation, user authorization (via MCPAuthHandler), +// and token exchange. +type OAuthFlowRunner struct { + handler MCPAuthHandler +} + +// NewOAuthFlowRunner creates a new OAuthFlowRunner with the given auth handler. +func NewOAuthFlowRunner(handler MCPAuthHandler) *OAuthFlowRunner { + return &OAuthFlowRunner{handler: handler} +} + +// RunAuthFlow executes the OAuth authorization flow for the given server. +// It extracts the OAuthHandler from the error, performs dynamic client registration +// if needed, generates PKCE parameters, delegates to the MCPAuthHandler for user +// interaction, and exchanges the authorization code for a token. +func (r *OAuthFlowRunner) RunAuthFlow(ctx context.Context, serverName string, authErr error) error { + // Extract the OAuthHandler from the authorization-required error. + oauthHandler := client.GetOAuthHandler(authErr) + if oauthHandler == nil { + return fmt.Errorf("oauth flow: failed to extract OAuth handler from error: %w", authErr) + } + + // Perform dynamic client registration if no client ID is configured yet. + if oauthHandler.GetClientID() == "" { + if err := oauthHandler.RegisterClient(ctx, "kit"); err != nil { + return fmt.Errorf("oauth flow: dynamic client registration failed: %w", err) + } + } + + // Generate PKCE code verifier and challenge. + codeVerifier, err := client.GenerateCodeVerifier() + if err != nil { + return fmt.Errorf("oauth flow: failed to generate code verifier: %w", err) + } + codeChallenge := client.GenerateCodeChallenge(codeVerifier) + + // Generate a random state parameter for CSRF protection. + state, err := client.GenerateState() + if err != nil { + return fmt.Errorf("oauth flow: failed to generate state: %w", err) + } + + // Build the authorization URL the user needs to visit. + authURL, err := oauthHandler.GetAuthorizationURL(ctx, state, codeChallenge) + if err != nil { + return fmt.Errorf("oauth flow: failed to get authorization URL: %w", err) + } + + // Delegate to the MCPAuthHandler for user-facing authorization (e.g. open + // browser, wait for redirect). It returns the full callback URL containing + // the authorization code and state. + callbackURL, err := r.handler.HandleAuth(ctx, serverName, authURL) + if err != nil { + return fmt.Errorf("oauth flow: user authorization failed: %w", err) + } + + // Parse the callback URL to extract the authorization code and state. + parsed, err := url.Parse(callbackURL) + if err != nil { + return fmt.Errorf("oauth flow: failed to parse callback URL: %w", err) + } + + code := parsed.Query().Get("code") + returnedState := parsed.Query().Get("state") + + if code == "" { + return fmt.Errorf("oauth flow: callback URL missing 'code' parameter") + } + if returnedState == "" { + return fmt.Errorf("oauth flow: callback URL missing 'state' parameter") + } + + // Exchange the authorization code for an access token. + if err := oauthHandler.ProcessAuthorizationResponse(ctx, code, returnedState, codeVerifier); err != nil { + return fmt.Errorf("oauth flow: token exchange failed: %w", err) + } + + return nil +} + +// IsOAuthError returns true if the error is an OAuthAuthorizationRequiredError. +func IsOAuthError(err error) bool { + return client.IsOAuthAuthorizationRequiredError(err) +} diff --git a/internal/tools/token_store.go b/internal/tools/token_store.go new file mode 100644 index 00000000..8c9763d3 --- /dev/null +++ b/internal/tools/token_store.go @@ -0,0 +1,155 @@ +package tools + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" + + "github.com/mark3labs/mcp-go/client/transport" +) + +// Compile-time check that FileTokenStore implements transport.TokenStore. +var _ transport.TokenStore = (*FileTokenStore)(nil) + +// FileTokenStore is a file-backed implementation of transport.TokenStore that +// persists OAuth tokens as JSON on disk. Tokens are stored in a shared JSON file +// keyed by server URL, allowing multiple MCP servers to maintain independent tokens. +// +// The token file is located at $XDG_CONFIG_HOME/.kit/mcp_tokens.json, falling back +// to ~/.config/.kit/mcp_tokens.json when XDG_CONFIG_HOME is not set. +// +// FileTokenStore is safe for concurrent use. +type FileTokenStore struct { + serverKey string + filePath string + mu sync.RWMutex +} + +// NewFileTokenStore creates a new FileTokenStore for the given server URL. +// The serverKey is used as the map key in the shared token file, and should +// typically be the MCP server's base URL. +// +// Returns an error if the token file path cannot be resolved. +func NewFileTokenStore(serverKey string) (*FileTokenStore, error) { + filePath, err := resolveTokenFilePath() + if err != nil { + return nil, fmt.Errorf("resolving token file path: %w", err) + } + + return &FileTokenStore{ + serverKey: serverKey, + filePath: filePath, + }, nil +} + +// GetToken returns the stored token for this store's server key. +// Returns transport.ErrNoToken if no token exists for the server key or if +// the token file does not yet exist. +// Returns context.Canceled or context.DeadlineExceeded if the context is done. +func (s *FileTokenStore) GetToken(ctx context.Context) (*transport.Token, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + + s.mu.RLock() + defer s.mu.RUnlock() + + tokens, err := readTokenFile(s.filePath) + if err != nil { + if os.IsNotExist(err) { + return nil, transport.ErrNoToken + } + return nil, fmt.Errorf("reading token file: %w", err) + } + + token, ok := tokens[s.serverKey] + if !ok { + return nil, transport.ErrNoToken + } + + return token, nil +} + +// SaveToken persists the given token for this store's server key. +// If the token file or its parent directories do not exist, they are created. +// Existing tokens for other server keys are preserved. +// Returns context.Canceled or context.DeadlineExceeded if the context is done. +func (s *FileTokenStore) SaveToken(ctx context.Context, token *transport.Token) error { + if err := ctx.Err(); err != nil { + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + + tokens, err := readTokenFile(s.filePath) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("reading token file: %w", err) + } + if tokens == nil { + tokens = make(map[string]*transport.Token) + } + + tokens[s.serverKey] = token + + if err := writeTokenFile(s.filePath, tokens); err != nil { + return fmt.Errorf("writing token file: %w", err) + } + + return nil +} + +// resolveTokenFilePath determines the path to the token file using +// XDG_CONFIG_HOME if set, otherwise falling back to ~/.config/.kit/. +func resolveTokenFilePath() (string, error) { + configDir := os.Getenv("XDG_CONFIG_HOME") + if configDir == "" { + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("determining user home directory: %w", err) + } + configDir = filepath.Join(home, ".config") + } + + return filepath.Join(configDir, ".kit", "mcp_tokens.json"), nil +} + +// readTokenFile reads and unmarshals the token file into a server-keyed map. +// Returns os.ErrNotExist (via os.IsNotExist) if the file does not exist. +func readTokenFile(path string) (map[string]*transport.Token, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + var tokens map[string]*transport.Token + if err := json.Unmarshal(data, &tokens); err != nil { + return nil, fmt.Errorf("unmarshaling token file: %w", err) + } + + return tokens, nil +} + +// writeTokenFile marshals the token map and writes it to disk, creating +// parent directories as needed. The file is written with 0600 permissions +// to protect sensitive token data. +func writeTokenFile(path string, tokens map[string]*transport.Token) error { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0700); err != nil { + return fmt.Errorf("creating token directory %s: %w", dir, err) + } + + data, err := json.MarshalIndent(tokens, "", " ") + if err != nil { + return fmt.Errorf("marshaling tokens: %w", err) + } + + if err := os.WriteFile(path, data, 0600); err != nil { + return fmt.Errorf("writing token file %s: %w", path, err) + } + + return nil +} diff --git a/pkg/kit/kit.go b/pkg/kit/kit.go index 1243e6d1..664cd6a2 100644 --- a/pkg/kit/kit.go +++ b/pkg/kit/kit.go @@ -48,6 +48,7 @@ type Kit struct { skills []*skills.Skill extRunner *extensions.Runner bufferedLogger *tools.BufferedDebugLogger + authHandler MCPAuthHandler // OAuth handler for remote MCP servers (may need Close) // Hook registries — interception layer (see hooks.go). beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult] @@ -439,6 +440,18 @@ type Options struct { // Debug enables debug logging for the SDK. Debug bool + // MCPAuthHandler handles OAuth authorization for remote MCP servers. + // When set, remote transports (streamable HTTP, SSE) are configured with + // OAuth support. If the server returns a 401, the handler is invoked to + // let the user authorize via browser. + // + // If nil, a [DefaultMCPAuthHandler] is created automatically — opening the + // system browser and listening on a local callback server. + // + // Set to a custom implementation to control the authorization UX (e.g. + // display a URL in a custom UI, redirect to a web app, etc.). + MCPAuthHandler MCPAuthHandler + // CLI is optional CLI-specific configuration. SDK users leave this nil. CLI *CLIOptions } @@ -655,6 +668,23 @@ func New(ctx context.Context, opts *Options) (*Kit, error) { MaxSteps: maxSteps, StreamingEnabled: streaming, } + + // Set up OAuth handler for remote MCP servers. + // The SDK MCPAuthHandler interface is structurally identical to + // tools.MCPAuthHandler, so any implementation satisfies both. + if opts.MCPAuthHandler != nil { + setupOpts.AuthHandler = opts.MCPAuthHandler + } else { + // Create a default handler that opens the system browser. + defaultHandler, authErr := NewDefaultMCPAuthHandler() + if authErr != nil { + // Non-fatal: OAuth just won't be available for remote servers. + charmlog.Warn("Failed to create OAuth handler; remote MCP servers requiring auth will fail", "error", authErr) + } else { + setupOpts.AuthHandler = defaultHandler + } + } + if opts.CLI != nil { setupOpts.ShowSpinner = opts.CLI.ShowSpinner setupOpts.SpinnerFunc = opts.CLI.SpinnerFunc @@ -685,6 +715,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) { skills: loadedSkills, extRunner: agentResult.ExtRunner, bufferedLogger: agentResult.BufferedLogger, + authHandler: setupOpts.AuthHandler, beforeToolCall: beforeToolCall, afterToolResult: afterToolResult, beforeTurn: beforeTurn, @@ -1645,5 +1676,9 @@ func (m *Kit) Close() error { if m.treeSession != nil { _ = m.treeSession.Close() } + // Release the OAuth callback port if we own the handler. + if closer, ok := m.authHandler.(interface{ Close() error }); ok { + _ = closer.Close() + } return m.agent.Close() } diff --git a/pkg/kit/oauth.go b/pkg/kit/oauth.go new file mode 100644 index 00000000..07991d33 --- /dev/null +++ b/pkg/kit/oauth.go @@ -0,0 +1,265 @@ +package kit + +import ( + "context" + "fmt" + "net" + "net/http" + "os/exec" + "runtime" + "sync" + "time" +) + +// MCPAuthHandler handles OAuth authorization for MCP servers. +// Implementations control the user experience — opening a browser, showing a +// prompt, displaying a URL, etc. +// +// The default implementation ([DefaultMCPAuthHandler]) opens the system browser +// and starts a local HTTP callback server to receive the authorization code. +type MCPAuthHandler interface { + // RedirectURI returns the OAuth redirect URI that the callback server + // will listen on. This is called during MCP transport setup — before any + // OAuth errors occur — so the redirect URI can be registered with the + // authorization server. + RedirectURI() string + + // HandleAuth is called when an MCP server requires OAuth authorization. + // It receives the server name and an authorization URL that the user must + // visit. The handler must: + // 1. Direct the user to authURL (e.g. open browser, display URL) + // 2. Listen for the OAuth callback on the redirect URI + // 3. Return the full callback URL (with code and state query params) + // + // Return an error to abort the connection to this MCP server. + // The context controls the overall timeout; implementations should + // respect ctx.Done(). + HandleAuth(ctx context.Context, serverName string, authURL string) (callbackURL string, err error) +} + +// DefaultMCPAuthHandler opens the system browser and starts a local HTTP +// callback server to receive the OAuth authorization code. It eagerly reserves +// a TCP port on construction so [RedirectURI] is stable for the lifetime of +// the handler. +// +// Create instances with [NewDefaultMCPAuthHandler] (random port) or +// [NewDefaultMCPAuthHandlerWithPort] (explicit port). +type DefaultMCPAuthHandler struct { + listener net.Listener + port int + mu sync.Mutex // guards listener lifecycle +} + +// NewDefaultMCPAuthHandler creates a handler that listens on a random +// available port on localhost. The port is reserved immediately so +// [RedirectURI] returns a stable value. Call [DefaultMCPAuthHandler.Close] +// when the handler is no longer needed to release the port. +func NewDefaultMCPAuthHandler() (*DefaultMCPAuthHandler, error) { + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + return nil, fmt.Errorf("failed to listen for OAuth callback: %w", err) + } + port := listener.Addr().(*net.TCPAddr).Port + return &DefaultMCPAuthHandler{listener: listener, port: port}, nil +} + +// NewDefaultMCPAuthHandlerWithPort creates a handler that listens on the +// specified port on localhost. The port is reserved immediately. Pass 0 to +// let the OS pick a free port (equivalent to [NewDefaultMCPAuthHandler]). +// Call [DefaultMCPAuthHandler.Close] when the handler is no longer needed. +func NewDefaultMCPAuthHandlerWithPort(port int) (*DefaultMCPAuthHandler, error) { + addr := fmt.Sprintf("localhost:%d", port) + listener, err := net.Listen("tcp", addr) + if err != nil { + return nil, fmt.Errorf("failed to listen on %s for OAuth callback: %w", addr, err) + } + actualPort := listener.Addr().(*net.TCPAddr).Port + return &DefaultMCPAuthHandler{listener: listener, port: actualPort}, nil +} + +// RedirectURI returns the OAuth redirect URI pointing to the local callback +// server. This value is stable for the lifetime of the handler. +func (h *DefaultMCPAuthHandler) RedirectURI() string { + return fmt.Sprintf("http://localhost:%d/oauth/callback", h.port) +} + +// Port returns the TCP port the callback server is bound to. +func (h *DefaultMCPAuthHandler) Port() int { + return h.port +} + +// HandleAuth opens the system browser to authURL and waits for the OAuth +// callback on the local server. It returns the full callback URL including +// query parameters (code, state, etc.). +// +// If the context has no deadline, a default 2-minute timeout is applied. +// The callback server is started for each HandleAuth call and shut down +// before returning. +func (h *DefaultMCPAuthHandler) HandleAuth(ctx context.Context, serverName string, authURL string) (string, error) { + h.mu.Lock() + listener := h.listener + h.mu.Unlock() + + if listener == nil { + return "", fmt.Errorf("OAuth callback handler is closed") + } + + // Apply default timeout if the context has no deadline. + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, 2*time.Minute) + defer cancel() + } + + // Channel receives the full callback URL from the HTTP handler. + callbackCh := make(chan string, 1) + + mux := http.NewServeMux() + mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) { + // Reconstruct the full callback URL as the caller expects it. + fullURL := fmt.Sprintf("http://localhost:%d%s", h.port, r.RequestURI) + + // Send the callback URL to the waiting goroutine (non-blocking). + select { + case callbackCh <- fullURL: + default: + } + + // Respond with a friendly HTML page so the user knows they can + // close the browser tab. + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprint(w, oauthSuccessHTML) + }) + + server := &http.Server{ + Handler: mux, + } + + // Start serving on the pre-reserved listener. We need to create a new + // listener on the same port because http.Server.Serve takes ownership + // and closes the listener when done. The original listener is kept open + // to reserve the port; we create a second listener via SO_REUSEADDR + // semantics (Go's default on most platforms) or, more reliably, we + // temporarily release and re-acquire. + // + // Strategy: use the held listener directly for Serve. After Serve + // returns (due to Shutdown), re-acquire the listener to keep the port + // reserved for future HandleAuth calls. + h.mu.Lock() + serveListener := h.listener + h.listener = nil // Serve will close it + h.mu.Unlock() + + if serveListener == nil { + return "", fmt.Errorf("OAuth callback handler is closed") + } + + // Start the HTTP server in a background goroutine. + serverErrCh := make(chan error, 1) + go func() { + err := server.Serve(serveListener) + if err != nil && err != http.ErrServerClosed { + serverErrCh <- err + } + close(serverErrCh) + }() + + // Re-acquire the listener after Serve completes (deferred). + defer func() { + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + _ = server.Shutdown(shutdownCtx) + + // Re-reserve the port for future HandleAuth calls. + h.mu.Lock() + defer h.mu.Unlock() + if h.listener == nil { + newListener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", h.port)) + if err == nil { + h.listener = newListener + } + // If re-listen fails, the handler degrades gracefully — the + // next HandleAuth call will return an error. + } + }() + + // Open the system browser. + if err := openBrowser(authURL); err != nil { + // Browser open is best-effort; the user can still navigate manually. + _ = err + } + + // Wait for the callback, a server error, or context cancellation. + select { + case url := <-callbackCh: + return url, nil + case err := <-serverErrCh: + return "", fmt.Errorf("OAuth callback server error for %q: %w", serverName, err) + case <-ctx.Done(): + return "", fmt.Errorf("OAuth authorization timed out for %q: %w", serverName, ctx.Err()) + } +} + +// Close releases the reserved port and shuts down the handler. After Close, +// HandleAuth will return an error. Close is safe to call multiple times. +func (h *DefaultMCPAuthHandler) Close() error { + h.mu.Lock() + defer h.mu.Unlock() + if h.listener != nil { + err := h.listener.Close() + h.listener = nil + return err + } + return nil +} + +// openBrowser opens the default system browser to the given URL. This is a +// best-effort operation — errors are returned but callers typically ignore +// them since the user can navigate manually. +func openBrowser(url string) error { + switch runtime.GOOS { + case "linux": + return exec.Command("xdg-open", url).Start() + case "windows": + return exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() + case "darwin": + return exec.Command("open", url).Start() + default: + return fmt.Errorf("unsupported platform: %s", runtime.GOOS) + } +} + +// oauthSuccessHTML is the HTML page returned to the browser after a +// successful OAuth callback. +const oauthSuccessHTML = ` + +
+ +You can close this tab and return to the terminal.
+