diff --git a/.kit/extensions/subagent-monitor.go b/.kit/extensions/subagent-monitor.go index f7e86d0c..d61b2541 100644 --- a/.kit/extensions/subagent-monitor.go +++ b/.kit/extensions/subagent-monitor.go @@ -13,6 +13,8 @@ // - No channels in maps (Yaegi panics on range over map[string]chan) // - All ctx.* calls guarded with nil checks // - Simple data structures only +// - The extension runner serializes handler calls per-extension, so +// concurrent subagent events cannot race on this shared state. package main import ( @@ -43,7 +45,8 @@ const ( ) // --------------------------------------------------------------------------- -// Package-level state - all simple types +// Package-level state — safe because the runner serializes all handler +// invocations for the same extension (per-extension reentrant mutex). // --------------------------------------------------------------------------- var ( @@ -282,8 +285,8 @@ func Init(api ext.API) { submonPushWidget() - // Remove the entry immediately (no goroutine to avoid races) - newEntries := submonEntries[:0] + // Remove the entry — build a new slice to avoid aliasing bugs + newEntries := make([]*submonEntry, 0, len(submonEntries)) for _, en := range submonEntries { if en.callID != e.ToolCallID { newEntries = append(newEntries, en) diff --git a/examples/extensions/subagent-monitor_test.go b/examples/extensions/subagent-monitor_test.go index 0129d73f..25a5fd28 100644 --- a/examples/extensions/subagent-monitor_test.go +++ b/examples/extensions/subagent-monitor_test.go @@ -130,6 +130,58 @@ func TestSubagentMonitor_MultipleSubagents(t *testing.T) { time.Sleep(100 * time.Millisecond) } +// TestSubagentMonitor_ConcurrentSubagents verifies no panics when multiple +// subagents emit events concurrently from different goroutines. +func TestSubagentMonitor_ConcurrentSubagents(t *testing.T) { + harness := test.New(t) + harness.LoadFile("../../.kit/extensions/subagent-monitor.go") + + _, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"}) + if err != nil { + t.Fatalf("SessionStart should not error: %v", err) + } + + // Start 5 subagents concurrently + done := make(chan struct{}, 5) + for i := range 5 { + go func(idx int) { + defer func() { done <- struct{}{} }() + + callID := fmt.Sprintf("concurrent-%d", idx) + task := fmt.Sprintf("concurrent task %d", idx) + + _, _ = harness.Emit(extensions.SubagentStartEvent{ + ToolCallID: callID, + Task: task, + }) + + // Emit many chunks rapidly + for j := range 20 { + _, _ = harness.Emit(extensions.SubagentChunkEvent{ + ToolCallID: callID, + Task: task, + ChunkType: "text", + Content: fmt.Sprintf("agent %d chunk %d", idx, j), + }) + } + + _, _ = harness.Emit(extensions.SubagentEndEvent{ + ToolCallID: callID, + Task: task, + Response: "done", + }) + }(i) + } + + // Wait for all goroutines + for range 5 { + <-done + } + + // Allow any final processing + time.Sleep(200 * time.Millisecond) +} + // TestSubagentMonitor_SessionShutdown verifies shutdown doesn't panic // even with nil ctx functions. func TestSubagentMonitor_SessionShutdown(t *testing.T) { diff --git a/internal/extensions/runner.go b/internal/extensions/runner.go index c2b03f9e..0bd6aeac 100644 --- a/internal/extensions/runner.go +++ b/internal/extensions/runner.go @@ -1,21 +1,93 @@ package extensions import ( + "bytes" "fmt" "log" "os" + "runtime" "sort" + "strconv" "strings" "sync" "github.com/spf13/viper" ) +// --------------------------------------------------------------------------- +// reentrantMu — a per-extension mutex that allows the same goroutine to +// re-enter (e.g. handler → ctx.EmitCustomEvent → handler in same extension). +// Different goroutines are serialized, preventing concurrent state mutation. +// --------------------------------------------------------------------------- + +type reentrantMu struct { + mu sync.Mutex + cond *sync.Cond + owner int64 // goroutine ID that holds the lock, or 0 + depth int // re-entrancy depth +} + +// initReentrantMu initializes the reentrant mutex in-place. Must be called +// after the struct is at its final memory location (not before copying). +func (r *reentrantMu) init() { + r.cond = sync.NewCond(&r.mu) +} + +// lock acquires the mutex. If the calling goroutine already holds it, the +// call succeeds immediately (re-entrant). Every call to lock must be paired +// with a call to unlock. +func (r *reentrantMu) lock() { + gid := goroutineID() + r.mu.Lock() + if r.owner == gid { + // Re-entrant: same goroutine already holds the lock. + r.depth++ + r.mu.Unlock() + return + } + // Wait for the current owner to release. + for r.owner != 0 { + r.cond.Wait() // releases mu, blocks, re-acquires mu on wake + } + r.owner = gid + r.depth = 1 + r.mu.Unlock() +} + +// unlock releases the mutex (or decrements re-entrancy depth). +func (r *reentrantMu) unlock() { + r.mu.Lock() + r.depth-- + if r.depth == 0 { + r.owner = 0 + r.cond.Signal() + } + r.mu.Unlock() +} + +// goroutineID extracts the current goroutine's ID from runtime.Stack output. +// This is a well-known technique used by Go testing infrastructure. +func goroutineID() int64 { + var buf [64]byte + n := runtime.Stack(buf[:], false) + // Stack output starts with "goroutine NNN [" + s := buf[:n] + s = s[len("goroutine "):] + s = s[:bytes.IndexByte(s, ' ')] + id, _ := strconv.ParseInt(string(s), 10, 64) + return id +} + // Runner manages loaded extensions and dispatches events to their handlers // sequentially. Handlers execute in extension // load order; for cancellable events the first blocking result wins. +// +// Each extension has a dedicated reentrant mutex so that handlers for the +// same extension are serialized (preventing data races on shared package-level +// state), while handlers for different extensions may execute concurrently. type Runner struct { extensions []LoadedExtension + extMu []reentrantMu // per-extension reentrant mutex, indexed by extension position ctx Context widgets map[string]WidgetConfig // keyed by widget ID statusEntries map[string]StatusBarEntry // keyed by status key @@ -52,7 +124,11 @@ type LoadedExtension struct { // NewRunner creates a Runner from a set of loaded extensions. func NewRunner(exts []LoadedExtension) *Runner { - return &Runner{extensions: exts} + mus := make([]reentrantMu, len(exts)) + for i := range mus { + mus[i].init() + } + return &Runner{extensions: exts, extMu: mus} } // SetContext updates the runtime context (session ID, model, etc.) that is @@ -367,6 +443,11 @@ func (r *Runner) Emit(event Event) (Result, error) { for i := range r.extensions { ext := &r.extensions[i] handlers := ext.Handlers[event.Type()] + if len(handlers) == 0 { + continue + } + + r.extMu[i].lock() for _, handler := range handlers { result, err := safeCall(handler, event, ctx) if err != nil { @@ -379,6 +460,7 @@ func (r *Runner) Emit(event Event) (Result, error) { // Check for blocking/short-circuit results. if isBlocking(result) { + r.extMu[i].unlock() return result, nil } @@ -386,6 +468,7 @@ func (r *Runner) Emit(event Event) (Result, error) { // the caller is responsible for applying the modifications. accumulated = result } + r.extMu[i].unlock() } return accumulated, nil } @@ -712,11 +795,17 @@ func (r *Runner) EmitCustomEvent(name, data string) { // Extension-registered handlers first (in load order). for i := range r.extensions { - for _, h := range r.extensions[i].CustomEventHandlers[name] { + extHandlers := r.extensions[i].CustomEventHandlers[name] + if len(extHandlers) == 0 { + continue + } + r.extMu[i].lock() + for _, h := range extHandlers { safeInvoke(h) } + r.extMu[i].unlock() } - // Then dynamic subscriptions. + // Then dynamic subscriptions (not extension-scoped, no per-ext lock). for _, h := range dynamicHandlers { safeInvoke(h) } diff --git a/internal/extensions/runner_test.go b/internal/extensions/runner_test.go index ec67be80..cedc076a 100644 --- a/internal/extensions/runner_test.go +++ b/internal/extensions/runner_test.go @@ -1,6 +1,7 @@ package extensions import ( + "sync" "testing" ) @@ -571,3 +572,142 @@ func TestRunner_ContextPrintNilSafe(t *testing.T) { t.Fatalf("unexpected error: %v", err) } } + +func TestRunner_ConcurrentEmitSameExtension(t *testing.T) { + // Verify that concurrent Emit calls for the same extension are serialized + // and don't cause data races on shared handler state. + var counter int + ext := makeHandlerExt("shared-state.go", map[EventType][]HandlerFunc{ + SubagentStart: { + func(e Event, c Context) Result { + // Read-modify-write: racy without serialization. + v := counter + counter = v + 1 + return nil + }, + }, + SubagentChunk: { + func(e Event, c Context) Result { + v := counter + counter = v + 1 + return nil + }, + }, + }) + + r := makeRunner(ext) + var wg sync.WaitGroup + const goroutines = 20 + const iterations = 50 + wg.Add(goroutines) + for range goroutines { + go func() { + defer wg.Done() + for range iterations { + _, _ = r.Emit(SubagentStartEvent{ToolCallID: "x"}) + _, _ = r.Emit(SubagentChunkEvent{ToolCallID: "x"}) + } + }() + } + wg.Wait() + if counter != goroutines*iterations*2 { + t.Errorf("expected counter=%d, got %d (race detected)", goroutines*iterations*2, counter) + } +} + +func TestRunner_ConcurrentEmitDifferentExtensions(t *testing.T) { + // Two extensions with independent state should not block each other + // and should both run correctly under concurrent Emit calls. + var counter1, counter2 int + ext1 := makeHandlerExt("ext1.go", map[EventType][]HandlerFunc{ + SubagentStart: { + func(e Event, c Context) Result { + v := counter1 + counter1 = v + 1 + return nil + }, + }, + }) + ext2 := makeHandlerExt("ext2.go", map[EventType][]HandlerFunc{ + SubagentStart: { + func(e Event, c Context) Result { + v := counter2 + counter2 = v + 1 + return nil + }, + }, + }) + + r := makeRunner(ext1, ext2) + var wg sync.WaitGroup + const goroutines = 20 + const iterations = 50 + wg.Add(goroutines) + for range goroutines { + go func() { + defer wg.Done() + for range iterations { + _, _ = r.Emit(SubagentStartEvent{ToolCallID: "x"}) + } + }() + } + wg.Wait() + expected := goroutines * iterations + if counter1 != expected { + t.Errorf("ext1 counter: expected %d, got %d", expected, counter1) + } + if counter2 != expected { + t.Errorf("ext2 counter: expected %d, got %d", expected, counter2) + } +} + +func TestRunner_ReentrantEmitCustomEvent(t *testing.T) { + // Verify that a handler can call EmitCustomEvent (which dispatches to + // the same extension's custom event handlers) without deadlocking. + var order []string + ext := LoadedExtension{ + Path: "reentrant.go", + Handlers: map[EventType][]HandlerFunc{ + SessionStart: { + func(e Event, c Context) Result { + order = append(order, "session_start") + // This triggers EmitCustomEvent for the same extension + // via a direct runner call (simulating ctx.EmitCustomEvent). + return nil + }, + }, + }, + CustomEventHandlers: map[string][]func(string){ + "test-event": { + func(data string) { + order = append(order, "custom:"+data) + }, + }, + }, + } + + r := makeRunner(ext) + + // Wire up the handler to call EmitCustomEvent re-entrantly. + ext.Handlers[SessionStart] = []HandlerFunc{ + func(e Event, c Context) Result { + order = append(order, "session_start") + r.EmitCustomEvent("test-event", "hello") + return nil + }, + } + r.extensions[0] = ext + // Rebuild mutexes after modifying extensions slice. + r.extMu = make([]reentrantMu, len(r.extensions)) + for i := range r.extMu { + r.extMu[i].init() + } + + _, err := r.Emit(SessionStartEvent{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(order) != 2 || order[0] != "session_start" || order[1] != "custom:hello" { + t.Errorf("expected [session_start, custom:hello], got %v", order) + } +} diff --git a/pkg/kit/kit.go b/pkg/kit/kit.go index e6885a0f..d19b71ac 100644 --- a/pkg/kit/kit.go +++ b/pkg/kit/kit.go @@ -51,6 +51,7 @@ type Kit struct { bufferedLogger *tools.BufferedDebugLogger authHandler MCPAuthHandler // OAuth handler for remote MCP servers (may need Close) opts *Options // stored for reload operations (skills, etc.) + mcpConfig *config.Config // loaded MCP/server config, shared with subagents // hasCustomSystemPrompt is true when the user explicitly configured a // system prompt (via --system-prompt flag, config file, or SDK option). @@ -849,6 +850,13 @@ type Options struct { // (e.g. AGENTS.md) from the working directory. NoContextFiles bool + // MCPConfig provides a pre-loaded MCP configuration. When set, + // LoadAndValidateConfig is skipped during Kit creation — avoiding + // viper access entirely. This is set automatically for in-process + // subagents (inheriting the parent's loaded config) and can be used + // by SDK consumers who build config programmatically. + MCPConfig *config.Config + // InProcessMCPServers registers mcp-go servers that run in the same // process. Each key is the server name (used to prefix tool names, e.g. // "docs__search"). The value must be a *[server.MCPServer]. @@ -1136,8 +1144,11 @@ func New(ctx context.Context, opts *Options) (*Kit, error) { } // ---- viperInitMu released — heavy I/O below runs concurrently ---- - // Load MCP configuration. Use pre-loaded config if provided via CLI options. - if opts.CLI != nil && opts.CLI.MCPConfig != nil { + // Load MCP configuration. Use pre-loaded config if provided directly, + // via CLI options, or load from viper as a last resort. + if opts.MCPConfig != nil { + mcpConfig = opts.MCPConfig + } else if opts.CLI != nil && opts.CLI.MCPConfig != nil { mcpConfig = opts.CLI.MCPConfig } if mcpConfig == nil { @@ -1258,6 +1269,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) { bufferedLogger: agentResult.BufferedLogger, authHandler: setupOpts.AuthHandler, opts: opts, + mcpConfig: mcpConfig, hasCustomSystemPrompt: hasCustomSystemPrompt, beforeToolCall: beforeToolCall, afterToolResult: afterToolResult, @@ -1582,13 +1594,15 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult tools = SubagentTools() } - // Create child Kit instance. + // Create child Kit instance. Pass the parent's loaded MCP config to + // avoid re-reading viper (which races with concurrent subagent spawns). childOpts := &Options{ Model: model, SystemPrompt: systemPrompt, Tools: tools, NoSession: cfg.NoSession, Quiet: true, + MCPConfig: m.mcpConfig, } child, err := New(ctx, childOpts) if err != nil {