From c2f2bdb3d30815041ccb057674716074cea5c70b Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Tue, 7 Apr 2026 14:09:59 +0300 Subject: [PATCH] feat: auto-reload custom prompts and skills on file change - Add internal/watcher package with general-purpose ContentWatcher using fsnotify, configurable file extensions, and debouncing - Add ContentReloadEvent and App.NotifyContentReload() for TUI signaling - Add GetPromptTemplates/GetSkillItems callback fields on AppModelOptions following the existing GetExtensionCommands lazy-provider pattern - Add Kit.ReloadSkills() to re-discover skills from disk - Wire fsnotify watcher for .kit/prompts/, .kit/skills/, .agents/skills/, and global config directories, triggering on .md/.txt changes - TUI refreshes autocomplete entries and skill list on reload --- cmd/root.go | 102 +++++++++++++- internal/app/app.go | 13 ++ internal/app/events.go | 5 + internal/ui/model.go | 63 +++++++++ internal/watcher/watcher.go | 230 +++++++++++++++++++++++++++++++ internal/watcher/watcher_test.go | 225 ++++++++++++++++++++++++++++++ pkg/kit/kit.go | 2 + pkg/kit/skills.go | 13 ++ 8 files changed, 648 insertions(+), 5 deletions(-) create mode 100644 internal/watcher/watcher.go create mode 100644 internal/watcher/watcher_test.go diff --git a/cmd/root.go b/cmd/root.go index e61d217c..0dcd0ddb 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -7,6 +7,7 @@ import ( "image/color" "log" "os" + "path/filepath" "strings" tea "charm.land/bubbletea/v2" @@ -18,6 +19,7 @@ import ( "github.com/mark3labs/kit/internal/prompts" "github.com/mark3labs/kit/internal/ui" "github.com/mark3labs/kit/internal/ui/commands" + "github.com/mark3labs/kit/internal/watcher" kit "github.com/mark3labs/kit/pkg/kit" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -1620,6 +1622,49 @@ func runNormalMode(ctx context.Context) error { }) } + // Build prompt template and skill item provider callbacks for hot-reload. + // These are called by the TUI when ContentReloadEvent fires. + getPromptTemplates := func() []*prompts.PromptTemplate { + if noPromptTemplates { + return nil + } + homeDir, _ := os.UserHomeDir() + cwd, _ := os.Getwd() + tpls, _, err := prompts.LoadAll(prompts.LoadOptions{ + Cwd: cwd, + HomeDir: homeDir, + ExtraPaths: promptTemplatePaths, + ConfigPaths: viper.GetStringSlice("prompts"), + IncludeDefaults: true, + }) + if err != nil { + log.Printf("Warning: failed to reload prompt templates: %v", err) + } + return tpls + } + + getSkillItems := func() []ui.SkillItem { + // Re-discover skills from disk. + if err := kitInstance.ReloadSkills(); err != nil { + log.Printf("Warning: failed to reload skills: %v", err) + return nil + } + cwd, _ := os.Getwd() + var items []ui.SkillItem + for _, s := range kitInstance.GetSkills() { + source := "user" + if strings.HasPrefix(s.Path, cwd) { + source = "project" + } + items = append(items, ui.SkillItem{ + Name: s.Name, + Path: s.Path, + Source: source, + }) + } + return items + } + // Build extension UI providers once (shared between both modes). getWidgets := widgetProviderForUI(kitInstance) getHeader := headerProviderForUI(kitInstance) @@ -1715,9 +1760,54 @@ func runNormalMode(ctx context.Context) error { } } + // Start file watchers for automatic prompt and skill hot-reload. + { + homeDir, _ := os.UserHomeDir() + cwd, _ := os.Getwd() + + // Collect prompt template directories. + promptDirs := watcher.CollectDirs( + []string{ + filepath.Join(homeDir, ".kit", "prompts"), + filepath.Join(cwd, ".kit", "prompts"), + }, + append(promptTemplatePaths, viper.GetStringSlice("prompts")...), + ) + + // Collect skill directories. + skillDirs := watcher.CollectDirs( + []string{ + filepath.Join(homeDir, ".config", "kit", "skills"), + filepath.Join(cwd, ".agents", "skills"), + filepath.Join(cwd, ".kit", "skills"), + }, + nil, + ) + + // Combine all content directories and start a single watcher. + allContentDirs := append(promptDirs, skillDirs...) + if len(allContentDirs) > 0 { + contentWatcher, watchErr := watcher.New(watcher.Options{ + Dirs: allContentDirs, + Extensions: []string{".md", ".txt"}, + Label: "prompts/skills", + OnReload: func() { + log.Printf("auto-reloading prompts and skills") + appInstance.NotifyContentReload() + }, + }) + if watchErr != nil { + log.Printf("content file watcher not started: %v", watchErr) + } else { + go contentWatcher.Start(ctx) + defer func() { _ = contentWatcher.Close() }() + } + } + } + // Check if running in non-interactive mode if positionalPrompt != "" { - return runNonInteractiveModeApp(ctx, appInstance, cli, positionalPrompt, quietFlag, jsonFlag, noExitFlag, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI) + return runNonInteractiveModeApp(ctx, appInstance, cli, positionalPrompt, quietFlag, jsonFlag, noExitFlag, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getPromptTemplates, getSkillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI) } // Quiet mode is not allowed in interactive mode @@ -1725,7 +1815,7 @@ func runNormalMode(ctx context.Context) error { return fmt.Errorf("--quiet requires a prompt") } - return runInteractiveModeBubbleTea(ctx, appInstance, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI, startupExtensionMessages) + return runInteractiveModeBubbleTea(ctx, appInstance, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getPromptTemplates, getSkillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI, startupExtensionMessages) } // runNonInteractiveModeApp executes a single prompt via the app layer and exits, @@ -1738,7 +1828,7 @@ func runNormalMode(ctx context.Context) error { // // When --no-exit is set, after the prompt completes the interactive BubbleTea // TUI is started so the user can continue the conversation. -func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui.CLI, prompt string, quiet, jsonOutput, noExit bool, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error) error { +func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui.CLI, prompt string, quiet, jsonOutput, noExit bool, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getPromptTemplates func() []*prompts.PromptTemplate, getSkillItems func() []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error) error { // Expand @file references in the prompt before sending to the agent. if cwd, err := os.Getwd(); err == nil { prompt = ui.ProcessFileAttachments(prompt, cwd) @@ -1781,7 +1871,7 @@ func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui // If --no-exit was requested, hand off to the interactive TUI. if noExit { - return runInteractiveModeBubbleTea(ctx, appInstance, modelName, providerName, loadingMessage, serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModel, emitModelChange, isReasoningModel, thinkingLevel, setThinkingLevel, switchSession, reloadExtensions, nil) + return runInteractiveModeBubbleTea(ctx, appInstance, modelName, providerName, loadingMessage, serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getPromptTemplates, getSkillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModel, emitModelChange, isReasoningModel, thinkingLevel, setThinkingLevel, switchSession, reloadExtensions, nil) } return nil @@ -1879,7 +1969,7 @@ func writeJSONError(err error) { // 4. Calls program.Run() which blocks until the user quits (Ctrl+C or /quit). // // SetupCLI is not used for interactive mode; the TUI (AppModel) handles its own rendering. -func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error, startupExtensionMessages []string) error { +func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getPromptTemplates func() []*prompts.PromptTemplate, getSkillItems func() []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error, startupExtensionMessages []string) error { // Determine terminal size; fall back gracefully. termWidth, termHeight, err := term.GetSize(int(os.Stdout.Fd())) if err != nil || termWidth == 0 { @@ -1903,8 +1993,10 @@ func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelN UsageTracker: usageTracker, ExtensionCommands: extCommands, PromptTemplates: promptTemplates, + GetPromptTemplates: getPromptTemplates, ContextPaths: contextPaths, SkillItems: skillItems, + GetSkillItems: getSkillItems, StartupExtensionMessages: startupExtensionMessages, GetWidgets: getWidgets, GetHeader: getHeader, diff --git a/internal/app/app.go b/internal/app/app.go index 555e0e71..3ca3770b 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -997,6 +997,19 @@ func (a *App) NotifyWidgetUpdate() { } } +// NotifyContentReload sends a ContentReloadEvent to the TUI so it refreshes +// prompt templates and skills from their provider callbacks. Called by file +// watchers when .md/.txt files change in prompt or skill directories. +// In non-interactive mode this is a no-op. +func (a *App) NotifyContentReload() { + a.mu.Lock() + prog := a.program + a.mu.Unlock() + if prog != nil { + prog.Send(ContentReloadEvent{}) + } +} + // SendEvent sends a tea.Msg to the registered program. Safe to call from // any goroutine. No-op when no program is registered. // diff --git a/internal/app/events.go b/internal/app/events.go index f586d9cc..81ac7659 100644 --- a/internal/app/events.go +++ b/internal/app/events.go @@ -167,6 +167,11 @@ type ModelChangedEvent struct { // from its WidgetProvider on the next render cycle. type WidgetUpdateEvent struct{} +// ContentReloadEvent is sent when prompt templates or skills are reloaded +// from disk (e.g. by a file watcher detecting changes). The TUI refreshes +// its autocomplete entries and internal state from the provider callbacks. +type ContentReloadEvent struct{} + // EditorTextSetEvent is sent when an extension calls ctx.SetEditorText to // pre-fill the input editor with text. The TUI handles this by setting the // textarea content and moving the cursor to the end. diff --git a/internal/ui/model.go b/internal/ui/model.go index dc8c008a..4c0f873f 100644 --- a/internal/ui/model.go +++ b/internal/ui/model.go @@ -294,6 +294,11 @@ type AppModelOptions struct { // and are expanded when submitted (e.g., /review → full prompt text). PromptTemplates []*prompts.PromptTemplate + // GetPromptTemplates, if non-nil, returns the current prompt templates. + // Called on ContentReloadEvent to refresh the template list after a file + // watcher detects changes. May be nil if prompt hot-reload is not needed. + GetPromptTemplates func() []*prompts.PromptTemplate + // ContextPaths lists absolute paths of loaded context files (e.g. // AGENTS.md). Displayed in the [Context] startup section. ContextPaths []string @@ -301,6 +306,11 @@ type AppModelOptions struct { // SkillItems lists loaded skills for the [Skills] startup section. SkillItems []SkillItem + // GetSkillItems, if non-nil, returns the current skill items. + // Called on ContentReloadEvent to refresh the skill list after a file + // watcher detects changes. May be nil if skill hot-reload is not needed. + GetSkillItems func() []SkillItem + // MCPToolCount is the number of tools loaded from external MCP servers. MCPToolCount int @@ -500,6 +510,10 @@ type AppModel struct { // They appear in autocomplete and are expanded when submitted. promptTemplates []*prompts.PromptTemplate + // getPromptTemplates returns the current prompt templates. Used to + // refresh the template list after content hot-reload. May be nil. + getPromptTemplates func() []*prompts.PromptTemplate + // treeSelector is the tree navigation overlay, active in stateTreeSelector. treeSelector *TreeSelectorComponent @@ -508,6 +522,10 @@ type AppModel struct { contextPaths []string skillItems []SkillItem + // getSkillItems returns the current skill items. Used to refresh the + // skill list after content hot-reload. May be nil. + getSkillItems func() []SkillItem + // mcpToolCount and extensionToolCount track tool counts by source for // the startup info display. mcpToolCount int @@ -721,6 +739,7 @@ func NewAppModel(appCtrl AppController, opts AppModelOptions) *AppModel { // Store extension commands for dispatch. m.extensionCommands = opts.ExtensionCommands m.promptTemplates = opts.PromptTemplates + m.getPromptTemplates = opts.GetPromptTemplates m.getWidgets = opts.GetWidgets m.getHeader = opts.GetHeader m.getFooter = opts.GetFooter @@ -746,6 +765,7 @@ func NewAppModel(appCtrl AppController, opts AppModelOptions) *AppModel { // Store context/skills metadata and tool counts for startup display. m.contextPaths = opts.ContextPaths m.skillItems = opts.SkillItems + m.getSkillItems = opts.GetSkillItems m.mcpToolCount = opts.MCPToolCount m.extensionToolCount = opts.ExtensionToolCount m.startupExtensionMessages = opts.StartupExtensionMessages @@ -1817,6 +1837,12 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } } + case app.ContentReloadEvent: + // Prompt templates or skills changed on disk — refresh from providers. + m.refreshPromptTemplates() + m.refreshSkillItems() + m.printSystemMessage("Prompts and skills reloaded.") + case app.EditorTextSetEvent: // Extension wants to pre-fill the input editor with text. if ic, ok := m.input.(*InputComponent); ok { @@ -2710,6 +2736,43 @@ func (m *AppModel) expandPromptTemplate(text string) (string, bool) { return text, false } +// refreshPromptTemplates reloads prompt templates from the provider callback +// and updates the autocomplete entries. Called on ContentReloadEvent. +func (m *AppModel) refreshPromptTemplates() { + if m.getPromptTemplates == nil { + return + } + newTemplates := m.getPromptTemplates() + m.promptTemplates = newTemplates + + if ic, ok := m.input.(*InputComponent); ok { + // Remove old prompt commands and add fresh ones. + var kept []commands.SlashCommand + for _, sc := range ic.commands { + if sc.Category != "Prompts" { + kept = append(kept, sc) + } + } + for _, tpl := range newTemplates { + kept = append(kept, commands.SlashCommand{ + Name: "/" + tpl.Name, + Description: tpl.Description, + Category: "Prompts", + }) + } + ic.commands = kept + } +} + +// refreshSkillItems reloads skill items from the provider callback. +// Called on ContentReloadEvent. +func (m *AppModel) refreshSkillItems() { + if m.getSkillItems == nil { + return + } + m.skillItems = m.getSkillItems() +} + // printHelpMessage renders the help text listing all available slash commands. func (m *AppModel) printHelpMessage() { help := "## Available Commands\n\n" + diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go new file mode 100644 index 00000000..8f2c65e0 --- /dev/null +++ b/internal/watcher/watcher.go @@ -0,0 +1,230 @@ +// Package watcher provides a general-purpose file watcher that monitors +// directories for changes to files matching specified extensions. It uses +// fsnotify for kernel-level notifications with debouncing to coalesce +// rapid editor writes. +package watcher + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/charmbracelet/log" + "github.com/fsnotify/fsnotify" +) + +// ContentWatcher monitors directories for file changes matching a set of +// extensions and triggers a reload callback when changes are detected. +// It uses fsnotify for kernel-level file notifications (inotify on Linux, +// kqueue on macOS) with debouncing to coalesce rapid editor writes. +type ContentWatcher struct { + watcher *fsnotify.Watcher + onReload func() + extensions []string // e.g. [".md", ".txt"] + label string // for logging (e.g. "prompts", "skills") + debounce time.Duration + cancel context.CancelFunc + done chan struct{} + mu sync.Mutex +} + +// Options configures a ContentWatcher. +type Options struct { + // Dirs are the directories to watch. + Dirs []string + // Extensions are the file extensions to watch for (e.g. ".md", ".txt"). + // Include the leading dot. + Extensions []string + // OnReload is called when a matching file changes (after debouncing). + OnReload func() + // Label is a human-readable name for logging (e.g. "prompts", "skills"). + Label string + // Debounce is the debounce duration. Defaults to 300ms if zero. + Debounce time.Duration +} + +// New creates a ContentWatcher that monitors the given directories for +// file changes matching the specified extensions. When a change is detected +// (after debouncing), onReload is called. The watcher must be started with +// Start() and stopped with Close(). +func New(opts Options) (*ContentWatcher, error) { + if len(opts.Dirs) == 0 { + return nil, fmt.Errorf("no directories to watch") + } + + fsw, err := fsnotify.NewWatcher() + if err != nil { + return nil, fmt.Errorf("creating file watcher: %w", err) + } + + for _, dir := range opts.Dirs { + if err := fsw.Add(dir); err != nil { + log.Debug("watcher: skipping directory", "label", opts.Label, "dir", dir, "err", err) + continue + } + + // Also watch immediate subdirectories (for skill/SKILL.md pattern). + entries, err := os.ReadDir(dir) + if err != nil { + continue + } + for _, entry := range entries { + if entry.IsDir() { + subdir := filepath.Join(dir, entry.Name()) + if err := fsw.Add(subdir); err != nil { + log.Debug("watcher: skipping subdirectory", "label", opts.Label, "dir", subdir, "err", err) + } + } + } + } + + debounce := opts.Debounce + if debounce == 0 { + debounce = 300 * time.Millisecond + } + + return &ContentWatcher{ + watcher: fsw, + onReload: opts.OnReload, + extensions: opts.Extensions, + label: opts.Label, + debounce: debounce, + done: make(chan struct{}), + }, nil +} + +// Start begins watching for file changes. It blocks until the context +// is cancelled or Close() is called. Typically called in a goroutine. +func (w *ContentWatcher) Start(ctx context.Context) { + w.mu.Lock() + ctx, w.cancel = context.WithCancel(ctx) + w.mu.Unlock() + + defer close(w.done) + + var timer *time.Timer + var timerC <-chan time.Time + + for { + select { + case <-ctx.Done(): + if timer != nil { + timer.Stop() + } + return + + case event, ok := <-w.watcher.Events: + if !ok { + return + } + + // Only care about files matching our extensions. + if !w.matchesExtension(event.Name) { + continue + } + + // React to write, create, remove, rename events. + if event.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Remove|fsnotify.Rename) == 0 { + continue + } + + log.Debug("watcher: file changed", "label", w.label, "file", event.Name, "op", event.Op) + + // Debounce: reset timer on each event. + if timer != nil { + timer.Stop() + } + timer = time.NewTimer(w.debounce) + timerC = timer.C + + case <-timerC: + timerC = nil + timer = nil + log.Debug("watcher: reloading", "label", w.label) + w.onReload() + + case err, ok := <-w.watcher.Errors: + if !ok { + return + } + log.Warn("watcher: error", "label", w.label, "err", err) + } + } +} + +// Close stops the watcher and releases resources. +func (w *ContentWatcher) Close() error { + w.mu.Lock() + cancel := w.cancel + w.mu.Unlock() + + if cancel != nil { + cancel() + } + + // Wait for the event loop to finish. + <-w.done + return w.watcher.Close() +} + +// matchesExtension returns true if the file name ends with one of the +// watched extensions. +func (w *ContentWatcher) matchesExtension(name string) bool { + for _, ext := range w.extensions { + if strings.HasSuffix(name, ext) { + return true + } + } + return false +} + +// CollectDirs returns the directories to watch for a given set of standard +// directories and extra paths. Directories are deduplicated by absolute path +// and verified to exist. For explicit file paths, the parent directory is +// watched instead. +func CollectDirs(standardDirs []string, extraPaths []string) []string { + var dirs []string + seen := make(map[string]bool) + + add := func(dir string) { + abs, err := filepath.Abs(dir) + if err != nil { + return + } + if seen[abs] { + return + } + + // Verify the directory exists. + info, err := os.Stat(abs) + if err != nil || !info.IsDir() { + return + } + + seen[abs] = true + dirs = append(dirs, abs) + } + + for _, d := range standardDirs { + add(d) + } + + for _, p := range extraPaths { + info, err := os.Stat(p) + if err != nil { + continue + } + if info.IsDir() { + add(p) + } else { + // For explicit files, watch the parent directory. + add(filepath.Dir(p)) + } + } + + return dirs +} diff --git a/internal/watcher/watcher_test.go b/internal/watcher/watcher_test.go new file mode 100644 index 00000000..effe5192 --- /dev/null +++ b/internal/watcher/watcher_test.go @@ -0,0 +1,225 @@ +package watcher + +import ( + "os" + "path/filepath" + "sync/atomic" + "testing" + "time" +) + +func TestContentWatcher_ReloadsOnMatchingFile(t *testing.T) { + dir := t.TempDir() + + // Write an initial file so the directory isn't empty. + initial := filepath.Join(dir, "existing.md") + if err := os.WriteFile(initial, []byte("# Hello"), 0644); err != nil { + t.Fatal(err) + } + + var reloadCount atomic.Int32 + w, err := New(Options{ + Dirs: []string{dir}, + Extensions: []string{".md"}, + OnReload: func() { reloadCount.Add(1) }, + Label: "test", + Debounce: 50 * time.Millisecond, + }) + if err != nil { + t.Fatal(err) + } + + go w.Start(t.Context()) + + // Wait for watcher to be ready. + time.Sleep(100 * time.Millisecond) + + // Modify the file. + if err := os.WriteFile(initial, []byte("# Updated"), 0644); err != nil { + t.Fatal(err) + } + + // Wait for debounce + processing. + time.Sleep(200 * time.Millisecond) + + if got := reloadCount.Load(); got != 1 { + t.Errorf("expected 1 reload, got %d", got) + } + + _ = w.Close() +} + +func TestContentWatcher_IgnoresNonMatchingFiles(t *testing.T) { + dir := t.TempDir() + + var reloadCount atomic.Int32 + w, err := New(Options{ + Dirs: []string{dir}, + Extensions: []string{".md"}, + OnReload: func() { reloadCount.Add(1) }, + Label: "test", + Debounce: 50 * time.Millisecond, + }) + if err != nil { + t.Fatal(err) + } + + go w.Start(t.Context()) + + time.Sleep(100 * time.Millisecond) + + // Write a non-matching file. + if err := os.WriteFile(filepath.Join(dir, "readme.txt"), []byte("hello"), 0644); err != nil { + t.Fatal(err) + } + + time.Sleep(200 * time.Millisecond) + + if got := reloadCount.Load(); got != 0 { + t.Errorf("expected 0 reloads for non-matching file, got %d", got) + } + + _ = w.Close() +} + +func TestContentWatcher_MultipleExtensions(t *testing.T) { + dir := t.TempDir() + + var reloadCount atomic.Int32 + w, err := New(Options{ + Dirs: []string{dir}, + Extensions: []string{".md", ".txt"}, + OnReload: func() { reloadCount.Add(1) }, + Label: "test", + Debounce: 50 * time.Millisecond, + }) + if err != nil { + t.Fatal(err) + } + + go w.Start(t.Context()) + + time.Sleep(100 * time.Millisecond) + + // Write a .txt file — should trigger. + if err := os.WriteFile(filepath.Join(dir, "notes.txt"), []byte("notes"), 0644); err != nil { + t.Fatal(err) + } + + time.Sleep(200 * time.Millisecond) + + if got := reloadCount.Load(); got != 1 { + t.Errorf("expected 1 reload for .txt file, got %d", got) + } + + _ = w.Close() +} + +func TestContentWatcher_Debounces(t *testing.T) { + dir := t.TempDir() + + var reloadCount atomic.Int32 + w, err := New(Options{ + Dirs: []string{dir}, + Extensions: []string{".md"}, + OnReload: func() { reloadCount.Add(1) }, + Label: "test", + Debounce: 100 * time.Millisecond, + }) + if err != nil { + t.Fatal(err) + } + + go w.Start(t.Context()) + + time.Sleep(100 * time.Millisecond) + + // Rapid-fire writes — should debounce into 1 reload. + for i := range 5 { + if err := os.WriteFile(filepath.Join(dir, "test.md"), []byte("v"+string(rune('0'+i))), 0644); err != nil { + t.Fatal(err) + } + time.Sleep(30 * time.Millisecond) + } + + time.Sleep(300 * time.Millisecond) + + if got := reloadCount.Load(); got != 1 { + t.Errorf("expected 1 debounced reload, got %d", got) + } + + _ = w.Close() +} + +func TestContentWatcher_WatchesSubdirectories(t *testing.T) { + dir := t.TempDir() + + // Create a subdirectory (simulates skill-name/SKILL.md pattern). + subdir := filepath.Join(dir, "my-skill") + if err := os.MkdirAll(subdir, 0755); err != nil { + t.Fatal(err) + } + + var reloadCount atomic.Int32 + w, err := New(Options{ + Dirs: []string{dir}, + Extensions: []string{".md"}, + OnReload: func() { reloadCount.Add(1) }, + Label: "test", + Debounce: 50 * time.Millisecond, + }) + if err != nil { + t.Fatal(err) + } + + go w.Start(t.Context()) + + time.Sleep(100 * time.Millisecond) + + // Write to subdirectory. + if err := os.WriteFile(filepath.Join(subdir, "SKILL.md"), []byte("# Skill"), 0644); err != nil { + t.Fatal(err) + } + + time.Sleep(200 * time.Millisecond) + + if got := reloadCount.Load(); got != 1 { + t.Errorf("expected 1 reload for subdirectory file, got %d", got) + } + + _ = w.Close() +} + +func TestCollectDirs_Deduplicates(t *testing.T) { + dir := t.TempDir() + + dirs := CollectDirs([]string{dir, dir}, nil) + if len(dirs) != 1 { + t.Errorf("expected 1 deduplicated dir, got %d", len(dirs)) + } +} + +func TestCollectDirs_FileParent(t *testing.T) { + dir := t.TempDir() + file := filepath.Join(dir, "test.md") + if err := os.WriteFile(file, []byte("test"), 0644); err != nil { + t.Fatal(err) + } + + dirs := CollectDirs(nil, []string{file}) + if len(dirs) != 1 { + t.Fatalf("expected 1 dir, got %d", len(dirs)) + } + + abs, _ := filepath.Abs(dir) + if dirs[0] != abs { + t.Errorf("expected %s, got %s", abs, dirs[0]) + } +} + +func TestCollectDirs_SkipsNonexistent(t *testing.T) { + dirs := CollectDirs([]string{"/nonexistent/dir"}, nil) + if len(dirs) != 0 { + t.Errorf("expected 0 dirs for nonexistent path, got %d", len(dirs)) + } +} diff --git a/pkg/kit/kit.go b/pkg/kit/kit.go index a390cb1f..403c78c1 100644 --- a/pkg/kit/kit.go +++ b/pkg/kit/kit.go @@ -49,6 +49,7 @@ type Kit struct { extRunner *extensions.Runner bufferedLogger *tools.BufferedDebugLogger authHandler MCPAuthHandler // OAuth handler for remote MCP servers (may need Close) + opts *Options // stored for reload operations (skills, etc.) // Hook registries — interception layer (see hooks.go). beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult] @@ -737,6 +738,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) { extRunner: agentResult.ExtRunner, bufferedLogger: agentResult.BufferedLogger, authHandler: setupOpts.AuthHandler, + opts: opts, beforeToolCall: beforeToolCall, afterToolResult: afterToolResult, beforeTurn: beforeTurn, diff --git a/pkg/kit/skills.go b/pkg/kit/skills.go index 167d8631..4c153900 100644 --- a/pkg/kit/skills.go +++ b/pkg/kit/skills.go @@ -1,6 +1,7 @@ package kit import ( + "fmt" "os" "github.com/mark3labs/kit/internal/extensions" @@ -136,3 +137,15 @@ func (m *Kit) ClearSkillCache() { defer m.skillCache.mu.Unlock() m.skillCache.skills = nil } + +// ReloadSkills re-discovers skills from disk, replacing the current set. +// This is called by file watchers when skill files change. +func (m *Kit) ReloadSkills() error { + newSkills, err := loadSkills(m.opts) + if err != nil { + return fmt.Errorf("reloading skills: %w", err) + } + m.skills = newSkills + m.ClearSkillCache() + return nil +}