mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| aecce001ee | |||
| 32d73171fd | |||
| 265fd2ec0c | |||
| efebf2eba6 | |||
| f7b655ae33 |
+82
-8
@@ -717,13 +717,20 @@ func runNormalMode(ctx context.Context) error {
|
||||
|
||||
// Build Kit options from CLI flags and create the SDK instance.
|
||||
// kit.New() handles: config → skills → agent → session → extension bridge.
|
||||
authHandler, authErr := kit.NewCLIMCPAuthHandler()
|
||||
if authErr != nil {
|
||||
// Non-fatal: OAuth just won't be available for remote MCP servers.
|
||||
fmt.Fprintf(os.Stderr, "Warning: Failed to create OAuth handler: %v\n", authErr)
|
||||
}
|
||||
|
||||
kitOpts := &kit.Options{
|
||||
Quiet: quietFlag,
|
||||
Debug: debugMode,
|
||||
NoSession: noSessionFlag,
|
||||
Continue: continueFlag,
|
||||
SessionPath: sessionPath,
|
||||
AutoCompact: autoCompactFlag,
|
||||
Quiet: quietFlag,
|
||||
Debug: debugMode,
|
||||
NoSession: noSessionFlag,
|
||||
Continue: continueFlag,
|
||||
SessionPath: sessionPath,
|
||||
AutoCompact: autoCompactFlag,
|
||||
MCPAuthHandler: authHandler,
|
||||
CLI: &kit.CLIOptions{
|
||||
MCPConfig: mcpConfig,
|
||||
ShowSpinner: true,
|
||||
@@ -796,6 +803,13 @@ func runNormalMode(ctx context.Context) error {
|
||||
appInstance := app.New(appOpts, messages)
|
||||
defer appInstance.Close()
|
||||
|
||||
// Wire OAuth handler to route messages through the TUI once it's running.
|
||||
if authHandler != nil {
|
||||
authHandler.NotifyFunc = func(serverName, message string) {
|
||||
appInstance.PrintFromExtension("info", message)
|
||||
}
|
||||
}
|
||||
|
||||
// Buffer for extension messages during startup (printed after startup banner).
|
||||
var startupExtensionMessages []string
|
||||
|
||||
@@ -819,7 +833,37 @@ func runNormalMode(ctx context.Context) error {
|
||||
PrintBlock: appInstance.PrintBlockFromExtension,
|
||||
SendMessage: func(text string) { appInstance.Run(text) },
|
||||
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
Abort: func() { appInstance.Abort() },
|
||||
IsIdle: func() bool { return !appInstance.IsBusy() },
|
||||
Compact: func(cfg extensions.CompactConfig) error {
|
||||
return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
|
||||
},
|
||||
SendMultimodalMessage: func(text string, files []extensions.FilePart) {
|
||||
parts := make([]kit.LLMFilePart, len(files))
|
||||
for i, f := range files {
|
||||
parts[i] = kit.LLMFilePart{
|
||||
Filename: f.Filename,
|
||||
Data: f.Data,
|
||||
MediaType: f.MediaType,
|
||||
}
|
||||
}
|
||||
appInstance.RunWithFiles(text, parts)
|
||||
},
|
||||
GetSessionUsage: func() extensions.SessionUsage {
|
||||
if usageTracker == nil {
|
||||
return extensions.SessionUsage{}
|
||||
}
|
||||
stats := usageTracker.GetSessionStats()
|
||||
return extensions.SessionUsage{
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheReadTokens: stats.TotalCacheReadTokens,
|
||||
TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
RequestCount: stats.RequestCount,
|
||||
}
|
||||
},
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
SetWidget: func(config extensions.WidgetConfig) {
|
||||
kitInstance.Extensions().SetWidget(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
@@ -1240,7 +1284,37 @@ func runNormalMode(ctx context.Context) error {
|
||||
PrintBlock: appInstance.PrintBlockFromExtension,
|
||||
SendMessage: func(text string) { appInstance.Run(text) },
|
||||
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
Abort: func() { appInstance.Abort() },
|
||||
IsIdle: func() bool { return !appInstance.IsBusy() },
|
||||
Compact: func(cfg extensions.CompactConfig) error {
|
||||
return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
|
||||
},
|
||||
SendMultimodalMessage: func(text string, files []extensions.FilePart) {
|
||||
parts := make([]kit.LLMFilePart, len(files))
|
||||
for i, f := range files {
|
||||
parts[i] = kit.LLMFilePart{
|
||||
Filename: f.Filename,
|
||||
Data: f.Data,
|
||||
MediaType: f.MediaType,
|
||||
}
|
||||
}
|
||||
appInstance.RunWithFiles(text, parts)
|
||||
},
|
||||
GetSessionUsage: func() extensions.SessionUsage {
|
||||
if usageTracker == nil {
|
||||
return extensions.SessionUsage{}
|
||||
}
|
||||
stats := usageTracker.GetSessionStats()
|
||||
return extensions.SessionUsage{
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheReadTokens: stats.TotalCacheReadTokens,
|
||||
TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
RequestCount: stats.RequestCount,
|
||||
}
|
||||
},
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
SetWidget: func(config extensions.WidgetConfig) {
|
||||
kitInstance.Extensions().SetWidget(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
|
||||
@@ -168,6 +168,10 @@ var (
|
||||
// Test
|
||||
pendingTest *PendingTest
|
||||
|
||||
// Typing indicator
|
||||
typingTicker *time.Ticker
|
||||
typingStop chan struct{}
|
||||
|
||||
// Latest context for background goroutines
|
||||
latestCtx ext.Context
|
||||
latestCtxSet bool
|
||||
@@ -203,8 +207,23 @@ func configDir() string {
|
||||
return filepath.Join(home, ".config", "kit")
|
||||
}
|
||||
|
||||
func globalConfigDir() string {
|
||||
home, _ := os.UserHomeDir()
|
||||
return filepath.Join(home, ".config", "kit")
|
||||
}
|
||||
|
||||
func configPath() string {
|
||||
return filepath.Join(configDir(), "kit-telegram.json")
|
||||
// Prefer project-local config, fall back to global config.
|
||||
local := filepath.Join(configDir(), "kit-telegram.json")
|
||||
if _, err := os.Stat(local); err == nil {
|
||||
return local
|
||||
}
|
||||
global := filepath.Join(globalConfigDir(), "kit-telegram.json")
|
||||
if _, err := os.Stat(global); err == nil {
|
||||
return global
|
||||
}
|
||||
// Neither exists — return local path (will be created on connect).
|
||||
return local
|
||||
}
|
||||
|
||||
func failureLogDir() string {
|
||||
@@ -387,6 +406,14 @@ func tgEditMessageText(token string, chatID int64, messageID int, text string) (
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
func tgSendChatAction(token string, chatID int64, action string) error {
|
||||
_, err := telegramRequest(token, "sendChatAction", map[string]any{
|
||||
"chat_id": chatID,
|
||||
"action": action,
|
||||
}, 15)
|
||||
return err
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────
|
||||
// Error classification
|
||||
// ──────────────────────────────────────────────
|
||||
@@ -637,6 +664,48 @@ func clearHealthTimer() {
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────
|
||||
// Typing indicator
|
||||
// ──────────────────────────────────────────────
|
||||
|
||||
func startTypingLoop() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if typingTicker != nil {
|
||||
return
|
||||
}
|
||||
cfg := config
|
||||
if cfg == nil || !cfg.Enabled {
|
||||
return
|
||||
}
|
||||
token := cfg.BotToken
|
||||
chatID := cfg.ChatID
|
||||
typingTicker = time.NewTicker(4 * time.Second)
|
||||
typingStop = make(chan struct{})
|
||||
// Send immediately, then every 4 seconds.
|
||||
go func() {
|
||||
tgSendChatAction(token, chatID, "typing")
|
||||
for {
|
||||
select {
|
||||
case <-typingTicker.C:
|
||||
tgSendChatAction(token, chatID, "typing")
|
||||
case <-typingStop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func stopTypingLoop() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if typingTicker != nil {
|
||||
typingTicker.Stop()
|
||||
close(typingStop)
|
||||
typingTicker = nil
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────
|
||||
// Polling lifecycle
|
||||
// ──────────────────────────────────────────────
|
||||
@@ -2105,6 +2174,7 @@ func Init(api ext.API) {
|
||||
mu.Unlock()
|
||||
|
||||
sendShutdownDisconnectedMessage()
|
||||
stopTypingLoop()
|
||||
stopPolling()
|
||||
clearHealthTimer()
|
||||
clearFooter()
|
||||
@@ -2128,6 +2198,7 @@ func Init(api ext.API) {
|
||||
mu.Unlock()
|
||||
|
||||
report("run.start", fmt.Sprintf("runId=%d", run.ID))
|
||||
startTypingLoop()
|
||||
ensureProgressMessage()
|
||||
updateProgressMessage()
|
||||
})
|
||||
@@ -2140,6 +2211,8 @@ func Init(api ext.API) {
|
||||
run := activeRun
|
||||
mu.Unlock()
|
||||
|
||||
stopTypingLoop()
|
||||
|
||||
if run != nil {
|
||||
// Capture final response from event
|
||||
if e.Response != "" {
|
||||
|
||||
@@ -25,6 +25,11 @@ type AgentConfig struct {
|
||||
StreamingEnabled bool
|
||||
DebugLogger tools.DebugLogger
|
||||
|
||||
// AuthHandler handles OAuth authorization for remote MCP servers.
|
||||
// When set, remote transports are configured with OAuth support.
|
||||
// If nil, remote MCP servers that require OAuth will fail to connect.
|
||||
AuthHandler tools.MCPAuthHandler
|
||||
|
||||
// CoreTools overrides the default core tool set. If empty, core.AllTools()
|
||||
// is used. This allows SDK users to provide a custom tool set (e.g.
|
||||
// CodingTools or tools with a custom WorkDir).
|
||||
@@ -139,6 +144,10 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
toolManager = tools.NewMCPToolManager()
|
||||
toolManager.SetModel(providerResult.Model)
|
||||
|
||||
if agentConfig.AuthHandler != nil {
|
||||
toolManager.SetAuthHandler(agentConfig.AuthHandler)
|
||||
}
|
||||
|
||||
if agentConfig.DebugLogger != nil {
|
||||
toolManager.SetDebugLogger(agentConfig.DebugLogger)
|
||||
}
|
||||
|
||||
@@ -36,6 +36,8 @@ type AgentCreationOptions struct {
|
||||
SpinnerFunc SpinnerFunc // Function to show spinner (provided by caller)
|
||||
// DebugLogger is an optional logger for debugging MCP communications
|
||||
DebugLogger tools.DebugLogger // Optional debug logger
|
||||
// AuthHandler handles OAuth authorization for remote MCP servers
|
||||
AuthHandler tools.MCPAuthHandler
|
||||
// CoreTools overrides the default core tool set. If empty, core.AllTools()
|
||||
// is used.
|
||||
CoreTools []fantasy.AgentTool
|
||||
@@ -56,6 +58,7 @@ func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error
|
||||
MaxSteps: opts.MaxSteps,
|
||||
StreamingEnabled: opts.StreamingEnabled,
|
||||
DebugLogger: opts.DebugLogger,
|
||||
AuthHandler: opts.AuthHandler,
|
||||
CoreTools: opts.CoreTools,
|
||||
ToolWrapper: opts.ToolWrapper,
|
||||
ExtraTools: opts.ExtraTools,
|
||||
|
||||
@@ -162,6 +162,24 @@ func (a *App) CancelCurrentStep() {
|
||||
cancel()
|
||||
}
|
||||
|
||||
// IsBusy returns true when the agent is currently processing a turn.
|
||||
func (a *App) IsBusy() bool {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
return a.busy
|
||||
}
|
||||
|
||||
// Abort cancels the current agent step (if running) and clears the queue.
|
||||
// Unlike InterruptAndSend, no new message is injected — the agent simply
|
||||
// stops. Safe to call when idle (no-op).
|
||||
func (a *App) Abort() {
|
||||
a.mu.Lock()
|
||||
a.queue = a.queue[:0]
|
||||
cancel := a.cancelStep
|
||||
a.mu.Unlock()
|
||||
cancel()
|
||||
}
|
||||
|
||||
// QueueLength returns the number of prompts currently waiting in the queue.
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
@@ -399,6 +417,78 @@ func (a *App) CompactConversation(customInstructions string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CompactAsync is like CompactConversation but calls onComplete/onError
|
||||
// callbacks instead of sending TUI events. Used by the extension API's
|
||||
// ctx.Compact() which needs callback-based notification.
|
||||
func (a *App) CompactAsync(customInstructions string, onComplete func(), onError func(string)) error {
|
||||
a.mu.Lock()
|
||||
if a.closed {
|
||||
a.mu.Unlock()
|
||||
return fmt.Errorf("app is closed")
|
||||
}
|
||||
if a.busy {
|
||||
a.mu.Unlock()
|
||||
return fmt.Errorf("cannot compact while the agent is working")
|
||||
}
|
||||
if a.opts.Kit == nil {
|
||||
a.mu.Unlock()
|
||||
return fmt.Errorf("SDK instance not available")
|
||||
}
|
||||
a.busy = true
|
||||
a.wg.Add(1)
|
||||
a.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
defer a.wg.Done()
|
||||
defer func() {
|
||||
a.mu.Lock()
|
||||
a.busy = false
|
||||
a.mu.Unlock()
|
||||
}()
|
||||
|
||||
// Subscribe to SDK events for streaming compaction summary to the TUI.
|
||||
sendFn := func(msg tea.Msg) {
|
||||
if a.program != nil {
|
||||
a.program.Send(msg)
|
||||
}
|
||||
}
|
||||
unsub := a.subscribeSDKEvents(sendFn, nil)
|
||||
defer unsub()
|
||||
|
||||
result, err := a.opts.Kit.Compact(a.rootCtx, nil, customInstructions)
|
||||
if err != nil {
|
||||
a.sendEvent(CompactErrorEvent{Err: err})
|
||||
if onError != nil {
|
||||
onError(err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
if result == nil {
|
||||
a.sendEvent(CompactErrorEvent{Err: fmt.Errorf("nothing to compact")})
|
||||
if onError != nil {
|
||||
onError("nothing to compact")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Sync in-memory store with the compacted session.
|
||||
if a.opts.TreeSession != nil {
|
||||
a.store.Replace(a.opts.TreeSession.GetLLMMessages())
|
||||
}
|
||||
|
||||
a.sendEvent(CompactCompleteEvent{
|
||||
Summary: result.Summary,
|
||||
OriginalTokens: result.OriginalTokens,
|
||||
CompactedTokens: result.CompactedTokens,
|
||||
MessagesRemoved: result.MessagesRemoved,
|
||||
})
|
||||
if onComplete != nil {
|
||||
onComplete()
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Non-interactive execution
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
@@ -77,6 +77,64 @@ type Context struct {
|
||||
// ctx.CancelAndSend("Stop what you're doing and focus on the tests")
|
||||
CancelAndSend func(string)
|
||||
|
||||
// Abort cancels the current agent turn (if running) and clears the
|
||||
// message queue. Unlike CancelAndSend, no new message is injected —
|
||||
// the agent simply stops. Safe to call when idle (no-op).
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ctx.Abort() // stop whatever the agent is doing
|
||||
Abort func()
|
||||
|
||||
// IsIdle returns true when the agent is not processing a turn.
|
||||
// Extensions can use this to decide whether to dispatch immediately
|
||||
// or queue work for later.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// if ctx.IsIdle() {
|
||||
// ctx.SendMessage("start new task")
|
||||
// }
|
||||
IsIdle func() bool
|
||||
|
||||
// Compact triggers context compaction, summarising older messages to
|
||||
// free context window space. Returns an error if compaction cannot
|
||||
// start (e.g. agent is busy or app is closed). The actual compaction
|
||||
// runs asynchronously; use OnComplete/OnError callbacks in
|
||||
// CompactConfig to observe the result.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// err := ctx.Compact(ext.CompactConfig{
|
||||
// OnComplete: func() { ctx.PrintInfo("Compaction done") },
|
||||
// OnError: func(errMsg string) { ctx.PrintError("Compact failed: " + errMsg) },
|
||||
// })
|
||||
Compact func(CompactConfig) error
|
||||
|
||||
// SendMultimodalMessage injects a message with file attachments (images,
|
||||
// documents) into the conversation and triggers a new agent turn. Files
|
||||
// are described by FilePart structs containing the raw bytes, filename,
|
||||
// and MIME type. If the agent is busy the message is queued.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// data, _ := os.ReadFile("photo.jpg")
|
||||
// ctx.SendMultimodalMessage("Describe this image", []ext.FilePart{
|
||||
// {Filename: "photo.jpg", Data: data, MediaType: "image/jpeg"},
|
||||
// })
|
||||
SendMultimodalMessage func(text string, files []FilePart)
|
||||
|
||||
// GetSessionUsage returns aggregated token usage and cost statistics
|
||||
// for the current session. This includes total input/output tokens,
|
||||
// cache read/write tokens, total cost, and request count.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// usage := ctx.GetSessionUsage()
|
||||
// fmt.Sprintf("Tokens: ↑%d ↓%d Cost: $%.3f",
|
||||
// usage.TotalInputTokens, usage.TotalOutputTokens, usage.TotalCost)
|
||||
GetSessionUsage func() SessionUsage
|
||||
|
||||
// SetWidget places or updates a persistent widget in the TUI. Widgets
|
||||
// remain visible across agent turns until explicitly removed. The
|
||||
// widget is identified by WidgetConfig.ID; calling SetWidget with the
|
||||
@@ -937,6 +995,48 @@ type StatusBarEntry struct {
|
||||
Priority int
|
||||
}
|
||||
|
||||
// CompactConfig configures a programmatic context compaction request.
|
||||
type CompactConfig struct {
|
||||
// CustomInstructions is optional text appended to the summary prompt
|
||||
// (e.g. "Focus on the API design decisions"). Empty uses the default.
|
||||
CustomInstructions string
|
||||
// OnComplete is called when compaction finishes successfully.
|
||||
// May be nil if the caller doesn't need notification.
|
||||
OnComplete func()
|
||||
// OnError is called when compaction fails. The argument is the error message.
|
||||
// May be nil if the caller doesn't need notification.
|
||||
OnError func(errMsg string)
|
||||
}
|
||||
|
||||
// FilePart describes a file attachment for multimodal messages. Extensions
|
||||
// use this with SendMultimodalMessage to attach images or documents.
|
||||
type FilePart struct {
|
||||
// Filename is the name of the file (e.g. "photo.jpg").
|
||||
Filename string
|
||||
// Data is the raw file content.
|
||||
Data []byte
|
||||
// MediaType is the MIME type (e.g. "image/jpeg", "application/pdf").
|
||||
MediaType string
|
||||
}
|
||||
|
||||
// SessionUsage contains aggregated token usage and cost statistics for
|
||||
// the current session. Extensions use this with GetSessionUsage() to
|
||||
// report usage information.
|
||||
type SessionUsage struct {
|
||||
// TotalInputTokens is the sum of input tokens across all requests.
|
||||
TotalInputTokens int
|
||||
// TotalOutputTokens is the sum of output tokens across all requests.
|
||||
TotalOutputTokens int
|
||||
// TotalCacheReadTokens is the sum of cache read tokens.
|
||||
TotalCacheReadTokens int
|
||||
// TotalCacheWriteTokens is the sum of cache write tokens.
|
||||
TotalCacheWriteTokens int
|
||||
// TotalCost is the total cost in USD across all requests.
|
||||
TotalCost float64
|
||||
// RequestCount is the number of LLM requests made in this session.
|
||||
RequestCount int
|
||||
}
|
||||
|
||||
// PrintBlockOpts configures a custom styled block for PrintBlock.
|
||||
type PrintBlockOpts struct {
|
||||
// Text is the main content to display.
|
||||
|
||||
@@ -154,6 +154,11 @@ func NewInstaller(projectDir string) *Installer {
|
||||
|
||||
// Install clones a git repository to the appropriate scope.
|
||||
func (i *Installer) Install(source *GitSource, scope InstallScope) error {
|
||||
return i.install(source, scope, nil)
|
||||
}
|
||||
|
||||
// install is the internal implementation that supports optional include paths.
|
||||
func (i *Installer) install(source *GitSource, scope InstallScope, includePaths []string) error {
|
||||
targetDir := i.getInstallPath(source, scope)
|
||||
|
||||
// Check if already installed
|
||||
@@ -199,6 +204,7 @@ func (i *Installer) Install(source *GitSource, scope InstallScope) error {
|
||||
Pinned: source.Pinned,
|
||||
Scope: scope,
|
||||
Installed: time.Now(),
|
||||
Include: includePaths,
|
||||
}
|
||||
if err := i.addToManifest(entry, scope); err != nil {
|
||||
// Don't fail the install, just log the error
|
||||
@@ -268,7 +274,22 @@ func (i *Installer) Update(source *GitSource, scope InstallScope) error {
|
||||
cleanCmd.Dir = targetDir
|
||||
_ = cleanCmd.Run() // Ignore errors - clean is best effort
|
||||
|
||||
// Update manifest timestamp
|
||||
// Update manifest timestamp, preserving existing fields like Include
|
||||
existing, _ := i.loadManifest(scope)
|
||||
var include []string
|
||||
var installed time.Time
|
||||
if existing != nil {
|
||||
for _, p := range existing.Packages {
|
||||
if p.Host+"/"+p.Path == source.Identity() {
|
||||
include = p.Include
|
||||
installed = p.Installed
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if installed.IsZero() {
|
||||
installed = time.Now()
|
||||
}
|
||||
entry := ManifestEntry{
|
||||
Source: source.String(),
|
||||
Repo: source.Repo,
|
||||
@@ -277,8 +298,9 @@ func (i *Installer) Update(source *GitSource, scope InstallScope) error {
|
||||
Ref: "",
|
||||
Pinned: false,
|
||||
Scope: scope,
|
||||
Installed: time.Now(),
|
||||
Installed: installed,
|
||||
Updated: time.Now(),
|
||||
Include: include,
|
||||
}
|
||||
_ = i.addToManifest(entry, scope) // Best effort - don't fail update if manifest fails
|
||||
|
||||
@@ -503,30 +525,7 @@ func (i *Installer) PreviewExtensions(source *GitSource) ([]ExtensionPreview, st
|
||||
// InstallWithInclude clones a repo and installs only the specified extensions.
|
||||
// includePaths are relative paths like "./git/main.go" - if empty, installs all.
|
||||
func (i *Installer) InstallWithInclude(source *GitSource, scope InstallScope, includePaths []string) error {
|
||||
// First, do a regular install
|
||||
if err := i.Install(source, scope); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If specific includes were requested, update the manifest
|
||||
if len(includePaths) > 0 {
|
||||
entry := ManifestEntry{
|
||||
Source: source.String(),
|
||||
Repo: source.Repo,
|
||||
Host: source.Host,
|
||||
Path: source.Path,
|
||||
Ref: source.Ref,
|
||||
Pinned: source.Pinned,
|
||||
Scope: scope,
|
||||
Include: includePaths,
|
||||
}
|
||||
|
||||
if err := addEntryToManifest(entry, scope); err != nil {
|
||||
return fmt.Errorf("updating manifest with includes: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return i.install(source, scope, includePaths)
|
||||
}
|
||||
|
||||
// CleanupTempDir removes a temporary directory used for preview.
|
||||
|
||||
@@ -133,7 +133,7 @@ func findExtensionsInDir(dir string) []string {
|
||||
|
||||
for _, entry := range entries {
|
||||
full := filepath.Join(dir, entry.Name())
|
||||
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".go") {
|
||||
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".go") && !strings.HasSuffix(entry.Name(), "_test.go") {
|
||||
results = append(results, full)
|
||||
} else if entry.IsDir() {
|
||||
main := filepath.Join(full, "main.go")
|
||||
@@ -190,9 +190,13 @@ func findExtensionsInRepo(repoPath string) []string {
|
||||
isExtDir := base == "extensions" || base == "ext" ||
|
||||
strings.HasSuffix(base, "-extensions") || strings.HasSuffix(base, "-ext")
|
||||
|
||||
isExamplesSubdir := relPath == "examples" || strings.HasPrefix(relPath, "examples/")
|
||||
// Allow walking into examples/ so we can reach examples/extensions/ etc,
|
||||
// but don't treat examples/ itself or non-extension subdirs as extension locations.
|
||||
if relPath == "examples" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !isExtDir && !isExamplesSubdir {
|
||||
if !isExtDir {
|
||||
mainPath := filepath.Join(path, "main.go")
|
||||
if _, err := os.Stat(mainPath); err == nil {
|
||||
if relPath == base { // Top-level directory
|
||||
@@ -202,13 +206,6 @@ func findExtensionsInRepo(repoPath string) []string {
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
if isExamplesSubdir || isExtDir {
|
||||
if !multiFileDirs[relPath] {
|
||||
multiFileDirs[relPath] = true
|
||||
results = append(results, mainPath)
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
@@ -227,7 +224,7 @@ func findExtensionsInRepo(repoPath string) []string {
|
||||
}
|
||||
|
||||
// It's a file
|
||||
if !strings.HasSuffix(info.Name(), ".go") {
|
||||
if !strings.HasSuffix(info.Name(), ".go") || strings.HasSuffix(info.Name(), "_test.go") {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -253,10 +253,13 @@ func ScanForExtensions(dir string) ([]ExtensionPreview, error) {
|
||||
isExtDir := base == "extensions" || base == "ext" ||
|
||||
strings.HasSuffix(base, "-extensions") || strings.HasSuffix(base, "-ext")
|
||||
|
||||
// Or check if it's a subdirectory of examples/ that might contain extensions
|
||||
isExamplesSubdir := relPath == "examples" || strings.HasPrefix(relPath, "examples/")
|
||||
// Allow walking into examples/ so we can reach examples/extensions/ etc,
|
||||
// but don't treat examples/ itself or non-extension subdirs as extension locations.
|
||||
if relPath == "examples" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !isExtDir && !isExamplesSubdir {
|
||||
if !isExtDir {
|
||||
// Check for main.go before skipping
|
||||
mainPath := filepath.Join(path, "main.go")
|
||||
if _, err := os.Stat(mainPath); err == nil {
|
||||
@@ -272,18 +275,6 @@ func ScanForExtensions(dir string) ([]ExtensionPreview, error) {
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
// Inside a valid extensions directory
|
||||
if isExamplesSubdir || isExtDir {
|
||||
if !multiFileDirs[relPath] {
|
||||
multiFileDirs[relPath] = true
|
||||
previews = append(previews, ExtensionPreview{
|
||||
Path: "./" + relPath + "/main.go",
|
||||
Name: deriveExtensionName(relPath+"/main.go", true),
|
||||
IsMain: true,
|
||||
})
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
}
|
||||
|
||||
// Not an extension location
|
||||
@@ -309,7 +300,7 @@ func ScanForExtensions(dir string) ([]ExtensionPreview, error) {
|
||||
}
|
||||
|
||||
// It's a file - check if it's a valid extension
|
||||
if !strings.HasSuffix(info.Name(), ".go") {
|
||||
if !strings.HasSuffix(info.Name(), ".go") || strings.HasSuffix(info.Name(), "_test.go") {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -86,6 +86,21 @@ func normalizeContext(ctx Context) Context {
|
||||
if ctx.CancelAndSend == nil {
|
||||
ctx.CancelAndSend = func(string) {}
|
||||
}
|
||||
if ctx.Abort == nil {
|
||||
ctx.Abort = func() {}
|
||||
}
|
||||
if ctx.IsIdle == nil {
|
||||
ctx.IsIdle = func() bool { return true }
|
||||
}
|
||||
if ctx.Compact == nil {
|
||||
ctx.Compact = func(CompactConfig) error { return fmt.Errorf("compact not available") }
|
||||
}
|
||||
if ctx.SendMultimodalMessage == nil {
|
||||
ctx.SendMultimodalMessage = func(string, []FilePart) {}
|
||||
}
|
||||
if ctx.GetSessionUsage == nil {
|
||||
ctx.GetSessionUsage = func() SessionUsage { return SessionUsage{} }
|
||||
}
|
||||
if ctx.SetWidget == nil {
|
||||
ctx.SetWidget = func(WidgetConfig) {}
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ func Symbols() interp.Exports {
|
||||
// Session types
|
||||
"SessionMessage": reflect.ValueOf((*SessionMessage)(nil)),
|
||||
"ExtensionEntry": reflect.ValueOf((*ExtensionEntry)(nil)),
|
||||
"SessionUsage": reflect.ValueOf((*SessionUsage)(nil)),
|
||||
|
||||
// Option types
|
||||
"OptionDef": reflect.ValueOf((*OptionDef)(nil)),
|
||||
@@ -44,6 +45,8 @@ func Symbols() interp.Exports {
|
||||
// LLM completion types
|
||||
"CompleteRequest": reflect.ValueOf((*CompleteRequest)(nil)),
|
||||
"CompleteResponse": reflect.ValueOf((*CompleteResponse)(nil)),
|
||||
"CompactConfig": reflect.ValueOf((*CompactConfig)(nil)),
|
||||
"FilePart": reflect.ValueOf((*FilePart)(nil)),
|
||||
|
||||
// Status bar types
|
||||
"StatusBarEntry": reflect.ValueOf((*StatusBarEntry)(nil)),
|
||||
|
||||
@@ -58,6 +58,9 @@ type AgentSetupOptions struct {
|
||||
// StreamingEnabled controls streaming. Only meaningful when ProviderConfig
|
||||
// is also set.
|
||||
StreamingEnabled bool
|
||||
// AuthHandler handles OAuth authorization for remote MCP servers.
|
||||
// When set, remote transports are configured with OAuth support.
|
||||
AuthHandler tools.MCPAuthHandler
|
||||
}
|
||||
|
||||
// AgentSetupResult bundles the created agent and any debug logger so the caller
|
||||
@@ -185,6 +188,7 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
Quiet: opts.Quiet,
|
||||
SpinnerFunc: opts.SpinnerFunc,
|
||||
DebugLogger: debugLogger,
|
||||
AuthHandler: opts.AuthHandler,
|
||||
CoreTools: opts.CoreTools,
|
||||
ToolWrapper: toolWrapper,
|
||||
ExtraTools: extraTools,
|
||||
|
||||
@@ -524,13 +524,13 @@ func buildOpenAIProviderOptions(config *ProviderConfig, modelName string) fantas
|
||||
func thinkingLevelToReasoningEffort(level ThinkingLevel) *openai.ReasoningEffort {
|
||||
switch level {
|
||||
case ThinkingMinimal:
|
||||
return openai.ReasoningEffortOption(openai.ReasoningEffortMinimal)
|
||||
return new(openai.ReasoningEffortMinimal)
|
||||
case ThinkingLow:
|
||||
return openai.ReasoningEffortOption(openai.ReasoningEffortLow)
|
||||
return new(openai.ReasoningEffortLow)
|
||||
case ThinkingMedium:
|
||||
return openai.ReasoningEffortOption(openai.ReasoningEffortMedium)
|
||||
return new(openai.ReasoningEffortMedium)
|
||||
case ThinkingHigh:
|
||||
return openai.ReasoningEffortOption(openai.ReasoningEffortHigh)
|
||||
return new(openai.ReasoningEffortHigh)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -68,6 +68,7 @@ type MCPConnectionPool struct {
|
||||
cancel context.CancelFunc
|
||||
debug bool
|
||||
debugLogger DebugLogger
|
||||
oauthFlow *OAuthFlowRunner
|
||||
}
|
||||
|
||||
// NewMCPConnectionPool creates a new MCP connection pool with the specified configuration.
|
||||
@@ -75,7 +76,7 @@ type MCPConnectionPool struct {
|
||||
// goroutine for periodic health checks that runs until Close is called.
|
||||
// The model parameter is used for MCP servers that require sampling support.
|
||||
// Thread-safe for concurrent use immediately after creation.
|
||||
func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageModel, debug bool) *MCPConnectionPool {
|
||||
func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageModel, debug bool, authHandler MCPAuthHandler) *MCPConnectionPool {
|
||||
if config == nil {
|
||||
config = DefaultConnectionPoolConfig()
|
||||
}
|
||||
@@ -90,6 +91,10 @@ func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageMo
|
||||
debug: debug,
|
||||
}
|
||||
|
||||
if authHandler != nil {
|
||||
pool.oauthFlow = NewOAuthFlowRunner(authHandler)
|
||||
}
|
||||
|
||||
go pool.startHealthCheck()
|
||||
return pool
|
||||
}
|
||||
@@ -103,6 +108,15 @@ func (p *MCPConnectionPool) SetDebugLogger(logger DebugLogger) {
|
||||
p.debugLogger = logger
|
||||
}
|
||||
|
||||
// SetOAuthFlow sets the OAuth flow runner for the connection pool.
|
||||
// When set, the pool can trigger OAuth re-authorization when a tool call fails
|
||||
// with an OAuth error (e.g. expired token). Thread-safe and can be called at any time.
|
||||
func (p *MCPConnectionPool) SetOAuthFlow(flow *OAuthFlowRunner) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.oauthFlow = flow
|
||||
}
|
||||
|
||||
// GetConnection retrieves or creates a connection for the specified MCP server.
|
||||
// If a healthy, non-idle connection exists in the pool, it will be reused.
|
||||
// Otherwise, a new connection is created and added to the pool.
|
||||
@@ -230,18 +244,43 @@ func (p *MCPConnectionPool) performHealthCheck(ctx context.Context, conn *MCPCon
|
||||
|
||||
// createConnection creates a new connection
|
||||
func (p *MCPConnectionPool) createConnection(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (*MCPConnection, error) {
|
||||
client, err := p.createMCPClient(ctx, serverName, serverConfig)
|
||||
mcpClient, err := p.createMCPClient(ctx, serverName, serverConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// SSE transport can return OAuth error during Start()
|
||||
if p.oauthFlow != nil && IsOAuthError(err) {
|
||||
if flowErr := p.oauthFlow.RunAuthFlow(ctx, serverName, err); flowErr != nil {
|
||||
return nil, fmt.Errorf("OAuth authorization failed: %w", flowErr)
|
||||
}
|
||||
// Retry after successful auth
|
||||
mcpClient, err = p.createMCPClient(ctx, serverName, serverConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := p.initializeClient(ctx, client); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, err
|
||||
if err := p.initializeClient(ctx, mcpClient); err != nil {
|
||||
// Streamable HTTP transport returns OAuth error during Initialize()
|
||||
if p.oauthFlow != nil && IsOAuthError(err) {
|
||||
if flowErr := p.oauthFlow.RunAuthFlow(ctx, serverName, err); flowErr != nil {
|
||||
_ = mcpClient.Close()
|
||||
return nil, fmt.Errorf("OAuth authorization failed: %w", flowErr)
|
||||
}
|
||||
// Retry initialization after successful auth
|
||||
if err := p.initializeClient(ctx, mcpClient); err != nil {
|
||||
_ = mcpClient.Close()
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
_ = mcpClient.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
conn := &MCPConnection{
|
||||
client: client,
|
||||
client: mcpClient,
|
||||
serverName: serverName,
|
||||
serverConfig: serverConfig,
|
||||
lastUsed: time.Now(),
|
||||
@@ -323,13 +362,29 @@ func (p *MCPConnectionPool) createSSEClient(ctx context.Context, serverConfig co
|
||||
}
|
||||
}
|
||||
|
||||
// Enable OAuth for remote transports when an auth handler is configured.
|
||||
// The OAuthConfig uses PKCE and the handler's redirect URI. Client ID and
|
||||
// scopes are discovered automatically via dynamic client registration and
|
||||
// server metadata (RFC 9728).
|
||||
if p.oauthFlow != nil {
|
||||
tokenStore, tsErr := NewFileTokenStore(serverConfig.URL)
|
||||
if tsErr != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
|
||||
}
|
||||
options = append(options, transport.WithOAuth(transport.OAuthConfig{
|
||||
RedirectURI: p.oauthFlow.handler.RedirectURI(),
|
||||
PKCEEnabled: true,
|
||||
TokenStore: tokenStore,
|
||||
}))
|
||||
}
|
||||
|
||||
sseClient, err := client.NewSSEMCPClient(serverConfig.URL, options...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := sseClient.Start(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to start SSE client: %v", err)
|
||||
return nil, fmt.Errorf("failed to start SSE client: %w", err)
|
||||
}
|
||||
|
||||
return sseClient, nil
|
||||
@@ -354,13 +409,29 @@ func (p *MCPConnectionPool) createStreamableClient(ctx context.Context, serverCo
|
||||
}
|
||||
}
|
||||
|
||||
// Enable OAuth for remote transports when an auth handler is configured.
|
||||
// The OAuthConfig uses PKCE and the handler's redirect URI. Client ID and
|
||||
// scopes are discovered automatically via dynamic client registration and
|
||||
// server metadata (RFC 9728).
|
||||
if p.oauthFlow != nil {
|
||||
tokenStore, tsErr := NewFileTokenStore(serverConfig.URL)
|
||||
if tsErr != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
|
||||
}
|
||||
options = append(options, transport.WithHTTPOAuth(transport.OAuthConfig{
|
||||
RedirectURI: p.oauthFlow.handler.RedirectURI(),
|
||||
PKCEEnabled: true,
|
||||
TokenStore: tokenStore,
|
||||
}))
|
||||
}
|
||||
|
||||
streamableClient, err := client.NewStreamableHttpClient(serverConfig.URL, options...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := streamableClient.Start(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to start streamable HTTP client: %v", err)
|
||||
return nil, fmt.Errorf("failed to start streamable HTTP client: %w", err)
|
||||
}
|
||||
|
||||
return streamableClient, nil
|
||||
@@ -381,7 +452,7 @@ func (p *MCPConnectionPool) initializeClient(ctx context.Context, client client.
|
||||
|
||||
_, err := client.Initialize(initCtx, initRequest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialization timeout or failed: %v", err)
|
||||
return fmt.Errorf("initialization timeout or failed: %w", err)
|
||||
}
|
||||
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
@@ -539,6 +610,9 @@ func (p *MCPConnectionPool) Close() error {
|
||||
|
||||
// isConnectionError checks if the error is connection-related
|
||||
func isConnectionError(err error) bool {
|
||||
if IsOAuthError(err) {
|
||||
return false // OAuth errors are recoverable, not connection failures
|
||||
}
|
||||
errStr := err.Error()
|
||||
return strings.Contains(errStr, "Connection not found") ||
|
||||
strings.Contains(errStr, "transport error") ||
|
||||
|
||||
@@ -59,9 +59,30 @@ func (t *mcpFantasyTool) Run(ctx context.Context, call fantasy.ToolCall) (fantas
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
// Mark connection as unhealthy for automatic recovery
|
||||
t.mapping.manager.connectionPool.HandleConnectionError(t.mapping.serverName, err)
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("failed to call mcp tool: %w", err)
|
||||
// Handle OAuth re-authorization: token may have expired mid-session.
|
||||
if t.mapping.manager.connectionPool.oauthFlow != nil && IsOAuthError(err) {
|
||||
if flowErr := t.mapping.manager.connectionPool.oauthFlow.RunAuthFlow(ctx, t.mapping.serverName, err); flowErr != nil {
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", t.mapping.originalName, flowErr)
|
||||
}
|
||||
// Retry the tool call after successful re-auth.
|
||||
result, err = conn.client.CallTool(ctx, mcp.CallToolRequest{
|
||||
Request: mcp.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: mcp.CallToolParams{
|
||||
Name: t.mapping.originalName,
|
||||
Arguments: arguments,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.mapping.manager.connectionPool.HandleConnectionError(t.mapping.serverName, err)
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("failed to call mcp tool after re-auth: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Mark connection as unhealthy for automatic recovery
|
||||
t.mapping.manager.connectionPool.HandleConnectionError(t.mapping.serverName, err)
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("failed to call mcp tool: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Marshal the MCP result to JSON string
|
||||
|
||||
+10
-1
@@ -22,6 +22,7 @@ type MCPToolManager struct {
|
||||
tools []fantasy.AgentTool
|
||||
toolMap map[string]*toolMapping // maps prefixed tool names to their server and original name
|
||||
model fantasy.LanguageModel // LLM model for sampling
|
||||
authHandler MCPAuthHandler // OAuth handler for remote servers (nil = no OAuth)
|
||||
config *config.Config
|
||||
debug bool
|
||||
debugLogger DebugLogger
|
||||
@@ -53,6 +54,14 @@ func (m *MCPToolManager) SetModel(model fantasy.LanguageModel) {
|
||||
m.model = model
|
||||
}
|
||||
|
||||
// SetAuthHandler sets the OAuth handler for remote MCP server authentication.
|
||||
// When set, remote transports (streamable HTTP, SSE) are configured with OAuth
|
||||
// support, enabling automatic authorization flows when servers require authentication.
|
||||
// This method should be called before LoadTools.
|
||||
func (m *MCPToolManager) SetAuthHandler(handler MCPAuthHandler) {
|
||||
m.authHandler = handler
|
||||
}
|
||||
|
||||
// SetDebugLogger sets the debug logger for the tool manager.
|
||||
// The logger will be used to output detailed debugging information about MCP connections,
|
||||
// tool loading, and execution. If a connection pool exists, it will also be configured
|
||||
@@ -76,7 +85,7 @@ func (m *MCPToolManager) LoadTools(ctx context.Context, config *config.Config) e
|
||||
if m.debugLogger == nil {
|
||||
m.debugLogger = NewSimpleDebugLogger(config.Debug)
|
||||
}
|
||||
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, config.Debug)
|
||||
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, config.Debug, m.authHandler)
|
||||
m.connectionPool.SetDebugLogger(m.debugLogger)
|
||||
|
||||
var loadErrors []string
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
)
|
||||
|
||||
// MCPAuthHandler is the internal interface for handling MCP OAuth flows.
|
||||
// The SDK-level kit.MCPAuthHandler is adapted to this interface in cmd/root.go
|
||||
// or pkg/kit/kit.go, keeping the tools package decoupled from the SDK.
|
||||
type MCPAuthHandler interface {
|
||||
// RedirectURI returns the OAuth redirect URI for transport setup.
|
||||
RedirectURI() string
|
||||
// HandleAuth is called when a server requires OAuth authorization.
|
||||
// It receives the server name and the authorization URL the user must visit.
|
||||
// It returns the full callback URL (containing code and state query params)
|
||||
// after the user completes authorization.
|
||||
HandleAuth(ctx context.Context, serverName string, authURL string) (callbackURL string, err error)
|
||||
}
|
||||
|
||||
// OAuthFlowRunner handles the OAuth authorization flow when an MCP server
|
||||
// returns an OAuthAuthorizationRequiredError. It coordinates dynamic client
|
||||
// registration, PKCE generation, user authorization (via MCPAuthHandler),
|
||||
// and token exchange.
|
||||
type OAuthFlowRunner struct {
|
||||
handler MCPAuthHandler
|
||||
}
|
||||
|
||||
// NewOAuthFlowRunner creates a new OAuthFlowRunner with the given auth handler.
|
||||
func NewOAuthFlowRunner(handler MCPAuthHandler) *OAuthFlowRunner {
|
||||
return &OAuthFlowRunner{handler: handler}
|
||||
}
|
||||
|
||||
// RunAuthFlow executes the OAuth authorization flow for the given server.
|
||||
// It extracts the OAuthHandler from the error, performs dynamic client registration
|
||||
// if needed, generates PKCE parameters, delegates to the MCPAuthHandler for user
|
||||
// interaction, and exchanges the authorization code for a token.
|
||||
func (r *OAuthFlowRunner) RunAuthFlow(ctx context.Context, serverName string, authErr error) error {
|
||||
// Extract the OAuthHandler from the authorization-required error.
|
||||
oauthHandler := client.GetOAuthHandler(authErr)
|
||||
if oauthHandler == nil {
|
||||
return fmt.Errorf("oauth flow: failed to extract OAuth handler from error: %w", authErr)
|
||||
}
|
||||
|
||||
// Perform dynamic client registration if no client ID is configured yet.
|
||||
if oauthHandler.GetClientID() == "" {
|
||||
if err := oauthHandler.RegisterClient(ctx, "kit"); err != nil {
|
||||
return fmt.Errorf("oauth flow: dynamic client registration failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Generate PKCE code verifier and challenge.
|
||||
codeVerifier, err := client.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
return fmt.Errorf("oauth flow: failed to generate code verifier: %w", err)
|
||||
}
|
||||
codeChallenge := client.GenerateCodeChallenge(codeVerifier)
|
||||
|
||||
// Generate a random state parameter for CSRF protection.
|
||||
state, err := client.GenerateState()
|
||||
if err != nil {
|
||||
return fmt.Errorf("oauth flow: failed to generate state: %w", err)
|
||||
}
|
||||
|
||||
// Build the authorization URL the user needs to visit.
|
||||
authURL, err := oauthHandler.GetAuthorizationURL(ctx, state, codeChallenge)
|
||||
if err != nil {
|
||||
return fmt.Errorf("oauth flow: failed to get authorization URL: %w", err)
|
||||
}
|
||||
|
||||
// Delegate to the MCPAuthHandler for user-facing authorization (e.g. open
|
||||
// browser, wait for redirect). It returns the full callback URL containing
|
||||
// the authorization code and state.
|
||||
callbackURL, err := r.handler.HandleAuth(ctx, serverName, authURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("oauth flow: user authorization failed: %w", err)
|
||||
}
|
||||
|
||||
// Parse the callback URL to extract the authorization code and state.
|
||||
parsed, err := url.Parse(callbackURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("oauth flow: failed to parse callback URL: %w", err)
|
||||
}
|
||||
|
||||
code := parsed.Query().Get("code")
|
||||
returnedState := parsed.Query().Get("state")
|
||||
|
||||
if code == "" {
|
||||
return fmt.Errorf("oauth flow: callback URL missing 'code' parameter")
|
||||
}
|
||||
if returnedState == "" {
|
||||
return fmt.Errorf("oauth flow: callback URL missing 'state' parameter")
|
||||
}
|
||||
|
||||
// Exchange the authorization code for an access token.
|
||||
if err := oauthHandler.ProcessAuthorizationResponse(ctx, code, returnedState, codeVerifier); err != nil {
|
||||
return fmt.Errorf("oauth flow: token exchange failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsOAuthError returns true if the error is an OAuthAuthorizationRequiredError.
|
||||
func IsOAuthError(err error) bool {
|
||||
return client.IsOAuthAuthorizationRequiredError(err)
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/mark3labs/mcp-go/client/transport"
|
||||
)
|
||||
|
||||
// Compile-time check that FileTokenStore implements transport.TokenStore.
|
||||
var _ transport.TokenStore = (*FileTokenStore)(nil)
|
||||
|
||||
// FileTokenStore is a file-backed implementation of transport.TokenStore that
|
||||
// persists OAuth tokens as JSON on disk. Tokens are stored in a shared JSON file
|
||||
// keyed by server URL, allowing multiple MCP servers to maintain independent tokens.
|
||||
//
|
||||
// The token file is located at $XDG_CONFIG_HOME/.kit/mcp_tokens.json, falling back
|
||||
// to ~/.config/.kit/mcp_tokens.json when XDG_CONFIG_HOME is not set.
|
||||
//
|
||||
// FileTokenStore is safe for concurrent use.
|
||||
type FileTokenStore struct {
|
||||
serverKey string
|
||||
filePath string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewFileTokenStore creates a new FileTokenStore for the given server URL.
|
||||
// The serverKey is used as the map key in the shared token file, and should
|
||||
// typically be the MCP server's base URL.
|
||||
//
|
||||
// Returns an error if the token file path cannot be resolved.
|
||||
func NewFileTokenStore(serverKey string) (*FileTokenStore, error) {
|
||||
filePath, err := resolveTokenFilePath()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolving token file path: %w", err)
|
||||
}
|
||||
|
||||
return &FileTokenStore{
|
||||
serverKey: serverKey,
|
||||
filePath: filePath,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetToken returns the stored token for this store's server key.
|
||||
// Returns transport.ErrNoToken if no token exists for the server key or if
|
||||
// the token file does not yet exist.
|
||||
// Returns context.Canceled or context.DeadlineExceeded if the context is done.
|
||||
func (s *FileTokenStore) GetToken(ctx context.Context) (*transport.Token, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
tokens, err := readTokenFile(s.filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, transport.ErrNoToken
|
||||
}
|
||||
return nil, fmt.Errorf("reading token file: %w", err)
|
||||
}
|
||||
|
||||
token, ok := tokens[s.serverKey]
|
||||
if !ok {
|
||||
return nil, transport.ErrNoToken
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// SaveToken persists the given token for this store's server key.
|
||||
// If the token file or its parent directories do not exist, they are created.
|
||||
// Existing tokens for other server keys are preserved.
|
||||
// Returns context.Canceled or context.DeadlineExceeded if the context is done.
|
||||
func (s *FileTokenStore) SaveToken(ctx context.Context, token *transport.Token) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
tokens, err := readTokenFile(s.filePath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("reading token file: %w", err)
|
||||
}
|
||||
if tokens == nil {
|
||||
tokens = make(map[string]*transport.Token)
|
||||
}
|
||||
|
||||
tokens[s.serverKey] = token
|
||||
|
||||
if err := writeTokenFile(s.filePath, tokens); err != nil {
|
||||
return fmt.Errorf("writing token file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveTokenFilePath determines the path to the token file using
|
||||
// XDG_CONFIG_HOME if set, otherwise falling back to ~/.config/.kit/.
|
||||
func resolveTokenFilePath() (string, error) {
|
||||
configDir := os.Getenv("XDG_CONFIG_HOME")
|
||||
if configDir == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("determining user home directory: %w", err)
|
||||
}
|
||||
configDir = filepath.Join(home, ".config")
|
||||
}
|
||||
|
||||
return filepath.Join(configDir, ".kit", "mcp_tokens.json"), nil
|
||||
}
|
||||
|
||||
// readTokenFile reads and unmarshals the token file into a server-keyed map.
|
||||
// Returns os.ErrNotExist (via os.IsNotExist) if the file does not exist.
|
||||
func readTokenFile(path string) (map[string]*transport.Token, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tokens map[string]*transport.Token
|
||||
if err := json.Unmarshal(data, &tokens); err != nil {
|
||||
return nil, fmt.Errorf("unmarshaling token file: %w", err)
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// writeTokenFile marshals the token map and writes it to disk, creating
|
||||
// parent directories as needed. The file is written with 0600 permissions
|
||||
// to protect sensitive token data.
|
||||
func writeTokenFile(path string, tokens map[string]*transport.Token) error {
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return fmt.Errorf("creating token directory %s: %w", dir, err)
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(tokens, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling tokens: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, data, 0600); err != nil {
|
||||
return fmt.Errorf("writing token file %s: %w", path, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -48,6 +48,7 @@ type Kit struct {
|
||||
skills []*skills.Skill
|
||||
extRunner *extensions.Runner
|
||||
bufferedLogger *tools.BufferedDebugLogger
|
||||
authHandler MCPAuthHandler // OAuth handler for remote MCP servers (may need Close)
|
||||
|
||||
// Hook registries — interception layer (see hooks.go).
|
||||
beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult]
|
||||
@@ -439,6 +440,18 @@ type Options struct {
|
||||
// Debug enables debug logging for the SDK.
|
||||
Debug bool
|
||||
|
||||
// MCPAuthHandler handles OAuth authorization for remote MCP servers.
|
||||
// When set, remote transports (streamable HTTP, SSE) are configured with
|
||||
// OAuth support. If the server returns a 401, the handler is invoked to
|
||||
// let the user authorize via browser.
|
||||
//
|
||||
// If nil, a [DefaultMCPAuthHandler] is created automatically — opening the
|
||||
// system browser and listening on a local callback server.
|
||||
//
|
||||
// Set to a custom implementation to control the authorization UX (e.g.
|
||||
// display a URL in a custom UI, redirect to a web app, etc.).
|
||||
MCPAuthHandler MCPAuthHandler
|
||||
|
||||
// CLI is optional CLI-specific configuration. SDK users leave this nil.
|
||||
CLI *CLIOptions
|
||||
}
|
||||
@@ -655,6 +668,23 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
MaxSteps: maxSteps,
|
||||
StreamingEnabled: streaming,
|
||||
}
|
||||
|
||||
// Set up OAuth handler for remote MCP servers.
|
||||
// The SDK MCPAuthHandler interface is structurally identical to
|
||||
// tools.MCPAuthHandler, so any implementation satisfies both.
|
||||
if opts.MCPAuthHandler != nil {
|
||||
setupOpts.AuthHandler = opts.MCPAuthHandler
|
||||
} else {
|
||||
// Create a default handler that opens the system browser.
|
||||
defaultHandler, authErr := NewDefaultMCPAuthHandler()
|
||||
if authErr != nil {
|
||||
// Non-fatal: OAuth just won't be available for remote servers.
|
||||
charmlog.Warn("Failed to create OAuth handler; remote MCP servers requiring auth will fail", "error", authErr)
|
||||
} else {
|
||||
setupOpts.AuthHandler = defaultHandler
|
||||
}
|
||||
}
|
||||
|
||||
if opts.CLI != nil {
|
||||
setupOpts.ShowSpinner = opts.CLI.ShowSpinner
|
||||
setupOpts.SpinnerFunc = opts.CLI.SpinnerFunc
|
||||
@@ -685,6 +715,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
skills: loadedSkills,
|
||||
extRunner: agentResult.ExtRunner,
|
||||
bufferedLogger: agentResult.BufferedLogger,
|
||||
authHandler: setupOpts.AuthHandler,
|
||||
beforeToolCall: beforeToolCall,
|
||||
afterToolResult: afterToolResult,
|
||||
beforeTurn: beforeTurn,
|
||||
@@ -1645,5 +1676,9 @@ func (m *Kit) Close() error {
|
||||
if m.treeSession != nil {
|
||||
_ = m.treeSession.Close()
|
||||
}
|
||||
// Release the OAuth callback port if we own the handler.
|
||||
if closer, ok := m.authHandler.(interface{ Close() error }); ok {
|
||||
_ = closer.Close()
|
||||
}
|
||||
return m.agent.Close()
|
||||
}
|
||||
|
||||
@@ -0,0 +1,265 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MCPAuthHandler handles OAuth authorization for MCP servers.
|
||||
// Implementations control the user experience — opening a browser, showing a
|
||||
// prompt, displaying a URL, etc.
|
||||
//
|
||||
// The default implementation ([DefaultMCPAuthHandler]) opens the system browser
|
||||
// and starts a local HTTP callback server to receive the authorization code.
|
||||
type MCPAuthHandler interface {
|
||||
// RedirectURI returns the OAuth redirect URI that the callback server
|
||||
// will listen on. This is called during MCP transport setup — before any
|
||||
// OAuth errors occur — so the redirect URI can be registered with the
|
||||
// authorization server.
|
||||
RedirectURI() string
|
||||
|
||||
// HandleAuth is called when an MCP server requires OAuth authorization.
|
||||
// It receives the server name and an authorization URL that the user must
|
||||
// visit. The handler must:
|
||||
// 1. Direct the user to authURL (e.g. open browser, display URL)
|
||||
// 2. Listen for the OAuth callback on the redirect URI
|
||||
// 3. Return the full callback URL (with code and state query params)
|
||||
//
|
||||
// Return an error to abort the connection to this MCP server.
|
||||
// The context controls the overall timeout; implementations should
|
||||
// respect ctx.Done().
|
||||
HandleAuth(ctx context.Context, serverName string, authURL string) (callbackURL string, err error)
|
||||
}
|
||||
|
||||
// DefaultMCPAuthHandler opens the system browser and starts a local HTTP
|
||||
// callback server to receive the OAuth authorization code. It eagerly reserves
|
||||
// a TCP port on construction so [RedirectURI] is stable for the lifetime of
|
||||
// the handler.
|
||||
//
|
||||
// Create instances with [NewDefaultMCPAuthHandler] (random port) or
|
||||
// [NewDefaultMCPAuthHandlerWithPort] (explicit port).
|
||||
type DefaultMCPAuthHandler struct {
|
||||
listener net.Listener
|
||||
port int
|
||||
mu sync.Mutex // guards listener lifecycle
|
||||
}
|
||||
|
||||
// NewDefaultMCPAuthHandler creates a handler that listens on a random
|
||||
// available port on localhost. The port is reserved immediately so
|
||||
// [RedirectURI] returns a stable value. Call [DefaultMCPAuthHandler.Close]
|
||||
// when the handler is no longer needed to release the port.
|
||||
func NewDefaultMCPAuthHandler() (*DefaultMCPAuthHandler, error) {
|
||||
listener, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to listen for OAuth callback: %w", err)
|
||||
}
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
return &DefaultMCPAuthHandler{listener: listener, port: port}, nil
|
||||
}
|
||||
|
||||
// NewDefaultMCPAuthHandlerWithPort creates a handler that listens on the
|
||||
// specified port on localhost. The port is reserved immediately. Pass 0 to
|
||||
// let the OS pick a free port (equivalent to [NewDefaultMCPAuthHandler]).
|
||||
// Call [DefaultMCPAuthHandler.Close] when the handler is no longer needed.
|
||||
func NewDefaultMCPAuthHandlerWithPort(port int) (*DefaultMCPAuthHandler, error) {
|
||||
addr := fmt.Sprintf("localhost:%d", port)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to listen on %s for OAuth callback: %w", addr, err)
|
||||
}
|
||||
actualPort := listener.Addr().(*net.TCPAddr).Port
|
||||
return &DefaultMCPAuthHandler{listener: listener, port: actualPort}, nil
|
||||
}
|
||||
|
||||
// RedirectURI returns the OAuth redirect URI pointing to the local callback
|
||||
// server. This value is stable for the lifetime of the handler.
|
||||
func (h *DefaultMCPAuthHandler) RedirectURI() string {
|
||||
return fmt.Sprintf("http://localhost:%d/oauth/callback", h.port)
|
||||
}
|
||||
|
||||
// Port returns the TCP port the callback server is bound to.
|
||||
func (h *DefaultMCPAuthHandler) Port() int {
|
||||
return h.port
|
||||
}
|
||||
|
||||
// HandleAuth opens the system browser to authURL and waits for the OAuth
|
||||
// callback on the local server. It returns the full callback URL including
|
||||
// query parameters (code, state, etc.).
|
||||
//
|
||||
// If the context has no deadline, a default 2-minute timeout is applied.
|
||||
// The callback server is started for each HandleAuth call and shut down
|
||||
// before returning.
|
||||
func (h *DefaultMCPAuthHandler) HandleAuth(ctx context.Context, serverName string, authURL string) (string, error) {
|
||||
h.mu.Lock()
|
||||
listener := h.listener
|
||||
h.mu.Unlock()
|
||||
|
||||
if listener == nil {
|
||||
return "", fmt.Errorf("OAuth callback handler is closed")
|
||||
}
|
||||
|
||||
// Apply default timeout if the context has no deadline.
|
||||
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, 2*time.Minute)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// Channel receives the full callback URL from the HTTP handler.
|
||||
callbackCh := make(chan string, 1)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Reconstruct the full callback URL as the caller expects it.
|
||||
fullURL := fmt.Sprintf("http://localhost:%d%s", h.port, r.RequestURI)
|
||||
|
||||
// Send the callback URL to the waiting goroutine (non-blocking).
|
||||
select {
|
||||
case callbackCh <- fullURL:
|
||||
default:
|
||||
}
|
||||
|
||||
// Respond with a friendly HTML page so the user knows they can
|
||||
// close the browser tab.
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = fmt.Fprint(w, oauthSuccessHTML)
|
||||
})
|
||||
|
||||
server := &http.Server{
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
// Start serving on the pre-reserved listener. We need to create a new
|
||||
// listener on the same port because http.Server.Serve takes ownership
|
||||
// and closes the listener when done. The original listener is kept open
|
||||
// to reserve the port; we create a second listener via SO_REUSEADDR
|
||||
// semantics (Go's default on most platforms) or, more reliably, we
|
||||
// temporarily release and re-acquire.
|
||||
//
|
||||
// Strategy: use the held listener directly for Serve. After Serve
|
||||
// returns (due to Shutdown), re-acquire the listener to keep the port
|
||||
// reserved for future HandleAuth calls.
|
||||
h.mu.Lock()
|
||||
serveListener := h.listener
|
||||
h.listener = nil // Serve will close it
|
||||
h.mu.Unlock()
|
||||
|
||||
if serveListener == nil {
|
||||
return "", fmt.Errorf("OAuth callback handler is closed")
|
||||
}
|
||||
|
||||
// Start the HTTP server in a background goroutine.
|
||||
serverErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
err := server.Serve(serveListener)
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
serverErrCh <- err
|
||||
}
|
||||
close(serverErrCh)
|
||||
}()
|
||||
|
||||
// Re-acquire the listener after Serve completes (deferred).
|
||||
defer func() {
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer shutdownCancel()
|
||||
_ = server.Shutdown(shutdownCtx)
|
||||
|
||||
// Re-reserve the port for future HandleAuth calls.
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if h.listener == nil {
|
||||
newListener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", h.port))
|
||||
if err == nil {
|
||||
h.listener = newListener
|
||||
}
|
||||
// If re-listen fails, the handler degrades gracefully — the
|
||||
// next HandleAuth call will return an error.
|
||||
}
|
||||
}()
|
||||
|
||||
// Open the system browser.
|
||||
if err := openBrowser(authURL); err != nil {
|
||||
// Browser open is best-effort; the user can still navigate manually.
|
||||
_ = err
|
||||
}
|
||||
|
||||
// Wait for the callback, a server error, or context cancellation.
|
||||
select {
|
||||
case url := <-callbackCh:
|
||||
return url, nil
|
||||
case err := <-serverErrCh:
|
||||
return "", fmt.Errorf("OAuth callback server error for %q: %w", serverName, err)
|
||||
case <-ctx.Done():
|
||||
return "", fmt.Errorf("OAuth authorization timed out for %q: %w", serverName, ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
// Close releases the reserved port and shuts down the handler. After Close,
|
||||
// HandleAuth will return an error. Close is safe to call multiple times.
|
||||
func (h *DefaultMCPAuthHandler) Close() error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if h.listener != nil {
|
||||
err := h.listener.Close()
|
||||
h.listener = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// openBrowser opens the default system browser to the given URL. This is a
|
||||
// best-effort operation — errors are returned but callers typically ignore
|
||||
// them since the user can navigate manually.
|
||||
func openBrowser(url string) error {
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
return exec.Command("xdg-open", url).Start()
|
||||
case "windows":
|
||||
return exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
|
||||
case "darwin":
|
||||
return exec.Command("open", url).Start()
|
||||
default:
|
||||
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
|
||||
// oauthSuccessHTML is the HTML page returned to the browser after a
|
||||
// successful OAuth callback.
|
||||
const oauthSuccessHTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<title>Authorization Successful</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
min-height: 100vh;
|
||||
margin: 0;
|
||||
background: #f8f9fa;
|
||||
color: #333;
|
||||
}
|
||||
.container {
|
||||
text-align: center;
|
||||
padding: 2rem;
|
||||
}
|
||||
h1 { color: #22863a; }
|
||||
p { color: #586069; margin-top: 0.5rem; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>✓ Authorization Successful</h1>
|
||||
<p>You can close this tab and return to the terminal.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`
|
||||
@@ -0,0 +1,68 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
// CLIMCPAuthHandler wraps a [DefaultMCPAuthHandler] and prints status messages
|
||||
// to a writer (typically stderr) so the user knows what's happening during
|
||||
// OAuth authorization. This is the handler used by the CLI/TUI binary.
|
||||
//
|
||||
// For TUI integration, set NotifyFunc to route messages through the TUI's
|
||||
// event system instead of (or in addition to) the writer.
|
||||
type CLIMCPAuthHandler struct {
|
||||
inner *DefaultMCPAuthHandler
|
||||
w io.Writer
|
||||
|
||||
// NotifyFunc, when set, is called with status messages instead of writing
|
||||
// to the writer. This allows the TUI to display system messages in the
|
||||
// chat stream. If nil, messages are written to w.
|
||||
NotifyFunc func(serverName, message string)
|
||||
}
|
||||
|
||||
// NewCLIMCPAuthHandler creates a CLI auth handler that prints status messages
|
||||
// to stderr and delegates the actual OAuth flow to a [DefaultMCPAuthHandler].
|
||||
func NewCLIMCPAuthHandler() (*CLIMCPAuthHandler, error) {
|
||||
inner, err := NewDefaultMCPAuthHandler()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &CLIMCPAuthHandler{inner: inner, w: os.Stderr}, nil
|
||||
}
|
||||
|
||||
// RedirectURI returns the OAuth redirect URI from the inner handler.
|
||||
func (h *CLIMCPAuthHandler) RedirectURI() string {
|
||||
return h.inner.RedirectURI()
|
||||
}
|
||||
|
||||
// HandleAuth prints status messages and delegates to the inner handler.
|
||||
func (h *CLIMCPAuthHandler) HandleAuth(ctx context.Context, serverName string, authURL string) (string, error) {
|
||||
h.notify(serverName, fmt.Sprintf("🔐 MCP server %q requires authentication. Opening browser...", serverName))
|
||||
h.notify(serverName, fmt.Sprintf(" If the browser doesn't open, visit:\n %s", authURL))
|
||||
|
||||
callbackURL, err := h.inner.HandleAuth(ctx, serverName, authURL)
|
||||
if err != nil {
|
||||
h.notify(serverName, fmt.Sprintf("✗ Authentication failed for %q: %v", serverName, err))
|
||||
return "", err
|
||||
}
|
||||
|
||||
h.notify(serverName, fmt.Sprintf("✓ Authenticated with %q", serverName))
|
||||
return callbackURL, nil
|
||||
}
|
||||
|
||||
// Close releases the inner handler's resources.
|
||||
func (h *CLIMCPAuthHandler) Close() error {
|
||||
return h.inner.Close()
|
||||
}
|
||||
|
||||
// notify sends a message through NotifyFunc if set, otherwise writes to w.
|
||||
func (h *CLIMCPAuthHandler) notify(serverName, message string) {
|
||||
if h.NotifyFunc != nil {
|
||||
h.NotifyFunc(serverName, message)
|
||||
return
|
||||
}
|
||||
_, _ = fmt.Fprintln(h.w, message)
|
||||
}
|
||||
Reference in New Issue
Block a user