Files
kit/internal/extensions/loader_test.go
Ed Zynda 8b7be8b735 fix: use immediate parent dir for main.go extension names
deriveExtensionName was using the full directory path (e.g.
examples/extensions/kit-telegram) to derive the display name for
main.go extensions, producing verbose names like 'Examples Extensions
Kit Telegram Extension'. Now uses filepath.Base(dir) so only the
immediate parent directory is used, giving 'Kit Telegram Extension'.

Also fix TestLoadExtensions_SkipsBadFiles which was flaky when
globally-installed git packages existed — isolate the test from the
host environment by overriding XDG_CONFIG_HOME, XDG_DATA_HOME, and
the working directory.
2026-03-18 17:36:06 +03:00

594 lines
15 KiB
Go

package extensions
import (
"os"
"path/filepath"
"slices"
"testing"
)
func TestDiscoverExtensionPaths_ExplicitFile(t *testing.T) {
// Create a temp dir with a .go file.
dir := t.TempDir()
f := filepath.Join(dir, "my-ext.go")
if err := os.WriteFile(f, []byte("package main"), 0644); err != nil {
t.Fatal(err)
}
paths := discoverExtensionPaths([]string{f})
if len(paths) == 0 {
t.Fatal("expected at least 1 path")
}
abs, _ := filepath.Abs(f)
if !slices.Contains(paths, abs) {
t.Errorf("expected %q in discovered paths %v", abs, paths)
}
}
func TestDiscoverExtensionPaths_ExplicitDir(t *testing.T) {
dir := t.TempDir()
f := filepath.Join(dir, "ext.go")
if err := os.WriteFile(f, []byte("package main"), 0644); err != nil {
t.Fatal(err)
}
paths := discoverExtensionPaths([]string{dir})
abs, _ := filepath.Abs(f)
if !slices.Contains(paths, abs) {
t.Errorf("expected %q in discovered paths %v", abs, paths)
}
}
func TestDiscoverExtensionPaths_SubdirMainGo(t *testing.T) {
dir := t.TempDir()
subdir := filepath.Join(dir, "my-plugin")
if err := os.MkdirAll(subdir, 0755); err != nil {
t.Fatal(err)
}
main := filepath.Join(subdir, "main.go")
if err := os.WriteFile(main, []byte("package main"), 0644); err != nil {
t.Fatal(err)
}
paths := discoverExtensionPaths([]string{dir})
abs, _ := filepath.Abs(main)
if !slices.Contains(paths, abs) {
t.Errorf("expected %q in discovered paths %v", abs, paths)
}
}
func TestDiscoverExtensionPaths_Dedup(t *testing.T) {
dir := t.TempDir()
f := filepath.Join(dir, "ext.go")
if err := os.WriteFile(f, []byte("package main"), 0644); err != nil {
t.Fatal(err)
}
// Pass the same file twice.
paths := discoverExtensionPaths([]string{f, f})
count := 0
abs, _ := filepath.Abs(f)
for _, p := range paths {
if p == abs {
count++
}
}
if count != 1 {
t.Errorf("expected dedup to 1, got %d", count)
}
}
func TestDiscoverExtensionPaths_NonGoFileIgnored(t *testing.T) {
dir := t.TempDir()
f := filepath.Join(dir, "readme.txt")
if err := os.WriteFile(f, []byte("hello"), 0644); err != nil {
t.Fatal(err)
}
paths := discoverExtensionPaths([]string{f})
for _, p := range paths {
abs, _ := filepath.Abs(f)
if p == abs {
t.Error("non-.go file should not be discovered")
}
}
}
func TestDiscoverExtensionPaths_NonexistentIgnored(t *testing.T) {
paths := discoverExtensionPaths([]string{"/nonexistent/path/ext.go"})
for _, p := range paths {
if p == "/nonexistent/path/ext.go" {
t.Error("nonexistent path should not be discovered")
}
}
}
func TestFindExtensionsInDir_EmptyDir(t *testing.T) {
dir := t.TempDir()
results := findExtensionsInDir(dir)
if len(results) != 0 {
t.Errorf("expected 0 results, got %d", len(results))
}
}
func TestFindExtensionsInDir_NonexistentDir(t *testing.T) {
results := findExtensionsInDir("/nonexistent/dir")
if len(results) != 0 {
t.Errorf("expected 0 results, got %d", len(results))
}
}
func TestFindExtensionsInDir_MixedContent(t *testing.T) {
dir := t.TempDir()
// .go file at top level
if err := os.WriteFile(filepath.Join(dir, "ext.go"), []byte("package main"), 0644); err != nil {
t.Fatal(err)
}
// non-.go file (should be ignored)
if err := os.WriteFile(filepath.Join(dir, "notes.txt"), []byte("hi"), 0644); err != nil {
t.Fatal(err)
}
// subdir with main.go
sub := filepath.Join(dir, "plugin")
if err := os.MkdirAll(sub, 0755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(sub, "main.go"), []byte("package main"), 0644); err != nil {
t.Fatal(err)
}
// subdir without main.go (should be ignored)
empty := filepath.Join(dir, "empty")
if err := os.MkdirAll(empty, 0755); err != nil {
t.Fatal(err)
}
results := findExtensionsInDir(dir)
if len(results) != 2 {
t.Fatalf("expected 2 results, got %d: %v", len(results), results)
}
}
func TestLoadSingleExtension_ValidExtension(t *testing.T) {
dir := t.TempDir()
src := `package main
import "kit/ext"
func Init(api ext.API) {
api.OnToolCall(func(tc ext.ToolCallEvent, ctx ext.Context) *ext.ToolCallResult {
return nil
})
api.OnSessionStart(func(se ext.SessionStartEvent, ctx ext.Context) {
})
}
`
f := filepath.Join(dir, "valid.go")
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
t.Fatal(err)
}
ext, err := loadSingleExtension(f)
if err != nil {
t.Fatalf("failed to load extension: %v", err)
}
if ext.Path != f {
t.Errorf("expected path %q, got %q", f, ext.Path)
}
if len(ext.Handlers[ToolCall]) != 1 {
t.Errorf("expected 1 ToolCall handler, got %d", len(ext.Handlers[ToolCall]))
}
if len(ext.Handlers[SessionStart]) != 1 {
t.Errorf("expected 1 SessionStart handler, got %d", len(ext.Handlers[SessionStart]))
}
}
func TestLoadSingleExtension_NoInitFunction(t *testing.T) {
dir := t.TempDir()
src := `package main
func Hello() string { return "hi" }
`
f := filepath.Join(dir, "noinit.go")
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
t.Fatal(err)
}
_, err := loadSingleExtension(f)
if err == nil {
t.Fatal("expected error for missing Init function")
}
}
func TestLoadSingleExtension_SyntaxError(t *testing.T) {
dir := t.TempDir()
src := `package main
func Init( { broken }
`
f := filepath.Join(dir, "broken.go")
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
t.Fatal(err)
}
_, err := loadSingleExtension(f)
if err == nil {
t.Fatal("expected error for syntax error")
}
}
func TestLoadSingleExtension_WrongSignature(t *testing.T) {
dir := t.TempDir()
src := `package main
func Init(s string) {}
`
f := filepath.Join(dir, "wrongsig.go")
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
t.Fatal(err)
}
_, err := loadSingleExtension(f)
if err == nil {
t.Fatal("expected error for wrong Init signature")
}
}
func TestLoadSingleExtension_RegistersTool(t *testing.T) {
dir := t.TempDir()
src := `package main
import "kit/ext"
func Init(api ext.API) {
api.RegisterTool(ext.ToolDef{
Name: "my_tool",
Description: "does stuff",
Parameters: "{\"type\":\"object\"}",
Execute: func(input string) (string, error) {
return "result: " + input, nil
},
})
}
`
f := filepath.Join(dir, "toolreg.go")
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
t.Fatal(err)
}
ext, err := loadSingleExtension(f)
if err != nil {
t.Fatalf("failed to load extension: %v", err)
}
if len(ext.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(ext.Tools))
}
if ext.Tools[0].Name != "my_tool" {
t.Errorf("expected tool name 'my_tool', got %q", ext.Tools[0].Name)
}
}
func TestLoadSingleExtension_RegistersCommand(t *testing.T) {
dir := t.TempDir()
src := `package main
import "kit/ext"
func Init(api ext.API) {
api.RegisterCommand(ext.CommandDef{
Name: "hello",
Description: "says hello",
Execute: func(args string, ctx ext.Context) (string, error) {
return "hello " + args, nil
},
})
}
`
f := filepath.Join(dir, "cmdreg.go")
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
t.Fatal(err)
}
ext, err := loadSingleExtension(f)
if err != nil {
t.Fatalf("failed to load extension: %v", err)
}
if len(ext.Commands) != 1 {
t.Fatalf("expected 1 command, got %d", len(ext.Commands))
}
if ext.Commands[0].Name != "hello" {
t.Errorf("expected command name 'hello', got %q", ext.Commands[0].Name)
}
}
func TestLoadExtensions_SkipsBadFiles(t *testing.T) {
dir := t.TempDir()
// Isolate from host environment so globally-installed extensions
// are not discovered alongside the test fixtures.
isolated := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", filepath.Join(isolated, "config"))
t.Setenv("XDG_DATA_HOME", filepath.Join(isolated, "data"))
origWd, _ := os.Getwd()
_ = os.Chdir(isolated)
t.Cleanup(func() { _ = os.Chdir(origWd) })
// Good extension
good := `package main
import "kit/ext"
func Init(api ext.API) {
api.OnSessionStart(func(_ ext.SessionStartEvent, _ ext.Context) {})
}
`
if err := os.WriteFile(filepath.Join(dir, "good.go"), []byte(good), 0644); err != nil {
t.Fatal(err)
}
// Bad extension (syntax error)
bad := `package main
func Init( { broken }
`
if err := os.WriteFile(filepath.Join(dir, "bad.go"), []byte(bad), 0644); err != nil {
t.Fatal(err)
}
loaded, err := LoadExtensions([]string{dir})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Should have loaded the good one and skipped the bad one.
if len(loaded) != 1 {
t.Fatalf("expected 1 loaded extension, got %d", len(loaded))
}
}
func TestLoadSingleExtension_HandlerExecution(t *testing.T) {
dir := t.TempDir()
src := `package main
import "kit/ext"
func Init(api ext.API) {
api.OnToolCall(func(tc ext.ToolCallEvent, ctx ext.Context) *ext.ToolCallResult {
if tc.ToolName == "banned" {
return &ext.ToolCallResult{Block: true, Reason: "tool is banned"}
}
return nil
})
}
`
f := filepath.Join(dir, "blocker.go")
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
t.Fatal(err)
}
ext, err := loadSingleExtension(f)
if err != nil {
t.Fatalf("failed to load extension: %v", err)
}
// Build a runner and test the handler actually works.
r := NewRunner([]LoadedExtension{*ext})
result, err := r.Emit(ToolCallEvent{ToolName: "banned", Input: "{}"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
tcr, ok := result.(ToolCallResult)
if !ok {
t.Fatalf("expected ToolCallResult, got %T", result)
}
if !tcr.Block {
t.Error("expected Block=true for banned tool")
}
if tcr.Reason != "tool is banned" {
t.Errorf("expected reason 'tool is banned', got %q", tcr.Reason)
}
// Non-banned tool should pass through.
result2, _ := r.Emit(ToolCallEvent{ToolName: "allowed", Input: "{}"})
if result2 != nil {
t.Errorf("expected nil result for allowed tool, got %v", result2)
}
}
func TestGlobalExtensionsDir_XDG(t *testing.T) {
// Save and restore XDG_CONFIG_HOME.
orig := os.Getenv("XDG_CONFIG_HOME")
defer func() { _ = os.Setenv("XDG_CONFIG_HOME", orig) }()
_ = os.Setenv("XDG_CONFIG_HOME", "/custom/config")
dir := globalExtensionsDir()
expected := "/custom/config/kit/extensions"
if dir != expected {
t.Errorf("expected %q, got %q", expected, dir)
}
}
func TestGlobalExtensionsDir_Default(t *testing.T) {
orig := os.Getenv("XDG_CONFIG_HOME")
defer func() { _ = os.Setenv("XDG_CONFIG_HOME", orig) }()
_ = os.Setenv("XDG_CONFIG_HOME", "")
dir := globalExtensionsDir()
home, _ := os.UserHomeDir()
expected := filepath.Join(home, ".config", "kit", "extensions")
if dir != expected {
t.Errorf("expected %q, got %q", expected, dir)
}
}
func TestLoadSingleExtension_ContextPrint(t *testing.T) {
dir := t.TempDir()
src := `package main
import "kit/ext"
func Init(api ext.API) {
api.OnInput(func(ie ext.InputEvent, ctx ext.Context) *ext.InputResult {
if ie.Text == "!hello" && ctx.Print != nil {
ctx.Print("Hello from extension!")
return &ext.InputResult{Action: "handled"}
}
return nil
})
}
`
f := filepath.Join(dir, "printer.go")
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
t.Fatal(err)
}
ext, err := loadSingleExtension(f)
if err != nil {
t.Fatalf("failed to load extension: %v", err)
}
// Wire up a Print function and verify it's called.
var printed []string
r := NewRunner([]LoadedExtension{*ext})
r.SetContext(Context{
Print: func(text string) {
printed = append(printed, text)
},
})
result, err := r.Emit(InputEvent{Text: "!hello", Source: "interactive"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
ir, ok := result.(InputResult)
if !ok {
t.Fatalf("expected InputResult, got %T", result)
}
if ir.Action != "handled" {
t.Errorf("expected Action 'handled', got %q", ir.Action)
}
if len(printed) != 1 || printed[0] != "Hello from extension!" {
t.Errorf("expected Print to capture 'Hello from extension!', got %v", printed)
}
}
func TestLoadSingleExtension_ContextPrintInfo(t *testing.T) {
dir := t.TempDir()
src := `package main
import "kit/ext"
func Init(api ext.API) {
api.OnInput(func(ie ext.InputEvent, ctx ext.Context) *ext.InputResult {
if ie.Text == "!info" && ctx.PrintInfo != nil {
ctx.PrintInfo("Styled info from extension")
return &ext.InputResult{Action: "handled"}
}
if ie.Text == "!error" && ctx.PrintError != nil {
ctx.PrintError("Styled error from extension")
return &ext.InputResult{Action: "handled"}
}
return nil
})
}
`
f := filepath.Join(dir, "styled.go")
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
t.Fatal(err)
}
ext, err := loadSingleExtension(f)
if err != nil {
t.Fatalf("failed to load extension: %v", err)
}
var infos, errors []string
r := NewRunner([]LoadedExtension{*ext})
r.SetContext(Context{
PrintInfo: func(text string) { infos = append(infos, text) },
PrintError: func(text string) { errors = append(errors, text) },
})
result, _ := r.Emit(InputEvent{Text: "!info"})
if ir, ok := result.(InputResult); !ok || ir.Action != "handled" {
t.Fatal("expected handled result for !info")
}
if len(infos) != 1 || infos[0] != "Styled info from extension" {
t.Errorf("expected PrintInfo capture, got %v", infos)
}
result, _ = r.Emit(InputEvent{Text: "!error"})
if ir, ok := result.(InputResult); !ok || ir.Action != "handled" {
t.Fatal("expected handled result for !error")
}
if len(errors) != 1 || errors[0] != "Styled error from extension" {
t.Errorf("expected PrintError capture, got %v", errors)
}
}
func TestLoadSingleExtension_ContextPrintBlock(t *testing.T) {
dir := t.TempDir()
src := `package main
import "kit/ext"
func Init(api ext.API) {
api.OnInput(func(ie ext.InputEvent, ctx ext.Context) *ext.InputResult {
if ie.Text == "!status" && ctx.PrintBlock != nil {
ctx.PrintBlock(ext.PrintBlockOpts{
Text: "All systems go\nModel: " + ctx.Model,
BorderColor: "#a6e3a1",
Subtitle: "test-ext",
})
return &ext.InputResult{Action: "handled"}
}
return nil
})
}
`
f := filepath.Join(dir, "block.go")
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
t.Fatal(err)
}
ext, err := loadSingleExtension(f)
if err != nil {
t.Fatalf("failed to load extension: %v", err)
}
var captured []PrintBlockOpts
r := NewRunner([]LoadedExtension{*ext})
r.SetContext(Context{
Model: "claude-4",
PrintBlock: func(opts PrintBlockOpts) {
captured = append(captured, opts)
},
})
result, _ := r.Emit(InputEvent{Text: "!status", Source: "interactive"})
if ir, ok := result.(InputResult); !ok || ir.Action != "handled" {
t.Fatal("expected handled result for !status")
}
if len(captured) != 1 {
t.Fatalf("expected 1 PrintBlock call, got %d", len(captured))
}
if captured[0].BorderColor != "#a6e3a1" {
t.Errorf("expected border '#a6e3a1', got %q", captured[0].BorderColor)
}
if captured[0].Subtitle != "test-ext" {
t.Errorf("expected subtitle 'test-ext', got %q", captured[0].Subtitle)
}
// Verify the text includes the model from context.
if captured[0].Text != "All systems go\nModel: claude-4" {
t.Errorf("unexpected text: %q", captured[0].Text)
}
}
func TestCountHandlers(t *testing.T) {
ext := &LoadedExtension{
Handlers: map[EventType][]HandlerFunc{
ToolCall: {func(Event, Context) Result { return nil }, func(Event, Context) Result { return nil }},
SessionStart: {func(Event, Context) Result { return nil }},
},
}
if n := countHandlers(ext); n != 3 {
t.Errorf("expected 3 handlers, got %d", n)
}
}