Compare commits

...

4 Commits

Author SHA1 Message Date
Ed Zynda a322dfc59a fix(ui): eliminate mouse copy-selection drift during streaming
- Lock viewport scroll while a drag-select is active so highlighted
  content stays under the cursor (SetItems, appendStreamingChunk,
  MouseWheelDown all now honor IsMouseDown).
- HandleMouseDrag defensively clears autoScroll on every update so a
  racy re-enable can't shift the row mid-drag.
- Recompute scrollback yOffset/viewport height on each mouse event
  via currentScrollbackBounds() instead of relying on stale values
  cached during the previous View() pass.
- Account for canceling/ctrlCPressedOnce warning rows in
  distributeHeight and mark layoutDirty when those flags toggle so
  the height budget and mouse origin stay in sync.
- Add ScrollList regression tests covering the three invariants.
2026-05-15 13:30:57 +03:00
Ed Zynda b1387d837e feat(ui): add /copy slash command to copy last message
- Register /copy (alias /cp) in the System command category
- Walk the scrollback to find the last user/assistant/reasoning
  message, skipping transient system messages
- Reuse internal/ui/clipboard.CopyToClipboard for OSC 52 + native
  clipboard support (works over SSH)
- Document the command in /help
2026-05-15 13:06:35 +03:00
Ed Zynda f561f4cfd9 fix(session): order kept messages before post-compact branch in BuildContext
After /compact, BuildContext emitted [summary, post-compact, kept]
which placed an older kept user/assistant turn after the latest
post-compaction turn. This broke user/assistant alternation and caused
the model to respond as if the post-compaction turn never happened on
the next user message.

- Emit kept messages chronologically before post-compaction messages
- Mirror the same order in GetContextEntryIDs so cut-point to entry-ID
  mapping stays aligned across repeat compactions
- Update TestCompactionWithNewMessagesAfterCompaction to assert the
  correct chronological order
2026-05-14 20:42:20 +03:00
Ed Zynda 64caed57d4 fix(sdk): stop leaking fantasy types through pkg/kit.AgentConfig (#30) (#32)
* fix(sdk): stop leaking fantasy types through pkg/kit.AgentConfig (#30)

Replace the alias-based AgentConfig and handler types with SDK-owned
structs and function types. CoreTools / ExtraTools / ToolWrapper now
accept []kit.Tool, and the handler types (ToolCallHandler,
ToolExecutionHandler, ToolResultHandler, ResponseHandler,
StreamingResponseHandler, ToolCallContentHandler) plus SpinnerFunc are
declared in pkg/kit/ with signatures that reference only SDK types.

Consumers no longer need to import charm.land/fantasy to populate an
AgentConfig or assign a handler. go doc pkg/kit AgentConfig output no
longer mentions fantasy.*.

- Add unexported (*AgentConfig).toInternal() to convert at the SDK
  boundary; Tool is still an alias for the underlying tool type, so
  slice and function fields convert without allocation.
- Add agent_config_internal_test.go covering nil receiver, scalar
  fields, tool slices, ToolWrapper invocation, OnMCPServerLoaded, and
  auth/token-factory wiring.
- Add types_test.go cases that populate AgentConfig and SpinnerFunc
  without importing fantasy -- the file compiling is the regression
  proof for the leak.
- Update pkg/kit/README.md Re-exported Types section to record that
  AgentConfig and the handler types are now Kit-owned.

Fixes #30

* fix(sdk): add DebugLogger and MCPTaskConfig to kit.AgentConfig (#30)

The first revision of the SDK-owned AgentConfig dropped two fields that
internal/agent.AgentConfig carried: DebugLogger (tools.DebugLogger) and
MCPTaskConfig (tools.MCPTaskConfig). Restore them with SDK-owned
equivalents and wire them through toInternal().

- Add kit.DebugLogger interface (LogDebug / IsDebugEnabled) mirroring
  tools.DebugLogger. Interface-to-interface assignment is automatic
  because the method sets match.
- Add kit.MCPTaskConfig struct mirroring tools.MCPTaskConfig with SDK
  types (MCPTaskMode, MCPTaskProgressHandler) and a toToolsConfig()
  helper that converts at the SDK boundary.
- Wire both new fields in (*AgentConfig).toInternal().
- Extend agent_config_internal_test.go with cases for both fields.
- Document the additions in pkg/kit/README.md.
2026-05-13 21:10:28 +03:00
12 changed files with 905 additions and 119 deletions
+18 -9
View File
@@ -129,26 +129,35 @@ func TestCompactionWithNewMessagesAfterCompaction(t *testing.T) {
msg4 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 4 - after compaction"}}}
_, _ = tm.AppendMessage(msg4)
// BuildContext should return: [summary] + [M4 (new after compaction)] + [M3 (kept)]
// BuildContext should return: [summary] + [M3 (kept)] + [M4 (new after compaction)]
// Kept messages must appear BEFORE post-compaction messages so the LLM
// sees the conversation in chronological order. Otherwise the latest
// post-compaction user message would be followed by an older kept user
// message, breaking user/assistant alternation and causing the model to
// respond as if the post-compaction turn never happened.
messages, _, _ := tm.BuildContext()
if len(messages) != 3 {
t.Fatalf("expected 3 messages (summary + M4 + M3), got %d: %+v", len(messages), messages)
t.Fatalf("expected 3 messages (summary + M3 + M4), got %d: %+v", len(messages), messages)
}
// Verify order: summary, M4 (new), M3 (kept)
// Verify order: summary, M3 (kept), M4 (new)
if messages[0].Role != fantasy.MessageRoleSystem {
t.Errorf("first message should be summary, got %s", messages[0].Role)
}
if messages[1].Role != fantasy.MessageRoleAssistant {
t.Errorf("second message should be assistant (M4), got %s", messages[1].Role)
if messages[1].Role != fantasy.MessageRoleUser {
t.Errorf("second message should be user (M3 kept), got %s", messages[1].Role)
}
m4Text := messages[1].Content[0].(fantasy.TextPart).Text
m3Text := messages[1].Content[0].(fantasy.TextPart).Text
if m3Text != "Message 3 - kept" {
t.Errorf("unexpected M3 text: %s", m3Text)
}
if messages[2].Role != fantasy.MessageRoleAssistant {
t.Errorf("third message should be assistant (M4 post-compact), got %s", messages[2].Role)
}
m4Text := messages[2].Content[0].(fantasy.TextPart).Text
if m4Text != "Message 4 - after compaction" {
t.Errorf("unexpected M4 text: %s", m4Text)
}
if messages[2].Role != fantasy.MessageRoleUser {
t.Errorf("third message should be user (M3), got %s", messages[2].Role)
}
// Verify that M1 is NOT in the context
for i, msg := range messages {
+82 -79
View File
@@ -755,9 +755,17 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
}
}
// If there is a compaction, inject the summary first and collect
// the kept messages starting from FirstKeptEntryID (since the
// compaction entry's parent chain doesn't include them).
// If there is a compaction, inject the summary first, then the
// preserved "kept" messages (chronologically before the compaction),
// then the post-compaction messages (chronologically after).
//
// Order matters: the kept messages must come BEFORE the post-compaction
// branch so the LLM sees the conversation in chronological order. If the
// kept messages were appended last, the latest user message in the
// current branch would be followed by an older kept user message,
// breaking the strict user/assistant alternation that providers expect
// and causing the model to respond as if the previous turn never
// happened.
if lastCompaction != nil {
messages = append(messages, fantasy.Message{
Role: fantasy.MessageRoleSystem,
@@ -768,49 +776,10 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
},
})
// Collect entries from the compaction entry itself (at compactionIndex)
// and any entries before it in the branch (newer messages).
for i := compactionIndex; i < len(branch); i++ {
entry := branch[i]
switch e := entry.(type) {
case *MessageEntry:
msg, err := e.ToMessage()
if err != nil {
continue // skip malformed entries
}
msgs := msg.ToLLMMessages()
messages = append(messages, msgs...)
case *BranchSummaryEntry:
// Convert branch summary to a user message for context.
if e.Summary != "" {
messages = append(messages, fantasy.Message{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{
Text: fmt.Sprintf("[Branch context: %s]", e.Summary),
},
},
})
}
case *ModelChangeEntry:
provider = e.Provider
modelID = e.ModelID
case *CompactionEntry:
// Already handled above (summary injected).
continue
}
}
// Now collect the kept messages starting from FirstKeptEntryID.
// These are not in the current branch because the compaction entry
// is parented to the first kept entry's parent, not the first kept entry.
// We iterate through entries in order (not using getBranchLocked) to avoid
// walking back to old compacted messages.
// We stop when we reach the compaction entry to avoid double-counting
// messages that were added after the compaction.
// Step 1: collect the kept messages starting from FirstKeptEntryID.
// These are not on the current branch (the compaction entry is a
// new root with no parent), so we iterate tm.entries in append order
// and stop when we reach the compaction entry itself.
if lastCompaction.FirstKeptEntryID != "" {
found := false
for _, entry := range tm.entries {
@@ -825,13 +794,12 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
}
}
// Stop when we reach the compaction entry itself.
// Messages after the compaction are collected from the branch walk above.
// Stop when we reach the compaction entry itself; messages
// after it are collected from the branch walk below.
if entryID == lastCompaction.ID {
break
}
// Process this kept entry.
switch e := entry.(type) {
case *MessageEntry:
msg, err := e.ToMessage()
@@ -860,6 +828,42 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
}
}
// Step 2: collect entries on the current branch after the compaction
// entry (these are post-compaction messages). The compaction entry
// itself is skipped — its summary was already injected above.
for i := compactionIndex; i < len(branch); i++ {
entry := branch[i]
switch e := entry.(type) {
case *MessageEntry:
msg, err := e.ToMessage()
if err != nil {
continue
}
msgs := msg.ToLLMMessages()
messages = append(messages, msgs...)
case *BranchSummaryEntry:
if e.Summary != "" {
messages = append(messages, fantasy.Message{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{
Text: fmt.Sprintf("[Branch context: %s]", e.Summary),
},
},
})
}
case *ModelChangeEntry:
provider = e.Provider
modelID = e.ModelID
case *CompactionEntry:
// Summary already injected above.
continue
}
}
return messages, provider, modelID
}
@@ -1030,44 +1034,22 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
var ids []string
// If there's a compaction, we need to collect IDs from:
// 1. Entries after the compaction entry in the branch (newer messages)
// 2. Entries from FirstKeptEntryID onwards (kept messages)
// If there's a compaction, we collect IDs in the same order as
// BuildContext: [summary placeholder, kept messages, post-compaction
// messages]. This ordering must stay in sync with BuildContext so a
// cut-point index can be mapped back to the correct entry ID.
if lastCompaction != nil {
// Placeholder for the summary system message (no entry ID).
ids = append(ids, "")
// Collect IDs from entries after the compaction entry (newer messages).
for i := compactionIndex + 1; i < len(branch); i++ {
entry := branch[i]
switch e := entry.(type) {
case *MessageEntry:
msg, err := e.ToMessage()
if err != nil {
continue
}
msgs := msg.ToLLMMessages()
for range msgs {
ids = append(ids, e.ID)
}
case *BranchSummaryEntry:
if e.Summary != "" {
ids = append(ids, e.ID)
}
}
}
// Collect IDs from the kept messages starting at FirstKeptEntryID.
// We iterate through entries in order (not using getBranchLocked) to avoid
// walking back to old compacted messages.
// We stop when we reach the compaction entry to avoid double-counting.
// Step 1: IDs of the kept messages starting at FirstKeptEntryID.
// Iterate tm.entries in append order and stop at the compaction
// entry to avoid double-counting post-compaction messages.
if lastCompaction.FirstKeptEntryID != "" {
found := false
for _, entry := range tm.entries {
entryID := tm.EntryID(entry)
// Skip entries until we reach the first kept entry.
if !found {
if entryID == lastCompaction.FirstKeptEntryID {
found = true
@@ -1076,7 +1058,6 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
}
}
// Stop when we reach the compaction entry itself.
if entryID == lastCompaction.ID {
break
}
@@ -1100,6 +1081,28 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
}
}
// Step 2: IDs of entries after the compaction entry on the current
// branch (post-compaction messages).
for i := compactionIndex + 1; i < len(branch); i++ {
entry := branch[i]
switch e := entry.(type) {
case *MessageEntry:
msg, err := e.ToMessage()
if err != nil {
continue
}
msgs := msg.ToLLMMessages()
for range msgs {
ids = append(ids, e.ID)
}
case *BranchSummaryEntry:
if e.Summary != "" {
ids = append(ids, e.ID)
}
}
}
return ids
}
+6
View File
@@ -161,6 +161,12 @@ var SlashCommands = []SlashCommand{
Category: "Navigation",
Aliases: []string{"/r"},
},
{
Name: "/copy",
Description: "Copy the last message to the system clipboard",
Category: "System",
Aliases: []string{"/cp"},
},
{
Name: "/export",
Description: "Export session (JSONL by default, or /export path.jsonl)",
+117 -8
View File
@@ -1266,7 +1266,11 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.scrollList.autoScroll = false
case tea.MouseWheelDown:
m.scrollList.ScrollBy(scrollLines)
if m.scrollList.AtBottom() {
// Only re-enable auto-scroll when the user is not actively
// selecting text. Otherwise a wheel-down during a drag-select
// would re-arm GotoBottom on the next stream chunk, shifting
// the highlighted row out from under the cursor.
if m.scrollList.AtBottom() && !m.scrollList.IsMouseDown() {
m.scrollList.autoScroll = true
}
}
@@ -1274,9 +1278,14 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// ── Mouse click selection (crush-style character-level) ──────────────────
case tea.MouseClickMsg:
if msg.Button == tea.MouseLeft {
// Calculate viewport-relative coordinates.
viewY := msg.Y - m.scrollbackYOffset
if viewY >= 0 && viewY < m.scrollList.height {
// Compute the scrollback origin from the current frame's layout
// rather than the stale cached value from the previous View().
// scrollbackYOffset/scrollList.height are only refreshed inside
// View() and lag behind any state change that resized the header
// (extension widgets, warning rows, etc.) since the last render.
yOff, vpHeight := m.currentScrollbackBounds()
viewY := msg.Y - yOff
if viewY >= 0 && viewY < vpHeight {
// Clear any previous selection on a new click.
// HandleMouseDown will set up new selection state.
if m.scrollList.HandleMouseDown(msg.X, viewY) {
@@ -1287,8 +1296,9 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// ── Mouse motion/drag for character-level selection ──────────────────────
case tea.MouseMotionMsg:
viewY := msg.Y - m.scrollbackYOffset
if viewY >= 0 && viewY < m.scrollList.height {
yOff, vpHeight := m.currentScrollbackBounds()
viewY := msg.Y - yOff
if viewY >= 0 && viewY < vpHeight {
m.scrollList.HandleMouseDrag(msg.X, viewY)
}
@@ -1618,10 +1628,16 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// ── Cancel timer expired ─────────────────────────────────────────────────
case uicore.CancelTimerExpiredMsg:
if m.canceling {
m.layoutDirty = true
}
m.canceling = false
// ── Ctrl+C reset timer expired ────────────────────────────────────────────
case uicore.CtrlCResetMsg:
if m.ctrlCPressedOnce {
m.layoutDirty = true
}
m.ctrlCPressedOnce = false
// ── Input submitted ──────────────────────────────────────────────────────
@@ -3110,6 +3126,8 @@ func (m *AppModel) handleSlashCommand(sc *commands.SlashCommand, args string) te
return m.handleResumeCommand()
case "/export":
return m.handleExportCommand(args)
case "/copy":
return m.handleCopyCommand()
case "/share":
return m.handleShareCommand()
case "/import":
@@ -3524,6 +3542,7 @@ func (m *AppModel) printHelpMessage() {
"**System:**\n" +
"- `/compact [instructions]`: Summarise older messages to free context space\n" +
"- `/clear`: Clear message history\n" +
"- `/copy`: Copy the last message to the system clipboard\n" +
"- `/export [path]`: Export session as JSONL\n" +
"- `/import <path.jsonl>`: Import session from JSONL file\n" +
"- `/reset-usage`: Reset usage statistics\n" +
@@ -3760,7 +3779,12 @@ func (m *AppModel) appendStreamingChunk(role, content string) {
}
// Auto-scroll to bottom if enabled (iteratr pattern)
// Don't call SetItems() - the slice reference hasn't changed
if m.scrollList != nil {
//
// CRITICAL: never scroll the viewport while the user is actively
// selecting text (mouse button held). Doing so shifts the
// highlighted content out from under the cursor and produces the
// off-by-N-row drift users see when copy-selecting during streaming.
if m.scrollList != nil && !m.scrollList.IsMouseDown() {
if m.scrollList.autoScroll {
m.scrollList.GotoBottom()
} else if m.scrollList.AtBottom() {
@@ -3788,6 +3812,36 @@ func (m *AppModel) appendStreamingChunk(role, content string) {
m.refreshContent()
}
// currentScrollbackBounds returns the live (yOffset, viewportHeight) for the
// scrollback region, computed from the current state — not from the cached
// values populated inside View().
//
// scrollbackYOffset and scrollList.height are refreshed once per render, so
// any state change that resizes the header (extension widget toggles,
// warning rows, queued messages, etc.) leaves the cached values one frame
// stale. Mouse click handlers in Update() can then place the cursor on the
// wrong line, producing the off-by-N-row drift seen during copy-selection.
//
// This recomputes the header height by rendering it (cheap — the renderer
// returns "" when no extension header is set) and recomputes the viewport
// height the same way distributeHeight() does, so both inputs to the
// y → (item, line) mapping are always current.
func (m *AppModel) currentScrollbackBounds() (yOffset, viewportHeight int) {
// Force a fresh layout if anything in Update() marked the state dirty;
// otherwise scrollList.height still reflects the previous frame.
if m.layoutDirty {
m.distributeHeight()
m.layoutDirty = false
}
if headerView := m.renderHeaderFooter(m.getHeader); headerView != "" {
yOffset = lipgloss.Height(headerView)
}
if m.scrollList != nil {
viewportHeight = m.scrollList.height
}
return yOffset, viewportHeight
}
// distributeHeight recalculates child component heights after a window resize,
// queue change, widget update, or state transition, and propagates the computed
// stream height to the StreamComponent.
@@ -3860,7 +3914,20 @@ func (m *AppModel) distributeHeight() {
headerFooterLines += lipgloss.Height(footerView)
}
streamHeight := max(m.height-separatorLines-widgetLines-headerFooterLines-queuedLines-inputLines-statusBarLines, 0)
// Account for transient warning rows that View() injects between the
// scrollback and the separator. These flags are toggled by ESC/Ctrl+C
// handlers; without subtracting them here the joined view exceeds
// m.height by one line per active warning and the bottom of the screen
// gets silently clipped — which in turn invalidates scrollbackYOffset.
var warningLines int
if m.canceling {
warningLines++
}
if m.ctrlCPressedOnce {
warningLines++
}
streamHeight := max(m.height-separatorLines-widgetLines-headerFooterLines-queuedLines-inputLines-statusBarLines-warningLines, 0)
// In alt screen mode, give the calculated height to ScrollList instead of stream.
// The stream component still exists but is embedded as the last item in scrollList.
@@ -4284,6 +4351,48 @@ func (m *AppModel) handleNameCommand(args string) tea.Cmd {
return nil
}
// handleCopyCommand copies the last user or assistant message to the system
// clipboard. Skips transient system messages (e.g. /help output) so the user
// gets the actual last conversational message.
func (m *AppModel) handleCopyCommand() tea.Cmd {
if len(m.messages) == 0 {
m.printSystemMessage("No messages to copy.")
return nil
}
var (
text string
role string
)
for i := len(m.messages) - 1; i >= 0; i-- {
switch msg := m.messages[i].(type) {
case *TextMessageItem:
if msg.role == "user" || msg.role == "assistant" {
text = msg.content
role = msg.role
}
case *StreamingMessageItem:
if msg.role == "assistant" || msg.role == "reasoning" {
text = msg.content.String()
role = msg.role
}
}
if text != "" {
break
}
}
if strings.TrimSpace(text) == "" {
m.printSystemMessage("No copyable message found.")
return nil
}
m.printSystemMessage(fmt.Sprintf(
"Copied last %s message to clipboard (%d chars).", role, len(text),
))
return clipboard.CopyToClipboard(text)
}
// handleExportCommand exports the current session to a file.
// Usage: /export — copies the JSONL file to cwd with a descriptive name.
//
+19 -2
View File
@@ -60,10 +60,13 @@ func NewScrollList(width, height int) *ScrollList {
}
// SetItems replaces the items in the scroll list. If auto-scroll is enabled,
// the viewport will scroll to the bottom to show the latest content.
// the viewport will scroll to the bottom to show the latest content — EXCEPT
// when the user is actively selecting text (mouse button held), in which case
// the scroll position is locked so the highlighted content stays under the
// cursor. The pending bottom-scroll is deferred to MouseUp.
func (s *ScrollList) SetItems(items []MessageItem) {
s.items = items
if s.autoScroll {
if s.autoScroll && !s.sel.MouseDown {
s.GotoBottom()
}
}
@@ -157,6 +160,10 @@ func (s *ScrollList) HandleMouseDown(x, y int) bool {
// HandleMouseDrag handles mouse motion while button is held.
// Updates the selection endpoint for character-level precision.
// Returns true if selection was updated.
//
// Defensively disables auto-scroll on every drag update — even if the
// MouseDown handler missed (e.g. click landed in viewport padding), any
// active drag means the user is selecting and the viewport must not jump.
func (s *ScrollList) HandleMouseDrag(x, y int) bool {
if !s.sel.MouseDown {
return false
@@ -171,6 +178,9 @@ func (s *ScrollList) HandleMouseDrag(x, y int) bool {
return false
}
// Hard-lock the viewport while dragging.
s.autoScroll = false
s.sel.DragItemIdx = itemIdx
s.sel.DragLineIdx = lineIdx
s.sel.DragCol = x
@@ -178,6 +188,13 @@ func (s *ScrollList) HandleMouseDrag(x, y int) bool {
return true
}
// IsMouseDown reports whether the user currently has the mouse button held
// (i.e. a selection drag is in progress). Used by the parent model to avoid
// re-enabling auto-scroll during streaming while the user is selecting.
func (s *ScrollList) IsMouseDown() bool {
return s.sel.MouseDown
}
// HandleMouseUp handles mouse button release.
// Returns true if there was an active selection.
func (s *ScrollList) HandleMouseUp() bool {
+132
View File
@@ -0,0 +1,132 @@
package ui
import (
"fmt"
"strings"
"testing"
)
// fakeItem is a deterministic MessageItem for ScrollList tests.
type fakeItem struct {
id string
lines int
}
func (f *fakeItem) ID() string { return f.id }
func (f *fakeItem) Render(_ int) string {
if f.lines <= 0 {
return ""
}
parts := make([]string, f.lines)
for i := range parts {
parts[i] = fmt.Sprintf("%s-line-%d", f.id, i)
}
return strings.Join(parts, "\n")
}
func (f *fakeItem) Height() int { return f.lines }
// makeItems builds n fake items of `lines` height each.
func makeItems(n, lines int) []MessageItem {
out := make([]MessageItem, n)
for i := range out {
out[i] = &fakeItem{id: fmt.Sprintf("item-%d", i), lines: lines}
}
return out
}
// TestScrollList_MouseDownPreventsAutoScroll verifies the core fix for the
// copy-selection drift bug: while the user has the mouse button held
// (drag-selecting), incoming content updates must NOT shift the viewport,
// because doing so moves the highlighted content out from under the cursor.
func TestScrollList_MouseDownPreventsAutoScroll(t *testing.T) {
sl := NewScrollList(80, 10)
sl.SetItems(makeItems(20, 2)) // 40 lines of content into a 10-line viewport
// Capture the auto-scrolled-to-bottom position.
startOffsetIdx := sl.offsetIdx
startOffsetLine := sl.offsetLine
// User clicks somewhere in the visible area, starting a drag-select.
if !sl.HandleMouseDown(5, 3) {
t.Fatalf("HandleMouseDown should accept a click inside the viewport")
}
if !sl.IsMouseDown() {
t.Fatalf("IsMouseDown should be true after HandleMouseDown")
}
// New content arrives. With autoScroll still true, SetItems would
// normally call GotoBottom() and shift the viewport. The fix should
// suppress that while MouseDown is held.
sl.SetItems(makeItems(30, 2)) // 60 lines now
if sl.offsetIdx != startOffsetIdx || sl.offsetLine != startOffsetLine {
t.Errorf("viewport scrolled during active drag: was (%d,%d), now (%d,%d)",
startOffsetIdx, startOffsetLine, sl.offsetIdx, sl.offsetLine)
}
// User releases the mouse — drag is over.
sl.HandleMouseUp()
if sl.IsMouseDown() {
t.Fatalf("IsMouseDown should be false after HandleMouseUp")
}
// After release, a fresh content update should resume auto-scrolling
// (move the offset to track the new bottom).
afterReleaseIdx := sl.offsetIdx
afterReleaseLine := sl.offsetLine
sl.SetItems(makeItems(50, 2))
if sl.offsetIdx == afterReleaseIdx && sl.offsetLine == afterReleaseLine {
t.Errorf("autoscroll did not resume after MouseUp: offset stuck at (%d,%d)",
afterReleaseIdx, afterReleaseLine)
}
}
// TestScrollList_DragDisablesAutoScroll verifies that any successful
// HandleMouseDrag call clears autoScroll, even when HandleMouseDown didn't
// observe it (e.g. a stale wheel-down event set it back to true mid-stream).
func TestScrollList_DragDisablesAutoScroll(t *testing.T) {
sl := NewScrollList(80, 10)
sl.SetItems(makeItems(20, 2))
// Begin a selection.
if !sl.HandleMouseDown(5, 3) {
t.Fatalf("HandleMouseDown failed")
}
// Simulate an external code path that re-enabled autoScroll while
// MouseDown is still held (the precise condition that caused drift).
sl.autoScroll = true
// Drag motion should hard-lock the viewport again.
if !sl.HandleMouseDrag(10, 4) {
t.Fatalf("HandleMouseDrag failed")
}
if sl.autoScroll {
t.Errorf("HandleMouseDrag must clear autoScroll to prevent mid-drag jumps")
}
}
// TestScrollList_SetItemsRespectsMouseDown is the most direct regression
// test: even with autoScroll enabled and new content appended at the
// bottom, SetItems must not move the viewport while a mouse drag is in
// progress. This is what caused the "highlighting shifts by 1+ rows
// during streaming" symptom reported by the user.
func TestScrollList_SetItemsRespectsMouseDown(t *testing.T) {
sl := NewScrollList(80, 5)
sl.SetItems(makeItems(10, 2)) // 20 lines into a 5-line viewport
// At bottom.
preIdx, preLine := sl.offsetIdx, sl.offsetLine
// Hold mouse down (no actual drag needed).
if !sl.HandleMouseDown(0, 0) {
t.Fatalf("HandleMouseDown failed")
}
// Append several more items as if streaming. With the bug, each
// SetItems would call GotoBottom and shift the offset.
for n := 11; n <= 15; n++ {
sl.SetItems(makeItems(n, 2))
if sl.offsetIdx != preIdx || sl.offsetLine != preLine {
t.Fatalf("viewport drifted during streaming with mouse held: "+
"start=(%d,%d) now=(%d,%d) after adding item %d",
preIdx, preLine, sl.offsetIdx, sl.offsetLine, n)
}
}
}
+17 -2
View File
@@ -243,7 +243,7 @@ host.ClearSession()
## Re-exported Types
The SDK re-exports types so you don't need direct internal imports:
The SDK re-exports message/session/MCP types so you don't need direct internal imports. Agent-configuration types are Kit-owned (not aliases) and use only SDK types in their signatures, so consumers never need to import the underlying LLM-provider package.
```go
// Message types
@@ -251,13 +251,28 @@ kit.Message, kit.MessageRole, kit.ContentPart
kit.TextContent, kit.ReasoningContent, kit.ToolCall, kit.ToolResult, kit.Finish
kit.RoleUser, kit.RoleAssistant, kit.RoleTool, kit.RoleSystem
// LLM types — concrete Kit-owned structs, no external library dependency
// LLM types — Kit-owned `LLM*` aliases over the underlying provider types,
// so consumers never import the provider package directly
kit.LLMMessage // {Role LLMMessageRole, Content string}
kit.LLMMessageRole // "user" | "assistant" | "system" | "tool"
kit.LLMUsage // {InputTokens, OutputTokens, TotalTokens, ...}
kit.LLMResponse // {Content, FinishReason, Usage}
kit.LLMFilePart // {Filename, Data []byte, MediaType}
// Agent configuration — concrete Kit-owned structs and function types.
// All fields use SDK types (e.g. `[]kit.Tool`), so consumers can construct
// these without importing any LLM-provider package.
kit.AgentConfig // Lower-level agent config — prefer Options unless you need direct control
kit.DebugLogger // Interface: LogDebug(string) / IsDebugEnabled() bool
kit.MCPTaskConfig // Task-aware MCP tools/call config (modes, polling, progress)
kit.ToolCallHandler // func(toolCallID, toolName, toolArgs string)
kit.ToolExecutionHandler // func(toolCallID, toolName, toolArgs string, isStarting bool)
kit.ToolResultHandler // func(toolCallID, toolName, toolArgs, result, metadata string, isError bool)
kit.ResponseHandler // func(content string)
kit.StreamingResponseHandler // func(content string)
kit.ToolCallContentHandler // func(content string)
kit.SpinnerFunc // func(fn func() error) error
// MCP OAuth types
kit.MCPServer // *server.MCPServer for in-process MCP transport
kit.MCPServerConfig // Configuration for an MCP server (stdio, SSE, or in-process)
+208
View File
@@ -0,0 +1,208 @@
package kit
import (
"context"
"errors"
"testing"
"time"
"github.com/mark3labs/kit/internal/agent"
)
// TestAgentConfigToInternal verifies that the SDK-side AgentConfig converts
// faithfully to the internal agent.AgentConfig representation, preserving
// every field consumed by the internal agent layer.
//
// Regression test for https://github.com/mark3labs/kit/issues/30.
func TestAgentConfigToInternal(t *testing.T) {
t.Run("nil receiver returns nil", func(t *testing.T) {
var c *AgentConfig
if got := c.toInternal(); got != nil {
t.Errorf("nil.toInternal() = %v, want nil", got)
}
})
t.Run("scalar fields round-trip", func(t *testing.T) {
c := &AgentConfig{
SystemPrompt: "sys",
MaxSteps: 7,
StreamingEnabled: true,
DisableCoreTools: true,
}
got := c.toInternal()
if got == nil {
t.Fatal("toInternal() = nil")
}
if got.SystemPrompt != "sys" {
t.Errorf("SystemPrompt = %q, want %q", got.SystemPrompt, "sys")
}
if got.MaxSteps != 7 {
t.Errorf("MaxSteps = %d, want 7", got.MaxSteps)
}
if !got.StreamingEnabled {
t.Error("StreamingEnabled = false, want true")
}
if !got.DisableCoreTools {
t.Error("DisableCoreTools = false, want true")
}
})
t.Run("tool slices propagate without conversion", func(t *testing.T) {
// Tool is a type alias for the underlying LLM-tool type, so the
// SDK []Tool and internal []fantasy.AgentTool slices share the
// same backing array after conversion.
tool := NewTool[struct{}]("noop", "noop", nil)
c := &AgentConfig{
CoreTools: []Tool{tool},
ExtraTools: []Tool{tool, tool},
}
got := c.toInternal()
if len(got.CoreTools) != 1 {
t.Errorf("CoreTools len = %d, want 1", len(got.CoreTools))
}
if len(got.ExtraTools) != 2 {
t.Errorf("ExtraTools len = %d, want 2", len(got.ExtraTools))
}
})
t.Run("tool wrapper is invoked through internal config", func(t *testing.T) {
called := false
c := &AgentConfig{
ToolWrapper: func(in []Tool) []Tool {
called = true
return in
},
}
got := c.toInternal()
if got.ToolWrapper == nil {
t.Fatal("internal ToolWrapper is nil")
}
_ = got.ToolWrapper(nil)
if !called {
t.Error("SDK ToolWrapper was not invoked through the internal config")
}
})
t.Run("OnMCPServerLoaded propagates", func(t *testing.T) {
var captured string
wantErr := errors.New("boom")
c := &AgentConfig{
OnMCPServerLoaded: func(name string, _ int, _ error) {
captured = name
},
}
got := c.toInternal()
got.OnMCPServerLoaded("svr", 3, wantErr)
if captured != "svr" {
t.Errorf("OnMCPServerLoaded captured = %q, want %q", captured, "svr")
}
})
t.Run("DebugLogger propagates", func(t *testing.T) {
dl := &fakeDebugLogger{enabled: true}
c := &AgentConfig{DebugLogger: dl}
got := c.toInternal()
if got.DebugLogger == nil {
t.Fatal("internal DebugLogger is nil")
}
if !got.DebugLogger.IsDebugEnabled() {
t.Error("IsDebugEnabled = false, want true")
}
got.DebugLogger.LogDebug("hello")
if len(dl.messages) != 1 || dl.messages[0] != "hello" {
t.Errorf("messages = %v, want [hello]", dl.messages)
}
})
t.Run("MCPTaskConfig propagates with mode + progress", func(t *testing.T) {
c := &AgentConfig{
MCPTaskConfig: MCPTaskConfig{
PerServerMode: map[string]MCPTaskMode{
"build-svr": MCPTaskModeAlways,
},
DefaultTTL: 30 * time.Second,
PollInterval: 250 * time.Millisecond,
MaxPollInterval: 2 * time.Second,
Timeout: 5 * time.Minute,
Progress: func(_ MCPTaskProgress) {},
},
}
got := c.toInternal()
if got.MCPTaskConfig.DefaultTTL != 30*time.Second {
t.Errorf("DefaultTTL = %v, want 30s", got.MCPTaskConfig.DefaultTTL)
}
if got.MCPTaskConfig.PollInterval != 250*time.Millisecond {
t.Errorf("PollInterval = %v, want 250ms", got.MCPTaskConfig.PollInterval)
}
if got.MCPTaskConfig.MaxPollInterval != 2*time.Second {
t.Errorf("MaxPollInterval = %v, want 2s", got.MCPTaskConfig.MaxPollInterval)
}
if got.MCPTaskConfig.Timeout != 5*time.Minute {
t.Errorf("Timeout = %v, want 5m", got.MCPTaskConfig.Timeout)
}
mode, ok := got.MCPTaskConfig.PerServerMode["build-svr"]
if !ok {
t.Fatal("PerServerMode missing 'build-svr'")
}
if string(mode) != string(MCPTaskModeAlways) {
t.Errorf("mode = %q, want %q", mode, MCPTaskModeAlways)
}
if got.MCPTaskConfig.Progress == nil {
t.Fatal("internal Progress handler is nil")
}
})
t.Run("auth and token store factories are wired", func(t *testing.T) {
auth := &fakeAuthHandler{}
tokenCalls := 0
var tokenServer string
factory := MCPTokenStoreFactory(func(server string) (MCPTokenStore, error) {
tokenCalls++
tokenServer = server
return nil, nil
})
c := &AgentConfig{
AuthHandler: auth,
TokenStoreFactory: factory,
}
got := c.toInternal()
if got.AuthHandler == nil {
t.Fatal("internal AuthHandler is nil")
}
if got.TokenStoreFactory == nil {
t.Fatal("internal TokenStoreFactory is nil")
}
_, _ = got.TokenStoreFactory("https://example.test")
if tokenCalls != 1 {
t.Errorf("token factory call count = %d, want 1", tokenCalls)
}
if tokenServer != "https://example.test" {
t.Errorf("token factory server arg = %q", tokenServer)
}
if got.AuthHandler.RedirectURI() != "redirect" {
t.Errorf("RedirectURI = %q, want %q", got.AuthHandler.RedirectURI(), "redirect")
}
})
// Compile-time check that the internal type is what we expect.
//nolint:staticcheck // QF1011: explicit type asserts the conversion target.
var _ *agent.AgentConfig = (&AgentConfig{}).toInternal()
}
// fakeAuthHandler implements both kit.MCPAuthHandler and the structurally
// identical tools.MCPAuthHandler used by the internal layer.
type fakeAuthHandler struct{}
func (f *fakeAuthHandler) RedirectURI() string { return "redirect" }
func (f *fakeAuthHandler) HandleAuth(_ context.Context, _ string, _ string) (string, error) {
return "", nil
}
// fakeDebugLogger implements kit.DebugLogger for tests.
type fakeDebugLogger struct {
enabled bool
messages []string
}
func (f *fakeDebugLogger) LogDebug(m string) { f.messages = append(f.messages, m) }
func (f *fakeDebugLogger) IsDebugEnabled() bool { return f.enabled }
+1 -1
View File
@@ -1489,7 +1489,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
if opts.CLI != nil {
setupOpts.ShowSpinner = opts.CLI.ShowSpinner
setupOpts.SpinnerFunc = opts.CLI.SpinnerFunc
setupOpts.SpinnerFunc = agent.SpinnerFunc(opts.CLI.SpinnerFunc)
setupOpts.UseBufferedLogger = opts.CLI.UseBufferedLogger
if opts.CLI.ProgressReaderFunc != nil {
providerConfig.ProgressReaderFunc = opts.CLI.ProgressReaderFunc
+64
View File
@@ -98,6 +98,70 @@ type MCPTaskProgress struct {
// dispatched on a goroutine.
type MCPTaskProgressHandler func(MCPTaskProgress)
// MCPTaskConfig configures task-aware MCP tools/call execution. All fields
// are optional; the zero value disables progress callbacks and applies
// sensible polling defaults inside the engine.
//
// For most consumers, the flat [Options] fields (`MCPTaskMode`,
// `MCPTaskTTL`, `MCPTaskPollInterval`, `MCPTaskMaxPollInterval`,
// `MCPTaskTimeout`, `MCPTaskProgress`) are the preferred entry point.
// MCPTaskConfig is exposed for the low-level [AgentConfig] path.
type MCPTaskConfig struct {
// PerServerMode overrides the per-server task mode resolved from
// [MCPServerConfig]. Keys are server names. Missing entries fall back
// to the configured value.
PerServerMode map[string]MCPTaskMode
// DefaultTTL is the TTL hint sent in TaskParams when augmenting a
// tools/call. Zero means omit the TTL — let the server pick its own.
DefaultTTL time.Duration
// PollInterval is the fallback interval between tasks/get requests
// when the server does not suggest one. Zero defaults to 1 second.
PollInterval time.Duration
// MaxPollInterval caps the polling interval. Zero defaults to 5 seconds.
MaxPollInterval time.Duration
// Timeout is the maximum wall-clock duration to wait for a task to
// reach a terminal state. Zero defaults to 15 minutes. Independent
// of the per-call context deadline; whichever fires first wins.
Timeout time.Duration
// Progress, if non-nil, receives every status transition observed by
// the polling loop.
Progress MCPTaskProgressHandler
}
// toToolsConfig converts the SDK-level [MCPTaskConfig] to the internal
// tools-package representation. Keeps the dependency arrow internal-only.
func (c MCPTaskConfig) toToolsConfig() tools.MCPTaskConfig {
cfg := tools.MCPTaskConfig{
DefaultTTL: c.DefaultTTL,
PollInterval: c.PollInterval,
MaxPollInterval: c.MaxPollInterval,
Timeout: c.Timeout,
}
if len(c.PerServerMode) > 0 {
cfg.PerServerMode = make(map[string]tools.MCPTaskMode, len(c.PerServerMode))
for k, v := range c.PerServerMode {
cfg.PerServerMode[k] = tools.MCPTaskMode(v)
}
}
if c.Progress != nil {
h := c.Progress
cfg.Progress = func(p tools.MCPTaskProgress) {
h(MCPTaskProgress{
Server: p.Server,
TaskID: p.TaskID,
Status: MCPTaskStatus(p.Status),
Message: p.Message,
})
}
}
return cfg
}
// mcpTaskOptions carries SDK consumer configuration into the agent setup.
// Stored on Options as a single value so the public surface stays compact;
// individual fields are exposed via WithMCP* builder functions.
+145 -18
View File
@@ -11,6 +11,7 @@ import (
"github.com/mark3labs/kit/internal/message"
"github.com/mark3labs/kit/internal/models"
"github.com/mark3labs/kit/internal/session"
"github.com/mark3labs/kit/internal/tools"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/server"
)
@@ -75,25 +76,151 @@ type Config = config.Config
// local (stdio) and remote (StreamableHTTP/SSE) server types.
type MCPServerConfig = config.MCPServerConfig
// ==== Agent Types (internal/agent/) ====
// ==== Agent Types ====
// AgentConfig holds configuration options for creating a new Agent.
type AgentConfig = agent.AgentConfig
// DebugLogger is an SDK-owned interface for low-level debug logging from
// the engine and MCP tool plumbing. Implementations must be safe for
// concurrent use.
//
// Most consumers do not need to provide one; pass [Options.Debug] = true
// to use the default logger. DebugLogger is exposed for the low-level
// [AgentConfig] path and for embedders that want to route debug output
// into their own logging system.
type DebugLogger interface {
// LogDebug records a single debug message. Implementations may drop,
// buffer, or render the message however they choose.
LogDebug(message string)
// IsDebugEnabled reports whether debug logging is active. Callers may
// check this before doing expensive formatting work.
IsDebugEnabled() bool
}
type (
// ToolCallHandler is a function type for handling tool calls as they happen.
ToolCallHandler = agent.ToolCallHandler
// ToolExecutionHandler is a function type for handling tool execution start/end events.
ToolExecutionHandler = agent.ToolExecutionHandler
// ToolResultHandler is a function type for handling tool results.
ToolResultHandler = agent.ToolResultHandler
// ResponseHandler is a function type for handling LLM responses.
ResponseHandler = agent.ResponseHandler
// StreamingResponseHandler is a function type for handling streaming LLM responses.
StreamingResponseHandler = agent.StreamingResponseHandler
// ToolCallContentHandler is a function type for handling content that accompanies tool calls.
ToolCallContentHandler = agent.ToolCallContentHandler
)
// AgentConfig holds configuration options for constructing an agent at the
// SDK boundary. All fields use SDK-owned types, so consumers can populate
// this struct without importing any underlying LLM-provider package.
//
// For most use cases, prefer the high-level [New] entry point with
// [Options]. AgentConfig is exposed for advanced consumers that need
// direct access to the lower-level agent configuration shape.
type AgentConfig struct {
// ModelConfig holds the LLM provider configuration. A nil value means
// that the default provider/model resolution will be used.
ModelConfig *ProviderConfig
// MCPConfig describes any MCP servers whose tools should be loaded
// alongside core tools.
MCPConfig *Config
// SystemPrompt is the system prompt sent to the LLM.
SystemPrompt string
// MaxSteps caps the number of LLM iterations per turn. A value of
// zero means no cap is applied at this layer.
MaxSteps int
// StreamingEnabled controls whether the agent streams responses.
StreamingEnabled bool
// AuthHandler handles OAuth authorization for remote MCP servers.
// When nil, remote MCP servers requiring OAuth will fail to connect.
AuthHandler MCPAuthHandler
// TokenStoreFactory, if non-nil, creates a custom token store for each
// remote MCP server's OAuth tokens. When nil, the default file-based
// token store is used.
TokenStoreFactory MCPTokenStoreFactory
// CoreTools overrides the default core tool set. If empty, [AllTools]
// is used. Provide a custom tool set (e.g. [CodingTools] or tools
// built with a custom WorkDir) to scope agent capabilities.
CoreTools []Tool
// DisableCoreTools, when true, prevents loading any core tools.
// Combined with empty CoreTools this yields a chat-only agent with
// no built-in tools.
DisableCoreTools bool
// ExtraTools are additional tools loaded alongside core and MCP tools.
ExtraTools []Tool
// ToolWrapper, if non-nil, wraps the combined tool list before it is
// handed to the LLM. Used to intercept tool calls or results.
ToolWrapper func([]Tool) []Tool
// OnMCPServerLoaded, if non-nil, is invoked once for each MCP server
// when its tools have finished loading (or failed). Called from a
// background goroutine.
OnMCPServerLoaded func(serverName string, toolCount int, err error)
// DebugLogger receives low-level debug output from the engine and the
// MCP tool plumbing. Nil means no debug output is emitted at this
// layer (regardless of [Options.Debug], which feeds the higher-level
// [New] entry point). Pass an implementation here when wiring a custom
// logger through the lower-level AgentConfig path.
DebugLogger DebugLogger
// MCPTaskConfig configures task-aware MCP tools/call execution — mode
// overrides, polling intervals, timeouts, and the progress handler.
// The zero value preserves historical synchronous-only behaviour for
// any server that didn't advertise task support during initialize.
MCPTaskConfig MCPTaskConfig
}
// toInternal converts an AgentConfig to its internal representation.
// Slice and function fields convert without allocation because [Tool]
// is a type alias for the underlying LLM-tool type.
func (c *AgentConfig) toInternal() *agent.AgentConfig {
if c == nil {
return nil
}
out := &agent.AgentConfig{
ModelConfig: c.ModelConfig,
MCPConfig: c.MCPConfig,
SystemPrompt: c.SystemPrompt,
MaxSteps: c.MaxSteps,
StreamingEnabled: c.StreamingEnabled,
CoreTools: c.CoreTools,
DisableCoreTools: c.DisableCoreTools,
ExtraTools: c.ExtraTools,
ToolWrapper: c.ToolWrapper,
OnMCPServerLoaded: c.OnMCPServerLoaded,
}
if c.AuthHandler != nil {
out.AuthHandler = c.AuthHandler
}
if c.TokenStoreFactory != nil {
out.TokenStoreFactory = tools.TokenStoreFactory(c.TokenStoreFactory)
}
if c.DebugLogger != nil {
out.DebugLogger = c.DebugLogger
}
out.MCPTaskConfig = c.MCPTaskConfig.toToolsConfig()
return out
}
// ToolCallHandler is invoked when the LLM produces a tool call. It receives
// the call ID, tool name, and the JSON-encoded input arguments.
type ToolCallHandler func(toolCallID, toolName, toolArgs string)
// ToolExecutionHandler is invoked at the start and end of tool execution.
// The isStarting flag distinguishes the two phases.
type ToolExecutionHandler func(toolCallID, toolName, toolArgs string, isStarting bool)
// ToolResultHandler is invoked after a tool finishes executing. The metadata
// parameter carries optional structured data (e.g. file-diff info) from the
// tool execution, JSON-encoded; it may be empty.
type ToolResultHandler func(toolCallID, toolName, toolArgs, result, metadata string, isError bool)
// ResponseHandler is invoked with the final assistant text for each turn.
type ResponseHandler func(content string)
// StreamingResponseHandler is invoked with each streamed text delta as it
// arrives from the LLM.
type StreamingResponseHandler func(content string)
// ToolCallContentHandler is invoked with any assistant text that accompanies
// a tool call within the same step.
type ToolCallContentHandler func(content string)
// ==== Provider & Model Types (internal/models/) ====
@@ -126,7 +253,7 @@ type ModelsRegistry = models.ModelsRegistry
// SpinnerFunc wraps a function in a loading spinner animation. Used for
// Ollama model loading. Signature: func(fn func() error) error.
type SpinnerFunc = agent.SpinnerFunc
type SpinnerFunc func(fn func() error) error
// ==== LLM Types ====
//
+96
View File
@@ -1,6 +1,7 @@
package kit_test
import (
"context"
"encoding/json"
"testing"
@@ -263,6 +264,101 @@ func TestConvertFromLLMMessage(t *testing.T) {
}
}
// TestAgentConfigNoFantasyImport verifies AgentConfig can be populated with
// every field — including CoreTools, ExtraTools, and ToolWrapper — using
// only SDK-owned types. This test deliberately does not import
// "charm.land/fantasy"; the package compiling at all is the proof that the
// SDK no longer leaks the dependency name through AgentConfig.
//
// Regression test for https://github.com/mark3labs/kit/issues/30.
func TestAgentConfigNoFantasyImport(t *testing.T) {
myTool := kit.NewTool[struct{}]("noop", "does nothing", func(_ context.Context, _ struct{}) (kit.ToolOutput, error) {
return kit.TextResult("ok"), nil
})
wrapperCalled := false
cfg := kit.AgentConfig{
SystemPrompt: "you are a tester",
MaxSteps: 5,
StreamingEnabled: true,
CoreTools: []kit.Tool{myTool},
ExtraTools: []kit.Tool{myTool},
DisableCoreTools: false,
ToolWrapper: func(in []kit.Tool) []kit.Tool {
wrapperCalled = true
return in
},
OnMCPServerLoaded: func(_ string, _ int, _ error) {},
}
if cfg.SystemPrompt != "you are a tester" {
t.Errorf("SystemPrompt = %q, want %q", cfg.SystemPrompt, "you are a tester")
}
if cfg.MaxSteps != 5 {
t.Errorf("MaxSteps = %d, want 5", cfg.MaxSteps)
}
if !cfg.StreamingEnabled {
t.Error("StreamingEnabled = false, want true")
}
if len(cfg.CoreTools) != 1 {
t.Errorf("CoreTools len = %d, want 1", len(cfg.CoreTools))
}
if len(cfg.ExtraTools) != 1 {
t.Errorf("ExtraTools len = %d, want 1", len(cfg.ExtraTools))
}
// Exercise the wrapper to confirm the func type is usable.
out := cfg.ToolWrapper(cfg.CoreTools)
if !wrapperCalled {
t.Error("ToolWrapper was not invoked")
}
if len(out) != 1 {
t.Errorf("wrapped tool list len = %d, want 1", len(out))
}
}
// TestAgentConfigToolWrapperSignature documents that AgentConfig.ToolWrapper
// uses kit.Tool (not the underlying provider type) in its signature.
func TestAgentConfigToolWrapperSignature(t *testing.T) {
//nolint:staticcheck // QF1011: explicit type asserts the SDK-side func signature.
var _ func([]kit.Tool) []kit.Tool = func(in []kit.Tool) []kit.Tool { return in }
cfg := kit.AgentConfig{
ToolWrapper: func(in []kit.Tool) []kit.Tool { return in },
}
if cfg.ToolWrapper == nil {
t.Fatal("ToolWrapper assignment failed")
}
}
// TestSpinnerFuncSignature verifies SpinnerFunc has the documented signature
// and can be constructed without importing any provider package.
func TestSpinnerFuncSignature(t *testing.T) {
called := false
var sp kit.SpinnerFunc = func(fn func() error) error {
called = true
return fn()
}
err := sp(func() error { return nil })
if err != nil {
t.Errorf("SpinnerFunc returned err: %v", err)
}
if !called {
t.Error("SpinnerFunc did not invoke fn")
}
}
// TestHandlerTypesSignatures verifies the SDK-owned handler function types
// can be assigned from plain function literals using only standard library
// types in their signatures (no provider-package import required).
func TestHandlerTypesSignatures(t *testing.T) {
var _ kit.ToolCallHandler = func(_, _, _ string) {}
var _ kit.ToolExecutionHandler = func(_, _, _ string, _ bool) {}
var _ kit.ToolResultHandler = func(_, _, _, _, _ string, _ bool) {}
var _ kit.ResponseHandler = func(_ string) {}
var _ kit.StreamingResponseHandler = func(_ string) {}
var _ kit.ToolCallContentHandler = func(_ string) {}
}
// containsStr is a tiny helper to avoid importing strings in test.
func containsStr(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && indexStr(s, substr) >= 0)