mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3ea0db69ea | |||
| 4304a5e899 | |||
| 4019c1e4f7 | |||
| 30ad7c1d0b | |||
| e33564c569 | |||
| 5ff28445fd | |||
| 13d177e5d0 | |||
| 3ffc995f27 |
@@ -545,6 +545,28 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
})
|
||||
```
|
||||
|
||||
### Custom Tools
|
||||
|
||||
Create custom tools with automatic schema generation — no external dependencies needed:
|
||||
|
||||
```go
|
||||
type SearchInput struct {
|
||||
Query string `json:"query" description:"Search query"`
|
||||
}
|
||||
|
||||
searchTool := kit.NewTool("search", "Search the codebase",
|
||||
func(ctx context.Context, input SearchInput) (kit.ToolOutput, error) {
|
||||
return kit.TextResult("Found: ..."), nil
|
||||
},
|
||||
)
|
||||
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
ExtraTools: []kit.Tool{searchTool}, // adds alongside built-in tools
|
||||
})
|
||||
```
|
||||
|
||||
Use `kit.NewParallelTool` for tools safe to run concurrently. See the [SDK docs](/sdk/overview) for full details on struct tags, `ToolOutput` fields, and `ToolCallIDFromContext`.
|
||||
|
||||
### With Callbacks
|
||||
|
||||
```go
|
||||
|
||||
+31
-3
@@ -89,6 +89,14 @@ type ReasoningCompleteHandler func()
|
||||
// Note: This is an alias for core.ToolOutputCallback to avoid import cycles.
|
||||
type ToolOutputHandler = core.ToolOutputCallback
|
||||
|
||||
// StepMessagesHandler is a function type for persisting messages after each
|
||||
// complete step in a multi-step agent turn. The handler receives the messages
|
||||
// produced by the step (typically an assistant message with tool calls followed
|
||||
// by a tool-role message with results, or a final assistant message with text).
|
||||
// This enables incremental session persistence so that progress is saved as
|
||||
// it happens rather than only at the end of the turn.
|
||||
type StepMessagesHandler func(stepMessages []fantasy.Message)
|
||||
|
||||
// StepUsageHandler is a function type for handling token usage after each
|
||||
// complete step in a multi-step agent turn. This enables real-time cost
|
||||
// tracking during long-running tool-calling conversations.
|
||||
@@ -141,6 +149,11 @@ type GenerateWithLoopResult struct {
|
||||
TotalUsage fantasy.Usage
|
||||
// StopReason is the LLM provider's finish reason for the final response.
|
||||
StopReason string
|
||||
// PersistedMessageCount is the number of new messages (beyond the original
|
||||
// input) that were already persisted incrementally via OnStepMessages during
|
||||
// generation. The caller should skip these when doing post-generation
|
||||
// persistence to avoid duplicates.
|
||||
PersistedMessageCount int
|
||||
}
|
||||
|
||||
// NewAgent creates a new Agent with core tools and optional MCP tool integration.
|
||||
@@ -377,7 +390,7 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []fantasy.Message
|
||||
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult,
|
||||
onResponse, onToolCallContent, nil, nil, nil, nil, nil)
|
||||
onResponse, onToolCallContent, nil, nil, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
// GenerateWithLoopAndStreaming processes messages using the agent with streaming and callbacks.
|
||||
@@ -390,6 +403,7 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
onReasoningDelta ReasoningDeltaHandler,
|
||||
onReasoningComplete ReasoningCompleteHandler,
|
||||
onToolOutput ToolOutputHandler,
|
||||
onStepMessages StepMessagesHandler,
|
||||
onStepUsage StepUsageHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
|
||||
@@ -429,6 +443,10 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
// when it returns an error, but the OnStepFinish callback fires
|
||||
// for every step that completed before the error occurred.
|
||||
var completedStepMessages []fantasy.Message
|
||||
// persistedCount tracks how many new messages (beyond the original
|
||||
// input) were persisted incrementally via onStepMessages, so the
|
||||
// caller can skip them during post-generation persistence.
|
||||
var persistedCount int
|
||||
|
||||
// Use the streaming agent
|
||||
streamCall := fantasy.AgentStreamCall{
|
||||
@@ -514,6 +532,13 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
// persisted even if a later step is cancelled.
|
||||
completedStepMessages = append(completedStepMessages, step.Messages...)
|
||||
|
||||
// Persist step messages incrementally so progress is saved
|
||||
// as it happens rather than only at the end of the turn.
|
||||
if onStepMessages != nil && len(step.Messages) > 0 {
|
||||
onStepMessages(step.Messages)
|
||||
persistedCount += len(step.Messages)
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
@@ -592,7 +617,8 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
partialMessages = append(partialMessages, messages...)
|
||||
partialMessages = append(partialMessages, completedStepMessages...)
|
||||
return &GenerateWithLoopResult{
|
||||
ConversationMessages: partialMessages,
|
||||
ConversationMessages: partialMessages,
|
||||
PersistedMessageCount: persistedCount,
|
||||
}, err
|
||||
}
|
||||
return nil, err
|
||||
@@ -607,7 +633,9 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
onResponse(result.Response.Content.Text())
|
||||
}
|
||||
|
||||
return convertAgentResult(result, messages), nil
|
||||
r := convertAgentResult(result, messages)
|
||||
r.PersistedMessageCount = persistedCount
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// Non-streaming path with no callbacks — use the simpler Generate call.
|
||||
|
||||
@@ -2,11 +2,11 @@ package extensions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/traefik/yaegi/interp"
|
||||
"github.com/traefik/yaegi/stdlib"
|
||||
"github.com/traefik/yaegi/stdlib/unrestricted"
|
||||
@@ -34,11 +34,10 @@ func LoadExtensions(extraPaths []string) ([]LoadedExtension, error) {
|
||||
for _, p := range paths {
|
||||
ext, err := loadSingleExtension(p)
|
||||
if err != nil {
|
||||
log.Printf("WARN skipping extension: path=%s err=%v", p, err)
|
||||
continue
|
||||
}
|
||||
loaded = append(loaded, *ext)
|
||||
log.Printf("DEBUG loaded extension: path=%s handlers=%d tools=%d commands=%d tool_renderers=%d", p, countHandlers(ext), len(ext.Tools), len(ext.Commands), len(ext.ToolRenderers))
|
||||
log.Debug("loaded extension", "path", p, "handlers", countHandlers(ext), "tools", len(ext.Tools), "commands", len(ext.Commands), "tool_renderers", len(ext.ToolRenderers))
|
||||
}
|
||||
return loaded, nil
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"charm.land/lipgloss/v2"
|
||||
"golang.org/x/term"
|
||||
|
||||
@@ -173,33 +172,6 @@ func (c *CLI) DisplayDebugConfig(config map[string]any) {
|
||||
fmt.Println(c.renderer.RenderDebugConfigMessage(config, time.Now()).Content)
|
||||
}
|
||||
|
||||
// UpdateUsageFromResponse records token usage using metadata from the fantasy
|
||||
// response. Only actual API-reported tokens are used for cost tracking.
|
||||
// If the provider doesn't report token counts, no usage is recorded.
|
||||
func (c *CLI) UpdateUsageFromResponse(response *fantasy.Response, inputText string) {
|
||||
if c.usageTracker == nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage := response.Usage
|
||||
inputTokens := int(usage.InputTokens)
|
||||
outputTokens := int(usage.OutputTokens)
|
||||
|
||||
// Only use actual API-reported tokens for cost tracking.
|
||||
// We intentionally do NOT estimate tokens - estimation is inaccurate
|
||||
// and should never be used for cost calculations.
|
||||
if inputTokens > 0 {
|
||||
cacheReadTokens := int(usage.CacheReadTokens)
|
||||
cacheWriteTokens := int(usage.CacheCreationTokens)
|
||||
c.usageTracker.UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens)
|
||||
// Per-response usage is a single API call, so it represents the
|
||||
// actual context window fill level.
|
||||
c.usageTracker.SetContextTokens(inputTokens + outputTokens)
|
||||
}
|
||||
// If inputTokens is 0, the provider didn't report usage - we skip recording
|
||||
// rather than estimating, to ensure cost accuracy.
|
||||
}
|
||||
|
||||
// DisplayUsageAfterResponse renders and displays token usage information immediately
|
||||
// following an AI response. This provides real-time feedback about the cost and
|
||||
// token consumption of each interaction.
|
||||
|
||||
@@ -69,7 +69,7 @@ type InputComponent struct {
|
||||
hideHint bool
|
||||
|
||||
// agentBusy indicates the agent is currently working. When true, the
|
||||
// hint text shows steering shortcut (Ctrl+S) instead of submit.
|
||||
// hint text shows steering shortcut (Ctrl+X s) instead of submit.
|
||||
agentBusy bool
|
||||
|
||||
// pendingImages holds clipboard images attached to the next submission.
|
||||
@@ -109,7 +109,7 @@ func NewInputComponent(width int, title string, appCtrl AppController) *InputCom
|
||||
ta.Placeholder = "Type your message..."
|
||||
ta.ShowLineNumbers = false
|
||||
ta.Prompt = ""
|
||||
ta.CharLimit = 5000
|
||||
ta.CharLimit = 0
|
||||
ta.SetWidth(width - 8) // Account for container padding, border and internal padding
|
||||
ta.SetHeight(3) // Default to 3 lines like huh
|
||||
ta.Focus()
|
||||
@@ -514,12 +514,12 @@ func (s *InputComponent) View() tea.View {
|
||||
availableHintWidth := s.width - 3
|
||||
if s.agentBusy {
|
||||
// When the agent is working, show steering shortcut.
|
||||
if availableHintWidth >= 55 {
|
||||
hint = "enter queue • ctrl+s steer • esc esc cancel"
|
||||
} else if availableHintWidth >= 35 {
|
||||
hint = "↵ queue • ^S steer • esc×2 cancel"
|
||||
if availableHintWidth >= 60 {
|
||||
hint = "enter queue • ctrl+x s steer • esc esc cancel"
|
||||
} else if availableHintWidth >= 40 {
|
||||
hint = "↵ queue • ^X s steer • esc×2 cancel"
|
||||
} else {
|
||||
hint = "^S steer"
|
||||
hint = "^X s steer"
|
||||
}
|
||||
} else if availableHintWidth >= 67 {
|
||||
hint = "enter submit • ctrl+j / shift+enter new line • ctrl+v paste image"
|
||||
|
||||
@@ -152,7 +152,7 @@ func (r *MessageRenderer) SetWidth(width int) {
|
||||
|
||||
// RenderUserMessage renders a user's input message using herald Tip alert
|
||||
func (r *MessageRenderer) RenderUserMessage(content string, timestamp time.Time) UIMessage {
|
||||
rendered := render.UserBlock(content, r.ty, style.GetTheme())
|
||||
rendered := render.UserBlock(content, r.width, r.ty, style.GetTheme())
|
||||
|
||||
return UIMessage{
|
||||
Type: UserMessage,
|
||||
|
||||
+143
-59
@@ -477,7 +477,7 @@ type AppModel struct {
|
||||
queuedMessages []string
|
||||
|
||||
// steeringMessages stores the text of prompts that were sent as steer
|
||||
// messages (injected mid-turn via Ctrl+S). Rendered with a "STEERING"
|
||||
// messages (injected mid-turn via Ctrl+X s). Rendered with a "STEERING"
|
||||
// badge above the input. Cleared when the steer is consumed.
|
||||
steeringMessages []string
|
||||
|
||||
@@ -498,6 +498,11 @@ type AppModel struct {
|
||||
// A second ESC within 2 seconds will cancel the current step.
|
||||
canceling bool
|
||||
|
||||
// leaderKeyActive tracks whether the Ctrl+X leader key prefix has been
|
||||
// pressed. The next keypress is interpreted as a chord suffix (e.g. "s"
|
||||
// for steer). Cleared on any subsequent keypress.
|
||||
leaderKeyActive bool
|
||||
|
||||
// providerName is the LLM provider for the startup message.
|
||||
providerName string
|
||||
|
||||
@@ -1268,6 +1273,71 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
// ── Leader key chord handling (Ctrl+X prefix) ──────────────
|
||||
// If the leader key was previously pressed, the current key
|
||||
// completes the chord. We consume it regardless of match so
|
||||
// the prefix doesn't leak to child components.
|
||||
if m.leaderKeyActive {
|
||||
m.leaderKeyActive = false
|
||||
switch msg.String() {
|
||||
case "s":
|
||||
// Ctrl+X s → Steer: inject the current input as a steering
|
||||
// message into the running agent turn.
|
||||
if m.state == stateWorking && m.appCtrl != nil {
|
||||
var text string
|
||||
if ic, ok := m.input.(*InputComponent); ok {
|
||||
text = strings.TrimSpace(ic.textarea.Value())
|
||||
}
|
||||
if text != "" {
|
||||
// Clear the input, collect pending images, and push to history.
|
||||
var images []uicore.ImageAttachment
|
||||
if ic, ok := m.input.(*InputComponent); ok {
|
||||
ic.pushHistory(text)
|
||||
ic.textarea.SetValue("")
|
||||
images = ic.ClearPendingImages()
|
||||
}
|
||||
|
||||
// Preprocess @file references.
|
||||
processedText := text
|
||||
if m.cwd != "" {
|
||||
processedText = fileutil.ProcessFileAttachments(text, m.cwd)
|
||||
}
|
||||
|
||||
// Convert image attachments to kit.LLMFilePart for the app layer.
|
||||
var fileParts []kit.LLMFilePart
|
||||
for _, img := range images {
|
||||
fileParts = append(fileParts, kit.LLMFilePart{
|
||||
Data: img.Data,
|
||||
MediaType: img.MediaType,
|
||||
})
|
||||
}
|
||||
|
||||
// Build display text (include image count if any).
|
||||
displayText := text
|
||||
if len(images) > 0 {
|
||||
displayText = fmt.Sprintf("%s\n[%d image(s) attached]", text, len(images))
|
||||
}
|
||||
|
||||
// Inject the steer message.
|
||||
sLen := m.appCtrl.SteerWithFiles(processedText, fileParts)
|
||||
if sLen > 0 {
|
||||
m.steeringMessages = append(m.steeringMessages, displayText)
|
||||
m.layoutDirty = true
|
||||
} else {
|
||||
// Started immediately (agent was idle).
|
||||
m.pendingUserPrints = append(m.pendingUserPrints, displayText)
|
||||
m.flushStreamAndPendingUserMessages()
|
||||
if m.state != stateWorking {
|
||||
m.state = stateWorking
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Chord consumed — don't propagate to children.
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
switch msg.String() {
|
||||
case "esc":
|
||||
if m.state == stateWorking {
|
||||
@@ -1286,61 +1356,10 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
// In other states pass ESC through to children below.
|
||||
|
||||
case "ctrl+s":
|
||||
// Steer: inject the current input as a steering message into the
|
||||
// running agent turn. Only active during stateWorking — in input
|
||||
// state, Ctrl+S is passed through to children (no-op by default).
|
||||
if m.state == stateWorking && m.appCtrl != nil {
|
||||
var text string
|
||||
if ic, ok := m.input.(*InputComponent); ok {
|
||||
text = strings.TrimSpace(ic.textarea.Value())
|
||||
}
|
||||
if text != "" {
|
||||
// Clear the input, collect pending images, and push to history.
|
||||
var images []uicore.ImageAttachment
|
||||
if ic, ok := m.input.(*InputComponent); ok {
|
||||
ic.pushHistory(text)
|
||||
ic.textarea.SetValue("")
|
||||
images = ic.ClearPendingImages()
|
||||
}
|
||||
|
||||
// Preprocess @file references.
|
||||
processedText := text
|
||||
if m.cwd != "" {
|
||||
processedText = fileutil.ProcessFileAttachments(text, m.cwd)
|
||||
}
|
||||
|
||||
// Convert image attachments to kit.LLMFilePart for the app layer.
|
||||
var fileParts []kit.LLMFilePart
|
||||
for _, img := range images {
|
||||
fileParts = append(fileParts, kit.LLMFilePart{
|
||||
Data: img.Data,
|
||||
MediaType: img.MediaType,
|
||||
})
|
||||
}
|
||||
|
||||
// Build display text (include image count if any).
|
||||
displayText := text
|
||||
if len(images) > 0 {
|
||||
displayText = fmt.Sprintf("%s\n[%d image(s) attached]", text, len(images))
|
||||
}
|
||||
|
||||
// Inject the steer message.
|
||||
sLen := m.appCtrl.SteerWithFiles(processedText, fileParts)
|
||||
if sLen > 0 {
|
||||
m.steeringMessages = append(m.steeringMessages, displayText)
|
||||
m.layoutDirty = true
|
||||
} else {
|
||||
// Started immediately (agent was idle).
|
||||
m.pendingUserPrints = append(m.pendingUserPrints, displayText)
|
||||
m.flushStreamAndPendingUserMessages()
|
||||
if m.state != stateWorking {
|
||||
m.state = stateWorking
|
||||
}
|
||||
}
|
||||
}
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
case "ctrl+x":
|
||||
// Activate leader key prefix — the next keypress completes the chord.
|
||||
m.leaderKeyActive = true
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
// Route key events to the focused child. Check for editor
|
||||
@@ -2462,22 +2481,34 @@ func (m *AppModel) renderHeaderFooter(getter func() *WidgetData) string {
|
||||
return renderContentBlock(data.Text, m.width, opts...)
|
||||
}
|
||||
|
||||
// maxQueuedMessageLines is the maximum number of visible content lines
|
||||
// rendered for each queued or steering message block. Messages exceeding
|
||||
// this limit are truncated with an ellipsis to prevent large pastes from
|
||||
// overflowing the screen and squeezing the stream region to zero.
|
||||
const maxQueuedMessageLines = 3
|
||||
|
||||
// renderQueuedMessages renders queued and steering prompts as styled content
|
||||
// blocks with badges, anchored between the separator and input. Steering
|
||||
// messages use a distinct "STEERING" badge to differentiate from queued ones.
|
||||
// Long messages are visually truncated to maxQueuedMessageLines.
|
||||
func (m *AppModel) renderQueuedMessages() string {
|
||||
if len(m.queuedMessages) == 0 && len(m.steeringMessages) == 0 {
|
||||
return ""
|
||||
}
|
||||
theme := style.GetTheme()
|
||||
|
||||
// Available content width inside the block: container minus border (1)
|
||||
// minus left padding (2). Used to estimate line wrapping for truncation.
|
||||
contentWidth := max(m.width-3, 10)
|
||||
|
||||
var blocks []string
|
||||
|
||||
// Render steering messages first (higher priority).
|
||||
if len(m.steeringMessages) > 0 {
|
||||
badge := style.CreateBadge("STEERING", theme.Warning)
|
||||
for _, msg := range m.steeringMessages {
|
||||
content := msg + "\n" + badge
|
||||
display := truncateMessageForBlock(msg, maxQueuedMessageLines, contentWidth)
|
||||
content := display + "\n" + badge
|
||||
rendered := renderContentBlock(
|
||||
content,
|
||||
m.width,
|
||||
@@ -2492,7 +2523,8 @@ func (m *AppModel) renderQueuedMessages() string {
|
||||
if len(m.queuedMessages) > 0 {
|
||||
badge := style.CreateBadge("QUEUED", theme.Accent)
|
||||
for _, msg := range m.queuedMessages {
|
||||
content := msg + "\n" + badge
|
||||
display := truncateMessageForBlock(msg, maxQueuedMessageLines, contentWidth)
|
||||
content := display + "\n" + badge
|
||||
rendered := renderContentBlock(
|
||||
content,
|
||||
m.width,
|
||||
@@ -2506,6 +2538,58 @@ func (m *AppModel) renderQueuedMessages() string {
|
||||
return strings.Join(blocks, "\n")
|
||||
}
|
||||
|
||||
// truncateMessageForBlock truncates a message to at most maxLines visible
|
||||
// lines, accounting for soft-wrapping at the given width. If the message is
|
||||
// truncated, the last visible line is replaced with an ellipsis ("…").
|
||||
func truncateMessageForBlock(msg string, maxLines, width int) string {
|
||||
if width <= 0 {
|
||||
width = 1
|
||||
}
|
||||
|
||||
lines := strings.Split(msg, "\n")
|
||||
|
||||
// Count visible lines (each hard line may wrap into multiple visual lines).
|
||||
var kept []string
|
||||
visibleCount := 0
|
||||
truncated := false
|
||||
|
||||
for _, line := range lines {
|
||||
// Calculate how many visual lines this hard line occupies.
|
||||
lineWidth := lipgloss.Width(line)
|
||||
wrapped := 1
|
||||
if lineWidth > width {
|
||||
wrapped = (lineWidth + width - 1) / width // ceil division
|
||||
}
|
||||
|
||||
if visibleCount+wrapped > maxLines {
|
||||
// This line would exceed the limit. Keep a partial if we
|
||||
// still have room for at least one more visual line.
|
||||
remaining := maxLines - visibleCount
|
||||
if remaining > 0 {
|
||||
// Truncate the line to fit the remaining visual lines.
|
||||
runes := []rune(line)
|
||||
maxRunes := remaining * width
|
||||
if maxRunes < len(runes) {
|
||||
kept = append(kept, string(runes[:maxRunes]))
|
||||
} else {
|
||||
kept = append(kept, line)
|
||||
}
|
||||
}
|
||||
truncated = true
|
||||
break
|
||||
}
|
||||
|
||||
kept = append(kept, line)
|
||||
visibleCount += wrapped
|
||||
}
|
||||
|
||||
if !truncated {
|
||||
return msg
|
||||
}
|
||||
|
||||
return strings.Join(kept, "\n") + "…"
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Print helpers — add content to ScrollList
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -2876,7 +2960,7 @@ func (m *AppModel) printHelpMessage() {
|
||||
"**Keys:**\n" +
|
||||
"- `Ctrl+C`: Exit at any time\n" +
|
||||
"- `ESC` (x2): Cancel ongoing LLM generation\n" +
|
||||
"- `Ctrl+S`: Steer — redirect the agent mid-turn (injected between tool calls)\n" +
|
||||
"- `Ctrl+X s`: Steer — redirect the agent mid-turn (injected between tool calls)\n" +
|
||||
"- `Enter` (while working): Queue message for after the agent finishes\n\n" +
|
||||
"You can also just type your message to chat with the AI assistant."
|
||||
m.printSystemMessage(help)
|
||||
|
||||
@@ -2,6 +2,7 @@ package ui
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
@@ -892,3 +893,107 @@ func TestSubmit_duringWorking_stays(t *testing.T) {
|
||||
t.Fatalf("expected Run('queued prompt') called, got %v", ctrl.runCalls)
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// truncateMessageForBlock
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// TestTruncateMessageForBlock_shortMessage verifies that short messages are
|
||||
// returned unchanged.
|
||||
func TestTruncateMessageForBlock_shortMessage(t *testing.T) {
|
||||
msg := "hello world"
|
||||
got := truncateMessageForBlock(msg, 3, 80)
|
||||
if got != msg {
|
||||
t.Fatalf("expected unchanged message, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTruncateMessageForBlock_exactLines verifies that a message with exactly
|
||||
// maxLines hard lines is returned unchanged.
|
||||
func TestTruncateMessageForBlock_exactLines(t *testing.T) {
|
||||
msg := "line1\nline2\nline3"
|
||||
got := truncateMessageForBlock(msg, 3, 80)
|
||||
if got != msg {
|
||||
t.Fatalf("expected unchanged message, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTruncateMessageForBlock_tooManyLines verifies that messages exceeding
|
||||
// maxLines are truncated with an ellipsis.
|
||||
func TestTruncateMessageForBlock_tooManyLines(t *testing.T) {
|
||||
msg := "line1\nline2\nline3\nline4\nline5"
|
||||
got := truncateMessageForBlock(msg, 3, 80)
|
||||
want := "line1\nline2\nline3…"
|
||||
if got != want {
|
||||
t.Fatalf("expected %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTruncateMessageForBlock_longWrappingLine verifies that a single long
|
||||
// line that would wrap beyond maxLines is truncated.
|
||||
func TestTruncateMessageForBlock_longWrappingLine(t *testing.T) {
|
||||
// 100 chars at width 20 = 5 visual lines, exceeds maxLines=3
|
||||
msg := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
got := truncateMessageForBlock(msg, 3, 20)
|
||||
// Should be truncated to 3*20=60 runes + "…"
|
||||
if len([]rune(got)) != 61 { // 60 runes + "…"
|
||||
t.Fatalf("expected 61 runes (60 + ellipsis), got %d runes: %q", len([]rune(got)), got)
|
||||
}
|
||||
if got[len(got)-3:] != "…" { // "…" is 3 bytes in UTF-8
|
||||
t.Fatal("expected trailing ellipsis")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTruncateMessageForBlock_emptyMessage verifies that empty messages are
|
||||
// returned unchanged.
|
||||
func TestTruncateMessageForBlock_emptyMessage(t *testing.T) {
|
||||
got := truncateMessageForBlock("", 3, 80)
|
||||
if got != "" {
|
||||
t.Fatalf("expected empty string, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTruncateMessageForBlock_mixedWrapAndHardLines verifies truncation when
|
||||
// some hard lines wrap and the total exceeds maxLines.
|
||||
func TestTruncateMessageForBlock_mixedWrapAndHardLines(t *testing.T) {
|
||||
// First line: 40 chars at width 20 = 2 visual lines
|
||||
// Second line: "short" = 1 visual line (total: 3, exactly at limit)
|
||||
// Third line: would exceed
|
||||
msg := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\nshort\nextra"
|
||||
got := truncateMessageForBlock(msg, 3, 20)
|
||||
want := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\nshort…"
|
||||
if got != want {
|
||||
t.Fatalf("expected %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRenderQueuedMessages_truncatesLongMessages verifies that the rendered
|
||||
// queued message view truncates long messages instead of showing them in full.
|
||||
func TestRenderQueuedMessages_truncatesLongMessages(t *testing.T) {
|
||||
ctrl := &stubAppController{}
|
||||
m, _, _ := newTestAppModel(ctrl)
|
||||
m.width = 80
|
||||
|
||||
// Queue a very long message (20 lines).
|
||||
var b strings.Builder
|
||||
for i := range 20 {
|
||||
if i > 0 {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
b.WriteString("This is a long line of text for testing purposes")
|
||||
}
|
||||
m.queuedMessages = []string{b.String()}
|
||||
|
||||
rendered := m.renderQueuedMessages()
|
||||
if rendered == "" {
|
||||
t.Fatal("expected non-empty rendered output")
|
||||
}
|
||||
|
||||
// The full message would be ~20+ lines. With truncation to 3 content
|
||||
// lines + badge + padding, it should be much shorter.
|
||||
lines := len(strings.Split(rendered, "\n"))
|
||||
// 3 content lines + 1 badge + 2 padding + border overhead ≈ ~7 lines max
|
||||
if lines > 10 {
|
||||
t.Fatalf("expected truncated output to be ≤10 lines, got %d lines", lines)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ func newInputPrompt(message, placeholder, defaultValue string, width, height int
|
||||
ta.Placeholder = placeholder
|
||||
ta.ShowLineNumbers = false
|
||||
ta.Prompt = ""
|
||||
ta.CharLimit = 1000
|
||||
ta.CharLimit = 0
|
||||
ta.SetWidth(width - 12) // account for border + padding
|
||||
ta.SetHeight(1)
|
||||
ta.Focus()
|
||||
|
||||
@@ -14,11 +14,19 @@ import (
|
||||
)
|
||||
|
||||
// UserBlock renders a user message with herald Tip styling.
|
||||
func UserBlock(content string, ty *herald.Typography, theme style.Theme) string {
|
||||
// The width parameter controls line wrapping so long messages don't overflow.
|
||||
func UserBlock(content string, width int, ty *herald.Typography, theme style.Theme) string {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
content = "(empty message)"
|
||||
}
|
||||
|
||||
// Wrap content before passing to herald Alert so long lines break
|
||||
// inside the alert box. Subtract 4 to account for the alert bar
|
||||
// prefix ("│ ") and a small margin.
|
||||
if width > 4 {
|
||||
content = lipgloss.Wrap(content, width-4, "")
|
||||
}
|
||||
|
||||
rendered := ty.Tip(content)
|
||||
return styleMarginBottom(theme, rendered)
|
||||
}
|
||||
|
||||
@@ -85,11 +85,13 @@ func GetMarkdownTypography() *herald.Typography {
|
||||
return ty
|
||||
}
|
||||
|
||||
// ToMarkdown renders markdown content using herald-md.
|
||||
// The width parameter is currently unused as herald handles wrapping
|
||||
// based on terminal width internally.
|
||||
// ToMarkdown renders markdown content using herald-md and wraps the result
|
||||
// to the given width so that long lines do not overflow the terminal.
|
||||
func ToMarkdown(content string, width int) string {
|
||||
ty := GetMarkdownTypography()
|
||||
rendered := heraldmd.Render(ty, []byte(content))
|
||||
if width > 0 {
|
||||
rendered = lipgloss.Wrap(rendered, width, "")
|
||||
}
|
||||
return rendered
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ func NewToolApprovalInput(toolName, toolArgs string, width int) *ToolApprovalInp
|
||||
ta := textarea.New()
|
||||
ta.Placeholder = ""
|
||||
ta.ShowLineNumbers = false
|
||||
ta.CharLimit = 1000
|
||||
ta.CharLimit = 0
|
||||
ta.SetWidth(width - 8) // Account for container padding, border and internal padding
|
||||
ta.SetHeight(4) // Default to 3 lines like huh
|
||||
ta.Focus()
|
||||
|
||||
+42
-18
@@ -1309,6 +1309,17 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
|
||||
IsStderr: isStderr,
|
||||
})
|
||||
},
|
||||
// Persist step messages incrementally so that progress survives
|
||||
// crashes and long-running turns don't lose work. Each step's
|
||||
// messages are persisted as a unit: for tool-calling steps this is
|
||||
// the assistant message (with tool_use parts) + tool-role message
|
||||
// (with tool_result parts) as a pair; for the final step it's the
|
||||
// assistant text/reasoning message alone.
|
||||
func(stepMessages []fantasy.Message) {
|
||||
for _, msg := range stepMessages {
|
||||
_, _ = m.session.AppendMessage(msg)
|
||||
}
|
||||
},
|
||||
func(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64) {
|
||||
// Emit step usage event for real-time cost tracking
|
||||
if viper.GetBool("debug") {
|
||||
@@ -1331,11 +1342,17 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
|
||||
// 2. Persist pre-generation messages to the tree session.
|
||||
// 3. Build context from the tree (walks leaf-to-root for current branch).
|
||||
// 4. Emit turn/message start events.
|
||||
// 5. Run generation.
|
||||
// 6. Emit turn/message end events.
|
||||
// 7. Persist post-generation messages (tool calls, results, assistant).
|
||||
// 5. Run generation (messages are persisted incrementally per step).
|
||||
// 6. Persist any remaining messages not covered by incremental persistence.
|
||||
// 7. Emit turn/message end events.
|
||||
// 8. Run AfterTurn hooks.
|
||||
//
|
||||
// During generation, each completed step's messages are persisted immediately
|
||||
// via the onStepMessages callback. Tool calls are always persisted as
|
||||
// call/response pairs (assistant + tool messages together). Reasoning and
|
||||
// text-only assistant messages are persisted as soon as their step completes.
|
||||
// This ensures long-running turns don't lose progress on crash or cancellation.
|
||||
//
|
||||
// promptLabel is the human-readable label emitted in TurnStartEvent.Prompt.
|
||||
// prompt is the raw user text passed to BeforeTurn hooks.
|
||||
func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, preMessages []fantasy.Message) (*TurnResult, error) {
|
||||
@@ -1405,16 +1422,18 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
|
||||
|
||||
result, err := m.generate(ctx, messages)
|
||||
if err != nil {
|
||||
// Persist any messages from completed steps (tool call/result
|
||||
// pairs) so partial progress is not lost. The agent layer only
|
||||
// includes fully-paired tool_use + tool_result messages in
|
||||
// completedStepMessages, so there are no orphaned entries that
|
||||
// would break subsequent API requests. The user message and any
|
||||
// completed work remain in the session; only the in-progress
|
||||
// (pending) message or tool call is discarded.
|
||||
if result != nil && len(result.ConversationMessages) > sentCount {
|
||||
for _, msg := range result.ConversationMessages[sentCount:] {
|
||||
_, _ = m.session.AppendMessage(msg)
|
||||
// Persist any messages from completed steps that were NOT already
|
||||
// persisted incrementally by the onStepMessages callback. The agent
|
||||
// layer only includes fully-paired tool_use + tool_result messages
|
||||
// in completedStepMessages, so there are no orphaned entries that
|
||||
// would break subsequent API requests.
|
||||
if result != nil {
|
||||
newMessages := result.ConversationMessages[sentCount:]
|
||||
alreadyPersisted := result.PersistedMessageCount
|
||||
if alreadyPersisted < len(newMessages) {
|
||||
for _, msg := range newMessages[alreadyPersisted:] {
|
||||
_, _ = m.session.AppendMessage(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
m.events.emit(TurnEndEvent{Error: err})
|
||||
@@ -1425,12 +1444,17 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
|
||||
|
||||
responseText := result.FinalResponse.Content.Text()
|
||||
|
||||
// Persist new messages (tool calls, tool results, assistant response)
|
||||
// BEFORE emitting events so that extension handlers calling
|
||||
// GetContextStats() see up-to-date token counts.
|
||||
// Persist any new messages that were NOT already persisted incrementally
|
||||
// by the onStepMessages callback during generation. This handles the
|
||||
// non-streaming path (where onStepMessages is not called) and any edge
|
||||
// cases where the final response messages weren't covered by step callbacks.
|
||||
if len(result.ConversationMessages) > sentCount {
|
||||
for _, msg := range result.ConversationMessages[sentCount:] {
|
||||
_, _ = m.session.AppendMessage(msg)
|
||||
newMessages := result.ConversationMessages[sentCount:]
|
||||
alreadyPersisted := result.PersistedMessageCount
|
||||
if alreadyPersisted < len(newMessages) {
|
||||
for _, msg := range newMessages[alreadyPersisted:] {
|
||||
_, _ = m.session.AppendMessage(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,9 +6,21 @@ import (
|
||||
|
||||
// SessionManager defines the contract for conversation storage backends.
|
||||
// Implementations can use files (default), databases, cloud storage, etc.
|
||||
//
|
||||
// Implementations must be safe for concurrent use. During generation,
|
||||
// AppendMessage is called incrementally from the agent's step-completion
|
||||
// callback while read methods (GetMessages, GetCurrentBranch, etc.) may be
|
||||
// called concurrently from the UI or extension goroutines.
|
||||
type SessionManager interface {
|
||||
// AppendMessage adds a message to the current branch and returns its entry ID.
|
||||
// The entry ID is used for tree navigation and must be unique within the session.
|
||||
//
|
||||
// During generation, AppendMessage is called incrementally after each
|
||||
// completed agent step rather than in a batch at the end of the turn.
|
||||
// For tool-calling steps, the assistant message (containing tool_use parts)
|
||||
// and the tool-role message (containing tool_result parts) are appended
|
||||
// together as a pair. This ensures the session never contains an orphaned
|
||||
// tool call without its result, which would break subsequent LLM requests.
|
||||
AppendMessage(msg LLMMessage) (entryID string, err error)
|
||||
|
||||
// GetMessages returns all messages on the current branch (from root to leaf),
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/mark3labs/kit/internal/core"
|
||||
@@ -16,6 +18,123 @@ type ToolOption = core.ToolOption
|
||||
// If empty, os.Getwd() is used at execution time.
|
||||
var WithWorkDir = core.WithWorkDir
|
||||
|
||||
// --- Custom tool creation ---
|
||||
|
||||
// ToolOutput is the return value from custom tool handlers created with
|
||||
// [NewTool] or [NewParallelTool]. It provides a dependency-free way to
|
||||
// return results without importing the underlying LLM framework.
|
||||
type ToolOutput struct {
|
||||
// Content is the text content returned to the LLM.
|
||||
Content string
|
||||
|
||||
// IsError, when true, signals to the LLM that the tool call failed.
|
||||
IsError bool
|
||||
|
||||
// Data contains optional binary data (images, audio, etc.).
|
||||
Data []byte
|
||||
|
||||
// MediaType is the MIME type for binary Data (e.g. "image/png").
|
||||
MediaType string
|
||||
|
||||
// Metadata is optional opaque metadata attached to the response.
|
||||
// It is not sent to the LLM but may be consumed by hooks or the UI.
|
||||
Metadata any
|
||||
}
|
||||
|
||||
// TextResult creates a successful text [ToolOutput].
|
||||
func TextResult(content string) ToolOutput {
|
||||
return ToolOutput{Content: content}
|
||||
}
|
||||
|
||||
// ErrorResult creates an error [ToolOutput]. The LLM will see the content
|
||||
// as a tool error, allowing it to retry or adjust its approach.
|
||||
func ErrorResult(content string) ToolOutput {
|
||||
return ToolOutput{Content: content, IsError: true}
|
||||
}
|
||||
|
||||
// toolCallIDKey is the context key for the tool call ID.
|
||||
type toolCallIDKey struct{}
|
||||
|
||||
// ToolCallIDFromContext extracts the tool call ID from the context.
|
||||
// The call ID is set automatically by [NewTool] and [NewParallelTool]
|
||||
// before invoking the handler. Returns an empty string if no ID is present.
|
||||
func ToolCallIDFromContext(ctx context.Context) string {
|
||||
s, _ := ctx.Value(toolCallIDKey{}).(string)
|
||||
return s
|
||||
}
|
||||
|
||||
// NewTool creates a custom [Tool] with automatic JSON schema generation from
|
||||
// the TInput struct type. The handler receives a typed input (deserialized
|
||||
// from the LLM's JSON arguments) and returns a [ToolResult].
|
||||
//
|
||||
// Struct tags on TInput control the generated schema:
|
||||
//
|
||||
// json:"name" → parameter name
|
||||
// description:"..." → parameter description shown to the LLM
|
||||
// enum:"a,b,c" → restrict valid values
|
||||
// omitempty → marks the parameter as optional
|
||||
//
|
||||
// The tool call ID is injected into the context and can be retrieved with
|
||||
// [ToolCallIDFromContext].
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type WeatherInput struct {
|
||||
// City string `json:"city" description:"City name"`
|
||||
// }
|
||||
//
|
||||
// tool := kit.NewTool("get_weather", "Get weather for a city",
|
||||
// func(ctx context.Context, input WeatherInput) (kit.ToolResult, error) {
|
||||
// return kit.TextResult("72°F, sunny in " + input.City), nil
|
||||
// },
|
||||
// )
|
||||
func NewTool[TInput any](name, description string, fn func(ctx context.Context, input TInput) (ToolOutput, error)) Tool {
|
||||
return fantasy.NewAgentTool(name, description,
|
||||
func(ctx context.Context, input TInput, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
ctx = context.WithValue(ctx, toolCallIDKey{}, call.ID)
|
||||
result, err := fn(ctx, input)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
resp := fantasy.ToolResponse{
|
||||
Content: result.Content,
|
||||
IsError: result.IsError,
|
||||
Data: result.Data,
|
||||
MediaType: result.MediaType,
|
||||
}
|
||||
if result.Metadata != nil {
|
||||
resp = fantasy.WithResponseMetadata(resp, result.Metadata)
|
||||
}
|
||||
return resp, nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// NewParallelTool is like [NewTool] but marks the tool as safe for concurrent
|
||||
// execution alongside other tools. Use this when the tool has no side effects
|
||||
// or when concurrent calls are safe.
|
||||
func NewParallelTool[TInput any](name, description string, fn func(ctx context.Context, input TInput) (ToolOutput, error)) Tool {
|
||||
return fantasy.NewParallelAgentTool(name, description,
|
||||
func(ctx context.Context, input TInput, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
ctx = context.WithValue(ctx, toolCallIDKey{}, call.ID)
|
||||
result, err := fn(ctx, input)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
resp := fantasy.ToolResponse{
|
||||
Content: result.Content,
|
||||
IsError: result.IsError,
|
||||
Data: result.Data,
|
||||
MediaType: result.MediaType,
|
||||
}
|
||||
if result.Metadata != nil {
|
||||
resp = fantasy.WithResponseMetadata(resp, result.Metadata)
|
||||
}
|
||||
return resp, nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// --- Individual tool constructors ---
|
||||
|
||||
// NewReadTool creates a file-reading tool.
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
package kit_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// TestNewTool_BasicTextResult verifies that NewTool creates a working tool
|
||||
// that returns text content via ToolOutput.
|
||||
func TestNewTool_BasicTextResult(t *testing.T) {
|
||||
type Input struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
tool := kit.NewTool("greet", "Greet someone",
|
||||
func(ctx context.Context, input Input) (kit.ToolOutput, error) {
|
||||
return kit.TextResult("hello " + input.Name), nil
|
||||
},
|
||||
)
|
||||
|
||||
info := tool.Info()
|
||||
if info.Name != "greet" {
|
||||
t.Errorf("Info().Name = %q, want %q", info.Name, "greet")
|
||||
}
|
||||
if info.Description != "Greet someone" {
|
||||
t.Errorf("Info().Description = %q, want %q", info.Description, "Greet someone")
|
||||
}
|
||||
if info.Parallel {
|
||||
t.Error("NewTool should not mark tool as parallel")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewParallelTool_MarkedParallel verifies that NewParallelTool marks the
|
||||
// tool as safe for concurrent execution.
|
||||
func TestNewParallelTool_MarkedParallel(t *testing.T) {
|
||||
type Input struct {
|
||||
Query string `json:"query"`
|
||||
}
|
||||
|
||||
tool := kit.NewParallelTool("search", "Search for things",
|
||||
func(ctx context.Context, input Input) (kit.ToolOutput, error) {
|
||||
return kit.TextResult("found: " + input.Query), nil
|
||||
},
|
||||
)
|
||||
|
||||
info := tool.Info()
|
||||
if info.Name != "search" {
|
||||
t.Errorf("Info().Name = %q, want %q", info.Name, "search")
|
||||
}
|
||||
if !info.Parallel {
|
||||
t.Error("NewParallelTool should mark tool as parallel")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTextResult verifies the TextResult convenience constructor.
|
||||
func TestTextResult(t *testing.T) {
|
||||
r := kit.TextResult("ok")
|
||||
if r.Content != "ok" {
|
||||
t.Errorf("Content = %q, want %q", r.Content, "ok")
|
||||
}
|
||||
if r.IsError {
|
||||
t.Error("TextResult should not set IsError")
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorResult verifies the ErrorResult convenience constructor.
|
||||
func TestErrorResult(t *testing.T) {
|
||||
r := kit.ErrorResult("bad input")
|
||||
if r.Content != "bad input" {
|
||||
t.Errorf("Content = %q, want %q", r.Content, "bad input")
|
||||
}
|
||||
if !r.IsError {
|
||||
t.Error("ErrorResult should set IsError")
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolCallIDFromContext verifies round-trip context injection.
|
||||
func TestToolCallIDFromContext(t *testing.T) {
|
||||
// Empty context returns empty string.
|
||||
if id := kit.ToolCallIDFromContext(context.Background()); id != "" {
|
||||
t.Errorf("expected empty string from bare context, got %q", id)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolOutput_Metadata verifies that metadata can be set on ToolOutput.
|
||||
func TestToolOutput_Metadata(t *testing.T) {
|
||||
r := kit.ToolOutput{
|
||||
Content: "data",
|
||||
Metadata: map[string]string{"key": "value"},
|
||||
}
|
||||
if r.Metadata == nil {
|
||||
t.Error("expected non-nil Metadata")
|
||||
}
|
||||
m, ok := r.Metadata.(map[string]string)
|
||||
if !ok {
|
||||
t.Fatalf("expected map[string]string, got %T", r.Metadata)
|
||||
}
|
||||
if m["key"] != "value" {
|
||||
t.Errorf("Metadata[key] = %q, want %q", m["key"], "value")
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolOutput_BinaryData verifies that binary data fields work correctly.
|
||||
func TestToolOutput_BinaryData(t *testing.T) {
|
||||
data := []byte{0x89, 0x50, 0x4E, 0x47}
|
||||
r := kit.ToolOutput{
|
||||
Content: "image result",
|
||||
Data: data,
|
||||
MediaType: "image/png",
|
||||
}
|
||||
if len(r.Data) != 4 {
|
||||
t.Errorf("Data len = %d, want 4", len(r.Data))
|
||||
}
|
||||
if r.MediaType != "image/png" {
|
||||
t.Errorf("MediaType = %q, want %q", r.MediaType, "image/png")
|
||||
}
|
||||
}
|
||||
@@ -347,6 +347,77 @@ Lower values run first. Within the same priority, registration order applies. Fi
|
||||
|
||||
## Tools
|
||||
|
||||
### Creating custom tools
|
||||
|
||||
Use `kit.NewTool` to create custom tools. The JSON schema is auto-generated from the input struct — no external dependencies required:
|
||||
|
||||
```go
|
||||
type WeatherInput struct {
|
||||
City string `json:"city" description:"City name, e.g. 'San Francisco'"`
|
||||
}
|
||||
|
||||
weatherTool := kit.NewTool("get_weather", "Get current weather for a city",
|
||||
func(ctx context.Context, input WeatherInput) (kit.ToolOutput, error) {
|
||||
// Your logic here (API calls, database lookups, etc.)
|
||||
return kit.TextResult("72°F, sunny in " + input.City), nil
|
||||
},
|
||||
)
|
||||
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
ExtraTools: []kit.Tool{weatherTool},
|
||||
})
|
||||
```
|
||||
|
||||
**Struct tags** control the generated schema:
|
||||
|
||||
| Tag | Purpose | Example |
|
||||
|-----|---------|---------|
|
||||
| `json:"name"` | Parameter name | `json:"city"` |
|
||||
| `description:"..."` | Description shown to the LLM | `description:"City name"` |
|
||||
| `enum:"a,b,c"` | Restrict valid values | `enum:"json,text,csv"` |
|
||||
| `omitempty` | Marks parameter as optional | `json:"limit,omitempty"` |
|
||||
|
||||
**Return helpers:**
|
||||
|
||||
| Function | Description |
|
||||
|----------|-------------|
|
||||
| `kit.TextResult(content)` | Successful text result |
|
||||
| `kit.ErrorResult(content)` | Error result (LLM sees it as a tool error) |
|
||||
|
||||
**ToolOutput fields** (for advanced use):
|
||||
|
||||
```go
|
||||
kit.ToolOutput{
|
||||
Content: "result text", // text returned to the LLM
|
||||
IsError: false, // true = LLM sees this as an error
|
||||
Data: pngBytes, // optional binary data (images, audio)
|
||||
MediaType: "image/png", // MIME type for binary Data
|
||||
Metadata: map[string]any{}, // opaque metadata for hooks/UI (not sent to LLM)
|
||||
}
|
||||
```
|
||||
|
||||
**Parallel tools** — mark as safe for concurrent execution:
|
||||
|
||||
```go
|
||||
searchTool := kit.NewParallelTool("search", "Search the web",
|
||||
func(ctx context.Context, input SearchInput) (kit.ToolOutput, error) {
|
||||
return kit.TextResult("results..."), nil
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
**Tool call ID** — available in context for logging/tracing:
|
||||
|
||||
```go
|
||||
tool := kit.NewTool("my_tool", "...",
|
||||
func(ctx context.Context, input MyInput) (kit.ToolOutput, error) {
|
||||
callID := kit.ToolCallIDFromContext(ctx) // correlation ID from the LLM
|
||||
log.Printf("[%s] my_tool called", callID)
|
||||
return kit.TextResult("ok"), nil
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
### Built-in tool constructors
|
||||
|
||||
```go
|
||||
|
||||
+49
-21
@@ -7,17 +7,16 @@ description: Monitor tool calls and streaming output with the Kit Go SDK.
|
||||
|
||||
## Event-based monitoring
|
||||
|
||||
For more granular control, use the event subscription API:
|
||||
Subscribe to events for real-time monitoring. Each method returns an unsubscribe function:
|
||||
|
||||
```go
|
||||
// Subscribe returns an unsubscribe function
|
||||
unsub := host.OnToolCall(func(event kit.ToolCallEvent) {
|
||||
fmt.Printf("Tool: %s, Args: %s\n", event.Name, event.Args)
|
||||
fmt.Printf("Tool: %s, Args: %s\n", event.ToolName, event.ToolArgs)
|
||||
})
|
||||
defer unsub()
|
||||
|
||||
unsub2 := host.OnToolResult(func(event kit.ToolResultEvent) {
|
||||
fmt.Printf("Result: %s (error: %v)\n", event.Name, event.IsError)
|
||||
fmt.Printf("Result: %s (error: %v)\n", event.ToolName, event.IsError)
|
||||
})
|
||||
defer unsub2()
|
||||
|
||||
@@ -44,33 +43,62 @@ defer unsub6()
|
||||
|
||||
## Hook system
|
||||
|
||||
Hooks allow you to intercept and modify behavior. Unlike events, hooks can modify or cancel operations:
|
||||
Hooks can **modify or cancel** operations. Unlike events (read-only), hooks are read-write interceptors.
|
||||
|
||||
### BeforeToolCall — block tool execution
|
||||
|
||||
```go
|
||||
// Intercept tool calls before execution
|
||||
host.OnBeforeToolCall(0, func(ctx context.Context, name string, args string) (string, error) {
|
||||
if name == "bash" {
|
||||
log.Println("Bash command:", args)
|
||||
host.OnBeforeToolCall(kit.HookPriorityNormal, func(h kit.BeforeToolCallHook) *kit.BeforeToolCallResult {
|
||||
// h.ToolCallID, h.ToolName, h.ToolArgs
|
||||
if h.ToolName == "bash" && strings.Contains(h.ToolArgs, "rm -rf") {
|
||||
return &kit.BeforeToolCallResult{Block: true, Reason: "dangerous command"}
|
||||
}
|
||||
return args, nil // return modified args or error to cancel
|
||||
return nil // allow
|
||||
})
|
||||
```
|
||||
|
||||
// Process results after tool execution
|
||||
host.OnAfterToolResult(0, func(ctx context.Context, name string, result string) (string, error) {
|
||||
return result, nil
|
||||
})
|
||||
### AfterToolResult — modify tool output
|
||||
|
||||
// Before/after each agent turn
|
||||
host.OnBeforeTurn(0, func(ctx context.Context) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
host.OnAfterTurn(0, func(ctx context.Context) error {
|
||||
```go
|
||||
host.OnAfterToolResult(kit.HookPriorityNormal, func(h kit.AfterToolResultHook) *kit.AfterToolResultResult {
|
||||
// h.ToolCallID, h.ToolName, h.ToolArgs, h.Result, h.IsError
|
||||
if h.ToolName == "read" {
|
||||
filtered := redactSecrets(h.Result)
|
||||
return &kit.AfterToolResultResult{Result: &filtered}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
The first argument is a priority (lower = runs first).
|
||||
### BeforeTurn — modify prompt, inject messages
|
||||
|
||||
```go
|
||||
host.OnBeforeTurn(kit.HookPriorityNormal, func(h kit.BeforeTurnHook) *kit.BeforeTurnResult {
|
||||
// h.Prompt
|
||||
newPrompt := h.Prompt + "\nAlways respond in JSON."
|
||||
return &kit.BeforeTurnResult{Prompt: &newPrompt}
|
||||
// Also available: SystemPrompt *string, InjectText *string
|
||||
})
|
||||
```
|
||||
|
||||
### AfterTurn — observation only
|
||||
|
||||
```go
|
||||
host.OnAfterTurn(kit.HookPriorityNormal, func(h kit.AfterTurnHook) {
|
||||
// h.Response, h.Error
|
||||
log.Printf("Turn completed: %d chars", len(h.Response))
|
||||
})
|
||||
```
|
||||
|
||||
### Hook priorities
|
||||
|
||||
```go
|
||||
kit.HookPriorityHigh = 0 // runs first
|
||||
kit.HookPriorityNormal = 50 // default
|
||||
kit.HookPriorityLow = 100 // runs last
|
||||
```
|
||||
|
||||
Lower values run first. First non-nil result wins.
|
||||
|
||||
## Subagent event monitoring
|
||||
|
||||
|
||||
@@ -68,3 +68,28 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
| `CompactionOptions` | `*CompactionOptions` | — | Configuration for auto-compaction |
|
||||
| `Skills` | `[]string` | — | Explicit skill files/dirs to load |
|
||||
| `SkillsDir` | `string` | — | Override default skills directory |
|
||||
|
||||
## Tool configuration
|
||||
|
||||
**`Tools`** replaces ALL default tools (core + MCP + extension). **`ExtraTools`** adds tools alongside the defaults. Use `Tools` to restrict capabilities; use `ExtraTools` to extend them.
|
||||
|
||||
Create custom tools with `kit.NewTool` — no external dependencies needed:
|
||||
|
||||
```go
|
||||
type LookupInput struct {
|
||||
ID string `json:"id" description:"Record ID to look up"`
|
||||
}
|
||||
|
||||
lookupTool := kit.NewTool("lookup", "Look up a record by ID",
|
||||
func(ctx context.Context, input LookupInput) (kit.ToolOutput, error) {
|
||||
record := db.Find(input.ID)
|
||||
return kit.TextResult(record.String()), nil
|
||||
},
|
||||
)
|
||||
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
ExtraTools: []kit.Tool{lookupTool},
|
||||
})
|
||||
```
|
||||
|
||||
See [Overview](/sdk/overview#custom-tools) for full custom tool documentation.
|
||||
|
||||
@@ -68,6 +68,44 @@ The SDK provides several prompt variants:
|
||||
| `Steer(ctx, instruction)` | System-level steering without user message |
|
||||
| `FollowUp(ctx, text)` | Continue without new user input |
|
||||
|
||||
## Custom tools
|
||||
|
||||
Create custom tools with `kit.NewTool`. The JSON schema is auto-generated from the input struct — no external dependencies required:
|
||||
|
||||
```go
|
||||
type WeatherInput struct {
|
||||
City string `json:"city" description:"City name"`
|
||||
}
|
||||
|
||||
weatherTool := kit.NewTool("get_weather", "Get current weather for a city",
|
||||
func(ctx context.Context, input WeatherInput) (kit.ToolOutput, error) {
|
||||
return kit.TextResult("72°F, sunny in " + input.City), nil
|
||||
},
|
||||
)
|
||||
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
ExtraTools: []kit.Tool{weatherTool},
|
||||
})
|
||||
```
|
||||
|
||||
Struct tags control the schema:
|
||||
|
||||
- `json:"name"` — parameter name
|
||||
- `description:"..."` — description shown to the LLM
|
||||
- `enum:"a,b,c"` — restrict valid values
|
||||
- `omitempty` — marks the parameter as optional
|
||||
|
||||
Return values:
|
||||
|
||||
| Helper | Description |
|
||||
|--------|-------------|
|
||||
| `kit.TextResult(s)` | Successful text result |
|
||||
| `kit.ErrorResult(s)` | Error result (LLM sees it as a tool error) |
|
||||
|
||||
For advanced use, return a `kit.ToolOutput` struct directly with `Data`, `MediaType`, and `Metadata` fields.
|
||||
|
||||
Use `kit.NewParallelTool` for tools that are safe to run concurrently. Use `kit.ToolCallIDFromContext(ctx)` to retrieve the LLM-assigned call ID for logging or tracing.
|
||||
|
||||
## Event system
|
||||
|
||||
Subscribe to events for monitoring:
|
||||
|
||||
Reference in New Issue
Block a user