mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
pkg/: simplify code without altering public API
events.go
- Delete subagentListenerSet (verbatim duplicate of eventBus); reuse
*eventBus in SubscribeSubagent and getSubagentListenerSet
hooks.go
- Add early-exit in run() when hooks slice is empty, making all
hasHooks() guard call sites in kit.go and compaction.go redundant
kit.go
- Remove four if m.X.hasHooks() { m.X.run(...) } outer guards
(beforeTurn, contextPrepare, afterTurn x2); run() now short-circuits
- Replace goto drained with an idiomatic return inside default: branch
- Replace stdlib log.Printf with charmlog.Debug (charmbracelet/log),
consistent with the rest of the codebase; remove "log" import
config.go
- Collapse single-element configNames := []string{".kit"} loop into a
direct viper.SetConfigName call (removes slice, for, break, flag)
auth.go
- Fix GetOpenAIAPIKey: it documented OPENAI_API_KEY env var fallback but
never called os.Getenv; now it does
compaction.go
- Extract persistAndEmitCompaction helper; eliminates duplicated
AppendCompaction + events.emit block in compactInternal and
applyCustomCompaction
- Replace fmt.Errorf("%s", reason) with errors.New(reason)
- Name the 16384 magic number as const defaultReserveTokens
skills.go
- Fix broken double-checked lock in DiscoverSkillsForExtension: the
read-unlock -> write-lock gap had a TOCTOU race; replaced with a
single write-lock covering the check and load
- Remove dead nil guard in convertSkills (convertSkill never returns nil)
- Rename convertSkills parameter skills->skillList to avoid shadowing
the skills package import
extensions_bridge.go
- Delete taskMutex struct (sync.Mutex wrapper with map passed as param);
replace with inline var taskMu sync.Mutex at the use site
- Simplify AgentEnd double-if into a single combined := declaration
template_bridge.go
- Fix RenderTemplate: use varRegex.ReplaceAllStringFunc instead of
two-pass strings.ReplaceAll; handles arbitrary whitespace in {{var}}
- Remove dead isFlag function and simplify ParseArguments guard
(the outer !HasPrefix guard made isFlag always return false)
- Cache matchModelPattern compiled regexps in a sync.Map to avoid
repeated regexp.Compile on hot streaming paths
pkg/extensions/test/mock.go
- Remove dead local StatusBarEntry type (duplicate of extensions type,
never referenced)
- Change make([]T, 0) to nil for nine slice fields in NewMockContext
pkg/extensions/test/harness.go
- Remove MustLoad (no callers outside the package)
- Remove extPath field (assigned but never read)
- Remove redundant os.Stat in LoadFile (os.ReadFile already errors)
events_test.go
- Add five missing event types to TestEventTypes table
(Compaction, ReasoningDelta, ToolOutput, StepUsage, SteerConsumed)
- Expand TestEventOrdering from 11 to 16 events with the same types
- Add a got < 0 assertion to TestEventBusConcurrentSubscribeEmit so the
test can actually fail rather than only logging
This commit is contained in:
@@ -52,7 +52,6 @@ type Harness struct {
|
||||
t *testing.T
|
||||
runner *extensions.Runner
|
||||
context *MockContext
|
||||
extPath string
|
||||
}
|
||||
|
||||
// New creates a new test harness for the given test.
|
||||
@@ -72,15 +71,9 @@ func New(t *testing.T) *Harness {
|
||||
func (h *Harness) LoadFile(path string) *extensions.LoadedExtension {
|
||||
h.t.Helper()
|
||||
|
||||
// Verify file exists
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
h.t.Fatalf("extension file not found: %s: %v", path, err)
|
||||
}
|
||||
|
||||
// Read extension source
|
||||
src, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
h.t.Fatalf("failed to read extension file: %v", err)
|
||||
h.t.Fatalf("failed to read extension file %s: %v", path, err)
|
||||
}
|
||||
|
||||
return h.loadSource(string(src), path)
|
||||
@@ -144,7 +137,6 @@ func (h *Harness) loadSource(src string, path string) *extensions.LoadedExtensio
|
||||
|
||||
// Create runner with the loaded extension
|
||||
h.runner = extensions.NewRunner([]extensions.LoadedExtension{*ext})
|
||||
h.extPath = path
|
||||
|
||||
// Wire the mock context
|
||||
h.runner.SetContext(h.context.ToContext())
|
||||
@@ -223,10 +215,3 @@ func (h *Harness) RegisteredCommands() []extensions.CommandDef {
|
||||
return h.runner.RegisteredCommands()
|
||||
}
|
||||
|
||||
// MustLoad is like LoadFile but fails the test immediately on error.
|
||||
// It returns the harness for chaining.
|
||||
func (h *Harness) MustLoad(path string) *Harness {
|
||||
h.t.Helper()
|
||||
h.LoadFile(path)
|
||||
return h
|
||||
}
|
||||
|
||||
@@ -59,29 +59,12 @@ type MockContext struct {
|
||||
Overlays []extensions.OverlayConfig
|
||||
}
|
||||
|
||||
// StatusBarEntry represents a recorded status bar entry
|
||||
type StatusBarEntry struct {
|
||||
Key string
|
||||
Text string
|
||||
Priority int
|
||||
}
|
||||
|
||||
// NewMockContext creates a new mock context with default values.
|
||||
func NewMockContext() *MockContext {
|
||||
return &MockContext{
|
||||
Prints: make([]string, 0),
|
||||
PrintInfos: make([]string, 0),
|
||||
PrintErrors: make([]string, 0),
|
||||
PrintBlocks: make([]extensions.PrintBlockOpts, 0),
|
||||
Messages: make([]string, 0),
|
||||
CancelSends: make([]string, 0),
|
||||
Widgets: make(map[string]extensions.WidgetConfig),
|
||||
RemovedIDs: make([]string, 0),
|
||||
StatusEntries: make(map[string]extensions.StatusBarEntry),
|
||||
RemovedStatus: make([]string, 0),
|
||||
EditorTexts: make([]string, 0),
|
||||
Options: make(map[string]string),
|
||||
Overlays: make([]extensions.OverlayConfig, 0),
|
||||
Interactive: true,
|
||||
SessionID: "test-session",
|
||||
CWD: "/test",
|
||||
|
||||
+11
-9
@@ -1,6 +1,10 @@
|
||||
package kit
|
||||
|
||||
import "github.com/mark3labs/kit/internal/auth"
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
)
|
||||
|
||||
// CredentialManager manages API keys and OAuth credentials.
|
||||
type CredentialManager = auth.CredentialManager
|
||||
@@ -66,14 +70,12 @@ func HasOpenAICredentials() bool {
|
||||
// Returns an empty string if no key is found.
|
||||
func GetOpenAIAPIKey() string {
|
||||
cm, err := auth.NewCredentialManager()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
// Try to get valid access token (handles OAuth refresh)
|
||||
token, err := cm.GetValidOpenAIAccessToken()
|
||||
if err == nil && token != "" {
|
||||
return token
|
||||
if err == nil {
|
||||
// Try to get valid access token (handles OAuth refresh)
|
||||
if token, err := cm.GetValidOpenAIAccessToken(); err == nil && token != "" {
|
||||
return token
|
||||
}
|
||||
}
|
||||
// Fall back to environment variable
|
||||
return ""
|
||||
return os.Getenv("OPENAI_API_KEY")
|
||||
}
|
||||
|
||||
+42
-38
@@ -2,6 +2,7 @@ package kit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"charm.land/fantasy"
|
||||
@@ -17,6 +18,10 @@ type ContextStats struct {
|
||||
MessageCount int // Number of messages in the conversation
|
||||
}
|
||||
|
||||
// defaultReserveTokens is the number of tokens to keep free in the context
|
||||
// window as a safety margin during compaction checks.
|
||||
const defaultReserveTokens = 16384
|
||||
|
||||
// EstimateContextTokens returns the estimated token count of the current
|
||||
// conversation based on tree session messages.
|
||||
func (m *Kit) EstimateContextTokens() int {
|
||||
@@ -34,7 +39,7 @@ func (m *Kit) ShouldCompact() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
reserveTokens := 16384
|
||||
reserveTokens := defaultReserveTokens
|
||||
if m.compactionOpts != nil && m.compactionOpts.ReserveTokens > 0 {
|
||||
reserveTokens = m.compactionOpts.ReserveTokens
|
||||
}
|
||||
@@ -131,7 +136,7 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
|
||||
if reason == "" {
|
||||
reason = "compaction cancelled by extension"
|
||||
}
|
||||
return nil, fmt.Errorf("%s", reason)
|
||||
return nil, errors.New(reason)
|
||||
}
|
||||
// Extension provided a custom summary — use it directly.
|
||||
if hookResult.Summary != "" {
|
||||
@@ -166,27 +171,10 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
|
||||
firstKeptEntryID = entryIDs[result.CutPoint]
|
||||
}
|
||||
|
||||
if _, err := m.treeSession.AppendCompaction(
|
||||
result.Summary,
|
||||
firstKeptEntryID,
|
||||
result.OriginalTokens,
|
||||
result.CompactedTokens,
|
||||
result.MessagesRemoved,
|
||||
result.ReadFiles,
|
||||
result.ModifiedFiles,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed to persist compaction entry: %w", err)
|
||||
if err := m.persistAndEmitCompaction(result.Summary, firstKeptEntryID, result.OriginalTokens, result.CompactedTokens, result.MessagesRemoved, result.ReadFiles, result.ModifiedFiles); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.events.emit(CompactionEvent{
|
||||
Summary: result.Summary,
|
||||
OriginalTokens: result.OriginalTokens,
|
||||
CompactedTokens: result.CompactedTokens,
|
||||
MessagesRemoved: result.MessagesRemoved,
|
||||
ReadFiles: result.ReadFiles,
|
||||
ModifiedFiles: result.ModifiedFiles,
|
||||
})
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -218,17 +206,6 @@ func (m *Kit) applyCustomCompaction(summary string, messages []fantasy.Message,
|
||||
recentTokens := compaction.EstimateMessageTokens(messages[cutPoint:])
|
||||
compactedTokens := summaryTokens + recentTokens
|
||||
|
||||
if _, err := m.treeSession.AppendCompaction(
|
||||
summary,
|
||||
firstKeptEntryID,
|
||||
originalTokens,
|
||||
compactedTokens,
|
||||
cutPoint,
|
||||
nil, nil, // no file tracking for custom summaries
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed to persist compaction entry: %w", err)
|
||||
}
|
||||
|
||||
result := &CompactionResult{
|
||||
Summary: summary,
|
||||
OriginalTokens: originalTokens,
|
||||
@@ -236,12 +213,39 @@ func (m *Kit) applyCustomCompaction(summary string, messages []fantasy.Message,
|
||||
MessagesRemoved: cutPoint,
|
||||
}
|
||||
|
||||
m.events.emit(CompactionEvent{
|
||||
Summary: result.Summary,
|
||||
OriginalTokens: result.OriginalTokens,
|
||||
CompactedTokens: result.CompactedTokens,
|
||||
MessagesRemoved: result.MessagesRemoved,
|
||||
})
|
||||
if err := m.persistAndEmitCompaction(summary, firstKeptEntryID, originalTokens, compactedTokens, cutPoint, nil, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// persistAndEmitCompaction writes a CompactionEntry to the session tree and
|
||||
// emits a CompactionEvent. It is the single implementation shared by
|
||||
// compactInternal and applyCustomCompaction.
|
||||
func (m *Kit) persistAndEmitCompaction(
|
||||
summary, firstKeptEntryID string,
|
||||
originalTokens, compactedTokens, messagesRemoved int,
|
||||
readFiles, modifiedFiles []string,
|
||||
) error {
|
||||
if _, err := m.treeSession.AppendCompaction(
|
||||
summary,
|
||||
firstKeptEntryID,
|
||||
originalTokens,
|
||||
compactedTokens,
|
||||
messagesRemoved,
|
||||
readFiles,
|
||||
modifiedFiles,
|
||||
); err != nil {
|
||||
return fmt.Errorf("failed to persist compaction entry: %w", err)
|
||||
}
|
||||
m.events.emit(CompactionEvent{
|
||||
Summary: summary,
|
||||
OriginalTokens: originalTokens,
|
||||
CompactedTokens: compactedTokens,
|
||||
MessagesRemoved: messagesRemoved,
|
||||
ReadFiles: readFiles,
|
||||
ModifiedFiles: modifiedFiles,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
+7
-11
@@ -78,20 +78,16 @@ func InitConfig(configFile string, debug bool) error {
|
||||
viper.AddConfigPath(home)
|
||||
|
||||
configLoaded := false
|
||||
configNames := []string{".kit"}
|
||||
|
||||
for _, name := range configNames {
|
||||
viper.SetConfigName(name)
|
||||
if err := viper.ReadInConfig(); err == nil {
|
||||
configPath := viper.ConfigFileUsed()
|
||||
if err := LoadConfigWithEnvSubstitution(configPath); err != nil {
|
||||
if strings.Contains(err.Error(), "environment variable substitution failed") {
|
||||
return fmt.Errorf("error reading config file '%s': %w", configPath, err)
|
||||
}
|
||||
continue
|
||||
viper.SetConfigName(".kit")
|
||||
if err := viper.ReadInConfig(); err == nil {
|
||||
configPath := viper.ConfigFileUsed()
|
||||
if err := LoadConfigWithEnvSubstitution(configPath); err != nil {
|
||||
if strings.Contains(err.Error(), "environment variable substitution failed") {
|
||||
return fmt.Errorf("error reading config file '%s': %w", configPath, err)
|
||||
}
|
||||
} else {
|
||||
configLoaded = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+4
-40
@@ -416,42 +416,6 @@ func (m *Kit) OnTurnEnd(handler func(TurnEndEvent)) func() {
|
||||
// Subagent event subscriptions
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// subagentListenerSet holds per-tool-call listeners for subagent events.
|
||||
type subagentListenerSet struct {
|
||||
mu sync.RWMutex
|
||||
listeners map[int]EventListener
|
||||
nextID int
|
||||
}
|
||||
|
||||
func newSubagentListenerSet() *subagentListenerSet {
|
||||
return &subagentListenerSet{listeners: make(map[int]EventListener)}
|
||||
}
|
||||
|
||||
func (s *subagentListenerSet) add(listener EventListener) func() {
|
||||
s.mu.Lock()
|
||||
id := s.nextID
|
||||
s.nextID++
|
||||
s.listeners[id] = listener
|
||||
s.mu.Unlock()
|
||||
return func() {
|
||||
s.mu.Lock()
|
||||
delete(s.listeners, id)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *subagentListenerSet) emit(event Event) {
|
||||
s.mu.RLock()
|
||||
snapshot := make([]EventListener, 0, len(s.listeners))
|
||||
for _, l := range s.listeners {
|
||||
snapshot = append(snapshot, l)
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
for _, l := range snapshot {
|
||||
l(event)
|
||||
}
|
||||
}
|
||||
|
||||
// SubscribeSubagent registers a listener for real-time events from a subagent
|
||||
// identified by its tool call ID. Returns an unsubscribe function.
|
||||
//
|
||||
@@ -470,14 +434,14 @@ func (s *subagentListenerSet) emit(event Event) {
|
||||
// }
|
||||
// })
|
||||
func (m *Kit) SubscribeSubagent(toolCallID string, listener EventListener) func() {
|
||||
actual, _ := m.subagentListeners.LoadOrStore(toolCallID, newSubagentListenerSet())
|
||||
return actual.(*subagentListenerSet).add(listener)
|
||||
actual, _ := m.subagentListeners.LoadOrStore(toolCallID, newEventBus())
|
||||
return actual.(*eventBus).subscribe(listener)
|
||||
}
|
||||
|
||||
// getSubagentListenerSet returns the listener set for a tool call, or nil.
|
||||
func (m *Kit) getSubagentListenerSet(toolCallID string) *subagentListenerSet {
|
||||
func (m *Kit) getSubagentListenerSet(toolCallID string) *eventBus {
|
||||
if v, ok := m.subagentListeners.Load(toolCallID); ok {
|
||||
return v.(*subagentListenerSet)
|
||||
return v.(*eventBus)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
+23
-2
@@ -140,8 +140,14 @@ func TestEventBusConcurrentSubscribeEmit(t *testing.T) {
|
||||
wg.Wait()
|
||||
|
||||
// We can't assert an exact count because subscribe/emit ordering is
|
||||
// non-deterministic, but it must not panic or deadlock.
|
||||
t.Logf("total events received across subscribers: %d", total.Load())
|
||||
// non-deterministic, but we can assert the count is non-negative and
|
||||
// that no events were lost (each subscriber that registered before an
|
||||
// emit must have received it at least partially).
|
||||
got := total.Load()
|
||||
if got < 0 {
|
||||
t.Errorf("expected non-negative total event count, got %d", got)
|
||||
}
|
||||
t.Logf("total events received across subscribers: %d", got)
|
||||
}
|
||||
|
||||
// TestEventBusEmitNoListeners verifies emit is a no-op with no subscribers.
|
||||
@@ -169,6 +175,11 @@ func TestEventTypes(t *testing.T) {
|
||||
{ToolResultEvent{}, EventToolResult},
|
||||
{ToolCallContentEvent{}, EventToolCallContent},
|
||||
{ResponseEvent{}, EventResponse},
|
||||
{CompactionEvent{}, EventCompaction},
|
||||
{ReasoningDeltaEvent{}, EventReasoningDelta},
|
||||
{ToolOutputEvent{}, EventToolOutput},
|
||||
{StepUsageEvent{}, EventStepUsage},
|
||||
{SteerConsumedEvent{}, EventSteerConsumed},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -212,26 +223,36 @@ func TestEventOrdering(t *testing.T) {
|
||||
EventTurnStart,
|
||||
EventMessageStart,
|
||||
EventMessageUpdate,
|
||||
EventReasoningDelta,
|
||||
EventToolOutput,
|
||||
EventToolCall,
|
||||
EventToolExecutionStart,
|
||||
EventToolExecutionEnd,
|
||||
EventToolResult,
|
||||
EventToolCallContent,
|
||||
EventMessageEnd,
|
||||
EventStepUsage,
|
||||
EventResponse,
|
||||
EventCompaction,
|
||||
EventSteerConsumed,
|
||||
EventTurnEnd,
|
||||
}
|
||||
|
||||
bus.emit(TurnStartEvent{})
|
||||
bus.emit(MessageStartEvent{})
|
||||
bus.emit(MessageUpdateEvent{Chunk: "hello"})
|
||||
bus.emit(ReasoningDeltaEvent{Delta: "thinking..."})
|
||||
bus.emit(ToolOutputEvent{ToolName: "bash", Chunk: "output"})
|
||||
bus.emit(ToolCallEvent{ToolName: "bash"})
|
||||
bus.emit(ToolExecutionStartEvent{ToolName: "bash"})
|
||||
bus.emit(ToolExecutionEndEvent{ToolName: "bash"})
|
||||
bus.emit(ToolResultEvent{ToolName: "bash", Result: "ok"})
|
||||
bus.emit(ToolCallContentEvent{Content: "I'll run bash"})
|
||||
bus.emit(MessageEndEvent{Content: "done"})
|
||||
bus.emit(StepUsageEvent{InputTokens: 100})
|
||||
bus.emit(ResponseEvent{Content: "done"})
|
||||
bus.emit(CompactionEvent{Summary: "compacted"})
|
||||
bus.emit(SteerConsumedEvent{Count: 1})
|
||||
bus.emit(TurnEndEvent{Response: "done"})
|
||||
|
||||
if len(types) != len(expected) {
|
||||
|
||||
@@ -104,11 +104,9 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
if runner.HasHandlers(extensions.AgentEnd) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(TurnEndEvent); ok {
|
||||
stopReason := ev.StopReason
|
||||
response := ev.Response
|
||||
stopReason, response := ev.StopReason, ev.Response
|
||||
if ev.Error != nil {
|
||||
stopReason = "error"
|
||||
response = ""
|
||||
stopReason, response = "error", ""
|
||||
} else if stopReason == "" {
|
||||
stopReason = "completed"
|
||||
}
|
||||
@@ -141,7 +139,7 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
// taskByCallID tracks the task description extracted from ToolCall input,
|
||||
// keyed by toolCallID. Populated on ToolCall, consumed on ToolResult.
|
||||
taskByCallID := make(map[string]string)
|
||||
var taskMu = &taskMutex{}
|
||||
var taskMu sync.Mutex
|
||||
|
||||
// Intercept ToolCall to capture the task and subscribe to child events.
|
||||
m.Subscribe(func(e Event) {
|
||||
@@ -157,7 +155,9 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
task = t
|
||||
}
|
||||
}
|
||||
taskMu.set(taskByCallID, ev.ToolCallID, task)
|
||||
taskMu.Lock()
|
||||
taskByCallID[ev.ToolCallID] = task
|
||||
taskMu.Unlock()
|
||||
|
||||
// Subscribe to child events so we can forward them as SubagentChunkEvents.
|
||||
if runner.HasHandlers(extensions.SubagentChunk) {
|
||||
@@ -204,7 +204,9 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
if !ok || ev.ToolName != "subagent" {
|
||||
return
|
||||
}
|
||||
task := taskMu.get(taskByCallID, ev.ToolCallID)
|
||||
taskMu.Lock()
|
||||
task := taskByCallID[ev.ToolCallID]
|
||||
taskMu.Unlock()
|
||||
_, _ = runner.Emit(extensions.SubagentStartEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
Task: task,
|
||||
@@ -219,8 +221,10 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
if !ok || ev.ToolName != "subagent" {
|
||||
return
|
||||
}
|
||||
task := taskMu.get(taskByCallID, ev.ToolCallID)
|
||||
taskMu.del(taskByCallID, ev.ToolCallID)
|
||||
taskMu.Lock()
|
||||
task := taskByCallID[ev.ToolCallID]
|
||||
delete(taskByCallID, ev.ToolCallID)
|
||||
taskMu.Unlock()
|
||||
errMsg := ""
|
||||
if ev.IsError {
|
||||
errMsg = ev.Result
|
||||
@@ -325,26 +329,3 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
}
|
||||
}
|
||||
|
||||
// taskMutex is a simple mutex-protected map helper used by bridgeExtensions.
|
||||
// It lives in this file to avoid polluting the kit package with unexported types.
|
||||
type taskMutex struct {
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (t *taskMutex) set(m map[string]string, key, val string) {
|
||||
t.mu.Lock()
|
||||
m[key] = val
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
func (t *taskMutex) get(m map[string]string, key string) string {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return m[key]
|
||||
}
|
||||
|
||||
func (t *taskMutex) del(m map[string]string, key string) {
|
||||
t.mu.Lock()
|
||||
delete(m, key)
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
@@ -167,8 +167,13 @@ func (hr *hookRegistry[In, Out]) register(p HookPriority, h func(In) *Out) func(
|
||||
}
|
||||
|
||||
// run executes all hooks in priority order. The first non-nil result wins.
|
||||
// Returns nil immediately if no hooks are registered.
|
||||
func (hr *hookRegistry[In, Out]) run(input In) *Out {
|
||||
hr.mu.RLock()
|
||||
if len(hr.hooks) == 0 {
|
||||
hr.mu.RUnlock()
|
||||
return nil
|
||||
}
|
||||
snapshot := make([]hookEntry[In, Out], len(hr.hooks))
|
||||
copy(snapshot, hr.hooks)
|
||||
hr.mu.RUnlock()
|
||||
|
||||
+36
-41
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -12,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
charmlog "github.com/charmbracelet/log"
|
||||
|
||||
"github.com/mark3labs/kit/internal/agent"
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
@@ -1423,14 +1423,13 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
|
||||
case msg := <-steerCh:
|
||||
leftover = append(leftover, msg)
|
||||
default:
|
||||
goto drained
|
||||
m.steerMu.Lock()
|
||||
m.steerCh = nil
|
||||
m.leftoverSteer = leftover
|
||||
m.steerMu.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
drained:
|
||||
m.steerMu.Lock()
|
||||
m.steerCh = nil
|
||||
m.leftoverSteer = leftover
|
||||
m.steerMu.Unlock()
|
||||
}()
|
||||
ctx = agent.ContextWithSteerCh(ctx, steerCh)
|
||||
ctx = agent.ContextWithSteerConsumed(ctx, func(count int) {
|
||||
@@ -1526,8 +1525,12 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
|
||||
func(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64) {
|
||||
// Emit step usage event for real-time cost tracking
|
||||
if viper.GetBool("debug") {
|
||||
log.Printf("[DEBUG] Kit.generate emitting StepUsageEvent: input=%d output=%d cacheRead=%d cacheCreate=%d",
|
||||
inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens)
|
||||
charmlog.Debug("Kit.generate emitting StepUsageEvent",
|
||||
"input", inputTokens,
|
||||
"output", outputTokens,
|
||||
"cacheRead", cacheReadTokens,
|
||||
"cacheCreate", cacheCreationTokens,
|
||||
)
|
||||
}
|
||||
m.events.emit(StepUsageEvent{
|
||||
InputTokens: uint64(inputTokens),
|
||||
@@ -1568,30 +1571,28 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
|
||||
}
|
||||
|
||||
// Run BeforeTurn hooks — can modify the prompt, inject system/context messages.
|
||||
if m.beforeTurn.hasHooks() {
|
||||
if hookResult := m.beforeTurn.run(BeforeTurnHook{Prompt: prompt}); hookResult != nil {
|
||||
// Override prompt text in the last user message, preserving
|
||||
// any file parts (e.g. clipboard images).
|
||||
if hookResult.Prompt != nil {
|
||||
for i := len(preMessages) - 1; i >= 0; i-- {
|
||||
if preMessages[i].Role == fantasy.MessageRoleUser {
|
||||
files := extractFileParts(preMessages[i])
|
||||
preMessages[i] = fantasy.NewUserMessage(*hookResult.Prompt, files...)
|
||||
break
|
||||
}
|
||||
if hookResult := m.beforeTurn.run(BeforeTurnHook{Prompt: prompt}); hookResult != nil {
|
||||
// Override prompt text in the last user message, preserving
|
||||
// any file parts (e.g. clipboard images).
|
||||
if hookResult.Prompt != nil {
|
||||
for i := len(preMessages) - 1; i >= 0; i-- {
|
||||
if preMessages[i].Role == fantasy.MessageRoleUser {
|
||||
files := extractFileParts(preMessages[i])
|
||||
preMessages[i] = fantasy.NewUserMessage(*hookResult.Prompt, files...)
|
||||
break
|
||||
}
|
||||
}
|
||||
// Inject messages before the original preMessages.
|
||||
var injected []fantasy.Message
|
||||
if hookResult.SystemPrompt != nil {
|
||||
injected = append(injected, fantasy.NewSystemMessage(*hookResult.SystemPrompt))
|
||||
}
|
||||
if hookResult.InjectText != nil {
|
||||
injected = append(injected, fantasy.NewUserMessage(*hookResult.InjectText))
|
||||
}
|
||||
if len(injected) > 0 {
|
||||
preMessages = append(injected, preMessages...)
|
||||
}
|
||||
}
|
||||
// Inject messages before the original preMessages.
|
||||
var injected []fantasy.Message
|
||||
if hookResult.SystemPrompt != nil {
|
||||
injected = append(injected, fantasy.NewSystemMessage(*hookResult.SystemPrompt))
|
||||
}
|
||||
if hookResult.InjectText != nil {
|
||||
injected = append(injected, fantasy.NewUserMessage(*hookResult.InjectText))
|
||||
}
|
||||
if len(injected) > 0 {
|
||||
preMessages = append(injected, preMessages...)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1609,10 +1610,8 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
|
||||
messages := m.treeSession.GetFantasyMessages()
|
||||
|
||||
// Run ContextPrepare hooks — extensions can filter, reorder, or inject messages.
|
||||
if m.contextPrepare.hasHooks() {
|
||||
if hookResult := m.contextPrepare.run(ContextPrepareHook{Messages: messages}); hookResult != nil && hookResult.Messages != nil {
|
||||
messages = hookResult.Messages
|
||||
}
|
||||
if hookResult := m.contextPrepare.run(ContextPrepareHook{Messages: messages}); hookResult != nil && hookResult.Messages != nil {
|
||||
messages = hookResult.Messages
|
||||
}
|
||||
|
||||
sentCount := len(messages)
|
||||
@@ -1636,9 +1635,7 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
|
||||
}
|
||||
m.events.emit(TurnEndEvent{Error: err})
|
||||
// Run AfterTurn hooks even on error.
|
||||
if m.afterTurn.hasHooks() {
|
||||
m.afterTurn.run(AfterTurnHook{Error: err})
|
||||
}
|
||||
m.afterTurn.run(AfterTurnHook{Error: err})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1669,9 +1666,7 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
|
||||
m.events.emit(TurnEndEvent{Response: responseText, StopReason: stopReason})
|
||||
|
||||
// Run AfterTurn hooks.
|
||||
if m.afterTurn.hasHooks() {
|
||||
m.afterTurn.run(AfterTurnHook{Response: responseText})
|
||||
}
|
||||
m.afterTurn.run(AfterTurnHook{Response: responseText})
|
||||
|
||||
// Build TurnResult with usage stats.
|
||||
turnResult := &TurnResult{
|
||||
|
||||
+9
-21
@@ -91,22 +91,12 @@ var globalSkillCache skillCache
|
||||
func (m *Kit) DiscoverSkillsForExtension() []extensions.Skill {
|
||||
cwd, _ := os.Getwd()
|
||||
|
||||
// Check cache first
|
||||
globalSkillCache.mu.RLock()
|
||||
if len(globalSkillCache.skills) > 0 {
|
||||
globalSkillCache.mu.RUnlock()
|
||||
return m.convertSkills(globalSkillCache.skills)
|
||||
}
|
||||
globalSkillCache.mu.RUnlock()
|
||||
|
||||
// Load fresh
|
||||
skillList, _ := skills.LoadSkills(cwd)
|
||||
|
||||
globalSkillCache.mu.Lock()
|
||||
globalSkillCache.skills = skillList
|
||||
globalSkillCache.mu.Unlock()
|
||||
|
||||
return m.convertSkills(skillList)
|
||||
defer globalSkillCache.mu.Unlock()
|
||||
if len(globalSkillCache.skills) == 0 {
|
||||
globalSkillCache.skills, _ = skills.LoadSkills(cwd)
|
||||
}
|
||||
return m.convertSkills(globalSkillCache.skills)
|
||||
}
|
||||
|
||||
// LoadSkillForExtension loads a single skill file for extensions.
|
||||
@@ -140,12 +130,10 @@ func (m *Kit) convertSkill(s *skills.Skill) *extensions.Skill {
|
||||
}
|
||||
|
||||
// convertSkills converts a slice of skills.
|
||||
func (m *Kit) convertSkills(skills []*skills.Skill) []extensions.Skill {
|
||||
result := make([]extensions.Skill, 0, len(skills))
|
||||
for _, s := range skills {
|
||||
if converted := m.convertSkill(s); converted != nil {
|
||||
result = append(result, *converted)
|
||||
}
|
||||
func (m *Kit) convertSkills(skillList []*skills.Skill) []extensions.Skill {
|
||||
result := make([]extensions.Skill, 0, len(skillList))
|
||||
for _, s := range skillList {
|
||||
result = append(result, *m.convertSkill(s))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
+30
-34
@@ -3,6 +3,7 @@ package kit
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
@@ -34,16 +35,17 @@ func ParseTemplate(name, content string) extensions.PromptTemplate {
|
||||
}
|
||||
|
||||
// RenderTemplate substitutes variables into template content.
|
||||
// Handles {{name}} and {{ name }} (any whitespace) placeholders.
|
||||
func RenderTemplate(tpl extensions.PromptTemplate, vars map[string]string) string {
|
||||
result := tpl.Content
|
||||
for name, value := range vars {
|
||||
placeholder := "{{" + name + "}}"
|
||||
result = strings.ReplaceAll(result, placeholder, value)
|
||||
// Also handle with spaces
|
||||
placeholderSpaced := "{{ " + name + " }}"
|
||||
result = strings.ReplaceAll(result, placeholderSpaced, value)
|
||||
}
|
||||
return result
|
||||
return varRegex.ReplaceAllStringFunc(tpl.Content, func(m string) string {
|
||||
sub := varRegex.FindStringSubmatch(m)
|
||||
if len(sub) > 1 {
|
||||
if v, ok := vars[sub[1]]; ok {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return m
|
||||
})
|
||||
}
|
||||
|
||||
// ParseArguments parses command-line style arguments.
|
||||
@@ -58,13 +60,10 @@ func ParseArguments(input string, pattern extensions.ArgumentPattern) extensions
|
||||
return result
|
||||
}
|
||||
|
||||
// First field is the command itself (if present)
|
||||
// First field is the command itself (if present); skip it.
|
||||
startIdx := 0
|
||||
if len(fields) > 0 && !strings.HasPrefix(fields[0], "-") {
|
||||
// Check if it's a command name or positional arg
|
||||
if len(pattern.Positional) == 0 || !isFlag(fields[0], pattern.Flags) {
|
||||
startIdx = 1 // Skip command name
|
||||
}
|
||||
startIdx = 1
|
||||
}
|
||||
|
||||
// Parse flags
|
||||
@@ -224,16 +223,6 @@ func parseFields(input string) []string {
|
||||
return fields
|
||||
}
|
||||
|
||||
// isFlag checks if a field is a known flag.
|
||||
func isFlag(field string, flags map[string]string) bool {
|
||||
if strings.HasPrefix(field, "--") {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(field, "-") && len(field) > 1 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// EvaluateModelConditional checks if condition matches current model.
|
||||
// Condition supports wildcards: * matches any, ? matches single char.
|
||||
@@ -248,17 +237,24 @@ func EvaluateModelConditional(currentModel, condition string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// matchModelPattern matches a model against a pattern with wildcards.
|
||||
func matchModelPattern(model, pattern string) bool {
|
||||
// Convert pattern to regexp
|
||||
pattern = strings.ReplaceAll(pattern, "*", ".*")
|
||||
pattern = strings.ReplaceAll(pattern, "?", ".")
|
||||
pattern = "^" + pattern + "$"
|
||||
// modelPatternCache caches compiled regexps for model glob patterns.
|
||||
var modelPatternCache sync.Map
|
||||
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
// Fallback: exact match
|
||||
return model == pattern
|
||||
// matchModelPattern matches a model against a pattern with wildcards.
|
||||
// Compiled regexps are cached to avoid recompilation on hot paths.
|
||||
func matchModelPattern(model, pattern string) bool {
|
||||
rePattern := "^" + strings.ReplaceAll(strings.ReplaceAll(pattern, "*", ".*"), "?", ".") + "$"
|
||||
var re *regexp.Regexp
|
||||
if v, ok := modelPatternCache.Load(rePattern); ok {
|
||||
re = v.(*regexp.Regexp)
|
||||
} else {
|
||||
compiled, err := regexp.Compile(rePattern)
|
||||
if err != nil {
|
||||
// Fallback: exact match
|
||||
return model == pattern
|
||||
}
|
||||
modelPatternCache.Store(rePattern, compiled)
|
||||
re = compiled
|
||||
}
|
||||
return re.MatchString(model)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user