diff --git a/pkg/kit/README.md b/pkg/kit/README.md index c8ba204a..24c813a9 100644 --- a/pkg/kit/README.md +++ b/pkg/kit/README.md @@ -263,6 +263,8 @@ kit.LLMFilePart // {Filename, Data []byte, MediaType} // All fields use SDK types (e.g. `[]kit.Tool`), so consumers can construct // these without importing any LLM-provider package. kit.AgentConfig // Lower-level agent config — prefer Options unless you need direct control +kit.DebugLogger // Interface: LogDebug(string) / IsDebugEnabled() bool +kit.MCPTaskConfig // Task-aware MCP tools/call config (modes, polling, progress) kit.ToolCallHandler // func(toolCallID, toolName, toolArgs string) kit.ToolExecutionHandler // func(toolCallID, toolName, toolArgs string, isStarting bool) kit.ToolResultHandler // func(toolCallID, toolName, toolArgs, result, metadata string, isError bool) diff --git a/pkg/kit/agent_config_internal_test.go b/pkg/kit/agent_config_internal_test.go index ef9d5b4e..dc9a3e4d 100644 --- a/pkg/kit/agent_config_internal_test.go +++ b/pkg/kit/agent_config_internal_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "testing" + "time" "github.com/mark3labs/kit/internal/agent" ) @@ -97,6 +98,60 @@ func TestAgentConfigToInternal(t *testing.T) { } }) + t.Run("DebugLogger propagates", func(t *testing.T) { + dl := &fakeDebugLogger{enabled: true} + c := &AgentConfig{DebugLogger: dl} + got := c.toInternal() + if got.DebugLogger == nil { + t.Fatal("internal DebugLogger is nil") + } + if !got.DebugLogger.IsDebugEnabled() { + t.Error("IsDebugEnabled = false, want true") + } + got.DebugLogger.LogDebug("hello") + if len(dl.messages) != 1 || dl.messages[0] != "hello" { + t.Errorf("messages = %v, want [hello]", dl.messages) + } + }) + + t.Run("MCPTaskConfig propagates with mode + progress", func(t *testing.T) { + c := &AgentConfig{ + MCPTaskConfig: MCPTaskConfig{ + PerServerMode: map[string]MCPTaskMode{ + "build-svr": MCPTaskModeAlways, + }, + DefaultTTL: 30 * time.Second, + PollInterval: 250 * time.Millisecond, + MaxPollInterval: 2 * time.Second, + Timeout: 5 * time.Minute, + Progress: func(_ MCPTaskProgress) {}, + }, + } + got := c.toInternal() + if got.MCPTaskConfig.DefaultTTL != 30*time.Second { + t.Errorf("DefaultTTL = %v, want 30s", got.MCPTaskConfig.DefaultTTL) + } + if got.MCPTaskConfig.PollInterval != 250*time.Millisecond { + t.Errorf("PollInterval = %v, want 250ms", got.MCPTaskConfig.PollInterval) + } + if got.MCPTaskConfig.MaxPollInterval != 2*time.Second { + t.Errorf("MaxPollInterval = %v, want 2s", got.MCPTaskConfig.MaxPollInterval) + } + if got.MCPTaskConfig.Timeout != 5*time.Minute { + t.Errorf("Timeout = %v, want 5m", got.MCPTaskConfig.Timeout) + } + mode, ok := got.MCPTaskConfig.PerServerMode["build-svr"] + if !ok { + t.Fatal("PerServerMode missing 'build-svr'") + } + if string(mode) != string(MCPTaskModeAlways) { + t.Errorf("mode = %q, want %q", mode, MCPTaskModeAlways) + } + if got.MCPTaskConfig.Progress == nil { + t.Fatal("internal Progress handler is nil") + } + }) + t.Run("auth and token store factories are wired", func(t *testing.T) { auth := &fakeAuthHandler{} tokenCalls := 0 @@ -142,3 +197,12 @@ func (f *fakeAuthHandler) RedirectURI() string { return "redirect" } func (f *fakeAuthHandler) HandleAuth(_ context.Context, _ string, _ string) (string, error) { return "", nil } + +// fakeDebugLogger implements kit.DebugLogger for tests. +type fakeDebugLogger struct { + enabled bool + messages []string +} + +func (f *fakeDebugLogger) LogDebug(m string) { f.messages = append(f.messages, m) } +func (f *fakeDebugLogger) IsDebugEnabled() bool { return f.enabled } diff --git a/pkg/kit/mcp_tasks.go b/pkg/kit/mcp_tasks.go index e341f023..f46e9303 100644 --- a/pkg/kit/mcp_tasks.go +++ b/pkg/kit/mcp_tasks.go @@ -98,6 +98,70 @@ type MCPTaskProgress struct { // dispatched on a goroutine. type MCPTaskProgressHandler func(MCPTaskProgress) +// MCPTaskConfig configures task-aware MCP tools/call execution. All fields +// are optional; the zero value disables progress callbacks and applies +// sensible polling defaults inside the engine. +// +// For most consumers, the flat [Options] fields (`MCPTaskMode`, +// `MCPTaskTTL`, `MCPTaskPollInterval`, `MCPTaskMaxPollInterval`, +// `MCPTaskTimeout`, `MCPTaskProgress`) are the preferred entry point. +// MCPTaskConfig is exposed for the low-level [AgentConfig] path. +type MCPTaskConfig struct { + // PerServerMode overrides the per-server task mode resolved from + // [MCPServerConfig]. Keys are server names. Missing entries fall back + // to the configured value. + PerServerMode map[string]MCPTaskMode + + // DefaultTTL is the TTL hint sent in TaskParams when augmenting a + // tools/call. Zero means omit the TTL — let the server pick its own. + DefaultTTL time.Duration + + // PollInterval is the fallback interval between tasks/get requests + // when the server does not suggest one. Zero defaults to 1 second. + PollInterval time.Duration + + // MaxPollInterval caps the polling interval. Zero defaults to 5 seconds. + MaxPollInterval time.Duration + + // Timeout is the maximum wall-clock duration to wait for a task to + // reach a terminal state. Zero defaults to 15 minutes. Independent + // of the per-call context deadline; whichever fires first wins. + Timeout time.Duration + + // Progress, if non-nil, receives every status transition observed by + // the polling loop. + Progress MCPTaskProgressHandler +} + +// toToolsConfig converts the SDK-level [MCPTaskConfig] to the internal +// tools-package representation. Keeps the dependency arrow internal-only. +func (c MCPTaskConfig) toToolsConfig() tools.MCPTaskConfig { + cfg := tools.MCPTaskConfig{ + DefaultTTL: c.DefaultTTL, + PollInterval: c.PollInterval, + MaxPollInterval: c.MaxPollInterval, + Timeout: c.Timeout, + } + if len(c.PerServerMode) > 0 { + cfg.PerServerMode = make(map[string]tools.MCPTaskMode, len(c.PerServerMode)) + for k, v := range c.PerServerMode { + cfg.PerServerMode[k] = tools.MCPTaskMode(v) + } + } + if c.Progress != nil { + h := c.Progress + cfg.Progress = func(p tools.MCPTaskProgress) { + h(MCPTaskProgress{ + Server: p.Server, + TaskID: p.TaskID, + Status: MCPTaskStatus(p.Status), + Message: p.Message, + }) + } + } + return cfg +} + // mcpTaskOptions carries SDK consumer configuration into the agent setup. // Stored on Options as a single value so the public surface stays compact; // individual fields are exposed via WithMCP* builder functions. diff --git a/pkg/kit/types.go b/pkg/kit/types.go index a92346c5..5f2428a5 100644 --- a/pkg/kit/types.go +++ b/pkg/kit/types.go @@ -78,6 +78,23 @@ type MCPServerConfig = config.MCPServerConfig // ==== Agent Types ==== +// DebugLogger is an SDK-owned interface for low-level debug logging from +// the engine and MCP tool plumbing. Implementations must be safe for +// concurrent use. +// +// Most consumers do not need to provide one; pass [Options.Debug] = true +// to use the default logger. DebugLogger is exposed for the low-level +// [AgentConfig] path and for embedders that want to route debug output +// into their own logging system. +type DebugLogger interface { + // LogDebug records a single debug message. Implementations may drop, + // buffer, or render the message however they choose. + LogDebug(message string) + // IsDebugEnabled reports whether debug logging is active. Callers may + // check this before doing expensive formatting work. + IsDebugEnabled() bool +} + // AgentConfig holds configuration options for constructing an agent at the // SDK boundary. All fields use SDK-owned types, so consumers can populate // this struct without importing any underlying LLM-provider package. @@ -134,6 +151,19 @@ type AgentConfig struct { // when its tools have finished loading (or failed). Called from a // background goroutine. OnMCPServerLoaded func(serverName string, toolCount int, err error) + + // DebugLogger receives low-level debug output from the engine and the + // MCP tool plumbing. Nil means no debug output is emitted at this + // layer (regardless of [Options.Debug], which feeds the higher-level + // [New] entry point). Pass an implementation here when wiring a custom + // logger through the lower-level AgentConfig path. + DebugLogger DebugLogger + + // MCPTaskConfig configures task-aware MCP tools/call execution — mode + // overrides, polling intervals, timeouts, and the progress handler. + // The zero value preserves historical synchronous-only behaviour for + // any server that didn't advertise task support during initialize. + MCPTaskConfig MCPTaskConfig } // toInternal converts an AgentConfig to its internal representation. @@ -161,6 +191,10 @@ func (c *AgentConfig) toInternal() *agent.AgentConfig { if c.TokenStoreFactory != nil { out.TokenStoreFactory = tools.TokenStoreFactory(c.TokenStoreFactory) } + if c.DebugLogger != nil { + out.DebugLogger = c.DebugLogger + } + out.MCPTaskConfig = c.MCPTaskConfig.toToolsConfig() return out }