Compare commits

...

29 Commits

Author SHA1 Message Date
kit-agent 0bbccbb0a5 docs(sdk): document Options.DebugLogger and WithDebugLogger
- README.md: add WithDebugLogger to the functional-options helper list
- pkg/kit/README.md: expand the Debug row and add a DebugLogger row in
  the Options field summary
- www/pages/sdk/overview.md: add WithDebugLogger to the helpers table
  with a note that it overrides WithDebug when set
- www/pages/sdk/options.md: surface DebugLogger in the example, expand
  the Debug field description, add a DebugLogger row to the Core
  fields table, and add a "Custom debug logger" section with the
  interface signature and a log/slog adapter example
2026-06-18 14:22:32 +03:00
kit-agent 276c787937 refactor(sdk): drop unreachable kit.AgentConfig surface
kit.AgentConfig (pkg/kit/types.go) and its toInternal converter were
exposed as the documented "low-level / advanced consumer" path for
agent construction, but the converter was unexported and not wired
into any public constructor — neither New(*Options) nor NewAgent(...Option)
accept an AgentConfig. The only call sites were the dedicated
agent_config_internal_test.go (same-package internal test) and two
fantasy-import regression tests in types_test.go.

Net effect today: no SDK consumer outside pkg/kit can populate or use
kit.AgentConfig in any way. The type, the converter, the dedicated
test file, and a chain of godoc cross-references all exist purely for
their own sake — they don't enlarge what consumers can do, but they
do enlarge the SDK's stability contract (every field becomes a public
shape the internal agent layer can't refactor freely).

The companion PR added Options.DebugLogger + WithDebugLogger so the
last functional capability AgentConfig was documented to enable —
installing a custom debug logger — is reachable through the supported
construction path. With that wired, AgentConfig has no remaining
purpose.

Changes:

- pkg/kit/types.go: remove the AgentConfig struct and its toInternal()
  method. Drop the now-unused internal/agent and internal/tools
  imports. Update the DebugLogger godoc to point at Options.DebugLogger
  and WithDebugLogger instead of AgentConfig.
- pkg/kit/agent_config_internal_test.go: deleted (208 LOC). It exercised
  the unexported toInternal() method directly; with the method gone
  the test has no subject.
- pkg/kit/types_test.go: rename TestAgentConfigNoFantasyImport to
  TestOptionsNoFantasyImport and rewrite it against Options
  (SystemPrompt, MaxSteps, Streaming, Tools, ExtraTools,
  DisableCoreTools, OnMCPServerLoaded). The original test also asserted
  ToolWrapper field semantics; that capability migrates to the hook
  system (OnBeforeToolCall / OnAfterToolResult), already covered by
  hooks_test.go, so the assertion is dropped with a pointer in the
  godoc. TestAgentConfigToolWrapperSignature replaced by
  TestToolSliceSignature, which still pins that []kit.Tool is the
  user-visible slice type for every tool-related SDK surface — the
  no-fantasy-import contract the original test guarded.
- pkg/kit/mcp_tasks.go: update the MCPTaskConfig godoc to stop
  referencing AgentConfig. MCPTaskConfig stays — it is still emitted
  through Options.MCPTask* fields and used as the engine-facing
  config type.
- pkg/kit/README.md: drop the kit.AgentConfig line from the type
  inventory.

internal/agent.AgentConfig is untouched and remains the internal
construction shape. With the public type gone the internal one can
evolve freely without breaking the SDK contract.

Verification:
- go build ./pkg/... ./internal/... ./cmd/... — clean
- go vet ./pkg/... ./internal/... ./cmd/... — clean
- go test -race -timeout 300s ./... — all packages pass
2026-06-18 13:45:00 +03:00
kit-agent 40bc710938 feat(sdk): add Options.DebugLogger and WithDebugLogger option
Today the SDK exposes a DebugLogger interface (pkg/kit/types.go) but no
public path to install one — the only consumer of the field is the
unexported kit.AgentConfig.toInternal() method, which itself is not
reachable from outside the package. As a result, embedders that want to
forward Kit's low-level engine + MCP tool plumbing debug output into
their own logging system (slog, zap, charm/log, an in-app TUI panel,
etc.) have no option but the on/off Debug bool, which always installs
the built-in SimpleDebugLogger / BufferedDebugLogger.

This change closes that gap on the supported Options / functional-option
construction path:

- pkg/kit/kit.go: add Options.DebugLogger DebugLogger. When non-nil it
  is used directly and the Debug bool is ignored; the supplied logger's
  IsDebugEnabled() controls whether downstream code emits messages.
- pkg/kit/options.go: add WithDebugLogger(l DebugLogger) Option.
- internal/kitsetup/setup.go: add AgentSetupOptions.DebugLogger and
  switch SetupAgent's logger selection so the caller-supplied logger
  wins unconditionally; otherwise the existing Debug + UseBufferedLogger
  branch picks the built-in implementation. No behaviour change when
  DebugLogger is nil.
- pkg/kit/kit.go: wire opts.DebugLogger into setupOpts so the New()
  path threads it through.
- pkg/kit/viper_isolation_test.go: add TestWithDebugLoggerPlumbing and
  TestWithDebugLoggerNilClears covering the option-to-field contract
  and later-options-override semantics consistent with the other With*
  helpers.
- pkg/kit/README.md: list WithDebugLogger in the helper inventory.

Notes:
- kit.DebugLogger and tools.DebugLogger are structurally identical
  (LogDebug(string) / IsDebugEnabled() bool), so the SDK value flows
  into the internal field without a conversion.
- This is purely additive on the SDK surface and does not touch
  kit.AgentConfig — that field already carried a DebugLogger, but the
  AgentConfig path is unreachable from outside the package today.
2026-06-18 13:38:14 +03:00
Ed Zynda 888c6c7953 chore(models): refresh embedded models database from models.dev
- add GLM-5.2 across 9 providers (alibaba-token-plan-cn, baseten,
  cloudflare-workers-ai, fireworks-ai, neuralwatt, opencode-go,
  openrouter, venice, vercel)
- add moonshotai/Kimi-K2.7-Code on baseten
- drop deprecated neuralwatt models (MiniMax-M2.5,
  Devstral-Small-2-24B-Instruct-2512, gpt-oss-20b)
- pick up new reasoning_options metadata on several models
2026-06-18 12:42:11 +03:00
Ed Zynda a9d808eb9f build(deps): bump go module dependencies
- charm.land/fantasy v0.31.0 -> v0.32.0
- alecthomas/chroma/v2 v2.26.1 -> v2.27.0
- charmbracelet/openai-go to 20260617131321
- mark3labs/mcp-go v0.54.1 -> v0.55.0
- kaptinlin/jsonschema v0.8.0 -> v0.8.1
- pelletier/go-toml/v2 v2.3.1 -> v2.4.0
- google.golang.org/api v0.284.0 -> v0.285.0
- google.golang.org/genai v1.60.0 -> v1.61.0
2026-06-18 12:37:37 +03:00
Ed Zynda d7948a64f3 fix(app): make ctx.NewSession wait for agent idle (#63) (#64)
- Add ErrAgentBusy sentinel (shared between internal/app and
  internal/extensions) so callers can detect the busy condition with
  errors.Is instead of substring-matching the error message.
- Add App.WaitForIdle(timeout) backed by a per-busy-cycle idleCh closed
  by a new setBusyLocked chokepoint; all busy transitions now route
  through it to keep the channel in sync with the busy flag.
- Have RequestNewSessionFromExtension wait for idle (up to
  DefaultNewSessionIdleWait = 10m) instead of failing fast on IsBusy.
  This fixes the v0.79.0 phase-handoff race where OnAgentEnd fires from
  inside the agent loop, before drainQueue clears busy, so
  ctx.NewSession reliably failed with 'agent is busy'.
- Expose ext.ErrAgentBusy to Yaegi via symbols.go.
- Update NewSession godoc and phase-handoff example to document the new
  wait-then-send behavior.
- Add regression tests covering already-idle, blocks-until-drain,
  timeout, zero-timeout, app-close, headless guard, and idleCh
  transitions.

Fixes #63
2026-06-18 12:33:54 +03:00
Michal Hrušecký d2e2e5e9b3 feat(models): add apiModelName field to custom model config (#59)
* feat(models): add apiModelName field to custom model config

Allows custom models to specify an alternative model name to send
in API requests, distinct from the config key. Useful when a local
or custom endpoint expects a different model identifier.

Configures createCustomProvider to use modelInfo.APIModelName
when calling p.LanguageModel(), falling back to the config key.

* docs: document apiModelName field in custom model config
2026-06-17 17:17:50 +03:00
Ed Zynda 2c05280150 feat(ui): support /new <prompt> and ctx.NewSession for phase handoffs
- /new now accepts an optional initial prompt that is submitted as the
  first user turn of the new session, with @file expansion mirroring
  normal input submission
- Add ctx.NewSession(prompt) extension API for ending the current
  session and starting a fresh one from an extension (e.g. on AgentEnd)
- Plumb the prompt through BeforeSessionSwitchEvent.InitialPrompt so
  extensions can inspect or veto the switch
- Bridge extension calls into the TUI via app.NewSessionRequestEvent
  with a response channel so the caller observes success or failure
- Add pkg/kit EmitBeforeSessionSwitchWithPrompt; keep the old method
  as a thin compatibility wrapper
- Ship examples/extensions/phase-handoff.go demonstrating automatic
  session handoff on a <HANDOFF_READY> sentinel plus a /handoff command
- Tests cover the new /new prompt path, the extension request event,
  and the before-hook cancellation flow
2026-06-17 17:16:24 +03:00
Ed Zynda 6a1b061d06 build(deps): update dependencies and bump go to 1.26.4
- upgrade charm.land/fantasy v0.25.0 -> v0.31.0
- upgrade charm.land/lipgloss/v2, ultraviolet, and AWS/Azure SDKs
- bump golang.org/x/image and golang.org/x/term
- raise go directive to 1.26.4 as required by updated deps
2026-06-16 14:20:53 +03:00
Ed Zynda 3a35bc5cec chore(models): refresh embedded models.dev snapshot
- Add umans-ai provider (144 -> 145 providers)
- Add new models (kimi-k2.7-code variants, minimax-m3, gemma-4, glm-5.2)
- Remove stale entries (claude-fable-5, older fireworks/siliconflow models)
- Model count 5244 -> 5270
2026-06-16 14:10:24 +03:00
Ed Zynda 08c3d0fe3a feat(github): make GitHub integration work end-to-end via core command + action
Replace the headless-incompatible handler extension with a real,
out-of-the-box integration modeled on opencode:

- add 'kit github run' core command that reads the Actions event, gates
  on author_association, reacts, runs the agent headlessly, posts a
  comment, and opens a kit-agent[bot] PR when files changed
- add a bundled composite action (action.yml) that installs the Kit
  binary and runs 'kit github run' — no separate action repo needed
- point the generated workflow at the in-repo action (mark3labs/kit@v0)
- maintain floating major/minor tags (v0, v0.x) on release so the action
  reference always resolves
- bound git/gh/agent subprocesses with timeouts; ignore mid-sentence
  /kit mentions
- remove the github-handler example extension (superseded) and refresh
  docs to describe the action + command

Part of #60
2026-06-16 13:29:06 +03:00
Ed Zynda 16662ca208 feat(extensions): add GitHub handler extension and env var access (#62)
* feat(extensions): add GitHub handler extension and env var access

- add github-handler example extension that runs Kit as a GitHub
  collaborator inside Actions: parses the event, gates on
  author_association, drives the agent, posts comments, and opens PRs
- seed the Yaegi interpreter with os.Environ() in the loader and test
  harness so extensions can read env vars (e.g. GITHUB_EVENT_PATH) via
  os.Getenv/LookupEnv/Environ without mutating the host environment
- document env var access, the new extension, and env-aware testing
  across the docs site, README, and kit-extensions skill

Part of #60

* fix(extensions): harden github-handler command parsing and subprocesses

- only trigger on /kit at the start or end of a comment line, ignoring
  incidental mid-sentence mentions like "please review /kit behavior"
- bound git/gh subprocess calls with a 30s timeout via CommandContext so
  a stalled network call or auth prompt cannot hang the Actions job
- add a regression test for the mid-sentence mention case

Part of #60
2026-06-16 00:29:13 +03:00
Ed Zynda 7067c99c84 feat(cmd): add kit github install command (#60) (#61)
* feat(cmd): add kit github install command (#60)

- Add `kit github` parent command and `kit github install` subcommand
  that scaffolds .github/workflows/kit.yml to run Kit as a GitHub
  Actions collaborator/reviewer triggered by `/kit` comments
- Generate a least-privilege workflow with persist-credentials: false,
  resolve the provider secret env var from the model registry, and
  refuse to clobber an existing file unless --force
- Offer to set the provider secret via the gh CLI when available;
  flags: --model, --force, --no-secret
- Add unit tests for secret resolution, workflow rendering, and write
- Document the command in README and the docs site (cli/commands, index)

Fixes #60

* fix(cmd): harden kit github install workflow and secret handling

- Pass the provider secret to `gh secret set` via stdin instead of the
  --body flag so the API key never appears in the process argument list
- Gate the generated workflow on author_association (OWNER, MEMBER,
  COLLABORATOR) so untrusted users cannot trigger privileged runs
- Match `/kit` only as a leading command token instead of an incidental
  substring anywhere in the comment body
- Thread cmd.Context() through to the gh invocation
- Update tests and docs to reflect the refined trigger conditions
2026-06-15 23:46:35 +03:00
Ed Zynda feaec4268e chore(models): update embedded models.dev snapshot
- Refresh internal/models/embedded_models.json from models.dev/api.json
- Providers 139 -> 144, models 5121 -> 5244
2026-06-15 16:31:27 +03:00
Sai Karthik 7f366eab84 cmd: add --no-skills, --skill, and --skills-dir CLI flags & config (#55)
* cmd: add --no-skills, --skill, and --skills-dir CLI flags

The pkg/kit Options struct already had full backend support for skills
control (NoSkills, Skills []string, SkillsDir) wired into loadSkills()
in pkg/kit/kit.go, but there were no corresponding CLI flags to drive
them. This commit closes that gap.

Changes in cmd/root.go:

- Add three package-level flag variables alongside the existing
  noExtensionsFlag/extensionPaths group:
    noSkillsFlag bool
    skillsPaths  []string
    skillsDir    string

- Register three persistent cobra flags in init():
    --no-skills        disable skill loading (auto-discovery and explicit)
    --skill <path>     load a skill file or directory (repeatable)
    --skills-dir <dir> override the project-local skills directory
                       used for auto-discovery

- Wire all three into the kitOpts struct literal in runNormalMode()
  so they flow directly into kit.New() -> loadSkills().

No changes to pkg/kit or internal/skills -- the backend was already
complete. No viper binding is needed because kit.go reads these fields
directly from opts rather than from viper (unlike NoExtensions which
uses the viper fallback path).

Example usage:
  kit --no-skills "prompt"
  kit --skill ./my-skill.md --skill ./other-skill.md "prompt"
  kit --skills-dir /path/to/skills "prompt"

Co-authored-by: Claude <claude@anthropic.com>

* docs: document --no-skills, --skill, and --skills-dir CLI flags

Add the three new skills CLI flags to all relevant documentation:

- README.md: add Skills section under Global Flags CLI reference
- www/pages/cli/flags.md: add Skills table (mirrors Extensions section pattern)
- www/pages/cli/commands.md: expand the Skills section with usage examples
  and a description of auto-discovery vs explicit loading vs --no-skills

Co-authored-by: Claude <claude@anthropic.com>

* feat: add config file support for skills options

Skills could previously only be controlled via CLI flags or SDK Options
fields. This commit wires all three skills settings into viper so they
can also be set in .kit.yml / .kit.yaml / .kit.json and via KIT_*
environment variables — matching the pattern used by no-extensions,
no-core-tools, and prompt-template.

cmd/root.go:
- Bind --no-skills, --skill, and --skills-dir flags to viper keys
  (no-skills, skill, skills-dir) so config file values flow through.

pkg/kit/kit.go:
- At skill-load time, merge opts fields with viper values:
  - noSkills = opts.NoSkills || v.GetBool("no-skills")
  - skillPaths: opts.Skills if non-empty, else v.GetStringSlice("skill")
  - skillsDir: opts.SkillsDir if non-empty, else v.GetString("skills-dir")
- Build a shallow-copied mergedOpts so loadSkills() picks up the
  resolved values without mutating the original Options struct.

docs:
- README.md: add skills keys to the Basic Configuration YAML example
- www/pages/configuration.md: add no-skills, skill, skills-dir rows to
  the All configuration keys table

Config file example (.kit.yml):
  no-skills: false
  skill:
    - /path/to/skill.md
  skills-dir: /path/to/skills/

Co-authored-by: Claude <claude@anthropic.com>

* config: add skills keys to default .kit.yml template

Add no-skills, skill, and skills-dir as commented-out examples in the
default config file generated by EnsureConfigExists(), alongside the
existing application settings block.

Co-authored-by: Claude <claude@anthropic.com>

* test: add test coverage for skills CLI flags and config keys

Four test locations updated:

pkg/kit/export_test.go:
- Add ConfigStringSliceForTest() helper to expose v.GetStringSlice()
  from the Kit's isolated viper store, needed to assert skill list values.

pkg/kit/kit_test.go (TestNewWithSkillsOptions):
- NoSkills=true: GetSkills() returns empty slice
- SkillsDir=<empty dir>: kit.New() succeeds with zero skills
- Skills=[file]: single explicit skill file is loaded and name parsed correctly

pkg/kit/viper_isolation_test.go:
- TestSkillsViperKeys: no-API-key struct-level checks for NoSkills, Skills,
  and SkillsDir fields on Options
- TestSkillsConfigFileKeys: full kit.New() round-trips via a written .kit.yml
  for each of the three config keys:
    no-skills: true  → GetSkills() returns empty
    skill: [path]    → named skill loaded from config file path
    skills-dir: dir  → custom discovery root accepted without error

internal/config/config_test.go (TestEnsureConfigExists):
- Assert generated ~/.kit.yml template contains '# Skills configuration',
  'no-skills:', and 'skills-dir:' comment blocks.

Co-authored-by: Claude <claude@anthropic.com>

---------

Co-authored-by: Claude <claude@anthropic.com>
2026-06-12 16:23:17 +03:00
Ed Zynda e8e99b19a8 refactor: dedupe cross-package logic and remove dead code from audit (#58)
* Remove dead code: 5 unused symbols across internal packages

- internal/models: LoadModelSettingsFromConfig (zero refs)
- internal/prompts: PromptTemplate.ExpandWithArgs (zero refs)
- internal/app: NewMessageStore (tests migrated to NewMessageStoreWithMessages)
- internal/config: HasEnvVars (+ its test)
- internal/core: ContextWithSudoPassword (test migrated to context.WithValue)

* pkg/kit: use TreeManager alias in exported signatures

NewTreeManagerAdapter and InitTreeSession now spell their signatures with
the public kit.TreeManager alias instead of internal/session.TreeManager,
so go doc renders domain types rather than internal paths.

* Consolidate tool-kind classification into internal/extensions

coreToolKinds + toolKindFor were duplicated verbatim in
internal/extensions/wrapper.go and pkg/kit/events.go, risking silent
divergence between extension events and SDK events. Single source of
truth now lives in internal/extensions/toolkinds.go; pkg/kit re-exports
the constants.

* Consolidate Anthropic OAuth detection and usage-tracker refresh

The 'is the active Anthropic credential a stored OAuth token' check was
copy-pasted at 5 sites, all prefix-matching the magic string
'stored OAuth' produced in internal/auth. Now:

- internal/auth: new CredentialSourceOAuth constant + IsAnthropicOAuth()
- internal/ui: new UpdateUsageTrackerForModel(); CreateUsageTracker and
  SetupCLI share lookupTrackableModel (SetupCLI no longer re-inlines the
  tracker construction)
- cmd/root.go + cmd/extension_context.go: verbatim-duplicated tracker
  refresh blocks replaced with ui.UpdateUsageTrackerForModel
- pkg/kit isAnthropicOAuth delegates to auth.IsAnthropicOAuth
- internal/models compares source against the constant

* pkg/kit: consolidate model-path helpers and argument tokenizer

- ExtractModelFromPath mis-parsed model IDs containing '/' (e.g.
  'openrouter/meta/llama' -> 'meta'); it now delegates to
  RemoveProviderFromModel and is deprecated alongside
  ExtractProviderFromPath (-> GetCurrentProvider)
- parseFields delegated to prompts.ParseCommandArgs so extension argument
  parsing and builtin prompt-template parsing share one quote/escape
  grammar; ParseCommandArgs now also splits on tabs (superset of both
  previous tokenizers)

* Unify the two {{variable}} template engines

internal/skills and pkg/kit/template_bridge each had their own grammar:
skills rejected '{{ name }}' (whitespace) but allowed digit-first names;
the bridge was the opposite. A template behaved differently depending on
whether it was loaded as a skill prompt or via the extension API.

internal/skills is now the single engine using the superset grammar
(\{\{\s*(\w+)\s*\}\}); pkg/kit ParseTemplate/RenderTemplate are thin
adapters over it. Expand is now regex-based so whitespace placeholders
expand consistently; missing variables are still left as-is.

* internal/ui: extract switchModel helper for model-switch flow

The model-selector handler (ModelSelectedMsg) and /model slash command
duplicated the full switch sequence (thinking-level fallback, setModel,
display-state update, preference persistence, ModelChange emit) and had
already drifted in ordering. Both now call a single switchModel method.
Display state is still updated directly (no prog.Send from Update).

* extbridge: extract shared BaseContext for extension wiring

cmd/extension_context.go and internal/acpserver/session.go each built a
giant extensions.Context literal, duplicating ~15 delegation closures
(GetContextStats, GetMessages, AppendEntry, options, SetModel core,
Complete, SpawnSubagent, ...) that had to be kept in sync by hand. New
data-access fields had to be wired in both places or ACP-mode extensions
silently got nil function fields.

extbridge.BaseContext now provides the headless half; both call sites
overlay only their UI-specific closures. As a side effect ACP mode gains
previously-missing APIs (state, tree navigation, skills, template
parsing, model resolution) that were nil before. The interactive TUI
keeps its exact SetModel/ReloadExtensions ordering via overrides.

* internal/tools: extract withOAuthRetry and marshalToolResult helpers

ExecuteTool repeated the OAuth-error/re-auth/retry stanza verbatim twice
(sync and task-augmented paths) and the marshal-and-wrap stanza four
times. Both are now single helpers with identical error strings, so a
fix to OAuth retry or error categorization applies everywhere at once.

* internal/ui: extract buildShareFile with defer-based cleanup

handleShareCommand repeated the close/remove/print/return cleanup chain
four times across its temp-file write error paths. File assembly now
lives in buildShareFile with a single deferred cleanup on error.

* cmd: extract flag validation, preference restore, and provider-URL routing from runNormalMode

runNormalMode opened with ~150 lines of policy logic (flag-combination
validation, persisted model/thinking-level preference restoration, and
two subtle --provider-url model-rewrite rules). These are now standalone
functions (validateModeFlags, restorePersistedPreferences,
applyProviderURLRouting) so the routing policy is independently readable
and testable. Behaviour unchanged; ordering preserved.

* fix: address review findings on SDK godoc and nil guard

- pkg/kit: remove internal package paths from exported godoc on
  ParseTemplate and the ToolKind* constants (SDK doc surface must not
  reference internal packages)
- internal/tools: guard marshalToolResult against a nil CallToolResult
  (json.Marshal(nil) succeeds as 'null', then result.IsError panics if
  a client returns nil result with nil error)

Skipped the TreeNode Children deep-copy suggestion: the slice already
comes from TreeManager.GetChildren which returns a fresh copy per call
into a throwaway intermediate, so no internal state is exposed.
2026-06-11 16:13:18 +03:00
Egbert Eich ef072f6e59 Make subagent inherit tools from parent (#51)
While the tool list of the main agent could be controlled by several
options, subagent used to be equipped with all available tools (except
for the subagent tool itself).
With this change the list of tools is taken from the parent, the
subagent tool itself is removed and the remaining tool list is added
to the subagent.

Signed-off-by: Egbert Eich <eich@suse.com>
2026-06-09 16:28:01 +03:00
Ed Zynda 49f8b485be feat(extensions): add OnLLMUsage, SetState, enriched AgentEndEvent (#53) (#54)
* feat(extensions): add OnLLMUsage, SetState, enriched AgentEndEvent (#53)

Three additive primitives to the extension API:

- OnLLMUsage event: per-LLM-call token + cost deltas attributed to the
  specific model/provider used for each round-trip. Derived from the SDK
  StepFinishEvent in the extension bridge. Enables accurate budget
  enforcement between calls instead of only at turn boundaries.

- ctx.SetState / GetState / DeleteState / ListState: session-scoped,
  last-write-wins key-value store backed by a sidecar file
  (<session>.ext-state.json) outside the conversation tree. Reads are
  O(1), writes don't grow the JSONL, and the store is not duplicated on
  fork. State is preserved across hot-reloads.

- Enriched AgentEndEvent: ToolCallCount, ToolNames, LLMCallCount, token
  deltas (input/output/cache-read/cache-write), CostDelta, and
  DurationMs populated by a per-turn aggregator. Existing handlers
  reading only Response/StopReason are unaffected.

Includes unit tests for the state store, LLMUsage registration,
enriched AgentEndEvent, turn aggregator, llmUsageMeta, and sidecar path
derivation. Adds examples/extensions/usage-budget.go demoing all three
primitives together. Documents the additions in README, the docs site
(extensions overview, capabilities, examples), and the kit-extensions
and kit-sdk skill guides.

Fixes #53

* fix(extensions): address review feedback on state store and llmUsageMeta

- Serialize SetState/DeleteState saver invocations through a new saverMu
  so overlapping atomic-rename writes can no longer race on the shared
  .tmp file and persist an older snapshot after a newer one.
- LoadStateFromFile now clears the in-memory store when the sidecar is
  missing or empty, matching the documented "replace … with its
  contents" contract. This makes session-switching safe by preventing
  keys from a prior session leaking into a new one. Tests updated to
  cover both the missing-file and empty-file cases.
- llmUsageMeta now detects Anthropic OAuth credentials and returns
  Cost=0, matching the comment and the existing usage_tracker behavior
  for OAuth users. Mirrors the OAuth detection already used in
  cmd/extension_context.go.
- Document the single-in-flight-turn assumption baked into the
  per-turn aggregator with a clear migration path (per-turn ID) for if
  concurrent turns ever become a supported use case.

* fix(extensions): release saverMu on panic in state store

Extract a runSaver helper that locks saverMu and defers Unlock before
invoking the persistence callback. Without the deferred Unlock, a panic
inside the saver (e.g. disk full mid-write) would leave saverMu held
forever and deadlock the next SetState/DeleteState. Both SetState and
DeleteState now route through the helper. New TestRunner_State_Saver
PanicReleasesSaverMu reproduces the deadlock window with a 2s deadline
and proves the mutex is released after a panic.
2026-06-09 16:18:10 +03:00
Nuno do Carmo febdc530e1 Feat/copilot login (#49)
* feat(auth): add Copilot login

Add experimental GitHub Copilot device login and copilot/* provider support for users with Copilot access but no OpenAI account.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fix(copilot): use responses for GPT-5

Route Copilot GPT-5 models through the Responses API because gpt-5.5 is not available on /chat/completions.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fix(copilot): honor device flow timing

* docs(copilot): add auth helper docstrings

* fix(auth): address copilot review feedback

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-06-08 00:21:20 +03:00
Ed Zynda e610bdd2d0 fix(cmd): route prefixed models through custom wire when --provider-url is set
When --provider-url was set with an explicit --model that already carried
a provider prefix (e.g. google/gemma-4-12b served by LM Studio), Kit
honored the prefix and routed through the Google wire protocol instead
of the user-supplied endpoint, producing confusing upstream errors.

- Strip any non-custom provider prefix from --model when --provider-url
  is set, so the request always lands on the OpenAI-compatible custom
  wire pointed at the user's URL.
- Leave behavior unchanged when --provider-url is absent.
- Document the rewrite in www/pages/providers.md.
2026-06-07 22:03:51 +03:00
Ed Zynda 6100e8b3a8 feat(ui): add /retry slash command for resubmitting last user message
- Add PopLastUserMessage() on *App: walks the current tree branch back to
  the parent of the most recent user message, syncs the in-memory store,
  and returns the prompt + image parts for resubmission.
- Register /retry (alias /rt) and wire handleRetryCommand which rebuilds
  the visible ScrollList from the truncated branch before resubmitting
  via Run/RunWithFiles. Mirrors SubmitMsg display path (badges, pending
  prints, stateWorking transition).
- Recovers from transient provider errors (overloaded, timeout) without
  duplicating the user message in context — the failed turn's entries
  become orphaned off-branch rather than being re-sent to the LLM.
- Update help text, AppController interface, and stub controller.
- Add unit tests covering busy/closed/no-session guards, the happy-path
  truncation, and the empty-branch error case.
2026-06-07 18:05:20 +03:00
Ed Zynda 9f125f3400 refactor(ui): standardize all popups on shared PopupList
- Extend PopupList with FullScreen mode, RenderItem callback, and
  external-state setters (SetItems/SetCursor/SetSearch) so any popup
  can reuse the same chrome (border, title, search, scroll, footer).
- Rewrite TreeSelector and SessionSelector as thin PopupList wrappers,
  dropping ~500 lines of duplicated rendering. Selector-specific keys
  (filter cycle, scope/named toggles, delete-confirm) are pre-handled;
  everything else delegates to PopupList.
- Migrate the / and @ autocomplete popups in InputComponent to render
  through PopupList, replacing the bespoke renderer.
- Fix /tree and /fork overflow with deep trees: measure tree-art
  prefix width via lipgloss.Width (handles multi-byte box drawing),
  truncate the prefix from the left with an ellipsis when it would
  push text off the row, and collapse multi-line message content to
  a single line so rows never wrap.
- Fix broken selection highlight in /tree, /fork, /sessions: emit a
  plain string from RenderItem for the cursor row so the outer row
  style paints one continuous fg+bg span instead of being shredded
  by mid-row ANSI resets from inner Render calls.
- Center the cursor in the visible window so context is always shown
  above and below the selection.
2026-06-07 17:45:06 +03:00
Ed Zynda 00eab47218 feat(ui): add /edit slash command with fuzzy file picker
- New /edit (alias /ed) opens $EDITOR on a chosen file via tea.ExecProcess
- Typing '/edit ' activates a fuzzy file popup mirroring the @ trigger:
  reuses GetFileSuggestions (git ls-files), supports directory drill-down,
  excludes MCP resources
- Selecting a file auto-submits and runs $EDITOR ($VISUAL preferred);
  on exit prints 'Edited <path>'
- Manual paths supported (~/, relative, absolute); non-existent paths
  pass through so the editor can create them; directories are rejected
- /help updated with the new command
2026-06-07 17:10:34 +03:00
Ed Zynda 06bf6d087a feat(models): resolve SDK default URLs for all registered providers
- Add sdkDefaultBaseURL map covering the 14 npm SDKs that ship a
  hard-coded baseURL (groq, cerebras, mistral, xai, perplexity,
  togetherai, deepinfra, cohere, v0, aihubmix, venice, merge-gateway,
  openrouter, vercel gateway), so providers whose models.dev entry
  omits the api field still auto-route correctly.
- Extend npmToWireProtocol so these thin OpenAI-compatible wrappers
  route through fantasy's openaicompat provider.
- Add resolveTemplatedAPIURL to substitute ${VAR} placeholders for
  cloudflare-workers-ai, databricks, snowflake-cortex from the env,
  with friendly errors that name the missing vars.
- Wire amazon-bedrock and azure-cognitive-services aliases into the
  existing native handlers; add createGoogleVertexProvider for the
  google-vertex case.
- Expose kit.ResolveProviderBaseURL in the public SDK so embedders
  can introspect the effective endpoint before instantiating a Kit.
- Refresh embedded_models.json from models.dev (5113 -> 5121 models;
  139 providers unchanged).
2026-06-07 14:06:05 +03:00
Ed Zynda fd960921ca refactor: address code audit findings across SDK, cmd, and internals
- Remove deprecated GenerateWithLoopAndStreaming and TreeManager
  AppendFantasyMessage / AddFantasyMessages / GetFantasyMessages to
  close the SDK leakage caused by the kit.TreeManager type alias
- Switch extensionAPI method signatures to local Extension* aliases so
  pkg.go.dev signatures no longer expose internal package names
- Bundle runNormalMode dependencies into a runModeDeps struct, shrinking
  the runNonInteractive and runInteractive call sites from 40+ positional
  args to (ctx, deps)
- Add generic subscribeTyped[E Event] helper and collapse ~30 typed OnXxx
  wrappers in pkg/kit/events.go onto it (public signatures unchanged)
- Extract setupBashPipes / interpretBashExit in internal/core/bash.go to
  deduplicate the buffered and streaming execution paths
- Extract resolveAutoRouteAPIKey and wrapProviderErr helpers in
  internal/models/providers.go and uniformly apply them across every
  createXxxProvider site
- Reimplement internal/extensions/watcher.go as a thin wrapper over the
  general-purpose internal/watcher.ContentWatcher, eliminating ~130 LOC
  of duplicated fsnotify logic while preserving the existing test API
- Add ctx.Err() pre-flight checks in executeRead / Write / Edit / Ls so
  cancellation actually short-circuits pure file-IO tools
2026-06-06 19:22:05 +03:00
Ed Zynda 0b651a8df9 build(deps): update dependencies except fantasy
- bump bubbletea v2.0.6 -> v2.0.7, ultraviolet, acp-go-sdk v0.13.0 -> v0.13.5
- bump indirect deps x/exp, charmtone, go-runewidth
- hold fantasy at v0.25.0 (v0.29.1 requires go 1.26.4)
- add no-op Logout method to acpserver.Agent for new acp.Agent interface
2026-06-04 15:48:07 +03:00
Ed Zynda 7315c1dea7 chore(models): update embedded model database from models.dev
- Refresh internal/models/embedded_models.json with latest data
- Add providers: alibaba-token-plan, anyapi, snowflake-cortex
- 139 providers, 5113 models total
2026-06-04 15:35:43 +03:00
Ed Zynda 0313fa03ad fix(ui): show pasted image previews in input and transcript (#48)
* fix(ui): show pasted image previews in input and transcript

The half-block thumbnail preview added in #47 rendered but was clipped
off the bottom of the screen, and submitted images showed only a text
badge in the conversation history.

- Mark the layout dirty when clipboardImageMsg / thumbnailReadyMsg reach
  the parent, so distributeHeight re-measures the now-taller input region
  instead of keeping a stale height that pushed the preview off-screen
- Render thumbnail previews in the transcript after a user message,
  appended as a verbatim ScrollList item (raw ANSI half-blocks would be
  mangled if folded into the word-wrapped user text block)
- Render transcript previews asynchronously via a tea.Cmd so decode +
  resample never blocks the Bubble Tea event loop
- Add regression tests covering the input layout recompute and the
  transcript preview flow

* fix(ui): anchor transcript image preview to its user message

- Insert the async thumbnail preview directly after the originating user
  message (tracked via anchorID) instead of appending, so a streamed
  assistant reply that lands first no longer pushes the preview out of place
- Make the layout regression test deterministic by forcing a truecolor
  profile, avoiding flakes on low-color CI terminals where the thumbnail
  would render empty
- Add tests for anchored insertion and the unknown-anchor append fallback
2026-06-04 15:30:47 +03:00
Ed Zynda d27022bcfb feat(ui): render half-block thumbnails for attached images (#47)
* feat(ui): render half-block thumbnails for attached images (#46)

- Add internal/ui/imagepreview package: Render() draws low-res
  thumbnails using Unicode half-blocks (▀) + truecolor/256-color SGR,
  which survives tmux/zellij (no graphics protocol)
- Cache a rendered thumbnail per pending clipboard image in the input
  component; render once at attach time, never per frame
- Fall back to the existing [N image(s) attached] text pill when the
  terminal lacks truecolor/256-color support
- Document Ctrl+V paste, Ctrl+U clear, and the preview in the docs
  site and README keyboard shortcuts

Fixes #46

* fix(ui): render image thumbnails off the event loop and cap size

- Render thumbnails asynchronously via a tea.Cmd instead of calling
  the decode + resample path synchronously inside Update(), which
  blocked the Bubble Tea event loop
- Add thumbnailReadyMsg + an imageGen generation counter so async
  results land on the correct pendingImages slot and stale renders
  after a clear/re-attach are discarded
- Guard imagepreview.Render against decompression bombs by checking
  DecodeConfig dimensions against a max before full decode

* fix(ui): skip image preview when input width is too small

- Return 0 from thumbCols when width <= 6 so a full-size thumbnail is
  no longer rendered for tiny or uninitialized (width 0) terminals;
  the caller falls back to the text pill
2026-06-04 14:36:39 +03:00
111 changed files with 9176 additions and 3054 deletions
+30
View File
@@ -39,6 +39,36 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# Keep floating major/minor tags (e.g. v1, v1.2) pointing at the latest
# release so the composite action can be referenced as `mark3labs/kit@v1`.
action-tags:
runs-on: ubuntu-latest
needs: goreleaser
if: ${{ github.event_name == 'push' && needs.goreleaser.result == 'success' }}
permissions:
contents: write
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Update floating major/minor tags
env:
FULL_TAG: ${{ github.ref_name }}
run: |
set -euo pipefail
# FULL_TAG looks like v1.2.3 — derive v1 and v1.2.
VER="${FULL_TAG#v}"
MAJOR="v${VER%%.*}"
MINOR="v${VER%.*}"
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
for t in "$MAJOR" "$MINOR"; do
echo "Pointing $t at $FULL_TAG"
git tag -f "$t" "$FULL_TAG"
git push -f origin "refs/tags/$t"
done
npm-publish:
runs-on: ubuntu-latest
needs: goreleaser
+76 -5
View File
@@ -28,6 +28,7 @@ A powerful, extensible AI coding agent CLI with multi-provider support, built-in
- **Interactive TUI**: Rich terminal interface powered by Bubble Tea with streaming, syntax highlighting, and custom rendering
- **Session Management**: Tree-based conversation history with branching support
- **Non-Interactive Mode**: Script-friendly positional args with JSON output
- **GitHub Integration**: Scaffold a GitHub Actions workflow with `kit github install` to run Kit as a collaborator/reviewer on `/kit` comments
- **ACP Server**: Run Kit as an [Agent Client Protocol](https://agentclientprotocol.com) agent over stdio
- **Go SDK**: Embed Kit in your own applications with full agent lifecycle events (30+ event types) and behavior-modifying hooks
@@ -128,6 +129,12 @@ temperature: 0.7
stream: true
thinking-level: off # off, none, minimal, low, medium, high
no-core-tools: false # set to true to disable all built-in core tools
# Skills — all three keys are optional
no-skills: false # set to true to disable all skill loading
skill: # explicit skill files/dirs (disables auto-discovery)
- /path/to/skill.md
skills-dir: "" # override project-local directory for auto-discovery
```
All of the above keys can also be set programmatically via the SDK
@@ -203,6 +210,11 @@ mcpServers:
--prompt-template Load a specific prompt template by name
--no-prompt-templates Disable prompt template loading
# Skills
--skill Load skill file or directory (repeatable)
--skills-dir Override the project-local skills directory for auto-discovery
--no-skills Disable skill loading (auto-discovery and explicit)
# Generation parameters
--max-tokens Maximum tokens in response (default: 8192, auto-raised up to 32768 for models with larger known output limits)
--temperature Randomness 0.0-1.0 (default: 0.7)
@@ -228,6 +240,10 @@ kit auth login [provider] --set-default # Set provider's default model as syste
kit auth logout [provider] # Remove credentials for provider
kit auth status # Check authentication status
# GitHub Copilot login (experimental; requires active Copilot subscription)
kit auth login copilot
kit --model copilot/gpt-5.5 "Hello"
# Model database
kit models [provider] # List available models (optionally filter by provider)
kit models --all # Show all providers (not just LLM-compatible)
@@ -245,6 +261,12 @@ kit install --uninstall <pkg> # Remove an installed package
# Skills
kit skill # Install the Kit extensions skill via skills.sh
# GitHub integration
kit github install # Scaffold .github/workflows/kit.yml (run Kit on '/kit' comments)
kit github install --model anthropic/claude-sonnet-4-5-20250929
kit github install --force # Overwrite an existing workflow file
kit github install --no-secret # Skip the offer to set the provider secret via the gh CLI
# ACP server
kit acp # Start as ACP agent (stdio JSON-RPC)
kit acp --debug # With debug logging to stderr
@@ -308,12 +330,15 @@ kit -e examples/extensions/minimal.go
### Extension Capabilities
**Lifecycle Events**: OnSessionStart, OnSessionShutdown, OnBeforeAgentStart, OnAgentStart, OnAgentEnd, OnToolCall, OnToolCallInputStart, OnToolCallInputDelta, OnToolCallInputEnd, OnToolExecutionStart, OnToolOutput, OnToolExecutionEnd, OnToolResult, OnInput, OnMessageStart, OnMessageUpdate, OnMessageEnd, OnModelChange, OnContextPrepare, OnBeforeFork, OnBeforeSessionSwitch, OnBeforeCompact, OnCustomEvent, OnSubagentStart, OnSubagentChunk, OnSubagentEnd
**Lifecycle Events**: OnSessionStart, OnSessionShutdown, OnBeforeAgentStart, OnAgentStart, OnAgentEnd, OnLLMUsage, OnToolCall, OnToolCallInputStart, OnToolCallInputDelta, OnToolCallInputEnd, OnToolExecutionStart, OnToolOutput, OnToolExecutionEnd, OnToolResult, OnInput, OnMessageStart, OnMessageUpdate, OnMessageEnd, OnModelChange, OnContextPrepare, OnBeforeFork, OnBeforeSessionSwitch, OnBeforeCompact, OnCustomEvent, OnSubagentStart, OnSubagentChunk, OnSubagentEnd
`OnAgentEnd` carries per-turn aggregates (`ToolCallCount`, `ToolNames`, `LLMCallCount`, `InputTokensDelta`, `OutputTokensDelta`, `CostDelta`, `DurationMs`) so observers don't need to maintain parallel bookkeeping. `OnLLMUsage` fires after each LLM provider call with token + cost deltas attributed to that specific call/model — use it for accurate budget enforcement *between* calls instead of waiting for the turn to finish.
**Custom Components**:
- **Tools**: Add new tools the LLM can invoke
- **Commands**: Register slash commands (e.g., `/mycommand`)
- **Options**: Register configurable extension options
- **Session State**: Last-write-wins key-value store via `ctx.SetState` / `GetState` / `DeleteState` / `ListState`, persisted to a per-session sidecar file outside the conversation tree
- **Widgets**: Persistent status displays above/below input
- **Headers/Footers**: Persistent content above/below the conversation
- **Status Bar**: Custom status bar entries
@@ -369,6 +394,7 @@ See the `examples/extensions/` directory:
- [`tool-logger.go`](examples/extensions/tool-logger.go) - Log all tool calls
- [`neon-theme.go`](examples/extensions/neon-theme.go) - Custom theme registration and switching
- [`tool-renderer-demo.go`](examples/extensions/tool-renderer-demo.go) - Custom tool call rendering
- [`usage-budget.go`](examples/extensions/usage-budget.go) - Per-call usage callback (`OnLLMUsage`), session state, and enriched `OnAgentEnd` per-turn report
- [`widget-status.go`](examples/extensions/widget-status.go) - Persistent status widgets
Also see [`.kit/extensions/go-edit-lint.go`](.kit/extensions/go-edit-lint.go) (in this repo) for a project-local extension example that runs gopls and golangci-lint on Go file edits.
@@ -459,6 +485,48 @@ Placeholders inside fenced code blocks (```) and inline code spans are ignored.
Disable templates with `--no-prompt-templates` or load a specific template with `--prompt-template <name>`.
## GitHub Integration
Kit can run as an automated collaborator/reviewer inside GitHub Actions. The
`kit github install` command scaffolds a workflow that triggers when someone
comments `/kit ...` on an issue or pull request review, runs the agent
non-interactively in the runner, and lets it respond.
```bash
kit github install
```
This writes `.github/workflows/kit.yml`. By default the command prompts for the
model (pre-filled with a sensible default); pass `--model` to skip the prompt.
If the [`gh` CLI](https://cli.github.com/) is detected on your `PATH` and the
provider API key is present in your environment, you'll be offered the option to
store it as a repository secret automatically.
The generated workflow:
- Triggers only on `issue_comment` and `pull_request_review_comment` (`types: [created]`).
- Runs only when the comment begins with the `/kit` command token.
- Restricts triggers to repository owners, members, and collaborators (via `author_association`).
- Uses least-privilege `permissions` and `persist-credentials: false`.
- Authenticates git/PR operations with the built-in `secrets.GITHUB_TOKEN` and
the provider via a repository secret (e.g. `ANTHROPIC_API_KEY`).
After committing the workflow and setting the provider secret, comment
`/kit <your request>` on any issue or pull request to trigger Kit.
The generated workflow uses the bundled [`mark3labs/kit`](action.yml) composite
action, which installs the Kit binary and runs `kit github run`. That command
reads the triggering event, enforces permissions, reacts with an emoji, runs the
agent against the issue thread or pull request, posts the response as a comment,
and — if the agent changed files — pushes a `kit-agent[bot]` branch and opens a
pull request.
| Flag | Description |
| --- | --- |
| `--model` | Provider/model to write into the workflow |
| `--force` | Overwrite an existing workflow file |
| `--no-secret` | Skip the offer to set the provider secret via the `gh` CLI |
## Session Management
Kit uses a tree-based session model that supports branching and forking conversations.
@@ -509,6 +577,8 @@ During an interactive session, use these slash commands:
| Shortcut | Description |
|----------|-------------|
| `Ctrl+V` | Paste an image from the clipboard — shows an inline low-res thumbnail preview (tmux/zellij-safe) |
| `Ctrl+U` | Clear all pending image attachments |
| `Ctrl+X e` | Open `$VISUAL`/`$EDITOR` to compose or edit your prompt |
| `Ctrl+X s` | Steer — inject a system-level instruction mid-turn |
| `ESC ESC` | Cancel the current operation (tool call or streaming) |
@@ -621,10 +691,10 @@ host, err := kit.NewAgent(ctx,
Available options: `WithModel`, `WithSystemPrompt`, `WithStreaming`,
`WithMaxTokens`, `WithThinkingLevel`, `WithTools`, `WithExtraTools`,
`WithProviderAPIKey`, `WithProviderURL`, `WithConfigFile`, `WithDebug`, and
`Ephemeral`. For advanced configuration not covered by the helpers (custom MCP
config, in-process MCP servers, session backends, MCP task tuning) construct an
`Options` value explicitly and call `kit.New`.
`WithProviderAPIKey`, `WithProviderURL`, `WithConfigFile`, `WithDebug`,
`WithDebugLogger`, and `Ephemeral`. For advanced configuration not covered by
the helpers (custom MCP config, in-process MCP servers, session backends, MCP
task tuning) construct an `Options` value explicitly and call `kit.New`.
### Per-instance config isolation
@@ -947,6 +1017,7 @@ npm/ - NPM package wrapper for distribution
- **Anthropic** - Claude models (native, prompt caching, OAuth)
- **OpenAI** - GPT models
- **Copilot** - GitHub Copilot models (`copilot`, requires active Copilot subscription)
- **Google** - Gemini models
- **Ollama** - Local models
- **Azure OpenAI** - Azure-hosted OpenAI
+75
View File
@@ -0,0 +1,75 @@
name: "Kit"
description: "Run Kit as an automated collaborator/reviewer on GitHub issues and pull requests."
author: "mark3labs"
branding:
icon: "git-merge"
color: "purple"
inputs:
model:
description: "Provider/model Kit should use (e.g. anthropic/claude-sonnet-4-5-20250929). Defaults to Kit's built-in default."
required: false
default: ""
version:
description: "Kit version to install (e.g. v0.77.0). Defaults to the latest release."
required: false
default: "latest"
runs:
using: "composite"
steps:
- name: Install Kit
shell: bash
env:
KIT_VERSION: ${{ inputs.version }}
run: |
set -euo pipefail
VERSION="${KIT_VERSION:-latest}"
if [ -z "$VERSION" ] || [ "$VERSION" = "latest" ]; then
VERSION="$(curl -fsSL https://api.github.com/repos/mark3labs/kit/releases/latest \
| grep -o '"tag_name": *"[^"]*"' | head -1 | cut -d'"' -f4)"
fi
if [ -z "$VERSION" ]; then
echo "::error::could not determine Kit version to install" >&2
exit 1
fi
VER="${VERSION#v}"
case "$(uname -s)" in
Linux) OS=linux ;;
Darwin) OS=darwin ;;
*) echo "::error::unsupported OS $(uname -s)" >&2; exit 1 ;;
esac
case "$(uname -m)" in
x86_64|amd64) ARCH=amd64 ;;
aarch64|arm64) ARCH=arm64 ;;
*) echo "::error::unsupported arch $(uname -m)" >&2; exit 1 ;;
esac
URL="https://github.com/mark3labs/kit/releases/download/${VERSION}/kit_${VER}_${OS}_${ARCH}.tar.gz"
echo "Installing Kit ${VERSION} from ${URL}"
TMP="$(mktemp -d)"
curl -fsSL "$URL" | tar -xz -C "$TMP"
mkdir -p "$HOME/.kit/bin"
mv "$TMP/kit" "$HOME/.kit/bin/kit"
chmod +x "$HOME/.kit/bin/kit"
echo "$HOME/.kit/bin" >> "$GITHUB_PATH"
rm -rf "$TMP"
- name: Verify Kit
shell: bash
run: kit --version
- name: Run Kit
shell: bash
env:
MODEL: ${{ inputs.model }}
run: |
set -euo pipefail
ARGS=()
if [ -n "${MODEL:-}" ]; then
ARGS+=(--model "$MODEL")
fi
kit github run ${ARGS[@]+"${ARGS[@]}"}
+157 -4
View File
@@ -31,10 +31,12 @@ using OAuth flows. Stored credentials take precedence over environment variables
Available providers:
- anthropic: Anthropic Claude API (OAuth)
- openai: OpenAI API (OAuth and API key)
- copilot: GitHub Copilot (GitHub device login)
Examples:
kit auth login anthropic
kit auth login openai
kit auth login copilot
kit auth logout anthropic
kit auth status`,
}
@@ -54,6 +56,7 @@ environment variables when making API calls.
Available providers:
- anthropic: Anthropic Claude API (OAuth)
- openai: OpenAI ChatGPT Plus/Pro (Codex OAuth)
- copilot: GitHub Copilot (GitHub device login, experimental)
Flags:
--set-default Set this provider's default model as the system default
@@ -61,7 +64,8 @@ Flags:
Examples:
kit auth login anthropic
kit auth login openai
kit auth login openai --set-default`,
kit auth login copilot
kit auth login copilot --set-default`,
Args: cobra.ExactArgs(1),
RunE: runAuthLogin,
}
@@ -80,10 +84,12 @@ You will need to use environment variables or command-line flags for authenticat
Available providers:
- anthropic: Anthropic Claude API
- openai: OpenAI API
- copilot: GitHub Copilot
Example:
kit auth logout anthropic
kit auth logout openai`,
kit auth logout openai
kit auth logout copilot`,
Args: cobra.ExactArgs(1),
RunE: runAuthLogout,
}
@@ -113,6 +119,7 @@ var (
var defaultModels = map[string]string{
"anthropic": "anthropic/claude-sonnet-4-5-20250929",
"openai": "openai/gpt-5.4",
"copilot": "copilot/gpt-5.5",
}
// setDefaultModelIfRequested sets the default model for the given provider
@@ -143,6 +150,7 @@ func init() {
authLoginCmd.Flags().BoolVar(&loginSetDefault, "set-default", false, "Set this provider's default model as the system default after login")
}
// runAuthLogin dispatches OAuth login to the selected provider.
func runAuthLogin(cmd *cobra.Command, args []string) error {
provider := strings.ToLower(args[0])
@@ -151,8 +159,10 @@ func runAuthLogin(cmd *cobra.Command, args []string) error {
return loginAnthropic()
case "openai":
return loginOpenAI()
case "copilot":
return loginCopilot(cmd.Context())
default:
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic, openai", provider)
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic, openai, copilot", provider)
}
}
@@ -164,8 +174,10 @@ func runAuthLogout(cmd *cobra.Command, args []string) error {
return logoutAnthropic()
case "openai":
return logoutOpenAI()
case "copilot":
return logoutCopilot()
default:
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic, openai", provider)
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic, openai, copilot", provider)
}
}
@@ -244,9 +256,31 @@ func runAuthStatus(cmd *cobra.Command, args []string) error {
}
}
// Check GitHub Copilot credentials
fmt.Print("\nGitHub Copilot: ")
if hasCopilotCreds, err := cm.HasCopilotCredentials(); err != nil {
fmt.Printf("Error checking credentials: %v\n", err)
} else if hasCopilotCreds {
if creds, err := cm.GetCopilotCredentials(); err != nil {
fmt.Printf("Error reading credentials: %v\n", err)
} else {
status := "✓ Authenticated"
if creds.IsExpired() {
status = "⚠️ Token expired (will refresh automatically)"
} else if creds.NeedsRefresh() {
status = "⚠️ Token expires soon (will refresh automatically)"
}
fmt.Printf("%s (GitHub OAuth, stored %s)\n", status, creds.CreatedAt.Format("2006-01-02 15:04:05"))
}
} else {
fmt.Println("✗ Not authenticated")
}
fmt.Println("\nTo authenticate with a provider:")
fmt.Println(" kit auth login anthropic")
fmt.Println(" kit auth login openai")
fmt.Println(" kit auth login copilot")
return nil
}
@@ -517,6 +551,85 @@ func loginOpenAI() error {
return nil
}
// loginCopilot authenticates GitHub Copilot using GitHub device flow.
func loginCopilot(ctx context.Context) error {
if ctx == nil {
ctx = context.Background()
}
cm, err := kit.NewCredentialManager()
if err != nil {
return fmt.Errorf("failed to initialize credential manager: %w", err)
}
if hasAuth, err := cm.HasCopilotCredentials(); err == nil && hasAuth {
var reauth bool
err := huh.NewConfirm().
Title("You are already authenticated with GitHub Copilot").
Description("Do you want to re-authenticate?").
Affirmative("Yes").
Negative("No").
Value(&reauth).
Run()
if err != nil {
return fmt.Errorf("failed to prompt for re-authentication: %w", err)
}
if !reauth {
fmt.Println("Authentication cancelled.")
return nil
}
}
client := auth.NewCopilotOAuthClient()
fmt.Println("🔐 Starting GitHub Copilot authentication...")
fmt.Println("This uses GitHub device login and requires an active GitHub Copilot subscription.")
fmt.Println("Experimental: this uses VS Code Copilot Chat client identifiers.")
fmt.Println()
deviceCode, err := client.StartDeviceFlow(ctx)
if err != nil {
return fmt.Errorf("failed to start GitHub device login: %w", err)
}
fmt.Println("📱 Open this page and enter the code:")
fmt.Printf("\n%s\n\n", deviceCode.VerificationURI)
fmt.Printf("Code: %s\n\n", deviceCode.UserCode)
auth.TryOpenBrowser(deviceCode.VerificationURI)
fmt.Println("Waiting for GitHub authorization...")
githubToken, err := client.PollDeviceToken(ctx, deviceCode)
if err != nil {
return fmt.Errorf("failed to complete GitHub device login: %w", err)
}
fmt.Println("\n🔄 Exchanging GitHub token for Copilot access token...")
creds, err := client.ExchangeGitHubToken(ctx, githubToken)
if err != nil {
return fmt.Errorf("failed to get GitHub Copilot token: %w", err)
}
if err := cm.SetCopilotOAuthCredentials(creds); err != nil {
return fmt.Errorf("failed to store credentials: %w", err)
}
fmt.Println("✅ Successfully authenticated with GitHub Copilot!")
fmt.Printf("📁 Credentials stored in: %s\n", cm.GetCredentialsPath())
fmt.Println("\n🎉 Your GitHub Copilot credentials will now be used for copilot/* models.")
fmt.Println("💡 You can check your authentication status with: kit auth status")
if err := setDefaultModelIfRequested("copilot"); err != nil {
return err
}
if !loginSetDefault {
fmt.Println("\n💡 To set Copilot as your default model, run:")
fmt.Println(" kit auth login copilot --set-default")
}
return nil
}
// callbackServer holds the HTTP server and channel for receiving the OAuth callback
type callbackServer struct {
Server *http.Server
@@ -635,3 +748,43 @@ func logoutOpenAI() error {
return nil
}
func logoutCopilot() error {
cm, err := kit.NewCredentialManager()
if err != nil {
return fmt.Errorf("failed to initialize credential manager: %w", err)
}
hasAuth, err := cm.HasCopilotCredentials()
if err != nil {
return fmt.Errorf("failed to check authentication status: %w", err)
}
if !hasAuth {
fmt.Println("You are not currently authenticated with GitHub Copilot.")
return nil
}
var confirm bool
err = huh.NewConfirm().
Title("Remove GitHub Copilot credentials").
Description("Are you sure you want to remove your stored credentials?").
Affirmative("Yes").
Negative("No").
Value(&confirm).
Run()
if err != nil || !confirm {
fmt.Println("Logout cancelled.")
return nil
}
if err := cm.RemoveCopilotCredentials(); err != nil {
return fmt.Errorf("failed to remove credentials: %w", err)
}
fmt.Println("✓ Successfully logged out from GitHub Copilot!")
fmt.Println("You will need to authenticate again with 'kit auth login copilot'.")
fmt.Println("Tip: this removes local credentials only. Revoke the GitHub OAuth grant at https://github.com/settings/applications")
return nil
}
+267 -429
View File
@@ -4,13 +4,11 @@ import (
"context"
"fmt"
"os"
"strings"
"github.com/spf13/viper"
"golang.org/x/term"
"github.com/mark3labs/kit/internal/app"
"github.com/mark3labs/kit/internal/auth"
"github.com/mark3labs/kit/internal/extbridge"
"github.com/mark3labs/kit/internal/extensions"
"github.com/mark3labs/kit/internal/models"
@@ -35,439 +33,279 @@ type extensionContextDeps struct {
// the three print routes appropriately for their phase (startup buffering
// vs. live runtime routing).
//
// This consolidates two near-identical 400-line literal expressions that
// previously appeared inline in runNormalMode.
// The headless half (data access, state, options, tree navigation, skills,
// templates, model resolution, subagents) comes from extbridge.BaseContext;
// this function overlays the TUI-specific fields and overrides SetModel /
// ReloadExtensions with TUI-aware versions.
func buildInteractiveExtensionContext(deps extensionContextDeps) extensions.Context {
kitInstance := deps.kitInstance
appInstance := deps.appInstance
usageTracker := deps.usageTracker
ctx := deps.ctx
return extensions.Context{
CWD: deps.cwd,
Model: deps.modelName,
Interactive: deps.interactive,
PrintBlock: func(opts extensions.PrintBlockOpts) {
appInstance.PrintBlockFromExtension(opts)
},
SendMessage: func(text string) { appInstance.Run(text) },
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
Abort: func() { appInstance.Abort() },
IsIdle: func() bool { return !appInstance.IsBusy() },
Compact: func(cfg extensions.CompactConfig) error {
return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
},
SendMultimodalMessage: func(text string, files []extensions.FilePart) {
parts := make([]kit.LLMFilePart, len(files))
for i, f := range files {
parts[i] = kit.LLMFilePart{
Filename: f.Filename,
Data: f.Data,
MediaType: f.MediaType,
}
}
appInstance.RunWithFiles(text, parts)
},
GetSessionUsage: func() extensions.SessionUsage {
if usageTracker == nil {
return extensions.SessionUsage{}
}
stats := usageTracker.GetSessionStats()
return extensions.SessionUsage{
TotalInputTokens: stats.TotalInputTokens,
TotalOutputTokens: stats.TotalOutputTokens,
TotalCacheReadTokens: stats.TotalCacheReadTokens,
TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
TotalCost: stats.TotalCost,
RequestCount: stats.RequestCount,
}
},
Exit: func() { appInstance.QuitFromExtension() },
SetWidget: func(config extensions.WidgetConfig) {
kitInstance.Extensions().SetWidget(config)
go appInstance.NotifyWidgetUpdate()
},
RemoveWidget: func(id string) {
kitInstance.Extensions().RemoveWidget(id)
go appInstance.NotifyWidgetUpdate()
},
SetHeader: func(config extensions.HeaderFooterConfig) {
kitInstance.Extensions().SetHeader(config)
go appInstance.NotifyWidgetUpdate()
},
RemoveHeader: func() {
kitInstance.Extensions().RemoveHeader()
go appInstance.NotifyWidgetUpdate()
},
SetFooter: func(config extensions.HeaderFooterConfig) {
kitInstance.Extensions().SetFooter(config)
go appInstance.NotifyWidgetUpdate()
},
RemoveFooter: func() {
kitInstance.Extensions().RemoveFooter()
go appInstance.NotifyWidgetUpdate()
},
PromptSelect: func(config extensions.PromptSelectConfig) extensions.PromptSelectResult {
ch := make(chan app.PromptResponse, 1)
appInstance.SendPromptRequest(app.PromptRequestEvent{
PromptType: "select",
Message: config.Message,
Options: config.Options,
ResponseCh: ch,
})
resp := <-ch
if resp.Cancelled {
return extensions.PromptSelectResult{Cancelled: true}
}
return extensions.PromptSelectResult{Value: resp.Value, Index: resp.Index}
},
PromptConfirm: func(config extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
ch := make(chan app.PromptResponse, 1)
def := "false"
if config.DefaultValue {
def = "true"
}
appInstance.SendPromptRequest(app.PromptRequestEvent{
PromptType: "confirm",
Message: config.Message,
Default: def,
ResponseCh: ch,
})
resp := <-ch
if resp.Cancelled {
return extensions.PromptConfirmResult{Cancelled: true}
}
return extensions.PromptConfirmResult{Value: resp.Confirmed}
},
PromptInput: func(config extensions.PromptInputConfig) extensions.PromptInputResult {
ch := make(chan app.PromptResponse, 1)
appInstance.SendPromptRequest(app.PromptRequestEvent{
PromptType: "input",
Message: config.Message,
Placeholder: config.Placeholder,
Default: config.Default,
ResponseCh: ch,
})
resp := <-ch
if resp.Cancelled {
return extensions.PromptInputResult{Cancelled: true}
}
return extensions.PromptInputResult{Value: resp.Value}
},
SetUIVisibility: func(v extensions.UIVisibility) {
kitInstance.Extensions().SetUIVisibility(v)
go appInstance.NotifyWidgetUpdate()
},
GetContextStats: func() extensions.ContextStats {
s := kitInstance.GetContextStats()
return extensions.ContextStats{
EstimatedTokens: s.EstimatedTokens,
ContextLimit: s.ContextLimit,
UsagePercent: s.UsagePercent,
MessageCount: s.MessageCount,
}
},
SetEditor: func(config extensions.EditorConfig) {
kitInstance.Extensions().SetEditor(config)
// Always use a goroutine for NotifyWidgetUpdate: prog.Send()
// deadlocks if called synchronously from inside BubbleTea's
// Update() handler. All call sites use go-routines uniformly.
go appInstance.NotifyWidgetUpdate()
},
ResetEditor: func() {
kitInstance.Extensions().ResetEditor()
go appInstance.NotifyWidgetUpdate()
},
GetMessages: func() []extensions.SessionMessage {
return kitInstance.Extensions().GetSessionMessages()
},
GetSessionPath: func() string {
return kitInstance.GetSessionPath()
},
AppendEntry: func(entryType string, data string) (string, error) {
return kitInstance.Extensions().AppendEntry(entryType, data)
},
GetEntries: func(entryType string) []extensions.ExtensionEntry {
return kitInstance.Extensions().GetEntries(entryType)
},
SetEditorText: func(text string) {
appInstance.SetEditorTextFromExtension(text)
},
SetStatus: func(key string, text string, priority int) {
kitInstance.Extensions().SetStatus(extensions.StatusBarEntry{
Key: key,
Text: text,
Priority: priority,
})
go appInstance.NotifyWidgetUpdate()
},
RemoveStatus: func(key string) {
kitInstance.Extensions().RemoveStatus(key)
go appInstance.NotifyWidgetUpdate()
},
GetOption: func(name string) string {
return kitInstance.Extensions().GetOption(name)
},
SetOption: func(name string, value string) {
kitInstance.Extensions().SetOption(name, value)
},
SetModel: func(modelString string) error {
// Capture previous model for the ModelChange event.
previousModel := kitInstance.Extensions().GetContext().Model
err := kitInstance.SetModel(context.Background(), modelString)
if err != nil {
return err
}
// Notify TUI so it updates model in status bar.
p, m, _ := models.ParseModelString(modelString)
appInstance.NotifyModelChanged(p, m)
// Update the context's Model field so handlers see it.
kitInstance.Extensions().UpdateContextModel(modelString)
// Fire OnModelChange event to extensions.
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
// Update usage tracker with new model info for correct token counting.
if usageTracker != nil {
newProvider, newModel, _ := models.ParseModelString(modelString)
if newProvider != "unknown" && newModel != "unknown" && newProvider != "ollama" {
registry := models.GetGlobalRegistry()
if modelInfo := registry.LookupModel(newProvider, newModel); modelInfo != nil {
// Check OAuth status for Anthropic models
isOAuth := false
if newProvider == "anthropic" {
_, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key"))
if err == nil && strings.HasPrefix(source, "stored OAuth") {
isOAuth = true
}
}
usageTracker.UpdateModelInfo(modelInfo, newProvider, isOAuth)
}
}
}
return nil
},
GetAvailableModels: func() []extensions.ModelInfoEntry {
return kitInstance.GetAvailableModels()
},
EmitCustomEvent: func(name string, data string) {
kitInstance.Extensions().EmitCustomEvent(name, data)
},
Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
return kitInstance.ExecuteCompletion(context.Background(), req)
},
SuspendTUI: func(callback func()) error {
return appInstance.SuspendTUI(callback)
},
RenderMessage: func(rendererName, content string) {
renderer := kitInstance.Extensions().GetMessageRenderer(rendererName)
if renderer == nil || renderer.Render == nil {
appInstance.PrintFromExtension("", content)
return
}
w, _, _ := term.GetSize(int(os.Stdout.Fd()))
if w == 0 {
w = 80
}
rendered := renderer.Render(content, w)
appInstance.PrintFromExtension("", rendered)
},
ReloadExtensions: func() error {
err := kitInstance.Extensions().Reload()
if err != nil {
return err
}
// Notify TUI that widgets/status/commands may have changed.
go appInstance.NotifyWidgetUpdate()
return nil
},
GetAllTools: func() []extensions.ToolInfo {
return kitInstance.Extensions().GetToolInfos()
},
SetActiveTools: func(names []string) {
kitInstance.Extensions().SetActiveTools(names)
},
RegisterTheme: func(name string, config extensions.ThemeColorConfig) {
tc := func(c extensions.ThemeColor) [2]string { return [2]string{c.Light, c.Dark} }
ui.RegisterThemeFromConfig(name,
tc(config.Primary), tc(config.Secondary),
tc(config.Success), tc(config.Warning),
tc(config.Error), tc(config.Info),
tc(config.Text), tc(config.Muted),
tc(config.VeryMuted), tc(config.Background),
tc(config.Border), tc(config.MutedBorder),
tc(config.System), tc(config.Tool),
tc(config.Accent), tc(config.Highlight),
tc(config.MdHeading), tc(config.MdLink),
tc(config.MdKeyword), tc(config.MdString),
tc(config.MdNumber), tc(config.MdComment),
)
},
SetTheme: func(name string) error {
return ui.ApplyTheme(name)
},
ListThemes: func() []string {
return ui.ListThemes()
},
ShowOverlay: func(config extensions.OverlayConfig) extensions.OverlayResult {
ch := make(chan app.OverlayResponse, 1)
appInstance.SendOverlayRequest(app.OverlayRequestEvent{
Title: config.Title,
Content: config.Content.Text,
Markdown: config.Content.Markdown,
BorderColor: config.Style.BorderColor,
Background: config.Style.Background,
Width: config.Width,
MaxHeight: config.MaxHeight,
Anchor: string(config.Anchor),
Actions: config.Actions,
ResponseCh: ch,
})
resp := <-ch
if resp.Cancelled {
return extensions.OverlayResult{Cancelled: true, Index: -1}
}
return extensions.OverlayResult{
Action: resp.Action,
Index: resp.Index,
}
},
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
return extbridge.SpawnSubagent(ctx, kitInstance, config)
},
// -------------------------------------------------------------------
// Tree Navigation API
// -------------------------------------------------------------------
GetTreeNode: func(entryID string) *extensions.TreeNode {
node := kitInstance.GetTreeNode(entryID)
if node == nil {
return nil
}
return &extensions.TreeNode{
ID: node.ID,
ParentID: node.ParentID,
Type: node.Type,
Role: node.Role,
Content: node.Content,
Model: node.Model,
Provider: node.Provider,
Timestamp: node.Timestamp,
Children: node.Children,
}
},
GetCurrentBranch: func() []extensions.TreeNode {
nodes := kitInstance.GetCurrentBranch()
result := make([]extensions.TreeNode, len(nodes))
for i, n := range nodes {
result[i] = extensions.TreeNode{
ID: n.ID,
ParentID: n.ParentID,
Type: n.Type,
Role: n.Role,
Content: n.Content,
Model: n.Model,
Provider: n.Provider,
Timestamp: n.Timestamp,
Children: n.Children,
}
}
return result
},
GetChildren: func(parentID string) []string {
return kitInstance.GetChildren(parentID)
},
NavigateTo: func(entryID string) extensions.TreeNavigationResult {
err := kitInstance.NavigateTo(entryID)
if err != nil {
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
}
return extensions.TreeNavigationResult{Success: true}
},
SummarizeBranch: func(fromID, toID string) string {
summary, _ := kitInstance.SummarizeBranch(fromID, toID)
return summary
},
CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult {
err := kitInstance.CollapseBranch(fromID, toID, summary)
if err != nil {
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
}
return extensions.TreeNavigationResult{Success: true}
},
ec := extbridge.BaseContext(deps.ctx, kitInstance)
// -------------------------------------------------------------------
// Skill Loading API
// -------------------------------------------------------------------
LoadSkill: func(path string) (*extensions.Skill, string) {
s, err := kitInstance.LoadSkillForExtension(path)
return s, err
},
LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult {
return kitInstance.LoadSkillsFromDirForExtension(dir)
},
DiscoverSkills: func() extensions.SkillLoadResult {
skills := kitInstance.DiscoverSkillsForExtension()
return extensions.SkillLoadResult{Skills: skills}
},
InjectSkillAsContext: func(skillName string) string {
skills := kitInstance.DiscoverSkillsForExtension()
for _, s := range skills {
if s.Name == skillName {
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
return ""
}
}
return fmt.Sprintf("skill not found: %s", skillName)
},
InjectRawSkillAsContext: func(path string) string {
s, err := kitInstance.LoadSkillForExtension(path)
if err != "" {
return err
}
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
return ""
},
GetAvailableSkills: func() []extensions.Skill {
return kitInstance.DiscoverSkillsForExtension()
},
ec.CWD = deps.cwd
ec.Model = deps.modelName
ec.Interactive = deps.interactive
// -------------------------------------------------------------------
// Template Parsing API
// -------------------------------------------------------------------
ParseTemplate: func(name, content string) extensions.PromptTemplate {
return kit.ParseTemplate(name, content)
},
RenderTemplate: func(tpl extensions.PromptTemplate, vars map[string]string) string {
return kit.RenderTemplate(tpl, vars)
},
ParseArguments: func(input string, pattern extensions.ArgumentPattern) extensions.ParseResult {
return kit.ParseArguments(input, pattern)
},
SimpleParseArguments: func(input string, count int) []string {
return kit.SimpleParseArguments(input, count)
},
EvaluateModelConditional: func(condition string) bool {
return kit.EvaluateModelConditional(kitInstance.Extensions().GetContext().Model, condition)
},
RenderWithModelConditionals: func(content string) string {
return kit.RenderWithModelConditionals(content, kitInstance.Extensions().GetContext().Model)
},
// -------------------------------------------------------------------
// Model Resolution API
// -------------------------------------------------------------------
ResolveModelChain: func(preferences []string) extensions.ModelResolutionResult {
return kit.ResolveModelChain(preferences)
},
GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) {
return kit.GetModelCapabilities(model)
},
CheckModelAvailable: func(model string) bool {
return kit.CheckModelAvailable(model)
},
GetCurrentProvider: func() string {
return kit.GetCurrentProvider(kitInstance.Extensions().GetContext().Model)
},
GetCurrentModelID: func() string {
return kit.GetCurrentModelID(kitInstance.Extensions().GetContext().Model)
},
ec.PrintBlock = func(opts extensions.PrintBlockOpts) {
appInstance.PrintBlockFromExtension(opts)
}
ec.SendMessage = func(text string) { appInstance.Run(text) }
ec.CancelAndSend = func(text string) { appInstance.InterruptAndSend(text) }
ec.Abort = func() { appInstance.Abort() }
ec.IsIdle = func() bool { return !appInstance.IsBusy() }
ec.Compact = func(cfg extensions.CompactConfig) error {
return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
}
ec.SendMultimodalMessage = func(text string, files []extensions.FilePart) {
parts := make([]kit.LLMFilePart, len(files))
for i, f := range files {
parts[i] = kit.LLMFilePart{
Filename: f.Filename,
Data: f.Data,
MediaType: f.MediaType,
}
}
appInstance.RunWithFiles(text, parts)
}
ec.NewSession = func(prompt string) error {
return appInstance.RequestNewSessionFromExtension(prompt)
}
ec.GetSessionUsage = func() extensions.SessionUsage {
if usageTracker == nil {
return extensions.SessionUsage{}
}
stats := usageTracker.GetSessionStats()
return extensions.SessionUsage{
TotalInputTokens: stats.TotalInputTokens,
TotalOutputTokens: stats.TotalOutputTokens,
TotalCacheReadTokens: stats.TotalCacheReadTokens,
TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
TotalCost: stats.TotalCost,
RequestCount: stats.RequestCount,
}
}
ec.Exit = func() { appInstance.QuitFromExtension() }
// TUI widgets/chrome — mutate runner state, then notify the TUI.
// Always use a goroutine for NotifyWidgetUpdate: prog.Send() deadlocks
// if called synchronously from inside BubbleTea's Update() handler.
// All call sites use go-routines uniformly.
ec.SetWidget = func(config extensions.WidgetConfig) {
kitInstance.Extensions().SetWidget(config)
go appInstance.NotifyWidgetUpdate()
}
ec.RemoveWidget = func(id string) {
kitInstance.Extensions().RemoveWidget(id)
go appInstance.NotifyWidgetUpdate()
}
ec.SetHeader = func(config extensions.HeaderFooterConfig) {
kitInstance.Extensions().SetHeader(config)
go appInstance.NotifyWidgetUpdate()
}
ec.RemoveHeader = func() {
kitInstance.Extensions().RemoveHeader()
go appInstance.NotifyWidgetUpdate()
}
ec.SetFooter = func(config extensions.HeaderFooterConfig) {
kitInstance.Extensions().SetFooter(config)
go appInstance.NotifyWidgetUpdate()
}
ec.RemoveFooter = func() {
kitInstance.Extensions().RemoveFooter()
go appInstance.NotifyWidgetUpdate()
}
ec.SetUIVisibility = func(v extensions.UIVisibility) {
kitInstance.Extensions().SetUIVisibility(v)
go appInstance.NotifyWidgetUpdate()
}
ec.SetEditor = func(config extensions.EditorConfig) {
kitInstance.Extensions().SetEditor(config)
go appInstance.NotifyWidgetUpdate()
}
ec.ResetEditor = func() {
kitInstance.Extensions().ResetEditor()
go appInstance.NotifyWidgetUpdate()
}
ec.SetEditorText = func(text string) {
appInstance.SetEditorTextFromExtension(text)
}
ec.SetStatus = func(key string, text string, priority int) {
kitInstance.Extensions().SetStatus(extensions.StatusBarEntry{
Key: key,
Text: text,
Priority: priority,
})
go appInstance.NotifyWidgetUpdate()
}
ec.RemoveStatus = func(key string) {
kitInstance.Extensions().RemoveStatus(key)
go appInstance.NotifyWidgetUpdate()
}
// Interactive prompts — channel-based round trips through the TUI.
ec.PromptSelect = func(config extensions.PromptSelectConfig) extensions.PromptSelectResult {
ch := make(chan app.PromptResponse, 1)
appInstance.SendPromptRequest(app.PromptRequestEvent{
PromptType: "select",
Message: config.Message,
Options: config.Options,
ResponseCh: ch,
})
resp := <-ch
if resp.Cancelled {
return extensions.PromptSelectResult{Cancelled: true}
}
return extensions.PromptSelectResult{Value: resp.Value, Index: resp.Index}
}
ec.PromptConfirm = func(config extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
ch := make(chan app.PromptResponse, 1)
def := "false"
if config.DefaultValue {
def = "true"
}
appInstance.SendPromptRequest(app.PromptRequestEvent{
PromptType: "confirm",
Message: config.Message,
Default: def,
ResponseCh: ch,
})
resp := <-ch
if resp.Cancelled {
return extensions.PromptConfirmResult{Cancelled: true}
}
return extensions.PromptConfirmResult{Value: resp.Confirmed}
}
ec.PromptInput = func(config extensions.PromptInputConfig) extensions.PromptInputResult {
ch := make(chan app.PromptResponse, 1)
appInstance.SendPromptRequest(app.PromptRequestEvent{
PromptType: "input",
Message: config.Message,
Placeholder: config.Placeholder,
Default: config.Default,
ResponseCh: ch,
})
resp := <-ch
if resp.Cancelled {
return extensions.PromptInputResult{Cancelled: true}
}
return extensions.PromptInputResult{Value: resp.Value}
}
ec.ShowOverlay = func(config extensions.OverlayConfig) extensions.OverlayResult {
ch := make(chan app.OverlayResponse, 1)
appInstance.SendOverlayRequest(app.OverlayRequestEvent{
Title: config.Title,
Content: config.Content.Text,
Markdown: config.Content.Markdown,
BorderColor: config.Style.BorderColor,
Background: config.Style.Background,
Width: config.Width,
MaxHeight: config.MaxHeight,
Anchor: string(config.Anchor),
Actions: config.Actions,
ResponseCh: ch,
})
resp := <-ch
if resp.Cancelled {
return extensions.OverlayResult{Cancelled: true, Index: -1}
}
return extensions.OverlayResult{
Action: resp.Action,
Index: resp.Index,
}
}
ec.SuspendTUI = func(callback func()) error {
return appInstance.SuspendTUI(callback)
}
// TUI-aware model switch: also notifies the TUI status bar and
// refreshes the usage tracker for correct token counting.
ec.SetModel = func(modelString string) error {
// Capture previous model for the ModelChange event.
previousModel := kitInstance.Extensions().GetContext().Model
err := kitInstance.SetModel(context.Background(), modelString)
if err != nil {
return err
}
// Notify TUI so it updates model in status bar.
p, m, _ := models.ParseModelString(modelString)
appInstance.NotifyModelChanged(p, m)
// Update the context's Model field so handlers see it.
kitInstance.Extensions().UpdateContextModel(modelString)
// Fire OnModelChange event to extensions.
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
// Update usage tracker with new model info for correct token counting.
ui.UpdateUsageTrackerForModel(usageTracker, modelString, viper.GetString("provider-api-key"))
return nil
}
ec.RenderMessage = func(rendererName, content string) {
renderer := kitInstance.Extensions().GetMessageRenderer(rendererName)
if renderer == nil || renderer.Render == nil {
appInstance.PrintFromExtension("", content)
return
}
w, _, _ := term.GetSize(int(os.Stdout.Fd()))
if w == 0 {
w = 80
}
rendered := renderer.Render(content, w)
appInstance.PrintFromExtension("", rendered)
}
ec.ReloadExtensions = func() error {
err := kitInstance.Extensions().Reload()
if err != nil {
return err
}
// Notify TUI that widgets/status/commands may have changed.
go appInstance.NotifyWidgetUpdate()
return nil
}
// Theme management (TUI only).
ec.RegisterTheme = func(name string, config extensions.ThemeColorConfig) {
tc := func(c extensions.ThemeColor) [2]string { return [2]string{c.Light, c.Dark} }
ui.RegisterThemeFromConfig(name,
tc(config.Primary), tc(config.Secondary),
tc(config.Success), tc(config.Warning),
tc(config.Error), tc(config.Info),
tc(config.Text), tc(config.Muted),
tc(config.VeryMuted), tc(config.Background),
tc(config.Border), tc(config.MutedBorder),
tc(config.System), tc(config.Tool),
tc(config.Accent), tc(config.Highlight),
tc(config.MdHeading), tc(config.MdLink),
tc(config.MdKeyword), tc(config.MdString),
tc(config.MdNumber), tc(config.MdComment),
)
}
ec.SetTheme = func(name string) error {
return ui.ApplyTheme(name)
}
ec.ListThemes = func() []string {
return ui.ListThemes()
}
// Skill context-injection (drives a new agent turn through the TUI).
ec.InjectSkillAsContext = func(skillName string) string {
skills := kitInstance.DiscoverSkillsForExtension()
for _, s := range skills {
if s.Name == skillName {
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
return ""
}
}
return fmt.Sprintf("skill not found: %s", skillName)
}
ec.InjectRawSkillAsContext = func(path string) string {
s, err := kitInstance.LoadSkillForExtension(path)
if err != "" {
return err
}
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
return ""
}
return ec
}
+255
View File
@@ -0,0 +1,255 @@
package cmd
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"charm.land/huh/v2"
"github.com/charmbracelet/log"
kit "github.com/mark3labs/kit/pkg/kit"
"github.com/spf13/cobra"
)
// defaultGitHubModel is the model written into the generated workflow when the
// user does not specify one and runs non-interactively.
const defaultGitHubModel = "anthropic/claude-sonnet-4-5-20250929"
// githubWorkflowPath is the repository-relative location of the generated
// GitHub Actions workflow that wires Kit into a repository as a collaborator.
const githubWorkflowPath = ".github/workflows/kit.yml"
var (
githubInstallModel string
githubInstallForce bool
githubInstallNoSecret bool
)
// githubCmd is the parent command for GitHub integration subcommands. It groups
// the turnkey setup tooling that wires Kit into a repository as an automated
// collaborator/reviewer driven by GitHub Actions.
var githubCmd = &cobra.Command{
Use: "github",
Short: "Set up Kit as a GitHub collaborator/reviewer",
Long: `Set up Kit as an automated collaborator/reviewer in a GitHub repository.
Kit runs inside a GitHub Actions runner, reads the relevant context (an issue
thread or pull request), runs the agent non-interactively, and responds by
posting comments and opening pull requests.
Use 'kit github install' to scaffold the GitHub Actions workflow.`,
}
// githubInstallCmd scaffolds the GitHub Actions workflow that runs Kit on
// '/kit' comment triggers. It writes .github/workflows/kit.yml and, when the
// 'gh' CLI is available, offers to set the provider API key as a repository
// secret.
var githubInstallCmd = &cobra.Command{
Use: "install",
Short: "Scaffold the GitHub Actions workflow that runs Kit",
Long: `Scaffold the GitHub Actions workflow that runs Kit as a collaborator.
This writes .github/workflows/kit.yml configured to trigger when someone
comments '/kit ...' on an issue or pull request review. The workflow runs Kit
inside an ephemeral Actions runner with least-privilege permissions and
'persist-credentials: false', mirroring established security practice.
If the GitHub CLI ('gh') is detected on your PATH, you will be offered the
option to store your provider API key as a repository secret automatically.
Flags:
--model Provider/model to write into the workflow (e.g. anthropic/claude-sonnet-4-5)
--force Overwrite an existing workflow file
--no-secret Skip the offer to set the provider secret via the gh CLI
Examples:
kit github install
kit github install --model anthropic/claude-sonnet-4-5-20250929
kit github install --force --no-secret`,
Args: cobra.NoArgs,
RunE: runGitHubInstall,
}
func init() {
githubInstallCmd.Flags().StringVarP(&githubInstallModel, "model", "m", "", "provider/model to write into the workflow")
githubInstallCmd.Flags().BoolVar(&githubInstallForce, "force", false, "overwrite an existing workflow file")
githubInstallCmd.Flags().BoolVar(&githubInstallNoSecret, "no-secret", false, "skip setting the provider secret via the gh CLI")
githubCmd.AddCommand(githubInstallCmd)
rootCmd.AddCommand(githubCmd)
}
func runGitHubInstall(cmd *cobra.Command, _ []string) error {
model, err := resolveGitHubModel()
if err != nil {
return err
}
provider, _, err := kit.ParseModelString(model)
if err != nil {
return fmt.Errorf("invalid model %q: %w", model, err)
}
secretName := providerSecretEnvVar(provider)
if err := writeGitHubWorkflow(model, secretName, githubInstallForce); err != nil {
return err
}
fmt.Printf("✅ Wrote %s\n", githubWorkflowPath)
maybeSetProviderSecret(cmd.Context(), secretName)
printGitHubInstallNextSteps(secretName)
log.Info("github workflow scaffolded", "model", model, "secret", secretName)
return nil
}
// resolveGitHubModel determines the model to embed in the workflow. The
// --model flag takes precedence; otherwise an interactive prompt is shown
// (pre-filled with the default), and non-interactive runs use the default.
func resolveGitHubModel() (string, error) {
if githubInstallModel != "" {
return strings.TrimSpace(githubInstallModel), nil
}
if !isInteractive() {
return defaultGitHubModel, nil
}
model := defaultGitHubModel
err := huh.NewInput().
Title("Model").
Description("Provider/model Kit should use in CI (e.g. anthropic/claude-sonnet-4-5)").
Value(&model).
Run()
if err != nil {
return "", fmt.Errorf("model selection cancelled: %w", err)
}
model = strings.TrimSpace(model)
if model == "" {
return "", fmt.Errorf("model cannot be empty")
}
return model, nil
}
// providerSecretEnvVar returns the environment variable / repository secret
// name that holds the API key for the given provider. It consults the model
// registry and falls back to "<PROVIDER>_API_KEY" for unknown providers.
func providerSecretEnvVar(provider string) string {
if info := kit.GetProviderInfo(provider); info != nil && len(info.Env) > 0 {
return info.Env[0]
}
sanitized := strings.ToUpper(strings.NewReplacer("-", "_", ".", "_").Replace(provider))
return sanitized + "_API_KEY"
}
// renderGitHubWorkflow builds the workflow YAML for the given model and
// provider secret name.
func renderGitHubWorkflow(model, secretName string) string {
return fmt.Sprintf(`name: kit
on:
issue_comment:
types: [created]
pull_request_review_comment:
types: [created]
jobs:
kit:
if: |
(github.event.comment.author_association == 'OWNER' ||
github.event.comment.author_association == 'MEMBER' ||
github.event.comment.author_association == 'COLLABORATOR') &&
(startsWith(github.event.comment.body, '/kit ') ||
github.event.comment.body == '/kit')
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
issues: write
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: mark3labs/kit@v0
with:
model: %s
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
%s: ${{ secrets.%s }}
`, model, secretName, secretName)
}
// writeGitHubWorkflow writes the generated workflow to githubWorkflowPath,
// creating parent directories as needed. It refuses to overwrite an existing
// file unless force is true.
func writeGitHubWorkflow(model, secretName string, force bool) error {
if _, err := os.Stat(githubWorkflowPath); err == nil && !force {
return fmt.Errorf("%s already exists; re-run with --force to overwrite", githubWorkflowPath)
} else if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("checking %s: %w", githubWorkflowPath, err)
}
if err := os.MkdirAll(filepath.Dir(githubWorkflowPath), 0o755); err != nil {
return fmt.Errorf("creating %s: %w", filepath.Dir(githubWorkflowPath), err)
}
content := renderGitHubWorkflow(model, secretName)
if err := os.WriteFile(githubWorkflowPath, []byte(content), 0o644); err != nil {
return fmt.Errorf("writing %s: %w", githubWorkflowPath, err)
}
return nil
}
// maybeSetProviderSecret offers to set the provider API key as a repository
// secret via the gh CLI when it is available, interactive, the secret value is
// present in the environment, and the user did not pass --no-secret.
func maybeSetProviderSecret(ctx context.Context, secretName string) {
if githubInstallNoSecret || !isInteractive() {
return
}
if _, err := exec.LookPath("gh"); err != nil {
return
}
value := os.Getenv(secretName)
if value == "" {
fmt.Printf("️ %s is not set in your environment; set the repository secret manually with:\n", secretName)
fmt.Printf(" gh secret set %s\n", secretName)
return
}
var confirm bool
if err := huh.NewConfirm().
Title(fmt.Sprintf("Set the %s repository secret via gh?", secretName)).
Description("Uses the value from your current environment.").
Value(&confirm).
Run(); err != nil || !confirm {
return
}
// Feed the secret value via stdin rather than a command-line argument so
// the API key never appears in the process argument list.
cmd := exec.CommandContext(ctx, "gh", "secret", "set", secretName)
cmd.Stdin = strings.NewReader(value)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
fmt.Printf("⚠️ Failed to set secret via gh: %v\n", err)
fmt.Printf(" Set it manually with: gh secret set %s\n", secretName)
return
}
fmt.Printf("✅ Set repository secret %s\n", secretName)
}
// printGitHubInstallNextSteps prints the manual follow-up actions a user must
// take after the workflow is scaffolded.
func printGitHubInstallNextSteps(secretName string) {
fmt.Println("\nNext steps:")
fmt.Printf(" 1. Commit the workflow: git add %s && git commit -m \"ci: add kit workflow\"\n", githubWorkflowPath)
fmt.Printf(" 2. Set the %s repository secret (Settings → Secrets → Actions), if not already set.\n", secretName)
fmt.Println(" 3. Comment '/kit <your request>' on an issue or pull request to trigger Kit.")
}
+498
View File
@@ -0,0 +1,498 @@
package cmd
import (
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"strings"
"time"
"github.com/charmbracelet/log"
"github.com/spf13/cobra"
)
// commandToken is the mention that triggers Kit from a comment, mirroring the
// `if:` guard in the generated workflow (.github/workflows/kit.yml).
const commandToken = "/kit"
// subprocessTimeout bounds each git/gh invocation so a stalled network call or
// an unexpected auth prompt cannot hang the Actions job indefinitely.
const subprocessTimeout = 30 * time.Second
// agentTimeout bounds the headless agent run so a runaway turn cannot block the
// job forever. GitHub Actions jobs have their own ceiling, but a tighter bound
// keeps feedback fast and costs predictable.
const agentTimeout = 20 * time.Minute
// botName / botEmail are the dedicated identity commits are attributed to, so
// Kit's changes are clearly distinguishable from human authors in history.
const (
botName = "kit-agent[bot]"
botEmail = "kit-agent[bot]@users.noreply.github.com"
)
// writeAssociations are the GitHub author_association values that imply
// write/admin access. Only these may trigger the handler.
var writeAssociations = map[string]bool{
"OWNER": true,
"MEMBER": true,
"COLLABORATOR": true,
}
var (
githubRunModel string
githubRunDryRun bool
)
// githubRunCmd is the runtime half of the GitHub integration. It is invoked by
// the bundled composite action (action.yml) inside a GitHub Actions runner once
// a collaborator comments '/kit <request>' on an issue or pull request. It reads
// the triggering event, enforces permissions, runs the agent headlessly against
// the comment/PR context, and responds by posting a comment and — when the agent
// leaves changes — opening a pull request.
var githubRunCmd = &cobra.Command{
Use: "run",
Short: "Run Kit against the current GitHub Actions event (used by the kit action)",
Long: `Run Kit against the current GitHub Actions event.
This command is normally invoked by the bundled composite action inside a
GitHub Actions runner; you rarely run it by hand. It reads the triggering
event from GITHUB_EVENT_PATH, verifies the commenter has write/admin access,
reacts with an emoji while it works, runs the agent non-interactively against
the issue thread or pull request, posts the response as a comment, and — if the
agent modified files — pushes a kit-agent[bot] branch and opens a pull request.
Set --dry-run (or KIT_GITHUB_DRY_RUN=1) to log every git/gh side effect and
skip the agent run instead of executing them.`,
Args: cobra.NoArgs,
RunE: runGitHubRun,
}
func init() {
githubRunCmd.Flags().StringVarP(&githubRunModel, "model", "m", "", "provider/model the agent should use (falls back to $MODEL, then a default)")
githubRunCmd.Flags().BoolVar(&githubRunDryRun, "dry-run", false, "log git/gh side effects and skip the agent run instead of executing them")
githubCmd.AddCommand(githubRunCmd)
}
// --- GitHub event types ------------------------------------------------------
type ghUser struct {
Login string `json:"login"`
}
type ghComment struct {
ID int64 `json:"id"`
Body string `json:"body"`
AuthorAssociation string `json:"author_association"`
User ghUser `json:"user"`
}
type ghIssue struct {
Number int `json:"number"`
Title string `json:"title"`
Body string `json:"body"`
PullRequest json.RawMessage `json:"pull_request"`
}
type ghPull struct {
Number int `json:"number"`
Title string `json:"title"`
Body string `json:"body"`
}
type ghRepo struct {
FullName string `json:"full_name"`
DefaultBranch string `json:"default_branch"`
}
type ghEvent struct {
Action string `json:"action"`
Comment *ghComment `json:"comment"`
Issue *ghIssue `json:"issue"`
PullRequest *ghPull `json:"pull_request"`
Repository ghRepo `json:"repository"`
}
// trigger normalises a single invocation across issue_comment and
// pull_request_review_comment events.
type trigger struct {
repo string
defaultBranch string
number int // issue or PR number
isPR bool // true when the target is a pull request
commentID int64 // triggering comment id (for reactions)
commentKind string // "issues" or "pulls" — reaction API path segment
author string
association string
request string // the user's instruction (comment body minus the token)
title string
body string
}
// runGitHubRun is the entry point wired to `kit github run`.
func runGitHubRun(cmd *cobra.Command, _ []string) error {
ctx := cmd.Context()
if !inGitHubActions() && !githubDryRun() {
return fmt.Errorf("kit github run is meant to run inside GitHub Actions (set GITHUB_ACTIONS=true or pass --dry-run)")
}
event, err := loadGitHubEvent()
if err != nil {
return err
}
tr, err := buildTrigger(event)
if err != nil {
// Not an actionable trigger (the workflow `if:` normally prevents this).
log.Info("github run: nothing to do", "reason", err)
return nil
}
if !writeAssociations[strings.ToUpper(tr.association)] {
log.Warn("github run: ignoring /kit from unauthorized author",
"author", tr.author, "association", tr.association)
return nil
}
model := resolveRunModel()
log.Info("github run: handling trigger",
"repo", tr.repo, "number", tr.number, "pr", tr.isPR, "author", tr.author, "model", model)
// React with 👀 so the human sees Kit picked up the request.
addReaction(ctx, tr, "eyes")
gathered := gatherContext(ctx, tr)
prompt := buildPrompt(tr, gathered)
response, runErr := runAgent(ctx, model, prompt)
if runErr != nil {
postComment(ctx, tr, "⚠️ Kit hit an error while processing this request:\n\n```\n"+runErr.Error()+"\n```")
addReaction(ctx, tr, "confused")
return runErr
}
response = strings.TrimSpace(response)
if response == "" {
response = "Kit finished without a textual response."
}
prURL := ""
if hasUncommittedChanges(ctx) {
prURL = openPullRequest(ctx, tr, response)
}
comment := response
if prURL != "" {
comment += "\n\n---\nOpened a pull request with the changes: " + prURL
}
postComment(ctx, tr, comment)
addReaction(ctx, tr, "rocket")
return nil
}
// resolveRunModel picks the model: --model flag, then $MODEL, then the default.
func resolveRunModel() string {
if m := strings.TrimSpace(githubRunModel); m != "" {
return m
}
if m := strings.TrimSpace(os.Getenv("MODEL")); m != "" {
return m
}
return defaultGitHubModel
}
func inGitHubActions() bool {
return os.Getenv("GITHUB_ACTIONS") == "true"
}
// githubDryRun reports whether side effects should be logged instead of run.
func githubDryRun() bool {
return githubRunDryRun || os.Getenv("KIT_GITHUB_DRY_RUN") != ""
}
// loadGitHubEvent reads and decodes the GitHub Actions event payload.
func loadGitHubEvent() (*ghEvent, error) {
path := os.Getenv("GITHUB_EVENT_PATH")
if path == "" {
return nil, fmt.Errorf("GITHUB_EVENT_PATH is not set")
}
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("reading event payload: %w", err)
}
var event ghEvent
if err := json.Unmarshal(data, &event); err != nil {
return nil, fmt.Errorf("parsing event payload: %w", err)
}
return &event, nil
}
// buildTrigger normalises an event into a trigger, or returns an error when the
// event is not an actionable `/kit` comment.
func buildTrigger(event *ghEvent) (*trigger, error) {
if event.Comment == nil {
return nil, fmt.Errorf("event has no comment; nothing to do")
}
request, ok := extractRequest(event.Comment.Body)
if !ok {
return nil, fmt.Errorf("comment does not contain the %q command", commandToken)
}
tr := &trigger{
repo: event.Repository.FullName,
defaultBranch: event.Repository.DefaultBranch,
commentID: event.Comment.ID,
author: event.Comment.User.Login,
association: event.Comment.AuthorAssociation,
request: request,
}
if tr.defaultBranch == "" {
tr.defaultBranch = "main"
}
switch {
case event.Issue != nil:
tr.number = event.Issue.Number
tr.title = event.Issue.Title
tr.body = event.Issue.Body
tr.isPR = len(event.Issue.PullRequest) > 0
tr.commentKind = "issues"
case event.PullRequest != nil:
tr.number = event.PullRequest.Number
tr.title = event.PullRequest.Title
tr.body = event.PullRequest.Body
tr.isPR = true
tr.commentKind = "pulls"
default:
return nil, fmt.Errorf("event has no issue or pull_request target")
}
if tr.repo == "" {
return nil, fmt.Errorf("event is missing repository.full_name")
}
return tr, nil
}
// extractRequest pulls the instruction text out of a comment body that mentions
// the command token. It only recognizes the token at the start of a line
// (mirroring the workflow guard) or at the very end, so incidental mid-sentence
// mentions like "please review /kit behavior" do not trigger the handler. It
// returns the remainder of the matching line as the request.
func extractRequest(body string) (string, bool) {
for line := range strings.SplitSeq(body, "\n") {
trimmed := strings.TrimSpace(line)
var rest string
switch {
case trimmed == commandToken:
return "", true
case strings.HasPrefix(trimmed, commandToken+" "):
rest = trimmed[len(commandToken):]
case strings.HasSuffix(trimmed, " "+commandToken):
return "", true
default:
continue
}
return strings.TrimSpace(rest), true
}
return "", false
}
// gatherContext assembles the issue thread or PR diff to give the agent. It
// always includes the title/body from the event payload, and — outside dry-run,
// when `gh` is available — enriches with the comment thread and PR diff.
func gatherContext(ctx context.Context, tr *trigger) string {
var b strings.Builder
target := "Issue"
if tr.isPR {
target = "Pull request"
}
fmt.Fprintf(&b, "%s #%d: %s\n", target, tr.number, tr.title)
if strings.TrimSpace(tr.body) != "" {
fmt.Fprintf(&b, "\n%s\n", strings.TrimSpace(tr.body))
}
if githubDryRun() || !commandExists("gh") {
return b.String()
}
num := fmt.Sprint(tr.number)
if tr.isPR {
if diff := ghOutput(ctx, "pr", "diff", num, "--repo", tr.repo); diff != "" {
fmt.Fprintf(&b, "\n## Diff\n```diff\n%s\n```\n", strings.TrimSpace(diff))
}
if comments := ghOutput(ctx, "pr", "view", num, "--repo", tr.repo, "--json", "comments", "--jq", ".comments[] | \"@\\(.author.login): \\(.body)\""); comments != "" {
fmt.Fprintf(&b, "\n## Comments\n%s\n", strings.TrimSpace(comments))
}
} else {
if comments := ghOutput(ctx, "issue", "view", num, "--repo", tr.repo, "--json", "comments", "--jq", ".comments[] | \"@\\(.author.login): \\(.body)\""); comments != "" {
fmt.Fprintf(&b, "\n## Comments\n%s\n", strings.TrimSpace(comments))
}
}
return b.String()
}
// buildPrompt constructs the instruction sent to the agent.
func buildPrompt(tr *trigger, gathered string) string {
target := "issue"
if tr.isPR {
target = "pull request"
}
request := tr.request
if request == "" {
request = "(no explicit instruction — review the " + target + " and respond helpfully)"
}
var b strings.Builder
fmt.Fprintf(&b, "You are Kit, operating as an automated collaborator on the GitHub repository %s.\n\n", tr.repo)
fmt.Fprintf(&b, "@%s (access: %s) triggered you on %s #%d with this request:\n\n", tr.author, tr.association, target, tr.number)
fmt.Fprintf(&b, "%s\n\n", request)
fmt.Fprintf(&b, "## Context\n%s\n\n", strings.TrimSpace(gathered))
b.WriteString("Carry out the request. If you modify files, they will be committed to a new ")
b.WriteString("branch and a pull request will be opened automatically, so you do not need to ")
b.WriteString("commit or push yourself. Finish with a concise summary of what you did.")
return b.String()
}
// runAgent drives the agent headlessly by invoking this same binary in quiet,
// ephemeral mode against the constructed prompt, and returns its response. In
// dry-run it returns a canned response without spawning anything.
func runAgent(ctx context.Context, model, prompt string) (string, error) {
if githubDryRun() {
log.Info("github run: [dry-run] would run agent", "model", model, "promptChars", len(prompt))
return "[dry-run] agent response", nil
}
exe, err := os.Executable()
if err != nil || exe == "" {
exe = "kit"
}
runCtx, cancel := context.WithTimeout(ctx, agentTimeout)
defer cancel()
args := []string{"--quiet", "--no-session", "--no-extensions"}
if model != "" {
args = append(args, "--model", model)
}
args = append(args, prompt)
cmd := exec.CommandContext(runCtx, exe, args...)
cmd.Stderr = os.Stderr // surface agent progress/errors in the Actions log
out, err := cmd.Output()
if err != nil {
return "", fmt.Errorf("agent run failed: %w", err)
}
return string(out), nil
}
// hasUncommittedChanges reports whether the agent produced working-tree changes.
func hasUncommittedChanges(ctx context.Context) bool {
if githubDryRun() {
return os.Getenv("KIT_GITHUB_FAKE_DIRTY") != ""
}
return strings.TrimSpace(gitOutput(ctx, "status", "--porcelain")) != ""
}
// openPullRequest commits the working tree as kit-agent[bot], pushes a branch,
// and opens a PR. It returns the PR URL, or "" on failure / dry-run.
func openPullRequest(ctx context.Context, tr *trigger, summary string) string {
branch := fmt.Sprintf("kit/issue-%d-%d", tr.number, time.Now().Unix())
runGit(ctx, "checkout", "-b", branch)
runGit(ctx, "add", "-A")
runGit(ctx, "-c", "user.name="+botName, "-c", "user.email="+botEmail,
"commit", "-m", fmt.Sprintf("kit: address #%d", tr.number))
// `persist-credentials: false` in the workflow means the checkout left no
// push credentials behind. Re-establish them from GITHUB_TOKEN via gh's git
// credential helper, then push over the existing origin remote.
if !githubDryRun() {
runCmd(ctx, "gh", "auth", "setup-git")
}
runGit(ctx, "push", "origin", "HEAD:"+branch)
title := fmt.Sprintf("kit: changes for #%d", tr.number)
body := fmt.Sprintf("Automated changes from Kit in response to #%d.\n\n%s", tr.number, summary)
if githubDryRun() {
log.Info("github run: [dry-run] would open PR", "branch", branch, "base", tr.defaultBranch)
return ""
}
return strings.TrimSpace(ghOutput(ctx, "pr", "create", "--repo", tr.repo,
"--head", branch, "--base", tr.defaultBranch, "--title", title, "--body", body))
}
// addReaction adds an emoji reaction to the trigger comment.
func addReaction(ctx context.Context, tr *trigger, content string) {
path := fmt.Sprintf("/repos/%s/%s/comments/%d/reactions", tr.repo, tr.commentKind, tr.commentID)
if githubDryRun() || !commandExists("gh") {
log.Info("github run: [dry-run] react", "content", content, "path", path)
return
}
runCmd(ctx, "gh", "api", "-X", "POST", path, "-f", "content="+content)
}
// postComment posts a comment back on the triggering issue or pull request.
func postComment(ctx context.Context, tr *trigger, body string) {
sub := "issue"
if tr.isPR {
sub = "pr"
}
if githubDryRun() || !commandExists("gh") {
log.Info("github run: [dry-run] comment", "sub", sub, "number", tr.number, "chars", len(body))
return
}
runCmd(ctx, "gh", sub, "comment", fmt.Sprint(tr.number), "--repo", tr.repo, "--body", body)
}
// --- thin subprocess helpers -------------------------------------------------
func commandExists(name string) bool {
_, err := exec.LookPath(name)
return err == nil
}
// runGit runs a mutating git command, logging instead of executing in dry-run.
func runGit(ctx context.Context, args ...string) {
if githubDryRun() {
log.Info("github run: [dry-run] git", "args", strings.Join(args, " "))
return
}
runCmd(ctx, "git", args...)
}
// gitOutput runs a read-only git command and returns its stdout.
func gitOutput(ctx context.Context, args ...string) string {
cmdCtx, cancel := context.WithTimeout(ctx, subprocessTimeout)
defer cancel()
out, err := exec.CommandContext(cmdCtx, "git", args...).Output()
if err != nil {
log.Error("github run: git failed", "args", strings.Join(args, " "), "err", err)
return ""
}
return string(out)
}
// ghOutput runs a gh command and returns its stdout.
func ghOutput(ctx context.Context, args ...string) string {
cmdCtx, cancel := context.WithTimeout(ctx, subprocessTimeout)
defer cancel()
out, err := exec.CommandContext(cmdCtx, "gh", args...).Output()
if err != nil {
log.Error("github run: gh failed", "args", strings.Join(args, " "), "err", err)
return ""
}
return string(out)
}
// runCmd runs a command for its side effects, surfacing failures in the log.
func runCmd(ctx context.Context, name string, args ...string) {
cmdCtx, cancel := context.WithTimeout(ctx, subprocessTimeout)
defer cancel()
if out, err := exec.CommandContext(cmdCtx, name, args...).CombinedOutput(); err != nil {
log.Error("github run: command failed", "cmd", name, "err", err, "output", strings.TrimSpace(string(out)))
}
}
+190
View File
@@ -0,0 +1,190 @@
package cmd
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
)
// setupEvent writes a GitHub event payload to a temp file, points
// GITHUB_EVENT_PATH at it, and forces dry-run + Actions mode. It also resets
// the run command's package-level flag state so tests are independent.
func setupEvent(t *testing.T, payload string) {
t.Helper()
path := filepath.Join(t.TempDir(), "event.json")
if err := os.WriteFile(path, []byte(payload), 0o644); err != nil {
t.Fatalf("write event: %v", err)
}
t.Setenv("GITHUB_ACTIONS", "true")
t.Setenv("KIT_GITHUB_DRY_RUN", "1")
t.Setenv("GITHUB_EVENT_PATH", path)
t.Cleanup(func() {
githubRunModel = ""
githubRunDryRun = false
})
}
const issueCommentEvent = `{
"action": "created",
"comment": {
"id": 555,
"body": "/kit fix the broken parser",
"author_association": "OWNER",
"user": {"login": "alice"}
},
"issue": {"number": 42, "title": "Parser crashes on empty input", "body": "It panics."},
"repository": {"full_name": "acme/widgets", "default_branch": "main"}
}`
func TestExtractRequest(t *testing.T) {
tests := []struct {
name string
body string
want string
wantHit bool
}{
{"start with request", "/kit fix the bug", "fix the bug", true},
{"bare token", "/kit", "", true},
{"trailing token", "hey /kit", "", true},
{"mid-sentence ignored", "please review /kit behavior in the docs", "", false},
{"no token", "just a normal comment", "", false},
{"token in second line", "thanks!\n/kit add tests", "add tests", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, hit := extractRequest(tt.body)
if hit != tt.wantHit || got != tt.want {
t.Errorf("extractRequest(%q) = (%q, %v), want (%q, %v)", tt.body, got, hit, tt.want, tt.wantHit)
}
})
}
}
func TestBuildTrigger_IssueComment(t *testing.T) {
event, err := func() (*ghEvent, error) {
setupEvent(t, issueCommentEvent)
return loadGitHubEvent()
}()
if err != nil {
t.Fatalf("loadGitHubEvent: %v", err)
}
tr, err := buildTrigger(event)
if err != nil {
t.Fatalf("buildTrigger: %v", err)
}
if tr.repo != "acme/widgets" || tr.number != 42 || tr.isPR || tr.request != "fix the broken parser" {
t.Errorf("unexpected trigger: %+v", tr)
}
if tr.commentKind != "issues" {
t.Errorf("commentKind = %q, want issues", tr.commentKind)
}
}
func TestBuildPrompt_ContainsContext(t *testing.T) {
setupEvent(t, issueCommentEvent)
event, _ := loadGitHubEvent()
tr, _ := buildTrigger(event)
prompt := buildPrompt(tr, gatherContext(context.Background(), tr))
for _, want := range []string{
"fix the broken parser", // the request
"acme/widgets", // the repo
"issue #42", // the target
"@alice", // the author
"Parser crashes on empty input", // context: title
"It panics.", // context: body
} {
if !strings.Contains(prompt, want) {
t.Errorf("prompt missing %q\n---\n%s", want, prompt)
}
}
}
func TestRunGitHub_AuthorizedIssueComment(t *testing.T) {
setupEvent(t, issueCommentEvent)
if err := runGitHubRun(githubRunCmd, nil); err != nil {
t.Fatalf("runGitHubRun: %v", err)
}
}
func TestRunGitHub_UnauthorizedAssociation(t *testing.T) {
setupEvent(t, strings.Replace(issueCommentEvent, `"OWNER"`, `"NONE"`, 1))
// Should return nil (no-op) without attempting the agent run.
if err := runGitHubRun(githubRunCmd, nil); err != nil {
t.Fatalf("runGitHubRun should be a no-op for unauthorized authors, got: %v", err)
}
}
func TestRunGitHub_CommentWithoutToken(t *testing.T) {
setupEvent(t, strings.Replace(issueCommentEvent,
`"/kit fix the broken parser"`, `"just a normal comment"`, 1))
if err := runGitHubRun(githubRunCmd, nil); err != nil {
t.Fatalf("runGitHubRun should be a no-op without /kit, got: %v", err)
}
}
func TestRunGitHub_MidSentenceMentionIgnored(t *testing.T) {
setupEvent(t, strings.Replace(issueCommentEvent,
`"/kit fix the broken parser"`, `"please review /kit behavior in the docs"`, 1))
if err := runGitHubRun(githubRunCmd, nil); err != nil {
t.Fatalf("runGitHubRun should ignore mid-sentence mentions, got: %v", err)
}
}
func TestRunGitHub_PullRequestReviewComment(t *testing.T) {
setupEvent(t, `{
"action": "created",
"comment": {
"id": 999,
"body": "/kit review this change",
"author_association": "COLLABORATOR",
"user": {"login": "bob"}
},
"pull_request": {"number": 7, "title": "Add caching", "body": "Speeds things up."},
"repository": {"full_name": "acme/widgets", "default_branch": "main"}
}`)
event, _ := loadGitHubEvent()
tr, err := buildTrigger(event)
if err != nil {
t.Fatalf("buildTrigger: %v", err)
}
if !tr.isPR || tr.number != 7 || tr.commentKind != "pulls" {
t.Errorf("unexpected PR trigger: %+v", tr)
}
if err := runGitHubRun(githubRunCmd, nil); err != nil {
t.Fatalf("runGitHubRun (PR): %v", err)
}
}
func TestRunGitHub_RequiresActionsOrDryRun(t *testing.T) {
// Neither GITHUB_ACTIONS nor dry-run set → must error rather than act.
t.Setenv("GITHUB_ACTIONS", "")
t.Setenv("KIT_GITHUB_DRY_RUN", "")
githubRunDryRun = false
t.Cleanup(func() { githubRunDryRun = false })
if err := runGitHubRun(githubRunCmd, nil); err == nil {
t.Fatal("expected an error when run outside Actions without --dry-run")
}
}
func TestResolveRunModel(t *testing.T) {
t.Cleanup(func() { githubRunModel = "" })
t.Setenv("MODEL", "")
githubRunModel = ""
if got := resolveRunModel(); got != defaultGitHubModel {
t.Errorf("default model = %q, want %q", got, defaultGitHubModel)
}
t.Setenv("MODEL", "openai/gpt-5")
if got := resolveRunModel(); got != "openai/gpt-5" {
t.Errorf("MODEL env model = %q, want openai/gpt-5", got)
}
githubRunModel = "anthropic/claude-sonnet-4-5"
if got := resolveRunModel(); got != "anthropic/claude-sonnet-4-5" {
t.Errorf("flag model = %q, want anthropic/claude-sonnet-4-5", got)
}
}
+102
View File
@@ -0,0 +1,102 @@
package cmd
import (
"os"
"path/filepath"
"strings"
"testing"
)
func TestProviderSecretEnvVar(t *testing.T) {
tests := []struct {
provider string
want string
}{
{"anthropic", "ANTHROPIC_API_KEY"},
{"openai", "OPENAI_API_KEY"},
// Unknown provider falls back to "<PROVIDER>_API_KEY" with sanitization.
{"my-custom.provider", "MY_CUSTOM_PROVIDER_API_KEY"},
}
for _, tt := range tests {
t.Run(tt.provider, func(t *testing.T) {
got := providerSecretEnvVar(tt.provider)
if got != tt.want {
t.Errorf("providerSecretEnvVar(%q) = %q, want %q", tt.provider, got, tt.want)
}
})
}
}
func TestRenderGitHubWorkflow(t *testing.T) {
out := renderGitHubWorkflow("anthropic/claude-sonnet-4-5-20250929", "ANTHROPIC_API_KEY")
wantSubstrings := []string{
"name: kit",
"issue_comment:",
"pull_request_review_comment:",
"startsWith(github.event.comment.body, '/kit ')",
"github.event.comment.body == '/kit'",
"github.event.comment.author_association == 'OWNER'",
"github.event.comment.author_association == 'COLLABORATOR'",
"persist-credentials: false",
"uses: mark3labs/kit@v0",
"model: anthropic/claude-sonnet-4-5-20250929",
"GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}",
"ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}",
"contents: write",
"pull-requests: write",
"issues: write",
}
for _, want := range wantSubstrings {
if !strings.Contains(out, want) {
t.Errorf("rendered workflow missing %q\n---\n%s", want, out)
}
}
}
func TestWriteGitHubWorkflow(t *testing.T) {
dir := t.TempDir()
cwd, err := os.Getwd()
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = os.Chdir(cwd) })
if err := os.Chdir(dir); err != nil {
t.Fatal(err)
}
// First write succeeds and creates nested directories.
if err := writeGitHubWorkflow("anthropic/claude-sonnet-4-5", "ANTHROPIC_API_KEY", false); err != nil {
t.Fatalf("writeGitHubWorkflow: %v", err)
}
data, err := os.ReadFile(githubWorkflowPath)
if err != nil {
t.Fatalf("reading workflow: %v", err)
}
if !strings.Contains(string(data), "model: anthropic/claude-sonnet-4-5") {
t.Errorf("workflow missing model line:\n%s", data)
}
// Second write without force must refuse to clobber.
if err := writeGitHubWorkflow("anthropic/claude-sonnet-4-5", "ANTHROPIC_API_KEY", false); err == nil {
t.Error("expected error when overwriting without --force, got nil")
}
// With force it overwrites.
if err := writeGitHubWorkflow("openai/gpt-5", "OPENAI_API_KEY", true); err != nil {
t.Fatalf("writeGitHubWorkflow with force: %v", err)
}
data, err = os.ReadFile(githubWorkflowPath)
if err != nil {
t.Fatalf("reading workflow: %v", err)
}
if !strings.Contains(string(data), "OPENAI_API_KEY") {
t.Errorf("forced overwrite did not update content:\n%s", data)
}
// Sanity: the file lives at the expected nested path.
if _, err := os.Stat(filepath.Join(dir, githubWorkflowPath)); err != nil {
t.Errorf("workflow not at expected path: %v", err)
}
}
+241 -89
View File
@@ -12,7 +12,6 @@ import (
tea "charm.land/bubbletea/v2"
"github.com/mark3labs/kit/internal/app"
"github.com/mark3labs/kit/internal/auth"
"github.com/mark3labs/kit/internal/config"
"github.com/mark3labs/kit/internal/extensions"
"github.com/mark3labs/kit/internal/models"
@@ -74,6 +73,11 @@ var (
noCoreToolsFlag bool
extensionPaths []string
// Skills control
noSkillsFlag bool
skillsPaths []string
skillsDir string
// TLS configuration
tlsSkipVerify bool
@@ -284,6 +288,14 @@ func init() {
rootCmd.PersistentFlags().
StringSliceVarP(&extensionPaths, "extension", "e", nil, "load additional extension file(s)")
// Skills flags
rootCmd.PersistentFlags().
BoolVar(&noSkillsFlag, "no-skills", false, "disable skill loading (auto-discovery and explicit)")
rootCmd.PersistentFlags().
StringSliceVar(&skillsPaths, "skill", nil, "load skill file or directory (repeatable)")
rootCmd.PersistentFlags().
StringVar(&skillsDir, "skills-dir", "", "override the project-local skills directory for auto-discovery")
flags := rootCmd.PersistentFlags()
flags.StringVar(&providerURL, "provider-url", "", "base URL for the provider API (applies to OpenAI, Anthropic, Ollama, and Google)")
flags.StringVar(&providerAPIKey, "provider-api-key", "", "API key for the provider (applies to OpenAI, Anthropic, and Google)")
@@ -334,6 +346,9 @@ func init() {
_ = viper.BindPFlag("extension", rootCmd.PersistentFlags().Lookup("extension"))
_ = viper.BindPFlag("prompt-template", rootCmd.PersistentFlags().Lookup("prompt-template"))
_ = viper.BindPFlag("no-prompt-templates", rootCmd.PersistentFlags().Lookup("no-prompt-templates"))
_ = viper.BindPFlag("no-skills", rootCmd.PersistentFlags().Lookup("no-skills"))
_ = viper.BindPFlag("skill", rootCmd.PersistentFlags().Lookup("skill"))
_ = viper.BindPFlag("skills-dir", rootCmd.PersistentFlags().Lookup("skills-dir"))
// Defaults are already set in flag definitions, no need to duplicate in viper
@@ -655,13 +670,16 @@ func beforeForkProviderForUI(k *kit.Kit) func(string, bool, string) (bool, strin
// beforeSessionSwitchProviderForUI returns a callback that emits a
// BeforeSessionSwitch event and returns (cancelled, reason). Returns nil
// if extensions are disabled — the UI treats nil as "no hook".
func beforeSessionSwitchProviderForUI(k *kit.Kit) func(string) (bool, string) {
// if extensions are disabled — the UI treats nil as "no hook". The
// initialPrompt argument is forwarded to the event so extensions can
// inspect the prompt that will be submitted as the first turn of the
// new session.
func beforeSessionSwitchProviderForUI(k *kit.Kit) func(switchReason, initialPrompt string) (bool, string) {
if !k.Extensions().HasExtensions() {
return nil
}
return func(switchReason string) (bool, string) {
return k.Extensions().EmitBeforeSessionSwitch(switchReason)
return func(switchReason, initialPrompt string) (bool, string) {
return k.Extensions().EmitBeforeSessionSwitchWithPrompt(switchReason, initialPrompt)
}
}
@@ -677,8 +695,8 @@ func globalShortcutsProviderForUI(k *kit.Kit) func() map[string]func() {
}
}
func runNormalMode(ctx context.Context) error {
// Validate flag combinations
// validateModeFlags rejects invalid flag combinations for the root command.
func validateModeFlags() error {
if quietFlag && positionalPrompt == "" {
return fmt.Errorf("--quiet requires a prompt (e.g. kit \"your question\" --quiet)")
}
@@ -691,21 +709,14 @@ func runNormalMode(ctx context.Context) error {
if noExitFlag && positionalPrompt == "" {
return fmt.Errorf("--no-exit requires a prompt (e.g. kit \"your question\" --no-exit)")
}
return nil
}
// Set up logging
if debugMode {
log.SetFlags(log.LstdFlags | log.Lshortfile)
}
// Update debug mode from viper
if viper.GetBool("debug") && !debugMode {
debugMode = viper.GetBool("debug")
log.SetFlags(log.LstdFlags | log.Lshortfile)
}
// Restore persisted model preference when no explicit --model flag or
// config file model is set. Precedence: CLI flag > config file > saved
// preference > built-in default. This mirrors how themes are persisted.
// restorePersistedPreferences applies saved model / thinking-level
// preferences into viper when neither a CLI flag nor a config-file value
// takes precedence. Precedence: CLI flag > config file > saved preference >
// built-in default. This mirrors how themes are persisted.
func restorePersistedPreferences() {
// Skip custom/* models unless --provider-url is also provided, since the
// custom provider requires a URL that was only valid for the previous session.
if !modelFlagChanged && !viper.InConfig("model") {
@@ -724,6 +735,15 @@ func runNormalMode(ctx context.Context) error {
viper.Set("thinking-level", pref)
}
}
}
// applyProviderURLRouting rewrites the model in viper when --provider-url
// is set, routing requests through the "custom" (OpenAI-compatible)
// provider. Must run after restorePersistedPreferences.
func applyProviderURLRouting() {
if viper.GetString("provider-url") == "" {
return
}
// When --provider-url is set but no explicit --model was provided,
// default to "custom/custom" so the user doesn't need to remember a
@@ -731,18 +751,53 @@ func runNormalMode(ctx context.Context) error {
// This intentionally overrides saved preferences but respects config-file
// models — if you specify a model in ~/.kit.yml, it will be used with
// custom/custom's provider routing.
if viper.GetString("provider-url") != "" && !modelFlagChanged && !viper.InConfig("model") {
if !modelFlagChanged && !viper.InConfig("model") {
viper.Set("model", "custom/custom")
}
// When --provider-url is set with an explicit --model that lacks a provider
// prefix (no "/"), auto-prefix with "custom/" for OpenAI-compatible endpoints.
if viper.GetString("provider-url") != "" && modelFlagChanged {
// When --provider-url is set with an explicit --model, route through the
// "custom" provider (OpenAI-compatible wire). This honors the user's
// intent: passing a custom URL means "use THIS endpoint", not "speak
// the Google/Anthropic/etc. wire protocol against this endpoint".
//
// Any provider prefix on the model is stripped so a model name that
// happens to collide with a known provider (e.g. `google/gemma-4-12b`
// served by LM Studio) still resolves correctly. If you genuinely need
// to point a non-OpenAI wire (Anthropic, Google, ...) at a proxy URL,
// use the explicit `custom/<name>` form to opt out of the rewrite by
// configuring the proxy as that provider in your config file instead.
if modelFlagChanged {
model := viper.GetString("model")
if model != "" && !strings.Contains(model, "/") {
viper.Set("model", "custom/"+model)
if model != "" {
name := model
if _, after, ok := strings.Cut(model, "/"); ok {
name = after
}
if !strings.HasPrefix(model, "custom/") {
viper.Set("model", "custom/"+name)
}
}
}
}
func runNormalMode(ctx context.Context) error {
if err := validateModeFlags(); err != nil {
return err
}
// Set up logging
if debugMode {
log.SetFlags(log.LstdFlags | log.Lshortfile)
}
// Update debug mode from viper
if viper.GetBool("debug") && !debugMode {
debugMode = viper.GetBool("debug")
log.SetFlags(log.LstdFlags | log.Lshortfile)
}
restorePersistedPreferences()
applyProviderURLRouting()
// Load MCP configuration.
mcpConfig, err := config.LoadAndValidateConfig()
@@ -784,6 +839,9 @@ func runNormalMode(ctx context.Context) error {
AutoCompact: autoCompactFlag,
MCPAuthHandler: authHandler,
DisableCoreTools: viper.GetBool("no-core-tools"),
NoSkills: noSkillsFlag,
Skills: skillsPaths,
SkillsDir: skillsDir,
// This callback is called when each MCP server finishes loading.
// We use a closure that captures appInstancePtr which is set after
// app.New() is called below.
@@ -916,6 +974,9 @@ func runNormalMode(ctx context.Context) error {
startupExtensionMessages = append(startupExtensionMessages, text)
}
kitInstance.Extensions().SetContext(extCtx)
if err := kitInstance.Extensions().InitStatePersistence(); err != nil {
log.Printf("WARN extension state init failed: %v", err)
}
kitInstance.Extensions().EmitSessionStart()
// Restore normal print functions for runtime use.
@@ -1146,23 +1207,7 @@ func runNormalMode(ctx context.Context) error {
// NotifyModelChanged calls prog.Send() which deadlocks. The UI layer
// updates m.providerName and m.modelName directly after setModel returns.
// Update usage tracker with new model info for correct token counting.
if usageTracker != nil {
newProvider, newModel, _ := models.ParseModelString(modelString)
if newProvider != "unknown" && newModel != "unknown" && newProvider != "ollama" {
registry := models.GetGlobalRegistry()
if modelInfo := registry.LookupModel(newProvider, newModel); modelInfo != nil {
// Check OAuth status for Anthropic models
isOAuth := false
if newProvider == "anthropic" {
_, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key"))
if err == nil && strings.HasPrefix(source, "stored OAuth") {
isOAuth = true
}
}
usageTracker.UpdateModelInfo(modelInfo, newProvider, isOAuth)
}
}
}
ui.UpdateUsageTrackerForModel(usageTracker, modelString, viper.GetString("provider-api-key"))
return nil
}
emitModelChangeForUI := func(newModel, previousModel, source string) {
@@ -1262,9 +1307,57 @@ func runNormalMode(ctx context.Context) error {
}
}
// Bundle all the shared dependencies into a single struct that both
// run-mode entry points consume. This keeps the dispatch site and the
// function signatures readable.
deps := runModeDeps{
appInstance: appInstance,
cli: cli,
modelName: modelName,
providerName: parsedProvider,
loadingMessage: kitInstance.GetLoadingMessage(),
serverNames: serverNames,
toolNames: toolNames,
mcpToolCount: mcpToolCount,
extensionToolCount: extensionToolCount,
usageTracker: usageTracker,
extCommands: extCommands,
promptTemplates: promptTemplates,
contextPaths: contextPaths,
skillItems: skillItems,
extensionItems: extensionItems,
getPromptTemplates: getPromptTemplates,
getSkillItems: getSkillItems,
getExtensionItems: getExtensionItems,
getToolNames: getToolNames,
getMCPToolCount: getMCPToolCount,
mcpPrompts: mcpPrompts,
getMCPPrompts: getMCPPrompts,
expandMCPPrompt: expandMCPPrompt,
getWidgets: getWidgets,
getHeader: getHeader,
getFooter: getFooter,
getToolRenderer: getToolRenderer,
getEditorInterceptor: getEditorInterceptor,
getUIVisibility: getUIVisibility,
getStatusBarEntries: getStatusBarEntries,
emitBeforeFork: emitBeforeFork,
emitBeforeSessionSwitch: emitBeforeSessionSwitch,
getGlobalShortcuts: getGlobalShortcuts,
getExtensionCommands: getExtensionCommands,
setModel: setModelForUI,
emitModelChange: emitModelChangeForUI,
isReasoningModel: kitInstance.IsReasoningModel(),
thinkingLevel: kitInstance.GetThinkingLevel(),
setThinkingLevel: setThinkingLevelForUI,
switchSession: switchSessionForUI,
reloadExtensions: reloadExtensionsForUI,
startupExtensionMessages: startupExtensionMessages,
}
// Check if running in non-interactive mode
if positionalPrompt != "" {
return runNonInteractiveModeApp(ctx, appInstance, cli, positionalPrompt, quietFlag, jsonFlag, noExitFlag, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, extensionItems, getPromptTemplates, getSkillItems, getExtensionItems, getToolNames, getMCPToolCount, mcpPrompts, getMCPPrompts, expandMCPPrompt, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI)
return runNonInteractiveModeApp(ctx, deps, positionalPrompt, quietFlag, jsonFlag, noExitFlag)
}
// Quiet mode is not allowed in interactive mode
@@ -1272,7 +1365,7 @@ func runNormalMode(ctx context.Context) error {
return fmt.Errorf("--quiet requires a prompt")
}
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, extensionItems, getPromptTemplates, getSkillItems, getExtensionItems, getToolNames, getMCPToolCount, mcpPrompts, getMCPPrompts, expandMCPPrompt, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI, startupExtensionMessages)
return runInteractiveModeBubbleTea(ctx, deps)
}
// runNonInteractiveModeApp executes a single prompt via the app layer and exits,
@@ -1285,7 +1378,10 @@ func runNormalMode(ctx context.Context) error {
//
// When --no-exit is set, after the prompt completes the interactive BubbleTea
// TUI is started so the user can continue the conversation.
func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui.CLI, prompt string, quiet, jsonOutput, noExit bool, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, extensionItems []ui.ExtensionItem, getPromptTemplates func() []*prompts.PromptTemplate, getSkillItems func() []ui.SkillItem, getExtensionItems func() []ui.ExtensionItem, getToolNames func() []string, getMCPToolCount func() int, mcpPrompts []ui.MCPPromptInfo, getMCPPrompts func() []ui.MCPPromptInfo, expandMCPPrompt func(string, string, map[string]string) (*ui.MCPPromptExpandResult, error), getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error) error {
func runNonInteractiveModeApp(ctx context.Context, deps runModeDeps, prompt string, quiet, jsonOutput, noExit bool) error {
appInstance := deps.appInstance
cli := deps.cli
modelName := deps.modelName
// Expand @file references in the prompt before sending to the agent.
// Text files are XML-inlined; binary files are extracted as multimodal parts.
var fileParts []kit.LLMFilePart
@@ -1346,12 +1442,67 @@ func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui
// If --no-exit was requested, hand off to the interactive TUI.
if noExit {
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, providerName, loadingMessage, serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, extensionItems, getPromptTemplates, getSkillItems, getExtensionItems, getToolNames, getMCPToolCount, mcpPrompts, getMCPPrompts, expandMCPPrompt, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModel, emitModelChange, isReasoningModel, thinkingLevel, setThinkingLevel, switchSession, reloadExtensions, nil)
// Drop the cli (interactive mode doesn't use it) and clear the
// interactive-only fields explicitly; deps carries everything else.
interactive := deps
interactive.cli = nil
interactive.startupExtensionMessages = nil
return runInteractiveModeBubbleTea(ctx, interactive)
}
return nil
}
// runModeDeps bundles the shared dependencies that runNormalMode wires up
// once and threads to both runNonInteractiveModeApp and
// runInteractiveModeBubbleTea. Grouping them into a single struct keeps the
// call sites and signatures readable and makes it trivial to add a new
// provider callback without touching every call chain.
type runModeDeps struct {
appInstance *app.App
cli *ui.CLI // non-interactive only
modelName string
providerName string
loadingMessage string
serverNames []string
toolNames []string
mcpToolCount int
extensionToolCount int
usageTracker *ui.UsageTracker
extCommands []commands.ExtensionCommand
promptTemplates []*prompts.PromptTemplate
contextPaths []string
skillItems []ui.SkillItem
extensionItems []ui.ExtensionItem
getPromptTemplates func() []*prompts.PromptTemplate
getSkillItems func() []ui.SkillItem
getExtensionItems func() []ui.ExtensionItem
getToolNames func() []string
getMCPToolCount func() int
mcpPrompts []ui.MCPPromptInfo
getMCPPrompts func() []ui.MCPPromptInfo
expandMCPPrompt func(string, string, map[string]string) (*ui.MCPPromptExpandResult, error)
getWidgets func(string) []ui.WidgetData
getHeader func() *ui.WidgetData
getFooter func() *ui.WidgetData
getToolRenderer func(string) *ui.ToolRendererData
getEditorInterceptor func() *ui.EditorInterceptor
getUIVisibility func() *ui.UIVisibility
getStatusBarEntries func() []ui.StatusBarEntryData
emitBeforeFork func(string, bool, string) (bool, string)
emitBeforeSessionSwitch func(string, string) (bool, string)
getGlobalShortcuts func() map[string]func()
getExtensionCommands func() []commands.ExtensionCommand
setModel func(string) error
emitModelChange func(string, string, string)
isReasoningModel bool
thinkingLevel string
setThinkingLevel func(string) error
switchSession func(string) error
reloadExtensions func() error
startupExtensionMessages []string // interactive only
}
// ---------------------------------------------------------------------------
// JSON output helpers (--json mode)
// ---------------------------------------------------------------------------
@@ -1444,7 +1595,8 @@ func writeJSONError(err error) {
// 4. Calls program.Run() which blocks until the user quits (Ctrl+C or /quit).
//
// SetupCLI is not used for interactive mode; the TUI (AppModel) handles its own rendering.
func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, extensionItems []ui.ExtensionItem, getPromptTemplates func() []*prompts.PromptTemplate, getSkillItems func() []ui.SkillItem, getExtensionItems func() []ui.ExtensionItem, getToolNames func() []string, getMCPToolCount func() int, mcpPrompts []ui.MCPPromptInfo, getMCPPrompts func() []ui.MCPPromptInfo, expandMCPPrompt func(string, string, map[string]string) (*ui.MCPPromptExpandResult, error), getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error, startupExtensionMessages []string) error {
func runInteractiveModeBubbleTea(_ context.Context, deps runModeDeps) error {
appInstance := deps.appInstance
// Redirect all log output (stdlib and charm) to a file so that log
// messages don't write to stderr and corrupt the TUI. Bubble Tea
// captures stdout for rendering; any stray stderr output from
@@ -1467,49 +1619,49 @@ func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelN
cwd, _ := os.Getwd()
appModel := ui.NewAppModel(appInstance, ui.AppModelOptions{
ModelName: modelName,
ProviderName: providerName,
LoadingMessage: loadingMessage,
ModelName: deps.modelName,
ProviderName: deps.providerName,
LoadingMessage: deps.loadingMessage,
Cwd: cwd,
Width: termWidth,
Height: termHeight,
ServerNames: serverNames,
ToolNames: toolNames,
GetToolNames: getToolNames,
GetMCPToolCount: getMCPToolCount,
MCPToolCount: mcpToolCount,
ExtensionToolCount: extensionToolCount,
UsageTracker: usageTracker,
ExtensionCommands: extCommands,
PromptTemplates: promptTemplates,
GetPromptTemplates: getPromptTemplates,
MCPPrompts: mcpPrompts,
GetMCPPrompts: getMCPPrompts,
ExpandMCPPrompt: expandMCPPrompt,
ContextPaths: contextPaths,
SkillItems: skillItems,
GetSkillItems: getSkillItems,
ExtensionItems: extensionItems,
GetExtensionItems: getExtensionItems,
StartupExtensionMessages: startupExtensionMessages,
GetWidgets: getWidgets,
GetHeader: getHeader,
GetFooter: getFooter,
GetToolRenderer: getToolRenderer,
GetEditorInterceptor: getEditorInterceptor,
GetUIVisibility: getUIVisibility,
GetStatusBarEntries: getStatusBarEntries,
EmitBeforeFork: emitBeforeFork,
EmitBeforeSessionSwitch: emitBeforeSessionSwitch,
GetGlobalShortcuts: getGlobalShortcuts,
GetExtensionCommands: getExtensionCommands,
SetModel: setModel,
EmitModelChange: emitModelChange,
ThinkingLevel: thinkingLevel,
IsReasoningModel: isReasoningModel,
SetThinkingLevel: setThinkingLevel,
SwitchSession: switchSession,
ReloadExtensions: reloadExtensions,
ServerNames: deps.serverNames,
ToolNames: deps.toolNames,
GetToolNames: deps.getToolNames,
GetMCPToolCount: deps.getMCPToolCount,
MCPToolCount: deps.mcpToolCount,
ExtensionToolCount: deps.extensionToolCount,
UsageTracker: deps.usageTracker,
ExtensionCommands: deps.extCommands,
PromptTemplates: deps.promptTemplates,
GetPromptTemplates: deps.getPromptTemplates,
MCPPrompts: deps.mcpPrompts,
GetMCPPrompts: deps.getMCPPrompts,
ExpandMCPPrompt: deps.expandMCPPrompt,
ContextPaths: deps.contextPaths,
SkillItems: deps.skillItems,
GetSkillItems: deps.getSkillItems,
ExtensionItems: deps.extensionItems,
GetExtensionItems: deps.getExtensionItems,
StartupExtensionMessages: deps.startupExtensionMessages,
GetWidgets: deps.getWidgets,
GetHeader: deps.getHeader,
GetFooter: deps.getFooter,
GetToolRenderer: deps.getToolRenderer,
GetEditorInterceptor: deps.getEditorInterceptor,
GetUIVisibility: deps.getUIVisibility,
GetStatusBarEntries: deps.getStatusBarEntries,
EmitBeforeFork: deps.emitBeforeFork,
EmitBeforeSessionSwitch: deps.emitBeforeSessionSwitch,
GetGlobalShortcuts: deps.getGlobalShortcuts,
GetExtensionCommands: deps.getExtensionCommands,
SetModel: deps.setModel,
EmitModelChange: deps.emitModelChange,
ThinkingLevel: deps.thinkingLevel,
IsReasoningModel: deps.isReasoningModel,
SetThinkingLevel: deps.setThinkingLevel,
SwitchSession: deps.switchSession,
ReloadExtensions: deps.reloadExtensions,
ShowSessionPicker: resumeFlag,
GetMCPResources: mcpGetResources,
MCPResourceReader: mcpResourceReader,
+1
View File
@@ -58,6 +58,7 @@ kit install github.com/mark3labs/kit/examples/extensions --local
| `project-rules.go` | Project-specific rules | Session data, file reading |
| `protected-paths.go` | Block dangerous operations | `OnToolCall` with blocking |
| `permission-gate.go` | Confirm destructive actions | `OnToolCall` with confirmation |
| `usage-budget.go` | Soft cost cap + per-turn report | `OnLLMUsage`, `SetState`/`GetState`, enriched `AgentEndEvent` |
### Tools & Commands
+110
View File
@@ -0,0 +1,110 @@
//go:build ignore
// phase-handoff.go demonstrates ctx.NewSession by automating the multi-phase
// workflow pattern: the agent works through a spec, writes a HANDOFF.md at
// the end of each phase, then a fresh session picks up where the last one
// left off.
//
// Two trigger modes are provided:
//
// 1. Automatic — when an assistant message ends with the sentinel
// "<HANDOFF_READY>", the extension starts a new session and pre-loads
// HANDOFF.md as the first prompt. Use this when you want the agent to
// hand off control to itself with no user intervention.
//
// 2. Manual — the /handoff slash command starts a new session immediately
// with the same handoff prompt. Useful when you finish a phase by hand
// and want to clear the context window before the next one starts.
//
// Usage:
//
// kit -e examples/extensions/phase-handoff.go
//
// Have your spec-driving agent write a HANDOFF.md at the end of each phase
// and finish its message with the literal string `<HANDOFF_READY>`. The
// next session boots automatically and reads HANDOFF.md as @file context.
package main
import (
"strings"
"kit/ext"
)
// HANDOFFSentinel is the marker the agent appends to its last message to
// request an automatic session switch. Change this to whatever fits your
// workflow.
const HANDOFFSentinel = "<HANDOFF_READY>"
// HANDOFFPrompt is the first prompt the new session receives. The leading
// "@HANDOFF.md" triggers Kit's @file expansion, inlining the handoff file's
// contents as XML-wrapped context.
const HANDOFFPrompt = "Read @HANDOFF.md and continue with the next phase."
func Init(api ext.API) {
// Automatic trigger: detect the sentinel at the end of an agent turn.
api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) {
msgs := ctx.GetMessages()
if len(msgs) == 0 {
return
}
last := msgs[len(msgs)-1]
if last.Role != "assistant" || !strings.Contains(last.Content, HANDOFFSentinel) {
return
}
// NewSession blocks while the agent finishes settling and then while
// the TUI completes the switch; run it in a goroutine so the agent's
// turn-end pipeline isn't stalled. The internal wait-for-idle (added
// in response to issue #63) makes this reliable even when post-turn
// tooling (formatters, on-save hooks, hidden tool calls) extends the
// busy window past AgentEnd.
go func() {
if err := ctx.NewSession(HANDOFFPrompt); err != nil {
ctx.PrintError("phase-handoff: " + err.Error())
return
}
ctx.PrintInfo("phase-handoff: started a fresh session from HANDOFF.md")
}()
})
// Manual trigger: /handoff [optional override prompt]
api.RegisterCommand(ext.CommandDef{
Name: "handoff",
Description: "Start a new session, optionally with a custom prompt",
Execute: func(args string, ctx ext.Context) (string, error) {
prompt := strings.TrimSpace(args)
if prompt == "" {
prompt = HANDOFFPrompt
}
if err := ctx.NewSession(prompt); err != nil {
return "", err
}
return "", nil
},
})
// Optional safeguard: surface the next prompt so the user can confirm
// before the auto-handoff proceeds. Set kit option "handoff.confirm=1"
// to enable.
api.OnBeforeSessionSwitch(func(e ext.BeforeSessionSwitchEvent, ctx ext.Context) *ext.BeforeSessionSwitchResult {
if ctx.GetOption("handoff.confirm") != "1" {
return nil
}
if e.InitialPrompt == "" {
return nil
}
resp := ctx.PromptConfirm(ext.PromptConfirmConfig{
Message: "Start a new session with prompt:\n " + e.InitialPrompt + "\n\nProceed?",
DefaultValue: true,
})
if resp.Cancelled || !resp.Value {
return &ext.BeforeSessionSwitchResult{
Cancel: true,
Reason: "handoff cancelled by user",
}
}
return nil
})
}
+87
View File
@@ -0,0 +1,87 @@
//go:build ignore
package main
import (
"fmt"
"strconv"
"kit/ext"
)
// Init demonstrates the three primitives added in issue #53:
//
// 1. api.OnLLMUsage(...) — per-LLM-call usage callback with token + cost
// deltas. Use this for budget enforcement that reacts between calls
// within a single agent turn, rather than only at turn boundaries.
//
// 2. ctx.SetState / ctx.GetState / ctx.DeleteState / ctx.ListState —
// last-write-wins, session-scoped key-value store backed by a sidecar
// file. Use this for snapshot state (current value of X) instead of
// ctx.AppendEntry, which is append-only and bloats branch reads.
//
// 3. ext.AgentEndEvent.ToolCallCount / .ToolNames / .LLMCallCount /
// .InputTokensDelta / .OutputTokensDelta / .CostDelta / .DurationMs —
// per-turn aggregates so observer extensions don't need to maintain
// parallel bookkeeping.
//
// Together these support a simple soft-budget cap: warn when the
// cumulative cost in this session exceeds a threshold, and print a
// per-turn report on AgentEnd.
//
// Usage: kit -e examples/extensions/usage-budget.go
func Init(api ext.API) {
const warnAtKey = "usage-budget:warn-at-usd"
// 1. Print per-LLM-call usage with provider, model, and cost.
api.OnLLMUsage(func(e ext.LLMUsageEvent, ctx ext.Context) {
ctx.Print(fmt.Sprintf(
"[usage] step=%d %s/%s tokens=↑%d ↓%d cache=↑%d/↓%d cost=$%.4f (%s)",
e.StepNumber, e.Provider, e.Model,
e.InputTokens, e.OutputTokens,
e.CacheWriteTokens, e.CacheReadTokens,
e.Cost, e.FinishReason,
))
// 2. Persist running total in last-write-wins state.
current := 0.0
if raw, ok := ctx.GetState("usage-budget:total-cost"); ok {
current, _ = strconv.ParseFloat(raw, 64)
}
current += e.Cost
ctx.SetState("usage-budget:total-cost", strconv.FormatFloat(current, 'f', 6, 64))
// Soft warn-at threshold (configurable via state).
warnAt := 0.50
if raw, ok := ctx.GetState(warnAtKey); ok {
if v, err := strconv.ParseFloat(raw, 64); err == nil {
warnAt = v
}
}
if current > warnAt {
ctx.PrintError(fmt.Sprintf(
"[usage] session cost $%.4f exceeds soft cap $%.2f",
current, warnAt,
))
}
})
// 3. Print a per-turn summary using the enriched AgentEndEvent.
api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) {
ctx.Print(fmt.Sprintf(
"[turn] stop=%s tools=%d llm-calls=%d tokens=↑%d ↓%d cost=$%.4f duration=%dms",
e.StopReason, e.ToolCallCount, e.LLMCallCount,
e.InputTokensDelta, e.OutputTokensDelta, e.CostDelta, e.DurationMs,
))
if len(e.ToolNames) > 0 {
ctx.Print(fmt.Sprintf("[turn] tool order: %v", e.ToolNames))
}
})
// Bootstrap default soft cap once per session.
api.OnSessionStart(func(e ext.SessionStartEvent, ctx ext.Context) {
if _, ok := ctx.GetState(warnAtKey); !ok {
ctx.SetState(warnAtKey, "0.50")
}
})
}
+46 -47
View File
@@ -1,32 +1,34 @@
module github.com/mark3labs/kit
go 1.26.3
go 1.26.4
require (
charm.land/bubbles/v2 v2.1.0
charm.land/bubbletea/v2 v2.0.6
charm.land/fantasy v0.25.0
charm.land/bubbletea/v2 v2.0.7
charm.land/fantasy v0.32.0
charm.land/huh/v2 v2.0.3
charm.land/lipgloss/v2 v2.0.3
github.com/alecthomas/chroma/v2 v2.26.1
charm.land/lipgloss/v2 v2.0.4
github.com/alecthomas/chroma/v2 v2.27.0
github.com/atotto/clipboard v0.1.4
github.com/aymanbagabas/go-udiff v0.4.1
github.com/charmbracelet/colorprofile v0.4.3
github.com/charmbracelet/fang v1.0.0
github.com/charmbracelet/log v1.0.0
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266
github.com/charmbracelet/ultraviolet v0.0.0-20260525132238-948f4557a654
github.com/charmbracelet/openai-go v0.0.0-20260617131321-5e4b9c18c4be
github.com/charmbracelet/ultraviolet v0.0.0-20260615092913-2399af76d5b1
github.com/charmbracelet/x/editor v0.2.0
github.com/clipperhouse/displaywidth v0.11.0
github.com/clipperhouse/uax29/v2 v2.7.0
github.com/coder/acp-go-sdk v0.13.0
github.com/coder/acp-go-sdk v0.13.5
github.com/fsnotify/fsnotify v1.10.1
github.com/indaco/herald v0.13.0
github.com/indaco/herald-md v0.3.0
github.com/mark3labs/mcp-go v0.54.1
github.com/mark3labs/mcp-go v0.55.0
github.com/spf13/cobra v1.10.2
github.com/spf13/viper v1.21.0
github.com/traefik/yaegi v0.16.1
golang.org/x/term v0.43.0
golang.org/x/image v0.42.0
golang.org/x/term v0.44.0
gopkg.in/yaml.v3 v3.0.1
)
@@ -35,42 +37,41 @@ require (
cloud.google.com/go/auth v0.20.0 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
cloud.google.com/go/compute/metadata v0.9.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1 // indirect
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.22.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 // indirect
github.com/aws/aws-sdk-go-v2 v1.41.8 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10 // indirect
github.com/aws/aws-sdk-go-v2/config v1.32.19 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.19.18 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.24 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.24 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.24 // indirect
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.25 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.24 // indirect
github.com/aws/aws-sdk-go-v2/service/signin v1.1.0 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.30.18 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.1 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.42.2 // indirect
github.com/aws/smithy-go v1.26.0 // indirect
github.com/aws/aws-sdk-go-v2 v1.42.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.13 // indirect
github.com/aws/aws-sdk-go-v2/config v1.32.25 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.19.24 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.29 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.29 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.29 // indirect
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.30 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.12 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.29 // indirect
github.com/aws/aws-sdk-go-v2/service/signin v1.2.0 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.31.3 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.6 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.43.3 // indirect
github.com/aws/smithy-go v1.27.2 // indirect
github.com/catppuccin/go v0.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab // indirect
github.com/charmbracelet/colorprofile v0.4.3 // indirect
github.com/charmbracelet/harmonica v0.2.0 // indirect
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 // indirect
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260527151214-009e6338d40d // indirect
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260615092313-b57e5e6d29bb // indirect
github.com/charmbracelet/x/exp/ordered v0.1.0 // indirect
github.com/charmbracelet/x/exp/slice v0.0.0-20260527151214-009e6338d40d // indirect
github.com/charmbracelet/x/exp/slice v0.0.0-20260615092313-b57e5e6d29bb // indirect
github.com/charmbracelet/x/exp/strings v0.1.0 // indirect
github.com/charmbracelet/x/json v0.2.0 // indirect
github.com/charmbracelet/x/termios v0.1.1 // indirect
github.com/charmbracelet/x/windows v0.2.2 // indirect
github.com/dlclark/regexp2 v1.12.0 // indirect
github.com/dlclark/regexp2/v2 v2.1.1 // indirect
github.com/dlclark/regexp2/v2 v2.2.2 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-json-experiment/json v0.0.0-20260520185125-572e7c383686 // indirect
github.com/felixge/httpsnoop v1.1.0 // indirect
github.com/go-json-experiment/json v0.0.0-20260601182631-00ed12fed2a6 // indirect
github.com/go-logfmt/logfmt v0.6.1 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
@@ -83,16 +84,14 @@ require (
github.com/googleapis/enterprise-certificate-proxy v0.3.16 // indirect
github.com/googleapis/gax-go/v2 v2.22.0 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/kaptinlin/go-i18n v0.4.5 // indirect
github.com/kaptinlin/jsonpointer v0.4.25 // indirect
github.com/kaptinlin/jsonschema v0.7.13 // indirect
github.com/kaptinlin/messageformat-go v0.6.0 // indirect
github.com/kaptinlin/jsonpointer v0.4.26 // indirect
github.com/kaptinlin/jsonschema v0.8.1 // indirect
github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect
github.com/muesli/mango v0.2.0 // indirect
github.com/muesli/mango-cobra v1.3.0 // indirect
github.com/muesli/mango-pflag v0.2.0 // indirect
github.com/muesli/roff v0.1.0 // indirect
github.com/pelletier/go-toml/v2 v2.3.1 // indirect
github.com/pelletier/go-toml/v2 v2.4.0 // indirect
github.com/sagikazarmark/locafero v0.12.0 // indirect
github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 // indirect
github.com/spf13/afero v1.15.0 // indirect
@@ -112,14 +111,14 @@ require (
go.opentelemetry.io/otel/metric v1.44.0 // indirect
go.opentelemetry.io/otel/trace v1.44.0 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/crypto v0.52.0 // indirect
golang.org/x/exp v0.0.0-20260528193900-50dc527dd6c7 // indirect
golang.org/x/net v0.55.0 // indirect
golang.org/x/crypto v0.53.0 // indirect
golang.org/x/exp v0.0.0-20260611194520-c48552f49976 // indirect
golang.org/x/net v0.56.0 // indirect
golang.org/x/oauth2 v0.36.0 // indirect
golang.org/x/time v0.15.0 // indirect
google.golang.org/api v0.282.0 // indirect
google.golang.org/genai v1.58.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260526163538-3dc84a4a5aaa // indirect
google.golang.org/api v0.285.0 // indirect
google.golang.org/genai v1.61.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260615183401-62b3387ff324 // indirect
google.golang.org/grpc v1.81.1 // indirect
google.golang.org/protobuf v1.36.11 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
@@ -132,12 +131,12 @@ require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/lucasb-eyer/go-colorful v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.22 // indirect
github.com/mattn/go-runewidth v0.0.23 // indirect
github.com/mattn/go-runewidth v0.0.24 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/spf13/pflag v1.0.10
golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.45.0 // indirect
golang.org/x/text v0.37.0
golang.org/x/sync v0.21.0 // indirect
golang.org/x/sys v0.46.0 // indirect
golang.org/x/text v0.38.0
)
+92 -94
View File
@@ -1,13 +1,13 @@
charm.land/bubbles/v2 v2.1.0 h1:YSnNh5cPYlYjPxRrzs5VEn3vwhtEn3jVGRBT3M7/I0g=
charm.land/bubbles/v2 v2.1.0/go.mod h1:l97h4hym2hvWBVfmJDtrEHHCtkIKeTEb3TTJ4ZOB3wY=
charm.land/bubbletea/v2 v2.0.6 h1:UHN/91OyuhaOFGSrBXQ/hMZD8IO1Uc4BvHlgHXL2WJo=
charm.land/bubbletea/v2 v2.0.6/go.mod h1:MH/D8ZLlN3op37vQvijKuU29g3rqTp+aQapURFonF9g=
charm.land/fantasy v0.25.0 h1:oXOWY1ivmTSnhYGzAolscF8zKtavWZyBWv0LHRSwN5Q=
charm.land/fantasy v0.25.0/go.mod h1:8QrWUzIcKwZQP+aAnC9vLu3iID6hu9/Jt+rPMiieBkc=
charm.land/bubbletea/v2 v2.0.7 h1:7qw2tTAVar7m7klOPBYfTB0mniv/RuexsYwMRNxSeL0=
charm.land/bubbletea/v2 v2.0.7/go.mod h1:DGW2q8gvzHnOpMpZTORs0aySVHCox5C+2Svk0fci1qs=
charm.land/fantasy v0.32.0 h1:tlC1qlOdXi2CkF6KB0x8YAAm3hiarI2/69u6pZmOZk8=
charm.land/fantasy v0.32.0/go.mod h1:CWAFEOB21guhmt4qWN9sOnAHkZzVWjKbhxbPHG+oRs8=
charm.land/huh/v2 v2.0.3 h1:2cJsMqEPwSywGHvdlKsJyQKPtSJLVnFKyFbsYZTlLkU=
charm.land/huh/v2 v2.0.3/go.mod h1:93eEveeeqn47MwiC3tf+2atZ2l7Is88rAtmZNZ8x9Wc=
charm.land/lipgloss/v2 v2.0.3 h1:yM2zJ4Cf5Y51b7RHIwioil4ApI/aypFXXVHSwlM6RzU=
charm.land/lipgloss/v2 v2.0.3/go.mod h1:7myLU9iG/3xluAWzpY/fSxYYHCgoKTie7laxk6ATwXA=
charm.land/lipgloss/v2 v2.0.4 h1:lcPeVtcp23SNra7lHy8iYE4UC2aIipVQ47sbGyyxR5Q=
charm.land/lipgloss/v2 v2.0.4/go.mod h1:0653x8epbZSzdDfO/XPS1a/uYPOBeSsCssOpJOqDzik=
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
cloud.google.com/go/auth v0.20.0 h1:kXTssoVb4azsVDoUiF8KvxAqrsQcQtB53DcSgta74CA=
@@ -16,8 +16,8 @@ cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIi
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1 h1:jHb/wfvRikGdxMXYV3QG/SzUOPYN9KEUUuC0Yd0/vC0=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1/go.mod h1:pzBXCYn05zvYIrwLgtK8Ap8QcjRg+0i76tMQdWN6wOk=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.22.0 h1:aokoqcHvaGjiM3VpjKDfMMnF/8epJ+Q1HLJ7CudztqE=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.22.0/go.mod h1:/WYEx9pcM9Y+Dd/APJaNlSvVSvzl54rrMdZT5+Oi2LM=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 h1:fhqpLE3UEXi9lPaBRpQ6XuRW0nU7hgg4zlmZZa+a9q4=
@@ -28,42 +28,42 @@ github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ
github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0=
github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
github.com/alecthomas/chroma/v2 v2.26.1 h1:2X21EdxGZNv5GF9mG5u+uzc02GCFyGxbcBm3Grd9A78=
github.com/alecthomas/chroma/v2 v2.26.1/go.mod h1:lxhRRa9H4hPmRLOOdYga4zkQIQjq3dtrrdwQeCfu78Y=
github.com/alecthomas/chroma/v2 v2.27.0 h1:FodwmyOBgJULFYmDqibcp9pvfDLWdtPRh9v/r5BXYZs=
github.com/alecthomas/chroma/v2 v2.27.0/go.mod h1:NjJ3ciIgrqBNeIkWZ4e46nseoLDslxU1LmfCoL+wcY8=
github.com/alecthomas/repr v0.5.2 h1:SU73FTI9D1P5UNtvseffFSGmdNci/O6RsqzeXJtP0Qs=
github.com/alecthomas/repr v0.5.2/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aws/aws-sdk-go-v2 v1.41.8 h1:sRs7nG6/RiEBZ/K5UO2sNw0w40U02Nmz1VtARloTZXk=
github.com/aws/aws-sdk-go-v2 v1.41.8/go.mod h1:4LAfZOPHNVNQEckOACQx60Y8pSRjIkNZQz1w92xpMJc=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10 h1:gx1AwW1Iyk9Z9dD9F4akX5gnN3QZwUB20GGKH/I+Rho=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10/go.mod h1:qqY157uZoqm5OXq/amuaBJyC9hgBCBQnsaWnPe905GY=
github.com/aws/aws-sdk-go-v2/config v1.32.19 h1:qRhIJMbevHUvIE7X4TK8N8zye5+5AhapcslPrvB+qKE=
github.com/aws/aws-sdk-go-v2/config v1.32.19/go.mod h1:RbJ24nfoya63+Mf5VI+CGCGk9vEdv28xPeii+gojRYs=
github.com/aws/aws-sdk-go-v2/credentials v1.19.18 h1:GcXQz2M/0ZvMo0v5DakUqbDBeBM1ZNaivkolEF4Esgw=
github.com/aws/aws-sdk-go-v2/credentials v1.19.18/go.mod h1:sHJ06tMGcD3ZpmMyJqV+VBsGilhSIZPIN+ZFy5Dg0C4=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.24 h1:FQm5ApnyzkuJdXLGskPce83CK1CQKC4RUnIHKVe4BU4=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.24/go.mod h1:JsC7dqQc55MlZ5mvNsDMMge71u8pVcSzU3RNz2h/5yQ=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.24 h1:u6kJU2i0va1AgtJsH3RdWKWqHULlTh7zHwb35Womf74=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.24/go.mod h1:7GY+xLcXOFUpCkNwDReft9qOAVg54A4/AnjHIU7sSAY=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.24 h1:Xhbcf3KugX6vX7SDyUK205Oicyfg7EGuvoVNyP5L6DM=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.24/go.mod h1:rwDgb2HNOGZsnTHylOUedM7Vnl+bCfnXDqUNPsFWYfk=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.25 h1:54CTMmlJ71Rk2dYvM9qZOob+39wjlVja2zDLxCu69Ew=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.25/go.mod h1:BZaHqxsS9vN1fvV5EfEl0OBLOk5+AajWsMu6MjqnZB4=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 h1:FLudkZLt5ci0ozzgkVo8BJGwvqNaZbTWb3UcucAateA=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9/go.mod h1:w7wZ/s9qK7c8g4al+UyoF1Sp/Z45UwMGcqIzLWVQHWk=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.24 h1:CQW2FTrflfoslYWLf3fv7vG28Q219+v8YJS5QTQb2+Y=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.24/go.mod h1:Xfx13T+u3nH6EEzgl9fBSO6nDRmze1FvnZNYkctQ2zw=
github.com/aws/aws-sdk-go-v2/service/signin v1.1.0 h1:yQo3eZ5qFaL1sJWqs1nL6j3yPHA2/R7c6tQ4T+0IO10=
github.com/aws/aws-sdk-go-v2/service/signin v1.1.0/go.mod h1:3Zzou41Qt/ueXfIzHvTEjDNuR5IjCUBVF01SNhrt1e8=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.18 h1:ApLTFdAZfDhZSiY5uskwECKHkSNNF83y2Ru2r7SezWA=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.18/go.mod h1:A9K9qx2l6nK89hp+a350FdGfRkrkH5HdiEjHbiy/Q/c=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.1 h1:4VD7TIZOGzehrgQ8vDE+1c6BQW4ErZPGY8ohZT5LXEE=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.1/go.mod h1:er0SFJfdV89Rit5hIJu/EXtv+qC2XMnxoksLmcUFkqM=
github.com/aws/aws-sdk-go-v2/service/sts v1.42.2 h1:XKnxlM4KZH1gktcsh3zSWc7GW4KivEv/OkifmHOhCUY=
github.com/aws/aws-sdk-go-v2/service/sts v1.42.2/go.mod h1:KJYmkQaFB3SUW2j3aBkPsxNmAb4ZsSOvbvCpuxzHJA0=
github.com/aws/smithy-go v1.26.0 h1:9ouqbi+NyKP7fV3Te7UElCwdAb6Y8uk7LGwPE5tVe/s=
github.com/aws/smithy-go v1.26.0/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/aws/aws-sdk-go-v2 v1.42.0 h1:XvXMJTkFQtpBKIWZnmr9ZEOc2InWM2yldjXEJ/bymhA=
github.com/aws/aws-sdk-go-v2 v1.42.0/go.mod h1:27+ACypSLljLAEKsCYOmrjKh83vuTRkuAe9Uv/3A4bg=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.13 h1:p1BBrg/Hhp6uK7zpejeI8QFXHJeC/mynzi04Sl03k9g=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.13/go.mod h1:8cIfkE9MDhkRZGpQ22aV6/lkYeYSozpz16Smrs5x4Ls=
github.com/aws/aws-sdk-go-v2/config v1.32.25 h1:ACCejvStYoilgwrfegSt5ZntCbPrk52qfwyNcnl3omM=
github.com/aws/aws-sdk-go-v2/config v1.32.25/go.mod h1:LJyU8sDRbXUxFn8xMJIGP+v9QYYwveNLI8a/giAOiAs=
github.com/aws/aws-sdk-go-v2/credentials v1.19.24 h1:2hQqYCV9yqyePQ9o6dCrZc/zO8U3TwPr9mIKlZnPu/I=
github.com/aws/aws-sdk-go-v2/credentials v1.19.24/go.mod h1:IDwpACtwqHLISdzfwUUNq4P9DsB/h5BLg4FwJPNfqFY=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.29 h1:r6qZHbT+wxgWO/e9vYNUEtg7lv5+UN3pRqKhLXvnArg=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.29/go.mod h1:QRnaRcTVGKPGRy8w78HMQtKUGRYcnMZAANATkeVA6Mo=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.29 h1:f3vKqSo13fhTYb+JEcXwXefZQE26I1FB5eTSniU67ko=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.29/go.mod h1:MzoLFUArKGpGD+ukmPiTPG1X5x4o6M2kq4v2dr1FiEc=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.29 h1:RdwIf/CuUsvJX3RgJagbOyotl/cxoLY4xviKuE7p2GY=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.29/go.mod h1:71wt8W2EgswdZy9Mf9KNnzxZ3TiZlv4caKghPktDOkA=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.30 h1:VTGy885W5DKBxWRUJbym9hytNaYzsyaPkCHGRRMAOhU=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.30/go.mod h1:AS0HycUvJRFvTt613AYDOgO2jzw+00cVSMny8XB3yMY=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.12 h1:ZD2+BSw9vFsNlKYIasSNt3uDbjqqXIBcM13UJv/Lx2k=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.12/go.mod h1:Ms4zlcVBbXbiP7EVLhl+lgjvA/a7YphqQ3Ih3174EmI=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.29 h1:DRebniUGZ2MqiiIVmQJ04vIXr918hubdHMnarSLEWyU=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.29/go.mod h1:LfRkPCD8YHDM2E5eTkos2UpwYeZnBcVarTa8L59bJHA=
github.com/aws/aws-sdk-go-v2/service/signin v1.2.0 h1:3nXpRcFwRCW8n7HgO2QGy0Dc20eQNfBuUemGQhpF8m8=
github.com/aws/aws-sdk-go-v2/service/signin v1.2.0/go.mod h1:LxYujSTLPRlp2vTtcUO/+1ilrew8ytt6SvQyOgejzFQ=
github.com/aws/aws-sdk-go-v2/service/sso v1.31.3 h1:ey1XLTYXb9PcLt4535632o5kCGXNXEhNb620Dqwuylo=
github.com/aws/aws-sdk-go-v2/service/sso v1.31.3/go.mod h1:Lk7PlmoTYryQmyBG0EXqj5BcUbj3whXdU2s3yGI3EAc=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.6 h1:yLr03zQE/5Eu5l3QU0Si+xMbLMbSDF2YXsigqXngs6g=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.6/go.mod h1:Q5N6icH+KJZDLh+ESNwzdv6cZ6vLFF/egy3IOxWhmz4=
github.com/aws/aws-sdk-go-v2/service/sts v1.43.3 h1:VrIhKRCSK1umelSgB9RghvA9RTUYeQffyAS5ApXehNI=
github.com/aws/aws-sdk-go-v2/service/sts v1.43.3/go.mod h1:r8wkDOuLaaMFqFiYAb8dGY2A3gJCOujMc6CFOVC4Zhc=
github.com/aws/smithy-go v1.27.2 h1:y9NPmSE6am6LjEFPfqHqG/jJk7AauQvhCJONKh7kpzk=
github.com/aws/smithy-go v1.27.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/aymanbagabas/go-udiff v0.4.1 h1:OEIrQ8maEeDBXQDoGCbbTTXYJMYRCRO1fnodZ12Gv5o=
@@ -84,10 +84,10 @@ github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 h1:ZR7e0r
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834/go.mod h1:aKC/t2arECF6rNOnaKaVU6y4t4ZeHQzqfxedE/VkVhA=
github.com/charmbracelet/log v1.0.0 h1:HVVVMmfOorfj3BA9i8X8UL69Hoz9lI0PYwXfJvOdRc4=
github.com/charmbracelet/log v1.0.0/go.mod h1:uYgY3SmLpwJWxmlrPwXvzVYujxis1vAKRV/0VQB7yWA=
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266 h1:BW/sZtyd1JyYy0h5adMm3tzpNyL857LWjuTRET6OhpY=
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266/go.mod h1:1DahUaExbUZx/jD+FNT2PKP4L9rLE5+ZBRuI8mZjd/E=
github.com/charmbracelet/ultraviolet v0.0.0-20260525132238-948f4557a654 h1:FpSYhY28ucg9ZRr+2wj67FAQ0Ey5yiK0072PmRDJNek=
github.com/charmbracelet/ultraviolet v0.0.0-20260525132238-948f4557a654/go.mod h1:hFpumms29Smx3LStRfku8vcCTBe1Kq8aCXtHUJa3mjY=
github.com/charmbracelet/openai-go v0.0.0-20260617131321-5e4b9c18c4be h1:pg+OWlIkk9HOe/8P5J95aKe2wGDzFUiiyFOUpwR30B4=
github.com/charmbracelet/openai-go v0.0.0-20260617131321-5e4b9c18c4be/go.mod h1:1DahUaExbUZx/jD+FNT2PKP4L9rLE5+ZBRuI8mZjd/E=
github.com/charmbracelet/ultraviolet v0.0.0-20260615092913-2399af76d5b1 h1:4+r3uOJ69ueRBt4okgEfWZeXs3BD36HcDBmOIAUlETk=
github.com/charmbracelet/ultraviolet v0.0.0-20260615092913-2399af76d5b1/go.mod h1:f/jRa757WUmaOZrbPspXymbg/GnbF+rwe4OLsG7aXYo=
github.com/charmbracelet/x/ansi v0.11.7 h1:kzv1kJvjg2S3r9KHo8hDdHFQLEqn4RBCb39dAYC84jI=
github.com/charmbracelet/x/ansi v0.11.7/go.mod h1:9qGpnAVYz+8ACONkZBUWPtL7lulP9No6p1epAihUZwQ=
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
@@ -98,14 +98,14 @@ github.com/charmbracelet/x/editor v0.2.0 h1:7XLUKtaRaB8jN7bWU2p2UChiySyaAuIfYiIR
github.com/charmbracelet/x/editor v0.2.0/go.mod h1:p3oQ28TSL3YPd+GKJ1fHWcp+7bVGpedHpXmo0D6t1dY=
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 h1:JSt3B+U9iqk37QUU2Rvb6DSBYRLtWqFqfxf8l5hOZUA=
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86/go.mod h1:2P0UgXMEa6TsToMSuFqKFQR+fZTO9CNGUNokkPatT/0=
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260527151214-009e6338d40d h1:sMilwx1YIYTrQva6jsB522AoRYAerNaDIKP4ZPtUq0A=
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260527151214-009e6338d40d/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260615092313-b57e5e6d29bb h1:hoqNT54vrpXamSaQe5GxupakGgvvqFmVgmLJjotpHco=
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260615092313-b57e5e6d29bb/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f h1:pk6gmGpCE7F3FcjaOEKYriCvpmIN4+6OS/RD0vm4uIA=
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f/go.mod h1:IfZAMTHB6XkZSeXUqriemErjAWCCzT0LwjKFYCZyw0I=
github.com/charmbracelet/x/exp/ordered v0.1.0 h1:55/qLwjIh0gL0Vni+QAWk7T/qRVP6sBf+2agPBgnOFE=
github.com/charmbracelet/x/exp/ordered v0.1.0/go.mod h1:5UHwmG+is5THxMyCJHNPCn2/ecI07aKNrW+LcResjJ8=
github.com/charmbracelet/x/exp/slice v0.0.0-20260527151214-009e6338d40d h1:RxcAR+vJCoD8QqT1cqLtkQKw+1cqvjqnu5IpPqYzPco=
github.com/charmbracelet/x/exp/slice v0.0.0-20260527151214-009e6338d40d/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
github.com/charmbracelet/x/exp/slice v0.0.0-20260615092313-b57e5e6d29bb h1:fr6DwrfJB2XQ3zM2fCwumXPE5G+hegnkEpl1KUuPsQI=
github.com/charmbracelet/x/exp/slice v0.0.0-20260615092313-b57e5e6d29bb/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
github.com/charmbracelet/x/exp/strings v0.1.0 h1:i69S2XI7uG1u4NLGeJPSYU++Nmjvpo9nwd6aoEm7gkA=
github.com/charmbracelet/x/exp/strings v0.1.0/go.mod h1:/ehtMPNh9K4odGFkqYJKpIYyePhdp1hLBRvyY4bWkH8=
github.com/charmbracelet/x/json v0.2.0 h1:DqB+ZGx2h+Z+1s98HOuOyli+i97wsFQIxP2ZQANTPrQ=
@@ -124,8 +124,8 @@ github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJ
github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM=
github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2 h1:aBangftG7EVZoUb69Os8IaYg++6uMOdKK83QtkkvJik=
github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2/go.mod h1:qwXFYgsP6T7XnJtbKlf1HP8AjxZZyzxMmc+Lq5GjlU4=
github.com/coder/acp-go-sdk v0.13.0 h1:IAKBDIbe/iBfKAGikeIndzb8fowt4ioD+gCtSU4HwMA=
github.com/coder/acp-go-sdk v0.13.0/go.mod h1:yKzM/3R9uELp4+nBAwwtkS0aN1FOFjo11CNPy37yFko=
github.com/coder/acp-go-sdk v0.13.5 h1:LI9jq5xon7xslaYlnoktvTVyDlE37yIk2daT7N9ASYk=
github.com/coder/acp-go-sdk v0.13.5/go.mod h1:yKzM/3R9uELp4+nBAwwtkS0aN1FOFjo11CNPy37yFko=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
@@ -133,8 +133,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.12.0 h1:0j4c5qQmnC6XOWNjP3PIXURXN2gWx76rd3KvgdPkCz8=
github.com/dlclark/regexp2 v1.12.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dlclark/regexp2/v2 v2.1.1 h1:LCUGyd9Wf+r+VVOl8Ny38JTpWJcAsdVnCIuhhtthmKw=
github.com/dlclark/regexp2/v2 v2.1.1/go.mod h1:avUrQvPaLz2DrFNHJF0taWAFFX2C1GMSSoeiqFjcBmU=
github.com/dlclark/regexp2/v2 v2.2.2 h1:MYWvNYw8okuqNhwTYO587EZMiDruVa2vhV6fsGpfya0=
github.com/dlclark/regexp2/v2 v2.2.2/go.mod h1:avUrQvPaLz2DrFNHJF0taWAFFX2C1GMSSoeiqFjcBmU=
github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI=
github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
@@ -144,14 +144,14 @@ github.com/envoyproxy/go-control-plane/envoy v1.37.0 h1:u3riX6BoYRfF4Dr7dwSOroNf
github.com/envoyproxy/go-control-plane/envoy v1.37.0/go.mod h1:DReE9MMrmecPy+YvQOAOHNYMALuowAnbjjEMkkWOi6A=
github.com/envoyproxy/protoc-gen-validate v1.3.3 h1:MVQghNeW+LZcmXe7SY1V36Z+WFMDjpqGAGacLe2T0ds=
github.com/envoyproxy/protoc-gen-validate v1.3.3/go.mod h1:TsndJ/ngyIdQRhMcVVGDDHINPLWB7C82oDArY51KfB0=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/felixge/httpsnoop v1.1.0 h1:3YtUj32ZZkqZtt3sZZsClsymw/QDuVfpNhoA31zeORc=
github.com/felixge/httpsnoop v1.1.0/go.mod h1:Zqxgdd+1Rkcz8euOqdr7lqgCRJztwr5hp9vDSi5UZCE=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.10.1 h1:b0/UzAf9yR5rhf3RPm9gf3ehBPpf0oZKIjtpKrx59Ho=
github.com/fsnotify/fsnotify v1.10.1/go.mod h1:TLheqan6HD6GBK6PrDWyDPBaEV8LspOxvPSjC+bVfgo=
github.com/go-json-experiment/json v0.0.0-20260520185125-572e7c383686 h1:NZBJxCpbHS1gzS6xAmyxbJznosZIIPk9IB42v62UvKA=
github.com/go-json-experiment/json v0.0.0-20260520185125-572e7c383686/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg=
github.com/go-json-experiment/json v0.0.0-20260601182631-00ed12fed2a6 h1:nxP4pPoyqOAgX8lYDFCfl3DyKeXErCvSvhcyzwGV9CE=
github.com/go-json-experiment/json v0.0.0-20260601182631-00ed12fed2a6/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg=
github.com/go-logfmt/logfmt v0.6.1 h1:4hvbpePJKnIzH1B+8OR/JPbTx37NktoI9LE2QZBBkvE=
github.com/go-logfmt/logfmt v0.6.1/go.mod h1:EV2pOAQoZaT1ZXZbqDl5hrymndi4SY9ED9/z6CO0XAk=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
@@ -189,14 +189,10 @@ github.com/indaco/herald v0.13.0 h1:+xVG9Fx5NpuWhwku/9IlRL6I009NnX4VUGKvlZHTRxU=
github.com/indaco/herald v0.13.0/go.mod h1:T5g1+XLYvpjouhzAGHnAHDCKizhESkoV6+QPZ3DhgWA=
github.com/indaco/herald-md v0.3.0 h1:hN1cKyrexPPM9PeHBsKuaWvIizSi/iYvM9yzRgtdb8M=
github.com/indaco/herald-md v0.3.0/go.mod h1:RUHVaDSG45ymJjKyxpDwBocLXrZo93FB4OeYMsw9B9s=
github.com/kaptinlin/go-i18n v0.4.5 h1:9tIlo5A0RXth+yZJO2MG7Bhpu/X9PlzQnGz/qyYWNoY=
github.com/kaptinlin/go-i18n v0.4.5/go.mod h1:mU/7BH4molY5lGZYBwBRKAaiJ70dWRHuqmQ0/pFLGno=
github.com/kaptinlin/jsonpointer v0.4.25 h1:iJ197e8n+WwqaqBsa53FqG3rPJCg5oijyFXEXNWWC3E=
github.com/kaptinlin/jsonpointer v0.4.25/go.mod h1:wVOBaXGGnP42YsMb6zev/3W5POTvspdNfh8DXzf8XS8=
github.com/kaptinlin/jsonschema v0.7.13 h1:kahVXTy/rURL0XJjyQ9WELm59wEmXi6IY0TWswQEFvU=
github.com/kaptinlin/jsonschema v0.7.13/go.mod h1:Uh0aUBusnhXDCEXJ2oimL/hx7YTo7F+sKniE+tM0ERc=
github.com/kaptinlin/messageformat-go v0.6.0 h1:D6jiXFsKW4/JG2CMddv/F6Rev9KVbCRKEzzV5QOAcpc=
github.com/kaptinlin/messageformat-go v0.6.0/go.mod h1:NKjwS6e9u7DRhAK+vydjDDwJ7UbdHhYjk/yk2WPuZPs=
github.com/kaptinlin/jsonpointer v0.4.26 h1:tw616yszHek+B3/GtDSia+uzBa3sLXGpmo4tYeMhBZw=
github.com/kaptinlin/jsonpointer v0.4.26/go.mod h1:wVOBaXGGnP42YsMb6zev/3W5POTvspdNfh8DXzf8XS8=
github.com/kaptinlin/jsonschema v0.8.1 h1:Krhuq1HpE+olHoPfcxkohqKKCnXfixUPv+aUYRegBBQ=
github.com/kaptinlin/jsonschema v0.8.1/go.mod h1:mCH2W5lXd29tdDjvoFfY32nedPORnlk7pCVrrcs/NkQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
@@ -205,12 +201,12 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/lucasb-eyer/go-colorful v1.4.0 h1:UtrWVfLdarDgc44HcS7pYloGHJUjHV/4FwW4TvVgFr4=
github.com/lucasb-eyer/go-colorful v1.4.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mark3labs/mcp-go v0.54.1 h1:Ap/ptEB9FtWzFKM8NDsTA7QDxerQOC06eZigrTldVj0=
github.com/mark3labs/mcp-go v0.54.1/go.mod h1:+8WclSK1ZUweCP3hvktSji8n8ABG/95QaEkeVE/Uwas=
github.com/mark3labs/mcp-go v0.55.0 h1:lJfz2aoctiwK+sI991+uIYwmKNIBciI+O7zsyDsa4U8=
github.com/mark3labs/mcp-go v0.55.0/go.mod h1:+8WclSK1ZUweCP3hvktSji8n8ABG/95QaEkeVE/Uwas=
github.com/mattn/go-isatty v0.0.22 h1:j8l17JJ9i6VGPUFUYoTUKPSgKe/83EYU2zBC7YNKMw4=
github.com/mattn/go-isatty v0.0.22/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw=
github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mattn/go-runewidth v0.0.24 h1:cpokDiIn0MGnhdHwuWnJBITySJ20QyNGnY2kR/ay2DU=
github.com/mattn/go-runewidth v0.0.24/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
@@ -225,8 +221,8 @@ github.com/muesli/roff v0.1.0 h1:YD0lalCotmYuF5HhZliKWlIx7IEhiXeSfq7hNjFqGF8=
github.com/muesli/roff v0.1.0/go.mod h1:pjAHQM9hdUUwm/krAfrLGgJkXJ+YuhtsfZ42kieB2Ig=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/pelletier/go-toml/v2 v2.3.1 h1:MYEvvGnQjeNkRF1qUuGolNtNExTDwct51yp7olPtrEc=
github.com/pelletier/go-toml/v2 v2.3.1/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pelletier/go-toml/v2 v2.4.0 h1:Mwu0mAkUKbittDs3/ADDWXqMmq3EOK2VHiuCkV00Row=
github.com/pelletier/go-toml/v2 v2.4.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo=
@@ -294,36 +290,38 @@ go.opentelemetry.io/otel/trace v1.44.0 h1:jxF5CsGYCe74MCRx2X4g7WsY/VBKRqqpNvXlX/
go.opentelemetry.io/otel/trace v1.44.0/go.mod h1:oLl1jrMQAVo6v3GAggN+1VH9VIz9iUSvW53sW1Q8PIE=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988=
golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc=
golang.org/x/exp v0.0.0-20260528193900-50dc527dd6c7 h1:cHpkPjp4TILjdZxz/O4ykwCpeS+dDqNuDGse4zgQDCk=
golang.org/x/exp v0.0.0-20260528193900-50dc527dd6c7/go.mod h1:d2fgXJLVs4dYDHUk5lwMIfzRzSrWCfGZb0ZqeLa/Vcw=
golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8=
golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww=
golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto=
golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio=
golang.org/x/exp v0.0.0-20260611194520-c48552f49976 h1:X8Hz2ImujgbmetVuW+w2YkyZChE3cBpZi2P158rTG9M=
golang.org/x/exp v0.0.0-20260611194520-c48552f49976/go.mod h1:vnf4pv9iKZXY58sQE1L86zmNWJ4159e1RkcWiLCkeEY=
golang.org/x/image v0.42.0 h1:1gSs6ehNWXLbkHBIPcWztk3D/6aIA/8hauiAYtlodVY=
golang.org/x/image v0.42.0/go.mod h1:rrpelvGFt+kLPAjPM4HeWPgrl0FtafueU//e5N0qk/Q=
golang.org/x/net v0.56.0 h1:Rw8j/hFzGvJUZwNBXnAtf5sVDVt+65SK2C7IxCxZt5o=
golang.org/x/net v0.56.0/go.mod h1:D3Ku6r+V6JROoZK144D2XfMHFcMq/0zSfLelVTCFKec=
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY=
golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM=
golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.46.0 h1:noSf2Fq6F8DBgS+LysIkx7rIExoNHJsxOAtPp4rthXw=
golang.org/x/sys v0.46.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.44.0 h1:0rLvDRCtNj0gZkyIXhCyOb2OAzEhLVqc4B+hrsBhrmc=
golang.org/x/term v0.44.0/go.mod h1:7ze4MdzUzLXpSAoFP1H0bOI9aXDqveSvatT5vKcFh2Y=
golang.org/x/text v0.38.0 h1:sXmwo9DwP3OK9EZ7PqAdaooSGozfl/3a6/xJcbzPRhE=
golang.org/x/text v0.38.0/go.mod h1:YXZt3QhHUKYT53r2lLKFIVi6Ao1jdzrTR/KQ09qyxF4=
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
google.golang.org/api v0.282.0 h1:WmJiSVqUnKqJCpJOx7YADbXaC+9DDsnGSfllFSj7R2I=
google.golang.org/api v0.282.0/go.mod h1:6Wssta4c5n9qHq5CBhmlai5h/PUa1djdDAIhYEHyvcM=
google.golang.org/genai v1.58.0 h1:MNA3ZkRyr7MnRwZ9RNZ60p4+UMKV3yYRw6pyHq4pp0U=
google.golang.org/genai v1.58.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
google.golang.org/genproto v0.0.0-20260504160031-60b97b32f348 h1:JjVGDZYWkJWZcxveJGzfkXC5myDVWAd4dZdgbzrDUv8=
google.golang.org/genproto v0.0.0-20260504160031-60b97b32f348/go.mod h1:95PqD4xM+AdOcBGsmgfaofXsiA37uXDtDufVbntT3TU=
google.golang.org/genproto/googleapis/api v0.0.0-20260504160031-60b97b32f348 h1:U8orV30l6KpDsi9dxU0CoJZGbjS8EEpw+6ba+XwGPQA=
google.golang.org/genproto/googleapis/api v0.0.0-20260504160031-60b97b32f348/go.mod h1:Yzdzr5OOZFgSsEV2D/Xi9NL3bszpXFAg0hFJiRohcD8=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260526163538-3dc84a4a5aaa h1:mZHHdPZl0dbGHCflZgAq/Q468DWVFcU2whhB2KAo8fk=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260526163538-3dc84a4a5aaa/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
google.golang.org/api v0.285.0 h1:B7eHHoKGAX/LrPkQvhQqnGwjgWxofbdGwCTQvpm8FkM=
google.golang.org/api v0.285.0/go.mod h1:NlOlUIr8MPoIhT9Bb/oUnRuHbJOLwxb6JSYJM8Yz+jQ=
google.golang.org/genai v1.61.0 h1:wCyNGiaC9q5A59B80zuEtNBhq3ypEvICFkZYOfK7IO0=
google.golang.org/genai v1.61.0/go.mod h1:mDdPDFXo1Ats7f1WXVyZgWb/CkMzFWTWJruIMy7hGIU=
google.golang.org/genproto v0.0.0-20260610212136-7ab31c22f7ad h1:cYL1DPJAQr4JMvhfGao0PDXoaf03ifMljAuDyrbMBd0=
google.golang.org/genproto v0.0.0-20260610212136-7ab31c22f7ad/go.mod h1:cVHIikDNAdx8ISZeW+2rYkEMf3xn0GSaBYmVnWXQBUo=
google.golang.org/genproto/googleapis/api v0.0.0-20260610212136-7ab31c22f7ad h1:3iLyITS/sySRwbUKoC7ogfj2Yr1Cjs0pfaRKj5U5HEw=
google.golang.org/genproto/googleapis/api v0.0.0-20260610212136-7ab31c22f7ad/go.mod h1:KdNqO+rCIWgFumrNBSEDlDNrkrQnpkax7Tv1WxNY8V4=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260615183401-62b3387ff324 h1:9HZDLIdYBJXAnaFOr9WHrKVycfpY+75s9HGadC0305A=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260615183401-62b3387ff324/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
google.golang.org/grpc v1.81.1 h1:VnnIIZ88UzOOKLukQi+ImGz8O1Wdp8nAGGnvOfEIWQQ=
google.golang.org/grpc v1.81.1/go.mod h1:xGH9GfzOyMTGIOXBJmXt+BX/V0kcdQbdcuwQ/zNw42I=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
+6
View File
@@ -61,6 +61,12 @@ func (a *Agent) Authenticate(_ context.Context, _ acp.AuthenticateRequest) (acp.
return acp.AuthenticateResponse{}, nil
}
// Logout handles logout requests. Kit doesn't require auth for local stdio
// usage, so this is a no-op.
func (a *Agent) Logout(_ context.Context, _ acp.LogoutRequest) (acp.LogoutResponse, error) {
return acp.LogoutResponse{}, nil
}
// Initialize negotiates capabilities with the ACP client.
func (a *Agent) Initialize(_ context.Context, params acp.InitializeRequest) (acp.InitializeResponse, error) {
log.Debug("acp: initialize", "protocol_version", params.ProtocolVersion)
+59 -97
View File
@@ -73,111 +73,73 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
// Wire extension context with headless implementations so extensions
// work in ACP mode. TUI-dependent features (widgets, prompts, editor)
// become no-ops or return cancelled; all data/model/tool APIs work
// identically to interactive mode.
// become no-ops or return cancelled; all data/model/tool APIs come from
// extbridge.BaseContext and work identically to interactive mode.
if kitInstance.Extensions().HasExtensions() {
kitInstance.Extensions().SetContext(extensions.Context{
SessionID: sessionID,
CWD: cwd,
Model: kitInstance.GetModelString(),
Interactive: false,
// Use a background context for subagent spawns: the create() ctx is
// request-scoped and may be cancelled before extensions spawn anything.
ec := extbridge.BaseContext(context.Background(), kitInstance)
// Output — route through structured logger.
Print: func(text string) { log.Debug("extension: print", "text", text) },
PrintInfo: func(text string) { log.Info("extension: info", "text", text) },
PrintError: func(text string) { log.Error("extension: error", "text", text) },
PrintBlock: func(opts extensions.PrintBlockOpts) {
log.Info("extension: block", "subtitle", opts.Subtitle, "text", opts.Text)
},
ec.SessionID = sessionID
ec.CWD = cwd
ec.Model = kitInstance.GetModelString()
ec.Interactive = false
// Message injection — no-ops for now; ACP clients drive prompts.
SendMessage: func(string) {},
CancelAndSend: func(string) {},
Exit: func() {},
// Output — route through structured logger.
ec.Print = func(text string) { log.Debug("extension: print", "text", text) }
ec.PrintInfo = func(text string) { log.Info("extension: info", "text", text) }
ec.PrintError = func(text string) { log.Error("extension: error", "text", text) }
ec.PrintBlock = func(opts extensions.PrintBlockOpts) {
log.Info("extension: block", "subtitle", opts.Subtitle, "text", opts.Text)
}
// TUI widgets/chrome — silent no-ops (no TUI in ACP).
SetWidget: func(extensions.WidgetConfig) {},
RemoveWidget: func(string) {},
SetHeader: func(extensions.HeaderFooterConfig) {},
RemoveHeader: func() {},
SetFooter: func(extensions.HeaderFooterConfig) {},
RemoveFooter: func() {},
SetEditor: func(extensions.EditorConfig) {},
ResetEditor: func() {},
SetEditorText: func(string) {},
SetUIVisibility: func(extensions.UIVisibility) {},
SetStatus: func(string, string, int) {},
RemoveStatus: func(string) {},
// Message injection — no-ops for now; ACP clients drive prompts.
ec.SendMessage = func(string) {}
ec.CancelAndSend = func(string) {}
ec.NewSession = func(string) error {
return fmt.Errorf("new session not available in ACP mode")
}
ec.Exit = func() {}
// Interactive prompts — return cancelled (no user to prompt).
PromptSelect: func(extensions.PromptSelectConfig) extensions.PromptSelectResult {
return extensions.PromptSelectResult{Cancelled: true}
},
PromptConfirm: func(extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
return extensions.PromptConfirmResult{Cancelled: true}
},
PromptInput: func(extensions.PromptInputConfig) extensions.PromptInputResult {
return extensions.PromptInputResult{Cancelled: true}
},
ShowOverlay: func(extensions.OverlayConfig) extensions.OverlayResult {
return extensions.OverlayResult{Cancelled: true, Index: -1}
},
SuspendTUI: func(callback func()) error { callback(); return nil },
// TUI widgets/chrome — silent no-ops (no TUI in ACP).
ec.SetWidget = func(extensions.WidgetConfig) {}
ec.RemoveWidget = func(string) {}
ec.SetHeader = func(extensions.HeaderFooterConfig) {}
ec.RemoveHeader = func() {}
ec.SetFooter = func(extensions.HeaderFooterConfig) {}
ec.RemoveFooter = func() {}
ec.SetEditor = func(extensions.EditorConfig) {}
ec.ResetEditor = func() {}
ec.SetEditorText = func(string) {}
ec.SetUIVisibility = func(extensions.UIVisibility) {}
ec.SetStatus = func(string, string, int) {}
ec.RemoveStatus = func(string) {}
// Data access — delegate to Kit instance.
GetContextStats: func() extensions.ContextStats {
s := kitInstance.GetContextStats()
return extensions.ContextStats{
EstimatedTokens: s.EstimatedTokens,
ContextLimit: s.ContextLimit,
UsagePercent: s.UsagePercent,
MessageCount: s.MessageCount,
}
},
GetMessages: func() []extensions.SessionMessage { return kitInstance.Extensions().GetSessionMessages() },
GetSessionPath: func() string { return kitInstance.GetSessionPath() },
AppendEntry: func(entryType, data string) (string, error) {
return kitInstance.Extensions().AppendEntry(entryType, data)
},
GetEntries: func(entryType string) []extensions.ExtensionEntry {
return kitInstance.Extensions().GetEntries(entryType)
},
// Interactive prompts — return cancelled (no user to prompt).
ec.PromptSelect = func(extensions.PromptSelectConfig) extensions.PromptSelectResult {
return extensions.PromptSelectResult{Cancelled: true}
}
ec.PromptConfirm = func(extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
return extensions.PromptConfirmResult{Cancelled: true}
}
ec.PromptInput = func(extensions.PromptInputConfig) extensions.PromptInputResult {
return extensions.PromptInputResult{Cancelled: true}
}
ec.ShowOverlay = func(extensions.OverlayConfig) extensions.OverlayResult {
return extensions.OverlayResult{Cancelled: true, Index: -1}
}
ec.SuspendTUI = func(callback func()) error { callback(); return nil }
// Options, model, and tool management.
GetOption: func(name string) string { return kitInstance.Extensions().GetOption(name) },
SetOption: func(name, value string) { kitInstance.Extensions().SetOption(name, value) },
SetModel: func(modelString string) error {
previousModel := kitInstance.Extensions().GetContext().Model
if err := kitInstance.SetModel(context.Background(), modelString); err != nil {
return err
}
kitInstance.Extensions().UpdateContextModel(modelString)
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
return nil
},
GetAvailableModels: func() []extensions.ModelInfoEntry { return kitInstance.GetAvailableModels() },
EmitCustomEvent: func(name, data string) { kitInstance.Extensions().EmitCustomEvent(name, data) },
GetAllTools: func() []extensions.ToolInfo { return kitInstance.Extensions().GetToolInfos() },
SetActiveTools: func(names []string) { kitInstance.Extensions().SetActiveTools(names) },
// Render — fall back to logging.
ec.RenderMessage = func(name, content string) {
renderer := kitInstance.Extensions().GetMessageRenderer(name)
if renderer != nil && renderer.Render != nil {
content = renderer.Render(content, 80)
}
log.Info("extension: message", "renderer", name, "content", content)
}
// LLM completions and subagents.
Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
return kitInstance.ExecuteCompletion(context.Background(), req)
},
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
return extbridge.SpawnSubagent(context.Background(), kitInstance, config)
},
// Render — fall back to logging.
RenderMessage: func(name, content string) {
renderer := kitInstance.Extensions().GetMessageRenderer(name)
if renderer != nil && renderer.Render != nil {
content = renderer.Render(content, 80)
}
log.Info("extension: message", "renderer", name, "content", content)
},
ReloadExtensions: func() error { return kitInstance.Extensions().Reload() },
})
kitInstance.Extensions().SetContext(ec)
kitInstance.Extensions().EmitSessionStart()
}
+3 -41
View File
@@ -169,9 +169,9 @@ type RetryHandler func(attempt int, err error)
type PrepareStepHandler func(stepNumber int, messages []fantasy.Message) []fantasy.Message
// GenerateCallbacks consolidates all callback functions for
// GenerateWithLoopAndStreaming into a single struct. This replaces the previous
// 16+ positional callback parameters, making it easier to add new callbacks
// without breaking existing callers (new fields default to nil).
// GenerateWithCallbacks into a single struct, replacing what was previously
// 16+ positional callback parameters. New fields default to nil, so adding
// new callbacks does not break existing callers.
type GenerateCallbacks struct {
OnToolCall ToolCallHandler
OnToolExecution ToolExecutionHandler
@@ -522,44 +522,6 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []fantasy.Message
})
}
// GenerateWithLoopAndStreaming processes messages using the agent with streaming and callbacks.
// The agent handles the tool call loop internally.
//
// Deprecated: Use GenerateWithCallbacks instead, which takes a GenerateCallbacks
// struct and is easier to extend with new callbacks.
func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fantasy.Message,
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler,
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
onStreamingResponse StreamingResponseHandler,
onReasoningDelta ReasoningDeltaHandler,
onReasoningComplete ReasoningCompleteHandler,
onToolOutput ToolOutputHandler,
onStepMessages StepMessagesHandler,
onStepUsage StepUsageHandler,
onPasswordPrompt PasswordPromptHandler,
onToolCallStart ToolCallStartHandler,
onToolCallDelta ToolCallDeltaHandler,
onToolCallEnd ToolCallEndHandler,
) (*GenerateWithLoopResult, error) {
return a.GenerateWithCallbacks(ctx, messages, GenerateCallbacks{
OnToolCall: onToolCall,
OnToolExecution: onToolExecution,
OnToolResult: onToolResult,
OnResponse: onResponse,
OnToolCallContent: onToolCallContent,
OnStreamingResponse: onStreamingResponse,
OnReasoningDelta: onReasoningDelta,
OnReasoningComplete: onReasoningComplete,
OnToolOutput: onToolOutput,
OnStepMessages: onStepMessages,
OnStepUsage: onStepUsage,
OnPasswordPrompt: onPasswordPrompt,
OnToolCallStart: onToolCallStart,
OnToolCallDelta: onToolCallDelta,
OnToolCallEnd: onToolCallEnd,
})
}
// GenerateWithCallbacks processes messages using the agent with streaming and callbacks.
// The agent handles the tool call loop internally. We map the rich callback system
// to kit's existing callback interface for UI integration.
+253 -9
View File
@@ -2,6 +2,7 @@ package app
import (
"context"
"errors"
"fmt"
"log"
"os"
@@ -13,6 +14,7 @@ import (
"charm.land/fantasy"
"github.com/mark3labs/kit/internal/extensions"
"github.com/mark3labs/kit/internal/message"
"github.com/mark3labs/kit/internal/session"
kit "github.com/mark3labs/kit/pkg/kit"
)
@@ -23,6 +25,26 @@ type queueItem struct {
Files []kit.LLMFilePart
}
// ErrAgentBusy is returned when an operation cannot proceed because the agent
// is still processing a turn (including any post-turn extension hooks) and did
// not become idle before the operation's deadline.
//
// This is an alias for extensions.ErrAgentBusy so the extension API and the
// app layer share a single sentinel value — callers can detect the condition
// with errors.Is(err, app.ErrAgentBusy) without substring-matching the error
// message.
var ErrAgentBusy = extensions.ErrAgentBusy
// DefaultNewSessionIdleWait bounds how long RequestNewSessionFromExtension
// will block waiting for the agent to settle. It needs to be generous enough
// to cover real-world post-turn tooling (project formatters, on-save linters,
// hidden tool calls) which routinely hold the busy flag for seconds and
// occasionally minutes — yet still short enough to surface a wedged agent.
//
// Issue #63 reported workloads where the busy window regularly exceeded
// 6 seconds; ten minutes is the same bound the workaround in that issue used.
const DefaultNewSessionIdleWait = 10 * time.Minute
// App is the application-layer orchestrator. It owns the agentic loop,
// conversation history (via MessageStore), and queue management. It is
// designed to be created once per session and reused across multiple prompts.
@@ -54,11 +76,25 @@ type App struct {
// each new step and called by CancelCurrentStep().
cancelStep context.CancelFunc
// mu protects busy, queue, and cancelStep.
// mu protects busy, queue, cancelStep, and idleCh.
mu sync.Mutex
busy bool
queue []queueItem
// idleCh is closed when the agent transitions from busy back to idle.
// While the agent is idle the channel is already closed (recv returns
// immediately). When busy transitions to true a fresh open channel is
// allocated so callers blocked on the previous one are released. All
// transitions are funnelled through setBusyLocked to keep the channel
// pointer in sync with the busy flag.
//
// This is the underlying primitive WaitForIdle and
// RequestNewSessionFromExtension wait on to fix the AgentEnd→NewSession
// race described in issue #63: AgentEnd is emitted from inside the agent
// loop, before drainQueue clears busy, so any extension hook that calls
// ctx.NewSession synchronously would otherwise observe busy==true.
idleCh chan struct{}
// wg tracks in-flight goroutines; Close() waits on it.
wg sync.WaitGroup
@@ -94,6 +130,10 @@ type App struct {
// initialMessages may be nil or empty for a fresh session.
func New(opts Options, initialMessages []kit.LLMMessage) *App {
rootCtx, rootCancel := context.WithCancel(context.Background())
// idleCh starts already closed: the freshly constructed App is idle, so
// any caller blocking on it via WaitForIdle should be released immediately.
idleCh := make(chan struct{})
close(idleCh)
return &App{
opts: opts,
store: NewMessageStoreWithMessages(initialMessages),
@@ -101,6 +141,90 @@ func New(opts Options, initialMessages []kit.LLMMessage) *App {
rootCancel: rootCancel,
// cancelStep starts as a no-op so CancelCurrentStep() is always safe.
cancelStep: func() {},
idleCh: idleCh,
}
}
// setBusyLocked is the single chokepoint for mutating a.busy. It keeps the
// idleCh signalling channel in sync with the busy flag:
//
// - false → true: allocate a fresh open channel so future WaitForIdle
// callers block until the next idle transition.
// - true → false: close the current channel so any waiters wake up.
//
// No-op when the requested state already matches. The caller must hold a.mu.
func (a *App) setBusyLocked(busy bool) {
if a.busy == busy {
return
}
a.busy = busy
if busy {
a.idleCh = make(chan struct{})
} else {
close(a.idleCh)
}
}
// idleSnapshot returns the current busy state and the channel that will be
// closed on the next idle transition. The snapshot is taken under a.mu so the
// pair is consistent (busy==true ⇒ ch is the open channel for *this* busy
// cycle, not a stale one).
func (a *App) idleSnapshot() (busy bool, ch chan struct{}) {
a.mu.Lock()
defer a.mu.Unlock()
return a.busy, a.idleCh
}
// WaitForIdle blocks until the agent is idle, the given timeout elapses, or
// the app shuts down. Returns nil on idle, ErrAgentBusy on timeout, or the
// rootCtx error if the app is closing.
//
// A non-positive timeout disables the deadline and waits indefinitely (until
// idle or app shutdown). Safe to call from any goroutine, but never from
// inside the Bubble Tea Update() loop — it blocks.
//
// Idiomatic use from extensions:
//
// if err := app.WaitForIdle(0); err != nil { /* shutdown */ }
//
// The loop guards against the agent re-arming itself between wakeups: if
// another prompt is queued (or a steer message lands) while we're waiting,
// setBusyLocked allocates a fresh idleCh and we wait again.
func (a *App) WaitForIdle(timeout time.Duration) error {
var deadline time.Time
if timeout > 0 {
deadline = time.Now().Add(timeout)
}
for {
busy, ch := a.idleSnapshot()
if !busy {
return nil
}
var timer *time.Timer
var timerCh <-chan time.Time
if timeout > 0 {
remaining := time.Until(deadline)
if remaining <= 0 {
return ErrAgentBusy
}
timer = time.NewTimer(remaining)
timerCh = timer.C
}
select {
case <-ch:
// Idle transition observed — loop and re-check under the
// mutex in case a new busy cycle started immediately after.
case <-timerCh:
return ErrAgentBusy
case <-a.rootCtx.Done():
if timer != nil {
timer.Stop()
}
return a.rootCtx.Err()
}
if timer != nil {
timer.Stop()
}
}
}
@@ -154,7 +278,7 @@ func (a *App) RunWithFiles(prompt string, files []kit.LLMFilePart) int {
return qLen
}
a.busy = true
a.setBusyLocked(true)
a.wg.Add(1)
a.mu.Unlock()
go a.drainQueue(item)
@@ -234,7 +358,7 @@ func (a *App) SteerWithFiles(prompt string, files []kit.LLMFilePart) int {
if !a.busy {
// Not busy — start immediately, same as RunWithFiles().
item := queueItem{Prompt: prompt, Files: files}
a.busy = true
a.setBusyLocked(true)
a.wg.Add(1)
a.mu.Unlock()
go a.drainQueue(item)
@@ -270,7 +394,7 @@ func (a *App) InterruptAndSend(prompt string) {
if !a.busy {
// Not busy — start immediately, same as Run().
a.busy = true
a.setBusyLocked(true)
a.wg.Add(1)
a.mu.Unlock()
go a.drainQueue(item)
@@ -343,6 +467,90 @@ func (a *App) SwitchTreeSession(ts *session.TreeManager) {
}
}
// PopLastUserMessage truncates the tree session back to the parent of the
// most recent user message on the current branch, syncs the in-memory
// message store, and returns the user prompt text plus any image file
// parts so the caller can resubmit via Run/RunWithFiles.
//
// This is the building block for /retry: the user message and any orphaned
// assistant/tool entries produced by a failed turn become unreachable on
// the current branch (they remain in the session file under a different
// leaf) and are excluded from the next LLM context.
//
// Returns an error when:
// - the agent is currently working (busy)
// - the app has been closed
// - no tree session is active (sessions disabled via --no-session)
// - no user message exists on the current branch
//
// Satisfies ui.AppController.
func (a *App) PopLastUserMessage() (string, []kit.LLMFilePart, error) {
a.mu.Lock()
if a.closed {
a.mu.Unlock()
return "", nil, fmt.Errorf("app is closed")
}
if a.busy {
a.mu.Unlock()
return "", nil, fmt.Errorf("cannot retry while the agent is working")
}
a.mu.Unlock()
ts := a.opts.TreeSession
if ts == nil {
return "", nil, fmt.Errorf("no tree session active; /retry requires a session")
}
// Walk the current branch backwards to find the most recent user message.
branch := ts.GetBranch("")
var target *session.MessageEntry
for i := len(branch) - 1; i >= 0; i-- {
me, ok := branch[i].(*session.MessageEntry)
if !ok {
continue
}
if me.Role == string(message.RoleUser) {
target = me
break
}
}
if target == nil {
return "", nil, fmt.Errorf("no user message to retry")
}
// Extract the prompt text and any image parts from the target entry.
msg, err := target.ToMessage()
if err != nil {
return "", nil, fmt.Errorf("decode user message: %w", err)
}
prompt := msg.Content()
var files []kit.LLMFilePart
for _, part := range msg.Parts {
if ic, ok := part.(message.ImageContent); ok {
files = append(files, kit.LLMFilePart{
Data: ic.Data,
MediaType: ic.MediaType,
})
}
}
// Move the leaf to the parent of the user message. The failed turn's
// entries (user message + any partial assistant/tool entries) are still
// in the tree file but no longer on the active branch, so they will not
// be re-sent to the LLM. runTurn() will append a fresh user message on
// the next call.
if err := ts.Branch(target.ParentID); err != nil {
return "", nil, fmt.Errorf("branch to parent: %w", err)
}
// Sync the in-memory store with the new branch position so subsequent
// reads (and ReloadMessagesFromTree() consumers) see the truncated view.
a.store.Clear()
a.store.Replace(ts.GetLLMMessages())
return prompt, files, nil
}
// AddContextMessage adds a user-role message to the conversation history
// without triggering an LLM response. Used by the ! shell command prefix
// to inject command output into context so the LLM can reference it in
@@ -385,7 +593,7 @@ func (a *App) CompactConversation(customInstructions string) error {
a.mu.Unlock()
return fmt.Errorf("SDK instance not available")
}
a.busy = true
a.setBusyLocked(true)
a.wg.Add(1)
a.mu.Unlock()
@@ -447,7 +655,7 @@ func (a *App) CompactAsync(customInstructions string, onComplete func(), onError
a.mu.Unlock()
return fmt.Errorf("SDK instance not available")
}
a.busy = true
a.setBusyLocked(true)
a.wg.Add(1)
a.mu.Unlock()
@@ -536,7 +744,7 @@ func (a *App) releaseBusyAfterCompact() {
// in just before closed was set.
if a.closed {
a.queue = a.queue[:0]
a.busy = false
a.setBusyLocked(false)
a.mu.Unlock()
return
}
@@ -548,7 +756,7 @@ func (a *App) releaseBusyAfterCompact() {
a.queue = a.queue[:0]
if len(pending) == 0 {
a.busy = false
a.setBusyLocked(false)
a.mu.Unlock()
return
}
@@ -765,7 +973,7 @@ func (a *App) drainQueue(first queueItem) {
// Mark as no longer busy
a.mu.Lock()
a.busy = false
a.setBusyLocked(false)
a.mu.Unlock()
}
@@ -1145,6 +1353,42 @@ func (a *App) SetEditorTextFromExtension(text string) {
}
}
// RequestNewSessionFromExtension sends a NewSessionRequestEvent to the TUI
// to end the current session and start a fresh one. If initialPrompt is
// non-empty it is submitted as the first user turn of the new session.
//
// If the agent is currently busy (e.g. the caller is an OnAgentEnd hook that
// fires before drainQueue clears the busy flag, or there are queued prompts
// still being processed) the call blocks until the agent becomes idle, up to
// DefaultNewSessionIdleWait. If that deadline elapses, ErrAgentBusy is
// returned and callers can detect it with errors.Is. This wait-then-send
// behavior fixes the v0.79.0 phase-handoff race documented in issue #63.
//
// Returns an error when running headless (no TUI attached), when the wait
// for idle times out (ErrAgentBusy), when the app is shutting down, or when
// a BeforeSessionSwitch extension hook cancels the switch.
//
// This is the implementation behind ctx.NewSession(prompt) for the
// interactive TUI. It blocks the caller until the TUI processes the
// switch, so it must be invoked from a goroutine outside Update().
func (a *App) RequestNewSessionFromExtension(initialPrompt string) error {
a.mu.Lock()
prog := a.program
a.mu.Unlock()
if prog == nil {
return fmt.Errorf("new session unavailable: no interactive TUI attached")
}
if err := a.WaitForIdle(DefaultNewSessionIdleWait); err != nil {
if errors.Is(err, ErrAgentBusy) {
return fmt.Errorf("cannot start new session: %w", err)
}
return err
}
ch := make(chan error, 1)
prog.Send(NewSessionRequestEvent{InitialPrompt: initialPrompt, ResponseCh: ch})
return <-ch
}
// NotifyModelChanged sends a ModelChangedEvent to the TUI so it updates
// the model name in the status bar and message attribution.
func (a *App) NotifyModelChanged(provider, model string) {
+428 -4
View File
@@ -9,7 +9,10 @@ import (
"time"
tea "charm.land/bubbletea/v2"
"charm.land/fantasy"
kit "github.com/mark3labs/kit/pkg/kit"
"github.com/mark3labs/kit/internal/session"
)
// --------------------------------------------------------------------------
@@ -791,7 +794,7 @@ func TestReleaseBusyAfterCompact_flushesQueuedMessages(t *testing.T) {
// summarising. (Run() would have appended them and returned a queue
// length > 0 to the caller.)
app.mu.Lock()
app.busy = true
app.setBusyLocked(true)
app.queue = append(app.queue,
queueItem{Prompt: "queued during compact #1"},
queueItem{Prompt: "queued during compact #2"},
@@ -831,7 +834,7 @@ func TestReleaseBusyAfterCompact_idleWhenQueueEmpty(t *testing.T) {
defer app.Close()
app.mu.Lock()
app.busy = true
app.setBusyLocked(true)
app.mu.Unlock()
app.releaseBusyAfterCompact()
@@ -898,7 +901,7 @@ func TestReleaseBusyAfterCompact_splicesSteerAheadOfQueue(t *testing.T) {
// Simulate the state at the end of compaction: busy is set and a couple
// of regular Run() prompts have piled up after the steer messages.
app.mu.Lock()
app.busy = true
app.setBusyLocked(true)
app.queue = append(app.queue,
queueItem{Prompt: "queued-1"},
queueItem{Prompt: "queued-2"},
@@ -947,7 +950,7 @@ func TestReleaseBusyAfterCompact_dropsQueueWhenClosed(t *testing.T) {
app := newTestApp(stub)
app.mu.Lock()
app.busy = true
app.setBusyLocked(true)
app.queue = append(app.queue, queueItem{Prompt: "would have run"})
app.closed = true
app.mu.Unlock()
@@ -969,3 +972,424 @@ func TestReleaseBusyAfterCompact_dropsQueueWhenClosed(t *testing.T) {
t.Fatalf("expected 0 PromptFunc calls on closed app, got %d", n)
}
}
// --------------------------------------------------------------------------
// PopLastUserMessage (/retry building block)
// --------------------------------------------------------------------------
// TestPopLastUserMessage_NoTreeSession verifies that PopLastUserMessage
// returns an error when no tree session is active.
func TestPopLastUserMessage_NoTreeSession(t *testing.T) {
app := newTestApp(newStub())
defer app.Close()
prompt, files, err := app.PopLastUserMessage()
if err == nil {
t.Fatal("expected error when no tree session is active")
}
if prompt != "" || files != nil {
t.Fatalf("expected zero values on error, got prompt=%q files=%v", prompt, files)
}
}
// TestPopLastUserMessage_WhileBusy verifies that PopLastUserMessage
// refuses to truncate while the agent is busy (would race with executeBatch).
func TestPopLastUserMessage_WhileBusy(t *testing.T) {
app := newTestApp(newStub())
defer app.Close()
app.mu.Lock()
app.setBusyLocked(true)
app.mu.Unlock()
_, _, err := app.PopLastUserMessage()
if err == nil {
t.Fatal("expected error when agent is busy")
}
if !strings.Contains(err.Error(), "working") {
t.Fatalf("expected error mentioning busy/working, got %q", err.Error())
}
}
// TestPopLastUserMessage_WhenClosed verifies that PopLastUserMessage
// returns an error after Close().
func TestPopLastUserMessage_WhenClosed(t *testing.T) {
app := newTestApp(newStub())
app.Close()
_, _, err := app.PopLastUserMessage()
if err == nil {
t.Fatal("expected error on closed app")
}
}
// TestPopLastUserMessage_TruncatesAndReturnsPrompt verifies the happy path:
// a real tree session with user→assistant→user→assistant entries is
// truncated back to before the most recent user message, and that user's
// text is returned.
func TestPopLastUserMessage_TruncatesAndReturnsPrompt(t *testing.T) {
dir := t.TempDir()
ts, err := session.CreateTreeSession(dir)
if err != nil {
t.Fatalf("create tree session: %v", err)
}
defer func() { _ = ts.Close() }()
// Build history: user "first" → assistant "ack 1" → user "second" → assistant "ack 2".
if _, err := ts.AppendLLMMessage(fantasy.NewUserMessage("first")); err != nil {
t.Fatal(err)
}
if _, err := ts.AppendLLMMessage(fantasy.Message{
Role: fantasy.MessageRoleAssistant,
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "ack 1"}},
}); err != nil {
t.Fatal(err)
}
if _, err := ts.AppendLLMMessage(fantasy.NewUserMessage("second")); err != nil {
t.Fatal(err)
}
if _, err := ts.AppendLLMMessage(fantasy.Message{
Role: fantasy.MessageRoleAssistant,
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "ack 2"}},
}); err != nil {
t.Fatal(err)
}
app := New(Options{TreeSession: ts, PromptFunc: newStub().fn}, nil)
defer app.Close()
prompt, files, err := app.PopLastUserMessage()
if err != nil {
t.Fatalf("PopLastUserMessage: %v", err)
}
if prompt != "second" {
t.Fatalf("expected prompt=%q, got %q", "second", prompt)
}
if files != nil {
t.Fatalf("expected no files, got %v", files)
}
// After truncation the branch should only contain the first user
// message and its assistant response (the "second" turn is orphaned).
msgs := ts.GetLLMMessages()
if len(msgs) != 2 {
t.Fatalf("expected 2 messages on truncated branch, got %d", len(msgs))
}
if got := messageText(msgs[0]); got != "first" {
t.Fatalf("expected first message %q, got %q", "first", got)
}
if got := messageText(msgs[1]); got != "ack 1" {
t.Fatalf("expected second message %q, got %q", "ack 1", got)
}
}
// messageText extracts concatenated TextPart content from a fantasy.Message.
func messageText(m fantasy.Message) string {
var out strings.Builder
for _, p := range m.Content {
if tp, ok := p.(fantasy.TextPart); ok {
out.WriteString(tp.Text)
}
}
return out.String()
}
// TestPopLastUserMessage_NoUserOnBranch verifies that an empty tree (no
// user messages at all) returns a friendly error rather than panicking.
func TestPopLastUserMessage_NoUserOnBranch(t *testing.T) {
dir := t.TempDir()
ts, err := session.CreateTreeSession(dir)
if err != nil {
t.Fatalf("create tree session: %v", err)
}
defer func() { _ = ts.Close() }()
app := New(Options{TreeSession: ts, PromptFunc: newStub().fn}, nil)
defer app.Close()
_, _, err = app.PopLastUserMessage()
if err == nil {
t.Fatal("expected error when no user message exists on branch")
}
if !strings.Contains(err.Error(), "no user message") {
t.Fatalf("expected error mentioning missing user message, got %q", err.Error())
}
}
// --------------------------------------------------------------------------
// WaitForIdle / RequestNewSessionFromExtension (issue #63)
// --------------------------------------------------------------------------
// TestWaitForIdle_AlreadyIdle verifies the fast path: a freshly constructed
// App is idle and WaitForIdle returns immediately without consulting the
// timeout.
func TestWaitForIdle_AlreadyIdle(t *testing.T) {
app := newTestApp(newStub())
defer app.Close()
start := time.Now()
if err := app.WaitForIdle(2 * time.Second); err != nil {
t.Fatalf("WaitForIdle on idle app: %v", err)
}
if elapsed := time.Since(start); elapsed > 100*time.Millisecond {
t.Fatalf("WaitForIdle blocked for %s on already-idle app", elapsed)
}
}
// TestWaitForIdle_BlocksUntilDrain reproduces the issue #63 race: while
// drainQueue holds busy==true the call should block, then return nil as soon
// as the drain completes.
func TestWaitForIdle_BlocksUntilDrain(t *testing.T) {
gate := make(chan struct{})
var gateOnce sync.Once
closeGate := func() { gateOnce.Do(func() { close(gate) }) }
stub := newStubWithFuncs(
func(ctx context.Context) (*kit.TurnResult, error) {
select {
case <-gate:
case <-ctx.Done():
return nil, ctx.Err()
}
return turnResult("done"), nil
},
)
app := newTestApp(stub)
t.Cleanup(func() {
closeGate()
app.Close()
})
app.Run("hello")
// Confirm the agent is busy before we start waiting.
if !waitForCondition(2*time.Second, func() bool { return app.IsBusy() }) {
t.Fatal("app never became busy after Run()")
}
errCh := make(chan error, 1)
go func() {
errCh <- app.WaitForIdle(5 * time.Second)
}()
// Should not return while the stub is blocked.
select {
case err := <-errCh:
t.Fatalf("WaitForIdle returned early (err=%v) while agent still busy", err)
case <-time.After(150 * time.Millisecond):
}
closeGate()
select {
case err := <-errCh:
if err != nil {
t.Fatalf("WaitForIdle: %v", err)
}
case <-time.After(3 * time.Second):
t.Fatal("WaitForIdle did not return after drain completed")
}
if app.IsBusy() {
t.Fatal("app still reports busy after WaitForIdle returned")
}
}
// TestWaitForIdle_TimeoutReturnsErrAgentBusy verifies that a slow turn yields
// ErrAgentBusy (detectable via errors.Is) when the deadline elapses.
func TestWaitForIdle_TimeoutReturnsErrAgentBusy(t *testing.T) {
gate := make(chan struct{})
stub := newStubWithFuncs(
func(ctx context.Context) (*kit.TurnResult, error) {
select {
case <-gate:
case <-ctx.Done():
return nil, ctx.Err()
}
return turnResult("done"), nil
},
)
app := newTestApp(stub)
// Release the stub before Close so wg.Wait() can return.
t.Cleanup(func() {
close(gate)
app.Close()
})
app.Run("hello")
if !waitForCondition(2*time.Second, func() bool { return app.IsBusy() }) {
t.Fatal("app never became busy after Run()")
}
err := app.WaitForIdle(50 * time.Millisecond)
if !errors.Is(err, ErrAgentBusy) {
t.Fatalf("expected ErrAgentBusy on timeout, got %v", err)
}
}
// TestWaitForIdle_ZeroTimeoutWaitsIndefinitely verifies that a non-positive
// timeout still blocks until idle (or shutdown) — not an instant ErrAgentBusy.
func TestWaitForIdle_ZeroTimeoutWaitsIndefinitely(t *testing.T) {
gate := make(chan struct{})
var gateOnce sync.Once
closeGate := func() { gateOnce.Do(func() { close(gate) }) }
stub := newStubWithFuncs(
func(ctx context.Context) (*kit.TurnResult, error) {
select {
case <-gate:
case <-ctx.Done():
return nil, ctx.Err()
}
return turnResult("done"), nil
},
)
app := newTestApp(stub)
t.Cleanup(func() {
closeGate()
app.Close()
})
app.Run("hello")
if !waitForCondition(2*time.Second, func() bool { return app.IsBusy() }) {
t.Fatal("app never became busy after Run()")
}
errCh := make(chan error, 1)
go func() { errCh <- app.WaitForIdle(0) }()
select {
case err := <-errCh:
t.Fatalf("WaitForIdle(0) returned early with %v while agent was busy", err)
case <-time.After(150 * time.Millisecond):
}
closeGate()
select {
case err := <-errCh:
if err != nil {
t.Fatalf("WaitForIdle(0) returned %v after idle", err)
}
case <-time.After(3 * time.Second):
t.Fatal("WaitForIdle(0) did not return after drain completed")
}
}
// TestWaitForIdle_AppClose verifies that shutting down the app while a
// caller is blocked in WaitForIdle releases the wait.
func TestWaitForIdle_AppClose(t *testing.T) {
gate := make(chan struct{})
stub := newStubWithFuncs(
func(ctx context.Context) (*kit.TurnResult, error) {
select {
case <-gate:
case <-ctx.Done():
return nil, ctx.Err()
}
return turnResult("done"), nil
},
)
app := newTestApp(stub)
app.Run("hello")
if !waitForCondition(2*time.Second, func() bool { return app.IsBusy() }) {
t.Fatal("app never became busy after Run()")
}
errCh := make(chan error, 1)
go func() { errCh <- app.WaitForIdle(5 * time.Second) }()
// Give the goroutine a moment to enter the wait.
time.Sleep(50 * time.Millisecond)
// rootCancel is called by Close, which should release the waiter
// before drainQueue itself observes the cancellation and clears busy.
go func() {
// Unblock the stub so Close() can proceed past wg.Wait().
close(gate)
}()
app.Close()
select {
case err := <-errCh:
// Either rootCtx cancellation propagated first (err = context.Canceled)
// or the drain finished cleanly first (err == nil); both are
// acceptable terminations. The key invariant is that WaitForIdle
// does not hang past Close.
if err != nil && !errors.Is(err, context.Canceled) {
t.Fatalf("WaitForIdle returned unexpected error: %v", err)
}
case <-time.After(3 * time.Second):
t.Fatal("WaitForIdle did not return after Close()")
}
}
// TestRequestNewSessionFromExtension_NoTUI verifies the headless guard: with
// no Bubble Tea program registered the call fails fast (no busy-wait).
func TestRequestNewSessionFromExtension_NoTUI(t *testing.T) {
app := newTestApp(newStub())
defer app.Close()
err := app.RequestNewSessionFromExtension("hello")
if err == nil {
t.Fatal("expected error in headless mode")
}
if !strings.Contains(err.Error(), "no interactive TUI") {
t.Fatalf("expected 'no interactive TUI' error, got %q", err.Error())
}
}
// TestBusyTransitionsSignalIdleCh exercises the setBusyLocked invariants
// directly: a fresh App is idle (closed channel); Run() opens a new channel
// that is then closed when drainQueue exits.
func TestBusyTransitionsSignalIdleCh(t *testing.T) {
app := newTestApp(newStub("ok"))
defer app.Close()
// Initial state: closed channel, busy==false.
busy, ch := app.idleSnapshot()
if busy {
t.Fatal("freshly constructed App should not be busy")
}
select {
case <-ch:
default:
t.Fatal("initial idleCh should already be closed")
}
gate := make(chan struct{})
var gateOnce sync.Once
closeGate := func() { gateOnce.Do(func() { close(gate) }) }
stub := newStubWithFuncs(func(ctx context.Context) (*kit.TurnResult, error) {
select {
case <-gate:
case <-ctx.Done():
return nil, ctx.Err()
}
return turnResult("ok"), nil
})
app2 := newTestApp(stub)
t.Cleanup(func() {
closeGate()
app2.Close()
})
app2.Run("hello")
if !waitForCondition(2*time.Second, func() bool { return app2.IsBusy() }) {
t.Fatal("app2 never became busy")
}
_, ch2 := app2.idleSnapshot()
select {
case <-ch2:
t.Fatal("idleCh should be open while busy")
default:
}
closeGate()
select {
case <-ch2:
case <-time.After(3 * time.Second):
t.Fatal("idleCh was never closed after drain completed")
}
}
+15
View File
@@ -247,6 +247,21 @@ type EditorTextSetEvent struct {
Text string
}
// NewSessionRequestEvent is sent when an extension calls ctx.NewSession to
// end the current session and start a fresh one. The TUI routes this into
// the same /new code path (including the BeforeSessionSwitch hook and any
// @file expansion in InitialPrompt). ResponseCh, when non-nil, receives a
// single result so the extension goroutine can observe success or failure.
type NewSessionRequestEvent struct {
// InitialPrompt, when non-empty, is the first user turn to submit
// after the session switch. @file references are expanded.
InitialPrompt string
// ResponseCh receives the outcome (nil error on success). Must be
// buffered (cap >= 1) so the TUI never blocks. May be nil if the
// caller does not need the result.
ResponseCh chan<- error
}
// ExtensionPrintEvent is sent when an extension calls ctx.Print, ctx.PrintInfo,
// ctx.PrintError, or ctx.PrintBlock. The TUI renders it via the appropriate
// renderer and tea.Println (scrollback); the CLI handler uses
-5
View File
@@ -13,11 +13,6 @@ type MessageStore struct {
messages []kit.LLMMessage
}
// NewMessageStore creates an empty MessageStore.
func NewMessageStore() *MessageStore {
return &MessageStore{}
}
// NewMessageStoreWithMessages creates a MessageStore pre-populated with the
// given messages. This is used when loading an existing session at startup.
func NewMessageStoreWithMessages(msgs []kit.LLMMessage) *MessageStore {
+10 -10
View File
@@ -29,7 +29,7 @@ func textOf(msg kit.LLMMessage) string {
// --------------------------------------------------------------------------
func TestNewMessageStore_empty(t *testing.T) {
s := NewMessageStore()
s := NewMessageStoreWithMessages(nil)
if s == nil {
t.Fatal("expected non-nil store")
}
@@ -72,7 +72,7 @@ func TestNewMessageStoreWithMessages_isolatesInput(t *testing.T) {
// --------------------------------------------------------------------------
func TestAdd_appendsMessage(t *testing.T) {
s := NewMessageStore()
s := NewMessageStoreWithMessages(nil)
s.Add(makeTextMsg("user", "first"))
s.Add(makeTextMsg("assistant", "second"))
@@ -82,7 +82,7 @@ func TestAdd_appendsMessage(t *testing.T) {
}
func TestAdd_preservesOrder(t *testing.T) {
s := NewMessageStore()
s := NewMessageStoreWithMessages(nil)
texts := []string{"a", "b", "c"}
for _, t2 := range texts {
s.Add(makeTextMsg("user", t2))
@@ -100,7 +100,7 @@ func TestAdd_preservesOrder(t *testing.T) {
// --------------------------------------------------------------------------
func TestReplace_swapsHistory(t *testing.T) {
s := NewMessageStore()
s := NewMessageStoreWithMessages(nil)
s.Add(makeTextMsg("user", "old"))
replacement := []kit.LLMMessage{
@@ -120,7 +120,7 @@ func TestReplace_swapsHistory(t *testing.T) {
// Replace must deep-copy the incoming slice.
func TestReplace_isolatesInput(t *testing.T) {
s := NewMessageStore()
s := NewMessageStoreWithMessages(nil)
replacement := []kit.LLMMessage{makeTextMsg("user", "original")}
s.Replace(replacement)
@@ -137,7 +137,7 @@ func TestReplace_isolatesInput(t *testing.T) {
// --------------------------------------------------------------------------
func TestGetAll_returnsCopy(t *testing.T) {
s := NewMessageStore()
s := NewMessageStoreWithMessages(nil)
s.Add(makeTextMsg("user", "hello"))
got := s.GetAll()
@@ -151,7 +151,7 @@ func TestGetAll_returnsCopy(t *testing.T) {
}
func TestGetAll_emptyStore(t *testing.T) {
s := NewMessageStore()
s := NewMessageStoreWithMessages(nil)
got := s.GetAll()
if len(got) != 0 {
t.Fatalf("expected empty slice, got %d elements", len(got))
@@ -163,7 +163,7 @@ func TestGetAll_emptyStore(t *testing.T) {
// --------------------------------------------------------------------------
func TestClear_removesAllMessages(t *testing.T) {
s := NewMessageStore()
s := NewMessageStoreWithMessages(nil)
s.Add(makeTextMsg("user", "a"))
s.Add(makeTextMsg("user", "b"))
s.Clear()
@@ -174,7 +174,7 @@ func TestClear_removesAllMessages(t *testing.T) {
}
func TestClear_allowsSubsequentAdds(t *testing.T) {
s := NewMessageStore()
s := NewMessageStoreWithMessages(nil)
s.Add(makeTextMsg("user", "before"))
s.Clear()
s.Add(makeTextMsg("user", "after"))
@@ -193,7 +193,7 @@ func TestClear_allowsSubsequentAdds(t *testing.T) {
// --------------------------------------------------------------------------
func TestConcurrentAccess(t *testing.T) {
s := NewMessageStore()
s := NewMessageStoreWithMessages(nil)
done := make(chan struct{})
// Writer goroutine.
+138 -5
View File
@@ -1,6 +1,7 @@
package auth
import (
"context"
"encoding/json"
"fmt"
"os"
@@ -9,11 +10,11 @@ import (
"time"
)
// CredentialStore holds all stored credentials for various providers.
// Currently supports Anthropic and OpenAI credentials with both OAuth and API key authentication methods.
// CredentialStore holds stored credentials for Anthropic, OpenAI, and GitHub Copilot.
type CredentialStore struct {
Anthropic *AnthropicCredentials `json:"anthropic,omitempty"`
OpenAI *OpenAICredentials `json:"openai,omitempty"`
Copilot *CopilotCredentials `json:"copilot,omitempty"`
}
// AnthropicCredentials holds Anthropic API credentials supporting both OAuth
@@ -43,6 +44,16 @@ type OpenAICredentials struct {
CreatedAt time.Time `json:"created_at"`
}
// CopilotCredentials holds GitHub OAuth credentials and the short-lived
// GitHub Copilot API token derived from them.
type CopilotCredentials struct {
Type string `json:"type"` // "oauth"
GitHubToken string `json:"github_token,omitempty"` // GitHub device-flow OAuth token
CopilotAccessToken string `json:"copilot_access_token,omitempty"` // Short-lived Copilot API token
ExpiresAt int64 `json:"expires_at,omitempty"` // Copilot token expiry
CreatedAt time.Time `json:"created_at"`
}
// oauthTokenExpired reports whether an OAuth token with the given type and
// expiry unix timestamp is past its expiry. Returns false for API key
// credentials or when no expiry is set.
@@ -91,6 +102,16 @@ func (c *OpenAICredentials) NeedsRefresh() bool {
return oauthTokenNeedsRefresh(c.Type, c.ExpiresAt)
}
// IsExpired checks if the Copilot API token is expired.
func (c *CopilotCredentials) IsExpired() bool {
return oauthTokenExpired(c.Type, c.ExpiresAt)
}
// NeedsRefresh reports whether the Copilot API token should be renewed.
func (c *CopilotCredentials) NeedsRefresh() bool {
return oauthTokenNeedsRefresh(c.Type, c.ExpiresAt)
}
// CredentialManager handles secure storage and retrieval of authentication credentials.
// It manages a JSON file stored in the user's config directory with appropriate
// file permissions for security.
@@ -222,7 +243,7 @@ func (cm *CredentialManager) RemoveAnthropicCredentials() error {
store.Anthropic = nil
// If store is empty, remove the file entirely
if store.Anthropic == nil {
if store.Anthropic == nil && store.OpenAI == nil && store.Copilot == nil {
if err := os.Remove(cm.credentialsPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove credentials file: %w", err)
}
@@ -279,7 +300,7 @@ func (cm *CredentialManager) RemoveOpenAICredentials() error {
store.OpenAI = nil
// If store is empty, remove the file entirely
if store.Anthropic == nil && store.OpenAI == nil {
if store.Anthropic == nil && store.OpenAI == nil && store.Copilot == nil {
if err := os.Remove(cm.credentialsPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove credentials file: %w", err)
}
@@ -289,6 +310,104 @@ func (cm *CredentialManager) RemoveOpenAICredentials() error {
return cm.SaveCredentials(store)
}
// GetCopilotCredentials retrieves stored GitHub Copilot credentials.
func (cm *CredentialManager) GetCopilotCredentials() (*CopilotCredentials, error) {
store, err := cm.LoadCredentials()
if err != nil {
return nil, err
}
return store.Copilot, nil
}
// RemoveCopilotCredentials removes stored GitHub Copilot credentials.
func (cm *CredentialManager) RemoveCopilotCredentials() error {
store, err := cm.LoadCredentials()
if err != nil {
return err
}
store.Copilot = nil
if store.Anthropic == nil && store.OpenAI == nil && store.Copilot == nil {
if err := os.Remove(cm.credentialsPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove credentials file: %w", err)
}
return nil
}
return cm.SaveCredentials(store)
}
// HasCopilotCredentials checks if valid GitHub Copilot credentials are stored.
func (cm *CredentialManager) HasCopilotCredentials() (bool, error) {
creds, err := cm.GetCopilotCredentials()
if err != nil {
return false, err
}
if creds == nil {
return false, nil
}
return creds.Type == "oauth" && creds.GitHubToken != "", nil
}
// SetCopilotOAuthCredentials stores GitHub Copilot OAuth credentials.
func (cm *CredentialManager) SetCopilotOAuthCredentials(creds *CopilotCredentials) error {
store, err := cm.LoadCredentials()
if err != nil {
return err
}
store.Copilot = creds
return cm.SaveCredentials(store)
}
// GetValidCopilotAccessToken returns a fresh Copilot API token, renewing it
// with the stored GitHub OAuth token when needed.
func (cm *CredentialManager) GetValidCopilotAccessToken() (string, error) {
return cm.GetValidCopilotAccessTokenContext(context.Background())
}
// GetValidCopilotAccessTokenContext returns a fresh Copilot API token, renewing
// it with the stored GitHub OAuth token when needed.
func (cm *CredentialManager) GetValidCopilotAccessTokenContext(ctx context.Context) (string, error) {
if ctx == nil {
ctx = context.Background()
}
creds, err := cm.GetCopilotCredentials()
if err != nil {
return "", err
}
if creds == nil {
return "", fmt.Errorf("no Copilot credentials found")
}
if creds.Type != "oauth" {
return "", fmt.Errorf("unknown credential type: %s", creds.Type)
}
if creds.GitHubToken == "" {
return "", fmt.Errorf("GitHub OAuth token missing from Copilot credentials")
}
if creds.CopilotAccessToken == "" || creds.NeedsRefresh() {
client := NewCopilotOAuthClient()
newCreds, err := client.RefreshCopilotToken(ctx, creds.GitHubToken)
if err != nil {
return "", fmt.Errorf("failed to refresh Copilot token: %w", err)
}
newCreds.CreatedAt = creds.CreatedAt
if err := cm.SetCopilotOAuthCredentials(newCreds); err != nil {
return "", fmt.Errorf("failed to save refreshed Copilot token: %w", err)
}
return newCreds.CopilotAccessToken, nil
}
return creds.CopilotAccessToken, nil
}
// HasOpenAICredentials checks if valid OpenAI credentials are stored.
// Returns true if either a non-empty OAuth access token or API key is present,
// false otherwise. Returns an error if credentials cannot be loaded.
@@ -394,6 +513,20 @@ func validateAnthropicAPIKey(apiKey string) error {
return nil
}
// CredentialSourceOAuth is the source description returned by
// GetAnthropicAPIKey when the key resolves to stored OAuth credentials.
// Consumers should compare against this constant (or use IsAnthropicOAuth)
// rather than matching the string literal.
const CredentialSourceOAuth = "stored OAuth credentials"
// IsAnthropicOAuth reports whether the active Anthropic credential resolves
// to a stored OAuth token (in which case the user is not billed per-token).
// flagValue is the --provider-api-key flag value (may be empty).
func IsAnthropicOAuth(flagValue string) bool {
_, source, err := GetAnthropicAPIKey(flagValue)
return err == nil && source == CredentialSourceOAuth
}
// GetAnthropicAPIKey retrieves an Anthropic API key from multiple sources in priority order:
// 1. Command-line flag value (highest priority)
// 2. Stored credentials (OAuth or API key)
@@ -416,7 +549,7 @@ func GetAnthropicAPIKey(flagValue string) (string, string, error) {
if err != nil {
return "", "", fmt.Errorf("failed to get valid OAuth token: %w", err)
}
return token, "stored OAuth credentials", nil
return token, CredentialSourceOAuth, nil
} else if creds.Type == "api_key" && creds.APIKey != "" {
return creds.APIKey, "stored API key", nil
}
+97
View File
@@ -4,6 +4,7 @@ import (
"os"
"path/filepath"
"testing"
"time"
)
func TestCredentialManager(t *testing.T) {
@@ -215,6 +216,7 @@ func TestCredentialStorePersistence(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer func() { _ = os.RemoveAll(tempDir) }()
credentialsPath := filepath.Join(tempDir, "credentials.json")
@@ -252,3 +254,98 @@ func TestCredentialStorePersistence(t *testing.T) {
t.Errorf("Expected file permissions 0600, got %v", info.Mode().Perm())
}
}
func TestCopilotCredentials(t *testing.T) {
tempDir, err := os.MkdirTemp("", "kit-auth-test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer func() { _ = os.RemoveAll(tempDir) }()
cm := &CredentialManager{
credentialsPath: filepath.Join(tempDir, "credentials.json"),
}
creds := &CopilotCredentials{
Type: "oauth",
GitHubToken: "github-token",
CopilotAccessToken: "copilot-token",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
CreatedAt: time.Now(),
}
if err := cm.SetCopilotOAuthCredentials(creds); err != nil {
t.Fatalf("SetCopilotOAuthCredentials failed: %v", err)
}
hasAuth, err := cm.HasCopilotCredentials()
if err != nil {
t.Fatalf("HasCopilotCredentials failed: %v", err)
}
if !hasAuth {
t.Fatal("Expected Copilot credentials")
}
token, err := cm.GetValidCopilotAccessToken()
if err != nil {
t.Fatalf("GetValidCopilotAccessToken failed: %v", err)
}
if token != creds.CopilotAccessToken {
t.Fatalf("Expected Copilot token %q, got %q", creds.CopilotAccessToken, token)
}
if err := cm.RemoveCopilotCredentials(); err != nil {
t.Fatalf("RemoveCopilotCredentials failed: %v", err)
}
hasAuth, err = cm.HasCopilotCredentials()
if err != nil {
t.Fatalf("HasCopilotCredentials after removal failed: %v", err)
}
if hasAuth {
t.Fatal("Expected no Copilot credentials after removal")
}
}
func TestRemoveCredentialsPreservesOtherProviders(t *testing.T) {
tempDir, err := os.MkdirTemp("", "kit-auth-test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer func() { _ = os.RemoveAll(tempDir) }()
cm := &CredentialManager{
credentialsPath: filepath.Join(tempDir, "credentials.json"),
}
if err := cm.SetOpenAIOAuthCredentials(&OpenAICredentials{
Type: "oauth",
AccessToken: "openai-token",
RefreshToken: "refresh-token",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
AccountID: "account",
CreatedAt: time.Now(),
}); err != nil {
t.Fatalf("SetOpenAIOAuthCredentials failed: %v", err)
}
if err := cm.SetCopilotOAuthCredentials(&CopilotCredentials{
Type: "oauth",
GitHubToken: "github-token",
CopilotAccessToken: "copilot-token",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
CreatedAt: time.Now(),
}); err != nil {
t.Fatalf("SetCopilotOAuthCredentials failed: %v", err)
}
if err := cm.RemoveCopilotCredentials(); err != nil {
t.Fatalf("RemoveCopilotCredentials failed: %v", err)
}
hasOpenAI, err := cm.HasOpenAICredentials()
if err != nil {
t.Fatalf("HasOpenAICredentials failed: %v", err)
}
if !hasOpenAI {
t.Fatal("Expected OpenAI credentials to remain after removing Copilot credentials")
}
}
+257
View File
@@ -10,6 +10,7 @@ import (
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
)
@@ -211,6 +212,262 @@ type OpenAIOAuthClient struct {
Scopes string
}
// CopilotOAuthClient handles GitHub device-flow OAuth and exchanges the
// GitHub token for a short-lived GitHub Copilot API token.
//
// The GitHub token comes from GitHub's OAuth device flow. It is then presented
// to GitHub's internal Copilot token endpoint, which returns the bearer token
// used by api.githubcopilot.com.
type CopilotOAuthClient struct {
ClientID string
DeviceURL string
TokenURL string
CopilotURL string
Scopes string
PollTimeout time.Duration
ClientTimeout time.Duration
}
// CopilotDeviceCode contains data returned by GitHub's device-code endpoint.
type CopilotDeviceCode struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
VerificationURI string `json:"verification_uri"`
ExpiresIn int `json:"expires_in"`
Interval int `json:"interval"`
}
// NewCopilotOAuthClient creates a GitHub Copilot OAuth client.
func NewCopilotOAuthClient() *CopilotOAuthClient {
return &CopilotOAuthClient{
ClientID: "Iv1.b507a08c87ecfe98",
DeviceURL: "https://github.com/login/device/code",
TokenURL: "https://github.com/login/oauth/access_token",
CopilotURL: "https://api.github.com/copilot_internal/v2/token",
Scopes: "read:user",
PollTimeout: 15 * time.Minute,
ClientTimeout: 30 * time.Second,
}
}
// StartDeviceFlow requests a GitHub device code for browser login.
//
// The returned user code and verification URI are displayed by loginCopilot.
// GitHub's response may omit interval, so this method normalizes it to the
// documented five-second default.
func (c *CopilotOAuthClient) StartDeviceFlow(ctx context.Context) (*CopilotDeviceCode, error) {
if ctx == nil {
ctx = context.Background()
}
data := url.Values{
"client_id": {c.ClientID},
"scope": {c.Scopes},
}
req, err := http.NewRequestWithContext(ctx, "POST", c.DeviceURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create device-code request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := (&http.Client{Timeout: c.ClientTimeout}).Do(req)
if err != nil {
return nil, fmt.Errorf("failed to request device code: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("device-code request failed with status %d: %s", resp.StatusCode, string(body))
}
var code CopilotDeviceCode
if err := json.NewDecoder(resp.Body).Decode(&code); err != nil {
return nil, fmt.Errorf("failed to decode device-code response: %w", err)
}
if code.DeviceCode == "" || code.UserCode == "" || code.VerificationURI == "" {
return nil, fmt.Errorf("device-code response missing required fields")
}
if code.Interval <= 0 {
code.Interval = 5
}
return &code, nil
}
// PollDeviceToken waits until the user authorizes the device code and returns
// the resulting GitHub OAuth token.
//
// It follows GitHub's device-flow polling contract: authorization_pending keeps
// polling, slow_down increases the interval, and polling stops at the earlier of
// the client timeout or the device-code expiry.
func (c *CopilotOAuthClient) PollDeviceToken(ctx context.Context, deviceCode *CopilotDeviceCode) (string, error) {
if ctx == nil {
ctx = context.Background()
}
if deviceCode == nil || deviceCode.DeviceCode == "" {
return "", fmt.Errorf("device code missing")
}
deadline := time.Now().Add(c.PollTimeout)
if deviceCode.ExpiresIn > 0 {
expiresAt := time.Now().Add(time.Duration(deviceCode.ExpiresIn) * time.Second)
if expiresAt.Before(deadline) {
deadline = expiresAt
}
}
interval := time.Duration(deviceCode.Interval) * time.Second
if interval <= 0 {
interval = 5 * time.Second
}
for time.Now().Before(deadline) {
wait := interval
if remaining := time.Until(deadline); remaining < wait {
wait = remaining
}
select {
case <-ctx.Done():
return "", ctx.Err()
case <-time.After(wait):
}
data := url.Values{
"client_id": {c.ClientID},
"device_code": {deviceCode.DeviceCode},
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
}
req, err := http.NewRequestWithContext(ctx, "POST", c.TokenURL, strings.NewReader(data.Encode()))
if err != nil {
return "", fmt.Errorf("failed to create device-token request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := (&http.Client{Timeout: c.ClientTimeout}).Do(req)
if err != nil {
return "", fmt.Errorf("failed to poll device token: %w", err)
}
var tokenResp struct {
AccessToken string `json:"access_token"`
Error string `json:"error"`
Description string `json:"error_description"`
}
decodeErr := json.NewDecoder(resp.Body).Decode(&tokenResp)
_ = resp.Body.Close()
if decodeErr != nil {
return "", fmt.Errorf("failed to decode device-token response: %w", decodeErr)
}
if tokenResp.AccessToken != "" {
return tokenResp.AccessToken, nil
}
switch tokenResp.Error {
case "authorization_pending":
continue
case "slow_down":
interval += 5 * time.Second
continue
case "expired_token":
return "", fmt.Errorf("device code expired; restart login")
case "access_denied":
return "", fmt.Errorf("github login denied")
case "":
return "", fmt.Errorf("device-token request failed with status %d", resp.StatusCode)
default:
if tokenResp.Description != "" {
return "", fmt.Errorf("device-token request failed: %s: %s", tokenResp.Error, tokenResp.Description)
}
return "", fmt.Errorf("device-token request failed: %s", tokenResp.Error)
}
}
return "", fmt.Errorf("timed out waiting for github device authorization")
}
// ExchangeGitHubToken converts a GitHub OAuth token into a Copilot API token.
// It is a semantic wrapper over RefreshCopilotToken used by the login flow.
func (c *CopilotOAuthClient) ExchangeGitHubToken(ctx context.Context, githubToken string) (*CopilotCredentials, error) {
return c.RefreshCopilotToken(ctx, githubToken)
}
// RefreshCopilotToken obtains a fresh short-lived Copilot token from GitHub.
//
// GitHub may return expires_at as either a Unix timestamp or RFC3339 string.
// parseCopilotExpiry handles both forms and falls back to a conservative
// 20-minute lifetime when the field is absent or unrecognized.
func (c *CopilotOAuthClient) RefreshCopilotToken(ctx context.Context, githubToken string) (*CopilotCredentials, error) {
if ctx == nil {
ctx = context.Background()
}
req, err := http.NewRequestWithContext(ctx, "GET", c.CopilotURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create copilot token request: %w", err)
}
req.Header.Set("Authorization", "token "+githubToken)
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", "kit")
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
resp, err := (&http.Client{Timeout: c.ClientTimeout}).Do(req)
if err != nil {
return nil, fmt.Errorf("failed to request copilot token: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("copilot token request failed with status %d: %s", resp.StatusCode, string(body))
}
var tokenResp struct {
Token string `json:"token"`
ExpiresAt any `json:"expires_at"`
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return nil, fmt.Errorf("failed to decode copilot token response: %w", err)
}
if tokenResp.Token == "" {
return nil, fmt.Errorf("copilot token response missing token")
}
expiresAt := parseCopilotExpiry(tokenResp.ExpiresAt)
if expiresAt == 0 {
expiresAt = time.Now().Add(20 * time.Minute).Unix()
}
return &CopilotCredentials{
Type: "oauth",
GitHubToken: githubToken,
CopilotAccessToken: tokenResp.Token,
ExpiresAt: expiresAt,
CreatedAt: time.Now(),
}, nil
}
// parseCopilotExpiry normalizes GitHub's expires_at variants to a Unix second.
func parseCopilotExpiry(value any) int64 {
switch v := value.(type) {
case float64:
return int64(v)
case string:
if parsed, err := strconv.ParseInt(v, 10, 64); err == nil {
return parsed
}
if parsed, err := time.Parse(time.RFC3339, v); err == nil {
return parsed.Unix()
}
}
return 0
}
// NewOpenAIOAuthClient creates a new OAuth client configured for OpenAI Codex OAuth.
// This uses the public client ID for CLI applications with PKCE for security.
func NewOpenAIOAuthClient() *OpenAIOAuthClient {
+124
View File
@@ -0,0 +1,124 @@
package auth
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestCopilotStartDeviceFlow(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Fatalf("expected POST, got %s", r.Method)
}
if err := r.ParseForm(); err != nil {
t.Fatalf("ParseForm failed: %v", err)
}
if r.Form.Get("client_id") != "client-id" {
t.Fatalf("expected client id, got %q", r.Form.Get("client_id"))
}
if r.Form.Get("scope") != "read:user" {
t.Fatalf("expected scope, got %q", r.Form.Get("scope"))
}
_ = json.NewEncoder(w).Encode(map[string]any{
"device_code": "device-code",
"user_code": "USER-CODE",
"verification_uri": "https://github.com/login/device",
"expires_in": 600,
"interval": 1,
})
}))
defer server.Close()
client := NewCopilotOAuthClient()
client.ClientID = "client-id"
client.DeviceURL = server.URL
code, err := client.StartDeviceFlow(context.Background())
if err != nil {
t.Fatalf("StartDeviceFlow failed: %v", err)
}
if code.DeviceCode != "device-code" || code.UserCode != "USER-CODE" || code.Interval != 1 {
t.Fatalf("unexpected device code: %#v", code)
}
}
func TestCopilotPollDeviceToken(t *testing.T) {
polls := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
polls++
if r.Method != http.MethodPost {
t.Fatalf("expected POST, got %s", r.Method)
}
if err := r.ParseForm(); err != nil {
t.Fatalf("ParseForm failed: %v", err)
}
if r.Form.Get("grant_type") != "urn:ietf:params:oauth:grant-type:device_code" {
t.Fatalf("unexpected grant type: %q", r.Form.Get("grant_type"))
}
if polls == 1 {
_ = json.NewEncoder(w).Encode(map[string]any{"error": "authorization_pending"})
return
}
_ = json.NewEncoder(w).Encode(map[string]any{"access_token": "github-token"})
}))
defer server.Close()
client := NewCopilotOAuthClient()
client.ClientID = "client-id"
client.TokenURL = server.URL
client.PollTimeout = 5 * time.Second
client.ClientTimeout = time.Second
token, err := client.PollDeviceToken(context.Background(), &CopilotDeviceCode{
DeviceCode: "device-code",
ExpiresIn: 10,
Interval: 1,
})
if err != nil {
t.Fatalf("PollDeviceToken failed: %v", err)
}
if token != "github-token" {
t.Fatalf("expected github-token, got %q", token)
}
if polls != 2 {
t.Fatalf("expected 2 polls, got %d", polls)
}
}
func TestCopilotRefreshToken(t *testing.T) {
expiresAt := time.Now().Add(time.Hour).Unix()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
t.Fatalf("expected GET, got %s", r.Method)
}
if r.Header.Get("Authorization") != "token github-token" {
t.Fatalf("unexpected authorization header: %q", r.Header.Get("Authorization"))
}
if r.Header.Get("User-Agent") != "kit" {
t.Fatalf("unexpected user agent: %q", r.Header.Get("User-Agent"))
}
_ = json.NewEncoder(w).Encode(map[string]any{
"token": "copilot-token",
"expires_at": expiresAt,
})
}))
defer server.Close()
client := NewCopilotOAuthClient()
client.CopilotURL = server.URL
creds, err := client.RefreshCopilotToken(context.Background(), "github-token")
if err != nil {
t.Fatalf("RefreshCopilotToken failed: %v", err)
}
if creds.GitHubToken != "github-token" || creds.CopilotAccessToken != "copilot-token" {
t.Fatalf("unexpected credentials: %#v", creds)
}
if creds.ExpiresAt != expiresAt {
t.Fatalf("expected expires_at %d, got %d", expiresAt, creds.ExpiresAt)
}
}
+17 -10
View File
@@ -227,16 +227,17 @@ type GenerationParams struct {
// or other custom/ prefixed models. These models are loaded from the config file
// and merged into the custom provider in the model registry.
type CustomModelConfig struct {
Name string `json:"name" yaml:"name"`
BaseURL string `json:"baseUrl,omitempty" yaml:"baseUrl,omitempty"`
APIKey string `json:"apiKey,omitempty" yaml:"apiKey,omitempty"`
Family string `json:"family,omitempty" yaml:"family,omitempty"`
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
Temperature bool `json:"temperature,omitempty" yaml:"temperature,omitempty"`
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
Cost CostConfig `json:"cost" yaml:"cost"`
Limit LimitConfig `json:"limit" yaml:"limit"`
Name string `json:"name" yaml:"name"`
BaseURL string `json:"baseUrl,omitempty" yaml:"baseUrl,omitempty"`
APIKey string `json:"apiKey,omitempty" yaml:"apiKey,omitempty"`
APIModelName string `json:"apiModelName,omitempty" yaml:"apiModelName,omitempty"`
Family string `json:"family,omitempty" yaml:"family,omitempty"`
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
Temperature bool `json:"temperature,omitempty" yaml:"temperature,omitempty"`
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
Cost CostConfig `json:"cost" yaml:"cost"`
Limit LimitConfig `json:"limit" yaml:"limit"`
// Generation parameter defaults for this model.
// These are applied when the user hasn't explicitly set the corresponding
@@ -493,6 +494,12 @@ mcpServers:
# maxTokens: 16384
# systemPrompt: "You are a deep reasoning assistant." # or a file path
# Skills configuration (all optional)
# no-skills: false # Set to true to disable all skill loading
# skill: # Explicit skill files/dirs (disables auto-discovery)
# - "/path/to/skill.md"
# skills-dir: "/path/to/skills" # Override project-local directory for auto-discovery
# API Configuration (can also use environment variables)
# provider-api-key: "your-api-key" # API key for OpenAI, Anthropic, or Google
# provider-url: "https://api.openai.com/v1" # Base URL for OpenAI, Anthropic, or Ollama
+3
View File
@@ -205,6 +205,9 @@ func TestEnsureConfigExists(t *testing.T) {
"type: \"local\"",
"type: \"remote\"",
"Core tools",
"# Skills configuration",
"no-skills:",
"skills-dir:",
}
for _, expected := range expectedSections {
-6
View File
@@ -56,9 +56,3 @@ func (e *EnvSubstituter) SubstituteEnvVars(content string) (string, error) {
return result, nil
}
// HasEnvVars checks if content contains environment variable patterns (${env://...}).
// This is useful for determining if substitution is needed before processing.
func HasEnvVars(content string) bool {
return envVarPattern.MatchString(content)
}
-38
View File
@@ -187,41 +187,3 @@ func TestEnvSubstituter_SubstituteEnvVars(t *testing.T) {
})
}
}
func TestHasEnvVars(t *testing.T) {
tests := []struct {
name string
content string
expected bool
}{
{
name: "has env vars",
content: `{"token": "${env://GITHUB_TOKEN}"}`,
expected: true,
},
{
name: "has env vars with default",
content: `{"debug": "${env://DEBUG:-false}"}`,
expected: true,
},
{
name: "no env vars",
content: `{"name": "${username}", "normal": "value"}`,
expected: false,
},
{
name: "empty content",
content: "",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := HasEnvVars(tt.content)
if result != tt.expected {
t.Errorf("Expected %v, got %v", tt.expected, result)
}
})
}
}
+57 -69
View File
@@ -59,12 +59,6 @@ func passwordPromptFromContext(ctx context.Context) PasswordPromptCallback {
return nil
}
// ContextWithSudoPassword returns a new context with the sudo password set.
// When present, the bash tool will use sudo -S to pipe this password to sudo commands.
func ContextWithSudoPassword(ctx context.Context, password string) context.Context {
return context.WithValue(ctx, sudoPasswordKey, password)
}
// sudoPasswordFromContext retrieves the sudo password from context.
func sudoPasswordFromContext(ctx context.Context) string {
if pw, ok := ctx.Value(sudoPasswordKey).(string); ok {
@@ -249,34 +243,37 @@ func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
return executeBashBuffered(cmdCtx, call, cmd, sudoPassword)
}
// executeBashBuffered collects all output before returning (original behavior).
// It uses explicit pipes (not cmd.Stdout) so that cmd.WaitDelay can forcibly
// close them when grandchild processes hold pipe handles open after the
// direct child exits.
func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd, sudoPassword string) (fantasy.ToolResponse, error) {
// setupBashPipes opens stdout/stderr pipes (plus an optional sudo stdin),
// starts the command, and asynchronously writes the sudo password if any.
// Returns the readers ready for the caller to consume. If setup fails,
// errResp is non-nil and the readers must not be used; the caller should
// return the response directly.
func setupBashPipes(cmd *exec.Cmd, sudoPassword string) (stdout, stderr io.Reader, errResp *fantasy.ToolResponse) {
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
return fantasy.NewTextErrorResponse("failed to create stdout pipe"), nil
r := fantasy.NewTextErrorResponse("failed to create stdout pipe")
return nil, nil, &r
}
stderrPipe, err := cmd.StderrPipe()
if err != nil {
return fantasy.NewTextErrorResponse("failed to create stderr pipe"), nil
r := fantasy.NewTextErrorResponse("failed to create stderr pipe")
return nil, nil, &r
}
// If we have a sudo password, create a stdin pipe and write the password
var stdinPipe io.WriteCloser
if sudoPassword != "" {
stdinPipe, err = cmd.StdinPipe()
if err != nil {
return fantasy.NewTextErrorResponse("failed to create stdin pipe"), nil
r := fantasy.NewTextErrorResponse("failed to create stdin pipe")
return nil, nil, &r
}
}
if err := cmd.Start(); err != nil {
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to start command: %v", err)), nil
r := fantasy.NewTextErrorResponse(fmt.Sprintf("failed to start command: %v", err))
return nil, nil, &r
}
// Write password to stdin if needed, then close stdin
if sudoPassword != "" && stdinPipe != nil {
go func() {
defer func() { _ = stdinPipe.Close() }()
@@ -284,19 +281,49 @@ func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exe
}()
}
return stdoutPipe, stderrPipe, nil
}
// interpretBashExit decodes cmd.Wait()'s error into an exit code, mapping
// context-deadline-exceeded to a friendly "command timed out" response.
// errResp is non-nil only when the caller should short-circuit and return
// it directly (e.g. timeout).
func interpretBashExit(waitErr error, cmdCtx context.Context) (exitCode int, errResp *fantasy.ToolResponse) {
if waitErr == nil {
return 0, nil
}
if exitErr, ok := waitErr.(*exec.ExitError); ok {
return exitErr.ExitCode(), nil
}
if cmdCtx.Err() == context.DeadlineExceeded {
r := fantasy.NewTextErrorResponse("command timed out")
return 0, &r
}
return 0, nil
}
// executeBashBuffered collects all output before returning (original behavior).
// It uses explicit pipes (not cmd.Stdout) so that cmd.WaitDelay can forcibly
// close them when grandchild processes hold pipe handles open after the
// direct child exits.
func executeBashBuffered(cmdCtx context.Context, _ fantasy.ToolCall, cmd *exec.Cmd, sudoPassword string) (fantasy.ToolResponse, error) {
stdoutPipe, stderrPipe, errResp := setupBashPipes(cmd, sudoPassword)
if errResp != nil {
return *errResp, nil
}
// Read pipes concurrently
var wg sync.WaitGroup
var stdout, stderr strings.Builder
var stdoutErr, stderrErr error
wg.Add(2)
go func() {
defer wg.Done()
_, stdoutErr = io.Copy(&stdout, stdoutPipe)
_, _ = io.Copy(&stdout, stdoutPipe)
}()
go func() {
defer wg.Done()
_, stderrErr = io.Copy(&stderr, stderrPipe)
_, _ = io.Copy(&stderr, stderrPipe)
}()
// Wait for the process to exit first. cmd.WaitDelay ensures that if
@@ -307,18 +334,9 @@ func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exe
// Wait for pipe readers to finish draining.
wg.Wait()
// Ignore pipe read errors caused by WaitDelay force-closing —
// we still have whatever was read before the close.
_ = stdoutErr
_ = stderrErr
exitCode := 0
if waitErr != nil {
if exitErr, ok := waitErr.(*exec.ExitError); ok {
exitCode = exitErr.ExitCode()
} else if cmdCtx.Err() == context.DeadlineExceeded {
return fantasy.NewTextErrorResponse("command timed out"), nil
}
exitCode, errResp := interpretBashExit(waitErr, cmdCtx)
if errResp != nil {
return *errResp, nil
}
return buildBashResponse(stdout.String(), stderr.String(), exitCode)
@@ -326,35 +344,9 @@ func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exe
// executeBashStreaming streams output as it arrives via the callback.
func executeBashStreaming(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd, outputCallback ToolOutputCallback, sudoPassword string) (fantasy.ToolResponse, error) {
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
return fantasy.NewTextErrorResponse("failed to create stdout pipe"), nil
}
stderrPipe, err := cmd.StderrPipe()
if err != nil {
return fantasy.NewTextErrorResponse("failed to create stderr pipe"), nil
}
// If we have a sudo password, create a stdin pipe
var stdinPipe io.WriteCloser
if sudoPassword != "" {
stdinPipe, err = cmd.StdinPipe()
if err != nil {
return fantasy.NewTextErrorResponse("failed to create stdin pipe"), nil
}
}
// Start command execution
if err := cmd.Start(); err != nil {
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to start command: %v", err)), nil
}
// Write password to stdin if needed, then close stdin
if sudoPassword != "" && stdinPipe != nil {
go func() {
defer func() { _ = stdinPipe.Close() }()
_, _ = io.WriteString(stdinPipe, sudoPassword+"\n")
}()
stdoutPipe, stderrPipe, errResp := setupBashPipes(cmd, sudoPassword)
if errResp != nil {
return *errResp, nil
}
// Stream stdout and stderr concurrently
@@ -391,20 +383,16 @@ func executeBashStreaming(cmdCtx context.Context, call fantasy.ToolCall, cmd *ex
// Wait for the process to exit. cmd.WaitDelay ensures that if pipes
// remain open (held by grandchild processes), they'll be forcibly closed
// after the grace period, which unblocks the scanners above.
err = cmd.Wait()
waitErr := cmd.Wait()
// Wait for the pipe readers to finish draining. This will complete
// quickly since cmd.Wait() (with WaitDelay) has already ensured
// the pipes are closed.
wg.Wait()
exitCode := 0
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
exitCode = exitErr.ExitCode()
} else if cmdCtx.Err() == context.DeadlineExceeded {
return fantasy.NewTextErrorResponse("command timed out"), nil
}
exitCode, errResp := interpretBashExit(waitErr, cmdCtx)
if errResp != nil {
return *errResp, nil
}
return buildBashResponse(strings.Join(stdoutChunks, "\n"), strings.Join(stderrChunks, "\n"), exitCode)
+1 -1
View File
@@ -183,7 +183,7 @@ func TestRewriteSudoForStdin(t *testing.T) {
func TestSudoPasswordFromContext(t *testing.T) {
// Test with password in context
ctx := ContextWithSudoPassword(context.Background(), "secret123")
ctx := context.WithValue(context.Background(), sudoPasswordKey, "secret123")
pw := sudoPasswordFromContext(ctx)
if pw != "secret123" {
t.Errorf("expected password 'secret123', got %q", pw)
+3
View File
@@ -83,6 +83,9 @@ func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
}
func executeEdit(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
if err := ctx.Err(); err != nil {
return fantasy.ToolResponse{}, err
}
var args editArgs
if err := parseArgs(call.Input, &args); err != nil {
return fantasy.NewTextErrorResponse("failed to parse arguments: " + err.Error()), nil
+3
View File
@@ -42,6 +42,9 @@ func NewLsTool(opts ...ToolOption) fantasy.AgentTool {
}
func executeLs(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
if err := ctx.Err(); err != nil {
return fantasy.ToolResponse{}, err
}
var args lsArgs
_ = parseArgs(call.Input, &args) // optional args
+3
View File
@@ -47,6 +47,9 @@ func NewReadTool(opts ...ToolOption) fantasy.AgentTool {
}
func executeRead(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
if err := ctx.Err(); err != nil {
return fantasy.ToolResponse{}, err
}
var args readArgs
if err := parseArgs(call.Input, &args); err != nil {
return fantasy.NewTextErrorResponse("path parameter is required"), nil
+3
View File
@@ -41,6 +41,9 @@ func NewWriteTool(opts ...ToolOption) fantasy.AgentTool {
}
func executeWrite(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
if err := ctx.Err(); err != nil {
return fantasy.ToolResponse{}, err
}
var args writeArgs
if err := parseArgs(call.Input, &args); err != nil {
return fantasy.NewTextErrorResponse("path and content parameters are required"), nil
+234
View File
@@ -0,0 +1,234 @@
package extbridge
import (
"context"
"github.com/mark3labs/kit/internal/extensions"
kit "github.com/mark3labs/kit/pkg/kit"
)
// BaseContext returns an extensions.Context populated with the headless,
// TUI-independent delegation fields: data access, state, options,
// model/tool management, completions, subagents, tree navigation, skills,
// template parsing, and model resolution.
//
// Callers overlay their UI-specific fields (print routes, widgets, prompts,
// editor, TUI-aware SetModel/ReloadExtensions, etc.) on the returned value:
// cmd/extension_context.go for the interactive TUI and
// internal/acpserver/session.go for headless ACP mode. Keeping the shared
// half here means a new data-access Context field only has to be wired once.
//
// ctx is used for subagent spawns; pass a long-lived context (not a
// per-request one) so later spawns aren't cancelled prematurely.
func BaseContext(ctx context.Context, kitInstance *kit.Kit) extensions.Context {
return extensions.Context{
// -------------------------------------------------------------------
// Data access
// -------------------------------------------------------------------
GetContextStats: func() extensions.ContextStats {
s := kitInstance.GetContextStats()
return extensions.ContextStats{
EstimatedTokens: s.EstimatedTokens,
ContextLimit: s.ContextLimit,
UsagePercent: s.UsagePercent,
MessageCount: s.MessageCount,
}
},
GetMessages: func() []extensions.SessionMessage {
return kitInstance.Extensions().GetSessionMessages()
},
GetSessionPath: func() string {
return kitInstance.GetSessionPath()
},
AppendEntry: func(entryType string, data string) (string, error) {
return kitInstance.Extensions().AppendEntry(entryType, data)
},
GetEntries: func(entryType string) []extensions.ExtensionEntry {
return kitInstance.Extensions().GetEntries(entryType)
},
// -------------------------------------------------------------------
// Extension state
// -------------------------------------------------------------------
SetState: func(key string, value string) {
kitInstance.Extensions().SetState(key, value)
},
GetState: func(key string) (string, bool) {
return kitInstance.Extensions().GetState(key)
},
DeleteState: func(key string) {
kitInstance.Extensions().DeleteState(key)
},
ListState: func() []string {
return kitInstance.Extensions().ListState()
},
// -------------------------------------------------------------------
// Options, model, and tool management
// -------------------------------------------------------------------
GetOption: func(name string) string {
return kitInstance.Extensions().GetOption(name)
},
SetOption: func(name string, value string) {
kitInstance.Extensions().SetOption(name, value)
},
// Headless model switch. The interactive TUI overrides this with a
// version that also notifies the TUI and refreshes the usage tracker.
SetModel: func(modelString string) error {
previousModel := kitInstance.Extensions().GetContext().Model
if err := kitInstance.SetModel(context.Background(), modelString); err != nil {
return err
}
kitInstance.Extensions().UpdateContextModel(modelString)
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
return nil
},
GetAvailableModels: func() []extensions.ModelInfoEntry {
return kitInstance.GetAvailableModels()
},
EmitCustomEvent: func(name string, data string) {
kitInstance.Extensions().EmitCustomEvent(name, data)
},
GetAllTools: func() []extensions.ToolInfo {
return kitInstance.Extensions().GetToolInfos()
},
SetActiveTools: func(names []string) {
kitInstance.Extensions().SetActiveTools(names)
},
// Headless reload. The interactive TUI overrides this to also
// refresh widgets/status/commands.
ReloadExtensions: func() error {
return kitInstance.Extensions().Reload()
},
// -------------------------------------------------------------------
// LLM completions and subagents
// -------------------------------------------------------------------
Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
return kitInstance.ExecuteCompletion(context.Background(), req)
},
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
return SpawnSubagent(ctx, kitInstance, config)
},
// -------------------------------------------------------------------
// Tree Navigation API
// -------------------------------------------------------------------
GetTreeNode: func(entryID string) *extensions.TreeNode {
node := kitInstance.GetTreeNode(entryID)
if node == nil {
return nil
}
return &extensions.TreeNode{
ID: node.ID,
ParentID: node.ParentID,
Type: node.Type,
Role: node.Role,
Content: node.Content,
Model: node.Model,
Provider: node.Provider,
Timestamp: node.Timestamp,
Children: node.Children,
}
},
GetCurrentBranch: func() []extensions.TreeNode {
nodes := kitInstance.GetCurrentBranch()
result := make([]extensions.TreeNode, len(nodes))
for i, n := range nodes {
result[i] = extensions.TreeNode{
ID: n.ID,
ParentID: n.ParentID,
Type: n.Type,
Role: n.Role,
Content: n.Content,
Model: n.Model,
Provider: n.Provider,
Timestamp: n.Timestamp,
Children: n.Children,
}
}
return result
},
GetChildren: func(parentID string) []string {
return kitInstance.GetChildren(parentID)
},
NavigateTo: func(entryID string) extensions.TreeNavigationResult {
err := kitInstance.NavigateTo(entryID)
if err != nil {
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
}
return extensions.TreeNavigationResult{Success: true}
},
SummarizeBranch: func(fromID, toID string) string {
summary, _ := kitInstance.SummarizeBranch(fromID, toID)
return summary
},
CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult {
err := kitInstance.CollapseBranch(fromID, toID, summary)
if err != nil {
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
}
return extensions.TreeNavigationResult{Success: true}
},
// -------------------------------------------------------------------
// Skill Loading API (context-injection variants are TUI-specific and
// wired by the interactive overlay)
// -------------------------------------------------------------------
LoadSkill: func(path string) (*extensions.Skill, string) {
s, err := kitInstance.LoadSkillForExtension(path)
return s, err
},
LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult {
return kitInstance.LoadSkillsFromDirForExtension(dir)
},
DiscoverSkills: func() extensions.SkillLoadResult {
skills := kitInstance.DiscoverSkillsForExtension()
return extensions.SkillLoadResult{Skills: skills}
},
GetAvailableSkills: func() []extensions.Skill {
return kitInstance.DiscoverSkillsForExtension()
},
// -------------------------------------------------------------------
// Template Parsing API
// -------------------------------------------------------------------
ParseTemplate: func(name, content string) extensions.PromptTemplate {
return kit.ParseTemplate(name, content)
},
RenderTemplate: func(tpl extensions.PromptTemplate, vars map[string]string) string {
return kit.RenderTemplate(tpl, vars)
},
ParseArguments: func(input string, pattern extensions.ArgumentPattern) extensions.ParseResult {
return kit.ParseArguments(input, pattern)
},
SimpleParseArguments: func(input string, count int) []string {
return kit.SimpleParseArguments(input, count)
},
EvaluateModelConditional: func(condition string) bool {
return kit.EvaluateModelConditional(kitInstance.Extensions().GetContext().Model, condition)
},
RenderWithModelConditionals: func(content string) string {
return kit.RenderWithModelConditionals(content, kitInstance.Extensions().GetContext().Model)
},
// -------------------------------------------------------------------
// Model Resolution API
// -------------------------------------------------------------------
ResolveModelChain: func(preferences []string) extensions.ModelResolutionResult {
return kit.ResolveModelChain(preferences)
},
GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) {
return kit.GetModelCapabilities(model)
},
CheckModelAvailable: func(model string) bool {
return kit.CheckModelAvailable(model)
},
GetCurrentProvider: func() string {
return kit.GetCurrentProvider(kitInstance.Extensions().GetContext().Model)
},
GetCurrentModelID: func() string {
return kit.GetCurrentModelID(kitInstance.Extensions().GetContext().Model)
},
}
}
+1
View File
@@ -66,6 +66,7 @@ func SpawnSubagent(ctx context.Context, k *kit.Kit, cfg extensions.SubagentConfi
SystemPrompt: cfg.SystemPrompt,
Timeout: cfg.Timeout,
NoSession: cfg.NoSession,
Tools: k.GetToolsForSubagent(),
}
if cfg.OnEvent != nil {
sdkCfg.OnEvent = func(e kit.Event) {
+202 -1
View File
@@ -1,5 +1,24 @@
package extensions
import (
"errors"
)
// ErrAgentBusy is returned (wrapped) when an extension API call that requires
// the agent to be idle cannot proceed because the agent is still processing a
// turn or post-turn hooks. Most notably, ctx.NewSession waits for idle
// internally; if its wait deadline elapses it returns an error that wraps
// this sentinel.
//
// Extensions can detect the condition with errors.Is:
//
// if err := ctx.NewSession(prompt); err != nil {
// if errors.Is(err, ext.ErrAgentBusy) {
// // agent never settled — fall back to a queued message instead
// }
// }
var ErrAgentBusy = errors.New("agent is busy")
// ---------------------------------------------------------------------------
// Internal types (used by runner, NOT exposed to Yaegi)
// ---------------------------------------------------------------------------
@@ -124,6 +143,48 @@ type Context struct {
// })
SendMultimodalMessage func(text string, files []FilePart)
// NewSession ends the current session and starts a fresh one (matching
// the /new slash command). When prompt is non-empty it is submitted as
// the first user turn of the new session, with @file references
// expanded the same way they are for normal user input. Pass an empty
// string to start an empty session.
//
// If the agent is currently busy when NewSession is called (for example,
// from an OnAgentEnd hook that fires before the agent fully settles, or
// while post-turn formatters/linters are still running), the call blocks
// until the agent transitions to idle. This avoids the v0.79.0
// phase-handoff race where NewSession from OnAgentEnd would fail with
// "agent is busy" because TurnEnd fires before the busy flag clears.
// The wait has a generous internal timeout; if it elapses the returned
// error wraps ErrAgentBusy (detectable with errors.Is).
//
// Returns an error if the agent does not become idle within the wait
// window, if a registered BeforeSessionSwitch handler cancels the
// switch, or if the new session file cannot be created. In
// non-interactive (ACP / headless) mode this is a no-op that returns
// an error.
//
// Because NewSession may block, call it from a goroutine — not
// directly from inside an event handler that the agent loop is waiting
// on.
//
// Typical pattern — start a fresh session at the end of a phase by
// reading a handoff file:
//
// api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) {
// msgs := ctx.GetMessages()
// if len(msgs) == 0 {
// return
// }
// last := msgs[len(msgs)-1].Content
// if strings.Contains(last, "<HANDOFF_READY>") {
// go func() {
// _ = ctx.NewSession("Read @HANDOFF.md and continue the next phase.")
// }()
// }
// })
NewSession func(prompt string) error
// GetSessionUsage returns aggregated token usage and cost statistics
// for the current session. This includes total input/output tokens,
// cache read/write tokens, total cost, and request count.
@@ -341,6 +402,13 @@ type Context struct {
// The data survives across session restarts and can be retrieved via
// GetEntries. Use entryType to namespace your data (e.g. "myext:state").
//
// AppendEntry is append-only and lives in the conversation tree, which
// makes it the right tool for audit logs and event histories. For
// last-write-wins snapshot state — "what's the current value of X?" —
// prefer SetState / GetState instead. Those primitives store data in a
// sidecar file outside the conversation tree, are O(1) to read/write,
// and do not bloat branch reads or duplicate on fork.
//
// Example:
//
// data, _ := json.Marshal(myState)
@@ -360,6 +428,45 @@ type Context struct {
// }
GetEntries func(entryType string) []ExtensionEntry
// SetState stores a key-value pair in session-scoped, last-write-wins
// extension state. Unlike AppendEntry the value is kept in a sidecar
// file outside the conversation tree, so:
// - reads are O(1) (no branch walk)
// - writes don't bloat the session JSONL
// - state is not duplicated on fork (branches share the sidecar)
// - state is invisible to the LLM
//
// Use SetState for snapshot state ("current value of X"); use
// AppendEntry for audit logs and event histories. Namespace keys with
// your extension name to avoid collisions (e.g. "myext:budget-cap").
//
// State persists for the lifetime of the session. For ephemeral or
// in-memory sessions the state lives only in memory.
//
// Example:
//
// ctx.SetState("myext:budget-cap", "10.00")
SetState func(key string, value string)
// GetState returns the value previously stored via SetState. The bool
// is false when the key was never written. Returns ("", false) when
// state is unavailable.
//
// Example:
//
// if cap, ok := ctx.GetState("myext:budget-cap"); ok {
// fmt.Println("current cap:", cap)
// }
GetState func(key string) (string, bool)
// DeleteState removes a key from session-scoped extension state.
// No-op when the key is missing.
DeleteState func(key string)
// ListState returns all keys currently stored in session-scoped
// extension state, in unspecified order.
ListState func() []string
// SetEditorText sets the text content of the input editor. This can
// be used to pre-fill the editor with suggested text (e.g. extracted
// questions, handoff prompts). The cursor is moved to the end.
@@ -1102,6 +1209,7 @@ type API struct {
onError func(func(ErrorEvent, Context))
onRetry func(func(RetryEvent, Context))
onPrepareStep func(func(PrepareStepEvent, Context) *PrepareStepResult)
onLLMUsage func(func(LLMUsageEvent, Context))
}
// OnToolCall registers a handler that fires before a tool executes.
@@ -1359,6 +1467,19 @@ func (a *API) OnPrepareStep(handler func(PrepareStepEvent, Context) *PrepareStep
a.onPrepareStep(handler)
}
// OnLLMUsage registers a handler that fires after each LLM provider call
// with the token and cost deltas for that single call. Use this for
// per-call usage attribution, real-time budget enforcement, and cost
// dashboards that need to react between calls within a single agent turn.
//
// Handlers receive an LLMUsageEvent describing the call's input/output
// tokens, cache tokens, computed cost, model, and provider. A single agent
// turn typically fires multiple LLMUsageEvents (one per tool-loop
// iteration).
func (a *API) OnLLMUsage(handler func(LLMUsageEvent, Context)) {
a.onLLMUsage(handler)
}
// RegisterToolRenderer registers a custom renderer for a specific tool's
// display in the TUI. The renderer controls the header (parameter summary)
// and/or body (result display) of the tool's output block. If multiple
@@ -2091,10 +2212,47 @@ type AgentStartEvent struct {
func (e AgentStartEvent) Type() EventType { return AgentStart }
// AgentEndEvent fires when the agent finishes responding.
// AgentEndEvent fires when the agent finishes responding. In addition to the
// final response and stop reason, the event carries per-turn aggregates so
// observer-style extensions don't have to maintain parallel bookkeeping in
// OnToolResult / OnStepFinish handlers.
type AgentEndEvent struct {
Response string
StopReason string // "completed", "cancelled", "error"
// ToolCallCount is the total number of tool invocations observed during
// this turn (sum across all steps).
ToolCallCount int
// ToolNames lists the tool names invoked during this turn, in call order.
// Duplicates are preserved (e.g. two bash calls produce ["bash", "bash"]).
ToolNames []string
// LLMCallCount is the number of LLM round-trips (tool-loop iterations)
// performed during this turn. Always >= 1 for a successful turn.
LLMCallCount int
// InputTokensDelta is the sum of input tokens consumed during this turn
// across every LLM call (including cache-hit input tokens).
InputTokensDelta int
// OutputTokensDelta is the sum of output tokens generated during this turn.
OutputTokensDelta int
// CacheReadTokensDelta is the sum of cache-read tokens during this turn.
CacheReadTokensDelta int
// CacheWriteTokensDelta is the sum of cache-write tokens during this turn.
CacheWriteTokensDelta int
// CostDelta is the total cost in USD attributable to this turn. Computed
// from per-step usage and current model pricing. Zero when pricing is
// unknown or OAuth credentials are in use.
CostDelta float64
// DurationMs is the elapsed wall-clock time from AgentStart to AgentEnd,
// in milliseconds.
DurationMs int64
}
func (e AgentEndEvent) Type() EventType { return AgentEnd }
@@ -2199,6 +2357,12 @@ type BeforeSessionSwitchEvent struct {
// Reason describes why the switch is happening: "new" for /new command,
// "clear" for /clear command.
Reason string
// InitialPrompt, when non-empty, is the prompt that will be submitted
// as the first user turn of the new session. Set when /new is invoked
// with an argument (e.g. "/new continue from HANDOFF.md") or when an
// extension calls ctx.NewSession(prompt). Extensions may inspect this
// to decide whether to allow the switch.
InitialPrompt string
}
func (e BeforeSessionSwitchEvent) Type() EventType { return BeforeSessionSwitch }
@@ -2403,6 +2567,43 @@ type PrepareStepResult struct {
func (PrepareStepResult) isResult() {}
// LLMUsageEvent fires after each LLM provider call with the per-call token
// and cost deltas. Use this for accurate budget tracking, cost dashboards,
// and any logic that needs to react between LLM calls within a single agent
// turn (rather than only at turn boundaries).
//
// A single agent turn typically produces multiple LLMUsageEvents (one per
// tool-loop iteration). The Model and Provider fields reflect the model used
// for that specific call, which may differ from earlier calls if the
// extension switched models mid-turn via ctx.SetModel().
type LLMUsageEvent struct {
// InputTokens is the number of input tokens for this call.
InputTokens int
// OutputTokens is the number of output tokens generated by this call.
OutputTokens int
// CacheReadTokens is the number of cache-hit input tokens (provider-specific).
CacheReadTokens int
// CacheWriteTokens is the number of cache-write tokens.
CacheWriteTokens int
// Cost is the USD cost of this call computed from the model's per-token
// pricing. Zero when pricing is unknown or OAuth credentials are in use.
Cost float64
// Model is the model identifier used for this call (e.g. "claude-sonnet-4-5-20250929").
Model string
// Provider is the provider identifier (e.g. "anthropic", "openai").
Provider string
// RequestID is an optional correlation id for the underlying provider
// call. May be empty when the provider does not surface one.
RequestID string
// StepNumber is the zero-based step index within the current agent turn.
StepNumber int
// FinishReason mirrors the provider's finish reason for this call
// (e.g. "stop", "tool_calls", "length"). May be empty.
FinishReason string
}
func (e LLMUsageEvent) Type() EventType { return LLMUsage }
// ThemeColor is an adaptive color pair with light and dark hex values.
// Either field may be empty to inherit from the default theme.
type ThemeColor struct {
+6 -1
View File
@@ -125,6 +125,11 @@ const (
// after steering messages are injected and before messages are sent
// to the LLM. Handlers can replace the context window for this step.
PrepareStep EventType = "prepare_step"
// LLMUsage fires after each LLM provider call with the token and cost
// deltas for that single call. Extensions use it to attribute usage to
// specific calls/models and to drive budget enforcement between calls.
LLMUsage EventType = "llm_usage"
)
// AllEventTypes returns every supported event type.
@@ -139,7 +144,7 @@ func AllEventTypes() []EventType {
BeforeFork, BeforeSessionSwitch, BeforeCompact,
SubagentStart, SubagentChunk, SubagentEnd,
StepStart, StepFinish, ReasoningStart, Warnings, Source, Error, Retry,
PrepareStep,
PrepareStep, LLMUsage,
}
}
+2 -2
View File
@@ -4,8 +4,8 @@ import "testing"
func TestAllEventTypes_Count(t *testing.T) {
all := AllEventTypes()
if len(all) != 32 {
t.Fatalf("expected 32 event types, got %d", len(all))
if len(all) != 33 {
t.Fatalf("expected 33 event types, got %d", len(all))
}
}
+119
View File
@@ -0,0 +1,119 @@
package extensions
import "testing"
func TestRunner_EmitLLMUsage(t *testing.T) {
var got LLMUsageEvent
var called bool
ext := makeHandlerExt("llmusage.go", map[EventType][]HandlerFunc{
LLMUsage: {
func(e Event, c Context) Result {
got = e.(LLMUsageEvent)
called = true
return nil
},
},
})
r := makeRunner(ext)
_, err := r.Emit(LLMUsageEvent{
InputTokens: 100,
OutputTokens: 50,
Cost: 0.0012,
Model: "claude-sonnet-4-5-20250929",
Provider: "anthropic",
StepNumber: 2,
FinishReason: "tool_calls",
})
if err != nil {
t.Fatalf("emit: %v", err)
}
if !called {
t.Fatal("expected LLMUsage handler to be called")
}
if got.InputTokens != 100 || got.OutputTokens != 50 {
t.Errorf("token fields not propagated: %+v", got)
}
if got.Cost != 0.0012 {
t.Errorf("cost not propagated, got %v", got.Cost)
}
if got.Model != "claude-sonnet-4-5-20250929" || got.Provider != "anthropic" {
t.Errorf("model/provider not propagated: %+v", got)
}
if got.StepNumber != 2 || got.FinishReason != "tool_calls" {
t.Errorf("step/finish reason not propagated: %+v", got)
}
}
func TestRunner_LLMUsageRegisteredViaTestAPI(t *testing.T) {
// Verify NewTestAPI wires up onLLMUsage so the extension can call
// api.OnLLMUsage during Init.
ext := &LoadedExtension{Handlers: make(map[EventType][]HandlerFunc)}
api := NewTestAPI(ext)
var calls int
api.OnLLMUsage(func(e LLMUsageEvent, c Context) {
calls++
})
if len(ext.Handlers[LLMUsage]) != 1 {
t.Fatalf("expected 1 LLMUsage handler registered, got %d", len(ext.Handlers[LLMUsage]))
}
r := makeRunner(*ext)
_, _ = r.Emit(LLMUsageEvent{InputTokens: 1})
if calls != 1 {
t.Errorf("expected handler called once, got %d", calls)
}
}
func TestAgentEndEvent_EnrichedFields(t *testing.T) {
// Verify the enriched event carries through Emit without mangling.
var got AgentEndEvent
ext := makeHandlerExt("end.go", map[EventType][]HandlerFunc{
AgentEnd: {
func(e Event, c Context) Result {
got = e.(AgentEndEvent)
return nil
},
},
})
r := makeRunner(ext)
_, err := r.Emit(AgentEndEvent{
Response: "done",
StopReason: "completed",
ToolCallCount: 3,
ToolNames: []string{"bash", "read", "bash"},
LLMCallCount: 4,
InputTokensDelta: 1500,
OutputTokensDelta: 400,
CacheReadTokensDelta: 200,
CacheWriteTokensDelta: 100,
CostDelta: 0.0123,
DurationMs: 2500,
})
if err != nil {
t.Fatalf("emit: %v", err)
}
if got.ToolCallCount != 3 {
t.Errorf("ToolCallCount: got %d want 3", got.ToolCallCount)
}
if len(got.ToolNames) != 3 || got.ToolNames[0] != "bash" || got.ToolNames[2] != "bash" {
t.Errorf("ToolNames: %v", got.ToolNames)
}
if got.LLMCallCount != 4 {
t.Errorf("LLMCallCount: got %d want 4", got.LLMCallCount)
}
if got.InputTokensDelta != 1500 || got.OutputTokensDelta != 400 {
t.Errorf("token deltas: %+v", got)
}
if got.CacheReadTokensDelta != 200 || got.CacheWriteTokensDelta != 100 {
t.Errorf("cache deltas: %+v", got)
}
if got.CostDelta != 0.0123 {
t.Errorf("CostDelta: got %v", got.CostDelta)
}
if got.DurationMs != 2500 {
t.Errorf("DurationMs: got %d", got.DurationMs)
}
}
+12 -2
View File
@@ -372,8 +372,12 @@ func loadSingleExtension(path string) (*LoadedExtension, error) {
Handlers: make(map[EventType][]HandlerFunc),
}
// Create a fresh interpreter.
i := interp.New(interp.Options{})
// Create a fresh interpreter. Yaegi runs extensions in restricted mode,
// where os.Getenv/os.LookupEnv/os.Environ read from a virtualized
// environment rather than the real one. Seed it with the process
// environment so extensions can read variables (e.g. CI-provided ones
// like GITHUB_EVENT_PATH) without being able to mutate the host's env.
i := interp.New(interp.Options{Env: os.Environ()})
// Expose the Go stdlib. The base set covers most packages; the
// unrestricted set adds os/exec so extensions can spawn processes.
@@ -669,6 +673,12 @@ func loadSingleExtension(path string) (*LoadedExtension, error) {
return *r
})
},
onLLMUsage: func(h func(LLMUsageEvent, Context)) {
reg(LLMUsage, func(e Event, c Context) Result {
h(e.(LLMUsageEvent), c)
return nil
})
},
}
// Call Init — the extension registers its handlers, tools, commands.
+187 -1
View File
@@ -2,9 +2,12 @@ package extensions
import (
"bytes"
"encoding/json"
"fmt"
"log"
"maps"
"os"
"path/filepath"
"runtime"
"sort"
"strconv"
@@ -99,6 +102,10 @@ type Runner struct {
customEventSubs map[string][]func(string) // inter-extension event bus
optionOverrides map[string]string // runtime option overrides
configStore *viper.Viper // per-instance config store (nil = global)
state map[string]string // session-scoped extension state (last-write-wins)
stateMu sync.RWMutex // guards state independently of mu
saverMu sync.Mutex // serializes stateSaver invocations so atomic-rename writes don't interleave
stateSaver func() // optional persistence hook invoked after each state mutation
mu sync.RWMutex
}
@@ -185,6 +192,9 @@ func normalizeContext(ctx Context) Context {
if ctx.SendMultimodalMessage == nil {
ctx.SendMultimodalMessage = func(string, []FilePart) {}
}
if ctx.NewSession == nil {
ctx.NewSession = func(string) error { return fmt.Errorf("new session not available") }
}
if ctx.GetSessionUsage == nil {
ctx.GetSessionUsage = func() SessionUsage { return SessionUsage{} }
}
@@ -264,6 +274,18 @@ func normalizeContext(ctx Context) Context {
if ctx.GetEntries == nil {
ctx.GetEntries = func(string) []ExtensionEntry { return nil }
}
if ctx.SetState == nil {
ctx.SetState = func(string, string) {}
}
if ctx.GetState == nil {
ctx.GetState = func(string) (string, bool) { return "", false }
}
if ctx.DeleteState == nil {
ctx.DeleteState = func(string) {}
}
if ctx.ListState == nil {
ctx.ListState = func() []string { return nil }
}
if ctx.GetOption == nil {
ctx.GetOption = func(string) string { return "" }
}
@@ -745,6 +767,168 @@ func (r *Runner) GetMessageRenderer(name string) *MessageRendererConfig {
return nil
}
// ---------------------------------------------------------------------------
// Extension state store (session-scoped, last-write-wins)
// ---------------------------------------------------------------------------
// SetState records a key-value pair in the runner's session-scoped extension
// state store. The store is in-memory; callers wire SetStateSaver to persist
// changes to a sidecar file. Thread-safe.
//
// When a saver is installed, concurrent SetState/DeleteState invocations are
// serialized through saverMu so that overlapping snapshot-and-rename writes
// cannot interleave (which would otherwise race on the shared tmp file and
// risk persisting an older snapshot after a newer one).
func (r *Runner) SetState(key, value string) {
r.stateMu.Lock()
if r.state == nil {
r.state = make(map[string]string)
}
r.state[key] = value
saver := r.stateSaver
r.stateMu.Unlock()
r.runSaver(saver)
}
// GetState returns the value previously stored via SetState, plus a bool
// indicating whether the key was present. Thread-safe.
func (r *Runner) GetState(key string) (string, bool) {
r.stateMu.RLock()
defer r.stateMu.RUnlock()
v, ok := r.state[key]
return v, ok
}
// DeleteState removes a key from the state store. No-op if the key is
// missing. Thread-safe. Saver invocations are serialized via saverMu — see
// SetState for the rationale.
func (r *Runner) DeleteState(key string) {
r.stateMu.Lock()
_, existed := r.state[key]
if existed {
delete(r.state, key)
}
saver := r.stateSaver
r.stateMu.Unlock()
if !existed {
return
}
r.runSaver(saver)
}
// runSaver invokes the optional persistence callback under saverMu so
// concurrent SetState/DeleteState writers cannot race on the shared tmp
// file used by SaveStateToFile's atomic rename. The deferred Unlock
// guarantees saverMu is released even if the saver panics.
func (r *Runner) runSaver(saver func()) {
if saver == nil {
return
}
r.saverMu.Lock()
defer r.saverMu.Unlock()
saver()
}
// ListState returns all keys currently in the state store, in unspecified
// order. Thread-safe.
func (r *Runner) ListState() []string {
r.stateMu.RLock()
defer r.stateMu.RUnlock()
if len(r.state) == 0 {
return nil
}
keys := make([]string, 0, len(r.state))
for k := range r.state {
keys = append(keys, k)
}
return keys
}
// SetStateSaver installs an optional persistence hook invoked after each
// mutation to the state store (SetState / DeleteState / LoadStateFromFile).
// Pass nil to disable persistence. Thread-safe.
func (r *Runner) SetStateSaver(saver func()) {
r.stateMu.Lock()
defer r.stateMu.Unlock()
r.stateSaver = saver
}
// SnapshotState returns a copy of the current state store as a
// fresh map. Useful for persisting to disk without holding the lock.
// Thread-safe.
func (r *Runner) SnapshotState() map[string]string {
r.stateMu.RLock()
defer r.stateMu.RUnlock()
if len(r.state) == 0 {
return nil
}
copyMap := make(map[string]string, len(r.state))
maps.Copy(copyMap, r.state)
return copyMap
}
// LoadStateFromFile reads a JSON map from path and replaces the in-memory
// state store with its contents. Missing or empty files are treated as
// "no prior state": the in-memory store is replaced with an empty map so
// callers can safely switch sessions without leaking keys from a prior
// session into a new one. Malformed JSON returns the parse error without
// touching the existing store. Thread-safe.
func (r *Runner) LoadStateFromFile(path string) error {
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
r.stateMu.Lock()
r.state = map[string]string{}
r.stateMu.Unlock()
return nil
}
return fmt.Errorf("reading extension state: %w", err)
}
if len(data) == 0 {
r.stateMu.Lock()
r.state = map[string]string{}
r.stateMu.Unlock()
return nil
}
var loaded map[string]string
if err := json.Unmarshal(data, &loaded); err != nil {
return fmt.Errorf("parsing extension state: %w", err)
}
r.stateMu.Lock()
r.state = loaded
r.stateMu.Unlock()
return nil
}
// SaveStateToFile writes the current state store to path as JSON, creating
// parent directories as needed. An empty store writes an empty object so
// that consumers can distinguish "loaded but empty" from "never saved".
// Writes are atomic via a tmp-file-and-rename sequence. Thread-safe.
func (r *Runner) SaveStateToFile(path string) error {
snap := r.SnapshotState()
if snap == nil {
snap = map[string]string{}
}
data, err := json.MarshalIndent(snap, "", " ")
if err != nil {
return fmt.Errorf("marshalling extension state: %w", err)
}
if dir := filepath.Dir(path); dir != "." && dir != "" {
if err := os.MkdirAll(dir, 0o755); err != nil {
return fmt.Errorf("creating state directory: %w", err)
}
}
tmp := path + ".tmp"
if err := os.WriteFile(tmp, data, 0o644); err != nil {
return fmt.Errorf("writing extension state: %w", err)
}
if err := os.Rename(tmp, path); err != nil {
_ = os.Remove(tmp)
return fmt.Errorf("renaming extension state: %w", err)
}
return nil
}
// ---------------------------------------------------------------------------
// Hot-reload
// ---------------------------------------------------------------------------
@@ -768,7 +952,9 @@ func (r *Runner) Reload(exts []LoadedExtension) {
r.uiVisibility = nil
r.disabledTools = nil
r.customEventSubs = nil
// optionOverrides are intentionally preserved.
// optionOverrides and state are intentionally preserved across reloads:
// they represent user/session intent (not extension code) and would be
// surprising to lose on a hot-reload.
}
// ---------------------------------------------------------------------------
+262
View File
@@ -0,0 +1,262 @@
package extensions
import (
"encoding/json"
"os"
"path/filepath"
"sync"
"testing"
"time"
)
func TestRunner_State_BasicSetGetDelete(t *testing.T) {
r := NewRunner(nil)
if _, ok := r.GetState("missing"); ok {
t.Fatal("expected GetState to return ok=false for missing key")
}
r.SetState("a", "1")
r.SetState("b", "2")
r.SetState("a", "3") // last-write-wins
if v, ok := r.GetState("a"); !ok || v != "3" {
t.Errorf("expected GetState(a)=(3,true), got (%q,%v)", v, ok)
}
if v, ok := r.GetState("b"); !ok || v != "2" {
t.Errorf("expected GetState(b)=(2,true), got (%q,%v)", v, ok)
}
keys := r.ListState()
if len(keys) != 2 {
t.Errorf("expected 2 keys, got %d (%v)", len(keys), keys)
}
r.DeleteState("a")
if _, ok := r.GetState("a"); ok {
t.Error("expected key a to be gone after DeleteState")
}
if len(r.ListState()) != 1 {
t.Errorf("expected 1 key after delete, got %v", r.ListState())
}
// Deleting missing key is a no-op.
r.DeleteState("never-there")
}
func TestRunner_State_SaverFires(t *testing.T) {
r := NewRunner(nil)
var calls int
var mu sync.Mutex
r.SetStateSaver(func() {
mu.Lock()
calls++
mu.Unlock()
})
r.SetState("a", "1")
r.SetState("a", "2")
r.DeleteState("a")
r.DeleteState("a") // missing → no save
mu.Lock()
defer mu.Unlock()
if calls != 3 {
t.Errorf("expected saver to fire 3 times (2 sets + 1 delete), got %d", calls)
}
}
func TestRunner_State_SaveAndLoadRoundTrip(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "ext-state.json")
r1 := NewRunner(nil)
r1.SetState("k1", "v1")
r1.SetState("k2", `{"json":"value"}`)
if err := r1.SaveStateToFile(path); err != nil {
t.Fatalf("SaveStateToFile: %v", err)
}
// Verify file contains JSON map.
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("reading saved file: %v", err)
}
var parsed map[string]string
if err := json.Unmarshal(data, &parsed); err != nil {
t.Fatalf("unmarshalling: %v", err)
}
if parsed["k1"] != "v1" || parsed["k2"] != `{"json":"value"}` {
t.Errorf("unexpected file contents: %v", parsed)
}
r2 := NewRunner(nil)
if err := r2.LoadStateFromFile(path); err != nil {
t.Fatalf("LoadStateFromFile: %v", err)
}
if v, ok := r2.GetState("k1"); !ok || v != "v1" {
t.Errorf("expected k1=v1 after load, got (%q,%v)", v, ok)
}
if v, ok := r2.GetState("k2"); !ok || v != `{"json":"value"}` {
t.Errorf("expected k2 to round-trip, got %q", v)
}
}
func TestRunner_State_LoadMissingFileClearsState(t *testing.T) {
// LoadStateFromFile is documented to "replace the in-memory state store
// with its contents"; for a missing file that means clearing the store.
// This is what makes session-switching safe: a new session that has not
// yet written a sidecar must not inherit keys from a prior session.
r := NewRunner(nil)
r.SetState("a", "1")
if err := r.LoadStateFromFile(filepath.Join(t.TempDir(), "does-not-exist.json")); err != nil {
t.Errorf("expected nil error for missing file, got %v", err)
}
if _, ok := r.GetState("a"); ok {
t.Error("expected pre-existing state to be cleared when target file is missing")
}
if keys := r.ListState(); keys != nil {
t.Errorf("expected ListState() to be nil after clearing, got %v", keys)
}
}
func TestRunner_State_LoadEmptyFileClearsState(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "empty.json")
if err := os.WriteFile(path, nil, 0o644); err != nil {
t.Fatal(err)
}
r := NewRunner(nil)
r.SetState("a", "1")
if err := r.LoadStateFromFile(path); err != nil {
t.Errorf("expected nil error for empty file, got %v", err)
}
if _, ok := r.GetState("a"); ok {
t.Error("expected pre-existing state to be cleared when target file is empty")
}
}
func TestRunner_State_LoadMalformedFileError(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "bad.json")
if err := os.WriteFile(path, []byte("{not json"), 0o644); err != nil {
t.Fatal(err)
}
r := NewRunner(nil)
if err := r.LoadStateFromFile(path); err == nil {
t.Error("expected error loading malformed JSON, got nil")
}
}
func TestRunner_State_PersistenceViaSaver(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "ext-state.json")
r := NewRunner(nil)
r.SetStateSaver(func() {
_ = r.SaveStateToFile(path)
})
r.SetState("hello", "world")
// File should exist with the value already.
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("reading saved file: %v", err)
}
var parsed map[string]string
if err := json.Unmarshal(data, &parsed); err != nil {
t.Fatalf("unmarshalling: %v", err)
}
if parsed["hello"] != "world" {
t.Errorf("expected file to contain hello=world, got %v", parsed)
}
}
func TestRunner_State_ConcurrentSet(t *testing.T) {
r := NewRunner(nil)
var wg sync.WaitGroup
const goroutines = 16
const iterations = 100
wg.Add(goroutines)
for range goroutines {
go func() {
defer wg.Done()
for range iterations {
r.SetState("k", "v")
_, _ = r.GetState("k")
}
}()
}
wg.Wait()
if v, ok := r.GetState("k"); !ok || v != "v" {
t.Errorf("expected k=v after concurrent writes, got (%q,%v)", v, ok)
}
}
func TestRunner_State_ContextNoOpsWhenUnset(t *testing.T) {
// Verify normalizeContext installs safe no-ops for SetState/GetState/etc.
// when not provided by the caller.
ext := makeHandlerExt("state.go", map[EventType][]HandlerFunc{
SessionStart: {
func(e Event, c Context) Result {
// All four state functions should be non-nil and safe to call.
c.SetState("a", "b")
if v, ok := c.GetState("a"); ok || v != "" {
t.Errorf("no-op GetState should return (\"\", false); got (%q,%v)", v, ok)
}
c.DeleteState("a")
if keys := c.ListState(); keys != nil {
t.Errorf("no-op ListState should return nil; got %v", keys)
}
return nil
},
},
})
r := makeRunner(ext)
// SetContext with empty Context to exercise normalizeContext defaults.
r.SetContext(Context{})
_, err := r.Emit(SessionStartEvent{})
if err != nil {
t.Fatalf("emit: %v", err)
}
}
func TestRunner_State_SaverPanicReleasesSaverMu(t *testing.T) {
// If the saver callback panics (e.g. disk full mid-write), runSaver
// must still release saverMu so subsequent SetState/DeleteState calls
// can make progress. Without `defer Unlock()` the lock would be
// permanently held and the next write would deadlock.
r := NewRunner(nil)
var calls int
r.SetStateSaver(func() {
calls++
if calls == 1 {
panic("simulated disk-write failure")
}
})
// First call panics. Recover, then verify a follow-up call still works
// without blocking (proving saverMu was released).
func() {
defer func() {
if rec := recover(); rec == nil {
t.Fatal("expected panic from first saver invocation")
}
}()
r.SetState("a", "1")
}()
done := make(chan struct{})
go func() {
r.SetState("b", "2") // would deadlock if saverMu were still held
close(done)
}()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("SetState after saver panic blocked — saverMu was not released")
}
if calls != 2 {
t.Errorf("expected saver to fire twice (panic + recovery write), got %d", calls)
}
}
+6
View File
@@ -28,6 +28,11 @@ func Symbols() interp.Exports {
"CommandDef": reflect.ValueOf((*CommandDef)(nil)),
"PrintBlockOpts": reflect.ValueOf((*PrintBlockOpts)(nil)),
// Sentinel errors. Extensions detect them with errors.Is:
//
// if errors.Is(err, ext.ErrAgentBusy) { ... }
"ErrAgentBusy": reflect.ValueOf(&ErrAgentBusy).Elem(),
// Session types
"SessionMessage": reflect.ValueOf((*SessionMessage)(nil)),
"ExtensionEntry": reflect.ValueOf((*ExtensionEntry)(nil)),
@@ -183,6 +188,7 @@ func Symbols() interp.Exports {
"RetryEvent": reflect.ValueOf((*RetryEvent)(nil)),
"PrepareStepEvent": reflect.ValueOf((*PrepareStepEvent)(nil)),
"PrepareStepResult": reflect.ValueOf((*PrepareStepResult)(nil)),
"LLMUsageEvent": reflect.ValueOf((*LLMUsageEvent)(nil)),
},
}
}
+6
View File
@@ -189,5 +189,11 @@ func NewTestAPI(ext *LoadedExtension) API {
return nil
})
},
onLLMUsage: func(h func(LLMUsageEvent, Context)) {
reg(LLMUsage, func(e Event, c Context) Result {
h(e.(LLMUsageEvent), c)
return nil
})
},
}
}
+38
View File
@@ -0,0 +1,38 @@
package extensions
// ToolKind constants classify what a tool does, enabling UIs to render
// appropriate visualizations (e.g. diff view for edit tools, command+output
// for execute tools) and file trackers to identify which results contain
// modifications.
//
// This is the single source of truth for tool-kind classification; the
// pkg/kit SDK re-exports these constants.
const (
ToolKindExecute = "execute" // Shell execution (bash)
ToolKindEdit = "edit" // File modification (edit, write)
ToolKindRead = "read" // File reading (read, ls)
ToolKindSearch = "search" // Content/file search (grep, find)
ToolKindSubagent = "agent" // Subagent spawning (subagent)
)
// coreToolKinds maps built-in tool names to their kind classification.
// MCP and extension tools without an entry default to ToolKindExecute.
var coreToolKinds = map[string]string{
"bash": ToolKindExecute,
"edit": ToolKindEdit,
"write": ToolKindEdit,
"read": ToolKindRead,
"ls": ToolKindRead,
"grep": ToolKindSearch,
"find": ToolKindSearch,
"subagent": ToolKindSubagent,
}
// ToolKindFor returns the ToolKind for a given tool name, defaulting to
// ToolKindExecute for unknown tools (including MCP tools).
func ToolKindFor(toolName string) string {
if kind, ok := coreToolKinds[toolName]; ok {
return kind
}
return ToolKindExecute
}
+24 -157
View File
@@ -1,143 +1,32 @@
package extensions
import (
"context"
"fmt"
"log"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/fsnotify/fsnotify"
"github.com/mark3labs/kit/internal/watcher"
)
// Watcher monitors extension directories for file changes and triggers
// a reload callback when .go files are created, modified, or removed.
// It uses fsnotify for kernel-level file notifications (inotify on Linux,
// kqueue on macOS) with debouncing to coalesce rapid editor writes.
type Watcher struct {
watcher *fsnotify.Watcher
onReload func()
debounce time.Duration
cancel context.CancelFunc
done chan struct{}
mu sync.Mutex
}
// Watcher monitors extension directories for .go file changes and triggers
// a reload callback when changes are detected. It is implemented in terms
// of the general-purpose internal/watcher.ContentWatcher.
//
// Type-aliasing here lets existing call sites (cmd/root.go and the
// watcher_test.go suite) keep using `extensions.NewWatcher` / `*Watcher`
// without knowing about the underlying implementation.
type Watcher = watcher.ContentWatcher
// NewWatcher creates a file watcher that monitors the given directories
// for .go file changes. When a change is detected (after debouncing),
// onReload is called. The watcher must be started with Start() and
// stopped with Close().
func NewWatcher(dirs []string, onReload func()) (*Watcher, error) {
fsw, err := fsnotify.NewWatcher()
if err != nil {
return nil, fmt.Errorf("creating file watcher: %w", err)
}
for _, dir := range dirs {
// Watch the directory itself.
if err := fsw.Add(dir); err != nil {
log.Printf("DEBUG watcher: skipping directory: dir=%s err=%v", dir, err)
continue
}
// Also watch immediate subdirectories (for */main.go pattern).
entries, err := os.ReadDir(dir)
if err != nil {
continue
}
for _, entry := range entries {
if entry.IsDir() {
subdir := filepath.Join(dir, entry.Name())
if err := fsw.Add(subdir); err != nil {
log.Printf("DEBUG watcher: skipping subdirectory: dir=%s err=%v", subdir, err)
}
}
}
}
return &Watcher{
watcher: fsw,
onReload: onReload,
debounce: 300 * time.Millisecond,
done: make(chan struct{}),
}, nil
}
// Start begins watching for file changes. It blocks until the context
// is cancelled or Close() is called. Typically called in a goroutine.
func (w *Watcher) Start(ctx context.Context) {
w.mu.Lock()
ctx, w.cancel = context.WithCancel(ctx)
w.mu.Unlock()
defer close(w.done)
var timer *time.Timer
var timerC <-chan time.Time
for {
select {
case <-ctx.Done():
if timer != nil {
timer.Stop()
}
return
case event, ok := <-w.watcher.Events:
if !ok {
return
}
// Only care about .go files.
if !strings.HasSuffix(event.Name, ".go") {
continue
}
// React to write, create, remove, rename events.
if event.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Remove|fsnotify.Rename) == 0 {
continue
}
log.Printf("DEBUG watcher: file changed: file=%s op=%s", event.Name, event.Op)
// Debounce: reset timer on each event.
if timer != nil {
timer.Stop()
}
timer = time.NewTimer(w.debounce)
timerC = timer.C
case <-timerC:
timerC = nil
timer = nil
log.Printf("DEBUG watcher: reloading extensions")
w.onReload()
case err, ok := <-w.watcher.Errors:
if !ok {
return
}
log.Printf("WARN watcher: error: %v", err)
}
}
}
// Close stops the watcher and releases resources.
func (w *Watcher) Close() error {
w.mu.Lock()
cancel := w.cancel
w.mu.Unlock()
if cancel != nil {
cancel()
}
// Wait for the event loop to finish.
<-w.done
return w.watcher.Close()
return watcher.New(watcher.Options{
Dirs: dirs,
Extensions: []string{".go"},
OnReload: onReload,
Label: "extensions",
})
}
// WatchedDirs returns the directories to watch for extension changes.
@@ -146,47 +35,25 @@ func (w *Watcher) Close() error {
// point to directories are also included; explicit file paths cause
// their parent directory to be watched instead.
func WatchedDirs(extraPaths []string) []string {
var dirs []string
seen := make(map[string]bool)
add := func(dir string) {
abs, err := filepath.Abs(dir)
if err != nil {
return
}
if seen[abs] {
return
}
// Verify the directory exists.
info, err := os.Stat(abs)
if err != nil || !info.IsDir() {
return
}
seen[abs] = true
dirs = append(dirs, abs)
standard := []string{
globalExtensionsDir(),
filepath.Join(".kit", "extensions"),
}
// Global extensions dir.
add(globalExtensionsDir())
// Project-local extensions dir.
add(filepath.Join(".kit", "extensions"))
// Explicit paths that are directories.
// Filter explicit paths into directories (passed through) and files
// (parent dir watched) for CollectDirs to dedupe.
var extras []string
for _, p := range extraPaths {
info, err := os.Stat(p)
if err != nil {
continue
}
if info.IsDir() {
add(p)
extras = append(extras, p)
} else {
// For explicit files, watch the parent directory.
add(filepath.Dir(p))
extras = append(extras, filepath.Dir(p))
}
}
return dirs
return watcher.CollectDirs(standard, extras)
}
+1 -22
View File
@@ -40,27 +40,6 @@ func ExtensionToolsAsLLMTools(defs []ToolDef, runner *Runner) []fantasy.AgentToo
return tools
}
// coreToolKinds maps built-in tool names to their kind classification.
var coreToolKinds = map[string]string{
"bash": "execute",
"edit": "edit",
"write": "edit",
"read": "read",
"ls": "read",
"grep": "search",
"find": "search",
"subagent": "agent",
}
// toolKindFor returns the ToolKind for a given tool name, defaulting to
// "execute" for unknown tools (including MCP tools).
func toolKindFor(toolName string) string {
if kind, ok := coreToolKinds[toolName]; ok {
return kind
}
return "execute"
}
// parseToolArgsJSON attempts to parse JSON-encoded tool args into a map.
// Returns nil on failure (non-fatal convenience parsing).
func parseToolArgsJSON(input string) map[string]any {
@@ -93,7 +72,7 @@ func (w *wrappedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.T
fmt.Sprintf("Error: tool %q is currently disabled", toolName)), nil
}
kind := toolKindFor(toolName)
kind := ToolKindFor(toolName)
// 1. Emit ToolCall — extensions can block execution.
if w.runner.HasHandlers(ToolCall) {
+11 -1
View File
@@ -53,6 +53,11 @@ type AgentSetupOptions struct {
// Debug enables debug logging. When zero-value, viper is consulted.
// Only meaningful when ProviderConfig is also set.
Debug bool
// DebugLogger, if non-nil, is used directly as the engine/MCP debug
// logger — overriding the built-in SimpleDebugLogger / BufferedDebugLogger
// selected by Debug + UseBufferedLogger. Callers supply this when they
// want to route debug output into their own logging system.
DebugLogger tools.DebugLogger
// NoExtensions skips extension loading. When false, viper is consulted.
// Only meaningful when ProviderConfig is also set.
NoExtensions bool
@@ -192,7 +197,12 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
// Create the appropriate debug logger.
var debugLogger tools.DebugLogger
var bufferedLogger *tools.BufferedDebugLogger
if debugEnabled {
switch {
case opts.DebugLogger != nil:
// Caller-supplied logger wins unconditionally. Its IsDebugEnabled()
// is the source of truth for whether downstream code emits messages.
debugLogger = opts.DebugLogger
case debugEnabled:
if opts.UseBufferedLogger {
bufferedLogger = tools.NewBufferedDebugLogger(true)
debugLogger = bufferedLogger
+22 -6
View File
@@ -10,15 +10,31 @@ import (
)
// TestNpmToWireProtocol documents the wire protocols that the auto-router
// understands. Provider-specific bundles (azure, bedrock, vercel, openrouter,
// google-vertex*) are intentionally absent — they have native top-level cases
// in CreateProvider and never reach the auto-router.
// understands. Provider-specific bundles that need bespoke auth or URL
// templating (azure, bedrock, openrouter, google-vertex*, @ai-sdk/gateway)
// are intentionally absent — they have native top-level cases in
// CreateProvider and never reach the auto-router.
func TestNpmToWireProtocol(t *testing.T) {
want := map[string]wireProtocol{
"@ai-sdk/openai": wireOpenAI,
"@ai-sdk/openai-compatible": wireOpenAI,
"@ai-sdk/anthropic": wireAnthropic,
"@ai-sdk/google": wireGoogle,
// Thin OpenAI-compatible wrappers — routed via openaicompat using
// the SDK's hard-coded default base URL (sdkDefaultBaseURL).
"@ai-sdk/groq": wireOpenAI,
"@ai-sdk/cerebras": wireOpenAI,
"@ai-sdk/perplexity": wireOpenAI,
"@ai-sdk/togetherai": wireOpenAI,
"@ai-sdk/xai": wireOpenAI,
"@ai-sdk/deepinfra": wireOpenAI,
"@ai-sdk/mistral": wireOpenAI,
"@ai-sdk/cohere": wireOpenAI,
"@ai-sdk/vercel": wireOpenAI,
"@aihubmix/ai-sdk-provider": wireOpenAI,
"venice-ai-sdk-provider": wireOpenAI,
"merge-gateway-ai-sdk-provider": wireOpenAI,
}
for npm, wire := range want {
if got := npmToWireProtocol[npm]; got != wire {
@@ -26,15 +42,15 @@ func TestNpmToWireProtocol(t *testing.T) {
}
}
// Bundle packages must NOT be in the table (regression guard against the
// old npmToLLMProvider map that listed 10 entries but only handled 3).
// Bundle packages must NOT be in the table — they need bespoke auth or
// URL templating that the auto-router cannot satisfy.
for _, npm := range []string{
"@ai-sdk/google-vertex",
"@ai-sdk/google-vertex/anthropic",
"@ai-sdk/amazon-bedrock",
"@ai-sdk/azure",
"@openrouter/ai-sdk-provider",
"@ai-sdk/vercel",
"@ai-sdk/gateway",
} {
if _, ok := npmToWireProtocol[npm]; ok {
t.Errorf("npmToWireProtocol unexpectedly contains bundle package %q", npm)
+84
View File
@@ -0,0 +1,84 @@
package models
import (
"net/http"
"testing"
"time"
)
func TestCopilotProviderAliasUsesCatalog(t *testing.T) {
registry := NewModelsRegistry()
models, err := registry.GetModelsForProvider("copilot")
if err != nil {
t.Fatalf("GetModelsForProvider(copilot) failed: %v", err)
}
if len(models) == 0 {
t.Fatal("expected copilot alias to return github-copilot catalog models")
}
if registry.LookupModel("copilot", "gpt-5.5") == nil {
t.Fatal("expected copilot/gpt-5.5 to resolve through github-copilot catalog")
}
if registry.GetProviderInfo("copilot") == nil {
t.Fatal("expected copilot alias to return github-copilot provider info")
}
}
func TestCopilotRejectsNonGPTModels(t *testing.T) {
_, err := CreateProvider(t.Context(), &ProviderConfig{ModelString: "copilot/claude-sonnet-4.6"})
if err == nil {
t.Fatal("expected non-GPT Copilot model to be rejected")
}
}
func TestCopilotHTTPClientCachesToken(t *testing.T) {
client := createCopilotHTTPClient("cached-token", time.Now().Add(time.Hour).Unix(), false)
transport, ok := client.Transport.(*copilotTransport)
if !ok {
t.Fatal("expected *copilotTransport")
}
token := transport.cachedToken(t.Context())
if token != "cached-token" {
t.Fatalf("expected cached token, got %q", token)
}
}
func TestCopilotTransportHeaders(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
if err != nil {
t.Fatalf("NewRequest failed: %v", err)
}
transport := &copilotTransport{
base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.Header.Get("Authorization") != "Bearer cached-token" {
t.Fatalf("unexpected Authorization header: %q", req.Header.Get("Authorization"))
}
if req.Header.Get("Copilot-Integration-Id") != copilotIntegrationID {
t.Fatalf("unexpected Copilot-Integration-Id header: %q", req.Header.Get("Copilot-Integration-Id"))
}
if req.Header.Get("Editor-Version") != copilotEditorVersion {
t.Fatalf("unexpected Editor-Version header: %q", req.Header.Get("Editor-Version"))
}
if req.Header.Get("User-Agent") != copilotUserAgent {
t.Fatalf("unexpected User-Agent header: %q", req.Header.Get("User-Agent"))
}
return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil
}),
token: "cached-token",
expiresAt: time.Now().Add(time.Hour).Unix(),
}
resp, err := transport.RoundTrip(req)
if err != nil {
t.Fatalf("RoundTrip failed: %v", err)
}
_ = resp.Body.Close()
}
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
+20 -25
View File
@@ -44,13 +44,14 @@ func loadCustomModelsFrom(v *viper.Viper) map[string]ModelInfo {
// modelConfigToModelInfo converts a CustomModelConfig to a ModelInfo.
func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
info := ModelInfo{
ID: modelID,
Name: cfg.Name,
Attachment: cfg.Attachment,
Reasoning: cfg.Reasoning,
Temperature: cfg.Temperature,
BaseURL: cfg.BaseURL,
APIKey: cfg.APIKey,
ID: modelID,
Name: cfg.Name,
Attachment: cfg.Attachment,
Reasoning: cfg.Reasoning,
Temperature: cfg.Temperature,
BaseURL: cfg.BaseURL,
APIKey: cfg.APIKey,
APIModelName: cfg.APIModelName,
Cost: Cost{
Input: cfg.Cost.Input,
Output: cfg.Cost.Output,
@@ -69,13 +70,6 @@ func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
return info
}
// LoadModelSettingsFromConfig loads per-model generation parameter overrides
// from the process-global viper store. Keys are "provider/model" strings.
// Returns nil if no model settings are configured.
func LoadModelSettingsFromConfig() map[string]*GenerationParams {
return LoadModelSettingsFrom(viper.GetViper())
}
// LoadModelSettingsFrom loads per-model generation parameter overrides from the
// supplied per-instance store. When v is nil the process-global store is used.
// Keys are "provider/model" strings. Returns nil if no model settings are
@@ -294,17 +288,18 @@ type GenerationParams struct {
// CustomModelConfig defines a custom model configuration loaded from the config file.
// This is a duplicate here to avoid circular dependencies with internal/config.
type CustomModelConfig struct {
Name string `json:"name" yaml:"name"`
BaseURL string `json:"baseUrl,omitempty" yaml:"baseUrl,omitempty"`
APIKey string `json:"apiKey,omitempty" yaml:"apiKey,omitempty"`
Family string `json:"family,omitempty" yaml:"family,omitempty"`
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
Temperature bool `json:"temperature,omitempty" yaml:"temperature,omitempty"`
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
Cost CostConfig `json:"cost" yaml:"cost"`
Limit LimitConfig `json:"limit" yaml:"limit"`
Params GenerationParamsConfig `json:"params,omitzero" yaml:"params,omitempty"`
Name string `json:"name" yaml:"name"`
BaseURL string `json:"baseUrl,omitempty" yaml:"baseUrl,omitempty"`
APIKey string `json:"apiKey,omitempty" yaml:"apiKey,omitempty"`
APIModelName string `json:"apiModelName,omitempty" yaml:"apiModelName,omitempty"`
Family string `json:"family,omitempty" yaml:"family,omitempty"`
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
Temperature bool `json:"temperature,omitempty" yaml:"temperature,omitempty"`
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
Cost CostConfig `json:"cost" yaml:"cost"`
Limit LimitConfig `json:"limit" yaml:"limit"`
Params GenerationParamsConfig `json:"params,omitzero" yaml:"params,omitempty"`
}
// GenerationParamsConfig is the JSON/YAML-serializable form of generation
File diff suppressed because one or more lines are too long
+64 -5
View File
@@ -62,14 +62,73 @@ const (
)
// npmToWireProtocol maps npm package names from models.dev to the wire
// protocol they speak. Provider-specific bundles (azure, bedrock, vercel,
// openrouter, google-vertex, google-vertex-anthropic) are intentionally
// absent — they have native top-level cases in CreateProvider and never
// reach the auto-router. Providers not in this map but with an api URL
// are auto-routed through the OpenAI-compatible wire.
// protocol they speak. Provider-specific bundles that need bespoke auth or
// URL templating (azure, bedrock, openrouter, google-vertex, google-vertex-
// anthropic, and @ai-sdk/gateway which is the Vercel AI Gateway) are
// intentionally absent — they have native top-level cases in CreateProvider
// and never reach the auto-router. Providers not in this map but with an
// api URL are auto-routed through the OpenAI-compatible wire.
//
// The thin OpenAI-compatible npm wrappers (groq, cerebras, mistral, …) are
// listed explicitly so that auto-routing can recover their hard-coded base
// URL from sdkDefaultBaseURL when the registry entry has no api field.
var npmToWireProtocol = map[string]wireProtocol{
// Native wires.
"@ai-sdk/openai": wireOpenAI,
"@ai-sdk/openai-compatible": wireOpenAI,
"@ai-sdk/anthropic": wireAnthropic,
"@ai-sdk/google": wireGoogle,
// Thin OpenAI-compatible wrappers. Each ships with a hard-coded base URL
// in its JS SDK (see sdkDefaultBaseURL) but speaks the plain OpenAI chat
// completions wire — so we can route them all through fantasy's
// openaicompat provider once we supply the URL.
"@ai-sdk/groq": wireOpenAI,
"@ai-sdk/cerebras": wireOpenAI,
"@ai-sdk/perplexity": wireOpenAI,
"@ai-sdk/togetherai": wireOpenAI,
"@ai-sdk/xai": wireOpenAI,
"@ai-sdk/deepinfra": wireOpenAI,
"@ai-sdk/mistral": wireOpenAI,
"@ai-sdk/cohere": wireOpenAI,
"@ai-sdk/vercel": wireOpenAI, // v0 API (api.v0.dev), distinct from @ai-sdk/gateway
"@aihubmix/ai-sdk-provider": wireOpenAI,
"venice-ai-sdk-provider": wireOpenAI,
"merge-gateway-ai-sdk-provider": wireOpenAI,
}
// sdkDefaultBaseURL maps an npm package name to the base URL its JavaScript
// SDK uses by default. This lets us recover a working endpoint for providers
// whose models.dev entry omits the `api` field because the JS SDK hard-codes
// the URL (e.g. groq, cerebras, mistral, x.ai…).
//
// Only OpenAI-compatible and native-wire SDKs are listed; providers needing
// bespoke auth or URL templating (bedrock SigV4, azure resource URLs,
// google-vertex project/location, cloudflare gateway account IDs, gitlab,
// sap-ai-core) are handled by native CreateProvider cases or surface a
// targeted error that asks the user to supply --provider-url.
var sdkDefaultBaseURL = map[string]string{
// Native wires.
"@ai-sdk/openai": "https://api.openai.com/v1",
"@ai-sdk/anthropic": "https://api.anthropic.com/v1",
"@ai-sdk/google": "https://generativelanguage.googleapis.com/v1beta",
// Thin OpenAI-compatible wrappers.
"@ai-sdk/groq": "https://api.groq.com/openai/v1",
"@ai-sdk/cerebras": "https://api.cerebras.ai/v1",
"@ai-sdk/perplexity": "https://api.perplexity.ai",
"@ai-sdk/togetherai": "https://api.together.xyz/v1",
"@ai-sdk/xai": "https://api.x.ai/v1",
"@ai-sdk/deepinfra": "https://api.deepinfra.com/v1/openai",
"@ai-sdk/mistral": "https://api.mistral.ai/v1",
"@ai-sdk/cohere": "https://api.cohere.com/compatibility/v1",
"@ai-sdk/vercel": "https://api.v0.dev/v1",
"@aihubmix/ai-sdk-provider": "https://aihubmix.com/v1",
"venice-ai-sdk-provider": "https://api.venice.ai/api/v1",
"merge-gateway-ai-sdk-provider": "https://api-gateway.merge.dev/v1/ai-sdk",
// Native handlers — included for ResolveProviderBaseURL introspection
// even though CreateProvider routes these via dedicated cases.
"@ai-sdk/gateway": "https://ai-gateway.vercel.sh/v1",
"@openrouter/ai-sdk-provider": "https://openrouter.ai/api/v1",
}
+293 -58
View File
@@ -13,6 +13,7 @@ import (
"os"
"regexp"
"strings"
"sync"
"time"
"charm.land/fantasy"
@@ -33,6 +34,24 @@ import (
const (
// ClaudeCodePrompt is the required system prompt for OAuth authentication.
ClaudeCodePrompt = "You are Claude Code, Anthropic's official CLI for Claude."
// copilotProviderID is the canonical models.dev provider key. The CLI also
// accepts the shorter "copilot" alias for user-facing model strings.
copilotProviderID = "github-copilot"
// copilotAliasProviderID is the short provider prefix accepted by kit.
copilotAliasProviderID = "copilot"
// copilotBaseURL is the fallback API URL if the model catalog has no API URL.
copilotBaseURL = "https://api.githubcopilot.com"
// GitHub Copilot currently expects VS Code Copilot Chat client identifiers.
// Keep these centralized so they are easy to audit and update when GitHub
// changes accepted client metadata.
copilotIntegrationID = "vscode-chat"
copilotEditorVersion = "vscode/1.104.1"
copilotEditorPluginVersion = "copilot-chat/0.31.0"
copilotUserAgent = "GitHubCopilotChat/0.31.0"
copilotOpenAIIntent = "conversation-agent"
copilotGitHubAPIVersion = "2026-01-09"
)
// resolveModelAlias resolves model aliases to their full names using the registry
@@ -215,6 +234,20 @@ func ParseModelString(modelString string) (provider, model string, err error) {
return "", "", fmt.Errorf("invalid model format %q: expected provider/model (e.g. anthropic/claude-sonnet-4-5)", modelString)
}
// isCopilotProvider reports whether provider is the canonical catalog key or
// the user-facing shorthand alias.
func isCopilotProvider(provider string) bool {
return provider == copilotAliasProviderID || provider == copilotProviderID
}
// catalogProviderID maps supported provider aliases to their models.dev keys.
func catalogProviderID(provider string) string {
if isCopilotProvider(provider) {
return copilotProviderID
}
return provider
}
// CreateProvider creates a fantasy LanguageModel based on the provider configuration.
// Model metadata is looked up from the models.dev database for cost tracking and
// capability detection, but unknown models are passed through to the provider
@@ -238,17 +271,30 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
}
registry := GetGlobalRegistry()
lookupProvider := catalogProviderID(provider)
// Look up model metadata (advisory, not blocking).
// Look up model metadata (advisory for most providers, strict for Copilot).
// When the model is known we validate config limits and print
// suggestions on likely typos; when unknown we let the provider
// API be the authority.
modelInfo := registry.LookupModel(provider, modelName)
if modelInfo == nil && provider != "ollama" && config.ProviderURL == "" {
// API be the authority except for Copilot, whose non-GPT catalog entries
// require unsupported wire protocols.
modelInfo := registry.LookupModel(lookupProvider, modelName)
if isCopilotProvider(provider) {
providerInfo := registry.GetProviderInfo(copilotProviderID)
if providerInfo == nil {
return nil, fmt.Errorf("unsupported provider: %s (not found in model database)", copilotProviderID)
}
if modelInfo == nil {
if suggestions := registry.SuggestModels(copilotProviderID, modelName); len(suggestions) > 0 {
return nil, fmt.Errorf("model %q not found for provider %s. Did you mean one of: %s", modelName, copilotProviderID, strings.Join(suggestions, ", "))
}
return nil, fmt.Errorf("model %q not found for provider %s", modelName, copilotProviderID)
}
} else if modelInfo == nil && provider != "ollama" && config.ProviderURL == "" {
// Model not in database — warn with suggestions but don't block.
if suggestions := registry.SuggestModels(provider, modelName); len(suggestions) > 0 {
if suggestions := registry.SuggestModels(lookupProvider, modelName); len(suggestions) > 0 {
fmt.Fprintf(os.Stderr, "Warning: model %q not found in model database for provider %s. Similar models: %s\n",
modelName, provider, strings.Join(suggestions, ", "))
modelName, lookupProvider, strings.Join(suggestions, ", "))
}
}
@@ -282,17 +328,21 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
result, createErr = createAnthropicProvider(ctx, config, modelName)
case "openai":
result, createErr = createOpenAIProvider(ctx, config, modelName)
case "copilot", "github-copilot":
result, createErr = createCopilotProvider(ctx, config, modelName)
case "google", "gemini":
result, createErr = createGoogleProvider(ctx, config, modelName)
case "ollama":
result, createErr = createOllamaProvider(ctx, config, modelName)
case "azure":
case "azure", "azure-cognitive-services":
result, createErr = createAzureProvider(ctx, config, modelName)
case "google-vertex-anthropic":
result, createErr = createVertexAnthropicProvider(ctx, config, modelName)
case "google-vertex":
result, createErr = createGoogleVertexProvider(ctx, config, modelName)
case "openrouter":
result, createErr = createOpenRouterProvider(ctx, config, modelName)
case "bedrock":
case "bedrock", "amazon-bedrock":
result, createErr = createBedrockProvider(ctx, config, modelName)
case "vercel":
result, createErr = createVercelProvider(ctx, config, modelName)
@@ -376,8 +426,27 @@ func autoRouteProvider(ctx context.Context, config *ProviderConfig, provider, mo
}
// All three wires use the provider's API URL from models.dev as the base.
if config.ProviderURL == "" && providerInfo.API != "" {
config.ProviderURL = providerInfo.API
// When the registry has none, fall back to the SDK's hard-coded default for
// this npm package (covers groq, cerebras, mistral, x.ai, etc. — providers
// whose JS SDK ships a built-in baseURL that models.dev doesn't restate).
if config.ProviderURL == "" {
if providerInfo.API != "" {
config.ProviderURL = providerInfo.API
} else if defaultURL, ok := sdkDefaultBaseURL[npmPackage]; ok {
config.ProviderURL = defaultURL
providerInfo.API = defaultURL // for downstream helpers that read info.API
}
}
// Provider templates a runtime account/region/deployment segment into the
// URL (cloudflare-ai-gateway, databricks, snowflake-cortex, gitlab,
// sap-ai-core). Resolve via environment variables, or surface a targeted
// error pointing the user at the right knobs.
if resolved, err := resolveTemplatedAPIURL(config.ProviderURL, providerInfo); err != nil {
return nil, err
} else if resolved != "" {
config.ProviderURL = resolved
providerInfo.API = resolved
}
switch wire {
@@ -398,6 +467,24 @@ func autoRouteProvider(ctx context.Context, config *ProviderConfig, provider, mo
}
}
// resolveAutoRouteAPIKey looks up the API key for an auto-routed provider,
// returning a uniform error message when none can be resolved.
func resolveAutoRouteAPIKey(config *ProviderConfig, info *ProviderInfo) (string, error) {
apiKey := resolveAPIKey(config.ProviderAPIKey, info.Env)
if apiKey == "" {
return "", fmt.Errorf("%s API key not provided. Use --provider-api-key or set %s",
info.Name, strings.Join(info.Env, " / "))
}
return apiKey, nil
}
// wrapProviderErr produces the uniform "failed to create X provider/model: %w"
// error wrap used by every createXxxProvider path. kind is typically
// "provider" or "model".
func wrapProviderErr(name, kind string, err error) error {
return fmt.Errorf("failed to create %s %s: %w", name, kind, err)
}
// createAutoRoutedOpenAICompatProvider creates an openaicompat provider using
// the api URL and env vars from models.dev.
func createAutoRoutedOpenAICompatProvider(ctx context.Context, config *ProviderConfig, modelName string, info *ProviderInfo) (*ProviderResult, error) {
@@ -409,10 +496,9 @@ func createAutoRoutedOpenAICompatProvider(ctx context.Context, config *ProviderC
return nil, fmt.Errorf("provider %s requires --provider-url (no API URL in database)", info.ID)
}
apiKey := resolveAPIKey(config.ProviderAPIKey, info.Env)
if apiKey == "" {
return nil, fmt.Errorf("%s API key not provided. Use --provider-api-key or set %s",
info.Name, strings.Join(info.Env, " / "))
apiKey, err := resolveAutoRouteAPIKey(config, info)
if err != nil {
return nil, err
}
var opts []openaicompat.Option
@@ -426,12 +512,12 @@ func createAutoRoutedOpenAICompatProvider(ctx context.Context, config *ProviderC
p, err := openaicompat.New(opts...)
if err != nil {
return nil, fmt.Errorf("failed to create %s provider: %w", info.Name, err)
return nil, wrapProviderErr(info.Name, "provider", err)
}
model, err := p.LanguageModel(ctx, modelName)
if err != nil {
return nil, fmt.Errorf("failed to create %s model: %w", info.Name, err)
return nil, wrapProviderErr(info.Name, "model", err)
}
return &ProviderResult{Model: model}, nil
@@ -442,10 +528,9 @@ func createAutoRoutedOpenAICompatProvider(ctx context.Context, config *ProviderC
func createAutoRoutedAnthropicProvider(ctx context.Context, config *ProviderConfig, modelName string, info *ProviderInfo) (*ProviderResult, error) {
clearConflictingAnthropicSamplingParams(config)
apiKey := resolveAPIKey(config.ProviderAPIKey, info.Env)
if apiKey == "" {
return nil, fmt.Errorf("%s API key not provided. Use --provider-api-key or set %s",
info.Name, strings.Join(info.Env, " / "))
apiKey, err := resolveAutoRouteAPIKey(config, info)
if err != nil {
return nil, err
}
var opts []anthropic.Option
@@ -464,12 +549,12 @@ func createAutoRoutedAnthropicProvider(ctx context.Context, config *ProviderConf
p, err := anthropic.New(opts...)
if err != nil {
return nil, fmt.Errorf("failed to create %s provider: %w", info.Name, err)
return nil, wrapProviderErr(info.Name, "provider", err)
}
model, err := p.LanguageModel(ctx, modelName)
if err != nil {
return nil, fmt.Errorf("failed to create %s model: %w", info.Name, err)
return nil, wrapProviderErr(info.Name, "model", err)
}
return &ProviderResult{Model: model}, nil
@@ -478,10 +563,9 @@ func createAutoRoutedAnthropicProvider(ctx context.Context, config *ProviderConf
// createAutoRoutedOpenAIProvider creates an openai provider for
// third-party providers with openai-compatible APIs.
func createAutoRoutedOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName string, info *ProviderInfo) (*ProviderResult, error) {
apiKey := resolveAPIKey(config.ProviderAPIKey, info.Env)
if apiKey == "" {
return nil, fmt.Errorf("%s API key not provided. Use --provider-api-key or set %s",
info.Name, strings.Join(info.Env, " / "))
apiKey, err := resolveAutoRouteAPIKey(config, info)
if err != nil {
return nil, err
}
var opts []openai.Option
@@ -498,12 +582,12 @@ func createAutoRoutedOpenAIProvider(ctx context.Context, config *ProviderConfig,
p, err := openai.New(opts...)
if err != nil {
return nil, fmt.Errorf("failed to create %s provider: %w", info.Name, err)
return nil, wrapProviderErr(info.Name, "provider", err)
}
model, err := p.LanguageModel(ctx, modelName)
if err != nil {
return nil, fmt.Errorf("failed to create %s model: %w", info.Name, err)
return nil, wrapProviderErr(info.Name, "model", err)
}
providerOpts := buildOpenAIProviderOptions(config, modelName)
@@ -522,10 +606,9 @@ func createAutoRoutedOpenAIProvider(ctx context.Context, config *ProviderConfig,
// path that the proxy rejects. In that case we install a transport that
// strips the injected segment so the proxy's own version is used.
func createAutoRoutedGoogleProvider(ctx context.Context, config *ProviderConfig, modelName string, info *ProviderInfo) (*ProviderResult, error) {
apiKey := resolveAPIKey(config.ProviderAPIKey, info.Env)
if apiKey == "" {
return nil, fmt.Errorf("%s API key not provided. Use --provider-api-key or set %s",
info.Name, strings.Join(info.Env, " / "))
apiKey, err := resolveAutoRouteAPIKey(config, info)
if err != nil {
return nil, err
}
opts := []google.Option{
@@ -550,12 +633,12 @@ func createAutoRoutedGoogleProvider(ctx context.Context, config *ProviderConfig,
p, err := google.New(opts...)
if err != nil {
return nil, fmt.Errorf("failed to create %s provider: %w", info.Name, err)
return nil, wrapProviderErr(info.Name, "provider", err)
}
model, err := p.LanguageModel(ctx, modelName)
if err != nil {
return nil, fmt.Errorf("failed to create %s model: %w", info.Name, err)
return nil, wrapProviderErr(info.Name, "model", err)
}
return &ProviderResult{Model: model}, nil
@@ -849,7 +932,7 @@ func createAnthropicProvider(ctx context.Context, config *ProviderConfig, modelN
}
// Handle OAuth vs API key authentication
if strings.HasPrefix(source, "stored OAuth") {
if source == auth.CredentialSourceOAuth {
httpClient := createOAuthHTTPClient(apiKey, config.TLSSkipVerify)
opts = append(opts, anthropic.WithHTTPClient(httpClient))
// Note: For OAuth, the API key is set as a placeholder; the transport handles auth
@@ -859,12 +942,12 @@ func createAnthropicProvider(ctx context.Context, config *ProviderConfig, modelN
provider, err := anthropic.New(opts...)
if err != nil {
return nil, fmt.Errorf("failed to create Anthropic provider: %w", err)
return nil, wrapProviderErr("Anthropic", "provider", err)
}
model, err := provider.LanguageModel(ctx, modelName)
if err != nil {
return nil, fmt.Errorf("failed to create Anthropic model: %w", err)
return nil, wrapProviderErr("Anthropic", "model", err)
}
// Build provider options for extended thinking (reasoning budget).
@@ -901,12 +984,12 @@ func createVertexAnthropicProvider(ctx context.Context, config *ProviderConfig,
provider, err := anthropic.New(opts...)
if err != nil {
return nil, fmt.Errorf("failed to create Vertex Anthropic provider: %w", err)
return nil, wrapProviderErr("Vertex Anthropic", "provider", err)
}
model, err := provider.LanguageModel(ctx, modelName)
if err != nil {
return nil, fmt.Errorf("failed to create Vertex Anthropic model: %w", err)
return nil, wrapProviderErr("Vertex Anthropic", "model", err)
}
return &ProviderResult{Model: model}, nil
@@ -974,12 +1057,12 @@ func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName
provider, err := openai.New(opts...)
if err != nil {
return nil, fmt.Errorf("failed to create OpenAI provider: %w", err)
return nil, wrapProviderErr("OpenAI", "provider", err)
}
model, err := provider.LanguageModel(ctx, modelName)
if err != nil {
return nil, fmt.Errorf("failed to create OpenAI model: %w", err)
return nil, wrapProviderErr("OpenAI", "model", err)
}
// Build provider options for OpenAI Responses API reasoning models.
@@ -988,6 +1071,72 @@ func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName
return &ProviderResult{Model: model, ProviderOptions: providerOpts}, nil
}
// createCopilotProvider builds a GitHub Copilot provider through fantasy's
// OpenAI-compatible provider. The catalog key is github-copilot, but the public
// model prefix may be either copilot/ or github-copilot/.
//
// Only gpt-* Copilot models are enabled here. The catalog also lists Claude and
// Gemini Copilot models, but those require different wire protocols and must be
// routed explicitly before they can be safely accepted.
func createCopilotProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
if !strings.HasPrefix(modelName, "gpt-") {
return nil, fmt.Errorf("GitHub Copilot model %q is not supported yet: only gpt-* models use the OpenAI-compatible protocol", modelName)
}
cm, err := auth.NewCredentialManager()
if err != nil {
return nil, fmt.Errorf("failed to initialize credential manager: %w", err)
}
token, err := cm.GetValidCopilotAccessTokenContext(ctx)
if err != nil {
return nil, fmt.Errorf("GitHub Copilot credentials not available. Use 'kit auth login copilot': %w", err)
}
expiresAt := int64(0)
if creds, err := cm.GetCopilotCredentials(); err == nil && creds != nil && creds.CopilotAccessToken == token {
expiresAt = creds.ExpiresAt
}
baseURL := copilotBaseURL
if providerInfo := GetGlobalRegistry().GetProviderInfo(copilotProviderID); providerInfo != nil && providerInfo.API != "" {
baseURL = providerInfo.API
}
if config.ProviderURL != "" {
baseURL = config.ProviderURL
}
opts := []openai.Option{
openai.WithName(copilotAliasProviderID),
openai.WithBaseURL(baseURL),
openai.WithAPIKey(token),
openai.WithHTTPClient(createCopilotHTTPClient(token, expiresAt, config.TLSSkipVerify)),
openai.WithUseResponsesAPI(),
openai.WithResponsesAPIFunc(copilotUsesResponsesAPI),
openai.WithObjectMode(fantasy.ObjectModeTool),
}
provider, err := openai.New(opts...)
if err != nil {
return nil, fmt.Errorf("failed to create GitHub Copilot provider: %w", err)
}
model, err := provider.LanguageModel(ctx, modelName)
if err != nil {
return nil, fmt.Errorf("failed to create GitHub Copilot model: %w", err)
}
providerOpts := buildOpenAIProviderOptions(config, modelName)
return &ProviderResult{Model: model, ProviderOptions: providerOpts}, nil
}
// copilotUsesResponsesAPI selects the OpenAI Responses API for Copilot models
// known to support it. Non-gpt models are rejected before provider creation.
func copilotUsesResponsesAPI(modelID string) bool {
return strings.HasPrefix(modelID, "gpt-5")
}
// createOpenAICodexProvider creates a provider for ChatGPT/Codex OAuth tokens.
// Uses the chatgpt.com/backend-api/codex endpoint with special headers.
func createOpenAICodexProvider(ctx context.Context, config *ProviderConfig, modelName, token, accountID string) (*ProviderResult, error) {
@@ -1015,12 +1164,12 @@ func createOpenAICodexProvider(ctx context.Context, config *ProviderConfig, mode
provider, err := openai.New(opts...)
if err != nil {
return nil, fmt.Errorf("failed to create OpenAI Codex provider: %w", err)
return nil, wrapProviderErr("OpenAI Codex", "provider", err)
}
model, err := provider.LanguageModel(ctx, modelName)
if err != nil {
return nil, fmt.Errorf("failed to create OpenAI Codex model: %w", err)
return nil, wrapProviderErr("OpenAI Codex", "model", err)
}
providerOpts := buildCodexProviderOptions(config, modelName)
@@ -1117,6 +1266,87 @@ func (t *codexTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return t.base.RoundTrip(newReq)
}
// createCopilotHTTPClient returns an HTTP client that injects Copilot-specific
// authorization and client metadata headers. The token and expiry are cached in
// the transport so streaming requests do not hit credentials.json on every
// RoundTrip; the credential manager is consulted only near expiry.
func createCopilotHTTPClient(token string, expiresAt int64, skipVerify bool) *http.Client {
var base http.RoundTripper
if skipVerify {
base = &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
} else {
base = http.DefaultTransport
}
return &http.Client{
Transport: &copilotTransport{
base: base,
token: token,
expiresAt: expiresAt,
},
Timeout: 120 * time.Second,
}
}
// copilotTransport decorates requests for api.githubcopilot.com.
//
// It owns a cached Copilot access token. When the token is still valid, the hot
// path is in-memory only. Near expiry it refreshes through CredentialManager,
// which updates both the cache here and credentials.json.
type copilotTransport struct {
base http.RoundTripper
token string
expiresAt int64
mu sync.Mutex
}
func (t *copilotTransport) RoundTrip(req *http.Request) (*http.Response, error) {
token := t.cachedToken(req.Context())
newReq := req.Clone(req.Context())
newReq.Header.Set("Authorization", "Bearer "+token)
newReq.Header.Set("Copilot-Integration-Id", copilotIntegrationID)
newReq.Header.Set("Editor-Version", copilotEditorVersion)
newReq.Header.Set("Editor-Plugin-Version", copilotEditorPluginVersion)
newReq.Header.Set("Openai-Intent", copilotOpenAIIntent)
newReq.Header.Set("User-Agent", copilotUserAgent)
newReq.Header.Set("X-GitHub-Api-Version", copilotGitHubAPIVersion)
return t.base.RoundTrip(newReq)
}
// cachedToken returns the cached token unless it is within the five-minute
// refresh window. Refresh errors fall back to the last token so the request can
// surface any authoritative auth failure from the Copilot API.
func (t *copilotTransport) cachedToken(ctx context.Context) string {
t.mu.Lock()
defer t.mu.Unlock()
if t.expiresAt == 0 || time.Now().Unix() < t.expiresAt-300 {
return t.token
}
cm, err := auth.NewCredentialManager()
if err != nil {
return t.token
}
fresh, err := cm.GetValidCopilotAccessTokenContext(ctx)
if err != nil || fresh == "" {
return t.token
}
t.token = fresh
if creds, err := cm.GetCopilotCredentials(); err == nil && creds != nil && creds.CopilotAccessToken == fresh {
t.expiresAt = creds.ExpiresAt
}
return t.token
}
func createGoogleProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
apiKey := firstNonEmpty(
config.ProviderAPIKey,
@@ -1133,12 +1363,12 @@ func createGoogleProvider(ctx context.Context, config *ProviderConfig, modelName
provider, err := google.New(opts...)
if err != nil {
return nil, fmt.Errorf("failed to create Google provider: %w", err)
return nil, wrapProviderErr("Google", "provider", err)
}
model, err := provider.LanguageModel(ctx, modelName)
if err != nil {
return nil, fmt.Errorf("failed to create Google model: %w", err)
return nil, wrapProviderErr("Google", "model", err)
}
return &ProviderResult{Model: model}, nil
@@ -1171,12 +1401,12 @@ func createAzureProvider(ctx context.Context, config *ProviderConfig, modelName
provider, err := azure.New(opts...)
if err != nil {
return nil, fmt.Errorf("failed to create Azure OpenAI provider: %w", err)
return nil, wrapProviderErr("Azure OpenAI", "provider", err)
}
model, err := provider.LanguageModel(ctx, modelName)
if err != nil {
return nil, fmt.Errorf("failed to create Azure OpenAI model: %w", err)
return nil, wrapProviderErr("Azure OpenAI", "model", err)
}
return &ProviderResult{Model: model}, nil
@@ -1196,12 +1426,12 @@ func createOpenRouterProvider(ctx context.Context, config *ProviderConfig, model
provider, err := openrouter.New(opts...)
if err != nil {
return nil, fmt.Errorf("failed to create OpenRouter provider: %w", err)
return nil, wrapProviderErr("OpenRouter", "provider", err)
}
model, err := provider.LanguageModel(ctx, modelName)
if err != nil {
return nil, fmt.Errorf("failed to create OpenRouter model: %w", err)
return nil, wrapProviderErr("OpenRouter", "model", err)
}
return &ProviderResult{Model: model}, nil
@@ -1213,12 +1443,12 @@ func createBedrockProvider(ctx context.Context, config *ProviderConfig, modelNam
// Bedrock uses AWS SDK default credential chain (env vars, shared config, etc.)
provider, err := bedrock.New(opts...)
if err != nil {
return nil, fmt.Errorf("failed to create Bedrock provider: %w", err)
return nil, wrapProviderErr("Bedrock", "provider", err)
}
model, err := provider.LanguageModel(ctx, modelName)
if err != nil {
return nil, fmt.Errorf("failed to create Bedrock model: %w", err)
return nil, wrapProviderErr("Bedrock", "model", err)
}
return &ProviderResult{Model: model}, nil
@@ -1242,12 +1472,12 @@ func createVercelProvider(ctx context.Context, config *ProviderConfig, modelName
provider, err := vercel.New(opts...)
if err != nil {
return nil, fmt.Errorf("failed to create Vercel provider: %w", err)
return nil, wrapProviderErr("Vercel", "provider", err)
}
model, err := provider.LanguageModel(ctx, modelName)
if err != nil {
return nil, fmt.Errorf("failed to create Vercel model: %w", err)
return nil, wrapProviderErr("Vercel", "model", err)
}
return &ProviderResult{Model: model}, nil
@@ -1300,12 +1530,17 @@ func createCustomProvider(ctx context.Context, config *ProviderConfig, modelName
p, err := openai.New(opts...)
if err != nil {
return nil, fmt.Errorf("failed to create custom provider: %w", err)
return nil, wrapProviderErr("custom", "provider", err)
}
model, err := p.LanguageModel(ctx, modelName)
apiModelName := modelName
if modelInfo != nil && modelInfo.APIModelName != "" {
apiModelName = modelInfo.APIModelName
}
model, err := p.LanguageModel(ctx, apiModelName)
if err != nil {
return nil, fmt.Errorf("failed to create custom model: %w", err)
return nil, wrapProviderErr("custom", "model", err)
}
return &ProviderResult{Model: model}, nil
@@ -1349,12 +1584,12 @@ func createOllamaProvider(ctx context.Context, config *ProviderConfig, modelName
provider, err := openaicompat.New(opts...)
if err != nil {
return nil, fmt.Errorf("failed to create Ollama provider: %w", err)
return nil, wrapProviderErr("Ollama", "provider", err)
}
model, err := provider.LanguageModel(ctx, modelName)
if err != nil {
return nil, fmt.Errorf("failed to create Ollama model: %w", err)
return nil, wrapProviderErr("Ollama", "model", err)
}
return &ProviderResult{
+27 -11
View File
@@ -16,17 +16,18 @@ var embeddedModelsJSON []byte
// ModelInfo represents information about a specific model.
type ModelInfo struct {
ID string
Name string
Family string // Model family (e.g., "claude", "gpt", "gemini")
Attachment bool
Reasoning bool
Temperature bool
Cost Cost
Limit Limit
ProviderNPM string // Model-specific provider npm override (e.g. "@ai-sdk/anthropic")
BaseURL string // Per-model base URL override (custom models only)
APIKey string // Per-model API key override (custom models only)
ID string
Name string
Family string // Model family (e.g., "claude", "gpt", "gemini")
Attachment bool
Reasoning bool
Temperature bool
Cost Cost
Limit Limit
ProviderNPM string // Model-specific provider npm override (e.g. "@ai-sdk/anthropic")
BaseURL string // Per-model base URL override (custom models only)
APIKey string // Per-model API key override (custom models only)
APIModelName string // Per-model API model name override (custom models only)
// Params holds per-model generation parameter defaults. These are applied
// when the user hasn't explicitly set the corresponding CLI flag or global
@@ -246,6 +247,7 @@ func loadEmbeddedProviders() map[string]modelsDBProvider {
// doesn't track yet. Callers should treat a nil return as "unknown model"
// and continue with sensible defaults.
func (r *ModelsRegistry) LookupModel(provider, modelID string) *ModelInfo {
provider = catalogProviderID(provider)
providerInfo, exists := r.providers[provider]
if !exists {
return nil
@@ -273,6 +275,7 @@ func LookupModelForSettings(modelString string) *ModelInfo {
// getRequiredEnvVars returns the required environment variables for a provider.
func (r *ModelsRegistry) getRequiredEnvVars(provider string) ([]string, error) {
provider = catalogProviderID(provider)
providerInfo, exists := r.providers[provider]
if !exists {
return nil, fmt.Errorf("unsupported provider: %s", provider)
@@ -287,6 +290,7 @@ func (r *ModelsRegistry) getRequiredEnvVars(provider string) ([]string, error) {
// variables. Returns nil for providers not in the registry (unknown
// providers are assumed to handle auth themselves or via --provider-api-key).
func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) error {
provider = catalogProviderID(provider)
if apiKey != "" {
return nil
}
@@ -311,6 +315,15 @@ func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) err
}
}
// For GitHub Copilot, check stored GitHub OAuth credentials.
if provider == copilotProviderID {
if cm, err := auth.NewCredentialManager(); err == nil {
if has, _ := cm.HasCopilotCredentials(); has {
return nil
}
}
}
envVars, err := r.getRequiredEnvVars(provider)
if err != nil {
// Unknown provider — nothing to validate
@@ -350,6 +363,7 @@ func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) err
// SuggestModels returns similar model names when an invalid model is provided.
func (r *ModelsRegistry) SuggestModels(provider, invalidModel string) []string {
provider = catalogProviderID(provider)
providerInfo, exists := r.providers[provider]
if !exists {
return nil
@@ -415,6 +429,7 @@ func isProviderLLMSupported(providerID string, info *ProviderInfo) bool {
// GetModelsForProvider returns all models for a specific provider.
func (r *ModelsRegistry) GetModelsForProvider(provider string) (map[string]ModelInfo, error) {
provider = catalogProviderID(provider)
providerInfo, exists := r.providers[provider]
if !exists {
return nil, fmt.Errorf("unsupported provider: %s", provider)
@@ -425,6 +440,7 @@ func (r *ModelsRegistry) GetModelsForProvider(provider string) (map[string]Model
// GetProviderInfo returns the full provider info, or nil if not found.
func (r *ModelsRegistry) GetProviderInfo(provider string) *ProviderInfo {
provider = catalogProviderID(provider)
info, exists := r.providers[provider]
if !exists {
return nil
+170
View File
@@ -0,0 +1,170 @@
package models
import (
"context"
"fmt"
"os"
"regexp"
"strings"
"charm.land/fantasy/providers/google"
)
// templatePlaceholderRe matches "${NAME}" placeholders in URL templates from
// models.dev (e.g. "https://${DATABRICKS_HOST}/ai-gateway/mlflow/v1").
var templatePlaceholderRe = regexp.MustCompile(`\$\{([A-Z0-9_]+)\}`)
// templateEnvVarOverrides supplies fallback environment variable names for
// placeholders that providers commonly use under non-obvious env names.
// The placeholder name itself is always tried first; this map adds extra
// names to try when the placeholder doesn't match the canonical env var.
var templateEnvVarOverrides = map[string][]string{
"CLOUDFLARE_ACCOUNT_ID": {"CF_ACCOUNT_ID"},
"CLOUDFLARE_GATEWAY_NAME": {"CF_GATEWAY", "CLOUDFLARE_GATEWAY"},
"DATABRICKS_HOST": {"DATABRICKS_WORKSPACE_URL"},
"SNOWFLAKE_ACCOUNT": {"SNOWFLAKE_ACCOUNT_ID"},
}
// resolveTemplatedAPIURL substitutes "${VAR}" placeholders in apiURL with the
// values of the named environment variables. Returns:
// - ("", nil) when apiURL contains no placeholders (caller keeps current URL),
// - (resolved, nil) when every placeholder was resolved,
// - ("", error) when one or more placeholders are unset, with a message that
// names the missing env vars and points at the relevant provider.
//
// The info parameter is used purely for error messaging (provider name).
func resolveTemplatedAPIURL(apiURL string, info *ProviderInfo) (string, error) {
if apiURL == "" || !strings.Contains(apiURL, "${") {
return "", nil
}
var missing []string
resolved := templatePlaceholderRe.ReplaceAllStringFunc(apiURL, func(match string) string {
// match is "${NAME}". Extract NAME.
name := match[2 : len(match)-1]
if v := os.Getenv(name); v != "" {
return v
}
for _, alt := range templateEnvVarOverrides[name] {
if v := os.Getenv(alt); v != "" {
return v
}
}
missing = append(missing, name)
return match
})
if len(missing) > 0 {
providerName := info.ID
if info.Name != "" {
providerName = info.Name
}
return "", fmt.Errorf(
"provider %s requires environment variable(s) %s to construct its API URL (%s); "+
"set them or pass --provider-url to override",
providerName, strings.Join(missing, ", "), apiURL,
)
}
return resolved, nil
}
// ResolveProviderBaseURL returns the base API URL kit will use when talking to
// the given provider, applying the same resolution order as CreateProvider:
//
// 1. The provider's `api` field from the models.dev registry.
// 2. The hard-coded default base URL of its npm SDK package (e.g.
// @ai-sdk/groq → https://api.groq.com/openai/v1).
// 3. Template substitution against the current process environment when the
// URL contains "${VAR}" placeholders (e.g. cloudflare-workers-ai needs
// CLOUDFLARE_ACCOUNT_ID).
//
// It returns an error when the provider is unknown, when no URL can be derived,
// or when a templated URL has unset placeholders. The error message is suitable
// for direct display to end users.
//
// Note: providers handled by bespoke auth schemes (amazon-bedrock SigV4,
// azure resource URLs, google-vertex project/location, sap-ai-core customer
// deployments) may return either an empty URL or a regional/templated URL —
// the actual endpoint is finalised inside their native handlers and depends on
// runtime credentials.
func ResolveProviderBaseURL(providerID string) (string, error) {
registry := GetGlobalRegistry()
info := registry.GetProviderInfo(providerID)
if info == nil {
return "", fmt.Errorf("unknown provider: %s", providerID)
}
apiURL := info.API
if apiURL == "" {
if defaultURL, ok := sdkDefaultBaseURL[info.NPM]; ok {
apiURL = defaultURL
}
}
if apiURL == "" {
return "", fmt.Errorf(
"provider %s has no default API URL: its npm package %q does not "+
"ship a built-in baseURL (likely Bedrock SigV4, Azure deployment, "+
"Vertex project/location, or a customer-hosted endpoint). "+
"Pass --provider-url or set the provider's URL env var",
providerID, info.NPM,
)
}
if strings.Contains(apiURL, "${") {
resolved, err := resolveTemplatedAPIURL(apiURL, info)
if err != nil {
return apiURL, err
}
return resolved, nil
}
return apiURL, nil
}
// createGoogleVertexProvider creates a Google Gemini provider that targets the
// Vertex AI backend (rather than the public generativelanguage.googleapis.com
// endpoint). It requires the same project/region environment variables as
// google-vertex-anthropic.
func createGoogleVertexProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
projectID := firstNonEmpty(
os.Getenv("GOOGLE_VERTEX_PROJECT"),
os.Getenv("GOOGLE_CLOUD_PROJECT"),
os.Getenv("GCLOUD_PROJECT"),
os.Getenv("CLOUDSDK_CORE_PROJECT"),
)
if projectID == "" {
return nil, fmt.Errorf(
"google Vertex project ID not provided, set GOOGLE_VERTEX_PROJECT, " +
"GOOGLE_CLOUD_PROJECT, or GCLOUD_PROJECT environment variable",
)
}
region := firstNonEmpty(
os.Getenv("GOOGLE_VERTEX_LOCATION"),
os.Getenv("CLOUD_ML_REGION"),
)
if region == "" {
region = "global"
}
opts := []google.Option{
google.WithVertex(projectID, region),
google.WithName("google-vertex"),
}
if config.TLSSkipVerify {
opts = append(opts, google.WithHTTPClient(createHTTPClientWithTLSConfig(true)))
}
provider, err := google.New(opts...)
if err != nil {
return nil, wrapProviderErr("Google Vertex", "provider", err)
}
model, err := provider.LanguageModel(ctx, modelName)
if err != nil {
return nil, wrapProviderErr("Google Vertex", "model", err)
}
return &ProviderResult{Model: model}, nil
}
+214
View File
@@ -0,0 +1,214 @@
package models
import (
"context"
"reflect"
"strings"
"testing"
)
// TestSDKDefaultBaseURL_CoversAllWireMappedPackages enforces the invariant
// that every npm package recognised by the auto-router has a corresponding
// default base URL — otherwise a provider that omits its `api` field in the
// registry would silently fail to route at runtime.
func TestSDKDefaultBaseURL_CoversAllWireMappedPackages(t *testing.T) {
for npm := range npmToWireProtocol {
// @ai-sdk/openai-compatible is a wire family, not a single SDK with
// a default URL — providers using it always supply their own `api`.
if npm == "@ai-sdk/openai-compatible" {
continue
}
if _, ok := sdkDefaultBaseURL[npm]; !ok {
t.Errorf("npm %q is in npmToWireProtocol but has no sdkDefaultBaseURL entry — "+
"providers using this npm with no `api` field cannot be routed", npm)
}
}
}
// TestSDKDefaultBaseURL_AllURLsAreAbsolute sanity-checks that every default
// URL is a well-formed absolute https endpoint (catches typos in the table).
func TestSDKDefaultBaseURL_AllURLsAreAbsolute(t *testing.T) {
for npm, url := range sdkDefaultBaseURL {
if !strings.HasPrefix(url, "https://") {
t.Errorf("sdkDefaultBaseURL[%q] = %q is not an absolute https URL", npm, url)
}
}
}
// TestResolveProviderBaseURL_RegistryFirst verifies that the registry's `api`
// field wins over any SDK default.
func TestResolveProviderBaseURL_RegistryFirst(t *testing.T) {
// xai is in the registry with no `api` field — its URL comes from the
// SDK default. Use a synthetic registry-backed provider to test the
// priority via the public registry instead.
url, err := ResolveProviderBaseURL("openai")
if err != nil {
t.Fatalf("ResolveProviderBaseURL(openai): %v", err)
}
if url != "https://api.openai.com/v1" {
t.Errorf("openai URL = %q, want https://api.openai.com/v1", url)
}
}
// TestResolveProviderBaseURL_SDKDefaultFallback verifies that providers
// without an `api` field (groq, cerebras, xai, …) resolve to their SDK
// hard-coded default URL.
func TestResolveProviderBaseURL_SDKDefaultFallback(t *testing.T) {
tests := map[string]string{
"groq": "https://api.groq.com/openai/v1",
"cerebras": "https://api.cerebras.ai/v1",
"xai": "https://api.x.ai/v1",
"mistral": "https://api.mistral.ai/v1",
"perplexity": "https://api.perplexity.ai",
"togetherai": "https://api.together.xyz/v1",
"deepinfra": "https://api.deepinfra.com/v1/openai",
"cohere": "https://api.cohere.com/compatibility/v1",
"v0": "https://api.v0.dev/v1",
"aihubmix": "https://aihubmix.com/v1",
"venice": "https://api.venice.ai/api/v1",
"openrouter": "https://openrouter.ai/api/v1",
}
for providerID, wantURL := range tests {
t.Run(providerID, func(t *testing.T) {
got, err := ResolveProviderBaseURL(providerID)
if err != nil {
t.Fatalf("ResolveProviderBaseURL(%s): %v", providerID, err)
}
if got != wantURL {
t.Errorf("%s URL = %q, want %q", providerID, got, wantURL)
}
})
}
}
// TestResolveProviderBaseURL_TemplatedURL_MissingEnv verifies that providers
// whose URL contains "${VAR}" placeholders surface a targeted error when the
// environment variables are unset.
func TestResolveProviderBaseURL_TemplatedURL_MissingEnv(t *testing.T) {
// cloudflare-workers-ai's api URL contains ${CLOUDFLARE_ACCOUNT_ID}.
// Ensure the variable is unset for this test.
t.Setenv("CLOUDFLARE_ACCOUNT_ID", "")
t.Setenv("CF_ACCOUNT_ID", "")
_, err := ResolveProviderBaseURL("cloudflare-workers-ai")
if err == nil {
t.Fatal("expected error for unset CLOUDFLARE_ACCOUNT_ID, got nil")
}
if !strings.Contains(err.Error(), "CLOUDFLARE_ACCOUNT_ID") {
t.Errorf("error should name the missing env var, got: %v", err)
}
if !strings.Contains(err.Error(), "--provider-url") {
t.Errorf("error should suggest --provider-url override, got: %v", err)
}
}
// TestResolveProviderBaseURL_TemplatedURL_Resolved verifies env-var
// substitution succeeds when the placeholder is set.
func TestResolveProviderBaseURL_TemplatedURL_Resolved(t *testing.T) {
t.Setenv("CLOUDFLARE_ACCOUNT_ID", "test-acct-123")
got, err := ResolveProviderBaseURL("cloudflare-workers-ai")
if err != nil {
t.Fatalf("ResolveProviderBaseURL: %v", err)
}
if !strings.Contains(got, "test-acct-123") {
t.Errorf("resolved URL %q should contain test-acct-123", got)
}
if strings.Contains(got, "${") {
t.Errorf("resolved URL %q still contains template placeholder", got)
}
}
// TestResolveProviderBaseURL_UnknownProvider verifies the not-in-registry error.
func TestResolveProviderBaseURL_UnknownProvider(t *testing.T) {
_, err := ResolveProviderBaseURL("does-not-exist")
if err == nil {
t.Fatal("expected error for unknown provider, got nil")
}
if !strings.Contains(err.Error(), "unknown provider") {
t.Errorf("error should say 'unknown provider', got: %v", err)
}
}
// TestAutoRouteProvider_SDKDefaultURLFallback verifies that providers whose
// registry entry omits the `api` field (groq, mistral, xai, etc.) are still
// auto-routed by falling back to the SDK's hard-coded default URL.
func TestAutoRouteProvider_SDKDefaultURLFallback(t *testing.T) {
tests := []struct {
name string
npmPackage string
wantInURL string
}{
{"groq", "@ai-sdk/groq", "groq.com"},
{"cerebras", "@ai-sdk/cerebras", "cerebras.ai"},
{"xai", "@ai-sdk/xai", "x.ai"},
{"mistral", "@ai-sdk/mistral", "mistral.ai"},
{"v0", "@ai-sdk/vercel", "v0.dev"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reg := &ModelsRegistry{
providers: map[string]ProviderInfo{
"testfallback": {
ID: "testfallback",
Name: "Test Fallback",
Env: []string{"TESTFALLBACK_API_KEY"},
NPM: tt.npmPackage,
// API intentionally omitted — must fall back to SDK default.
Models: map[string]ModelInfo{
"any-model": {ID: "any-model", Name: "any-model"},
},
},
},
}
config := &ProviderConfig{ProviderAPIKey: "test-key"}
result, err := autoRouteProvider(context.Background(), config, "testfallback", "any-model", reg)
if err != nil {
t.Fatalf("autoRouteProvider returned error: %v", err)
}
if result == nil || result.Model == nil {
t.Fatal("autoRouteProvider returned nil model")
}
// Verify the SDK default URL was picked up.
if !strings.Contains(config.ProviderURL, tt.wantInURL) {
t.Errorf("config.ProviderURL = %q, want substring %q (SDK default)",
config.ProviderURL, tt.wantInURL)
}
// All these wrappers route through the openai-compat wire.
gotType := reflect.TypeOf(result.Model).String()
if gotType != "openai.languageModel" {
t.Errorf("model type = %q, want openai.languageModel", gotType)
}
})
}
}
// TestResolveTemplatedAPIURL_NoPlaceholders verifies that URLs without
// placeholders are returned as-is (the caller keeps using the original).
func TestResolveTemplatedAPIURL_NoPlaceholders(t *testing.T) {
got, err := resolveTemplatedAPIURL("https://api.example.com/v1", &ProviderInfo{ID: "x"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "" {
t.Errorf("got %q, want empty string for URL with no placeholders", got)
}
}
// TestResolveTemplatedAPIURL_AltEnvVar verifies that the alternative env-var
// names (e.g. CF_ACCOUNT_ID for CLOUDFLARE_ACCOUNT_ID) are honoured.
func TestResolveTemplatedAPIURL_AltEnvVar(t *testing.T) {
t.Setenv("CLOUDFLARE_ACCOUNT_ID", "")
t.Setenv("CF_ACCOUNT_ID", "alt-name-123")
got, err := resolveTemplatedAPIURL(
"https://api.cloudflare.com/client/v4/accounts/${CLOUDFLARE_ACCOUNT_ID}/ai/v1",
&ProviderInfo{ID: "cloudflare-workers-ai"},
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(got, "alt-name-123") {
t.Errorf("resolved URL %q should have picked up CF_ACCOUNT_ID alternative", got)
}
}
+4 -9
View File
@@ -70,7 +70,8 @@ func ParseTemplate(path string) (*PromptTemplate, error) {
}
// ParseCommandArgs splits a command line into arguments respecting quotes.
// It handles single quotes, double quotes, and backslash escaping.
// It handles single quotes, double quotes, backslash escaping, and splits on
// spaces and tabs.
func ParseCommandArgs(input string) []string {
var args []string
var current strings.Builder
@@ -78,7 +79,7 @@ func ParseCommandArgs(input string) []string {
inDoubleQuote := false
escaped := false
for i, r := range input {
for _, r := range input {
if escaped {
current.WriteRune(r)
escaped = false
@@ -101,7 +102,7 @@ func ParseCommandArgs(input string) []string {
continue
}
if r == ' ' && !inSingleQuote && !inDoubleQuote {
if (r == ' ' || r == '\t') && !inSingleQuote && !inDoubleQuote {
if current.Len() > 0 {
args = append(args, current.String())
current.Reset()
@@ -110,7 +111,6 @@ func ParseCommandArgs(input string) []string {
}
current.WriteRune(r)
_ = i // silence unused warning when we need position later
}
if current.Len() > 0 {
@@ -325,8 +325,3 @@ func (t *PromptTemplate) Expand(argsInput string) string {
args := ParseCommandArgs(argsInput)
return SubstituteArgs(t.Content, args)
}
// ExpandWithArgs substitutes the provided arguments into the template content.
func (t *PromptTemplate) ExpandWithArgs(args []string) string {
return SubstituteArgs(t.Content, args)
}
-15
View File
@@ -458,11 +458,6 @@ func (tm *TreeManager) AppendLLMMessage(msg fantasy.Message) (string, error) {
return tm.AppendMessage(message.FromLLMMessage(msg))
}
// Deprecated: Use AppendLLMMessage instead.
func (tm *TreeManager) AppendFantasyMessage(msg fantasy.Message) (string, error) {
return tm.AppendLLMMessage(msg)
}
// AppendModelChange records a model/provider change.
func (tm *TreeManager) AppendModelChange(provider, modelID string) (string, error) {
tm.mu.Lock()
@@ -1170,11 +1165,6 @@ func (tm *TreeManager) AddLLMMessages(msgs []fantasy.Message) error {
return tm.flushLocked()
}
// Deprecated: Use AddLLMMessages instead.
func (tm *TreeManager) AddFantasyMessages(msgs []fantasy.Message) error {
return tm.AddLLMMessages(msgs)
}
// GetLLMMessages builds the context and returns just the messages.
// This satisfies the same conceptual role as the old Manager.GetMessages().
func (tm *TreeManager) GetLLMMessages() []fantasy.Message {
@@ -1182,11 +1172,6 @@ func (tm *TreeManager) GetLLMMessages() []fantasy.Message {
return msgs
}
// Deprecated: Use GetLLMMessages instead.
func (tm *TreeManager) GetFantasyMessages() []fantasy.Message {
return tm.GetLLMMessages()
}
// --- Internal helpers ---
// addEntryToIndex adds an entry to the in-memory indices.
+12 -7
View File
@@ -18,8 +18,11 @@ type PromptTemplate struct {
Variables []string
}
// variableRe matches {{variable_name}} placeholders.
var variableRe = regexp.MustCompile(`\{\{(\w+)\}\}`)
// variableRe matches {{variable_name}} placeholders, tolerating surrounding
// whitespace inside the braces (e.g. {{ name }}). This is the canonical
// template grammar shared by skill prompts and the extension template API
// (pkg/kit ParseTemplate/RenderTemplate delegate here).
var variableRe = regexp.MustCompile(`\{\{\s*(\w+)\s*\}\}`)
// NewPromptTemplate creates a PromptTemplate, automatically extracting
// variable names from {{...}} placeholders in content.
@@ -50,11 +53,13 @@ func LoadPromptTemplate(path string) (*PromptTemplate, error) {
// Expand replaces all {{variable}} placeholders with values from the
// provided map. Missing variables are left as-is (no error).
func (t *PromptTemplate) Expand(values map[string]string) string {
result := t.Content
for k, v := range values {
result = strings.ReplaceAll(result, "{{"+k+"}}", v)
}
return result
return variableRe.ReplaceAllStringFunc(t.Content, func(m string) string {
name := variableRe.FindStringSubmatch(m)[1]
if v, ok := values[name]; ok {
return v
}
return m
})
}
// ExpandStrict replaces all {{variable}} placeholders and returns an error
+60 -57
View File
@@ -641,30 +641,16 @@ func (m *MCPToolManager) ExecuteTool(ctx context.Context, prefixedName, inputJSO
Request: mcp.Request{Method: "tools/call"},
Params: callParams,
}
result, callErr := conn.client.CallTool(ctx, callRequest)
if callErr != nil {
if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) {
if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, mapping.serverName, callErr); flowErr != nil {
return nil, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", mapping.originalName, flowErr)
}
result, callErr = conn.client.CallTool(ctx, callRequest)
if callErr != nil {
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
return nil, fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr)
}
} else {
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
return nil, fmt.Errorf("failed to call mcp tool: %w", callErr)
}
var result *mcp.CallToolResult
err := m.withOAuthRetry(ctx, mapping.serverName, mapping.originalName, func() error {
var callErr error
result, callErr = conn.client.CallTool(ctx, callRequest)
return callErr
})
if err != nil {
return nil, err
}
marshaledResult, mErr := json.Marshal(result)
if mErr != nil {
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
}
return &MCPToolResult{
Content: string(marshaledResult),
IsError: result.IsError,
}, nil
return marshalToolResult(result)
}
// Task-augmented path. Bypass the upstream CallTool helper because its
@@ -683,40 +669,25 @@ func (m *MCPToolManager) ExecuteTool(ctx context.Context, prefixedName, inputJSO
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
return nil, fmt.Errorf("failed to call mcp tool: %w", callErr)
}
marshaledResult, mErr := json.Marshal(result)
if mErr != nil {
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
}
return &MCPToolResult{Content: string(marshaledResult), IsError: result.IsError}, nil
return marshalToolResult(result)
}
callResult, taskResult, callErr := callToolWithTask(ctx, rawClient, callParams)
if callErr != nil {
if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) {
if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, mapping.serverName, callErr); flowErr != nil {
return nil, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", mapping.originalName, flowErr)
}
callResult, taskResult, callErr = callToolWithTask(ctx, rawClient, callParams)
if callErr != nil {
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
return nil, fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr)
}
} else {
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
return nil, fmt.Errorf("failed to call mcp tool: %w", callErr)
}
var (
callResult *mcp.CallToolResult
taskResult *mcp.CreateTaskResult
)
err = m.withOAuthRetry(ctx, mapping.serverName, mapping.originalName, func() error {
var callErr error
callResult, taskResult, callErr = callToolWithTask(ctx, rawClient, callParams)
return callErr
})
if err != nil {
return nil, err
}
// Server chose to answer synchronously — same shape as the no-task path.
if callResult != nil {
marshaledResult, mErr := json.Marshal(callResult)
if mErr != nil {
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
}
return &MCPToolResult{
Content: string(marshaledResult),
IsError: callResult.IsError,
}, nil
return marshalToolResult(callResult)
}
// Asynchronous task path: poll until terminal, then return the result.
@@ -732,18 +703,50 @@ func (m *MCPToolManager) ExecuteTool(ctx context.Context, prefixedName, inputJSO
}
// Adapt TaskResultResult → CallToolResult for downstream JSON shape parity.
adapted := &mcp.CallToolResult{
return marshalToolResult(&mcp.CallToolResult{
Content: final.Content,
StructuredContent: final.StructuredContent,
IsError: final.IsError,
})
}
// withOAuthRetry runs call once; when it fails with an OAuth error and an
// OAuth flow is configured, it re-authorizes the server and retries once.
// Connection failures are reported to the pool and wrapped uniformly. This
// consolidates the retry/error chain shared by the synchronous and
// task-augmented tool-call paths.
func (m *MCPToolManager) withOAuthRetry(ctx context.Context, serverName, toolName string, call func() error) error {
callErr := call()
if callErr == nil {
return nil
}
marshaledResult, mErr := json.Marshal(adapted)
if mErr != nil {
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) {
if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, serverName, callErr); flowErr != nil {
return fmt.Errorf("OAuth re-authorization failed for tool %s: %w", toolName, flowErr)
}
if callErr = call(); callErr != nil {
m.connectionPool.HandleConnectionError(serverName, callErr)
return fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr)
}
return nil
}
m.connectionPool.HandleConnectionError(serverName, callErr)
return fmt.Errorf("failed to call mcp tool: %w", callErr)
}
// marshalToolResult converts an MCP CallToolResult into the JSON-encoded
// MCPToolResult shape returned to the agent.
func marshalToolResult(result *mcp.CallToolResult) (*MCPToolResult, error) {
if result == nil {
return nil, errors.New("mcp tool call returned nil result")
}
marshaled, err := json.Marshal(result)
if err != nil {
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", err)
}
return &MCPToolResult{
Content: string(marshaledResult),
IsError: final.IsError,
Content: string(marshaled),
IsError: result.IsError,
}, nil
}
+17 -1
View File
@@ -146,9 +146,10 @@ var SlashCommands = []SlashCommand{
},
{
Name: "/new",
Description: "Start a new session",
Description: "Start a new session (optionally with an initial prompt)",
Category: "Navigation",
Aliases: []string{"/n"},
HasArgs: true,
},
{
Name: "/name",
@@ -167,6 +168,21 @@ var SlashCommands = []SlashCommand{
Category: "System",
Aliases: []string{"/cp"},
},
{
Name: "/retry",
Description: "Resubmit the last user message (e.g. after a provider error)",
Category: "System",
Aliases: []string{"/rt"},
},
{
Name: "/edit",
Description: "Open a file in $EDITOR (fuzzy-find a path, then edit)",
Category: "System",
Aliases: []string{"/ed"},
HasArgs: true,
// Note: no Complete callback — file fuzzy-finding is driven directly
// by InputComponent (mirroring the @file popup with directory drill).
},
{
Name: "/export",
Description: "Export session (JSONL by default, or /export path.jsonl)",
+29 -35
View File
@@ -2,7 +2,6 @@ package ui
import (
"fmt"
"strings"
"github.com/mark3labs/kit/internal/auth"
"github.com/mark3labs/kit/internal/models"
@@ -44,28 +43,39 @@ func parseModelName(modelString string) (provider, model string) {
// ollama or unrecognised models). This is used by the interactive TUI path
// which doesn't go through SetupCLI.
func CreateUsageTracker(modelString, providerAPIKey string) *UsageTracker {
provider, model := parseModelName(modelString)
if provider == "unknown" || model == "unknown" || provider == "ollama" {
return nil
}
registry := models.GetGlobalRegistry()
modelInfo := registry.LookupModel(provider, model)
modelInfo, provider := lookupTrackableModel(modelString)
if modelInfo == nil {
return nil
}
isOAuth := false
if provider == "anthropic" {
_, source, err := auth.GetAnthropicAPIKey(providerAPIKey)
if err == nil && strings.HasPrefix(source, "stored OAuth") {
isOAuth = true
}
}
isOAuth := provider == "anthropic" && auth.IsAnthropicOAuth(providerAPIKey)
return NewUsageTracker(modelInfo, provider, 80, isOAuth)
}
// UpdateUsageTrackerForModel refreshes an existing tracker after a model
// switch so token counting and cost reporting use the new model's metadata.
// No-op for a nil tracker or untrackable models (unknown/ollama).
func UpdateUsageTrackerForModel(t *UsageTracker, modelString, providerAPIKey string) {
if t == nil {
return
}
modelInfo, provider := lookupTrackableModel(modelString)
if modelInfo == nil {
return
}
isOAuth := provider == "anthropic" && auth.IsAnthropicOAuth(providerAPIKey)
t.UpdateModelInfo(modelInfo, provider, isOAuth)
}
// lookupTrackableModel resolves a model string to registry metadata, returning
// nil for models without usage tracking support (unknown or ollama models).
func lookupTrackableModel(modelString string) (*models.ModelInfo, string) {
provider, model := parseModelName(modelString)
if provider == "unknown" || model == "unknown" || provider == "ollama" {
return nil, provider
}
return models.GetGlobalRegistry().LookupModel(provider, model), provider
}
// SetupCLI creates, configures, and initializes a CLI instance with the provided
// options. It sets up model display, usage tracking for supported providers, and
// shows initial loading information. Returns nil in quiet mode or an initialized
@@ -89,24 +99,8 @@ func SetupCLI(opts *CLISetupOptions) (*CLI, error) {
}
// Set up usage tracking for supported providers
if provider != "unknown" && model != "unknown" {
// Skip usage tracking for ollama as it's not in models.dev
if provider != "ollama" {
registry := models.GetGlobalRegistry()
if modelInfo := registry.LookupModel(provider, model); modelInfo != nil {
// Check if OAuth credentials are being used for Anthropic models
isOAuth := false
if provider == "anthropic" {
_, source, err := auth.GetAnthropicAPIKey(opts.ProviderAPIKey)
if err == nil && strings.HasPrefix(source, "stored OAuth") {
isOAuth = true
}
}
usageTracker := NewUsageTracker(modelInfo, provider, 80, isOAuth) // Will be updated with actual width
cli.SetUsageTracker(usageTracker)
}
}
if usageTracker := CreateUsageTracker(opts.ModelString, opts.ProviderAPIKey); usageTracker != nil {
cli.SetUsageTracker(usageTracker)
}
// Display model info (the system message block provides its own spacing).
+27
View File
@@ -125,6 +125,33 @@ func ExtractAtPrefix(line string, cursorCol int) (hasAt bool, prefix string, sta
return true, raw, atIdx
}
// editTriggerPrefixes lists the command tokens (including trailing space)
// that activate the /edit fuzzy-file picker. Aliases come first so the
// longer alias "/edit " is matched before a hypothetical superset.
var editTriggerPrefixes = []string{"/edit ", "/ed "}
// ExtractEditPrefix detects when the input value is a single-line /edit (or
// alias) invocation and returns the path-portion the user has typed so far.
//
// Returns:
// - cmdLen: byte offset where the path argument begins (i.e. length of
// the matched command token, including its trailing space)
// - pathPrefix: text the user has typed after the command token
// - ok: true when the value matches one of the /edit triggers
//
// Multi-line values never match — /edit only makes sense as a single line.
func ExtractEditPrefix(value string) (cmdLen int, pathPrefix string, ok bool) {
if strings.Contains(value, "\n") {
return 0, "", false
}
for _, p := range editTriggerPrefixes {
if strings.HasPrefix(value, p) {
return len(p), value[len(p):], true
}
}
return 0, "", false
}
// GetFileSuggestions returns file/directory suggestions matching the given
// prefix. It tries `git ls-files` first (fast, respects .gitignore), then
// falls back to a simple directory walk.
+233
View File
@@ -0,0 +1,233 @@
// Package imagepreview renders low-resolution, in-terminal thumbnails of
// images using Unicode upper half-block characters (U+2580, "▀") combined
// with SGR foreground/background color codes.
//
// The technique stacks two vertical pixels into a single character cell: the
// foreground color paints the top pixel and the background color paints the
// bottom pixel. This produces pure styled text — no graphics escape sequences
// — so the output survives terminal multiplexers (tmux, zellij) untouched.
//
// The Kitty graphics protocol, Sixel, and iTerm2 inline images are
// deliberately NOT used: those are graphics escape-sequence protocols that
// tmux and zellij strip or mangle by default.
package imagepreview
import (
"bytes"
"fmt"
"image"
"image/color"
"os"
"strings"
// Register the standard image decoders so image.Decode can handle the
// common clipboard / attachment formats.
_ "image/gif"
_ "image/jpeg"
_ "image/png"
"github.com/charmbracelet/colorprofile"
"github.com/charmbracelet/x/ansi"
xdraw "golang.org/x/image/draw"
)
// upperHalfBlock is U+2580 ("▀"). The glyph fills the top half of a cell,
// letting the foreground color render the top pixel and the cell's background
// color render the bottom pixel.
const upperHalfBlock = "▀"
// reset is the SGR reset sequence appended after each rendered row.
const reset = "\x1b[0m"
// maxImageDimension is the largest width or height, in pixels, that Render will
// fully decode. Images larger than this in either axis are rejected before the
// expensive image.Decode call to guard against decompression bombs (small
// encoded payloads that expand to enormous pixel buffers).
const maxImageDimension = 20000
// Render returns a half-block ANSI thumbnail of the image, scaled to fit
// within maxCols x maxRows terminal cells while preserving aspect ratio.
//
// Each terminal cell encodes two vertically-stacked pixels, so the effective
// pixel resolution of the thumbnail is up to maxCols x (maxRows*2).
//
// Colors are emitted at the fidelity of the detected terminal color profile:
// truecolor (24-bit) when available, degrading to 256-color. When the
// terminal supports neither (no truecolor and no 256-color), Render returns
// an empty string and a nil error so the caller can fall back to a text
// indicator. A non-nil error is only returned when the image data cannot be
// decoded.
//
// bg is the color used to composite transparent pixels (typically the
// terminal background). A nil bg defaults to black.
func Render(data []byte, mediaType string, maxCols, maxRows int, bg color.Color) (string, error) {
profile := colorprofile.Env(os.Environ())
return renderWithProfile(data, maxCols, maxRows, bg, profile)
}
// renderWithProfile is the testable core of Render. It accepts an explicit
// color profile instead of detecting one from the environment.
func renderWithProfile(data []byte, maxCols, maxRows int, bg color.Color, profile colorprofile.Profile) (string, error) {
// Half-block fidelity needs at least 256-color support. Anything less
// degrades to the caller's text fallback.
if profile < colorprofile.ANSI256 {
return "", nil
}
if maxCols < 1 || maxRows < 1 {
return "", nil
}
if bg == nil {
bg = color.Black
}
// Guard against decompression bombs: inspect the header dimensions before
// fully decoding, so a small malicious payload cannot expand into an
// enormous pixel buffer.
cfg, _, err := image.DecodeConfig(bytes.NewReader(data))
if err != nil {
return "", fmt.Errorf("decode image config: %w", err)
}
if cfg.Width > maxImageDimension || cfg.Height > maxImageDimension {
return "", fmt.Errorf("decode image: dimensions %dx%d exceed limit %d", cfg.Width, cfg.Height, maxImageDimension)
}
img, _, err := image.Decode(bytes.NewReader(data))
if err != nil {
return "", fmt.Errorf("decode image: %w", err)
}
// Target pixel dimensions: one pixel per column horizontally and two
// pixels per row vertically (the half-block trick).
cols, rows := fitDimensions(img.Bounds().Dx(), img.Bounds().Dy(), maxCols, maxRows)
if cols < 1 || rows < 1 {
return "", nil
}
pxW, pxH := cols, rows*2
scaled := image.NewRGBA(image.Rect(0, 0, pxW, pxH))
xdraw.CatmullRom.Scale(scaled, scaled.Bounds(), img, img.Bounds(), xdraw.Over, nil)
var b strings.Builder
for y := 0; y < pxH; y += 2 {
for x := range pxW {
top := composite(scaled.At(x, y), bg)
bottom := composite(scaled.At(x, y+1), bg)
b.WriteString(sgr(top, bottom, profile))
b.WriteString(upperHalfBlock)
}
b.WriteString(reset)
if y+2 < pxH {
b.WriteByte('\n')
}
}
return b.String(), nil
}
// fitDimensions returns the largest cell dimensions (cols, rows) that fit a
// srcW x srcH image inside a maxCols x maxRows box while preserving aspect
// ratio. Because each cell stacks two vertical pixels, a terminal cell is
// treated as roughly twice as tall as it is wide, which keeps the thumbnail's
// aspect ratio visually correct.
func fitDimensions(srcW, srcH, maxCols, maxRows int) (cols, rows int) {
if srcW <= 0 || srcH <= 0 {
return 0, 0
}
// Work in pixel space: the box is maxCols wide and maxRows*2 tall.
maxPxW := float64(maxCols)
maxPxH := float64(maxRows * 2)
scale := maxPxW / float64(srcW)
if h := maxPxH / float64(srcH); h < scale {
scale = h
}
if scale > 1 {
scale = 1 // never upscale; keep the low-res look
}
pxW := int(float64(srcW) * scale)
pxH := int(float64(srcH) * scale)
if pxW < 1 {
pxW = 1
}
if pxH < 2 {
pxH = 2
}
// Convert back to cells; round the row count up to an even pixel height.
cols = pxW
rows = (pxH + 1) / 2
if cols > maxCols {
cols = maxCols
}
if rows > maxRows {
rows = maxRows
}
return cols, rows
}
// composite blends a (possibly translucent) pixel over the background color,
// returning an opaque color. Fully opaque pixels are returned unchanged.
func composite(c, bg color.Color) color.Color {
r, g, b, a := c.RGBA()
if a == 0xffff {
return c
}
br, bgc, bb, _ := bg.RGBA()
// Standard "over" alpha compositing in 16-bit space.
inv := 0xffff - a
out := color.RGBA64{
R: uint16(r + br*inv/0xffff),
G: uint16(g + bgc*inv/0xffff),
B: uint16(b + bb*inv/0xffff),
A: 0xffff,
}
return out
}
// sgr builds the SGR escape sequence that sets the foreground (top pixel) and
// background (bottom pixel) colors at the fidelity of the given profile.
func sgr(fg, bg color.Color, profile colorprofile.Profile) string {
if profile >= colorprofile.TrueColor {
fr, fgc, fb := rgb8(fg)
br, bgc, bb := rgb8(bg)
return fmt.Sprintf("\x1b[38;2;%d;%d;%d;48;2;%d;%d;%dm", fr, fgc, fb, br, bgc, bb)
}
return fmt.Sprintf("\x1b[38;5;%d;48;5;%dm", index256(fg, profile), index256(bg, profile))
}
// rgb8 reduces a color to 8-bit RGB components.
func rgb8(c color.Color) (r, g, b uint8) {
cr, cg, cb, _ := c.RGBA()
return uint8(cr >> 8), uint8(cg >> 8), uint8(cb >> 8)
}
// index256 converts a color to its nearest 256-color palette index using the
// supplied profile.
func index256(c color.Color, profile colorprofile.Profile) uint8 {
cc := profile.Convert(c)
if idx, ok := cc.(ansi.IndexedColor); ok {
return uint8(idx)
}
if idx, ok := cc.(ansi.BasicColor); ok {
return uint8(idx)
}
// Fallback: derive an index directly if conversion produced an
// unexpected type.
r, g, b := rgb8(c)
return ansi256FromRGB(r, g, b)
}
// ansi256FromRGB maps an 8-bit RGB color to the xterm 256-color cube. It is a
// best-effort fallback used only when profile.Convert does not yield a known
// indexed color type.
func ansi256FromRGB(r, g, b uint8) uint8 {
q := func(v uint8) int {
switch {
case v < 48:
return 0
case v < 115:
return 1
default:
return int((v - 35) / 40)
}
}
ri, gi, bi := q(r), q(g), q(b)
return uint8(16 + 36*ri + 6*gi + bi)
}
@@ -0,0 +1,193 @@
package imagepreview
import (
"bytes"
"image"
"image/color"
"image/png"
"strings"
"testing"
"github.com/charmbracelet/colorprofile"
)
// makePNG builds a simple w x h PNG filled with the given color and returns
// its encoded bytes.
func makePNG(t *testing.T, w, h int, c color.Color) []byte {
t.Helper()
img := image.NewRGBA(image.Rect(0, 0, w, h))
for y := range h {
for x := range w {
img.Set(x, y, c)
}
}
var buf bytes.Buffer
if err := png.Encode(&buf, img); err != nil {
t.Fatalf("encode png: %v", err)
}
return buf.Bytes()
}
func TestRenderTrueColor(t *testing.T) {
data := makePNG(t, 20, 20, color.RGBA{R: 255, A: 255})
out, err := renderWithProfile(data, 10, 5, color.Black, colorprofile.TrueColor)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if out == "" {
t.Fatal("expected non-empty thumbnail for truecolor profile")
}
if !strings.Contains(out, upperHalfBlock) {
t.Error("output should contain upper half block glyphs")
}
if !strings.Contains(out, "\x1b[38;2;") || !strings.Contains(out, "48;2;") {
t.Errorf("expected truecolor SGR sequences, got %q", out)
}
// Red fill should appear as 255;0;0 somewhere.
if !strings.Contains(out, "255;0;0") {
t.Errorf("expected red color in output, got %q", out)
}
}
func TestRenderANSI256(t *testing.T) {
data := makePNG(t, 20, 20, color.RGBA{G: 255, A: 255})
out, err := renderWithProfile(data, 8, 4, color.Black, colorprofile.ANSI256)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if out == "" {
t.Fatal("expected non-empty thumbnail for ANSI256 profile")
}
if !strings.Contains(out, "\x1b[38;5;") || !strings.Contains(out, "48;5;") {
t.Errorf("expected 256-color SGR sequences, got %q", out)
}
if strings.Contains(out, "38;2;") {
t.Errorf("ANSI256 output should not contain truecolor sequences, got %q", out)
}
}
func TestRenderDegradesBelowANSI256(t *testing.T) {
data := makePNG(t, 20, 20, color.RGBA{B: 255, A: 255})
for _, p := range []colorprofile.Profile{colorprofile.ANSI, colorprofile.ASCII, colorprofile.NoTTY} {
out, err := renderWithProfile(data, 10, 5, color.Black, p)
if err != nil {
t.Fatalf("profile %v: unexpected error: %v", p, err)
}
if out != "" {
t.Errorf("profile %v: expected empty fallback, got %q", p, out)
}
}
}
func TestRenderInvalidImage(t *testing.T) {
out, err := renderWithProfile([]byte("not an image"), 10, 5, color.Black, colorprofile.TrueColor)
if err == nil {
t.Fatal("expected error for invalid image data")
}
if out != "" {
t.Errorf("expected empty output on decode error, got %q", out)
}
}
func TestRenderRejectsOversizedImage(t *testing.T) {
// A header advertising dimensions beyond maxImageDimension must be
// rejected before full decode (decompression-bomb guard). image.RGBA
// allocation is avoided by only checking the config path here.
w := maxImageDimension + 1
data := makePNG(t, w, 1, color.White)
out, err := renderWithProfile(data, 10, 5, color.Black, colorprofile.TrueColor)
if err == nil {
t.Fatal("expected error for oversized image dimensions")
}
if out != "" {
t.Errorf("expected empty output for oversized image, got %q", out)
}
}
func TestRenderZeroBox(t *testing.T) {
data := makePNG(t, 20, 20, color.White)
out, err := renderWithProfile(data, 0, 0, color.Black, colorprofile.TrueColor)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if out != "" {
t.Errorf("expected empty output for zero-sized box, got %q", out)
}
}
func TestRenderNilBackgroundDefaults(t *testing.T) {
data := makePNG(t, 10, 10, color.RGBA{R: 10, G: 20, B: 30, A: 255})
out, err := renderWithProfile(data, 6, 3, nil, colorprofile.TrueColor)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if out == "" {
t.Fatal("expected output with nil background (defaults to black)")
}
}
func TestRowCountWithinBounds(t *testing.T) {
// A tall image should be capped at maxRows cells.
data := makePNG(t, 10, 100, color.White)
out, err := renderWithProfile(data, 20, 6, color.Black, colorprofile.TrueColor)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
rows := strings.Count(out, "\n") + 1
if rows > 6 {
t.Errorf("expected at most 6 rows, got %d", rows)
}
}
func TestColumnCountWithinBounds(t *testing.T) {
// A wide image should be capped at maxCols cells per row.
data := makePNG(t, 100, 10, color.White)
out, err := renderWithProfile(data, 8, 20, color.Black, colorprofile.TrueColor)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
firstRow := strings.SplitN(out, "\n", 2)[0]
cols := strings.Count(firstRow, upperHalfBlock)
if cols > 8 {
t.Errorf("expected at most 8 columns, got %d", cols)
}
if cols == 0 {
t.Error("expected at least one column")
}
}
func TestFitDimensionsPreservesAspect(t *testing.T) {
// 2:1 (wide) image into a 40x20 box. Pixel box is 40x40; width-bound.
cols, rows := fitDimensions(200, 100, 40, 20)
if cols != 40 {
t.Errorf("expected 40 cols, got %d", cols)
}
// pxH = 100 * (40/200) = 20 → 10 rows.
if rows != 10 {
t.Errorf("expected 10 rows, got %d", rows)
}
}
func TestFitDimensionsNeverUpscales(t *testing.T) {
cols, rows := fitDimensions(4, 4, 40, 20)
if cols != 4 || rows != 2 {
t.Errorf("expected 4x2 (no upscale), got %dx%d", cols, rows)
}
}
func TestCompositeOpaquePassthrough(t *testing.T) {
c := color.RGBA{R: 1, G: 2, B: 3, A: 255}
got := composite(c, color.White)
if got != color.Color(c) {
t.Errorf("opaque color should pass through unchanged, got %v", got)
}
}
func TestCompositeTransparentOverBackground(t *testing.T) {
// Fully transparent pixel over red background should yield red.
got := composite(color.RGBA{}, color.RGBA{R: 255, A: 255})
r, g, b, a := got.RGBA()
if r>>8 != 255 || g>>8 != 0 || b>>8 != 0 || a != 0xffff {
t.Errorf("expected opaque red, got r=%d g=%d b=%d a=%d", r>>8, g>>8, b>>8, a)
}
}
+224 -187
View File
@@ -2,6 +2,7 @@ package ui
import (
"fmt"
"image/color"
"sort"
"strings"
@@ -13,6 +14,7 @@ import (
"github.com/mark3labs/kit/internal/clipboard"
"github.com/mark3labs/kit/internal/ui/commands"
"github.com/mark3labs/kit/internal/ui/core"
"github.com/mark3labs/kit/internal/ui/imagepreview"
"github.com/mark3labs/kit/internal/ui/style"
)
@@ -42,6 +44,12 @@ type InputComponent struct {
popupHeight int
submitNext bool // defer submit one tick so popup dismisses cleanly
// popup is the shared PopupList used to render the / and @ autocomplete
// dropdowns. State (items, cursor, visible search-driven filter) is
// driven externally by InputComponent — we only use PopupList for the
// rendering chrome so all popups in the app look identical.
popup *PopupList
// Argument completion state. When the user types "/cmd " followed by
// a partial argument and the command has a Complete function, the popup
// switches to argument-completion mode showing suggestions from Complete.
@@ -53,10 +61,16 @@ type InputComponent struct {
// file path, the popup shows file/directory suggestions from the cwd.
fileMode bool // true when showing @file completions
filePrefix string // current text after @ being matched
fileAtStartIdx int // byte offset of @ in the textarea value
fileAtStartIdx int // byte offset of @ (or path start in /edit mode) in the textarea value
fileSuggestions []FileSuggestion // backing storage for file entries
fileSynthCmds []commands.SlashCommand // synthetic commands.SlashCommands wrapping file entries
// fileEditMode is true when fileMode was activated by the /edit slash
// command rather than an @ trigger. Selecting a file submits the line
// (running $EDITOR on it); selecting a directory drills further like @
// does. MCP resources are excluded in this mode.
fileEditMode bool
// cwd is the working directory used for @file path resolution and
// autocomplete suggestions. Set by the parent via SetCwd.
cwd string
@@ -80,6 +94,23 @@ type InputComponent struct {
// Images are added via Ctrl+V and cleared on submit or Ctrl+U.
pendingImages []core.ImageAttachment
// imageThumbs caches the rendered half-block thumbnail for each entry in
// pendingImages (1:1 index correspondence). Thumbnails are rendered
// asynchronously off the Bubble Tea event loop (decode + resample is too
// slow to run inside Update), so an entry starts as the empty string
// placeholder and is filled in when the matching thumbnailReadyMsg
// arrives. An entry stays empty when the terminal cannot display a
// half-block preview, in which case the text pill is shown alone.
// See internal/ui/imagepreview.
imageThumbs []string
// imageGen is a monotonic generation counter incremented whenever the
// pending image set is cleared. Async thumbnail results carry the
// generation they were enqueued under and are discarded if it no longer
// matches, preventing a stale thumbnail from landing on the wrong slot
// after a clear + re-attach.
imageGen int
// history stores previously submitted prompts (most recent last).
// Limited to maxHistory entries; duplicates of the previous entry are
// skipped. Empty strings are never stored.
@@ -105,6 +136,16 @@ type clipboardImageMsg struct {
err error
}
// thumbnailReadyMsg carries the result of an async thumbnail render back to
// the Update loop. gen and index identify the pendingImages slot the
// thumbnail belongs to; the result is dropped if the generation no longer
// matches (the pending set was cleared) or the index is out of range.
type thumbnailReadyMsg struct {
gen int
index int
thumb string
}
// NewInputComponent creates a new InputComponent with the given width and
// optional AppController. If appCtrl is nil the component still works but
// /clear and /clear-queue are no-ops.
@@ -135,7 +176,7 @@ func NewInputComponent(width int, appCtrl AppController) *InputComponent {
styles.Focused.CursorLine = lipgloss.NewStyle()
ta.SetStyles(styles)
return &InputComponent{
ic := &InputComponent{
textarea: ta,
commands: commands.SlashCommands,
width: width,
@@ -143,6 +184,12 @@ func NewInputComponent(width int, appCtrl AppController) *InputComponent {
appCtrl: appCtrl,
hideHint: true,
}
ic.popup = NewPopupList("", nil, width, 0)
ic.popup.ShowSearch = false
ic.popup.HideCount = true
ic.popup.MaxVisible = ic.popupHeight
ic.popup.FooterHint = "↑↓ navigate • tab complete • ↵ select • esc dismiss"
return ic
}
// SetCwd sets the working directory used for @file autocomplete suggestions
@@ -193,7 +240,23 @@ func (s *InputComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return s, nil
}
if msg.image != nil {
s.pendingImages = append(s.pendingImages, *msg.image)
img := *msg.image
index := len(s.pendingImages)
s.pendingImages = append(s.pendingImages, img)
// Reserve a placeholder; the async render fills it in via
// thumbnailReadyMsg so Update never blocks on decode/resample.
s.imageThumbs = append(s.imageThumbs, "")
cols := s.thumbCols()
if cols < 1 {
return s, nil
}
return s, renderThumbnailCmd(img, cols, thumbMaxRows, style.GetTheme().Background, s.imageGen, index)
}
return s, nil
case thumbnailReadyMsg:
if msg.gen == s.imageGen && msg.index >= 0 && msg.index < len(s.imageThumbs) {
s.imageThumbs[msg.index] = msg.thumb
}
return s, nil
@@ -250,6 +313,8 @@ func (s *InputComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// Clear all pending image attachments.
if len(s.pendingImages) > 0 {
s.pendingImages = nil
s.imageThumbs = nil
s.imageGen++
return s, nil
}
}
@@ -405,10 +470,17 @@ func (s *InputComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
} else {
s.showPopup = false
s.fileMode = false
s.fileEditMode = false
}
} else if len(lines) == 1 && strings.HasPrefix(lines[0], "/") {
s.fileMode = false
if !strings.Contains(lines[0], " ") {
s.fileEditMode = false
if cmdLen, pathPrefix, isEdit := ExtractEditPrefix(lines[0]); isEdit {
// /edit fuzzy-file picker. Behaves like @ except
// MCP resources are excluded and selecting a file
// submits the line (running $EDITOR).
s.updateEditFilePopup(cmdLen, pathPrefix)
} else if !strings.Contains(lines[0], " ") {
// Command name completion.
s.showPopup = true
s.argMode = false
@@ -428,6 +500,7 @@ func (s *InputComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
s.showPopup = false
s.argMode = false
s.fileMode = false
s.fileEditMode = false
}
}
return s, cmd
@@ -486,6 +559,8 @@ func (s *InputComponent) handleSubmit(value string) tea.Cmd {
// images and clear them.
images := s.pendingImages
s.pendingImages = nil
s.imageThumbs = nil
s.imageGen++
return func() tea.Msg {
return core.SubmitMsg{Text: trimmed, Images: images}
}
@@ -519,6 +594,42 @@ func (s *InputComponent) resetHistoryBrowsing() {
s.savedInput = ""
}
// thumbMaxCols and thumbMaxRows cap the size, in terminal cells, of pending
// image previews. Kept small for the low-res look and to keep scrollback
// light.
const (
thumbMaxCols = 40
thumbMaxRows = 12
)
// thumbCols returns the thumbnail width in terminal cells given the current
// input width, or 0 when there is no room to render a preview.
func (s *InputComponent) thumbCols() int {
if s.width <= 6 {
return 0
}
cols := min(thumbMaxCols, s.width-6)
if cols < 1 {
return 0
}
return cols
}
// renderThumbnailCmd returns a tea.Cmd that renders a half-block ANSI preview
// off the Bubble Tea event loop. The decode + resample work runs in the Cmd
// goroutine, and the result is delivered as a thumbnailReadyMsg tagged with
// the generation and slot index it was enqueued for. An empty thumbnail
// (terminal unsupported or render error) leaves the text pill in place.
func renderThumbnailCmd(img core.ImageAttachment, cols, rows int, bg color.Color, gen, index int) tea.Cmd {
return func() tea.Msg {
thumb, err := imagepreview.Render(img.Data, img.MediaType, cols, rows, bg)
if err != nil {
thumb = ""
}
return thumbnailReadyMsg{gen: gen, index: index, thumb: thumb}
}
}
// View implements tea.Model. Renders the textarea, autocomplete popup
// (if visible), and help text.
func (s *InputComponent) View() tea.View {
@@ -544,7 +655,9 @@ func (s *InputComponent) View() tea.View {
// Popup is now rendered as a centered overlay in AppModel.View()
// instead of inline here to prevent bottom overflow
// Show image attachment indicator when images are pending.
// Show image attachment previews when images are pending. A cached
// half-block thumbnail is rendered when the terminal supports it;
// otherwise the text pill alone is shown.
if len(s.pendingImages) > 0 {
imgStyle := lipgloss.NewStyle().
Foreground(theme.Secondary).
@@ -553,6 +666,14 @@ func (s *InputComponent) View() tea.View {
label := fmt.Sprintf("[%d image(s) attached] ctrl+u to clear", len(s.pendingImages))
view.WriteString("\n")
view.WriteString(imgStyle.Render(label))
thumbStyle := lipgloss.NewStyle().PaddingLeft(3)
for i := range s.pendingImages {
if i < len(s.imageThumbs) && s.imageThumbs[i] != "" {
view.WriteString("\n")
view.WriteString(thumbStyle.Render(s.imageThumbs[i]))
}
}
}
if !s.hideHint {
@@ -591,191 +712,37 @@ func (s *InputComponent) View() tea.View {
return tea.NewView(containerStyle.Render(view.String()))
}
// renderPopup renders the autocomplete popup for slash command suggestions.
// When rendered inline (not centered), returns the styled popup content.
// RenderPopupCentered renders the popup as a centered overlay.
// RenderPopupCentered renders the autocomplete popup for / or @ as a
// centered overlay. Returns "" when the popup is not currently shown.
// The actual filtering / selection state lives on InputComponent — this
// method merely converts the filtered FuzzyMatch list into PopupItems
// and asks the shared PopupList to draw it. As a result the / popup, the
// @ popup, the model picker, the tree selector and the session selector
// all share identical chrome.
func (s *InputComponent) RenderPopupCentered(termWidth, termHeight int) string {
if !s.showPopup || len(s.filtered) == 0 {
return ""
}
popupContent := s.renderPopupWithOptions(true)
// Center popup using lipgloss.Place
positioned := lipgloss.Place(
termWidth,
termHeight,
lipgloss.Center,
lipgloss.Center,
popupContent,
)
return positioned
}
// renderPopupWithOptions renders the popup content with optional center styling.
func (s *InputComponent) renderPopupWithOptions(centered bool) string {
theme := style.GetTheme()
popupWidth := max(s.width-4, 20)
// Use the theme background for the popup - the full-width item backgrounds
// and primary-colored selection will provide sufficient contrast
popupBg := theme.Background
popupStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(theme.Primary).
Background(popupBg).
Padding(1, 2).
Width(popupWidth).
MarginLeft(0).
MarginBottom(1) // Visual depth/shadow effect
// Inner content width: popup minus border (2) and horizontal padding (4).
innerWidth := max(popupWidth-6, 10)
// Item background styles for high contrast
normalItemBg := lipgloss.NewStyle().
Background(popupBg).
Foreground(theme.Text).
Width(innerWidth).
Padding(0, 1)
selectedItemBg := lipgloss.NewStyle().
Background(theme.Primary).
Foreground(theme.Background).
Width(innerWidth).
Padding(0, 1).
Bold(true)
var items []string
visibleItems := min(len(s.filtered), s.popupHeight)
startIdx := 0
if s.selected >= s.popupHeight {
startIdx = s.selected - s.popupHeight + 1
}
endIdx := min(startIdx+visibleItems, len(s.filtered))
for i := startIdx; i < endIdx; i++ {
match := s.filtered[i]
sc := match.Command
// Choose the appropriate background style
itemStyle := normalItemBg
if i == s.selected {
itemStyle = selectedItemBg
items := make([]PopupItem, len(s.filtered))
for i, m := range s.filtered {
desc := ""
if m.Command != nil {
desc = m.Command.Description
}
// Build indicator with proper coloring
var indicator string
if i == s.selected {
indicator = "> "
} else {
indicator = " "
name := ""
if m.Command != nil {
name = m.Command.Name
}
// Build content with name and description
var content string
if s.fileMode {
// File mode: use full width for the path, show description inline
maxNameLen := max(innerWidth-16, 8)
displayName := sc.Name
if len(displayName) > maxNameLen && maxNameLen > 3 {
displayName = displayName[:maxNameLen-3] + "..."
}
if sc.Description != "" && innerWidth > 30 {
content = indicator + displayName + " " + sc.Description
} else {
content = indicator + displayName
}
} else {
// Line layout: indicator(2) + name(nameWidth-2 visual) + desc
if innerWidth < 20 {
// Very narrow: show truncated name only
displayName := sc.Name
maxName := max(innerWidth-2, 3)
if len(displayName) > maxName {
displayName = displayName[:maxName-1] + "…"
}
content = indicator + displayName
} else {
// Compute nameWidth from the longest command name in the
// visible slice so we never truncate unnecessarily.
nameWidth := 0
for _, fm := range s.filtered {
if n := len([]rune(fm.Command.Name)); n > nameWidth {
nameWidth = n
}
}
nameWidth += 3 // account for indicator prefix (2) + gap before description (1)
// Ensure descriptions still get at least 20 chars when possible.
maxForName := innerWidth - 20
if maxForName < 8 {
maxForName = innerWidth * 2 / 3
}
if nameWidth > maxForName {
nameWidth = maxForName
}
if nameWidth < 8 {
nameWidth = 8
}
maxNameChars := nameWidth - 2
displayName := sc.Name
if len(displayName) > maxNameChars {
displayName = displayName[:maxNameChars-1] + "…"
}
// Description gets remaining space
maxDescLen := max(innerWidth-nameWidth, 0)
desc := sc.Description
if maxDescLen >= 4 && desc != "" {
if len(desc) > maxDescLen {
desc = desc[:maxDescLen-3] + "..."
}
content = indicator + lipgloss.NewStyle().Width(maxNameChars).Render(displayName) + desc
} else {
content = indicator + displayName
}
}
items[i] = PopupItem{
Label: name,
Description: desc,
}
items = append(items, itemStyle.Render(content))
}
// Add scroll indicators with background
scrollStyle := lipgloss.NewStyle().
Background(popupBg).
Foreground(theme.VeryMuted).
Width(innerWidth).
Padding(0, 1)
if startIdx > 0 {
items = append([]string{scrollStyle.Render(" ↑ more above")}, items...)
}
if endIdx < len(s.filtered) {
items = append(items, scrollStyle.Render(" ↓ more below"))
}
content := strings.Join(items, "\n")
// Adapt footer text to available width with background
var footerText string
if innerWidth >= 50 {
footerText = "↑↓ navigate • tab complete • ↵ select • esc dismiss"
} else if innerWidth >= 30 {
footerText = "↑↓ nav • tab • ↵ select • esc"
} else {
footerText = "↑↓ tab ↵ esc"
}
footer := lipgloss.NewStyle().
Background(popupBg).
Foreground(theme.VeryMuted).
Italic(true).
Render(footerText)
return popupStyle.Render(content + "\n\n" + footer)
s.popup.SetSize(termWidth, termHeight)
s.popup.SetItems(items)
s.popup.SetCursor(s.selected)
return s.popup.RenderCentered(termWidth, termHeight)
}
// completeArgs checks whether the input line matches a command with a Complete
@@ -844,6 +811,8 @@ func readClipboardImageCmd() tea.Cmd {
func (s *InputComponent) ClearPendingImages() []core.ImageAttachment {
images := s.pendingImages
s.pendingImages = nil
s.imageThumbs = nil
s.imageGen++
return images
}
@@ -862,6 +831,7 @@ func (s *InputComponent) Clear() bool {
s.showPopup = false
s.argMode = false
s.fileMode = false
s.fileEditMode = false
s.browsingHistory = false
s.savedInput = ""
return hadContent
@@ -871,6 +841,11 @@ func (s *InputComponent) Clear() bool {
// file or MCP resource suggestion. For directories, it keeps the popup open
// for further drilling. For files and resources, it closes the popup and adds
// a trailing space.
//
// When fileEditMode is active the same path-replacement happens against the
// /edit (or alias) command prefix instead of an @ trigger. Selecting a file
// also arms submitNext so the next tick runs $EDITOR on it; selecting a
// directory keeps the popup open for drill-down.
func (s *InputComponent) applyFileCompletion(idx int) {
if idx >= len(s.fileSuggestions) {
return
@@ -889,7 +864,17 @@ func (s *InputComponent) applyFileCompletion(idx int) {
beforeAt := lastLine[:s.fileAtStartIdx]
var replacement string
if suggestion.IsMCPResource {
switch {
case s.fileEditMode:
// /edit path mode — no @ prefix; the path is the bare argument.
// MCP resources are excluded upstream, so only file/dir entries reach here.
needsQuote := strings.Contains(suggestion.RelPath, " ")
if needsQuote {
replacement = `"` + suggestion.RelPath + `"`
} else {
replacement = suggestion.RelPath
}
case suggestion.IsMCPResource:
// MCP resources use @mcp:server:uri format.
// Quote if the URI contains spaces.
ref := "mcp:" + suggestion.MCPServerName + ":" + suggestion.MCPResourceURI
@@ -899,7 +884,7 @@ func (s *InputComponent) applyFileCompletion(idx int) {
replacement = "@" + ref
}
replacement += " "
} else {
default:
needsQuote := strings.Contains(suggestion.RelPath, " ")
if needsQuote {
replacement = `@"` + suggestion.RelPath + `"`
@@ -925,9 +910,61 @@ func (s *InputComponent) applyFileCompletion(idx int) {
if suggestion.IsDir && !suggestion.IsMCPResource {
// Keep popup open — trigger a refresh for the new directory.
s.lastValue = "" // force re-evaluation on next update tick
} else {
s.showPopup = false
s.fileMode = false
s.selected = 0
return
}
s.showPopup = false
s.fileMode = false
s.selected = 0
if s.fileEditMode {
// A file was selected via /edit — submit on the next tick so the
// popup dismisses cleanly before $EDITOR takes the terminal.
s.fileEditMode = false
s.submitNext = true
}
}
// updateEditFilePopup queries the file-suggestion engine for the /edit path
// prefix and populates the popup state. cmdLen is the byte offset of the path
// argument within the current line (i.e. length of "/edit " or "/ed ").
// Directories are kept so the user can drill down; MCP resources are skipped.
func (s *InputComponent) updateEditFilePopup(cmdLen int, pathPrefix string) {
var suggestions []FileSuggestion
if s.cwd != "" {
suggestions = GetFileSuggestions(pathPrefix, s.cwd)
}
if len(suggestions) == 0 {
s.showPopup = false
s.fileMode = false
s.fileEditMode = false
return
}
sort.Slice(suggestions, func(i, j int) bool {
return suggestions[i].Score > suggestions[j].Score
})
if len(suggestions) > maxFileSuggestions {
suggestions = suggestions[:maxFileSuggestions]
}
s.showPopup = true
s.fileMode = true
s.fileEditMode = true
s.argMode = false
s.filePrefix = pathPrefix
s.fileAtStartIdx = cmdLen
s.fileSuggestions = suggestions
s.fileSynthCmds = make([]commands.SlashCommand, len(suggestions))
s.filtered = make([]FuzzyMatch, len(suggestions))
for i, fs := range suggestions {
name := fs.RelPath
desc := ""
if fs.IsDir {
desc = "directory"
}
s.fileSynthCmds[i] = commands.SlashCommand{Name: name, Description: desc}
s.filtered[i] = FuzzyMatch{Command: &s.fileSynthCmds[i], Score: fs.Score}
}
s.selected = 0
}
+477 -133
View File
@@ -7,6 +7,7 @@ import (
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
@@ -25,6 +26,7 @@ import (
"github.com/mark3labs/kit/internal/ui/commands"
uicore "github.com/mark3labs/kit/internal/ui/core"
"github.com/mark3labs/kit/internal/ui/fileutil"
"github.com/mark3labs/kit/internal/ui/imagepreview"
"github.com/mark3labs/kit/internal/ui/prefs"
"github.com/mark3labs/kit/internal/ui/style"
kit "github.com/mark3labs/kit/pkg/kit"
@@ -124,6 +126,14 @@ type AppController interface {
// attachments (e.g. pasted images) into the currently running agent
// turn. Behaves like Steer but includes file parts alongside the text.
SteerWithFiles(prompt string, files []kit.LLMFilePart) int
// PopLastUserMessage truncates the tree session at the parent of the
// most recent user message on the current branch, syncs the in-memory
// message store, and returns that user prompt (plus any image file
// parts) so the caller can resubmit it. Used by /retry to recover from
// provider errors (overloaded, timeout) without duplicating the user
// message in context. Returns an error if the agent is busy, no tree
// session is active, or no user message exists on the current branch.
PopLastUserMessage() (string, []kit.LLMFilePart, error)
}
// SkillItem holds display metadata about a loaded skill for the startup
@@ -435,9 +445,12 @@ type AppModelOptions struct {
EmitBeforeFork func(targetID string, isUserMsg bool, userText string) (bool, string)
// EmitBeforeSessionSwitch, if non-nil, is called before switching
// to a new session branch (e.g. /new, /clear). Returns (cancelled,
// reason). May be nil if no extensions are loaded.
EmitBeforeSessionSwitch func(reason string) (bool, string)
// to a new session branch (e.g. /new, /clear). reason is the trigger
// ("new", "clear", "extension"); initialPrompt is the user prompt
// that will run as the first turn of the new session (empty when
// /new is called without arguments). Returns (cancelled, reason).
// May be nil if no extensions are loaded.
EmitBeforeSessionSwitch func(reason, initialPrompt string) (bool, string)
// GetGlobalShortcuts, if non-nil, returns extension-registered global
// keyboard shortcuts. Keys are binding strings (e.g., "ctrl+p").
@@ -565,6 +578,13 @@ type AppModel struct {
// flushed first, preserving chronological order.
pendingUserPrints []string
// newSessionResultCh, when non-nil, receives the outcome of an
// in-flight extension-triggered NewSession request. Set when an
// app.NewSessionRequestEvent arrives; cleared (with a result sent)
// in performNewSession success/failure paths or in the
// beforeSessionSwitchResultMsg cancellation path.
newSessionResultCh chan<- error
// canceling tracks whether the user has pressed ESC once during stateWorking.
// A second ESC within 2 seconds will cancel the current step.
canceling bool
@@ -667,7 +687,7 @@ type AppModel struct {
// emitBeforeSessionSwitch emits a before-session-switch event to extensions.
// Returns (cancelled, reason). May be nil if no extensions are loaded.
emitBeforeSessionSwitch func(reason string) (bool, string)
emitBeforeSessionSwitch func(reason, initialPrompt string) (bool, string)
// thinkingLevel is the current extended thinking level.
thinkingLevel string
@@ -1198,53 +1218,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.modelSelector = nil
m.state = stateInput
if m.setModel != nil {
previousModel := m.providerName + "/" + m.modelName
// Check if thinking level needs adjustment for the new model.
// Some models (e.g., OpenAI gpt-5.4) don't support "minimal" and require "none".
if m.thinkingLevel != "" && m.thinkingLevel != "off" {
parts := strings.SplitN(msg.ModelString, "/", 2)
if len(parts) == 2 {
modelName := parts[1]
currentLevel := models.ParseThinkingLevel(m.thinkingLevel)
if !models.IsValidThinkingLevelForModel(currentLevel, modelName) {
fallback := models.SuggestThinkingLevelFallback(currentLevel, modelName)
if fallback != models.ThinkingOff {
m.printSystemMessage(fmt.Sprintf(
"Note: Model %s doesn't support '%s' thinking level. Adjusted to '%s'.",
modelName, currentLevel, fallback,
))
m.thinkingLevel = string(fallback)
if m.setThinkingLevel != nil {
_ = m.setThinkingLevel(string(fallback))
}
go func() { _ = prefs.SaveThinkingLevelPreference(string(fallback)) }()
}
}
}
}
if err := m.setModel(msg.ModelString); err != nil {
m.printSystemMessage(fmt.Sprintf("Failed to switch model: %v", err))
} else {
// Update display state directly — we cannot use
// NotifyModelChanged (prog.Send) from inside Update()
// without deadlocking BubbleTea.
parts := strings.SplitN(msg.ModelString, "/", 2)
if len(parts) == 2 {
m.providerName = parts[0]
m.modelName = parts[1]
}
m.printSystemMessage(fmt.Sprintf("Switched to %s", msg.ModelString))
// Persist model selection for next launch.
go func() { _ = prefs.SaveModelPreference(msg.ModelString) }()
if m.emitModelChange != nil {
emit := m.emitModelChange
newModel := msg.ModelString
prev := previousModel
go emit(newModel, prev, "user")
}
}
m.switchModel(msg.ModelString)
}
return m, tea.Batch(cmds...)
@@ -1794,14 +1768,27 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// messages stay in chronological order.
m.pendingUserPrints = append(m.pendingUserPrints, displayText)
m.flushStreamAndPendingUserMessages()
// Insert inline thumbnail previews after the user message.
cmds = append(cmds, m.transcriptPreviewCmd(msg.Images, m.lastMessageID()))
}
} else {
m.printUserMessage(displayText)
// Insert inline thumbnail previews after the user message.
cmds = append(cmds, m.transcriptPreviewCmd(msg.Images, m.lastMessageID()))
}
if m.state != stateWorking {
m.state = stateWorking
}
// ── Async transcript image preview ───────────────────────────────────────
case imagePreviewReadyMsg:
if msg.block != "" {
item := NewStyledMessageItem(generateMessageID(), "user", "", msg.block)
m.insertMessageAfter(msg.anchorID, item)
m.refreshContent()
m.layoutDirty = true
}
// ── Shell command (! / !!) ───────────────────────────────────────────────
case uicore.ShellCommandMsg:
// Show spinner while the shell command runs.
@@ -2215,6 +2202,25 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
ic.textarea.CursorEnd()
}
case app.NewSessionRequestEvent:
// Extension wants to end the current session and start a fresh
// one (with an optional initial prompt). Stash the response
// channel so performNewSession (or the before-hook cancellation
// path) can signal completion, then run the same /new pipeline
// the user would trigger.
if msg.ResponseCh != nil {
// Only one new-session request in flight at a time. If a
// previous response channel is still pending, fail it before
// replacing it so the prior extension goroutine unblocks.
if m.newSessionResultCh != nil {
m.newSessionResultCh <- fmt.Errorf("superseded by a newer NewSession request")
}
m.newSessionResultCh = msg.ResponseCh
}
if cmd := m.handleNewCommand(msg.InitialPrompt); cmd != nil {
cmds = append(cmds, cmd)
}
case app.PasswordPromptEvent:
// Sudo password prompt - show a modal input prompt
// If already in prompt state, cancel the new request
@@ -2397,6 +2403,16 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.layoutDirty = true
}
case editFileMsg:
// User returned from $EDITOR after `/edit <path>`. The file was
// edited directly on disk — no textarea changes. Report the result.
if msg.err != nil {
m.printSystemMessage(fmt.Sprintf("Editor exited with error: %v", msg.err))
} else {
m.printSystemMessage(fmt.Sprintf("Edited `%s`", msg.path))
}
m.layoutDirty = true
case extReloadResultMsg:
if msg.err != nil {
m.printSystemMessage(fmt.Sprintf("Extension reload failed: %v", msg.err))
@@ -2410,8 +2426,9 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// session reset if the hook did not cancel.
if msg.cancelled {
m.printSystemMessage(msg.reason)
m.signalNewSessionResult(fmt.Errorf("session switch cancelled: %s", msg.reason))
} else {
cmds = append(cmds, m.performNewSession())
cmds = append(cmds, m.performNewSession(msg.initialPrompt))
}
case beforeForkResultMsg:
@@ -2447,6 +2464,19 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.printSystemMessage(msg.Text)
}
// ── Clipboard image attached / thumbnail rendered ────────────────────────
// Both messages change the input region's rendered height (the pill and
// the async half-block preview), so forward them to the input and mark the
// layout dirty — otherwise distributeHeight keeps a stale, too-short input
// height and the preview is clipped off the bottom of the screen.
case clipboardImageMsg, thumbnailReadyMsg:
if m.input != nil {
updated, cmd := m.input.Update(msg)
m.input, _ = updated.(inputComponentIface)
cmds = append(cmds, cmd)
}
m.layoutDirty = true
default:
// Pass unrecognised messages to all children.
if m.input != nil {
@@ -3046,6 +3076,85 @@ func truncateMessageForBlock(msg string, maxLines, width int) string {
// Print helpers — add content to ScrollList
// --------------------------------------------------------------------------
// imagePreviewReadyMsg carries an asynchronously rendered transcript image
// preview block back to the Update loop, where it is inserted into the
// ScrollList directly after the originating user message (identified by
// anchorID). Inserting by anchor — rather than appending — keeps the preview
// next to its message even when the agent's streamed reply has already been
// appended while the thumbnail was being decoded off the event loop.
type imagePreviewReadyMsg struct {
block string
anchorID string
}
// transcriptPreviewCmd returns a tea.Cmd that renders half-block thumbnail
// previews for the given clipboard images off the Bubble Tea event loop
// (decode + resample must not block Update). The rendered block is delivered
// via imagePreviewReadyMsg, tagged with anchorID so the consumer can place it
// directly after the originating user message. Returns nil when there is
// nothing to render or no room for a preview; an empty result (terminal lacks
// color support) yields a nil message that Bubble Tea ignores.
func (m *AppModel) transcriptPreviewCmd(images []uicore.ImageAttachment, anchorID string) tea.Cmd {
if len(images) == 0 {
return nil
}
cols := thumbMaxCols
if m.width > 6 && m.width-6 < cols {
cols = m.width - 6
}
if cols < 1 {
return nil
}
bg := style.GetTheme().Background
imgs := images
return func() tea.Msg {
pad := lipgloss.NewStyle().PaddingLeft(2)
var blocks []string
for _, img := range imgs {
thumb, err := imagepreview.Render(img.Data, img.MediaType, cols, thumbMaxRows, bg)
if err != nil || thumb == "" {
continue
}
blocks = append(blocks, pad.Render(thumb))
}
if len(blocks) == 0 {
return nil
}
return imagePreviewReadyMsg{block: strings.Join(blocks, "\n"), anchorID: anchorID}
}
}
// lastMessageID returns the ID of the most recently added ScrollList message,
// or "" when there are none. Used to anchor an async transcript preview to the
// user message that was just printed.
func (m *AppModel) lastMessageID() string {
if len(m.messages) == 0 {
return ""
}
return m.messages[len(m.messages)-1].ID()
}
// insertMessageAfter inserts item immediately after the message whose ID
// matches anchorID. If anchorID is empty or not found, item is appended.
func (m *AppModel) insertMessageAfter(anchorID string, item MessageItem) {
idx := -1
if anchorID != "" {
for i, msgItem := range m.messages {
if msgItem.ID() == anchorID {
idx = i
break
}
}
}
if idx < 0 {
m.messages = append(m.messages, item)
return
}
m.messages = append(m.messages, nil)
copy(m.messages[idx+2:], m.messages[idx+1:])
m.messages[idx+1] = item
}
// printUserMessage renders a user message into the ScrollList.
func (m *AppModel) printUserMessage(text string) {
// Check if this exact message was just added (prevents duplicates)
@@ -3162,7 +3271,7 @@ func (m *AppModel) handleSlashCommand(sc *commands.SlashCommand, args string) te
case "/fork":
return m.handleForkCommand()
case "/new":
return m.handleNewCommand()
return m.handleNewCommand(args)
case "/name":
return m.handleNameCommand(args)
case "/resume":
@@ -3171,6 +3280,10 @@ func (m *AppModel) handleSlashCommand(sc *commands.SlashCommand, args string) te
return m.handleExportCommand(args)
case "/copy":
return m.handleCopyCommand()
case "/retry":
return m.handleRetryCommand()
case "/edit":
return m.handleEditCommand(args)
case "/share":
return m.handleShareCommand()
case "/import":
@@ -3589,13 +3702,15 @@ func (m *AppModel) printHelpMessage() {
"**Navigation:**\n" +
"- `/tree`: Navigate session tree (switch branches)\n" +
"- `/fork`: Branch from an earlier message\n" +
"- `/new`: Start a new session (discards context, saves old session)\n" +
"- `/new [prompt]`: Start a new session (discards context, saves old session). With a prompt, runs it as the first message; supports `@file` attachments.\n" +
"- `/resume`: Open session picker to switch sessions\n" +
"- `/name <name>`: Set a display name for this session\n\n" +
"**System:**\n" +
"- `/compact [instructions]`: Summarise older messages to free context space\n" +
"- `/clear`: Clear message history\n" +
"- `/copy`: Copy the last message to the system clipboard\n" +
"- `/retry`: Resubmit the last user message (e.g. after a provider error)\n" +
"- `/edit [path]`: Open a file in `$EDITOR` (fuzzy-find from cwd)\n" +
"- `/export [path]`: Export session as JSONL\n" +
"- `/import <path.jsonl>`: Import session from JSONL file\n" +
"- `/reset-usage`: Reset usage statistics\n" +
@@ -4080,11 +4195,31 @@ func (m *AppModel) handleModelCommand(args string) tea.Cmd {
return nil
}
// Direct model switch with the provided model string.
m.switchModel(args)
return nil
}
// switchModel performs a direct model switch, shared by the model selector
// overlay and the /model slash command: it adjusts the thinking level when
// the new model doesn't support the current one, calls the setModel
// callback, updates display state, persists preferences, and emits the
// ModelChange extension event.
//
// Display state is updated directly — we cannot use NotifyModelChanged
// (prog.Send) from inside Update() without deadlocking BubbleTea.
func (m *AppModel) switchModel(modelString string) {
if m.setModel == nil {
m.printSystemMessage("Model switching is not available.")
return
}
previousModel := m.providerName + "/" + m.modelName
// Check if thinking level needs adjustment for the new model.
// Some models (e.g., OpenAI gpt-5.4) don't support "minimal" and require "none".
if m.thinkingLevel != "" && m.thinkingLevel != "off" {
parts := strings.SplitN(args, "/", 2)
if len(parts) == 2 {
if parts := strings.SplitN(modelString, "/", 2); len(parts) == 2 {
modelName := parts[1]
currentLevel := models.ParseThinkingLevel(m.thinkingLevel)
if !models.IsValidThinkingLevelForModel(currentLevel, modelName) {
@@ -4104,32 +4239,26 @@ func (m *AppModel) handleModelCommand(args string) tea.Cmd {
}
}
// Direct model switch with the provided model string.
previousModel := m.providerName + "/" + m.modelName
if err := m.setModel(args); err != nil {
if err := m.setModel(modelString); err != nil {
m.printSystemMessage(fmt.Sprintf("Failed to switch model: %v", err))
return nil
return
}
// Update display state directly (cannot use prog.Send from Update).
parts := strings.SplitN(args, "/", 2)
if len(parts) == 2 {
if parts := strings.SplitN(modelString, "/", 2); len(parts) == 2 {
m.providerName = parts[0]
m.modelName = parts[1]
}
if m.emitModelChange != nil {
emit := m.emitModelChange
prev := previousModel
newModel := args
go emit(newModel, prev, "user")
}
m.printSystemMessage(fmt.Sprintf("Switched to %s", modelString))
// Persist model selection for next launch.
go func() { _ = prefs.SaveModelPreference(args) }()
go func() { _ = prefs.SaveModelPreference(modelString) }()
m.printSystemMessage(fmt.Sprintf("Switched to %s", args))
return nil
if m.emitModelChange != nil {
emit := m.emitModelChange
go emit(modelString, previousModel, "user")
}
}
// --------------------------------------------------------------------------
@@ -4269,7 +4398,12 @@ func (m *AppModel) handleForkCommand() tea.Cmd {
// handleNewCommand starts a completely new session (Pi-style /new behavior).
// Creates a new session file, discarding all context from the previous conversation.
func (m *AppModel) handleNewCommand() tea.Cmd {
// If initialPrompt is non-empty it is submitted as the first user turn of the
// new session, with @file references expanded the same way they are for
// regular user input.
func (m *AppModel) handleNewCommand(initialPrompt string) tea.Cmd {
initialPrompt = strings.TrimSpace(initialPrompt)
// Emit before-session-switch event in a goroutine so that extension
// handlers can call blocking operations (e.g. ctx.PromptConfirm) without
// deadlocking the BubbleTea event loop.
@@ -4277,23 +4411,25 @@ func (m *AppModel) handleNewCommand() tea.Cmd {
emit := m.emitBeforeSessionSwitch
ctrl := m.appCtrl
go func() {
cancelled, reason := emit("new")
cancelled, reason := emit("new", initialPrompt)
ctrl.SendEvent(beforeSessionSwitchResultMsg{
cancelled: cancelled,
reason: reason,
cancelled: cancelled,
reason: reason,
initialPrompt: initialPrompt,
})
}()
return noopCmd
}
return m.performNewSession()
return m.performNewSession(initialPrompt)
}
// performNewSession performs the actual session reset. Called either directly
// (when no before-hook exists) or after the async hook completes.
// Matches Pi behavior: creates a completely new session file, discarding all
// context from the previous conversation.
func (m *AppModel) performNewSession() tea.Cmd {
// context from the previous conversation. If initialPrompt is non-empty it
// is submitted as the first user turn (with @file expansion).
func (m *AppModel) performNewSession(initialPrompt string) tea.Cmd {
ts := m.appCtrl.GetTreeSession()
if ts == nil {
// No tree session — just clear messages.
@@ -4307,13 +4443,16 @@ func (m *AppModel) performNewSession() tea.Cmd {
// Clear the ScrollList so the new session starts fresh.
m.messages = []MessageItem{}
m.printSystemMessage("Conversation cleared. Starting fresh.")
return nil
cmd := m.submitInitialPrompt(initialPrompt)
m.signalNewSessionResult(nil)
return cmd
}
// Create a brand new session file (Pi-style /new behavior)
newTs, err := session.CreateTreeSession(m.cwd)
if err != nil {
m.printSystemMessage(fmt.Sprintf("Failed to create new session: %v", err))
m.signalNewSessionResult(fmt.Errorf("create new session: %w", err))
return nil
}
@@ -4326,6 +4465,67 @@ func (m *AppModel) performNewSession() tea.Cmd {
// Clear the ScrollList so the new session starts fresh.
m.messages = []MessageItem{}
m.printSystemMessage("New session started. Previous conversation saved.")
cmd := m.submitInitialPrompt(initialPrompt)
m.signalNewSessionResult(nil)
return cmd
}
// signalNewSessionResult delivers the outcome of an extension-triggered
// NewSession request (if one is in flight) and clears the response channel.
// Safe to call when no request is pending.
func (m *AppModel) signalNewSessionResult(err error) {
if m.newSessionResultCh == nil {
return
}
ch := m.newSessionResultCh
m.newSessionResultCh = nil
// Channel is buffered (cap >= 1) by contract — send is non-blocking.
ch <- err
}
// submitInitialPrompt is the shared submission path used by /new <prompt>
// and ctx.NewSession(prompt). It mirrors the SubmitMsg handler: @file
// references are expanded via fileutil.ProcessFileAttachments and the
// resulting prompt is forwarded to AppController.Run / RunWithFiles.
// Returns nil when prompt is empty.
func (m *AppModel) submitInitialPrompt(prompt string) tea.Cmd {
prompt = strings.TrimSpace(prompt)
if prompt == "" || m.appCtrl == nil {
return nil
}
processedText := prompt
var fileParts []kit.LLMFilePart
if m.cwd != "" {
result := fileutil.ProcessFileAttachments(prompt, m.cwd, m.mcpResourceReader)
processedText = result.ProcessedText
for _, fp := range result.FileParts {
fileParts = append(fileParts, kit.LLMFilePart{
Filename: fp.Filename,
Data: fp.Data,
MediaType: fp.MediaType,
})
}
}
displayText := prompt
if len(fileParts) > 0 {
displayText = fmt.Sprintf("%s\n[%d file(s) attached]", prompt, len(fileParts))
}
var qLen int
if len(fileParts) > 0 {
qLen = m.appCtrl.RunWithFiles(processedText, fileParts)
} else {
qLen = m.appCtrl.Run(processedText)
}
if qLen > 0 {
m.queuedMessages = append(m.queuedMessages, displayText)
m.layoutDirty = true
} else {
m.pendingUserPrints = append(m.pendingUserPrints, displayText)
m.flushStreamAndPendingUserMessages()
}
return nil
}
@@ -4446,6 +4646,141 @@ func (m *AppModel) handleCopyCommand() tea.Cmd {
return clipboard.CopyToClipboard(text)
}
// handleRetryCommand resubmits the most recent user message on the current
// branch. Used to recover from transient provider errors (overloaded,
// timeout) without users having to retype — and without the duplicate-user-
// message bloat that retyping creates.
//
// Flow:
// 1. App.PopLastUserMessage() truncates the tree at the parent of the last
// user message and returns its text + any image parts. The failed turn's
// entries become orphaned (still on disk, off-branch) so they will not
// be re-sent to the LLM.
// 2. The visible message list is rebuilt from the truncated branch so the
// prior user message + any partial assistant + error rendering vanish.
// 3. The prompt is resubmitted via Run/RunWithFiles, mirroring the normal
// SubmitMsg display path (badge formatting, pending-prints flush,
// stateWorking transition).
func (m *AppModel) handleRetryCommand() tea.Cmd {
if m.appCtrl == nil {
m.printSystemMessage("App controller unavailable.")
return nil
}
prompt, files, err := m.appCtrl.PopLastUserMessage()
if err != nil {
m.printSystemMessage(fmt.Sprintf("Cannot retry: %v", err))
return nil
}
// Rebuild the visible ScrollList from the truncated branch so the failed
// turn's user message and any partial assistant/error rendering disappear
// before the resubmit prints a fresh user message.
m.messages = []MessageItem{}
m.renderSessionHistory()
// Mirror SubmitMsg's badge formatting for the display text.
var imageCount, fileOnlyCount int
for _, f := range files {
if strings.HasPrefix(f.MediaType, "image/") {
imageCount++
} else {
fileOnlyCount++
}
}
displayText := prompt
if imageCount > 0 || fileOnlyCount > 0 {
var badges []string
if imageCount > 0 {
badges = append(badges, fmt.Sprintf("%d image(s) pasted", imageCount))
}
if fileOnlyCount > 0 {
badges = append(badges, fmt.Sprintf("%d file(s) attached", fileOnlyCount))
}
displayText = fmt.Sprintf("%s\n[%s]", prompt, strings.Join(badges, ", "))
}
var qLen int
if len(files) > 0 {
qLen = m.appCtrl.RunWithFiles(prompt, files)
} else {
qLen = m.appCtrl.Run(prompt)
}
if qLen > 0 {
m.queuedMessages = append(m.queuedMessages, displayText)
m.layoutDirty = true
} else {
m.pendingUserPrints = append(m.pendingUserPrints, displayText)
m.flushStreamAndPendingUserMessages()
}
if m.state != stateWorking {
m.state = stateWorking
}
return nil
}
// handleEditCommand opens the supplied path in $EDITOR via tea.ExecProcess,
// pausing the TUI for the duration of the editor session. The path is
// resolved relative to cwd; ~/ and absolute paths are honoured. Non-existent
// paths are allowed — most editors will create the file on save.
//
// On exit an editFileMsg is emitted with the resolved path (or error) so the
// Update loop can report the result. The textarea is not touched — use
// Ctrl+X e if you want to round-trip a prompt through $EDITOR instead.
func (m *AppModel) handleEditCommand(args string) tea.Cmd {
path := strings.TrimSpace(args)
if path == "" {
m.printSystemMessage("Usage: `/edit <path>` — or type `/edit ` and pick a file from the popup.")
return nil
}
// Strip optional surrounding double-quotes (the autocomplete inserts
// these when a path contains spaces).
if len(path) >= 2 && strings.HasPrefix(path, `"`) && strings.HasSuffix(path, `"`) {
path = path[1 : len(path)-1]
}
// Resolve ~/, relative, and absolute paths against cwd.
resolved := path
if strings.HasPrefix(resolved, "~/") {
if home, err := os.UserHomeDir(); err == nil {
resolved = filepath.Join(home, resolved[2:])
}
}
if !filepath.IsAbs(resolved) {
cwd, err := os.Getwd()
if err == nil {
resolved = filepath.Join(cwd, resolved)
}
}
resolved = filepath.Clean(resolved)
// Reject paths that exist but are directories — $EDITOR semantics vary.
if info, err := os.Stat(resolved); err == nil && info.IsDir() {
m.printSystemMessage(fmt.Sprintf("`%s` is a directory, not a file.", resolved))
return nil
}
editorApp := os.Getenv("VISUAL")
if editorApp == "" {
editorApp = os.Getenv("EDITOR")
}
if editorApp == "" {
m.printSystemMessage("Set `$EDITOR` or `$VISUAL` to use `/edit`")
return nil
}
editorCmd, cmdErr := editor.Command(editorApp, resolved)
if cmdErr != nil {
m.printSystemMessage(fmt.Sprintf("Failed to open editor: %v", cmdErr))
return nil
}
return tea.ExecProcess(editorCmd, func(err error) tea.Msg {
return editFileMsg{path: resolved, err: err}
})
}
// handleExportCommand exports the current session to a file.
// Usage: /export — copies the JSONL file to cwd with a descriptive name.
//
@@ -4561,61 +4896,11 @@ func (m *AppModel) handleShareCommand() tea.Cmd {
return r
}, name)
tmpFile, err := os.CreateTemp("", fmt.Sprintf("kit-%s-*.jsonl", name))
tmpPath, err := buildShareFile(name, data, sysPromptJSON)
if err != nil {
m.printSystemMessage(fmt.Sprintf("Failed to create temp file: %v", err))
m.printSystemMessage(fmt.Sprintf("Failed to share session: %v", err))
return nil
}
tmpPath := tmpFile.Name()
// Write the session data with the system prompt entry inserted after the header.
// The header is the first line, so we write:
// 1. First line (header) from original data
// 2. System prompt entry
// 3. Remaining lines from original data
lines := strings.Split(string(data), "\n")
if len(lines) > 0 && lines[len(lines)-1] == "" {
lines = lines[:len(lines)-1] // Remove trailing empty line
}
if len(lines) > 0 {
// Write header (first line)
if _, err := tmpFile.WriteString(lines[0] + "\n"); err != nil {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
m.printSystemMessage(fmt.Sprintf("Failed to write temp file: %v", err))
return nil
}
// Write system prompt entry
if _, err := tmpFile.Write(sysPromptJSON); err != nil {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
m.printSystemMessage(fmt.Sprintf("Failed to write system prompt: %v", err))
return nil
}
if _, err := tmpFile.WriteString("\n"); err != nil {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
m.printSystemMessage(fmt.Sprintf("Failed to write temp file: %v", err))
return nil
}
// Write remaining lines
for i := 1; i < len(lines); i++ {
if lines[i] == "" {
continue // Skip empty lines
}
if _, err := tmpFile.WriteString(lines[i] + "\n"); err != nil {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
m.printSystemMessage(fmt.Sprintf("Failed to write temp file: %v", err))
return nil
}
}
}
_ = tmpFile.Close()
m.printSystemMessage("Uploading session to GitHub Gist...")
@@ -4641,6 +4926,56 @@ func (m *AppModel) handleShareCommand() tea.Cmd {
}
}
// buildShareFile assembles a temp JSONL file containing the session data
// with the system-prompt entry inserted after the header line. On success
// the caller owns the returned file and must remove it when done; on error
// any partially-written temp file has already been cleaned up.
func buildShareFile(name string, data, sysPromptJSON []byte) (tmpPath string, err error) {
tmpFile, err := os.CreateTemp("", fmt.Sprintf("kit-%s-*.jsonl", name))
if err != nil {
return "", fmt.Errorf("create temp file: %w", err)
}
tmpPath = tmpFile.Name()
defer func() {
_ = tmpFile.Close()
if err != nil {
_ = os.Remove(tmpPath)
}
}()
// Write the session data with the system prompt entry inserted after the
// header. The header is the first line, so we write:
// 1. First line (header) from original data
// 2. System prompt entry
// 3. Remaining lines from original data
lines := strings.Split(string(data), "\n")
if len(lines) > 0 && lines[len(lines)-1] == "" {
lines = lines[:len(lines)-1] // Remove trailing empty line
}
if len(lines) == 0 {
return tmpPath, nil
}
if _, err = tmpFile.WriteString(lines[0] + "\n"); err != nil {
return "", fmt.Errorf("write temp file: %w", err)
}
if _, err = tmpFile.Write(sysPromptJSON); err != nil {
return "", fmt.Errorf("write system prompt: %w", err)
}
if _, err = tmpFile.WriteString("\n"); err != nil {
return "", fmt.Errorf("write temp file: %w", err)
}
for i := 1; i < len(lines); i++ {
if lines[i] == "" {
continue // Skip empty lines
}
if _, err = tmpFile.WriteString(lines[i] + "\n"); err != nil {
return "", fmt.Errorf("write temp file: %w", err)
}
}
return tmpPath, nil
}
// handleImportCommand imports a session from a JSONL file.
// Usage: /import path.jsonl
func (m *AppModel) handleImportCommand(args string) tea.Cmd {
@@ -4856,6 +5191,14 @@ type externalEditorMsg struct {
err error
}
// editFileMsg is sent when the user returns from $EDITOR after invoking the
// /edit slash command on a specific file. Unlike externalEditorMsg, no text
// is read back — the user edited the file directly on disk.
type editFileMsg struct {
path string
err error
}
// shareResultMsg carries the result of an async gist upload.
type shareResultMsg struct {
err error
@@ -4891,8 +5234,9 @@ type mcpPromptResultMsg struct {
// executed before-session-switch hook. The hook runs in a goroutine so that
// blocking operations like ctx.PromptConfirm() do not deadlock the TUI.
type beforeSessionSwitchResultMsg struct {
cancelled bool
reason string
cancelled bool
reason string
initialPrompt string
}
// beforeForkResultMsg carries the result of an asynchronously executed
+130
View File
@@ -2,6 +2,7 @@ package ui
import (
"errors"
"fmt"
"strings"
"testing"
@@ -87,6 +88,10 @@ func (s *stubAppController) SteerWithFiles(prompt string, _ []kit.LLMFilePart) i
return s.queueLen
}
func (s *stubAppController) PopLastUserMessage() (string, []kit.LLMFilePart, error) {
return "", nil, fmt.Errorf("no user message to retry")
}
// --------------------------------------------------------------------------
// Stub child components
// --------------------------------------------------------------------------
@@ -1139,3 +1144,128 @@ func TestRenderQueuedMessages_truncatesLongMessages(t *testing.T) {
t.Fatalf("expected truncated output to be ≤10 lines, got %d lines", lines)
}
}
// --------------------------------------------------------------------------
// /new <prompt> and ctx.NewSession
// --------------------------------------------------------------------------
// TestNewCommand_noPrompt verifies that /new without an argument resets the
// session (clears messages, prints the system message) and does NOT submit
// any prompt to the controller.
func TestNewCommand_noPrompt(t *testing.T) {
ctrl := &stubAppController{}
m, _, _ := newTestAppModel(ctrl)
m.cwd = t.TempDir()
_ = m.handleNewCommand("")
if len(ctrl.runCalls) != 0 {
t.Fatalf("expected no Run calls for empty prompt, got %v", ctrl.runCalls)
}
if ctrl.clearMsgCalled == 0 {
t.Fatal("expected ClearMessages to be called when no tree session is active")
}
}
// TestNewCommand_withPrompt verifies that /new <prompt> submits the prompt
// to AppController.Run after clearing the session.
func TestNewCommand_withPrompt(t *testing.T) {
ctrl := &stubAppController{}
m, _, _ := newTestAppModel(ctrl)
m.cwd = t.TempDir()
_ = m.handleNewCommand("continue from where we left off")
if len(ctrl.runCalls) != 1 {
t.Fatalf("expected exactly 1 Run call, got %d (%v)", len(ctrl.runCalls), ctrl.runCalls)
}
if ctrl.runCalls[0] != "continue from where we left off" {
t.Fatalf("unexpected prompt submitted: %q", ctrl.runCalls[0])
}
}
// TestNewCommand_whitespacePromptIsEmpty verifies that an all-whitespace
// prompt is treated as empty (no Run call).
func TestNewCommand_whitespacePromptIsEmpty(t *testing.T) {
ctrl := &stubAppController{}
m, _, _ := newTestAppModel(ctrl)
m.cwd = t.TempDir()
_ = m.handleNewCommand(" \n\t ")
if len(ctrl.runCalls) != 0 {
t.Fatalf("expected no Run calls for whitespace-only prompt, got %v", ctrl.runCalls)
}
}
// TestNewSessionRequestEvent_signalsResponseCh verifies that
// app.NewSessionRequestEvent runs the same /new pipeline and delivers a
// nil error to the response channel on success.
func TestNewSessionRequestEvent_signalsResponseCh(t *testing.T) {
ctrl := &stubAppController{}
m, _, _ := newTestAppModel(ctrl)
m.cwd = t.TempDir()
ch := make(chan error, 1)
m = sendMsg(m, app.NewSessionRequestEvent{
InitialPrompt: "hello from extension",
ResponseCh: ch,
})
select {
case err := <-ch:
if err != nil {
t.Fatalf("expected nil error on success, got %v", err)
}
default:
t.Fatal("expected ResponseCh to receive a value")
}
if len(ctrl.runCalls) != 1 || ctrl.runCalls[0] != "hello from extension" {
t.Fatalf("expected prompt to be submitted to Run, got %v", ctrl.runCalls)
}
if m.newSessionResultCh != nil {
t.Fatal("expected newSessionResultCh to be cleared after signaling")
}
}
// TestNewSessionRequestEvent_cancelledByExtension verifies that when the
// before-session-switch hook cancels, the response channel receives an
// error.
func TestNewSessionRequestEvent_cancelledByExtension(t *testing.T) {
ctrl := &stubAppController{}
m, _, _ := newTestAppModel(ctrl)
m.cwd = t.TempDir()
m.emitBeforeSessionSwitch = func(reason, prompt string) (bool, string) {
return true, "vetoed by test"
}
ch := make(chan error, 1)
m = sendMsg(m, app.NewSessionRequestEvent{
InitialPrompt: "should be cancelled",
ResponseCh: ch,
})
// The before-hook runs in a goroutine, which sends back a
// beforeSessionSwitchResultMsg. Pump that synchronously by reading
// the SendEvent call indirectly: SendEvent on stub is a no-op so we
// need to dispatch the message ourselves to simulate the round trip.
sendMsg(m, beforeSessionSwitchResultMsg{
cancelled: true,
reason: "vetoed by test",
initialPrompt: "should be cancelled",
})
select {
case err := <-ch:
if err == nil {
t.Fatal("expected non-nil error on cancellation")
}
if !strings.Contains(err.Error(), "vetoed by test") {
t.Fatalf("expected error to mention the veto reason, got %v", err)
}
default:
t.Fatal("expected ResponseCh to receive a value")
}
if len(ctrl.runCalls) != 0 {
t.Fatalf("expected no Run calls when cancelled, got %v", ctrl.runCalls)
}
}
@@ -0,0 +1,85 @@
package ui
import (
"strings"
"testing"
tea "charm.land/bubbletea/v2"
uicore "github.com/mark3labs/kit/internal/ui/core"
)
// drainCmds runs a tea.Cmd chain back through m.Update like the BubbleTea
// event loop, expanding batches, until no further messages are produced.
func drainCmds(t *testing.T, m *AppModel, cmd tea.Cmd) *AppModel {
t.Helper()
queue := []tea.Cmd{cmd}
for i := 0; i < 50 && len(queue) > 0; i++ {
c := queue[0]
queue = queue[1:]
if c == nil {
continue
}
msg := c()
if msg == nil {
continue
}
if batch, ok := msg.(tea.BatchMsg); ok {
queue = append(queue, batch...)
continue
}
updated, nc := m.Update(msg)
m = updated.(*AppModel)
_ = m.View()
if nc != nil {
queue = append(queue, nc)
}
}
return m
}
func measuredInputHeight(m *AppModel) int {
rendered := m.renderInput()
if rendered == "" {
return 0
}
return strings.Count(rendered, "\n") + 1
}
// TestPendingThumbnailTriggersLayoutRecompute is a regression test for the bug
// where a pasted image's async half-block preview rendered but was clipped off
// the bottom of the screen: the thumbnail arrives via thumbnailReadyMsg after
// distributeHeight already measured the input region without it. The parent
// must mark the layout dirty so the (now taller) input is re-measured.
func TestPendingThumbnailTriggersLayoutRecompute(t *testing.T) {
// Force a truecolor profile so imagepreview.Render deterministically
// produces a thumbnail regardless of the CI terminal's color support.
// Without this, a low-color test environment yields an empty preview and
// the glyph / height assertions below would flake.
t.Setenv("TERM", "xterm-256color")
t.Setenv("COLORTERM", "truecolor")
t.Setenv("NO_COLOR", "")
real := NewInputComponent(80, nil)
m, _, _ := newTestAppModel(nil)
m.input = real
m = sendMsg(m, tea.WindowSizeMsg{Width: 80, Height: 24})
heightBefore := measuredInputHeight(m)
updated, cmd := m.Update(clipboardImageMsg{image: &uicore.ImageAttachment{
Data: makeTestPNG(t, 16, 16),
MediaType: "image/png",
}})
m = updated.(*AppModel)
_ = m.View()
m = drainCmds(t, m, cmd)
heightAfter := measuredInputHeight(m)
if heightAfter <= heightBefore {
t.Errorf("input region should grow to fit the thumbnail (before=%d after=%d)", heightBefore, heightAfter)
}
if !strings.Contains(m.View().Content, "▀") {
t.Error("parent View should contain the half-block thumbnail (was clipped or not rendered)")
}
}
+182 -46
View File
@@ -20,17 +20,23 @@ type PopupItem struct {
Meta any // opaque data returned on selection
}
// PopupList is a generic, themed, scrollable fuzzy-find popup list. It is
// rendered as a centered overlay on top of the normal TUI layout and can be
// reused by any feature that needs a selection popup (slash commands, model
// selector, session picker, extension-provided lists, etc.).
// PopupList is a generic, themed, scrollable popup list used by every
// list-style popup in the TUI (slash commands, @file autocomplete, model
// picker, session picker, tree navigation, etc.).
//
// The caller is responsible for:
// - Building the initial item list
// - Providing a fuzzy-filter callback (or nil for substring matching)
// - Handling the result when the user selects or cancels
// Two layout modes:
// - Centered (default): bordered ~80-col box centered on the screen. Used
// for the input-bar popups (/ and @) and the model picker.
// - FullScreen: bordered panel filling almost the entire terminal. Used by
// /tree, /fork, /sessions and other browse-many-items popups.
//
// Navigation: up/down to move, enter to select, esc to cancel, type to filter.
// Two usage modes:
// - Internal state: caller creates the list with items, calls HandleKey for
// navigation/search, and PopupList owns the cursor and search string.
// Used by selectors like ModelSelector, TreeSelector, SessionSelector.
// - External state: caller drives the items / cursor / search themselves
// (e.g. InputComponent, where typing in the textarea filters the list).
// Caller uses SetItems / SetCursor / SetSearch and only calls Render.
type PopupList struct {
// Title shown at the top of the popup.
Title string
@@ -38,20 +44,45 @@ type PopupList struct {
Subtitle string
// FooterHint overrides the default keyboard-hint footer.
FooterHint string
// ExtraFooter is appended to the footer line (after the default hint).
// Used by selectors to surface mode info like the active filter.
ExtraFooter string
allItems []PopupItem // full unfiltered list
filtered []PopupItem // subset matching the current search
cursor int
search string
// FullScreen renders the popup at almost the full terminal size instead
// of a centered ~80-col box. Used by tree/session/fork selectors.
FullScreen bool
// ShowSearch toggles the "> <query>" search input line. Default true.
ShowSearch bool
// HideCount suppresses the "(i/N)" count in the footer.
HideCount bool
// MaxVisible caps the number of items visible at once. 0 = derive from
// available height.
MaxVisible int
// RenderItem optionally renders a single item row. When nil, the
// built-in label + description + active-checkmark renderer is used.
// innerWidth is the usable line width inside the popup (after border
// and padding). The returned string must already be styled — the
// shared selection-row background is applied by the popup only when
// RenderItem is nil.
RenderItem func(item PopupItem, innerWidth int, isCursor bool) string
// FilterFunc is called with (query, allItems) and should return the
// filtered+scored subset. When nil, a default substring match is used.
// filtered+scored subset. When nil, a default substring + fuzzy match
// is used. Only consulted in internal-state mode (via HandleKey).
FilterFunc func(query string, items []PopupItem) []PopupItem
width int
height int
maxVisible int // max items visible at once (0 = auto from height)
showSearch bool
allItems []PopupItem // full unfiltered list (internal-state mode)
filtered []PopupItem // items currently rendered (driven by FilterFunc
// in internal-state mode, or set directly via SetItems in external mode)
cursor int
search string
width int
height int
}
// PopupResult is returned by HandleKey to tell the caller what happened.
@@ -72,7 +103,7 @@ func NewPopupList(title string, items []PopupItem, width, height int) *PopupList
filtered: items,
width: width,
height: height,
showSearch: true,
ShowSearch: true,
}
// Position cursor on the active item if one exists.
for i, item := range p.filtered {
@@ -90,25 +121,102 @@ func (p *PopupList) SetSize(width, height int) {
p.height = height
}
// SetItems replaces the displayed item list and clamps the cursor. Used by
// external-state callers (e.g. InputComponent) that filter items themselves.
// In internal-state mode, this also replaces the unfiltered backing list.
func (p *PopupList) SetItems(items []PopupItem) {
p.allItems = items
p.filtered = items
if p.cursor >= len(p.filtered) {
p.cursor = max(len(p.filtered)-1, 0)
}
if p.cursor < 0 {
p.cursor = 0
}
}
// SetCursor moves the selection to the given index (clamped to range).
func (p *PopupList) SetCursor(i int) {
if len(p.filtered) == 0 {
p.cursor = 0
return
}
if i < 0 {
i = 0
}
if i >= len(p.filtered) {
i = len(p.filtered) - 1
}
p.cursor = i
}
// Cursor returns the current selection index.
func (p *PopupList) Cursor() int { return p.cursor }
// SetSearch replaces the search string without rebuilding the filtered list.
// Used by external-state callers that filter items themselves.
func (p *PopupList) SetSearch(s string) { p.search = s }
// Items returns the currently-visible (filtered) items.
func (p *PopupList) Items() []PopupItem { return p.filtered }
// Search returns the current search string.
func (p *PopupList) Search() string { return p.search }
// dimensions returns the (popupWidth, popupHeight, innerWidth, innerHeight)
// the popup will render at, given its current size and FullScreen flag.
func (p *PopupList) dimensions() (popupW, popupH, innerW, innerH int) {
if p.FullScreen {
// Leave a small margin so the border doesn't kiss the screen edge.
popupW = max(p.width-2, 20)
popupH = max(p.height-2, 10)
} else {
// Centered: cap at 80 cols, leave a 4-col margin.
popupW = max(min(p.width-4, 80), 20)
// Height is dynamic — let it grow with content within the screen.
popupH = 0
}
// Border (2) + horizontal padding (4) = 6 chrome cols.
innerW = max(popupW-6, 10)
if popupH > 0 {
// Border (2) + vertical padding (2) = 4 chrome rows.
innerH = max(popupH-4, 6)
}
return
}
// visibleCount returns the number of items visible at once.
func (p *PopupList) visibleCount() int {
if p.maxVisible > 0 {
return p.maxVisible
if p.MaxVisible > 0 {
return p.MaxVisible
}
// Reserve: title(1) + subtitle(1) + search(1) + separator(1) + footer(2) + border(2) + padding(2) = 10
if p.FullScreen {
_, _, _, innerH := p.dimensions()
// Reserve: title(1) + subtitle(0|1) + search(0|2) + sep(1) + footer(2)
overhead := 4
if p.Subtitle != "" {
overhead++
}
if p.ShowSearch {
overhead += 2
}
return max(innerH-overhead, 3)
}
// Centered: derive from terminal height (legacy behaviour).
overhead := 8
if p.Subtitle != "" {
overhead++
}
if p.showSearch {
overhead += 2 // search line + separator
if p.ShowSearch {
overhead += 2
}
return max(p.height/2-overhead, 3)
}
// HandleKey processes a single key event and returns the result. The caller
// should inspect PopupResult to decide whether to re-render, close the popup,
// or act on a selection.
// or act on a selection. Internal-state mode only — external-state callers
// drive cursor/search themselves and never call this.
//
// keyName is the Bubble Tea key string (e.g. "up", "down", "enter", "esc").
// keyText is the printable text for character keys (e.g. "a", "1").
@@ -191,7 +299,7 @@ func (p *PopupList) HandleKey(keyName, keyText string) PopupResult {
// as a centered overlay via lipgloss.Place + overlayContent.
func (p *PopupList) Render() string {
theme := style.GetTheme()
popupWidth := max(min(p.width-4, 80), 20)
popupW, popupH, innerW, _ := p.dimensions()
popupBg := theme.Background
popupStyle := lipgloss.NewStyle().
@@ -199,11 +307,12 @@ func (p *PopupList) Render() string {
BorderForeground(theme.Primary).
Background(popupBg).
Padding(1, 2).
Width(popupWidth).
MarginBottom(1)
// Inner content width: popup minus border (2) and horizontal padding (4).
innerWidth := max(popupWidth-6, 10)
Width(popupW)
if popupH > 0 {
popupStyle = popupStyle.Height(popupH)
} else {
popupStyle = popupStyle.MarginBottom(1)
}
var b strings.Builder
@@ -212,7 +321,7 @@ func (p *PopupList) Render() string {
Bold(true).
Foreground(theme.Accent).
Background(popupBg).
Width(innerWidth)
Width(innerW)
b.WriteString(titleStyle.Render(p.Title))
b.WriteString("\n")
@@ -221,17 +330,17 @@ func (p *PopupList) Render() string {
subtitleStyle := lipgloss.NewStyle().
Foreground(theme.Muted).
Background(popupBg).
Width(innerWidth)
Width(innerW)
b.WriteString(subtitleStyle.Render(p.Subtitle))
b.WriteString("\n")
}
// Search input.
if p.showSearch {
if p.ShowSearch {
searchStyle := lipgloss.NewStyle().
Foreground(theme.Info).
Background(popupBg).
Width(innerWidth)
Width(innerW)
if p.search != "" {
b.WriteString(searchStyle.Render(fmt.Sprintf("> %s", p.search)))
} else {
@@ -243,7 +352,7 @@ func (p *PopupList) Render() string {
sepStyle := lipgloss.NewStyle().
Foreground(theme.Muted).
Background(popupBg)
b.WriteString(sepStyle.Render(strings.Repeat("─", innerWidth)))
b.WriteString(sepStyle.Render(strings.Repeat("─", innerW)))
b.WriteString("\n")
}
@@ -251,20 +360,20 @@ func (p *PopupList) Render() string {
normalItemBg := lipgloss.NewStyle().
Background(popupBg).
Foreground(theme.Text).
Width(innerWidth).
Width(innerW).
Padding(0, 1)
selectedItemBg := lipgloss.NewStyle().
Background(theme.Primary).
Foreground(theme.Background).
Width(innerWidth).
Width(innerW).
Padding(0, 1).
Bold(true)
scrollStyle := lipgloss.NewStyle().
Background(popupBg).
Foreground(theme.VeryMuted).
Width(innerWidth).
Width(innerW).
Padding(0, 1)
vis := p.visibleCount()
@@ -274,7 +383,7 @@ func (p *PopupList) Render() string {
emptyStyle := lipgloss.NewStyle().
Foreground(theme.Muted).
Background(popupBg).
Width(innerWidth).
Width(innerW).
Padding(0, 1)
if p.search != "" {
items = append(items, emptyStyle.Render("No matches for \""+p.search+"\""))
@@ -282,9 +391,14 @@ func (p *PopupList) Render() string {
items = append(items, emptyStyle.Render("No items"))
}
} else {
// Center the cursor in the visible window so the user always sees
// context above and below. Clamp to bounds.
startIdx := 0
if p.cursor >= vis {
startIdx = p.cursor - vis + 1
if len(p.filtered) > vis {
startIdx = max(p.cursor-vis/2, 0)
if startIdx+vis > len(p.filtered) {
startIdx = len(p.filtered) - vis
}
}
endIdx := min(startIdx+vis, len(p.filtered))
@@ -292,10 +406,27 @@ func (p *PopupList) Render() string {
items = append(items, scrollStyle.Render(" ↑ more above"))
}
// Account for the consumed padding (1 left + 1 right = 2 cols)
// when rendering item content so RenderItem callbacks can match.
itemContentWidth := max(innerW-2, 6)
for i := startIdx; i < endIdx; i++ {
entry := p.filtered[i]
isCursor := i == p.cursor
if p.RenderItem != nil {
// Custom renderer: caller produces the inner text. We still
// wrap it in a full-width row so the selection highlight
// covers the line edge-to-edge.
rowStyle := normalItemBg
if isCursor {
rowStyle = selectedItemBg
}
content := p.RenderItem(entry, itemContentWidth, isCursor)
items = append(items, rowStyle.Render(content))
continue
}
itemStyle := normalItemBg
if isCursor {
itemStyle = selectedItemBg
@@ -310,7 +441,7 @@ func (p *PopupList) Render() string {
}
// Build content: indicator + label + description + active checkmark.
content := p.renderItemContent(indicator, entry, innerWidth, isCursor)
content := p.renderItemContent(indicator, entry, itemContentWidth, isCursor)
items = append(items, itemStyle.Render(content))
}
@@ -323,19 +454,24 @@ func (p *PopupList) Render() string {
// Footer with count and keyboard hints.
var footerParts []string
footerParts = append(footerParts, fmt.Sprintf("(%d/%d)", p.cursor+1, len(p.filtered)))
if !p.HideCount {
footerParts = append(footerParts, fmt.Sprintf("(%d/%d)", p.cursor+1, len(p.filtered)))
}
footerHint := p.FooterHint
if footerHint == "" {
if innerWidth >= 50 {
if innerW >= 50 {
footerHint = "↑↓ navigate • enter select • esc cancel • type to filter"
} else if innerWidth >= 30 {
} else if innerW >= 30 {
footerHint = "↑↓ nav • ↵ select • esc"
} else {
footerHint = "↑↓ ↵ esc"
}
}
footerParts = append(footerParts, footerHint)
if p.ExtraFooter != "" {
footerParts = append(footerParts, p.ExtraFooter)
}
footer := lipgloss.NewStyle().
Background(popupBg).
+131 -304
View File
@@ -5,7 +5,6 @@ import (
"regexp"
"strings"
"time"
"unicode/utf8"
"charm.land/bubbles/v2/key"
tea "charm.land/bubbletea/v2"
@@ -62,17 +61,14 @@ func (m SessionFilterMode) String() string {
// controlCharsRe matches ASCII control characters for stripping from previews.
var controlCharsRe = regexp.MustCompile(`[\x00-\x1f\x7f]`)
// SessionSelectorComponent is a full-screen Bubble Tea component that lets
// the user browse and select from available sessions. Modeled after pi's
// session picker: right-aligned metadata, background-highlighted selection,
// scope/filter toggles, and inline search.
// SessionSelectorComponent is a Bubble Tea component that lets the user browse
// and select from available sessions. It wraps PopupList in FullScreen mode:
// PopupList owns the cursor/search/scroll math/chrome; this component owns
// the session list, scope/filter toggles, and delete-confirmation flow.
type SessionSelectorComponent struct {
allSessions []session.SessionInfo
cwdSessions []session.SessionInfo
filtered []session.SessionInfo
cursor int
search string
filtered []session.SessionInfo // matches popup.Items() 1:1
scope SessionScopeMode
filter SessionFilterMode
@@ -80,6 +76,7 @@ type SessionSelectorComponent struct {
// currentPath is the active session file path for marking it in the list.
currentPath string
popup *PopupList
width int
height int
active bool
@@ -110,7 +107,12 @@ func NewSessionSelector(cwd string, width, height int) *SessionSelectorComponent
ss.scope = SessionScopeAll
}
ss.rebuildFiltered()
ss.popup = NewPopupList("Resume Session", nil, width, height)
ss.popup.FullScreen = true
ss.popup.FooterHint = "↑↓ nav • ↵ open • esc cancel • tab scope • ^N named • d delete • type to search"
ss.popup.RenderItem = ss.renderEntry
ss.rebuild()
return ss
}
@@ -131,10 +133,11 @@ func (ss *SessionSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case tea.WindowSizeMsg:
ss.width = msg.Width
ss.height = msg.Height
ss.popup.SetSize(msg.Width, msg.Height)
return ss, nil
case tea.KeyPressMsg:
// Delete confirmation mode.
// Delete confirmation mode swallows all keys until y/n.
if ss.confirmDelete >= 0 {
switch msg.String() {
case "y", "Y":
@@ -145,7 +148,7 @@ func (ss *SessionSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if err := session.DeleteSession(info.Path); err == nil {
name := sessionDisplayName(info)
ss.removeSession(info.Path)
ss.rebuildFiltered()
ss.rebuild()
return ss, func() tea.Msg {
return SessionDeletedMsg{Name: name}
}
@@ -159,64 +162,14 @@ func (ss *SessionSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
switch {
case key.Matches(msg, key.NewBinding(key.WithKeys("up"))):
if ss.cursor > 0 {
ss.cursor--
}
case key.Matches(msg, key.NewBinding(key.WithKeys("down"))):
if ss.cursor < len(ss.filtered)-1 {
ss.cursor++
}
case key.Matches(msg, key.NewBinding(key.WithKeys("pgup"))):
ss.cursor -= ss.visibleHeight()
if ss.cursor < 0 {
ss.cursor = 0
}
case key.Matches(msg, key.NewBinding(key.WithKeys("pgdown"))):
ss.cursor += ss.visibleHeight()
if ss.cursor >= len(ss.filtered) {
ss.cursor = len(ss.filtered) - 1
}
if ss.cursor < 0 {
ss.cursor = 0
}
case key.Matches(msg, key.NewBinding(key.WithKeys("home"))):
ss.cursor = 0
case key.Matches(msg, key.NewBinding(key.WithKeys("end"))):
ss.cursor = max(len(ss.filtered)-1, 0)
case key.Matches(msg, key.NewBinding(key.WithKeys("enter"))):
if ss.cursor < len(ss.filtered) {
info := ss.filtered[ss.cursor]
ss.active = false
return ss, func() tea.Msg {
return SessionSelectedMsg{Path: info.Path}
}
}
case key.Matches(msg, key.NewBinding(key.WithKeys("esc"))):
if ss.search != "" {
ss.search = ""
ss.rebuildFiltered()
} else {
ss.active = false
return ss, func() tea.Msg {
return SessionSelectorCancelledMsg{}
}
}
case key.Matches(msg, key.NewBinding(key.WithKeys("tab"))):
if ss.scope == SessionScopeCwd {
ss.scope = SessionScopeAll
} else {
ss.scope = SessionScopeCwd
}
ss.rebuildFiltered()
ss.rebuild()
return ss, nil
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+n"))):
if ss.filter == SessionFilterAll {
@@ -224,25 +177,48 @@ func (ss *SessionSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
} else {
ss.filter = SessionFilterAll
}
ss.rebuildFiltered()
case key.Matches(msg, key.NewBinding(key.WithKeys("d"))):
if ss.cursor < len(ss.filtered) {
ss.confirmDelete = ss.cursor
}
ss.rebuild()
return ss, nil
default:
if msg.Text != "" && len(msg.Text) == 1 {
ch := msg.Text[0]
if ch >= 32 && ch < 127 {
ss.search += string(ch)
ss.rebuildFiltered()
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+d"))):
// Ctrl+D as an explicit delete shortcut. Plain "d" still works
// below when the search field is empty so it doesn't conflict
// with typing the letter 'd' into a query.
if c := ss.popup.Cursor(); c < len(ss.filtered) {
ss.confirmDelete = c
}
return ss, nil
}
// Plain 'd' triggers delete only when there's no active search
// query (otherwise the user would never be able to type 'd' into
// a search like "doc").
if msg.String() == "d" && !ss.popup.IsSearching() {
if c := ss.popup.Cursor(); c < len(ss.filtered) {
ss.confirmDelete = c
return ss, nil
}
}
// Delegate everything else to the popup.
result := ss.popup.HandleKey(msg.String(), msg.Text)
if result.Changed {
ss.syncFiltered()
}
if result.Selected != nil {
cursor := ss.popup.Cursor()
if cursor < len(ss.filtered) {
info := ss.filtered[cursor]
ss.active = false
return ss, func() tea.Msg {
return SessionSelectedMsg{Path: info.Path}
}
}
if key.Matches(msg, key.NewBinding(key.WithKeys("backspace"))) && len(ss.search) > 0 {
ss.search = ss.search[:len(ss.search)-1]
ss.rebuildFiltered()
}
if result.Cancelled {
ss.active = false
return ss, func() tea.Msg {
return SessionSelectorCancelledMsg{}
}
}
}
@@ -251,152 +227,17 @@ func (ss *SessionSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// View implements tea.Model.
func (ss *SessionSelectorComponent) View() tea.View {
theme := style.GetTheme()
// Full-screen bordered container - uses entire terminal width and height
maxWidth := ss.width - 2 // Small margin on each side
if maxWidth < 20 {
maxWidth = ss.width
}
maxHeight := ss.height - 2 // Small margin top/bottom to prevent overflow
if maxHeight < 10 {
maxHeight = ss.height
}
horizontalPadding := 1
innerWidth := maxWidth - 4 // Account for border (2) + padding (2)
innerHeight := maxHeight - 4 // Account for border (2) + padding (2)
// Container style with border - full width/height like a framed panel
containerStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(theme.Primary).
Background(theme.Background).
Padding(1, horizontalPadding).
Width(maxWidth).
Height(maxHeight)
var contentBuilder strings.Builder
// ── Header: title + scope badges ─────────────────────────────
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(theme.Accent).
Background(theme.Background)
contentBuilder.WriteString(titleStyle.Render(fmt.Sprintf("Resume Session (%s)", ss.scope)))
contentBuilder.WriteString("\n")
// ── Help / keybindings ───────────────────────────────────────
helpStyle := lipgloss.NewStyle().
Foreground(theme.Muted).
Background(theme.Background)
if innerWidth >= 75 {
contentBuilder.WriteString(helpStyle.Render("tab: scope N: named D: delete R: rename type to search esc: cancel"))
} else if innerWidth >= 50 {
contentBuilder.WriteString(helpStyle.Render("tab scope N named D del type to search esc"))
} else {
contentBuilder.WriteString(helpStyle.Render("tab N D esc"))
}
contentBuilder.WriteString("\n")
// ── Search (only shown when active) ──────────────────────────
if ss.search != "" {
searchStyle := lipgloss.NewStyle().
Foreground(theme.Info).
Background(theme.Background)
contentBuilder.WriteString(searchStyle.Render(fmt.Sprintf("> %s", ss.search)))
contentBuilder.WriteString("\n")
}
// Separator line
sepWidth := innerWidth
contentBuilder.WriteString(
lipgloss.NewStyle().
Foreground(theme.Muted).
Background(theme.Background).
Render(strings.Repeat("─", sepWidth)))
contentBuilder.WriteString("\n")
// ── Delete confirmation ──────────────────────────────────────
// Compose dynamic footer extras: scope + filter + (delete confirm).
extra := fmt.Sprintf("scope: %s • filter: %s", ss.scope, ss.filter)
if ss.confirmDelete >= 0 && ss.confirmDelete < len(ss.filtered) {
warnStyle := lipgloss.NewStyle().
Foreground(theme.Error).
Bold(true).
Background(theme.Background)
name := sessionDisplayName(ss.filtered[ss.confirmDelete])
contentBuilder.WriteString(warnStyle.Render(fmt.Sprintf("Delete %q? (y/N)", truncateRunes(name, 40))))
contentBuilder.WriteString("\n")
name := truncateRunes(sessionDisplayName(ss.filtered[ss.confirmDelete]), 30)
extra = fmt.Sprintf("delete %q? y/N", name)
}
ss.popup.Title = fmt.Sprintf("Resume Session (%s)", ss.scope)
ss.popup.ExtraFooter = extra
// ── Session list ─────────────────────────────────────────────
if len(ss.filtered) == 0 {
emptyStyle := lipgloss.NewStyle().
Foreground(theme.Muted).
Background(theme.Background)
if ss.search != "" {
contentBuilder.WriteString(emptyStyle.Render(fmt.Sprintf("No sessions matching %q", ss.search)))
} else if ss.filter == SessionFilterNamed {
contentBuilder.WriteString(emptyStyle.Render("No named sessions. Press N to show all."))
} else if ss.scope == SessionScopeCwd {
contentBuilder.WriteString(emptyStyle.Render("No sessions in current folder. Press tab to view all."))
} else {
contentBuilder.WriteString(emptyStyle.Render("No sessions found"))
}
contentBuilder.WriteString("\n")
} else {
// Compute visible window based on inner container height
// Chrome: header(2) + separator(1) + footer separator(1) + footer(1) = 5
chromeLines := 5
if ss.search != "" {
chromeLines++
}
if ss.confirmDelete >= 0 {
chromeLines++
}
visH := max(innerHeight-chromeLines, 3)
// Center the cursor in the visible window.
startIdx := max(0, min(ss.cursor-visH/2, len(ss.filtered)-visH))
endIdx := min(startIdx+visH, len(ss.filtered))
for i := startIdx; i < endIdx; i++ {
info := ss.filtered[i]
isCursor := i == ss.cursor
isCurrent := info.Path == ss.currentPath
isDeleting := i == ss.confirmDelete
line := ss.renderEntry(info, isCursor, isCurrent, isDeleting, innerWidth)
contentBuilder.WriteString(line)
contentBuilder.WriteString("\n")
}
// Scroll position indicator.
if len(ss.filtered) > visH {
posStyle := lipgloss.NewStyle().
Foreground(theme.Muted).
Background(theme.Background)
contentBuilder.WriteString(posStyle.Render(fmt.Sprintf("(%d/%d)", ss.cursor+1, len(ss.filtered))))
contentBuilder.WriteString("\n")
}
}
// Footer separator
contentBuilder.WriteString(
lipgloss.NewStyle().
Foreground(theme.Muted).
Background(theme.Background).
Render(strings.Repeat("─", sepWidth)))
contentBuilder.WriteString("\n")
// Footer with filter info
footerStyle := lipgloss.NewStyle().
Foreground(theme.Muted).
Background(theme.Background)
contentBuilder.WriteString(footerStyle.Render(fmt.Sprintf("Filter: %s", ss.filter)))
// Apply the bordered container
content := contentBuilder.String()
borderedContent := containerStyle.Render(content)
v := tea.NewView(borderedContent)
rendered := ss.popup.RenderCentered(ss.width, ss.height)
v := tea.NewView(rendered)
v.AltScreen = true
return v
}
@@ -408,20 +249,9 @@ func (ss *SessionSelectorComponent) IsActive() bool {
// --- Internal helpers ---
func (ss *SessionSelectorComponent) visibleHeight() int {
// Reserve: title(1) + help(1) + blank(1) + scroll indicator(1) = 4.
// Optional: search(1), delete confirm(1).
chrome := 4
if ss.search != "" {
chrome++
}
if ss.confirmDelete >= 0 {
chrome++
}
return max(ss.height-chrome, 3)
}
func (ss *SessionSelectorComponent) rebuildFiltered() {
// rebuild applies the scope and filter selections, then publishes the
// resulting session list to the popup.
func (ss *SessionSelectorComponent) rebuild() {
var source []session.SessionInfo
if ss.scope == SessionScopeCwd {
source = ss.cwdSessions
@@ -439,23 +269,33 @@ func (ss *SessionSelectorComponent) rebuildFiltered() {
source = named
}
if ss.search != "" {
query := strings.ToLower(ss.search)
var matches []session.SessionInfo
for _, s := range source {
haystack := strings.ToLower(s.Name + " " + s.FirstMessage + " " + s.Cwd)
if strings.Contains(haystack, query) {
matches = append(matches, s)
}
// Build PopupItems. The Label holds a haystack string (name + first
// message + cwd) so PopupList's default filter can match against any
// of those fields. We render each row with a custom RenderItem.
items := make([]PopupItem, len(source))
for i, s := range source {
haystack := strings.TrimSpace(s.Name + " " + s.FirstMessage + " " + s.Cwd)
items[i] = PopupItem{
Label: haystack,
Active: s.Path == ss.currentPath,
Meta: s,
}
ss.filtered = matches
} else {
ss.filtered = source
}
ss.popup.SetItems(items)
ss.syncFiltered()
}
if ss.cursor >= len(ss.filtered) {
ss.cursor = max(len(ss.filtered)-1, 0)
// syncFiltered refreshes the filtered slice from popup.Items() so cursor
// indices map back to session.SessionInfo for the parent.
func (ss *SessionSelectorComponent) syncFiltered() {
items := ss.popup.Items()
out := make([]session.SessionInfo, 0, len(items))
for _, it := range items {
if s, ok := it.Meta.(session.SessionInfo); ok {
out = append(out, s)
}
}
ss.filtered = out
}
func (ss *SessionSelectorComponent) removeSession(path string) {
@@ -473,87 +313,74 @@ func removeByPath(sessions []session.SessionInfo, path string) []session.Session
return result
}
// renderEntry renders a single session line with right-aligned metadata.
// Layout: [cursor 2] [message ...variable...] [padding] [count age] [cwd?]
func (ss *SessionSelectorComponent) renderEntry(info session.SessionInfo, isCursor, isCurrent, isDeleting bool, width int) string {
// renderEntry is the RenderItem callback handed to PopupList. It produces a
// single-line entry with left-aligned message text and right-aligned
// metadata (message count + relative time, plus optional cwd in "All" scope).
//
// When isCursor we return a plain (unstyled) string so PopupList's outer
// row style can paint one continuous fg+bg span. Mixing inner lipgloss
// Render calls with an outer Background() breaks the highlight into bars,
// because each inner Render emits an ANSI reset that drops the background.
func (ss *SessionSelectorComponent) renderEntry(item PopupItem, innerWidth int, isCursor bool) string {
theme := style.GetTheme()
info, ok := item.Meta.(session.SessionInfo)
if !ok {
return item.Label
}
isCurrent := info.Path == ss.currentPath
isDeleting := ss.confirmDelete >= 0 && ss.confirmDelete < len(ss.filtered) &&
ss.filtered[ss.confirmDelete].Path == info.Path
// ── Cursor indicator (2 chars) ───────────────────────────────
cursorStr := " "
// Cursor indicator (2 cells).
indicator := " "
if isCursor {
cursorStr = lipgloss.NewStyle().Foreground(theme.Accent).Render("> ")
indicator = "> "
}
const cursorW = 2
// ── Right part: message count + relative time (+ optional cwd) ──
// Right-hand metadata.
age := relativeTime(info.Modified)
msgCount := fmt.Sprintf("%d", info.MessageCount)
rightPart := msgCount + " " + age
right := fmt.Sprintf("%d %s", info.MessageCount, age)
if ss.scope == SessionScopeAll && info.Cwd != "" {
shortCwd := shortenPath(info.Cwd)
if len(shortCwd) > 25 {
shortCwd = "..." + shortCwd[len(shortCwd)-22:]
}
rightPart = shortCwd + " " + rightPart
shortCwd := truncateRunes(shortenPath(info.Cwd), 25)
right = shortCwd + " " + right
}
rightW := utf8.RuneCountInString(rightPart)
rightW := lipgloss.Width(right)
// Message text width: innerWidth minus indicator(2) minus right minus gap(2).
availForMsg := max(innerWidth-2-rightW-2, 10)
// ── Message text ─────────────────────────────────────────────
displayText := sessionDisplayName(info)
// Strip control characters and collapse whitespace.
displayText = controlCharsRe.ReplaceAllString(displayText, " ")
displayText = strings.Join(strings.Fields(displayText), " ")
displayText = truncateRunes(displayText, availForMsg)
availableForMsg := max(width-cursorW-rightW-2, 10) // 2 for min spacing
displayText = truncateRunes(displayText, availableForMsg)
msgW := utf8.RuneCountInString(displayText)
msgW := lipgloss.Width(displayText)
spacing := max(innerWidth-2-msgW-rightW, 1)
// ── Style the message ────────────────────────────────────────
var msgStyle lipgloss.Style
// Selected row: raw string, outer row style paints it.
if isCursor {
return indicator + displayText + strings.Repeat(" ", spacing) + right
}
// Color the message text by state.
var msgStyle, rightStyle lipgloss.Style
switch {
case isDeleting:
msgStyle = lipgloss.NewStyle().Foreground(theme.Error)
case isCurrent:
msgStyle = lipgloss.NewStyle().Foreground(theme.Accent)
msgStyle = lipgloss.NewStyle().Foreground(theme.Accent).Bold(true)
case info.Name != "":
msgStyle = lipgloss.NewStyle().Foreground(theme.Warning)
default:
msgStyle = lipgloss.NewStyle().Foreground(theme.Text)
}
// ── Style the right part ─────────────────────────────────────
rightColor := theme.Muted
if isDeleting {
rightColor = theme.Error
}
var styledRight string
// ── Assemble with spacing ────────────────────────────────────
spacing := max(width-cursorW-msgW-rightW, 1)
// If selected, use inverted colors like PopupList
if isCursor {
// Inverted colors for selected item
msgStyle = lipgloss.NewStyle().
Background(theme.Primary).
Foreground(theme.Background).
Bold(true)
styledRight = lipgloss.NewStyle().
Background(theme.Primary).
Foreground(rightColor).
Render(rightPart)
cursorStr = lipgloss.NewStyle().
Background(theme.Primary).
Foreground(theme.Accent).
Render("> ")
rightStyle = lipgloss.NewStyle().Foreground(theme.Error)
} else {
styledRight = lipgloss.NewStyle().Foreground(rightColor).Render(rightPart)
rightStyle = lipgloss.NewStyle().Foreground(theme.Muted)
}
styledMsg := msgStyle.Render(displayText)
line := cursorStr + styledMsg + strings.Repeat(" ", spacing) + styledRight
return line
return indicator + msgStyle.Render(displayText) + strings.Repeat(" ", spacing) + rightStyle.Render(right)
}
// --- Package helpers ---
@@ -570,7 +397,7 @@ func sessionDisplayName(info session.SessionInfo) string {
return "(empty session)"
}
// truncateRunes truncates a string to at most maxRunes runes, appending "..."
// truncateRunes truncates a string to at most maxRunes runes, appending ""
// if truncated.
func truncateRunes(s string, maxRunes int) string {
if maxRunes <= 0 {
+136
View File
@@ -0,0 +1,136 @@
package ui
import (
"bytes"
"image"
"image/color"
"image/png"
"strings"
"testing"
uicore "github.com/mark3labs/kit/internal/ui/core"
)
// makeTestPNG builds a small solid-color PNG for transcript preview tests.
func makeTestPNG(t *testing.T, w, h int) []byte {
t.Helper()
img := image.NewRGBA(image.Rect(0, 0, w, h))
for y := range h {
for x := range w {
img.Set(x, y, color.RGBA{R: 200, G: 40, B: 90, A: 255})
}
}
var buf bytes.Buffer
if err := png.Encode(&buf, img); err != nil {
t.Fatalf("encode png: %v", err)
}
return buf.Bytes()
}
func TestTranscriptPreviewCmdNoImages(t *testing.T) {
m, _, _ := newTestAppModel(nil)
if cmd := m.transcriptPreviewCmd(nil, ""); cmd != nil {
t.Error("expected nil cmd when there are no images")
}
}
func TestTranscriptPreviewCmdRendersBlock(t *testing.T) {
m, _, _ := newTestAppModel(nil)
images := []uicore.ImageAttachment{
{Data: makeTestPNG(t, 16, 16), MediaType: "image/png"},
}
cmd := m.transcriptPreviewCmd(images, "anchor-1")
if cmd == nil {
t.Fatal("expected a non-nil cmd for a valid image")
}
msg := cmd()
// The result depends on the test process color profile. When the
// terminal supports color the cmd yields a preview block; otherwise it
// yields nil (caller keeps the text badge). Both are valid — assert the
// shape only when a block is produced.
if msg == nil {
t.Skip("color profile below ANSI256 in test env; preview correctly skipped")
}
ready, ok := msg.(imagePreviewReadyMsg)
if !ok {
t.Fatalf("expected imagePreviewReadyMsg, got %T", msg)
}
if !strings.Contains(ready.block, "▀") {
t.Errorf("preview block should contain half-block glyphs, got %q", ready.block)
}
if ready.anchorID != "anchor-1" {
t.Errorf("preview should carry the originating anchorID, got %q", ready.anchorID)
}
}
func TestImagePreviewReadyMsgAppendsItem(t *testing.T) {
m, _, _ := newTestAppModel(nil)
before := len(m.messages)
m = sendMsg(m, imagePreviewReadyMsg{block: "\x1b[38;2;1;2;3;48;2;4;5;6m▀\x1b[0m"})
if len(m.messages) != before+1 {
t.Fatalf("expected one appended message item, got %d (was %d)", len(m.messages), before)
}
last, ok := m.messages[len(m.messages)-1].(*TextMessageItem)
if !ok {
t.Fatalf("expected last item to be *TextMessageItem, got %T", m.messages[len(m.messages)-1])
}
if !strings.Contains(last.Render(0), "▀") {
t.Error("appended preview item should render the half-block block verbatim")
}
}
// TestImagePreviewReadyMsgInsertsAfterAnchor verifies the preview is placed
// directly after its originating user message even when a later message (e.g.
// a streamed assistant reply) was already appended while the thumbnail was
// being decoded asynchronously.
func TestImagePreviewReadyMsgInsertsAfterAnchor(t *testing.T) {
m, _, _ := newTestAppModel(nil)
userItem := NewStyledMessageItem("user-anchor", "user", "hi", "hi")
assistantItem := NewStyledMessageItem("assistant-1", "assistant", "reply", "reply")
m.messages = append(m.messages, userItem, assistantItem)
m = sendMsg(m, imagePreviewReadyMsg{
block: "\x1b[38;2;1;2;3;48;2;4;5;6m▀\x1b[0m",
anchorID: "user-anchor",
})
// Expect order: user, preview, assistant.
if len(m.messages) != 3 {
t.Fatalf("expected 3 messages, got %d", len(m.messages))
}
if m.messages[0].ID() != "user-anchor" {
t.Errorf("messages[0] should be the user message, got %q", m.messages[0].ID())
}
if m.messages[2].ID() != "assistant-1" {
t.Errorf("messages[2] should be the assistant message, got %q", m.messages[2].ID())
}
if !strings.Contains(m.messages[1].Render(0), "▀") {
t.Errorf("messages[1] should be the inserted preview, got %q", m.messages[1].Render(0))
}
}
// TestImagePreviewReadyMsgUnknownAnchorAppends verifies that when the anchor
// is missing (e.g. the message was cleared), the preview falls back to append.
func TestImagePreviewReadyMsgUnknownAnchorAppends(t *testing.T) {
m, _, _ := newTestAppModel(nil)
m.messages = append(m.messages, NewStyledMessageItem("only", "user", "hi", "hi"))
m = sendMsg(m, imagePreviewReadyMsg{
block: "\x1b[38;2;1;2;3;48;2;4;5;6m▀\x1b[0m",
anchorID: "does-not-exist",
})
if len(m.messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(m.messages))
}
if !strings.Contains(m.messages[1].Render(0), "▀") {
t.Error("preview should be appended as the last item when anchor is unknown")
}
}
func TestImagePreviewReadyMsgEmptyBlockIgnored(t *testing.T) {
m, _, _ := newTestAppModel(nil)
before := len(m.messages)
m = sendMsg(m, imagePreviewReadyMsg{block: ""})
if len(m.messages) != before {
t.Errorf("empty preview block should not append an item; got %d (was %d)", len(m.messages), before)
}
}
+183 -315
View File
@@ -53,16 +53,19 @@ type FlatNode struct {
}
// TreeSelectorComponent is a Bubble Tea component that renders the session
// tree as an ASCII art list with navigation and selection.
// tree as an ASCII art list with navigation and selection. It is a thin
// wrapper around PopupList (in FullScreen mode) — PopupList owns the cursor,
// search, scroll math, and chrome; TreeSelectorComponent supplies the
// filtered node list and a custom RenderItem that draws each tree node with
// its indentation prefix and role colors.
type TreeSelectorComponent struct {
tm *session.TreeManager
flatNodes []FlatNode
cursor int
flatNodes []FlatNode // visible nodes (matches popup.Items() 1:1)
filter TreeFilterMode
leafID string // real leaf for "active" marker
popup *PopupList
width int
height int
search string
active bool
selectedID string // set when user selects a node
cancelled bool
@@ -78,11 +81,12 @@ func NewTreeSelector(tm *session.TreeManager, width, height int) *TreeSelectorCo
height: height,
active: true,
}
ts.rebuildFlatList()
ts.initPopup()
ts.rebuild()
// Position cursor at the active leaf.
for i, node := range ts.flatNodes {
if node.ID == ts.leafID {
ts.cursor = i
ts.popup.SetCursor(i)
break
}
}
@@ -100,17 +104,25 @@ func NewTreeSelectorForFork(tm *session.TreeManager, width, height int) *TreeSel
height: height,
active: true,
}
ts.rebuildFlatList()
ts.initPopup()
ts.rebuild()
// Position cursor at the last user message before the leaf.
for i := len(ts.flatNodes) - 1; i >= 0; i-- {
if ts.isUserMessage(ts.flatNodes[i].Entry) {
ts.cursor = i
ts.popup.SetCursor(i)
break
}
}
return ts
}
func (ts *TreeSelectorComponent) initPopup() {
ts.popup = NewPopupList("Session Tree", nil, ts.width, ts.height)
ts.popup.FullScreen = true
ts.popup.FooterHint = "↑↓ nav • ←→ page • ↵ select • esc cancel • ^O filter • type to search"
ts.popup.RenderItem = ts.renderNode
}
// Init implements tea.Model.
func (ts *TreeSelectorComponent) Init() tea.Cmd {
return nil
@@ -122,96 +134,75 @@ func (ts *TreeSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case tea.WindowSizeMsg:
ts.width = msg.Width
ts.height = msg.Height
ts.popup.SetSize(msg.Width, msg.Height)
return ts, nil
case tea.KeyPressMsg:
// Tree-specific keys we handle ourselves before delegating to popup.
switch {
case key.Matches(msg, key.NewBinding(key.WithKeys("up"))):
if ts.cursor > 0 {
ts.cursor--
}
case key.Matches(msg, key.NewBinding(key.WithKeys("down"))):
if ts.cursor < len(ts.flatNodes)-1 {
ts.cursor++
}
case key.Matches(msg, key.NewBinding(key.WithKeys("left", "pgup"))):
// Page up.
ts.cursor -= ts.visibleHeight()
if ts.cursor < 0 {
ts.cursor = 0
}
result := ts.popup.HandleKey("pgup", "")
_ = result
return ts, nil
case key.Matches(msg, key.NewBinding(key.WithKeys("right", "pgdown"))):
// Page down.
ts.cursor += ts.visibleHeight()
if ts.cursor >= len(ts.flatNodes) {
ts.cursor = len(ts.flatNodes) - 1
}
result := ts.popup.HandleKey("pgdown", "")
_ = result
return ts, nil
case key.Matches(msg, key.NewBinding(key.WithKeys("home"))):
ts.cursor = 0
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+o"))):
ts.filter = (ts.filter + 1) % 5
ts.rebuild()
return ts, nil
case key.Matches(msg, key.NewBinding(key.WithKeys("end"))):
ts.cursor = len(ts.flatNodes) - 1
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+d"))):
ts.filter = TreeFilterDefault
ts.rebuild()
return ts, nil
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+t"))):
ts.filter = TreeFilterNoTools
ts.rebuild()
return ts, nil
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+u"))):
ts.filter = TreeFilterUserOnly
ts.rebuild()
return ts, nil
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+l"))):
ts.filter = TreeFilterLabelOnly
ts.rebuild()
return ts, nil
}
case key.Matches(msg, key.NewBinding(key.WithKeys("enter"))):
if ts.cursor < len(ts.flatNodes) {
ts.selectedID = ts.flatNodes[ts.cursor].ID
// Delegate everything else (nav, search, enter, esc) to the popup.
result := ts.popup.HandleKey(msg.String(), msg.Text)
// Update our flatNodes view if popup filtered/changed search.
if result.Changed {
ts.syncFlatNodes()
}
if result.Selected != nil {
cursor := ts.popup.Cursor()
if cursor < len(ts.flatNodes) {
node := ts.flatNodes[cursor]
ts.selectedID = node.ID
ts.active = false
return ts, func() tea.Msg {
return core.TreeNodeSelectedMsg{
ID: ts.selectedID,
Entry: ts.flatNodes[ts.cursor].Entry,
IsUser: ts.isUserMessage(ts.flatNodes[ts.cursor].Entry),
UserText: ts.extractUserText(ts.flatNodes[ts.cursor].Entry),
ID: node.ID,
Entry: node.Entry,
IsUser: ts.isUserMessage(node.Entry),
UserText: ts.extractUserText(node.Entry),
}
}
}
case key.Matches(msg, key.NewBinding(key.WithKeys("esc"))):
if ts.search != "" {
ts.search = ""
ts.rebuildFlatList()
} else {
ts.cancelled = true
ts.active = false
return ts, func() tea.Msg {
return core.TreeCancelledMsg{}
}
}
// Filter cycle with ctrl+o.
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+o"))):
ts.filter = (ts.filter + 1) % 5
ts.rebuildFlatList()
// Direct filter shortcuts.
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+d"))):
ts.filter = TreeFilterDefault
ts.rebuildFlatList()
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+t"))):
ts.filter = TreeFilterNoTools
ts.rebuildFlatList()
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+u"))):
ts.filter = TreeFilterUserOnly
ts.rebuildFlatList()
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+l"))):
ts.filter = TreeFilterLabelOnly
ts.rebuildFlatList()
default:
// Typing search.
if msg.Text != "" && len(msg.Text) == 1 {
ch := msg.Text[0]
if ch >= 32 && ch < 127 {
ts.search += string(ch)
ts.rebuildFlatList()
}
}
if key.Matches(msg, key.NewBinding(key.WithKeys("backspace"))) && len(ts.search) > 0 {
ts.search = ts.search[:len(ts.search)-1]
ts.rebuildFlatList()
}
if result.Cancelled {
ts.cancelled = true
ts.active = false
return ts, func() tea.Msg {
return core.TreeCancelledMsg{}
}
}
}
@@ -220,128 +211,10 @@ func (ts *TreeSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// View implements tea.Model.
func (ts *TreeSelectorComponent) View() tea.View {
theme := GetTheme()
// Full-screen bordered container - uses entire terminal width and height
maxWidth := ts.width - 2 // Small margin on each side
if maxWidth < 20 {
maxWidth = ts.width
}
maxHeight := ts.height - 2 // Small margin top/bottom to prevent overflow
if maxHeight < 10 {
maxHeight = ts.height
}
horizontalPadding := 1
innerWidth := maxWidth - 4 // Account for border (2) + padding (2)
innerHeight := maxHeight - 4 // Account for border (2) + padding (2)
// Container style with border - full width/height like a framed panel
containerStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(theme.Primary).
Background(theme.Background).
Padding(1, horizontalPadding).
Width(maxWidth).
Height(maxHeight)
// Header style with background highlight (like PopupList title)
headerStyle := lipgloss.NewStyle().
Bold(true).
Foreground(theme.Accent).
Background(theme.Background)
// Help text style
helpStyle := lipgloss.NewStyle().
Foreground(theme.Muted).
Background(theme.Background)
var contentBuilder strings.Builder
// Header row with title and help
headerRow := headerStyle.Render("Session Tree")
contentBuilder.WriteString(headerRow)
contentBuilder.WriteString("\n")
// Help text - adapt to terminal width
var helpText string
if ts.width >= 70 {
helpText = "↑/↓: move ←/→: page enter: select esc: cancel ^O: cycle filter"
} else if ts.width >= 45 {
helpText = "↑↓ move ↵ select esc cancel ^O filter"
} else {
helpText = "↑↓ ↵ esc ^O"
}
contentBuilder.WriteString(helpStyle.Render(helpText))
contentBuilder.WriteString("\n")
// Search display (if active)
if ts.search != "" {
searchStyle := lipgloss.NewStyle().
Foreground(theme.Info).
Background(theme.Background)
contentBuilder.WriteString(searchStyle.Render(fmt.Sprintf("> %s", ts.search)))
contentBuilder.WriteString("\n")
}
// Separator line - full width
sepWidth := innerWidth
contentBuilder.WriteString(
lipgloss.NewStyle().
Foreground(theme.Muted).
Background(theme.Background).
Render(strings.Repeat("─", sepWidth)))
contentBuilder.WriteString("\n")
// Tree content
if len(ts.flatNodes) == 0 {
emptyStyle := lipgloss.NewStyle().
Foreground(theme.Muted).
Background(theme.Background)
contentBuilder.WriteString(emptyStyle.Render("No entries in session"))
contentBuilder.WriteString("\n")
} else {
// Compute visible window based on inner container height
// Chrome: header(2) + separator(1) + footer separator(1) + footer(1) = 5
chromeLines := 5
if ts.search != "" {
chromeLines++
}
visH := max(innerHeight-chromeLines, 3)
startIdx := 0
if ts.cursor >= visH {
startIdx = ts.cursor - visH + 1
}
endIdx := min(startIdx+visH, len(ts.flatNodes))
for i := startIdx; i < endIdx; i++ {
node := ts.flatNodes[i]
line := ts.renderNode(node, i == ts.cursor, node.ID == ts.leafID, innerWidth)
contentBuilder.WriteString(line)
contentBuilder.WriteString("\n")
}
}
// Footer separator
contentBuilder.WriteString(
lipgloss.NewStyle().
Foreground(theme.Muted).
Background(theme.Background).
Render(strings.Repeat("─", sepWidth)))
contentBuilder.WriteString("\n")
// Footer with count and filter
footerStyle := lipgloss.NewStyle().
Foreground(theme.Muted).
Background(theme.Background)
footer := fmt.Sprintf("(%d/%d) [%s]", ts.cursor+1, len(ts.flatNodes), ts.filter)
contentBuilder.WriteString(footerStyle.Render(footer))
// Apply the bordered container - full width, no centering
content := contentBuilder.String()
borderedContent := containerStyle.Render(content)
v := tea.NewView(borderedContent)
// Update extra footer with current filter mode.
ts.popup.ExtraFooter = fmt.Sprintf("[%s]", ts.filter)
rendered := ts.popup.RenderCentered(ts.width, ts.height)
v := tea.NewView(rendered)
v.AltScreen = true
return v
}
@@ -353,38 +226,46 @@ func (ts *TreeSelectorComponent) IsActive() bool {
// --- Internal helpers ---
func (ts *TreeSelectorComponent) visibleHeight() int {
// Chrome: header(1) + help(1) + separator(1) + entries + separator(1) + footer(1) = 5 fixed.
// Optional search line adds 1 more. Use 7 as a safe estimate.
const chromeLines = 7
return max(ts.height-chromeLines, 3)
}
func (ts *TreeSelectorComponent) rebuildFlatList() {
tree := ts.tm.GetTree()
// rebuild reflattens the tree under the current filter and reseeds the popup
// with PopupItems. Called on initial load and whenever the filter changes.
func (ts *TreeSelectorComponent) rebuild() {
ts.flatNodes = ts.flatNodes[:0]
tree := ts.tm.GetTree()
for i, root := range tree {
isLast := i == len(tree)-1
ts.flattenNode(root, 0, isLast, "")
}
ts.publishItems()
}
// Apply search filter.
if ts.search != "" {
query := strings.ToLower(ts.search)
filtered := make([]FlatNode, 0)
for _, node := range ts.flatNodes {
text := ts.entryDisplayText(node.Entry)
if strings.Contains(strings.ToLower(text), query) {
filtered = append(filtered, node)
}
// syncFlatNodes refreshes flatNodes from the popup's current filtered view.
// Called after a search-driven HandleKey result so the cursor index matches.
func (ts *TreeSelectorComponent) syncFlatNodes() {
items := ts.popup.Items()
newFlat := make([]FlatNode, len(items))
for i, it := range items {
if fn, ok := it.Meta.(FlatNode); ok {
newFlat[i] = fn
}
ts.flatNodes = filtered
}
ts.flatNodes = newFlat
}
// Clamp cursor.
if ts.cursor >= len(ts.flatNodes) {
ts.cursor = max(len(ts.flatNodes)-1, 0)
// publishItems converts flatNodes → PopupItems and seeds the popup. We rely
// on PopupList's default substring filter against item.Label (which holds
// the display text) for search.
func (ts *TreeSelectorComponent) publishItems() {
items := make([]PopupItem, len(ts.flatNodes))
for i, n := range ts.flatNodes {
items[i] = PopupItem{
Label: ts.entryDisplayText(n.Entry),
Active: n.ID == ts.leafID,
Meta: n,
}
}
ts.popup.SetItems(items)
// Mirror the popup's current view in flatNodes so cursor lookups work.
ts.syncFlatNodes()
}
func (ts *TreeSelectorComponent) flattenNode(node *session.TreeNode, depth int, isLast bool, gutterPrefix string) {
@@ -473,35 +354,73 @@ func (ts *TreeSelectorComponent) passesFilter(node *session.TreeNode) bool {
}
}
func (ts *TreeSelectorComponent) renderNode(node FlatNode, isCursor, isLeaf bool, innerWidth int) string {
// renderNode is the RenderItem callback handed to PopupList. PopupList wraps
// the returned string with a full-width row style.
//
// When isCursor we return a plain (unstyled) string so the outer row style
// can paint a single continuous fg+bg span across the line. Composing inner
// lipgloss.Render calls emits ANSI resets mid-string which knock the
// background back out, breaking the highlight into disjoint bars (issue
// observed with deep tool-interaction branches).
func (ts *TreeSelectorComponent) renderNode(item PopupItem, innerWidth int, isCursor bool) string {
theme := GetTheme()
node, ok := item.Meta.(FlatNode)
if !ok {
return item.Label
}
isLeaf := node.ID == ts.leafID
// Cursor indicator - use ">" for selected (like PopupList)
var cursor string
// Indicator (2 cells).
indicator := " "
if isCursor {
cursor = lipgloss.NewStyle().Foreground(theme.Accent).Render("> ")
} else {
cursor = " "
indicator = "> "
}
// Role-colored content with background support for selection
text := ts.entryDisplayText(node.Entry)
// Prefix (tree art) — width measured in display cells via lipgloss.
prefix := node.Prefix
prefixW := lipgloss.Width(prefix)
// Calculate available width accounting for cursor, prefix, and markers
prefixLen := len(node.Prefix)
available := innerWidth - prefixLen - 4 // 4 for cursor and some padding
if available > 3 && len(text) > available {
trimLen := max(available-3, 1)
if trimLen < len(text) {
text = text[:trimLen] + "..."
// Compute right-side fixed parts: label badge + active marker.
var labelBadgeRaw, activeMarkerRaw string
if node.Label != "" {
labelBadgeRaw = " [" + node.Label + "]"
}
if isLeaf {
activeMarkerRaw = " ← active"
}
rightW := lipgloss.Width(labelBadgeRaw) + lipgloss.Width(activeMarkerRaw)
// If the tree prefix is so deep it would push the text off the row,
// truncate the prefix from the LEFT and prepend an ellipsis. Keeping
// the right-most segment preserves the most recent depth indicator
// (└─ / ├─) so the user can still see this row's connection to its
// parent. We reserve at least 20 cells for the actual entry text.
const minTextWidth = 20
budget := innerWidth - 2 - rightW - minTextWidth
if prefixW > budget && budget > 2 {
runes := []rune(prefix)
// Strip from the left until lipgloss.Width fits the budget.
for len(runes) > 0 && lipgloss.Width(string(runes)) > budget-1 {
runes = runes[1:]
}
prefix = "…" + string(runes)
prefixW = lipgloss.Width(prefix)
}
// Build the full line style
var lineStyle lipgloss.Style
var textStyle lipgloss.Style
// Reserve space for indicator(2) + prefix + right parts.
available := max(innerWidth-2-prefixW-rightW, 4)
// Base text color based on role
text := ts.entryDisplayText(node.Entry)
text = truncateRunes(text, available)
// Selected row: emit raw text. The outer row style applies fg+bg in one
// uninterrupted span, keeping the highlight solid edge-to-edge.
if isCursor {
return indicator + prefix + text + labelBadgeRaw + activeMarkerRaw
}
// Role-based text color.
var textStyle lipgloss.Style
switch e := node.Entry.(type) {
case *session.MessageEntry:
switch e.Role {
@@ -520,77 +439,27 @@ func (ts *TreeSelectorComponent) renderNode(node FlatNode, isCursor, isLeaf bool
textStyle = lipgloss.NewStyle().Foreground(theme.Muted)
}
// Apply selection highlighting (like PopupList)
if isCursor {
// Inverted colors for selected item - matches PopupList style
lineStyle = lipgloss.NewStyle().
Background(theme.Primary).
Foreground(theme.Background).
Bold(true)
textStyle = lipgloss.NewStyle().
Background(theme.Primary).
Foreground(theme.Background).
Bold(true)
}
// Render components
content := textStyle.Render(text)
// Label badge.
var labelBadge string
if node.Label != "" {
labelStyle := lipgloss.NewStyle().Foreground(theme.Warning)
if isCursor {
labelStyle = lipgloss.NewStyle().
Background(theme.Primary).
Foreground(theme.Warning)
}
labelBadge = " " + labelStyle.Render("["+node.Label+"]")
}
// Active marker - use Success color for better visibility
var activeMarker string
if isLeaf {
markerStyle := lipgloss.NewStyle().Foreground(theme.Success).Bold(true)
if isCursor {
markerStyle = lipgloss.NewStyle().
Background(theme.Primary).
Foreground(theme.Success).
Bold(true)
}
activeMarker = markerStyle.Render(" ← active")
}
// Prefix (tree lines) - use MutedBorder for subtler appearance
prefixStyle := lipgloss.NewStyle().Foreground(theme.MutedBorder)
if isCursor {
prefixStyle = lipgloss.NewStyle().
Background(theme.Primary).
Foreground(theme.MutedBorder)
labelStyle := lipgloss.NewStyle().Foreground(theme.Warning)
markerStyle := lipgloss.NewStyle().Foreground(theme.Success).Bold(true)
parts := indicator + prefixStyle.Render(prefix) + textStyle.Render(text)
if labelBadgeRaw != "" {
parts += labelStyle.Render(labelBadgeRaw)
}
renderedPrefix := prefixStyle.Render(node.Prefix)
// Combine all parts
line := cursor + renderedPrefix + content + labelBadge + activeMarker
// If selected, apply the background to the entire line
if isCursor {
return lineStyle.Render(line)
if activeMarkerRaw != "" {
parts += markerStyle.Render(activeMarkerRaw)
}
return line
return parts
}
func (ts *TreeSelectorComponent) entryDisplayText(entry any) string {
switch e := entry.(type) {
case *session.MessageEntry:
role := e.Role
text := extractTextFromParts(e.Parts)
if len(text) > 80 {
text = text[:80] + "..."
}
text := collapseToLine(extractTextFromParts(e.Parts))
text = truncateRunes(text, 200)
if text == "" {
// Tool call messages may not have text.
text = "(tool interaction)"
}
return fmt.Sprintf("%s: %s", role, text)
@@ -599,18 +468,10 @@ func (ts *TreeSelectorComponent) entryDisplayText(entry any) string {
return fmt.Sprintf("model: %s/%s", e.Provider, e.ModelID)
case *session.BranchSummaryEntry:
summary := e.Summary
if len(summary) > 60 {
summary = summary[:60] + "..."
}
return fmt.Sprintf("branch summary: %s", summary)
return fmt.Sprintf("branch summary: %s", truncateRunes(collapseToLine(e.Summary), 200))
case *session.CompactionEntry:
summary := e.Summary
if len(summary) > 60 {
summary = summary[:60] + "..."
}
return fmt.Sprintf("compaction: %s", summary)
return fmt.Sprintf("compaction: %s", truncateRunes(collapseToLine(e.Summary), 200))
case *session.LabelEntry:
return fmt.Sprintf("label: %s", e.Label)
@@ -623,6 +484,13 @@ func (ts *TreeSelectorComponent) entryDisplayText(entry any) string {
}
}
// collapseToLine flattens any multi-line string into a single line by
// replacing whitespace runs (including newlines and tabs) with single
// spaces. Used so popup rows never wrap and break the layout.
func collapseToLine(s string) string {
return strings.Join(strings.Fields(s), " ")
}
func (ts *TreeSelectorComponent) isUserMessage(entry any) bool {
if me, ok := entry.(*session.MessageEntry); ok {
return me.Role == "user"
+4 -2
View File
@@ -91,8 +91,10 @@ func (h *Harness) LoadString(src string, path string) *extensions.LoadedExtensio
func (h *Harness) loadSource(src string, path string) *extensions.LoadedExtension {
h.t.Helper()
// Create a fresh interpreter
i := interp.New(interp.Options{})
// Create a fresh interpreter. Seed the virtualized environment with the
// process environment so extensions can read env vars via os.Getenv,
// mirroring the production loader (see internal/extensions/loader.go).
i := interp.New(interp.Options{Env: os.Environ()})
// Expose Go stdlib
if err := i.Use(stdlib.Symbols); err != nil {
+4 -3
View File
@@ -74,7 +74,8 @@ host, err := kit.NewAgent(ctx,
Helpers: `WithModel`, `WithSystemPrompt`, `WithStreaming`, `WithMaxTokens`,
`WithThinkingLevel`, `WithTools`, `WithExtraTools`, `WithProviderAPIKey`,
`WithProviderURL`, `WithConfigFile`, `WithDebug`, and `Ephemeral`. `Option` is
`WithProviderURL`, `WithConfigFile`, `WithDebug`, `WithDebugLogger`, and
`Ephemeral`. `Option` is
a plain `func(*Options)`, so you can define your own. For fields without a
`With*` helper (`MCPConfig`, `InProcessMCPServers`, `SessionManager`, MCP task
tuning) construct an `Options` value and call `kit.New`.
@@ -329,7 +330,6 @@ kit.LLMFilePart // {Filename, Data []byte, MediaType}
// Agent configuration — concrete Kit-owned structs and function types.
// All fields use SDK types (e.g. `[]kit.Tool`), so consumers can construct
// these without importing any LLM-provider package.
kit.AgentConfig // Lower-level agent config — prefer Options unless you need direct control
kit.DebugLogger // Interface: LogDebug(string) / IsDebugEnabled() bool
kit.MCPTaskConfig // Task-aware MCP tools/call config (modes, polling, progress)
kit.ToolCallHandler // func(toolCallID, toolName, toolArgs string)
@@ -403,7 +403,8 @@ Key `Options` fields for SDK usage:
| `SessionPath` | Open specific session file |
| `Continue` | Resume most recent session |
| `InProcessMCPServers` | Map of name → `*kit.MCPServer` for in-process MCP servers |
| `Debug` | Enable debug logging |
| `Debug` | Enable debug logging via the built-in console logger (ignored when `DebugLogger` is set) |
| `DebugLogger` | Custom `DebugLogger` implementation — routes engine + MCP debug output into your own logging system |
## Environment Variables
+2 -2
View File
@@ -11,12 +11,12 @@ import (
// treeManagerAdapter adapts TreeManager to SessionManager interface.
// This is unexported - users don't interact with it directly.
type treeManagerAdapter struct {
inner *session.TreeManager
inner *TreeManager
}
// NewTreeManagerAdapter creates an adapter (exported for use in New function).
// This is used by the SDK when no custom SessionManager is provided.
func NewTreeManagerAdapter(tm *session.TreeManager) SessionManager {
func NewTreeManagerAdapter(tm *TreeManager) SessionManager {
return &treeManagerAdapter{inner: tm}
}
-208
View File
@@ -1,208 +0,0 @@
package kit
import (
"context"
"errors"
"testing"
"time"
"github.com/mark3labs/kit/internal/agent"
)
// TestAgentConfigToInternal verifies that the SDK-side AgentConfig converts
// faithfully to the internal agent.AgentConfig representation, preserving
// every field consumed by the internal agent layer.
//
// Regression test for https://github.com/mark3labs/kit/issues/30.
func TestAgentConfigToInternal(t *testing.T) {
t.Run("nil receiver returns nil", func(t *testing.T) {
var c *AgentConfig
if got := c.toInternal(); got != nil {
t.Errorf("nil.toInternal() = %v, want nil", got)
}
})
t.Run("scalar fields round-trip", func(t *testing.T) {
c := &AgentConfig{
SystemPrompt: "sys",
MaxSteps: 7,
StreamingEnabled: true,
DisableCoreTools: true,
}
got := c.toInternal()
if got == nil {
t.Fatal("toInternal() = nil")
}
if got.SystemPrompt != "sys" {
t.Errorf("SystemPrompt = %q, want %q", got.SystemPrompt, "sys")
}
if got.MaxSteps != 7 {
t.Errorf("MaxSteps = %d, want 7", got.MaxSteps)
}
if !got.StreamingEnabled {
t.Error("StreamingEnabled = false, want true")
}
if !got.DisableCoreTools {
t.Error("DisableCoreTools = false, want true")
}
})
t.Run("tool slices propagate without conversion", func(t *testing.T) {
// Tool is a type alias for the underlying LLM-tool type, so the
// SDK []Tool and internal []fantasy.AgentTool slices share the
// same backing array after conversion.
tool := NewTool[struct{}]("noop", "noop", nil)
c := &AgentConfig{
CoreTools: []Tool{tool},
ExtraTools: []Tool{tool, tool},
}
got := c.toInternal()
if len(got.CoreTools) != 1 {
t.Errorf("CoreTools len = %d, want 1", len(got.CoreTools))
}
if len(got.ExtraTools) != 2 {
t.Errorf("ExtraTools len = %d, want 2", len(got.ExtraTools))
}
})
t.Run("tool wrapper is invoked through internal config", func(t *testing.T) {
called := false
c := &AgentConfig{
ToolWrapper: func(in []Tool) []Tool {
called = true
return in
},
}
got := c.toInternal()
if got.ToolWrapper == nil {
t.Fatal("internal ToolWrapper is nil")
}
_ = got.ToolWrapper(nil)
if !called {
t.Error("SDK ToolWrapper was not invoked through the internal config")
}
})
t.Run("OnMCPServerLoaded propagates", func(t *testing.T) {
var captured string
wantErr := errors.New("boom")
c := &AgentConfig{
OnMCPServerLoaded: func(name string, _ int, _ error) {
captured = name
},
}
got := c.toInternal()
got.OnMCPServerLoaded("svr", 3, wantErr)
if captured != "svr" {
t.Errorf("OnMCPServerLoaded captured = %q, want %q", captured, "svr")
}
})
t.Run("DebugLogger propagates", func(t *testing.T) {
dl := &fakeDebugLogger{enabled: true}
c := &AgentConfig{DebugLogger: dl}
got := c.toInternal()
if got.DebugLogger == nil {
t.Fatal("internal DebugLogger is nil")
}
if !got.DebugLogger.IsDebugEnabled() {
t.Error("IsDebugEnabled = false, want true")
}
got.DebugLogger.LogDebug("hello")
if len(dl.messages) != 1 || dl.messages[0] != "hello" {
t.Errorf("messages = %v, want [hello]", dl.messages)
}
})
t.Run("MCPTaskConfig propagates with mode + progress", func(t *testing.T) {
c := &AgentConfig{
MCPTaskConfig: MCPTaskConfig{
PerServerMode: map[string]MCPTaskMode{
"build-svr": MCPTaskModeAlways,
},
DefaultTTL: 30 * time.Second,
PollInterval: 250 * time.Millisecond,
MaxPollInterval: 2 * time.Second,
Timeout: 5 * time.Minute,
Progress: func(_ MCPTaskProgress) {},
},
}
got := c.toInternal()
if got.MCPTaskConfig.DefaultTTL != 30*time.Second {
t.Errorf("DefaultTTL = %v, want 30s", got.MCPTaskConfig.DefaultTTL)
}
if got.MCPTaskConfig.PollInterval != 250*time.Millisecond {
t.Errorf("PollInterval = %v, want 250ms", got.MCPTaskConfig.PollInterval)
}
if got.MCPTaskConfig.MaxPollInterval != 2*time.Second {
t.Errorf("MaxPollInterval = %v, want 2s", got.MCPTaskConfig.MaxPollInterval)
}
if got.MCPTaskConfig.Timeout != 5*time.Minute {
t.Errorf("Timeout = %v, want 5m", got.MCPTaskConfig.Timeout)
}
mode, ok := got.MCPTaskConfig.PerServerMode["build-svr"]
if !ok {
t.Fatal("PerServerMode missing 'build-svr'")
}
if string(mode) != string(MCPTaskModeAlways) {
t.Errorf("mode = %q, want %q", mode, MCPTaskModeAlways)
}
if got.MCPTaskConfig.Progress == nil {
t.Fatal("internal Progress handler is nil")
}
})
t.Run("auth and token store factories are wired", func(t *testing.T) {
auth := &fakeAuthHandler{}
tokenCalls := 0
var tokenServer string
factory := MCPTokenStoreFactory(func(server string) (MCPTokenStore, error) {
tokenCalls++
tokenServer = server
return nil, nil
})
c := &AgentConfig{
AuthHandler: auth,
TokenStoreFactory: factory,
}
got := c.toInternal()
if got.AuthHandler == nil {
t.Fatal("internal AuthHandler is nil")
}
if got.TokenStoreFactory == nil {
t.Fatal("internal TokenStoreFactory is nil")
}
_, _ = got.TokenStoreFactory("https://example.test")
if tokenCalls != 1 {
t.Errorf("token factory call count = %d, want 1", tokenCalls)
}
if tokenServer != "https://example.test" {
t.Errorf("token factory server arg = %q", tokenServer)
}
if got.AuthHandler.RedirectURI() != "redirect" {
t.Errorf("RedirectURI = %q, want %q", got.AuthHandler.RedirectURI(), "redirect")
}
})
// Compile-time check that the internal type is what we expect.
//nolint:staticcheck // QF1011: explicit type asserts the conversion target.
var _ *agent.AgentConfig = (&AgentConfig{}).toInternal()
}
// fakeAuthHandler implements both kit.MCPAuthHandler and the structurally
// identical tools.MCPAuthHandler used by the internal layer.
type fakeAuthHandler struct{}
func (f *fakeAuthHandler) RedirectURI() string { return "redirect" }
func (f *fakeAuthHandler) HandleAuth(_ context.Context, _ string, _ string) (string, error) {
return "", nil
}
// fakeDebugLogger implements kit.DebugLogger for tests.
type fakeDebugLogger struct {
enabled bool
messages []string
}
func (f *fakeDebugLogger) LogDebug(m string) { f.messages = append(f.messages, m) }
func (f *fakeDebugLogger) IsDebugEnabled() bool { return f.enabled }
+34
View File
@@ -17,6 +17,9 @@ type AnthropicCredentials = auth.AnthropicCredentials
// and API key authentication methods.
type OpenAICredentials = auth.OpenAICredentials
// CopilotCredentials holds GitHub OAuth and Copilot API credentials.
type CopilotCredentials = auth.CopilotCredentials
// CredentialStore holds all stored credentials for various providers.
type CredentialStore = auth.CredentialStore
@@ -65,6 +68,37 @@ func HasOpenAICredentials() bool {
return has
}
// HasCopilotCredentials checks if valid GitHub Copilot credentials are stored.
func HasCopilotCredentials() bool {
cm, err := auth.NewCredentialManager()
if err != nil {
return false
}
has, err := cm.HasCopilotCredentials()
if err != nil {
return false
}
return has
}
// GetCopilotCredentials retrieves stored GitHub Copilot credentials.
func GetCopilotCredentials() (*CopilotCredentials, error) {
cm, err := auth.NewCredentialManager()
if err != nil {
return nil, err
}
return cm.GetCopilotCredentials()
}
// GetValidCopilotAccessToken returns a fresh GitHub Copilot access token.
func GetValidCopilotAccessToken() (string, error) {
cm, err := auth.NewCredentialManager()
if err != nil {
return "", err
}
return cm.GetValidCopilotAccessToken()
}
// GetOpenAIAPIKey resolves the OpenAI API key using the standard
// resolution order: stored credentials -> OPENAI_API_KEY env var.
// Returns an empty string if no key is found.
+54 -172
View File
@@ -3,6 +3,8 @@ package kit
import (
"encoding/json"
"sync"
"github.com/mark3labs/kit/internal/extensions"
)
// ---------------------------------------------------------------------------
@@ -103,34 +105,21 @@ type Event interface {
// appropriate visualizations (e.g. diff view for edit tools, command+output
// for execute tools) and file trackers to identify which results contain
// modifications.
//
// These constants re-export the canonical classification used by extension
// events, so SDK events and extension events always agree.
const (
ToolKindExecute = "execute" // Shell execution (bash)
ToolKindEdit = "edit" // File modification (edit, write)
ToolKindRead = "read" // File reading (read, ls)
ToolKindSearch = "search" // Content/file search (grep, find)
ToolKindSubagent = "agent" // Subagent spawning (subagent)
ToolKindExecute = extensions.ToolKindExecute // Shell execution (bash)
ToolKindEdit = extensions.ToolKindEdit // File modification (edit, write)
ToolKindRead = extensions.ToolKindRead // File reading (read, ls)
ToolKindSearch = extensions.ToolKindSearch // Content/file search (grep, find)
ToolKindSubagent = extensions.ToolKindSubagent // Subagent spawning (subagent)
)
// coreToolKinds maps built-in tool names to their kind. MCP and extension
// tools without an entry default to ToolKindExecute.
var coreToolKinds = map[string]string{
"bash": ToolKindExecute,
"edit": ToolKindEdit,
"write": ToolKindEdit,
"read": ToolKindRead,
"ls": ToolKindRead,
"grep": ToolKindSearch,
"find": ToolKindSearch,
"subagent": ToolKindSubagent,
}
// toolKindFor returns the ToolKind for a given tool name, defaulting to
// ToolKindExecute for unknown tools.
func toolKindFor(toolName string) string {
if kind, ok := coreToolKinds[toolName]; ok {
return kind
}
return ToolKindExecute
return extensions.ToolKindFor(toolName)
}
// parseToolArgs attempts to parse a JSON-encoded tool args string into a map.
@@ -571,67 +560,56 @@ func (eb *eventBus) emit(event Event) {
// Typed convenience subscribers
// ---------------------------------------------------------------------------
// subscribeTyped is the generic backbone of all the typed `On<EventName>`
// convenience methods on *Kit. It wraps Subscribe with a type assertion
// against E so handlers receive a strongly-typed event without each
// public method having to repeat the boilerplate. Returns an unsubscribe
// function.
func subscribeTyped[E Event](k *Kit, handler func(E)) func() {
return k.Subscribe(func(e Event) {
if tev, ok := e.(E); ok {
handler(tev)
}
})
}
// OnToolCall registers a handler that fires only for ToolCallEvent.
// Returns an unsubscribe function.
func (m *Kit) OnToolCall(handler func(ToolCallEvent)) func() {
return m.Subscribe(func(e Event) {
if tc, ok := e.(ToolCallEvent); ok {
handler(tc)
}
})
return subscribeTyped(m, handler)
}
// OnToolCallStart registers a handler that fires only for ToolCallStartEvent.
// This fires when the LLM begins generating tool call arguments — before the
// full argument JSON is available. Returns an unsubscribe function.
func (m *Kit) OnToolCallStart(handler func(ToolCallStartEvent)) func() {
return m.Subscribe(func(e Event) {
if tcs, ok := e.(ToolCallStartEvent); ok {
handler(tcs)
}
})
return subscribeTyped(m, handler)
}
// OnToolCallDelta registers a handler that fires only for ToolCallDeltaEvent.
// Each delta contains a JSON fragment of tool call arguments as they stream in.
// Returns an unsubscribe function.
func (m *Kit) OnToolCallDelta(handler func(ToolCallDeltaEvent)) func() {
return m.Subscribe(func(e Event) {
if tcd, ok := e.(ToolCallDeltaEvent); ok {
handler(tcd)
}
})
return subscribeTyped(m, handler)
}
// OnToolCallEnd registers a handler that fires only for ToolCallEndEvent.
// This fires when tool argument streaming is complete, before the tool call
// is parsed and execution begins. Returns an unsubscribe function.
func (m *Kit) OnToolCallEnd(handler func(ToolCallEndEvent)) func() {
return m.Subscribe(func(e Event) {
if tce, ok := e.(ToolCallEndEvent); ok {
handler(tce)
}
})
return subscribeTyped(m, handler)
}
// OnToolResult registers a handler that fires only for ToolResultEvent.
// Returns an unsubscribe function.
func (m *Kit) OnToolResult(handler func(ToolResultEvent)) func() {
return m.Subscribe(func(e Event) {
if tr, ok := e.(ToolResultEvent); ok {
handler(tr)
}
})
return subscribeTyped(m, handler)
}
// OnToolOutput registers a handler that fires only for ToolOutputEvent
// (streaming tool output chunks, e.g., from bash). Returns an unsubscribe function.
func (m *Kit) OnToolOutput(handler func(ToolOutputEvent)) func() {
return m.Subscribe(func(e Event) {
if to, ok := e.(ToolOutputEvent); ok {
handler(to)
}
})
return subscribeTyped(m, handler)
}
// OnStreaming registers a handler that fires only for MessageUpdateEvent
@@ -646,41 +624,25 @@ func (m *Kit) OnStreaming(handler func(MessageUpdateEvent)) func() {
// OnMessageUpdate registers a handler that fires only for MessageUpdateEvent
// (streaming text chunks). Returns an unsubscribe function.
func (m *Kit) OnMessageUpdate(handler func(MessageUpdateEvent)) func() {
return m.Subscribe(func(e Event) {
if mu, ok := e.(MessageUpdateEvent); ok {
handler(mu)
}
})
return subscribeTyped(m, handler)
}
// OnResponse registers a handler that fires only for ResponseEvent.
// Returns an unsubscribe function.
func (m *Kit) OnResponse(handler func(ResponseEvent)) func() {
return m.Subscribe(func(e Event) {
if r, ok := e.(ResponseEvent); ok {
handler(r)
}
})
return subscribeTyped(m, handler)
}
// OnTurnStart registers a handler that fires only for TurnStartEvent.
// Returns an unsubscribe function.
func (m *Kit) OnTurnStart(handler func(TurnStartEvent)) func() {
return m.Subscribe(func(e Event) {
if ts, ok := e.(TurnStartEvent); ok {
handler(ts)
}
})
return subscribeTyped(m, handler)
}
// OnTurnEnd registers a handler that fires only for TurnEndEvent.
// Returns an unsubscribe function.
func (m *Kit) OnTurnEnd(handler func(TurnEndEvent)) func() {
return m.Subscribe(func(e Event) {
if te, ok := e.(TurnEndEvent); ok {
handler(te)
}
})
return subscribeTyped(m, handler)
}
// ---------------------------------------------------------------------------
@@ -690,101 +652,61 @@ func (m *Kit) OnTurnEnd(handler func(TurnEndEvent)) func() {
// OnMessageStart registers a handler that fires only for MessageStartEvent.
// Returns an unsubscribe function.
func (m *Kit) OnMessageStart(handler func(MessageStartEvent)) func() {
return m.Subscribe(func(e Event) {
if ms, ok := e.(MessageStartEvent); ok {
handler(ms)
}
})
return subscribeTyped(m, handler)
}
// OnMessageEnd registers a handler that fires only for MessageEndEvent.
// Returns an unsubscribe function.
func (m *Kit) OnMessageEnd(handler func(MessageEndEvent)) func() {
return m.Subscribe(func(e Event) {
if me, ok := e.(MessageEndEvent); ok {
handler(me)
}
})
return subscribeTyped(m, handler)
}
// OnReasoningDelta registers a handler that fires only for ReasoningDeltaEvent.
// Returns an unsubscribe function.
func (m *Kit) OnReasoningDelta(handler func(ReasoningDeltaEvent)) func() {
return m.Subscribe(func(e Event) {
if rd, ok := e.(ReasoningDeltaEvent); ok {
handler(rd)
}
})
return subscribeTyped(m, handler)
}
// OnReasoningComplete registers a handler that fires only for ReasoningCompleteEvent.
// Returns an unsubscribe function.
func (m *Kit) OnReasoningComplete(handler func(ReasoningCompleteEvent)) func() {
return m.Subscribe(func(e Event) {
if rc, ok := e.(ReasoningCompleteEvent); ok {
handler(rc)
}
})
return subscribeTyped(m, handler)
}
// OnToolExecutionStart registers a handler that fires only for ToolExecutionStartEvent.
// Returns an unsubscribe function.
func (m *Kit) OnToolExecutionStart(handler func(ToolExecutionStartEvent)) func() {
return m.Subscribe(func(e Event) {
if tes, ok := e.(ToolExecutionStartEvent); ok {
handler(tes)
}
})
return subscribeTyped(m, handler)
}
// OnToolExecutionEnd registers a handler that fires only for ToolExecutionEndEvent.
// Returns an unsubscribe function.
func (m *Kit) OnToolExecutionEnd(handler func(ToolExecutionEndEvent)) func() {
return m.Subscribe(func(e Event) {
if tee, ok := e.(ToolExecutionEndEvent); ok {
handler(tee)
}
})
return subscribeTyped(m, handler)
}
// OnToolCallContent registers a handler that fires only for ToolCallContentEvent.
// Returns an unsubscribe function.
func (m *Kit) OnToolCallContent(handler func(ToolCallContentEvent)) func() {
return m.Subscribe(func(e Event) {
if tcc, ok := e.(ToolCallContentEvent); ok {
handler(tcc)
}
})
return subscribeTyped(m, handler)
}
// OnStepUsage registers a handler that fires only for StepUsageEvent.
// Returns an unsubscribe function.
func (m *Kit) OnStepUsage(handler func(StepUsageEvent)) func() {
return m.Subscribe(func(e Event) {
if su, ok := e.(StepUsageEvent); ok {
handler(su)
}
})
return subscribeTyped(m, handler)
}
// OnCompaction registers a handler that fires only for CompactionEvent.
// Returns an unsubscribe function.
func (m *Kit) OnCompaction(handler func(CompactionEvent)) func() {
return m.Subscribe(func(e Event) {
if ce, ok := e.(CompactionEvent); ok {
handler(ce)
}
})
return subscribeTyped(m, handler)
}
// OnSteerConsumed registers a handler that fires only for SteerConsumedEvent.
// Returns an unsubscribe function.
func (m *Kit) OnSteerConsumed(handler func(SteerConsumedEvent)) func() {
return m.Subscribe(func(e Event) {
if sc, ok := e.(SteerConsumedEvent); ok {
handler(sc)
}
})
return subscribeTyped(m, handler)
}
// ---------------------------------------------------------------------------
@@ -794,101 +716,61 @@ func (m *Kit) OnSteerConsumed(handler func(SteerConsumedEvent)) func() {
// OnStepStart registers a handler that fires only for StepStartEvent.
// Returns an unsubscribe function.
func (m *Kit) OnStepStart(handler func(StepStartEvent)) func() {
return m.Subscribe(func(e Event) {
if ss, ok := e.(StepStartEvent); ok {
handler(ss)
}
})
return subscribeTyped(m, handler)
}
// OnStepFinish registers a handler that fires only for StepFinishEvent.
// Returns an unsubscribe function.
func (m *Kit) OnStepFinish(handler func(StepFinishEvent)) func() {
return m.Subscribe(func(e Event) {
if sf, ok := e.(StepFinishEvent); ok {
handler(sf)
}
})
return subscribeTyped(m, handler)
}
// OnTextStart registers a handler that fires only for TextStartEvent.
// Returns an unsubscribe function.
func (m *Kit) OnTextStart(handler func(TextStartEvent)) func() {
return m.Subscribe(func(e Event) {
if ts, ok := e.(TextStartEvent); ok {
handler(ts)
}
})
return subscribeTyped(m, handler)
}
// OnTextEnd registers a handler that fires only for TextEndEvent.
// Returns an unsubscribe function.
func (m *Kit) OnTextEnd(handler func(TextEndEvent)) func() {
return m.Subscribe(func(e Event) {
if te, ok := e.(TextEndEvent); ok {
handler(te)
}
})
return subscribeTyped(m, handler)
}
// OnReasoningStart registers a handler that fires only for ReasoningStartEvent.
// Returns an unsubscribe function.
func (m *Kit) OnReasoningStart(handler func(ReasoningStartEvent)) func() {
return m.Subscribe(func(e Event) {
if rs, ok := e.(ReasoningStartEvent); ok {
handler(rs)
}
})
return subscribeTyped(m, handler)
}
// OnWarnings registers a handler that fires only for WarningsEvent.
// Returns an unsubscribe function.
func (m *Kit) OnWarnings(handler func(WarningsEvent)) func() {
return m.Subscribe(func(e Event) {
if w, ok := e.(WarningsEvent); ok {
handler(w)
}
})
return subscribeTyped(m, handler)
}
// OnSource registers a handler that fires only for SourceEvent.
// Returns an unsubscribe function.
func (m *Kit) OnSource(handler func(SourceEvent)) func() {
return m.Subscribe(func(e Event) {
if s, ok := e.(SourceEvent); ok {
handler(s)
}
})
return subscribeTyped(m, handler)
}
// OnStreamFinish registers a handler that fires only for StreamFinishEvent.
// Returns an unsubscribe function.
func (m *Kit) OnStreamFinish(handler func(StreamFinishEvent)) func() {
return m.Subscribe(func(e Event) {
if sf, ok := e.(StreamFinishEvent); ok {
handler(sf)
}
})
return subscribeTyped(m, handler)
}
// OnError registers a handler that fires only for ErrorEvent.
// Returns an unsubscribe function.
func (m *Kit) OnError(handler func(ErrorEvent)) func() {
return m.Subscribe(func(e Event) {
if ee, ok := e.(ErrorEvent); ok {
handler(ee)
}
})
return subscribeTyped(m, handler)
}
// OnRetry registers a handler that fires only for RetryEvent.
// Returns an unsubscribe function.
func (m *Kit) OnRetry(handler func(RetryEvent)) func() {
return m.Subscribe(func(e Event) {
if r, ok := e.(RetryEvent); ok {
handler(r)
}
})
return subscribeTyped(m, handler)
}
// ---------------------------------------------------------------------------
+6
View File
@@ -20,3 +20,9 @@ func (m *Kit) ConfigFloatForTest(key string) float64 { return m.v.GetFloat64(key
// ConfigBoolForTest returns the bool value of key from this Kit's isolated
// configuration store.
func (m *Kit) ConfigBoolForTest(key string) bool { return m.v.GetBool(key) }
// ConfigStringSliceForTest returns the string slice value of key from this
// Kit's isolated configuration store.
func (m *Kit) ConfigStringSliceForTest(key string) []string {
return m.v.GetStringSlice(key)
}
+120 -30
View File
@@ -2,6 +2,8 @@ package kit
import (
"fmt"
"log"
"strings"
"github.com/mark3labs/kit/internal/extensions"
"github.com/mark3labs/kit/internal/message"
@@ -96,6 +98,23 @@ type ExtensionAPI interface {
AppendEntry(extType, data string) (string, error)
GetEntries(extType string) []ExtensionEntry
// Session-scoped extension state (last-write-wins key-value store).
// Backed by an in-memory map and (optionally) a sidecar file per session;
// state lives outside the conversation tree and is not visible to the LLM.
SetState(key, value string)
GetState(key string) (string, bool)
DeleteState(key string)
ListState() []string
// InitStatePersistence loads any existing state from the per-session
// sidecar file and installs a saver hook so that subsequent SetState /
// DeleteState mutations are flushed to disk. Safe to call multiple times;
// repeat calls simply reload and reinstall the saver.
//
// For ephemeral or in-memory sessions (no session file path), the call
// is a no-op and state remains in memory for the lifetime of the runner.
InitStatePersistence() error
// Status bar
SetStatus(entry ExtensionStatusBarEntry)
RemoveStatus(key string)
@@ -118,6 +137,7 @@ type ExtensionAPI interface {
EmitCustomEvent(name, data string)
EmitBeforeFork(targetID string, isUserMsg bool, userText string) (cancelled bool, reason string)
EmitBeforeSessionSwitch(switchReason string) (cancelled bool, reason string)
EmitBeforeSessionSwitchWithPrompt(switchReason, initialPrompt string) (cancelled bool, reason string)
// Commands
Commands() []ExtensionCommandDef
@@ -155,17 +175,17 @@ func (m *Kit) Extensions() ExtensionAPI {
// Context management
func (e *extensionAPI) SetContext(ctx extensions.Context) {
func (e *extensionAPI) SetContext(ctx ExtensionContext) {
if e.kit.extRunner != nil {
e.kit.extRunner.SetContext(ctx)
}
}
func (e *extensionAPI) GetContext() extensions.Context {
func (e *extensionAPI) GetContext() ExtensionContext {
if e.kit.extRunner != nil {
return e.kit.extRunner.GetContext()
}
return extensions.Context{}
return ExtensionContext{}
}
func (e *extensionAPI) UpdateContextModel(model string) {
@@ -178,7 +198,7 @@ func (e *extensionAPI) UpdateContextModel(model string) {
// Widgets
func (e *extensionAPI) SetWidget(config extensions.WidgetConfig) {
func (e *extensionAPI) SetWidget(config ExtensionWidgetConfig) {
if e.kit.extRunner != nil {
e.kit.extRunner.SetWidget(config)
}
@@ -190,7 +210,7 @@ func (e *extensionAPI) RemoveWidget(id string) {
}
}
func (e *extensionAPI) GetWidgets(placement extensions.WidgetPlacement) []extensions.WidgetConfig {
func (e *extensionAPI) GetWidgets(placement ExtensionWidgetPlacement) []ExtensionWidgetConfig {
if e.kit.extRunner == nil {
return nil
}
@@ -199,7 +219,7 @@ func (e *extensionAPI) GetWidgets(placement extensions.WidgetPlacement) []extens
// Header/Footer
func (e *extensionAPI) SetHeader(config extensions.HeaderFooterConfig) {
func (e *extensionAPI) SetHeader(config ExtensionHeaderFooterConfig) {
if e.kit.extRunner != nil {
e.kit.extRunner.SetHeader(config)
}
@@ -211,14 +231,14 @@ func (e *extensionAPI) RemoveHeader() {
}
}
func (e *extensionAPI) GetHeader() *extensions.HeaderFooterConfig {
func (e *extensionAPI) GetHeader() *ExtensionHeaderFooterConfig {
if e.kit.extRunner == nil {
return nil
}
return e.kit.extRunner.GetHeader()
}
func (e *extensionAPI) SetFooter(config extensions.HeaderFooterConfig) {
func (e *extensionAPI) SetFooter(config ExtensionHeaderFooterConfig) {
if e.kit.extRunner != nil {
e.kit.extRunner.SetFooter(config)
}
@@ -230,7 +250,7 @@ func (e *extensionAPI) RemoveFooter() {
}
}
func (e *extensionAPI) GetFooter() *extensions.HeaderFooterConfig {
func (e *extensionAPI) GetFooter() *ExtensionHeaderFooterConfig {
if e.kit.extRunner == nil {
return nil
}
@@ -239,7 +259,7 @@ func (e *extensionAPI) GetFooter() *extensions.HeaderFooterConfig {
// Editor
func (e *extensionAPI) SetEditor(config extensions.EditorConfig) {
func (e *extensionAPI) SetEditor(config ExtensionEditorConfig) {
if e.kit.extRunner != nil {
e.kit.extRunner.SetEditor(config)
}
@@ -251,7 +271,7 @@ func (e *extensionAPI) ResetEditor() {
}
}
func (e *extensionAPI) GetEditor() *extensions.EditorConfig {
func (e *extensionAPI) GetEditor() *ExtensionEditorConfig {
if e.kit.extRunner == nil {
return nil
}
@@ -260,13 +280,13 @@ func (e *extensionAPI) GetEditor() *extensions.EditorConfig {
// UI Visibility
func (e *extensionAPI) SetUIVisibility(v extensions.UIVisibility) {
func (e *extensionAPI) SetUIVisibility(v ExtensionUIVisibility) {
if e.kit.extRunner != nil {
e.kit.extRunner.SetUIVisibility(v)
}
}
func (e *extensionAPI) GetUIVisibility() *extensions.UIVisibility {
func (e *extensionAPI) GetUIVisibility() *ExtensionUIVisibility {
if e.kit.extRunner == nil {
return nil
}
@@ -275,14 +295,14 @@ func (e *extensionAPI) GetUIVisibility() *extensions.UIVisibility {
// Tool rendering
func (e *extensionAPI) GetToolRenderer(toolName string) *extensions.ToolRenderConfig {
func (e *extensionAPI) GetToolRenderer(toolName string) *ExtensionToolRenderConfig {
if e.kit.extRunner == nil {
return nil
}
return e.kit.extRunner.GetToolRenderer(toolName)
}
func (e *extensionAPI) GetMessageRenderer(name string) *extensions.MessageRendererConfig {
func (e *extensionAPI) GetMessageRenderer(name string) *ExtensionMessageRendererConfig {
if e.kit.extRunner == nil {
return nil
}
@@ -291,7 +311,7 @@ func (e *extensionAPI) GetMessageRenderer(name string) *extensions.MessageRender
// Session data
func (e *extensionAPI) GetSessionMessages() []extensions.SessionMessage {
func (e *extensionAPI) GetSessionMessages() []ExtensionSessionMessage {
if e.kit.session == nil {
return nil
}
@@ -299,8 +319,8 @@ func (e *extensionAPI) GetSessionMessages() []extensions.SessionMessage {
// Try to use the legacy iterBranchMessages for backward compatibility
// with the default TreeManager adapter
if adapter, ok := e.kit.session.(*treeManagerAdapter); ok {
return iterBranchMessages(adapter.inner, func(me *session.MessageEntry, msg message.Message) extensions.SessionMessage {
return extensions.SessionMessage{
return iterBranchMessages(adapter.inner, func(me *session.MessageEntry, msg message.Message) ExtensionSessionMessage {
return ExtensionSessionMessage{
ID: me.ID,
Role: string(msg.Role),
Content: msg.Content(),
@@ -311,10 +331,10 @@ func (e *extensionAPI) GetSessionMessages() []extensions.SessionMessage {
// For custom SessionManagers, use the public interface
branch := e.kit.session.GetCurrentBranch()
var result []extensions.SessionMessage
var result []ExtensionSessionMessage
for _, entry := range branch {
if entry.Type == EntryTypeMessage {
result = append(result, extensions.SessionMessage{
result = append(result, ExtensionSessionMessage{
ID: entry.ID,
Role: entry.Role,
Content: entry.Content,
@@ -332,14 +352,75 @@ func (e *extensionAPI) AppendEntry(extType, data string) (string, error) {
return e.kit.session.AppendExtensionData(extType, data)
}
func (e *extensionAPI) GetEntries(extType string) []extensions.ExtensionEntry {
func (e *extensionAPI) SetState(key, value string) {
if e.kit.extRunner != nil {
e.kit.extRunner.SetState(key, value)
}
}
func (e *extensionAPI) GetState(key string) (string, bool) {
if e.kit.extRunner == nil {
return "", false
}
return e.kit.extRunner.GetState(key)
}
func (e *extensionAPI) DeleteState(key string) {
if e.kit.extRunner != nil {
e.kit.extRunner.DeleteState(key)
}
}
func (e *extensionAPI) ListState() []string {
if e.kit.extRunner == nil {
return nil
}
return e.kit.extRunner.ListState()
}
func (e *extensionAPI) InitStatePersistence() error {
if e.kit.extRunner == nil {
return nil
}
path := extStateSidecarPath(e.kit.GetSessionPath())
if path == "" {
// Ephemeral or in-memory session; no on-disk state.
e.kit.extRunner.SetStateSaver(nil)
return nil
}
if err := e.kit.extRunner.LoadStateFromFile(path); err != nil {
return err
}
runner := e.kit.extRunner
runner.SetStateSaver(func() {
if err := runner.SaveStateToFile(path); err != nil {
log.Printf("WARN extension state save failed: path=%s err=%v", path, err)
}
})
return nil
}
// extStateSidecarPath returns the path to the per-session extension state
// sidecar file derived from the session's JSONL path. Returns empty for
// ephemeral / in-memory sessions where no JSONL is being written.
func extStateSidecarPath(sessionPath string) string {
if sessionPath == "" {
return ""
}
if trimmed, ok := strings.CutSuffix(sessionPath, ".jsonl"); ok {
return trimmed + ".ext-state.json"
}
return sessionPath + ".ext-state.json"
}
func (e *extensionAPI) GetEntries(extType string) []ExtensionEntry {
if e.kit.session == nil {
return nil
}
entries := e.kit.session.GetExtensionData(extType)
result := make([]extensions.ExtensionEntry, 0, len(entries))
result := make([]ExtensionEntry, 0, len(entries))
for _, e := range entries {
result = append(result, extensions.ExtensionEntry{
result = append(result, ExtensionEntry{
ID: e.ID,
EntryType: e.ExtType,
Data: e.Data,
@@ -351,7 +432,7 @@ func (e *extensionAPI) GetEntries(extType string) []extensions.ExtensionEntry {
// Status bar
func (e *extensionAPI) SetStatus(entry extensions.StatusBarEntry) {
func (e *extensionAPI) SetStatus(entry ExtensionStatusBarEntry) {
if e.kit.extRunner != nil {
e.kit.extRunner.SetStatusEntry(entry)
}
@@ -363,7 +444,7 @@ func (e *extensionAPI) RemoveStatus(key string) {
}
}
func (e *extensionAPI) GetStatusEntries() []extensions.StatusBarEntry {
func (e *extensionAPI) GetStatusEntries() []ExtensionStatusBarEntry {
if e.kit.extRunner == nil {
return nil
}
@@ -394,12 +475,12 @@ func (e *extensionAPI) GetShortcuts() map[string]func() {
// Tools
func (e *extensionAPI) GetToolInfos() []extensions.ToolInfo {
func (e *extensionAPI) GetToolInfos() []ExtensionToolInfo {
agentTools := e.kit.agent.GetTools()
coreCount := e.kit.agent.GetCoreToolCount()
mcpCount := e.kit.agent.GetMCPToolCount()
result := make([]extensions.ToolInfo, 0, len(agentTools))
result := make([]ExtensionToolInfo, 0, len(agentTools))
for i, t := range agentTools {
info := t.Info()
source := "core"
@@ -412,7 +493,7 @@ func (e *extensionAPI) GetToolInfos() []extensions.ToolInfo {
if e.kit.extRunner != nil && e.kit.extRunner.IsToolDisabled(info.Name) {
enabled = false
}
result = append(result, extensions.ToolInfo{
result = append(result, ExtensionToolInfo{
Name: info.Name,
Description: info.Description,
Source: source,
@@ -487,11 +568,20 @@ func (e *extensionAPI) EmitBeforeFork(targetID string, isUserMsg bool, userText
}
func (e *extensionAPI) EmitBeforeSessionSwitch(switchReason string) (cancelled bool, reason string) {
return e.EmitBeforeSessionSwitchWithPrompt(switchReason, "")
}
// EmitBeforeSessionSwitchWithPrompt is like EmitBeforeSessionSwitch but also
// supplies the initial user prompt (if any) that will be submitted as the
// first turn of the new session. Extensions inspecting BeforeSessionSwitchEvent
// see this value in the event's InitialPrompt field.
func (e *extensionAPI) EmitBeforeSessionSwitchWithPrompt(switchReason, initialPrompt string) (cancelled bool, reason string) {
if e.kit.extRunner == nil || !e.kit.extRunner.HasHandlers(extensions.BeforeSessionSwitch) {
return false, ""
}
result, _ := e.kit.extRunner.Emit(extensions.BeforeSessionSwitchEvent{
Reason: switchReason,
Reason: switchReason,
InitialPrompt: initialPrompt,
})
if r, ok := result.(extensions.BeforeSessionSwitchResult); ok && r.Cancel {
reason := r.Reason
@@ -505,7 +595,7 @@ func (e *extensionAPI) EmitBeforeSessionSwitch(switchReason string) (cancelled b
// Commands
func (e *extensionAPI) Commands() []extensions.CommandDef {
func (e *extensionAPI) Commands() []ExtensionCommandDef {
if e.kit.extRunner == nil {
return nil
}
+226 -2
View File
@@ -3,8 +3,11 @@ package kit
import (
"strings"
"sync"
"time"
"github.com/mark3labs/kit/internal/auth"
"github.com/mark3labs/kit/internal/extensions"
"github.com/mark3labs/kit/internal/models"
)
// bridgeExtensions registers extension event handlers as SDK hooks and
@@ -19,6 +22,30 @@ import (
// wrapper (internal/extensions/wrapper.go) which composes underneath the SDK
// hook wrapper.
func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
// Per-turn aggregator: collects tool/LLM/usage signals between AgentStart
// and AgentEnd so the enriched AgentEndEvent can be populated without
// requiring extensions to maintain parallel bookkeeping.
//
// NOTE: this aggregator assumes a single in-flight turn per *Kit instance,
// which is the current contract — runTurn does not serialize callers and
// the SDK's TurnStartEvent/TurnEndEvent do not carry a turn ID, so two
// concurrent Prompt() calls on the same *Kit would clobber the counters.
// All current callers (TUI app layer, CLI runner, SDK examples) serialize
// turns above this layer. If concurrent turns become a supported use case,
// extend TurnStartEvent/TurnEndEvent with a turn ID and key this map per
// turn instead.
turnAgg := &turnAggregator{kit: m}
m.Subscribe(func(e Event) {
switch ev := e.(type) {
case TurnStartEvent:
turnAgg.start()
case ToolResultEvent:
turnAgg.recordTool(ev.ToolName)
case StepFinishEvent:
turnAgg.recordStep(ev.Usage)
}
})
// --- Interception hooks ---
// Extension Input → BeforeTurn hook (high priority, runs first).
@@ -109,9 +136,19 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
} else if stopReason == "" {
stopReason = "completed"
}
agg := turnAgg.consume()
_, _ = runner.Emit(extensions.AgentEndEvent{
Response: response,
StopReason: stopReason,
Response: response,
StopReason: stopReason,
ToolCallCount: agg.toolCallCount,
ToolNames: agg.toolNames,
LLMCallCount: agg.llmCallCount,
InputTokensDelta: agg.inputTokens,
OutputTokensDelta: agg.outputTokens,
CacheReadTokensDelta: agg.cacheReadTokens,
CacheWriteTokensDelta: agg.cacheWriteTokens,
CostDelta: agg.cost,
DurationMs: agg.durationMs(),
})
}
})
@@ -302,6 +339,32 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
}
})
// LLMUsage: derive per-call usage from StepFinish. Each step corresponds
// to one LLM provider call, so the step's usage is the per-call delta.
// Cost is computed from the current model's pricing (zero when unknown
// or OAuth credentials are in use). RequestID is left empty until the
// SDK surfaces a correlation id from the underlying provider.
if runner.HasHandlers(extensions.LLMUsage) {
m.Subscribe(func(e Event) {
ev, ok := e.(StepFinishEvent)
if !ok {
return
}
provider, modelID, cost := llmUsageMeta(m, ev.Usage)
_, _ = runner.Emit(extensions.LLMUsageEvent{
InputTokens: int(ev.Usage.InputTokens),
OutputTokens: int(ev.Usage.OutputTokens),
CacheReadTokens: int(ev.Usage.CacheReadTokens),
CacheWriteTokens: int(ev.Usage.CacheCreationTokens),
Cost: cost,
Model: modelID,
Provider: provider,
StepNumber: ev.StepNumber,
FinishReason: ev.FinishReason,
})
})
}
bridgeObserve(m, runner, extensions.ReasoningStart, func(ev ReasoningStartEvent) extensions.Event {
return extensions.ReasoningStartEvent{ID: ev.ID}
})
@@ -363,6 +426,167 @@ func bridgeObserve[In Event](m *Kit, runner *extensions.Runner, kind extensions.
})
}
// turnAggregator collects per-turn signals (tool calls, LLM round-trips, token
// usage, wall-clock duration) so that the enriched AgentEndEvent can be
// populated without requiring extensions to maintain parallel bookkeeping.
//
// The aggregator resets on each TurnStartEvent and is consumed (snapshotted +
// reset) on TurnEndEvent. All access is serialized via a mutex because the
// underlying event bus may fan handlers across goroutines in the future.
type turnAggregator struct {
mu sync.Mutex
started time.Time
ended time.Time
toolCallCount int
toolNames []string
llmCallCount int
inputTokens int
outputTokens int
cacheReadTokens int
cacheWriteTokens int
cost float64
kit *Kit
}
type turnSnapshot struct {
started time.Time
ended time.Time
toolCallCount int
toolNames []string
llmCallCount int
inputTokens int
outputTokens int
cacheReadTokens int
cacheWriteTokens int
cost float64
}
func (s turnSnapshot) durationMs() int64 {
if s.started.IsZero() {
return 0
}
end := s.ended
if end.IsZero() {
end = time.Now()
}
return end.Sub(s.started).Milliseconds()
}
// start resets all counters and records the turn's start time. Called from
// the TurnStartEvent subscriber.
func (a *turnAggregator) start() {
a.mu.Lock()
defer a.mu.Unlock()
a.started = time.Now()
a.ended = time.Time{}
a.toolCallCount = 0
a.toolNames = nil
a.llmCallCount = 0
a.inputTokens = 0
a.outputTokens = 0
a.cacheReadTokens = 0
a.cacheWriteTokens = 0
a.cost = 0
}
func (a *turnAggregator) recordTool(name string) {
a.mu.Lock()
defer a.mu.Unlock()
a.toolCallCount++
if name != "" {
a.toolNames = append(a.toolNames, name)
}
}
func (a *turnAggregator) recordStep(usage LLMUsage) {
a.mu.Lock()
defer a.mu.Unlock()
a.llmCallCount++
a.inputTokens += int(usage.InputTokens)
a.outputTokens += int(usage.OutputTokens)
a.cacheReadTokens += int(usage.CacheReadTokens)
a.cacheWriteTokens += int(usage.CacheCreationTokens)
if a.kit != nil {
_, _, c := llmUsageMeta(a.kit, usage)
a.cost += c
}
}
// consume returns a snapshot of the current turn and marks it ended.
// Subsequent start() calls clear the snapshot.
func (a *turnAggregator) consume() turnSnapshot {
a.mu.Lock()
defer a.mu.Unlock()
a.ended = time.Now()
names := a.toolNames
if len(names) > 0 {
copied := make([]string, len(names))
copy(copied, names)
names = copied
}
return turnSnapshot{
started: a.started,
ended: a.ended,
toolCallCount: a.toolCallCount,
toolNames: names,
llmCallCount: a.llmCallCount,
inputTokens: a.inputTokens,
outputTokens: a.outputTokens,
cacheReadTokens: a.cacheReadTokens,
cacheWriteTokens: a.cacheWriteTokens,
cost: a.cost,
}
}
// llmUsageMeta returns the current provider, model id, and computed cost for
// the given usage values using the Kit instance's active model. Cost is zero
// in any of the following cases:
// - the *Kit pointer is nil or has no active model;
// - the model is not in the registry (custom fine-tunes, unknown providers);
// - the model has no pricing fields set;
// - the active credential is an Anthropic OAuth token (matches the
// existing usage_tracker behavior of suppressing cost for OAuth users).
func llmUsageMeta(m *Kit, usage LLMUsage) (provider, modelID string, cost float64) {
if m == nil {
return "", "", 0
}
modelString := m.GetModelString()
if modelString == "" {
return "", "", 0
}
p, id, err := models.ParseModelString(modelString)
if err != nil {
return "", "", 0
}
provider, modelID = p, id
info := models.GetGlobalRegistry().LookupModel(provider, modelID)
if info == nil {
return provider, modelID, 0
}
if isAnthropicOAuth(m, provider) {
return provider, modelID, 0
}
cost = float64(usage.InputTokens) * info.Cost.Input / 1_000_000
cost += float64(usage.OutputTokens) * info.Cost.Output / 1_000_000
if info.Cost.CacheRead != nil {
cost += float64(usage.CacheReadTokens) * (*info.Cost.CacheRead) / 1_000_000
}
if info.Cost.CacheWrite != nil {
cost += float64(usage.CacheCreationTokens) * (*info.Cost.CacheWrite) / 1_000_000
}
return provider, modelID, cost
}
// isAnthropicOAuth reports whether the current Anthropic credential resolves
// to a stored OAuth token (in which case the user is not billed per-token),
// so OnLLMUsage cost reporting agrees with ctx.GetSessionUsage().
func isAnthropicOAuth(m *Kit, provider string) bool {
if m == nil || provider != "anthropic" {
return false
}
return auth.IsAnthropicOAuth(m.v.GetString("provider-api-key"))
}
// llmToContextMessages converts a slice of LLM messages to extension
// ContextMessage values, extracting plain text from each message.
func llmToContextMessages(msgs []LLMMessage) []extensions.ContextMessage {
+140
View File
@@ -0,0 +1,140 @@
package kit
import (
"testing"
"time"
)
// TestTurnAggregator_BasicLifecycle exercises the per-turn aggregator:
// start → record several tools and steps → consume → snapshot should reflect
// the accumulated counts and zero out for the next turn.
func TestTurnAggregator_BasicLifecycle(t *testing.T) {
agg := &turnAggregator{}
agg.start()
agg.recordTool("bash")
agg.recordTool("read")
agg.recordTool("bash")
agg.recordStep(LLMUsage{
InputTokens: 100,
OutputTokens: 50,
CacheReadTokens: 10,
CacheCreationTokens: 5,
})
agg.recordStep(LLMUsage{
InputTokens: 200,
OutputTokens: 75,
})
snap := agg.consume()
if snap.toolCallCount != 3 {
t.Errorf("toolCallCount: got %d want 3", snap.toolCallCount)
}
wantNames := []string{"bash", "read", "bash"}
if len(snap.toolNames) != len(wantNames) {
t.Fatalf("toolNames length: got %d want %d", len(snap.toolNames), len(wantNames))
}
for i, n := range wantNames {
if snap.toolNames[i] != n {
t.Errorf("toolNames[%d]: got %q want %q", i, snap.toolNames[i], n)
}
}
if snap.llmCallCount != 2 {
t.Errorf("llmCallCount: got %d want 2", snap.llmCallCount)
}
if snap.inputTokens != 300 {
t.Errorf("inputTokens: got %d want 300", snap.inputTokens)
}
if snap.outputTokens != 125 {
t.Errorf("outputTokens: got %d want 125", snap.outputTokens)
}
if snap.cacheReadTokens != 10 {
t.Errorf("cacheReadTokens: got %d want 10", snap.cacheReadTokens)
}
if snap.cacheWriteTokens != 5 {
t.Errorf("cacheWriteTokens: got %d want 5", snap.cacheWriteTokens)
}
if snap.durationMs() < 0 {
t.Errorf("durationMs should not be negative, got %d", snap.durationMs())
}
}
func TestTurnAggregator_StartResetsCounters(t *testing.T) {
agg := &turnAggregator{}
agg.start()
agg.recordTool("bash")
agg.recordStep(LLMUsage{InputTokens: 50})
// Begin a new turn — previous counters should be cleared.
agg.start()
snap := agg.consume()
if snap.toolCallCount != 0 || snap.llmCallCount != 0 || snap.inputTokens != 0 {
t.Errorf("expected counters zeroed after start(), got %+v", snap)
}
if snap.toolNames != nil {
t.Errorf("expected toolNames=nil after start(), got %v", snap.toolNames)
}
}
// TestTurnAggregator_DurationMs verifies the snapshot computes a positive
// duration when consume() runs after start().
func TestTurnAggregator_DurationMs(t *testing.T) {
agg := &turnAggregator{}
agg.start()
time.Sleep(5 * time.Millisecond)
snap := agg.consume()
if snap.durationMs() < 1 {
t.Errorf("expected positive duration, got %d", snap.durationMs())
}
}
// TestTurnAggregator_ZeroStartSafe ensures a snapshot taken without a prior
// start() doesn't crash and reports zero duration.
func TestTurnAggregator_ZeroStartSafe(t *testing.T) {
agg := &turnAggregator{}
snap := agg.consume()
if snap.durationMs() != 0 {
t.Errorf("expected zero duration for unstarted aggregator, got %d", snap.durationMs())
}
}
// TestLLMUsageMeta_NilKit verifies the helper degrades gracefully when given
// a nil Kit instance (zero values, no panic).
func TestLLMUsageMeta_NilKit(t *testing.T) {
provider, modelID, cost := llmUsageMeta(nil, LLMUsage{InputTokens: 100})
if provider != "" || modelID != "" || cost != 0 {
t.Errorf("expected zero values for nil kit, got (%q,%q,%v)", provider, modelID, cost)
}
}
// TestIsAnthropicOAuth_NonAnthropic verifies the helper short-circuits for any
// provider other than "anthropic" without touching the credential store.
func TestIsAnthropicOAuth_NonAnthropic(t *testing.T) {
for _, provider := range []string{"openai", "google", "openrouter", ""} {
if isAnthropicOAuth(nil, provider) {
t.Errorf("isAnthropicOAuth(nil, %q) = true, want false", provider)
}
}
}
func TestExtStateSidecarPath(t *testing.T) {
tests := []struct {
name string
in string
want string
}{
{"empty", "", ""},
{"jsonl", "/tmp/sessions/abc.jsonl", "/tmp/sessions/abc.ext-state.json"},
{"jsonl with subdir", "/a/b/c.jsonl", "/a/b/c.ext-state.json"},
{"no extension", "/tmp/session-blob", "/tmp/session-blob.ext-state.json"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := extStateSidecarPath(tc.in)
if got != tc.want {
t.Errorf("extStateSidecarPath(%q): got %q want %q", tc.in, got, tc.want)
}
})
}
}
+59 -6
View File
@@ -138,6 +138,19 @@ func (m *Kit) GetToolNames() []string {
return names
}
// GetToolsForSubagent like GetTools but eliminates subagent tool
// to avoid infinite recursion.
func (m *Kit) GetToolsForSubagent() []Tool {
var tools []Tool
for _, t := range m.agent.GetTools() {
if t.Info().Name == "subagent" {
continue
}
tools = append(tools, t)
}
return tools
}
// GetLoadingMessage returns the agent's startup info message (e.g. GPU
// fallback info), or empty string if none.
func (m *Kit) GetLoadingMessage() string {
@@ -1034,9 +1047,25 @@ type Options struct {
AutoCompact bool // Auto-compact when near context limit
CompactionOptions *CompactionOptions // Config for auto-compaction (nil = defaults)
// Debug enables debug logging for the SDK.
// Debug enables debug logging for the SDK. When DebugLogger is nil this
// flag selects between the default no-op SimpleDebugLogger (Debug=false)
// and the built-in console/buffered logger (Debug=true). When DebugLogger
// is non-nil this flag is ignored — the supplied logger's
// IsDebugEnabled() controls whether downstream code emits messages.
Debug bool
// DebugLogger, if non-nil, routes low-level debug output from the engine
// and the MCP tool plumbing to a caller-supplied implementation. This is
// the SDK escape hatch for embedders that want to forward debug output
// into their own logging system (zap, slog, log/charm, an in-app TUI
// panel, etc.) instead of the built-in console logger.
//
// When nil (default) the Debug bool controls whether the built-in logger
// is installed. When non-nil this logger is used unconditionally and the
// Debug bool is ignored; the supplied logger's IsDebugEnabled() reports
// whether downstream code should bother formatting messages.
DebugLogger DebugLogger
// MCPAuthHandler handles OAuth authorization for remote MCP servers.
// When set, remote transports (streamable HTTP, SSE) are configured
// with OAuth support. If the server returns a 401, the handler is
@@ -1147,7 +1176,7 @@ type CLIOptions struct {
// - Continue: resume most recent session for SessionDir (or cwd)
// - SessionPath: open a specific JSONL session file
// - default: create a new tree session for SessionDir (or cwd)
func InitTreeSession(opts *Options) (*session.TreeManager, error) {
func InitTreeSession(opts *Options) (*TreeManager, error) {
if opts == nil {
opts = &Options{}
}
@@ -1317,9 +1346,25 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
}
// Load skills — either from explicit paths or via auto-discovery.
if !opts.NoSkills {
// Merge viper config with opts: CLI flag / config file values are
// already bound to viper by cmd/root.go, so v.GetBool("no-skills"),
// v.GetStringSlice("skill"), and v.GetString("skills-dir") capture
// both --flag and .kit.yml keys transparently.
noSkills := opts.NoSkills || v.GetBool("no-skills")
skillPaths := opts.Skills
if len(skillPaths) == 0 {
skillPaths = v.GetStringSlice("skill")
}
skillsDir := opts.SkillsDir
if skillsDir == "" {
skillsDir = v.GetString("skills-dir")
}
if !noSkills {
mergedOpts := *opts
mergedOpts.Skills = skillPaths
mergedOpts.SkillsDir = skillsDir
var err error
loadedSkills, err = loadSkills(opts)
loadedSkills, err = loadSkills(&mergedOpts)
if err != nil {
return fmt.Errorf("failed to load skills: %w", err)
}
@@ -1485,6 +1530,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
ToolWrapper: hookToolWrapper(beforeToolCall, afterToolResult),
ProviderConfig: providerConfig,
Debug: debug,
DebugLogger: opts.DebugLogger,
NoExtensions: noExtensions,
MaxSteps: maxSteps,
StreamingEnabled: streaming,
@@ -1814,8 +1860,14 @@ type SubagentConfig struct {
// Empty string uses a minimal default prompt.
SystemPrompt string
// Tools overrides the tool set. If nil, SubagentTools() is used (all
// core tools except subagent, preventing infinite recursion).
// Tools overrides the tool set available to the subagent.
// If nil and the subagent is created via the SDK (Kit.Subagent()), the
// static SubagentTools() set (all core tools except "subagent") is used.
// When spawned internally by the agent loop, the parent's active tools
// minus "subagent" are used instead (see GetToolsForSubagent()).
// Pass m.GetToolsForSubagent() explicitly to opt into inheritance from
// SDK call sites.
// (The subagent tool is dropped to prevent infinite recursion.)
Tools []Tool
// NoSession, when true, uses an in-memory ephemeral session. When false
@@ -2076,6 +2128,7 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
SystemPrompt: systemPrompt,
Timeout: timeout,
OnEvent: onEvent,
Tools: m.GetToolsForSubagent(),
})
m.cleanupSubagentListeners(toolCallID)
if result == nil {
+75
View File
@@ -365,6 +365,81 @@ func TestNewSystemPromptFilePath(t *testing.T) {
}
}
// TestNewWithSkillsOptions verifies that the three skills-related Options
// fields (NoSkills, Skills, SkillsDir) are wired correctly into kit.New().
func TestNewWithSkillsOptions(t *testing.T) {
if os.Getenv("ANTHROPIC_API_KEY") == "" {
t.Skip("Skipping test: ANTHROPIC_API_KEY not set")
}
ctx := context.Background()
t.Run("NoSkills disables skill loading", func(t *testing.T) {
host, err := kit.New(ctx, &kit.Options{
Model: "anthropic/claude-sonnet-4-5-20250929",
Quiet: true,
NoSession: true,
NoSkills: true,
})
if err != nil {
t.Fatalf("kit.New failed: %v", err)
}
defer func() { _ = host.Close() }()
if got := host.GetSkills(); len(got) != 0 {
t.Errorf("NoSkills=true: expected 0 skills, got %d", len(got))
}
})
t.Run("SkillsDir propagates", func(t *testing.T) {
// Use a non-existent dir — no skills will load but the option must be
// accepted without error and result in zero skills.
dir := t.TempDir()
host, err := kit.New(ctx, &kit.Options{
Model: "anthropic/claude-sonnet-4-5-20250929",
Quiet: true,
NoSession: true,
SkillsDir: dir,
})
if err != nil {
t.Fatalf("kit.New failed: %v", err)
}
defer func() { _ = host.Close() }()
// Empty dir → no skills; the important thing is no error.
_ = host.GetSkills()
})
t.Run("explicit Skills paths load correctly", func(t *testing.T) {
// Write a minimal skill file to a temp dir.
dir := t.TempDir()
skillFile := dir + "/my-skill.md"
content := "---\nname: test-skill\ndescription: A test skill\n---\nDo the thing.\n"
if err := os.WriteFile(skillFile, []byte(content), 0o644); err != nil {
t.Fatalf("failed to write skill file: %v", err)
}
host, err := kit.New(ctx, &kit.Options{
Model: "anthropic/claude-sonnet-4-5-20250929",
Quiet: true,
NoSession: true,
Skills: []string{skillFile},
})
if err != nil {
t.Fatalf("kit.New failed: %v", err)
}
defer func() { _ = host.Close() }()
skills := host.GetSkills()
if len(skills) != 1 {
t.Fatalf("expected 1 skill, got %d", len(skills))
}
if skills[0].Name != "test-skill" {
t.Errorf("skill name = %q; want %q", skills[0].Name, "test-skill")
}
})
}
// TestNewSystemPromptInline confirms that inline system-prompt strings still
// flow through unchanged after the file-path resolution change.
func TestNewSystemPromptInline(t *testing.T) {
+5 -33
View File
@@ -102,10 +102,11 @@ type MCPTaskProgressHandler func(MCPTaskProgress)
// are optional; the zero value disables progress callbacks and applies
// sensible polling defaults inside the engine.
//
// For most consumers, the flat [Options] fields (`MCPTaskMode`,
// `MCPTaskTTL`, `MCPTaskPollInterval`, `MCPTaskMaxPollInterval`,
// `MCPTaskTimeout`, `MCPTaskProgress`) are the preferred entry point.
// MCPTaskConfig is exposed for the low-level [AgentConfig] path.
// Most consumers configure these via the flat [Options] fields
// (`MCPTaskMode`, `MCPTaskTTL`, `MCPTaskPollInterval`,
// `MCPTaskMaxPollInterval`, `MCPTaskTimeout`, `MCPTaskProgress`). The
// MCPTaskConfig type itself is retained for downstream consumers that
// receive it on engine-facing call sites.
type MCPTaskConfig struct {
// PerServerMode overrides the per-server task mode resolved from
// [MCPServerConfig]. Keys are server names. Missing entries fall back
@@ -133,35 +134,6 @@ type MCPTaskConfig struct {
Progress MCPTaskProgressHandler
}
// toToolsConfig converts the SDK-level [MCPTaskConfig] to the internal
// tools-package representation. Keeps the dependency arrow internal-only.
func (c MCPTaskConfig) toToolsConfig() tools.MCPTaskConfig {
cfg := tools.MCPTaskConfig{
DefaultTTL: c.DefaultTTL,
PollInterval: c.PollInterval,
MaxPollInterval: c.MaxPollInterval,
Timeout: c.Timeout,
}
if len(c.PerServerMode) > 0 {
cfg.PerServerMode = make(map[string]tools.MCPTaskMode, len(c.PerServerMode))
for k, v := range c.PerServerMode {
cfg.PerServerMode[k] = tools.MCPTaskMode(v)
}
}
if c.Progress != nil {
h := c.Progress
cfg.Progress = func(p tools.MCPTaskProgress) {
h(MCPTaskProgress{
Server: p.Server,
TaskID: p.TaskID,
Status: MCPTaskStatus(p.Status),
Message: p.Message,
})
}
}
return cfg
}
// mcpTaskOptions carries SDK consumer configuration into the agent setup.
// Stored on Options as a single value so the public surface stays compact;
// individual fields are exposed via WithMCP* builder functions.
+20
View File
@@ -61,3 +61,23 @@ func CheckProviderReady(provider string) error {
}
return models.GetGlobalRegistry().ValidateEnvironment(provider, "")
}
// ResolveProviderBaseURL returns the base API URL kit will use when talking to
// the given provider, applying the same resolution order that CreateProvider
// uses internally:
//
// 1. The provider's `api` field from the models.dev registry.
// 2. The hard-coded default base URL of its npm SDK package (e.g.
// @ai-sdk/groq → https://api.groq.com/openai/v1).
// 3. Template substitution against the current process environment when the
// URL contains "${VAR}" placeholders.
//
// Returns a non-nil error when the provider is unknown, when no URL can be
// derived, or when a templated URL has unset placeholders.
//
// Use this from your SDK integration to surface the effective endpoint before
// instantiating a Kit, or to validate that a provider is reachable without
// running an actual request.
func ResolveProviderBaseURL(providerID string) (string, error) {
return models.ResolveProviderBaseURL(providerID)
}
+11
View File
@@ -83,6 +83,17 @@ func WithConfigFile(path string) Option { return func(o *Options) { o.ConfigFile
// WithDebug enables SDK debug logging.
func WithDebug() Option { return func(o *Options) { o.Debug = true } }
// WithDebugLogger installs a caller-supplied [DebugLogger] for low-level
// engine and MCP tool plumbing output. When set this overrides the built-in
// logger selected by [WithDebug] — messages flow into the supplied logger
// unconditionally, and the logger's IsDebugEnabled reports whether downstream
// code should bother formatting them. Use this to forward Kit's debug output
// into your application's logging system (slog, zap, charm/log, an in-app
// panel, etc.).
func WithDebugLogger(l DebugLogger) Option {
return func(o *Options) { o.DebugLogger = l }
}
// Ephemeral configures an in-memory session with no persistence (equivalent to
// Options.NoSession = true).
func Ephemeral() Option { return func(o *Options) { o.NoSession = true } }
+27 -71
View File
@@ -7,45 +7,36 @@ import (
"github.com/mark3labs/kit/internal/extensions"
"github.com/mark3labs/kit/internal/models"
"github.com/mark3labs/kit/internal/prompts"
"github.com/mark3labs/kit/internal/skills"
)
// ---------------------------------------------------------------------------
// Template Parsing Bridge for Extensions (Phase 3)
// ---------------------------------------------------------------------------
// varRegex matches {{variable}} placeholders in templates.
var varRegex = regexp.MustCompile(`\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}\}`)
// ParseTemplate extracts {{variables}} from template content.
// ParseTemplate extracts {{variables}} from template content. The template
// grammar is shared with skill prompt templates, so a template parses
// identically regardless of which API loads it.
func ParseTemplate(name, content string) extensions.PromptTemplate {
matches := varRegex.FindAllStringSubmatch(content, -1)
vars := make([]string, 0, len(matches))
seen := make(map[string]bool)
for _, m := range matches {
if len(m) > 1 && !seen[m[1]] {
seen[m[1]] = true
vars = append(vars, m[1])
}
tpl := skills.NewPromptTemplate(name, content)
vars := tpl.Variables
if vars == nil {
vars = []string{}
}
return extensions.PromptTemplate{
Name: name,
Content: content,
Name: tpl.Name,
Content: tpl.Content,
Variables: vars,
}
}
// RenderTemplate substitutes variables into template content.
// Handles {{name}} and {{ name }} (any whitespace) placeholders.
// Handles {{name}} and {{ name }} (any whitespace) placeholders; missing
// variables are left as-is.
func RenderTemplate(tpl extensions.PromptTemplate, vars map[string]string) string {
return varRegex.ReplaceAllStringFunc(tpl.Content, func(m string) string {
sub := varRegex.FindStringSubmatch(m)
if len(sub) > 1 {
if v, ok := vars[sub[1]]; ok {
return v
}
}
return m
})
t := skills.PromptTemplate{Content: tpl.Content}
return t.Expand(vars)
}
// ParseArguments parses command-line style arguments.
@@ -183,44 +174,12 @@ func SimpleParseArguments(input string, count int) []string {
return result
}
// parseFields splits input respecting quoted strings.
// parseFields splits input into arguments respecting quoted strings and
// backslash escaping. It delegates to the canonical tokenizer in
// internal/prompts so extension argument parsing and builtin prompt-template
// parsing agree on grammar.
func parseFields(input string) []string {
var fields []string
var current strings.Builder
inQuote := false
quoteChar := rune(0)
for _, r := range input {
switch r {
case '"', '\'':
if !inQuote {
inQuote = true
quoteChar = r
} else if r == quoteChar {
inQuote = false
quoteChar = 0
} else {
current.WriteRune(r)
}
case ' ', '\t':
if inQuote {
current.WriteRune(r)
} else {
if current.Len() > 0 {
fields = append(fields, current.String())
current.Reset()
}
}
default:
current.WriteRune(r)
}
}
if current.Len() > 0 {
fields = append(fields, current.String())
}
return fields
return prompts.ParseCommandArgs(input)
}
// EvaluateModelConditional checks if condition matches current model.
@@ -417,21 +376,18 @@ func MatchModelGlob(model, pattern string) bool {
}
// ExtractProviderFromPath extracts provider from a path-like model string.
//
// Deprecated: Use GetCurrentProvider instead.
func ExtractProviderFromPath(model string) string {
parts := strings.Split(model, "/")
if len(parts) >= 2 {
return parts[0]
}
return ""
return GetCurrentProvider(model)
}
// ExtractModelFromPath extracts model ID from a path-like model string.
//
// Deprecated: Use RemoveProviderFromModel instead, which correctly handles
// model IDs containing "/" (e.g. "openrouter/meta/llama").
func ExtractModelFromPath(model string) string {
parts := strings.Split(model, "/")
if len(parts) >= 2 {
return parts[1]
}
return model
return RemoveProviderFromModel(model)
}
// IsBareModelID checks if a string is a bare model ID (no provider).
+4 -108
View File
@@ -5,13 +5,11 @@ import (
"charm.land/fantasy"
"github.com/mark3labs/kit/internal/agent"
"github.com/mark3labs/kit/internal/compaction"
"github.com/mark3labs/kit/internal/config"
"github.com/mark3labs/kit/internal/message"
"github.com/mark3labs/kit/internal/models"
"github.com/mark3labs/kit/internal/session"
"github.com/mark3labs/kit/internal/tools"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/server"
)
@@ -83,9 +81,10 @@ type MCPServerConfig = config.MCPServerConfig
// concurrent use.
//
// Most consumers do not need to provide one; pass [Options.Debug] = true
// to use the default logger. DebugLogger is exposed for the low-level
// [AgentConfig] path and for embedders that want to route debug output
// into their own logging system.
// (or use [WithDebug]) to install the built-in console logger. DebugLogger
// is the escape hatch for embedders that want to route debug output into
// their own logging system — install one via [Options.DebugLogger] or
// [WithDebugLogger].
type DebugLogger interface {
// LogDebug records a single debug message. Implementations may drop,
// buffer, or render the message however they choose.
@@ -95,109 +94,6 @@ type DebugLogger interface {
IsDebugEnabled() bool
}
// AgentConfig holds configuration options for constructing an agent at the
// SDK boundary. All fields use SDK-owned types, so consumers can populate
// this struct without importing any underlying LLM-provider package.
//
// For most use cases, prefer the high-level [New] entry point with
// [Options]. AgentConfig is exposed for advanced consumers that need
// direct access to the lower-level agent configuration shape.
type AgentConfig struct {
// ModelConfig holds the LLM provider configuration. A nil value means
// that the default provider/model resolution will be used.
ModelConfig *ProviderConfig
// MCPConfig describes any MCP servers whose tools should be loaded
// alongside core tools.
MCPConfig *Config
// SystemPrompt is the system prompt sent to the LLM.
SystemPrompt string
// MaxSteps caps the number of LLM iterations per turn. A value of
// zero means no cap is applied at this layer.
MaxSteps int
// StreamingEnabled controls whether the agent streams responses.
StreamingEnabled bool
// AuthHandler handles OAuth authorization for remote MCP servers.
// When nil, remote MCP servers requiring OAuth will fail to connect.
AuthHandler MCPAuthHandler
// TokenStoreFactory, if non-nil, creates a custom token store for each
// remote MCP server's OAuth tokens. When nil, the default file-based
// token store is used.
TokenStoreFactory MCPTokenStoreFactory
// CoreTools overrides the default core tool set. If empty, [AllTools]
// is used. Provide a custom tool set (e.g. [CodingTools] or tools
// built with a custom WorkDir) to scope agent capabilities.
CoreTools []Tool
// DisableCoreTools, when true, prevents loading any core tools.
// Combined with empty CoreTools this yields a chat-only agent with
// no built-in tools.
DisableCoreTools bool
// ExtraTools are additional tools loaded alongside core and MCP tools.
ExtraTools []Tool
// ToolWrapper, if non-nil, wraps the combined tool list before it is
// handed to the LLM. Used to intercept tool calls or results.
ToolWrapper func([]Tool) []Tool
// OnMCPServerLoaded, if non-nil, is invoked once for each MCP server
// when its tools have finished loading (or failed). Called from a
// background goroutine.
OnMCPServerLoaded func(serverName string, toolCount int, err error)
// DebugLogger receives low-level debug output from the engine and the
// MCP tool plumbing. Nil means no debug output is emitted at this
// layer (regardless of [Options.Debug], which feeds the higher-level
// [New] entry point). Pass an implementation here when wiring a custom
// logger through the lower-level AgentConfig path.
DebugLogger DebugLogger
// MCPTaskConfig configures task-aware MCP tools/call execution — mode
// overrides, polling intervals, timeouts, and the progress handler.
// The zero value preserves historical synchronous-only behaviour for
// any server that didn't advertise task support during initialize.
MCPTaskConfig MCPTaskConfig
}
// toInternal converts an AgentConfig to its internal representation.
// Slice and function fields convert without allocation because [Tool]
// is a type alias for the underlying LLM-tool type.
func (c *AgentConfig) toInternal() *agent.AgentConfig {
if c == nil {
return nil
}
out := &agent.AgentConfig{
ModelConfig: c.ModelConfig,
MCPConfig: c.MCPConfig,
SystemPrompt: c.SystemPrompt,
MaxSteps: c.MaxSteps,
StreamingEnabled: c.StreamingEnabled,
CoreTools: c.CoreTools,
DisableCoreTools: c.DisableCoreTools,
ExtraTools: c.ExtraTools,
ToolWrapper: c.ToolWrapper,
OnMCPServerLoaded: c.OnMCPServerLoaded,
}
if c.AuthHandler != nil {
out.AuthHandler = c.AuthHandler
}
if c.TokenStoreFactory != nil {
out.TokenStoreFactory = tools.TokenStoreFactory(c.TokenStoreFactory)
}
if c.DebugLogger != nil {
out.DebugLogger = c.DebugLogger
}
out.MCPTaskConfig = c.MCPTaskConfig.toToolsConfig()
return out
}
// ToolCallHandler is invoked when the LLM produces a tool call. It receives
// the call ID, tool name, and the JSON-encoded input arguments.
type ToolCallHandler func(toolCallID, toolName, toolArgs string)
+35 -41
View File
@@ -264,30 +264,31 @@ func TestConvertFromLLMMessage(t *testing.T) {
}
}
// TestAgentConfigNoFantasyImport verifies AgentConfig can be populated with
// every field — including CoreTools, ExtraTools, and ToolWrapper — using
// only SDK-owned types. This test deliberately does not import
// "charm.land/fantasy"; the package compiling at all is the proof that the
// SDK no longer leaks the dependency name through AgentConfig.
// TestOptionsNoFantasyImport verifies Options can be populated with the
// tool-related fields — Tools and ExtraTools — using only SDK-owned types.
// This test deliberately does not import "charm.land/fantasy"; the package
// compiling at all is the proof that the SDK no longer leaks the dependency
// name through the Options surface.
//
// Tool-call interception (formerly the AgentConfig.ToolWrapper escape hatch)
// is covered by the hook system — [Kit.OnBeforeToolCall] /
// [Kit.OnAfterToolResult] — whose hook payload types also use only
// SDK-owned identifiers; see hooks_test.go.
//
// Regression test for https://github.com/mark3labs/kit/issues/30.
func TestAgentConfigNoFantasyImport(t *testing.T) {
func TestOptionsNoFantasyImport(t *testing.T) {
myTool := kit.NewTool[struct{}]("noop", "does nothing", func(_ context.Context, _ struct{}) (kit.ToolOutput, error) {
return kit.TextResult("ok"), nil
})
wrapperCalled := false
cfg := kit.AgentConfig{
SystemPrompt: "you are a tester",
MaxSteps: 5,
StreamingEnabled: true,
CoreTools: []kit.Tool{myTool},
ExtraTools: []kit.Tool{myTool},
DisableCoreTools: false,
ToolWrapper: func(in []kit.Tool) []kit.Tool {
wrapperCalled = true
return in
},
streaming := true
cfg := kit.Options{
SystemPrompt: "you are a tester",
MaxSteps: 5,
Streaming: &streaming,
Tools: []kit.Tool{myTool},
ExtraTools: []kit.Tool{myTool},
DisableCoreTools: false,
OnMCPServerLoaded: func(_ string, _ int, _ error) {},
}
@@ -297,36 +298,29 @@ func TestAgentConfigNoFantasyImport(t *testing.T) {
if cfg.MaxSteps != 5 {
t.Errorf("MaxSteps = %d, want 5", cfg.MaxSteps)
}
if !cfg.StreamingEnabled {
t.Error("StreamingEnabled = false, want true")
if cfg.Streaming == nil || !*cfg.Streaming {
t.Error("Streaming = false/nil, want true")
}
if len(cfg.CoreTools) != 1 {
t.Errorf("CoreTools len = %d, want 1", len(cfg.CoreTools))
if len(cfg.Tools) != 1 {
t.Errorf("Tools len = %d, want 1", len(cfg.Tools))
}
if len(cfg.ExtraTools) != 1 {
t.Errorf("ExtraTools len = %d, want 1", len(cfg.ExtraTools))
}
// Exercise the wrapper to confirm the func type is usable.
out := cfg.ToolWrapper(cfg.CoreTools)
if !wrapperCalled {
t.Error("ToolWrapper was not invoked")
}
if len(out) != 1 {
t.Errorf("wrapped tool list len = %d, want 1", len(out))
}
}
// TestAgentConfigToolWrapperSignature documents that AgentConfig.ToolWrapper
// uses kit.Tool (not the underlying provider type) in its signature.
func TestAgentConfigToolWrapperSignature(t *testing.T) {
//nolint:staticcheck // QF1011: explicit type asserts the SDK-side func signature.
var _ func([]kit.Tool) []kit.Tool = func(in []kit.Tool) []kit.Tool { return in }
cfg := kit.AgentConfig{
ToolWrapper: func(in []kit.Tool) []kit.Tool { return in },
}
if cfg.ToolWrapper == nil {
t.Fatal("ToolWrapper assignment failed")
// TestToolSliceSignature documents that the kit.Tool alias — used by every
// SDK tool-related surface (Options.Tools, Options.ExtraTools, WithTools,
// WithExtraTools, hook payloads) — is referenced under its SDK-owned name
// in user code, without any fantasy import.
func TestToolSliceSignature(t *testing.T) {
var tools []kit.Tool
tools = append(tools, kit.NewTool[struct{}]("noop", "",
func(_ context.Context, _ struct{}) (kit.ToolOutput, error) {
return kit.TextResult("ok"), nil
}))
if len(tools) != 1 {
t.Fatalf("unexpected tool slice length: %d", len(tools))
}
}
+171
View File
@@ -63,6 +63,52 @@ func TestOptionFunctionsPlumbing(t *testing.T) {
}
}
// recordingDebugLogger is a kit.DebugLogger used to verify WithDebugLogger
// plumbs the supplied logger into Options. It records each LogDebug call.
type recordingDebugLogger struct {
enabled bool
messages []string
}
func (l *recordingDebugLogger) LogDebug(m string) { l.messages = append(l.messages, m) }
func (l *recordingDebugLogger) IsDebugEnabled() bool { return l.enabled }
// TestWithDebugLoggerPlumbing verifies that kit.WithDebugLogger assigns the
// supplied logger to Options.DebugLogger. End-to-end propagation into the
// engine is covered indirectly by the existing kitsetup tests; this test
// pins the SDK-surface contract.
func TestWithDebugLoggerPlumbing(t *testing.T) {
l := &recordingDebugLogger{enabled: true}
o := &kit.Options{}
kit.WithDebugLogger(l)(o)
if o.DebugLogger == nil {
t.Fatal("WithDebugLogger: expected Options.DebugLogger to be set")
}
if o.DebugLogger != l {
t.Error("WithDebugLogger: expected the supplied logger to be installed verbatim")
}
// Sanity: the installed logger satisfies the SDK interface contract.
if !o.DebugLogger.IsDebugEnabled() {
t.Error("installed logger IsDebugEnabled() returned false")
}
o.DebugLogger.LogDebug("hello")
if len(l.messages) != 1 || l.messages[0] != "hello" {
t.Errorf("LogDebug not forwarded; got %v", l.messages)
}
}
// TestWithDebugLoggerNilClears verifies that passing a nil logger to
// WithDebugLogger clears any previously-installed logger. This lets later
// options override earlier ones the same way WithModel / WithStreaming do.
func TestWithDebugLoggerNilClears(t *testing.T) {
o := &kit.Options{}
kit.WithDebugLogger(&recordingDebugLogger{enabled: true})(o)
kit.WithDebugLogger(nil)(o)
if o.DebugLogger != nil {
t.Errorf("WithDebugLogger(nil): expected DebugLogger to be cleared; got %#v", o.DebugLogger)
}
}
// TestOptionOrderingOverrides verifies later options override earlier ones.
func TestOptionOrderingOverrides(t *testing.T) {
o := &kit.Options{}
@@ -205,6 +251,131 @@ func TestNewZeroOptionsKeepsStreamingDefault(t *testing.T) {
}
}
// TestSkillsViperKeys verifies that the three skills config keys (no-skills,
// skill, skills-dir) flow through viper when set via a config file, matching
// the pattern used by no-extensions and no-core-tools. This test does not
// require an API key because it only exercises Options struct plumbing.
func TestSkillsViperKeys(t *testing.T) {
t.Run("NoSkills option disables skill loading", func(t *testing.T) {
o := &kit.Options{}
o.NoSkills = true
if !o.NoSkills {
t.Error("Options.NoSkills = true not reflected on struct")
}
})
t.Run("Skills paths set on Options", func(t *testing.T) {
o := &kit.Options{
Skills: []string{"/a/skill.md", "/b/skill.md"},
}
if len(o.Skills) != 2 {
t.Errorf("Options.Skills: got %d paths, want 2", len(o.Skills))
}
if o.Skills[0] != "/a/skill.md" {
t.Errorf("Options.Skills[0] = %q; want %q", o.Skills[0], "/a/skill.md")
}
})
t.Run("SkillsDir set on Options", func(t *testing.T) {
o := &kit.Options{
SkillsDir: "/custom/skills",
}
if o.SkillsDir != "/custom/skills" {
t.Errorf("Options.SkillsDir = %q; want %q", o.SkillsDir, "/custom/skills")
}
})
}
// TestSkillsConfigFileKeys verifies that no-skills, skill, and skills-dir
// config file keys are read via viper and applied correctly. Requires an API
// key because kit.New() is called to exercise the full config-load path.
func TestSkillsConfigFileKeys(t *testing.T) {
if os.Getenv("ANTHROPIC_API_KEY") == "" {
t.Skip("Skipping test: ANTHROPIC_API_KEY not set")
}
ctx := context.Background()
t.Run("no-skills config key disables skill loading", func(t *testing.T) {
// Write a config file with no-skills: true.
cfgFile := t.TempDir() + "/.kit.yml"
if err := os.WriteFile(cfgFile, []byte("no-skills: true\n"), 0o644); err != nil {
t.Fatalf("failed to write config: %v", err)
}
host, err := kit.New(ctx, &kit.Options{
Model: "anthropic/claude-sonnet-4-5-20250929",
Quiet: true,
NoSession: true,
ConfigFile: cfgFile,
})
if err != nil {
t.Fatalf("kit.New failed: %v", err)
}
defer func() { _ = host.Close() }()
if got := host.GetSkills(); len(got) != 0 {
t.Errorf("no-skills:true in config: expected 0 skills, got %d", len(got))
}
})
t.Run("skill config key loads explicit skill files", func(t *testing.T) {
dir := t.TempDir()
skillFile := dir + "/cfg-skill.md"
if err := os.WriteFile(skillFile, []byte("---\nname: cfg-skill\ndescription: from config\n---\nContent.\n"), 0o644); err != nil {
t.Fatalf("failed to write skill file: %v", err)
}
cfgContent := "skill:\n - " + skillFile + "\n"
cfgFile := dir + "/.kit.yml"
if err := os.WriteFile(cfgFile, []byte(cfgContent), 0o644); err != nil {
t.Fatalf("failed to write config: %v", err)
}
host, err := kit.New(ctx, &kit.Options{
Model: "anthropic/claude-sonnet-4-5-20250929",
Quiet: true,
NoSession: true,
ConfigFile: cfgFile,
})
if err != nil {
t.Fatalf("kit.New failed: %v", err)
}
defer func() { _ = host.Close() }()
skills := host.GetSkills()
if len(skills) != 1 {
t.Fatalf("expected 1 skill from config, got %d", len(skills))
}
if skills[0].Name != "cfg-skill" {
t.Errorf("skill name = %q; want %q", skills[0].Name, "cfg-skill")
}
})
t.Run("skills-dir config key overrides auto-discovery root", func(t *testing.T) {
dir := t.TempDir()
cfgContent := "skills-dir: " + dir + "\n"
cfgFile := dir + "/.kit.yml"
if err := os.WriteFile(cfgFile, []byte(cfgContent), 0o644); err != nil {
t.Fatalf("failed to write config: %v", err)
}
host, err := kit.New(ctx, &kit.Options{
Model: "anthropic/claude-sonnet-4-5-20250929",
Quiet: true,
NoSession: true,
ConfigFile: cfgFile,
})
if err != nil {
t.Fatalf("kit.New failed: %v", err)
}
defer func() { _ = host.Close() }()
// Empty dir → 0 skills; the key point is no error during init.
_ = host.GetSkills()
})
}
// TestNewStreamingExplicitOptOut verifies that a raw Options can still disable
// streaming by setting Streaming to a pointer to false.
func TestNewStreamingExplicitOptOut(t *testing.T) {
+57 -2
View File
@@ -88,7 +88,8 @@ api.OnAgentStart(func(e ext.AgentStartEvent, ctx ext.Context) {
// e.Prompt string
})
// Agent finished responding.
// Agent finished responding. Carries per-turn aggregates so observer-style
// extensions don't need to maintain parallel bookkeeping.
api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) {
// e.Response string
// e.StopReason string — "error" (on failure), "completed" (when LLM returns
@@ -96,6 +97,33 @@ api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) {
// (e.g. "stop", "length" (max output tokens hit), "tool-calls", "content-filter").
// To detect errors, check e.StopReason == "error".
// Do NOT compare against "completed" for success — instead check != "error".
//
// Per-turn aggregates (computed by Kit's runtime):
// e.ToolCallCount int — total tool invocations this turn
// e.ToolNames []string — tool names in call order (duplicates preserved)
// e.LLMCallCount int — LLM round-trips / tool-loop iterations
// e.InputTokensDelta int — sum of input tokens across LLM calls this turn
// e.OutputTokensDelta int
// e.CacheReadTokensDelta int
// e.CacheWriteTokensDelta int
// e.CostDelta float64 — USD cost (zero when pricing unknown / OAuth)
// e.DurationMs int64 — wall-clock duration AgentStart→AgentEnd
})
// Per-LLM-call usage — fires after each provider round-trip with token + cost
// deltas attributed to that specific call. A single turn typically produces
// multiple LLMUsageEvents (one per tool-loop iteration). Use this for accurate
// budget enforcement that needs to react between calls instead of waiting
// for the turn to finish.
api.OnLLMUsage(func(e ext.LLMUsageEvent, ctx ext.Context) {
// e.InputTokens, e.OutputTokens int
// e.CacheReadTokens, e.CacheWriteTokens int
// e.Cost float64 — USD; zero when pricing unknown / OAuth
// e.Model, e.Provider string — model used for THIS call
// (may differ across calls if SetModel was called)
// e.StepNumber int — zero-based step index in this turn
// e.FinishReason string — "stop" / "tool_calls" / "length" / ...
// e.RequestID string — optional provider correlation id (may be empty)
})
```
@@ -528,11 +556,38 @@ stats := ctx.GetContextStats() // .EstimatedTokens, .ContextLimit, .UsagePer
msgs := ctx.GetMessages() // []ext.SessionMessage on current branch
path := ctx.GetSessionPath() // file path of session JSONL
// Persist custom data in the session tree:
// Append-only log in the session tree (fork-aware, walked on every branch read):
id, err := ctx.AppendEntry("my-type", "data string")
entries := ctx.GetEntries("my-type") // []ext.ExtensionEntry{ID, EntryType, Data, Timestamp}
```
### Session State (last-write-wins)
Key-value store scoped to the session, persisted to a sidecar file
(`<session>.ext-state.json`) outside the conversation tree. Reads are O(1)
(no branch walk), writes don't grow the JSONL, and the store is not
duplicated on fork. State is invisible to the LLM and survives session
resume. For ephemeral / in-memory sessions, state lives only in memory.
```go
ctx.SetState("myext:budget-cap", "10.00") // last write wins
val, ok := ctx.GetState("myext:budget-cap") // (string, bool)
ctx.DeleteState("myext:budget-cap") // no-op if missing
keys := ctx.ListState() // []string, unspecified order
```
**When to use which:**
| Need | Use |
|------|-----|
| Snapshot state ("current value of X") | `SetState` / `GetState` |
| Audit log / event history | `AppendEntry` / `GetEntries` |
| One-shot per-turn signal | enriched `AgentEndEvent` fields |
| Per-LLM-call observation | `OnLLMUsage` event |
Namespace keys with your extension name (e.g. `"myext:budget-cap"`) to avoid
collisions across extensions.
### Model Management
```go
+13
View File
@@ -1104,6 +1104,19 @@ if extAPI.HasExtensions() {
tools := extAPI.GetToolInfos()
extAPI.SetActiveTools([]string{"bash", "read"})
// Session-scoped extension state (last-write-wins key-value store).
// Backed by an in-memory map and a per-session sidecar file
// (<session>.ext-state.json) outside the conversation tree.
extAPI.SetState("myext:budget-cap", "10.00")
val, ok := extAPI.GetState("myext:budget-cap")
extAPI.DeleteState("myext:budget-cap")
keys := extAPI.ListState()
// Load any existing state from the sidecar and install a saver hook so
// subsequent SetState/DeleteState mutations are flushed atomically.
// No-op for ephemeral / in-memory sessions. Safe to call multiple times.
_ = extAPI.InitStatePersistence()
// Events
extAPI.EmitSessionStart()
extAPI.EmitModelChange("new/model", "old/model", "extension")
+68
View File
@@ -56,6 +56,57 @@ kit install --all # Install all extensions without prompting
kit skill # Install the Kit extensions skill via skills.sh
```
### Skills CLI flags
Control which skills are loaded at startup:
```bash
# Load a specific skill file
kit --skill path/to/skill.md "prompt"
# Load multiple skill files or directories (flag is repeatable)
kit --skill ./skill1.md --skill ./skill2.md "prompt"
# Load all skills from a custom directory instead of the default locations
kit --skills-dir /path/to/skills "prompt"
# Disable all skill loading (auto-discovery and explicit)
kit --no-skills "prompt"
```
Skills are auto-discovered from `~/.config/kit/skills/`, `.kit/skills/`, and `.agents/skills/` by default. Use `--skills-dir` to override the project-local search root, or `--skill` to load files explicitly (which disables auto-discovery). `--no-skills` suppresses all skill loading regardless of other flags.
## GitHub integration
Scaffold a GitHub Actions workflow that runs Kit as an automated collaborator/reviewer. The workflow triggers when someone comments `/kit ...` on an issue or pull request review, runs the agent non-interactively in the runner, and lets it respond.
```bash
kit github install # Scaffold .github/workflows/kit.yml
kit github install --model anthropic/claude-sonnet-4-5-20250929 # Skip the model prompt
kit github install --force # Overwrite an existing workflow file
kit github install --no-secret # Skip the offer to set the provider secret via the gh CLI
```
By default the command prompts for the model (pre-filled with a sensible default). If the [`gh` CLI](https://cli.github.com/) is detected on your `PATH` and the provider API key is present in your environment, you'll be offered the option to store it as a repository secret automatically.
The generated workflow:
- Triggers only on `issue_comment` and `pull_request_review_comment` (`types: [created]`).
- Runs only when the comment begins with the `/kit` command token.
- Restricts triggers to repository owners, members, and collaborators (via `author_association`).
- Uses least-privilege `permissions` and `persist-credentials: false`.
- Authenticates git/PR operations with the built-in `secrets.GITHUB_TOKEN` and the provider via a repository secret (e.g. `ANTHROPIC_API_KEY`).
After committing the workflow and setting the provider secret, comment `/kit <your request>` on any issue or pull request to trigger Kit.
The generated workflow uses the bundled [`mark3labs/kit`](https://github.com/mark3labs/kit/blob/master/action.yml) composite action, which installs the Kit binary and runs `kit github run`. That command reads the triggering event, enforces permissions, reacts with an emoji, runs the agent against the issue thread or PR, posts the response as a comment, and — if the agent changed files — pushes a `kit-agent[bot]` branch and opens a pull request.
| Flag | Description |
|------|-------------|
| `--model` | Provider/model to write into the workflow |
| `--force` | Overwrite an existing workflow file |
| `--no-secret` | Skip the offer to set the provider secret via the `gh` CLI |
## Interactive slash commands
These commands are available inside the Kit TUI during an interactive session:
@@ -110,6 +161,23 @@ Press **Ctrl+X s** during streaming to inject a system-level instruction mid-tur
Example: While the model is writing code, press Ctrl+X s and type "Use async/await instead" to change the implementation approach.
### Image attachments
Attach images to your next prompt straight from the clipboard:
- Copy an image (e.g. a screenshot) to the system clipboard, then press **Ctrl+V** in the input to attach it.
- Press **Ctrl+U** to clear all pending image attachments.
- Attachments are sent alongside your text when you submit, and cleared afterward.
When a terminal supports color, Kit renders a small low-resolution **thumbnail preview** of each pending image directly in the input, below the `[N image(s) attached]` indicator, so you can confirm the right image was attached before sending.
The preview is drawn with Unicode half-block characters and ordinary terminal colors — not a graphics protocol — so it renders correctly inside terminal multiplexers like **tmux** and **zellij**. Thumbnails are capped to a small cell box for a glanceable, low-res look.
- Best fidelity needs a **truecolor** terminal (`COLORTERM=truecolor`); Kit degrades to 256-color where truecolor is unavailable.
- On terminals with neither, the preview is skipped and the `[N image(s) attached]` text indicator is shown alone.
You can also attach image files by referencing them with `@path/to/image.png` — binary files are auto-detected by MIME type. See [Quick Start](/quick-start) for the `@` attachment syntax.
## Prompt templates
### Creating templates

Some files were not shown because too many files have changed in this diff Show More