mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
fix(agent): track tool call args per ToolCallID for parallel calls (#33)
Previously GenerateWithCallbacks stored the most recent tool call's args in a single shared variable, which got clobbered when a provider emitted multiple tool_use blocks in a single step. Every OnToolResult callback then received the args of the last OnToolCall, regardless of which call it was actually resolving — breaking any downstream UI, log, or trace that derived its description from the toolArgs parameter. - Replace the shared currentToolArgs with a map keyed by ToolCallID, guarded by a sync.Mutex in case the streaming layer dispatches callbacks from multiple goroutines. - Delete each entry in OnToolResult so the map cannot accumulate across steps. - Add a regression test driving the streaming wrapper with a fake fantasy.Agent that emits two parallel tool calls before either result, asserting each callback sees its own args. Fixes #33
This commit is contained in:
+20
-5
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"charm.land/fantasy"
|
"charm.land/fantasy"
|
||||||
@@ -585,8 +586,13 @@ func (a *Agent) GenerateWithCallbacks(ctx context.Context, messages []fantasy.Me
|
|||||||
// This avoids type conflicts with provider-level options.
|
// This avoids type conflicts with provider-level options.
|
||||||
history = applyCacheControlToMessages(history)
|
history = applyCacheControlToMessages(history)
|
||||||
|
|
||||||
// Track current tool call args for callbacks
|
// Track tool call args per-ToolCallID so parallel tool calls in a single
|
||||||
var currentToolArgs string
|
// step don't clobber each other. Without this, OnToolResult callbacks would
|
||||||
|
// all see the args of the last OnToolCall in the step. The mutex guards
|
||||||
|
// against the possibility that the underlying streaming layer dispatches
|
||||||
|
// callbacks from multiple goroutines.
|
||||||
|
toolCallArgs := make(map[string]string)
|
||||||
|
var toolCallArgsMu sync.Mutex
|
||||||
|
|
||||||
// Use the streaming path when streaming is enabled OR when any callbacks are
|
// Use the streaming path when streaming is enabled OR when any callbacks are
|
||||||
// provided. The agent only exposes tool/step callbacks on AgentStreamCall, so
|
// provided. The agent only exposes tool/step callbacks on AgentStreamCall, so
|
||||||
@@ -773,7 +779,9 @@ func (a *Agent) GenerateWithCallbacks(ctx context.Context, messages []fantasy.Me
|
|||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
currentToolArgs = tc.Input
|
toolCallArgsMu.Lock()
|
||||||
|
toolCallArgs[tc.ToolCallID] = tc.Input
|
||||||
|
toolCallArgsMu.Unlock()
|
||||||
|
|
||||||
// Notify about the tool call
|
// Notify about the tool call
|
||||||
if cb.OnToolCall != nil {
|
if cb.OnToolCall != nil {
|
||||||
@@ -793,15 +801,22 @@ func (a *Agent) GenerateWithCallbacks(ctx context.Context, messages []fantasy.Me
|
|||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
|
// Look up the args recorded for this specific tool call. Delete
|
||||||
|
// the entry so the map doesn't accumulate across steps.
|
||||||
|
toolCallArgsMu.Lock()
|
||||||
|
args := toolCallArgs[tr.ToolCallID]
|
||||||
|
delete(toolCallArgs, tr.ToolCallID)
|
||||||
|
toolCallArgsMu.Unlock()
|
||||||
|
|
||||||
// Notify tool execution finished
|
// Notify tool execution finished
|
||||||
if cb.OnToolExecution != nil {
|
if cb.OnToolExecution != nil {
|
||||||
cb.OnToolExecution(tr.ToolCallID, tr.ToolName, currentToolArgs, false)
|
cb.OnToolExecution(tr.ToolCallID, tr.ToolName, args, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cb.OnToolResult != nil {
|
if cb.OnToolResult != nil {
|
||||||
// Extract result text and error status
|
// Extract result text and error status
|
||||||
resultText, isError := extractToolResultText(tr)
|
resultText, isError := extractToolResultText(tr)
|
||||||
cb.OnToolResult(tr.ToolCallID, tr.ToolName, currentToolArgs, resultText, tr.ClientMetadata, isError)
|
cb.OnToolResult(tr.ToolCallID, tr.ToolName, args, resultText, tr.ClientMetadata, isError)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -0,0 +1,109 @@
|
|||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"charm.land/fantasy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// fakeParallelAgent simulates a provider that emits two parallel tool_use
|
||||||
|
// blocks in a single step. It invokes the streaming callbacks in the order:
|
||||||
|
//
|
||||||
|
// OnToolCall(A) -> OnToolCall(B) -> OnToolResult(A) -> OnToolResult(B)
|
||||||
|
//
|
||||||
|
// Before the fix in #33 the agent-layer wrapper recorded a single
|
||||||
|
// `currentToolArgs` variable that was clobbered by the second OnToolCall, so
|
||||||
|
// both OnToolResult callbacks received B's args instead of their own.
|
||||||
|
type fakeParallelAgent struct {
|
||||||
|
calls []fantasy.ToolCallContent
|
||||||
|
results []fantasy.ToolResultContent
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeParallelAgent) Generate(_ context.Context, _ fantasy.AgentCall) (*fantasy.AgentResult, error) {
|
||||||
|
return &fantasy.AgentResult{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeParallelAgent) Stream(_ context.Context, opts fantasy.AgentStreamCall) (*fantasy.AgentResult, error) {
|
||||||
|
for _, tc := range f.calls {
|
||||||
|
if opts.OnToolCall != nil {
|
||||||
|
if err := opts.OnToolCall(tc); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, tr := range f.results {
|
||||||
|
if opts.OnToolResult != nil {
|
||||||
|
if err := opts.OnToolResult(tr); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &fantasy.AgentResult{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGenerateWithCallbacks_ParallelToolArgs is the regression test for #33.
|
||||||
|
// It drives the streaming-callback wiring inside GenerateWithCallbacks with a
|
||||||
|
// fake fantasy.Agent that emits two parallel tool calls before either result.
|
||||||
|
// Each OnToolResult must receive the args of its own tool call (matched by
|
||||||
|
// ToolCallID), not the args of the last OnToolCall in the step.
|
||||||
|
func TestGenerateWithCallbacks_ParallelToolArgs(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
argsA := `{"name":"scheduled_jobs"}`
|
||||||
|
argsB := `{"name":"gmail_trigger"}`
|
||||||
|
|
||||||
|
fake := &fakeParallelAgent{
|
||||||
|
calls: []fantasy.ToolCallContent{
|
||||||
|
{ToolCallID: "kit-A", ToolName: "load_skill", Input: argsA},
|
||||||
|
{ToolCallID: "kit-B", ToolName: "load_skill", Input: argsB},
|
||||||
|
},
|
||||||
|
results: []fantasy.ToolResultContent{
|
||||||
|
{ToolCallID: "kit-A", ToolName: "load_skill", Result: fantasy.ToolResultOutputContentText{Text: "ok-A"}},
|
||||||
|
{ToolCallID: "kit-B", ToolName: "load_skill", Result: fantasy.ToolResultOutputContentText{Text: "ok-B"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
a := &Agent{
|
||||||
|
fantasyAgent: fake,
|
||||||
|
streamingEnabled: false, // exercise the "hasCallbacks" branch
|
||||||
|
}
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
resultArgs := map[string]string{}
|
||||||
|
executionArgs := map[string]string{} // captured when running == false
|
||||||
|
|
||||||
|
cb := GenerateCallbacks{
|
||||||
|
OnToolExecution: func(id, _, args string, running bool) {
|
||||||
|
if running {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
executionArgs[id] = args
|
||||||
|
},
|
||||||
|
OnToolResult: func(id, _, args, _, _ string, _ bool) {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
resultArgs[id] = args
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := a.GenerateWithCallbacks(context.Background(), nil, cb); err != nil {
|
||||||
|
t.Fatalf("GenerateWithCallbacks returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := resultArgs["kit-A"], argsA; got != want {
|
||||||
|
t.Errorf("OnToolResult for kit-A: args = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := resultArgs["kit-B"], argsB; got != want {
|
||||||
|
t.Errorf("OnToolResult for kit-B: args = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := executionArgs["kit-A"], argsA; got != want {
|
||||||
|
t.Errorf("OnToolExecution(finish) for kit-A: args = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := executionArgs["kit-B"], argsB; got != want {
|
||||||
|
t.Errorf("OnToolExecution(finish) for kit-B: args = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user