pkg/: simplify code without altering public API

events.go
- Delete subagentListenerSet (verbatim duplicate of eventBus); reuse
  *eventBus in SubscribeSubagent and getSubagentListenerSet

hooks.go
- Add early-exit in run() when hooks slice is empty, making all
  hasHooks() guard call sites in kit.go and compaction.go redundant

kit.go
- Remove four if m.X.hasHooks() { m.X.run(...) } outer guards
  (beforeTurn, contextPrepare, afterTurn x2); run() now short-circuits
- Replace goto drained with an idiomatic return inside default: branch
- Replace stdlib log.Printf with charmlog.Debug (charmbracelet/log),
  consistent with the rest of the codebase; remove "log" import

config.go
- Collapse single-element configNames := []string{".kit"} loop into a
  direct viper.SetConfigName call (removes slice, for, break, flag)

auth.go
- Fix GetOpenAIAPIKey: it documented OPENAI_API_KEY env var fallback but
  never called os.Getenv; now it does

compaction.go
- Extract persistAndEmitCompaction helper; eliminates duplicated
  AppendCompaction + events.emit block in compactInternal and
  applyCustomCompaction
- Replace fmt.Errorf("%s", reason) with errors.New(reason)
- Name the 16384 magic number as const defaultReserveTokens

skills.go
- Fix broken double-checked lock in DiscoverSkillsForExtension: the
  read-unlock -> write-lock gap had a TOCTOU race; replaced with a
  single write-lock covering the check and load
- Remove dead nil guard in convertSkills (convertSkill never returns nil)
- Rename convertSkills parameter skills->skillList to avoid shadowing
  the skills package import

extensions_bridge.go
- Delete taskMutex struct (sync.Mutex wrapper with map passed as param);
  replace with inline var taskMu sync.Mutex at the use site
- Simplify AgentEnd double-if into a single combined := declaration

template_bridge.go
- Fix RenderTemplate: use varRegex.ReplaceAllStringFunc instead of
  two-pass strings.ReplaceAll; handles arbitrary whitespace in {{var}}
- Remove dead isFlag function and simplify ParseArguments guard
  (the outer !HasPrefix guard made isFlag always return false)
- Cache matchModelPattern compiled regexps in a sync.Map to avoid
  repeated regexp.Compile on hot streaming paths

pkg/extensions/test/mock.go
- Remove dead local StatusBarEntry type (duplicate of extensions type,
  never referenced)
- Change make([]T, 0) to nil for nine slice fields in NewMockContext

pkg/extensions/test/harness.go
- Remove MustLoad (no callers outside the package)
- Remove extPath field (assigned but never read)
- Remove redundant os.Stat in LoadFile (os.ReadFile already errors)

events_test.go
- Add five missing event types to TestEventTypes table
  (Compaction, ReasoningDelta, ToolOutput, StepUsage, SteerConsumed)
- Expand TestEventOrdering from 11 to 16 events with the same types
- Add a got < 0 assertion to TestEventBusConcurrentSubscribeEmit so the
  test can actually fail rather than only logging
This commit is contained in:
Ed Zynda
2026-03-29 12:39:19 +03:00
parent b0991c7aa6
commit 9fbbab05f6
12 changed files with 181 additions and 261 deletions
+1 -16
View File
@@ -52,7 +52,6 @@ type Harness struct {
t *testing.T
runner *extensions.Runner
context *MockContext
extPath string
}
// New creates a new test harness for the given test.
@@ -72,15 +71,9 @@ func New(t *testing.T) *Harness {
func (h *Harness) LoadFile(path string) *extensions.LoadedExtension {
h.t.Helper()
// Verify file exists
if _, err := os.Stat(path); err != nil {
h.t.Fatalf("extension file not found: %s: %v", path, err)
}
// Read extension source
src, err := os.ReadFile(path)
if err != nil {
h.t.Fatalf("failed to read extension file: %v", err)
h.t.Fatalf("failed to read extension file %s: %v", path, err)
}
return h.loadSource(string(src), path)
@@ -144,7 +137,6 @@ func (h *Harness) loadSource(src string, path string) *extensions.LoadedExtensio
// Create runner with the loaded extension
h.runner = extensions.NewRunner([]extensions.LoadedExtension{*ext})
h.extPath = path
// Wire the mock context
h.runner.SetContext(h.context.ToContext())
@@ -223,10 +215,3 @@ func (h *Harness) RegisteredCommands() []extensions.CommandDef {
return h.runner.RegisteredCommands()
}
// MustLoad is like LoadFile but fails the test immediately on error.
// It returns the harness for chaining.
func (h *Harness) MustLoad(path string) *Harness {
h.t.Helper()
h.LoadFile(path)
return h
}
-17
View File
@@ -59,29 +59,12 @@ type MockContext struct {
Overlays []extensions.OverlayConfig
}
// StatusBarEntry represents a recorded status bar entry
type StatusBarEntry struct {
Key string
Text string
Priority int
}
// NewMockContext creates a new mock context with default values.
func NewMockContext() *MockContext {
return &MockContext{
Prints: make([]string, 0),
PrintInfos: make([]string, 0),
PrintErrors: make([]string, 0),
PrintBlocks: make([]extensions.PrintBlockOpts, 0),
Messages: make([]string, 0),
CancelSends: make([]string, 0),
Widgets: make(map[string]extensions.WidgetConfig),
RemovedIDs: make([]string, 0),
StatusEntries: make(map[string]extensions.StatusBarEntry),
RemovedStatus: make([]string, 0),
EditorTexts: make([]string, 0),
Options: make(map[string]string),
Overlays: make([]extensions.OverlayConfig, 0),
Interactive: true,
SessionID: "test-session",
CWD: "/test",
+11 -9
View File
@@ -1,6 +1,10 @@
package kit
import "github.com/mark3labs/kit/internal/auth"
import (
"os"
"github.com/mark3labs/kit/internal/auth"
)
// CredentialManager manages API keys and OAuth credentials.
type CredentialManager = auth.CredentialManager
@@ -66,14 +70,12 @@ func HasOpenAICredentials() bool {
// Returns an empty string if no key is found.
func GetOpenAIAPIKey() string {
cm, err := auth.NewCredentialManager()
if err != nil {
return ""
}
// Try to get valid access token (handles OAuth refresh)
token, err := cm.GetValidOpenAIAccessToken()
if err == nil && token != "" {
return token
if err == nil {
// Try to get valid access token (handles OAuth refresh)
if token, err := cm.GetValidOpenAIAccessToken(); err == nil && token != "" {
return token
}
}
// Fall back to environment variable
return ""
return os.Getenv("OPENAI_API_KEY")
}
+42 -38
View File
@@ -2,6 +2,7 @@ package kit
import (
"context"
"errors"
"fmt"
"charm.land/fantasy"
@@ -17,6 +18,10 @@ type ContextStats struct {
MessageCount int // Number of messages in the conversation
}
// defaultReserveTokens is the number of tokens to keep free in the context
// window as a safety margin during compaction checks.
const defaultReserveTokens = 16384
// EstimateContextTokens returns the estimated token count of the current
// conversation based on tree session messages.
func (m *Kit) EstimateContextTokens() int {
@@ -34,7 +39,7 @@ func (m *Kit) ShouldCompact() bool {
return false
}
reserveTokens := 16384
reserveTokens := defaultReserveTokens
if m.compactionOpts != nil && m.compactionOpts.ReserveTokens > 0 {
reserveTokens = m.compactionOpts.ReserveTokens
}
@@ -131,7 +136,7 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
if reason == "" {
reason = "compaction cancelled by extension"
}
return nil, fmt.Errorf("%s", reason)
return nil, errors.New(reason)
}
// Extension provided a custom summary — use it directly.
if hookResult.Summary != "" {
@@ -166,27 +171,10 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
firstKeptEntryID = entryIDs[result.CutPoint]
}
if _, err := m.treeSession.AppendCompaction(
result.Summary,
firstKeptEntryID,
result.OriginalTokens,
result.CompactedTokens,
result.MessagesRemoved,
result.ReadFiles,
result.ModifiedFiles,
); err != nil {
return nil, fmt.Errorf("failed to persist compaction entry: %w", err)
if err := m.persistAndEmitCompaction(result.Summary, firstKeptEntryID, result.OriginalTokens, result.CompactedTokens, result.MessagesRemoved, result.ReadFiles, result.ModifiedFiles); err != nil {
return nil, err
}
m.events.emit(CompactionEvent{
Summary: result.Summary,
OriginalTokens: result.OriginalTokens,
CompactedTokens: result.CompactedTokens,
MessagesRemoved: result.MessagesRemoved,
ReadFiles: result.ReadFiles,
ModifiedFiles: result.ModifiedFiles,
})
return result, nil
}
@@ -218,17 +206,6 @@ func (m *Kit) applyCustomCompaction(summary string, messages []fantasy.Message,
recentTokens := compaction.EstimateMessageTokens(messages[cutPoint:])
compactedTokens := summaryTokens + recentTokens
if _, err := m.treeSession.AppendCompaction(
summary,
firstKeptEntryID,
originalTokens,
compactedTokens,
cutPoint,
nil, nil, // no file tracking for custom summaries
); err != nil {
return nil, fmt.Errorf("failed to persist compaction entry: %w", err)
}
result := &CompactionResult{
Summary: summary,
OriginalTokens: originalTokens,
@@ -236,12 +213,39 @@ func (m *Kit) applyCustomCompaction(summary string, messages []fantasy.Message,
MessagesRemoved: cutPoint,
}
m.events.emit(CompactionEvent{
Summary: result.Summary,
OriginalTokens: result.OriginalTokens,
CompactedTokens: result.CompactedTokens,
MessagesRemoved: result.MessagesRemoved,
})
if err := m.persistAndEmitCompaction(summary, firstKeptEntryID, originalTokens, compactedTokens, cutPoint, nil, nil); err != nil {
return nil, err
}
return result, nil
}
// persistAndEmitCompaction writes a CompactionEntry to the session tree and
// emits a CompactionEvent. It is the single implementation shared by
// compactInternal and applyCustomCompaction.
func (m *Kit) persistAndEmitCompaction(
summary, firstKeptEntryID string,
originalTokens, compactedTokens, messagesRemoved int,
readFiles, modifiedFiles []string,
) error {
if _, err := m.treeSession.AppendCompaction(
summary,
firstKeptEntryID,
originalTokens,
compactedTokens,
messagesRemoved,
readFiles,
modifiedFiles,
); err != nil {
return fmt.Errorf("failed to persist compaction entry: %w", err)
}
m.events.emit(CompactionEvent{
Summary: summary,
OriginalTokens: originalTokens,
CompactedTokens: compactedTokens,
MessagesRemoved: messagesRemoved,
ReadFiles: readFiles,
ModifiedFiles: modifiedFiles,
})
return nil
}
+7 -11
View File
@@ -78,20 +78,16 @@ func InitConfig(configFile string, debug bool) error {
viper.AddConfigPath(home)
configLoaded := false
configNames := []string{".kit"}
for _, name := range configNames {
viper.SetConfigName(name)
if err := viper.ReadInConfig(); err == nil {
configPath := viper.ConfigFileUsed()
if err := LoadConfigWithEnvSubstitution(configPath); err != nil {
if strings.Contains(err.Error(), "environment variable substitution failed") {
return fmt.Errorf("error reading config file '%s': %w", configPath, err)
}
continue
viper.SetConfigName(".kit")
if err := viper.ReadInConfig(); err == nil {
configPath := viper.ConfigFileUsed()
if err := LoadConfigWithEnvSubstitution(configPath); err != nil {
if strings.Contains(err.Error(), "environment variable substitution failed") {
return fmt.Errorf("error reading config file '%s': %w", configPath, err)
}
} else {
configLoaded = true
break
}
}
+4 -40
View File
@@ -416,42 +416,6 @@ func (m *Kit) OnTurnEnd(handler func(TurnEndEvent)) func() {
// Subagent event subscriptions
// ---------------------------------------------------------------------------
// subagentListenerSet holds per-tool-call listeners for subagent events.
type subagentListenerSet struct {
mu sync.RWMutex
listeners map[int]EventListener
nextID int
}
func newSubagentListenerSet() *subagentListenerSet {
return &subagentListenerSet{listeners: make(map[int]EventListener)}
}
func (s *subagentListenerSet) add(listener EventListener) func() {
s.mu.Lock()
id := s.nextID
s.nextID++
s.listeners[id] = listener
s.mu.Unlock()
return func() {
s.mu.Lock()
delete(s.listeners, id)
s.mu.Unlock()
}
}
func (s *subagentListenerSet) emit(event Event) {
s.mu.RLock()
snapshot := make([]EventListener, 0, len(s.listeners))
for _, l := range s.listeners {
snapshot = append(snapshot, l)
}
s.mu.RUnlock()
for _, l := range snapshot {
l(event)
}
}
// SubscribeSubagent registers a listener for real-time events from a subagent
// identified by its tool call ID. Returns an unsubscribe function.
//
@@ -470,14 +434,14 @@ func (s *subagentListenerSet) emit(event Event) {
// }
// })
func (m *Kit) SubscribeSubagent(toolCallID string, listener EventListener) func() {
actual, _ := m.subagentListeners.LoadOrStore(toolCallID, newSubagentListenerSet())
return actual.(*subagentListenerSet).add(listener)
actual, _ := m.subagentListeners.LoadOrStore(toolCallID, newEventBus())
return actual.(*eventBus).subscribe(listener)
}
// getSubagentListenerSet returns the listener set for a tool call, or nil.
func (m *Kit) getSubagentListenerSet(toolCallID string) *subagentListenerSet {
func (m *Kit) getSubagentListenerSet(toolCallID string) *eventBus {
if v, ok := m.subagentListeners.Load(toolCallID); ok {
return v.(*subagentListenerSet)
return v.(*eventBus)
}
return nil
}
+23 -2
View File
@@ -140,8 +140,14 @@ func TestEventBusConcurrentSubscribeEmit(t *testing.T) {
wg.Wait()
// We can't assert an exact count because subscribe/emit ordering is
// non-deterministic, but it must not panic or deadlock.
t.Logf("total events received across subscribers: %d", total.Load())
// non-deterministic, but we can assert the count is non-negative and
// that no events were lost (each subscriber that registered before an
// emit must have received it at least partially).
got := total.Load()
if got < 0 {
t.Errorf("expected non-negative total event count, got %d", got)
}
t.Logf("total events received across subscribers: %d", got)
}
// TestEventBusEmitNoListeners verifies emit is a no-op with no subscribers.
@@ -169,6 +175,11 @@ func TestEventTypes(t *testing.T) {
{ToolResultEvent{}, EventToolResult},
{ToolCallContentEvent{}, EventToolCallContent},
{ResponseEvent{}, EventResponse},
{CompactionEvent{}, EventCompaction},
{ReasoningDeltaEvent{}, EventReasoningDelta},
{ToolOutputEvent{}, EventToolOutput},
{StepUsageEvent{}, EventStepUsage},
{SteerConsumedEvent{}, EventSteerConsumed},
}
for _, tt := range tests {
@@ -212,26 +223,36 @@ func TestEventOrdering(t *testing.T) {
EventTurnStart,
EventMessageStart,
EventMessageUpdate,
EventReasoningDelta,
EventToolOutput,
EventToolCall,
EventToolExecutionStart,
EventToolExecutionEnd,
EventToolResult,
EventToolCallContent,
EventMessageEnd,
EventStepUsage,
EventResponse,
EventCompaction,
EventSteerConsumed,
EventTurnEnd,
}
bus.emit(TurnStartEvent{})
bus.emit(MessageStartEvent{})
bus.emit(MessageUpdateEvent{Chunk: "hello"})
bus.emit(ReasoningDeltaEvent{Delta: "thinking..."})
bus.emit(ToolOutputEvent{ToolName: "bash", Chunk: "output"})
bus.emit(ToolCallEvent{ToolName: "bash"})
bus.emit(ToolExecutionStartEvent{ToolName: "bash"})
bus.emit(ToolExecutionEndEvent{ToolName: "bash"})
bus.emit(ToolResultEvent{ToolName: "bash", Result: "ok"})
bus.emit(ToolCallContentEvent{Content: "I'll run bash"})
bus.emit(MessageEndEvent{Content: "done"})
bus.emit(StepUsageEvent{InputTokens: 100})
bus.emit(ResponseEvent{Content: "done"})
bus.emit(CompactionEvent{Summary: "compacted"})
bus.emit(SteerConsumedEvent{Count: 1})
bus.emit(TurnEndEvent{Response: "done"})
if len(types) != len(expected) {
+13 -32
View File
@@ -104,11 +104,9 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
if runner.HasHandlers(extensions.AgentEnd) {
m.Subscribe(func(e Event) {
if ev, ok := e.(TurnEndEvent); ok {
stopReason := ev.StopReason
response := ev.Response
stopReason, response := ev.StopReason, ev.Response
if ev.Error != nil {
stopReason = "error"
response = ""
stopReason, response = "error", ""
} else if stopReason == "" {
stopReason = "completed"
}
@@ -141,7 +139,7 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
// taskByCallID tracks the task description extracted from ToolCall input,
// keyed by toolCallID. Populated on ToolCall, consumed on ToolResult.
taskByCallID := make(map[string]string)
var taskMu = &taskMutex{}
var taskMu sync.Mutex
// Intercept ToolCall to capture the task and subscribe to child events.
m.Subscribe(func(e Event) {
@@ -157,7 +155,9 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
task = t
}
}
taskMu.set(taskByCallID, ev.ToolCallID, task)
taskMu.Lock()
taskByCallID[ev.ToolCallID] = task
taskMu.Unlock()
// Subscribe to child events so we can forward them as SubagentChunkEvents.
if runner.HasHandlers(extensions.SubagentChunk) {
@@ -204,7 +204,9 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
if !ok || ev.ToolName != "subagent" {
return
}
task := taskMu.get(taskByCallID, ev.ToolCallID)
taskMu.Lock()
task := taskByCallID[ev.ToolCallID]
taskMu.Unlock()
_, _ = runner.Emit(extensions.SubagentStartEvent{
ToolCallID: ev.ToolCallID,
Task: task,
@@ -219,8 +221,10 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
if !ok || ev.ToolName != "subagent" {
return
}
task := taskMu.get(taskByCallID, ev.ToolCallID)
taskMu.del(taskByCallID, ev.ToolCallID)
taskMu.Lock()
task := taskByCallID[ev.ToolCallID]
delete(taskByCallID, ev.ToolCallID)
taskMu.Unlock()
errMsg := ""
if ev.IsError {
errMsg = ev.Result
@@ -325,26 +329,3 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
}
}
// taskMutex is a simple mutex-protected map helper used by bridgeExtensions.
// It lives in this file to avoid polluting the kit package with unexported types.
type taskMutex struct {
mu sync.Mutex
}
func (t *taskMutex) set(m map[string]string, key, val string) {
t.mu.Lock()
m[key] = val
t.mu.Unlock()
}
func (t *taskMutex) get(m map[string]string, key string) string {
t.mu.Lock()
defer t.mu.Unlock()
return m[key]
}
func (t *taskMutex) del(m map[string]string, key string) {
t.mu.Lock()
delete(m, key)
t.mu.Unlock()
}
+5
View File
@@ -167,8 +167,13 @@ func (hr *hookRegistry[In, Out]) register(p HookPriority, h func(In) *Out) func(
}
// run executes all hooks in priority order. The first non-nil result wins.
// Returns nil immediately if no hooks are registered.
func (hr *hookRegistry[In, Out]) run(input In) *Out {
hr.mu.RLock()
if len(hr.hooks) == 0 {
hr.mu.RUnlock()
return nil
}
snapshot := make([]hookEntry[In, Out], len(hr.hooks))
copy(snapshot, hr.hooks)
hr.mu.RUnlock()
+36 -41
View File
@@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"strings"
@@ -12,6 +11,7 @@ import (
"time"
"charm.land/fantasy"
charmlog "github.com/charmbracelet/log"
"github.com/mark3labs/kit/internal/agent"
"github.com/mark3labs/kit/internal/config"
@@ -1423,14 +1423,13 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
case msg := <-steerCh:
leftover = append(leftover, msg)
default:
goto drained
m.steerMu.Lock()
m.steerCh = nil
m.leftoverSteer = leftover
m.steerMu.Unlock()
return
}
}
drained:
m.steerMu.Lock()
m.steerCh = nil
m.leftoverSteer = leftover
m.steerMu.Unlock()
}()
ctx = agent.ContextWithSteerCh(ctx, steerCh)
ctx = agent.ContextWithSteerConsumed(ctx, func(count int) {
@@ -1526,8 +1525,12 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
func(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64) {
// Emit step usage event for real-time cost tracking
if viper.GetBool("debug") {
log.Printf("[DEBUG] Kit.generate emitting StepUsageEvent: input=%d output=%d cacheRead=%d cacheCreate=%d",
inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens)
charmlog.Debug("Kit.generate emitting StepUsageEvent",
"input", inputTokens,
"output", outputTokens,
"cacheRead", cacheReadTokens,
"cacheCreate", cacheCreationTokens,
)
}
m.events.emit(StepUsageEvent{
InputTokens: uint64(inputTokens),
@@ -1568,30 +1571,28 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
}
// Run BeforeTurn hooks — can modify the prompt, inject system/context messages.
if m.beforeTurn.hasHooks() {
if hookResult := m.beforeTurn.run(BeforeTurnHook{Prompt: prompt}); hookResult != nil {
// Override prompt text in the last user message, preserving
// any file parts (e.g. clipboard images).
if hookResult.Prompt != nil {
for i := len(preMessages) - 1; i >= 0; i-- {
if preMessages[i].Role == fantasy.MessageRoleUser {
files := extractFileParts(preMessages[i])
preMessages[i] = fantasy.NewUserMessage(*hookResult.Prompt, files...)
break
}
if hookResult := m.beforeTurn.run(BeforeTurnHook{Prompt: prompt}); hookResult != nil {
// Override prompt text in the last user message, preserving
// any file parts (e.g. clipboard images).
if hookResult.Prompt != nil {
for i := len(preMessages) - 1; i >= 0; i-- {
if preMessages[i].Role == fantasy.MessageRoleUser {
files := extractFileParts(preMessages[i])
preMessages[i] = fantasy.NewUserMessage(*hookResult.Prompt, files...)
break
}
}
// Inject messages before the original preMessages.
var injected []fantasy.Message
if hookResult.SystemPrompt != nil {
injected = append(injected, fantasy.NewSystemMessage(*hookResult.SystemPrompt))
}
if hookResult.InjectText != nil {
injected = append(injected, fantasy.NewUserMessage(*hookResult.InjectText))
}
if len(injected) > 0 {
preMessages = append(injected, preMessages...)
}
}
// Inject messages before the original preMessages.
var injected []fantasy.Message
if hookResult.SystemPrompt != nil {
injected = append(injected, fantasy.NewSystemMessage(*hookResult.SystemPrompt))
}
if hookResult.InjectText != nil {
injected = append(injected, fantasy.NewUserMessage(*hookResult.InjectText))
}
if len(injected) > 0 {
preMessages = append(injected, preMessages...)
}
}
@@ -1609,10 +1610,8 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
messages := m.treeSession.GetFantasyMessages()
// Run ContextPrepare hooks — extensions can filter, reorder, or inject messages.
if m.contextPrepare.hasHooks() {
if hookResult := m.contextPrepare.run(ContextPrepareHook{Messages: messages}); hookResult != nil && hookResult.Messages != nil {
messages = hookResult.Messages
}
if hookResult := m.contextPrepare.run(ContextPrepareHook{Messages: messages}); hookResult != nil && hookResult.Messages != nil {
messages = hookResult.Messages
}
sentCount := len(messages)
@@ -1636,9 +1635,7 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
}
m.events.emit(TurnEndEvent{Error: err})
// Run AfterTurn hooks even on error.
if m.afterTurn.hasHooks() {
m.afterTurn.run(AfterTurnHook{Error: err})
}
m.afterTurn.run(AfterTurnHook{Error: err})
return nil, err
}
@@ -1669,9 +1666,7 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
m.events.emit(TurnEndEvent{Response: responseText, StopReason: stopReason})
// Run AfterTurn hooks.
if m.afterTurn.hasHooks() {
m.afterTurn.run(AfterTurnHook{Response: responseText})
}
m.afterTurn.run(AfterTurnHook{Response: responseText})
// Build TurnResult with usage stats.
turnResult := &TurnResult{
+9 -21
View File
@@ -91,22 +91,12 @@ var globalSkillCache skillCache
func (m *Kit) DiscoverSkillsForExtension() []extensions.Skill {
cwd, _ := os.Getwd()
// Check cache first
globalSkillCache.mu.RLock()
if len(globalSkillCache.skills) > 0 {
globalSkillCache.mu.RUnlock()
return m.convertSkills(globalSkillCache.skills)
}
globalSkillCache.mu.RUnlock()
// Load fresh
skillList, _ := skills.LoadSkills(cwd)
globalSkillCache.mu.Lock()
globalSkillCache.skills = skillList
globalSkillCache.mu.Unlock()
return m.convertSkills(skillList)
defer globalSkillCache.mu.Unlock()
if len(globalSkillCache.skills) == 0 {
globalSkillCache.skills, _ = skills.LoadSkills(cwd)
}
return m.convertSkills(globalSkillCache.skills)
}
// LoadSkillForExtension loads a single skill file for extensions.
@@ -140,12 +130,10 @@ func (m *Kit) convertSkill(s *skills.Skill) *extensions.Skill {
}
// convertSkills converts a slice of skills.
func (m *Kit) convertSkills(skills []*skills.Skill) []extensions.Skill {
result := make([]extensions.Skill, 0, len(skills))
for _, s := range skills {
if converted := m.convertSkill(s); converted != nil {
result = append(result, *converted)
}
func (m *Kit) convertSkills(skillList []*skills.Skill) []extensions.Skill {
result := make([]extensions.Skill, 0, len(skillList))
for _, s := range skillList {
result = append(result, *m.convertSkill(s))
}
return result
}
+30 -34
View File
@@ -3,6 +3,7 @@ package kit
import (
"regexp"
"strings"
"sync"
"github.com/mark3labs/kit/internal/extensions"
"github.com/mark3labs/kit/internal/models"
@@ -34,16 +35,17 @@ func ParseTemplate(name, content string) extensions.PromptTemplate {
}
// RenderTemplate substitutes variables into template content.
// Handles {{name}} and {{ name }} (any whitespace) placeholders.
func RenderTemplate(tpl extensions.PromptTemplate, vars map[string]string) string {
result := tpl.Content
for name, value := range vars {
placeholder := "{{" + name + "}}"
result = strings.ReplaceAll(result, placeholder, value)
// Also handle with spaces
placeholderSpaced := "{{ " + name + " }}"
result = strings.ReplaceAll(result, placeholderSpaced, value)
}
return result
return varRegex.ReplaceAllStringFunc(tpl.Content, func(m string) string {
sub := varRegex.FindStringSubmatch(m)
if len(sub) > 1 {
if v, ok := vars[sub[1]]; ok {
return v
}
}
return m
})
}
// ParseArguments parses command-line style arguments.
@@ -58,13 +60,10 @@ func ParseArguments(input string, pattern extensions.ArgumentPattern) extensions
return result
}
// First field is the command itself (if present)
// First field is the command itself (if present); skip it.
startIdx := 0
if len(fields) > 0 && !strings.HasPrefix(fields[0], "-") {
// Check if it's a command name or positional arg
if len(pattern.Positional) == 0 || !isFlag(fields[0], pattern.Flags) {
startIdx = 1 // Skip command name
}
startIdx = 1
}
// Parse flags
@@ -224,16 +223,6 @@ func parseFields(input string) []string {
return fields
}
// isFlag checks if a field is a known flag.
func isFlag(field string, flags map[string]string) bool {
if strings.HasPrefix(field, "--") {
return true
}
if strings.HasPrefix(field, "-") && len(field) > 1 {
return true
}
return false
}
// EvaluateModelConditional checks if condition matches current model.
// Condition supports wildcards: * matches any, ? matches single char.
@@ -248,17 +237,24 @@ func EvaluateModelConditional(currentModel, condition string) bool {
return false
}
// matchModelPattern matches a model against a pattern with wildcards.
func matchModelPattern(model, pattern string) bool {
// Convert pattern to regexp
pattern = strings.ReplaceAll(pattern, "*", ".*")
pattern = strings.ReplaceAll(pattern, "?", ".")
pattern = "^" + pattern + "$"
// modelPatternCache caches compiled regexps for model glob patterns.
var modelPatternCache sync.Map
re, err := regexp.Compile(pattern)
if err != nil {
// Fallback: exact match
return model == pattern
// matchModelPattern matches a model against a pattern with wildcards.
// Compiled regexps are cached to avoid recompilation on hot paths.
func matchModelPattern(model, pattern string) bool {
rePattern := "^" + strings.ReplaceAll(strings.ReplaceAll(pattern, "*", ".*"), "?", ".") + "$"
var re *regexp.Regexp
if v, ok := modelPatternCache.Load(rePattern); ok {
re = v.(*regexp.Regexp)
} else {
compiled, err := regexp.Compile(rePattern)
if err != nil {
// Fallback: exact match
return model == pattern
}
modelPatternCache.Store(rePattern, compiled)
re = compiled
}
return re.MatchString(model)
}