refactor: remove dead code, fix SDK leakage, deduplicate helpers

- Remove unused SetOpenAICredentials/validateOpenAIAPIKey (internal/auth)
- Remove unused SudoPasswordRequiredMetadata/IsSudoPasswordRequiredResult
  (internal/core)
- Add Extension* type aliases in pkg/kit/extension_api.go so the public
  ExtensionAPI interface no longer exposes internal/extensions types
- Extract bridgeObserve generic helper and llmToContextMessages /
  contextMessagesToLLM in pkg/kit/extensions_bridge.go (~150 lines saved)
- Extract parseHeaders and buildOAuthConfig in connection_pool.go to
  deduplicate SSE/Streamable client construction (~60 lines saved)
- Eliminate redundant second buildInteractiveExtensionContext call in
  cmd/root.go; swap print closures on the same context instead
- Replace 'Fantasy' with 'agent' in internal comment (pkg/kit/kit.go)
This commit is contained in:
Ed Zynda
2026-05-25 13:30:22 +03:00
parent bd24f3315c
commit d7c4565999
7 changed files with 278 additions and 368 deletions
+2 -10
View File
@@ -899,8 +899,9 @@ func runNormalMode(ctx context.Context) error {
appInstance: appInstance,
usageTracker: usageTracker,
})
// During startup, buffer extension messages so they appear after the banner.
extCtx.Print = func(text string) {
// Capture messages during startup, print after startup banner.
startupExtensionMessages = append(startupExtensionMessages, text)
}
extCtx.PrintInfo = func(text string) {
@@ -913,15 +914,6 @@ func runNormalMode(ctx context.Context) error {
kitInstance.Extensions().EmitSessionStart()
// Restore normal print functions for runtime use.
extCtx = buildInteractiveExtensionContext(extensionContextDeps{
ctx: ctx,
cwd: cwd,
modelName: modelName,
interactive: positionalPrompt == "",
kitInstance: kitInstance,
appInstance: appInstance,
usageTracker: usageTracker,
})
extCtx.Print = func(text string) { appInstance.PrintFromExtension("", text) }
extCtx.PrintInfo = func(text string) { appInstance.PrintFromExtension("info", text) }
extCtx.PrintError = func(text string) { appInstance.PrintFromExtension("error", text) }
-43
View File
@@ -255,29 +255,6 @@ func (cm *CredentialManager) HasAnthropicCredentials() (bool, error) {
}
}
// SetOpenAICredentials stores OpenAI API key credentials. It validates the
// API key format before storing. The API key must start with "sk-" and be
// at least 20 characters long. Returns an error if the API key is invalid or
// if storage fails.
func (cm *CredentialManager) SetOpenAICredentials(apiKey string) error {
if err := validateOpenAIAPIKey(apiKey); err != nil {
return err
}
store, err := cm.LoadCredentials()
if err != nil {
return err
}
store.OpenAI = &OpenAICredentials{
Type: "api_key",
APIKey: apiKey,
CreatedAt: time.Now(),
}
return cm.SaveCredentials(store)
}
// GetOpenAICredentials retrieves stored OpenAI credentials. Returns nil if
// no credentials are stored. The returned credentials may be either OAuth or API
// key type, check the Type field to determine which.
@@ -417,26 +394,6 @@ func validateAnthropicAPIKey(apiKey string) error {
return nil
}
// validateOpenAIAPIKey validates the format of an OpenAI API key
func validateOpenAIAPIKey(apiKey string) error {
apiKey = strings.TrimSpace(apiKey)
if apiKey == "" {
return fmt.Errorf("API key cannot be empty")
}
// OpenAI API keys typically start with "sk-" and are quite long
if !strings.HasPrefix(apiKey, "sk-") {
return fmt.Errorf("invalid OpenAI API key format (should start with 'sk-')")
}
if len(apiKey) < 20 {
return fmt.Errorf("API key appears to be too short")
}
return nil
}
// GetAnthropicAPIKey retrieves an Anthropic API key from multiple sources in priority order:
// 1. Command-line flag value (highest priority)
// 2. Stored credentials (OAuth or API key)
-9
View File
@@ -160,15 +160,6 @@ func rewriteSudoForStdin(command string) string {
return result
}
// SudoPasswordRequiredResult is a special marker that indicates sudo needs a password.
// This is stored in tool response metadata to signal the TUI to prompt for password.
const SudoPasswordRequiredMetadata = `{"sudo_password_required":true}`
// IsSudoPasswordRequiredResult checks if a tool response indicates sudo password is needed.
func IsSudoPasswordRequiredResult(resp fantasy.ToolResponse) bool {
return resp.Metadata == SudoPasswordRequiredMetadata
}
func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
var args bashArgs
if err := parseArgs(call.Input, &args); err != nil {
+63 -67
View File
@@ -345,49 +345,70 @@ func (p *MCPConnectionPool) createStdioClient(ctx context.Context, serverConfig
return stdioClient, nil
}
// createSSEClient creates an SSE client
// parseHeaders parses "Key: Value" header strings into a map.
func parseHeaders(raw []string) map[string]string {
if len(raw) == 0 {
return nil
}
headers := make(map[string]string)
for _, header := range raw {
parts := strings.SplitN(header, ":", 2)
if len(parts) == 2 {
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
headers[key] = value
}
}
if len(headers) == 0 {
return nil
}
return headers
}
// buildOAuthConfig constructs a transport.OAuthConfig from the server config
// and the pool's OAuth flow. Returns nil if OAuth is not applicable.
func (p *MCPConnectionPool) buildOAuthConfig(serverConfig config.MCPServerConfig) (*transport.OAuthConfig, error) {
if p.oauthFlow == nil || serverConfig.NoOAuth {
return nil, nil
}
tokenStore, err := p.createTokenStore(serverConfig.URL)
if err != nil {
return nil, fmt.Errorf("failed to create token store: %w", err)
}
cfg := &transport.OAuthConfig{
RedirectURI: p.oauthFlow.handler.RedirectURI(),
PKCEEnabled: true,
TokenStore: tokenStore,
}
if serverConfig.OAuthClientID != "" {
cfg.ClientID = serverConfig.OAuthClientID
}
if serverConfig.OAuthClientSecret != "" {
cfg.ClientSecret = serverConfig.OAuthClientSecret
}
if len(serverConfig.OAuthScopes) > 0 {
cfg.Scopes = serverConfig.OAuthScopes
}
return cfg, nil
}
func (p *MCPConnectionPool) createSSEClient(ctx context.Context, serverConfig config.MCPServerConfig) (client.MCPClient, error) {
var options []transport.ClientOption
if len(serverConfig.Headers) > 0 {
headers := make(map[string]string)
for _, header := range serverConfig.Headers {
parts := strings.SplitN(header, ":", 2)
if len(parts) == 2 {
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
headers[key] = value
}
}
if len(headers) > 0 {
options = append(options, transport.WithHeaders(headers))
}
if headers := parseHeaders(serverConfig.Headers); headers != nil {
options = append(options, transport.WithHeaders(headers))
}
// Enable OAuth for remote transports when an auth handler is configured
// and the server hasn't opted out via NoOAuth. Public MCP servers (e.g.
// PubMed) set NoOAuth to skip dynamic client registration and token
// exchange, which would otherwise fail with a 404.
if p.oauthFlow != nil && !serverConfig.NoOAuth {
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
if tsErr != nil {
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
}
oauthCfg := transport.OAuthConfig{
RedirectURI: p.oauthFlow.handler.RedirectURI(),
PKCEEnabled: true,
TokenStore: tokenStore,
}
if serverConfig.OAuthClientID != "" {
oauthCfg.ClientID = serverConfig.OAuthClientID
}
if serverConfig.OAuthClientSecret != "" {
oauthCfg.ClientSecret = serverConfig.OAuthClientSecret
}
if len(serverConfig.OAuthScopes) > 0 {
oauthCfg.Scopes = serverConfig.OAuthScopes
}
options = append(options, transport.WithOAuth(oauthCfg))
oauthCfg, err := p.buildOAuthConfig(serverConfig)
if err != nil {
return nil, err
}
if oauthCfg != nil {
options = append(options, transport.WithOAuth(*oauthCfg))
}
sseClient, err := client.NewSSEMCPClient(serverConfig.URL, options...)
@@ -406,43 +427,18 @@ func (p *MCPConnectionPool) createSSEClient(ctx context.Context, serverConfig co
func (p *MCPConnectionPool) createStreamableClient(ctx context.Context, serverConfig config.MCPServerConfig) (client.MCPClient, error) {
var options []transport.StreamableHTTPCOption
if len(serverConfig.Headers) > 0 {
headers := make(map[string]string)
for _, header := range serverConfig.Headers {
parts := strings.SplitN(header, ":", 2)
if len(parts) == 2 {
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
headers[key] = value
}
}
if len(headers) > 0 {
options = append(options, transport.WithHTTPHeaders(headers))
}
if headers := parseHeaders(serverConfig.Headers); headers != nil {
options = append(options, transport.WithHTTPHeaders(headers))
}
// Enable OAuth for remote transports when an auth handler is configured
// and the server hasn't opted out via NoOAuth.
if p.oauthFlow != nil && !serverConfig.NoOAuth {
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
if tsErr != nil {
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
}
oauthCfg := transport.OAuthConfig{
RedirectURI: p.oauthFlow.handler.RedirectURI(),
PKCEEnabled: true,
TokenStore: tokenStore,
}
if serverConfig.OAuthClientID != "" {
oauthCfg.ClientID = serverConfig.OAuthClientID
}
if serverConfig.OAuthClientSecret != "" {
oauthCfg.ClientSecret = serverConfig.OAuthClientSecret
}
if len(serverConfig.OAuthScopes) > 0 {
oauthCfg.Scopes = serverConfig.OAuthScopes
}
options = append(options, transport.WithHTTPOAuth(oauthCfg))
oauthCfg, err := p.buildOAuthConfig(serverConfig)
if err != nil {
return nil, err
}
if oauthCfg != nil {
options = append(options, transport.WithHTTPOAuth(*oauthCfg))
}
streamableClient, err := client.NewStreamableHttpClient(serverConfig.URL, options...)
+69 -20
View File
@@ -8,55 +8,104 @@ import (
"github.com/mark3labs/kit/internal/session"
)
// ==== Extension Types ====
//
// Type aliases for internal extension types exposed through the public
// ExtensionAPI interface. External SDK consumers can use these without
// importing internal packages directly.
// ExtensionContext holds the runtime context passed to extensions, including
// callbacks for printing, sending messages, and accessing session state.
type ExtensionContext = extensions.Context
// ExtensionWidgetConfig describes a widget registered by an extension.
type ExtensionWidgetConfig = extensions.WidgetConfig
// ExtensionWidgetPlacement indicates where a widget should be rendered
// (e.g. above or below the conversation).
type ExtensionWidgetPlacement = extensions.WidgetPlacement
// ExtensionHeaderFooterConfig describes a header or footer registered by an extension.
type ExtensionHeaderFooterConfig = extensions.HeaderFooterConfig
// ExtensionEditorConfig configures editor behaviour overrides set by extensions.
type ExtensionEditorConfig = extensions.EditorConfig
// ExtensionUIVisibility controls which UI elements are visible.
type ExtensionUIVisibility = extensions.UIVisibility
// ExtensionToolRenderConfig describes custom tool output rendering registered by an extension.
type ExtensionToolRenderConfig = extensions.ToolRenderConfig
// ExtensionMessageRendererConfig describes custom message rendering registered by an extension.
type ExtensionMessageRendererConfig = extensions.MessageRendererConfig
// ExtensionSessionMessage represents a single message in the session history
// as exposed to extensions.
type ExtensionSessionMessage = extensions.SessionMessage
// ExtensionEntry represents a custom data entry stored by an extension
// in the session tree.
type ExtensionEntry = extensions.ExtensionEntry
// ExtensionStatusBarEntry describes a status bar entry registered by an extension.
type ExtensionStatusBarEntry = extensions.StatusBarEntry
// ExtensionToolInfo describes a tool available to the agent, as seen by extensions.
type ExtensionToolInfo = extensions.ToolInfo
// ExtensionCommandDef describes a slash command registered by an extension.
type ExtensionCommandDef = extensions.CommandDef
// ExtensionAPI provides grouped access to all extension-related functionality.
// This cleans up the main Kit API surface while keeping all extension capabilities available.
type ExtensionAPI interface {
// Context management
SetContext(ctx extensions.Context)
GetContext() extensions.Context
SetContext(ctx ExtensionContext)
GetContext() ExtensionContext
UpdateContextModel(model string)
// Widgets
SetWidget(config extensions.WidgetConfig)
SetWidget(config ExtensionWidgetConfig)
RemoveWidget(id string)
GetWidgets(placement extensions.WidgetPlacement) []extensions.WidgetConfig
GetWidgets(placement ExtensionWidgetPlacement) []ExtensionWidgetConfig
// Header/Footer
SetHeader(config extensions.HeaderFooterConfig)
SetHeader(config ExtensionHeaderFooterConfig)
RemoveHeader()
GetHeader() *extensions.HeaderFooterConfig
SetFooter(config extensions.HeaderFooterConfig)
GetHeader() *ExtensionHeaderFooterConfig
SetFooter(config ExtensionHeaderFooterConfig)
RemoveFooter()
GetFooter() *extensions.HeaderFooterConfig
GetFooter() *ExtensionHeaderFooterConfig
// Editor
SetEditor(config extensions.EditorConfig)
SetEditor(config ExtensionEditorConfig)
ResetEditor()
GetEditor() *extensions.EditorConfig
GetEditor() *ExtensionEditorConfig
// UI Visibility
SetUIVisibility(v extensions.UIVisibility)
GetUIVisibility() *extensions.UIVisibility
SetUIVisibility(v ExtensionUIVisibility)
GetUIVisibility() *ExtensionUIVisibility
// Tool rendering
GetToolRenderer(toolName string) *extensions.ToolRenderConfig
GetMessageRenderer(name string) *extensions.MessageRendererConfig
GetToolRenderer(toolName string) *ExtensionToolRenderConfig
GetMessageRenderer(name string) *ExtensionMessageRendererConfig
// Session data
GetSessionMessages() []extensions.SessionMessage
GetSessionMessages() []ExtensionSessionMessage
AppendEntry(extType, data string) (string, error)
GetEntries(extType string) []extensions.ExtensionEntry
GetEntries(extType string) []ExtensionEntry
// Status bar
SetStatus(entry extensions.StatusBarEntry)
SetStatus(entry ExtensionStatusBarEntry)
RemoveStatus(key string)
GetStatusEntries() []extensions.StatusBarEntry
GetStatusEntries() []ExtensionStatusBarEntry
// Shortcuts
GetShortcuts() map[string]func()
// Tools
GetToolInfos() []extensions.ToolInfo
GetToolInfos() []ExtensionToolInfo
SetActiveTools(names []string)
// Options
@@ -71,7 +120,7 @@ type ExtensionAPI interface {
EmitBeforeSessionSwitch(switchReason string) (cancelled bool, reason string)
// Commands
Commands() []extensions.CommandDef
Commands() []ExtensionCommandDef
// Lifecycle
Reload() error
+143 -218
View File
@@ -54,83 +54,51 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
// Subscribe to SDK events and forward to extension runner so extensions
// see lifecycle events from the SDK's runTurn()/generate() path.
if runner.HasHandlers(extensions.AgentStart) {
m.Subscribe(func(e Event) {
if ev, ok := e.(TurnStartEvent); ok {
_, _ = runner.Emit(extensions.AgentStartEvent{Prompt: ev.Prompt})
}
})
}
bridgeObserve(m, runner, extensions.AgentStart, func(ev TurnStartEvent) extensions.Event {
return extensions.AgentStartEvent{Prompt: ev.Prompt}
})
if runner.HasHandlers(extensions.MessageStart) {
m.Subscribe(func(e Event) {
if _, ok := e.(MessageStartEvent); ok {
_, _ = runner.Emit(extensions.MessageStartEvent{})
}
})
}
bridgeObserve(m, runner, extensions.MessageStart, func(_ MessageStartEvent) extensions.Event {
return extensions.MessageStartEvent{}
})
if runner.HasHandlers(extensions.MessageUpdate) {
m.Subscribe(func(e Event) {
if ev, ok := e.(MessageUpdateEvent); ok {
_, _ = runner.Emit(extensions.MessageUpdateEvent{Chunk: ev.Chunk})
}
})
}
bridgeObserve(m, runner, extensions.MessageUpdate, func(ev MessageUpdateEvent) extensions.Event {
return extensions.MessageUpdateEvent{Chunk: ev.Chunk}
})
if runner.HasHandlers(extensions.MessageEnd) {
m.Subscribe(func(e Event) {
if ev, ok := e.(MessageEndEvent); ok {
_, _ = runner.Emit(extensions.MessageEndEvent{Content: ev.Content})
}
})
}
bridgeObserve(m, runner, extensions.MessageEnd, func(ev MessageEndEvent) extensions.Event {
return extensions.MessageEndEvent{Content: ev.Content}
})
// Tool output streaming events (observation only).
if runner.HasHandlers(extensions.ToolOutput) {
m.Subscribe(func(e Event) {
if ev, ok := e.(ToolOutputEvent); ok {
_, _ = runner.Emit(extensions.ToolOutputEvent{
ToolCallID: ev.ToolCallID,
ToolName: ev.ToolName,
Chunk: ev.Chunk,
IsStderr: ev.IsStderr,
})
}
})
}
bridgeObserve(m, runner, extensions.ToolOutput, func(ev ToolOutputEvent) extensions.Event {
return extensions.ToolOutputEvent{
ToolCallID: ev.ToolCallID,
ToolName: ev.ToolName,
Chunk: ev.Chunk,
IsStderr: ev.IsStderr,
}
})
// Tool call input streaming events — fire as the LLM generates tool arguments.
if runner.HasHandlers(extensions.ToolCallInputStart) {
m.Subscribe(func(e Event) {
if ev, ok := e.(ToolCallStartEvent); ok {
_, _ = runner.Emit(extensions.ToolCallInputStartEvent{
ToolCallID: ev.ToolCallID,
ToolName: ev.ToolName,
ToolKind: ev.ToolKind,
})
}
})
}
if runner.HasHandlers(extensions.ToolCallInputDelta) {
m.Subscribe(func(e Event) {
if ev, ok := e.(ToolCallDeltaEvent); ok {
_, _ = runner.Emit(extensions.ToolCallInputDeltaEvent{
ToolCallID: ev.ToolCallID,
Delta: ev.Delta,
})
}
})
}
if runner.HasHandlers(extensions.ToolCallInputEnd) {
m.Subscribe(func(e Event) {
if ev, ok := e.(ToolCallEndEvent); ok {
_, _ = runner.Emit(extensions.ToolCallInputEndEvent{
ToolCallID: ev.ToolCallID,
})
}
})
}
bridgeObserve(m, runner, extensions.ToolCallInputStart, func(ev ToolCallStartEvent) extensions.Event {
return extensions.ToolCallInputStartEvent{
ToolCallID: ev.ToolCallID,
ToolName: ev.ToolName,
ToolKind: ev.ToolKind,
}
})
bridgeObserve(m, runner, extensions.ToolCallInputDelta, func(ev ToolCallDeltaEvent) extensions.Event {
return extensions.ToolCallInputDeltaEvent{
ToolCallID: ev.ToolCallID,
Delta: ev.Delta,
}
})
bridgeObserve(m, runner, extensions.ToolCallInputEnd, func(ev ToolCallEndEvent) extensions.Event {
return extensions.ToolCallInputEndEvent{
ToolCallID: ev.ToolCallID,
}
})
if runner.HasHandlers(extensions.AgentEnd) {
m.Subscribe(func(e Event) {
@@ -278,54 +246,13 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
// Extension ContextPrepare → SDK ContextPrepare hook.
if runner.HasHandlers(extensions.ContextPrepare) {
m.OnContextPrepare(HookPriorityNormal, func(h ContextPrepareHook) *ContextPrepareResult {
// Convert LLM message slice to extension ContextMessage slice.
// Extract plain text from each message for the extension API.
extMsgs := make([]extensions.ContextMessage, len(h.Messages))
for i, msg := range h.Messages {
var sb strings.Builder
for _, part := range msg.Content {
if tp, ok := part.(LLMTextPart); ok {
sb.WriteString(tp.Text)
}
}
extMsgs[i] = extensions.ContextMessage{
Index: i,
Role: string(msg.Role),
Content: sb.String(),
}
}
extMsgs := llmToContextMessages(h.Messages)
result, _ := runner.Emit(extensions.ContextPrepareEvent{Messages: extMsgs})
r, ok := result.(extensions.ContextPrepareResult)
if !ok || r.Messages == nil {
return nil
}
// Rebuild LLM message slice from extension result.
rebuilt := make([]LLMMessage, 0, len(r.Messages))
for _, cm := range r.Messages {
if cm.Index >= 0 && cm.Index < len(h.Messages) {
// Reuse original message (preserves original role and content).
rebuilt = append(rebuilt, h.Messages[cm.Index])
} else {
// New message injected by extension — construct from role + text.
role := LLMRoleUser
switch cm.Role {
case "assistant":
role = LLMRoleAssistant
case "system":
role = LLMRoleSystem
case "tool":
role = LLMRoleTool
}
rebuilt = append(rebuilt, LLMMessage{
Role: role,
Content: []LLMMessagePart{LLMTextPart{Text: cm.Content}},
})
}
}
return &ContextPrepareResult{Messages: rebuilt}
return &ContextPrepareResult{Messages: contextMessagesToLLM(r.Messages, h.Messages)}
})
}
@@ -359,99 +286,56 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
// --- Step lifecycle observation events ---
if runner.HasHandlers(extensions.StepStart) {
m.Subscribe(func(e Event) {
if ev, ok := e.(StepStartEvent); ok {
_, _ = runner.Emit(extensions.StepStartEvent{StepNumber: ev.StepNumber})
}
})
}
bridgeObserve(m, runner, extensions.StepStart, func(ev StepStartEvent) extensions.Event {
return extensions.StepStartEvent{StepNumber: ev.StepNumber}
})
if runner.HasHandlers(extensions.StepFinish) {
m.Subscribe(func(e Event) {
if ev, ok := e.(StepFinishEvent); ok {
_, _ = runner.Emit(extensions.StepFinishEvent{
StepNumber: ev.StepNumber,
HasToolCalls: ev.HasToolCalls,
FinishReason: ev.FinishReason,
InputTokens: ev.Usage.InputTokens,
OutputTokens: ev.Usage.OutputTokens,
CacheReadTokens: ev.Usage.CacheReadTokens,
CacheWriteTokens: ev.Usage.CacheCreationTokens,
})
}
})
}
bridgeObserve(m, runner, extensions.StepFinish, func(ev StepFinishEvent) extensions.Event {
return extensions.StepFinishEvent{
StepNumber: ev.StepNumber,
HasToolCalls: ev.HasToolCalls,
FinishReason: ev.FinishReason,
InputTokens: ev.Usage.InputTokens,
OutputTokens: ev.Usage.OutputTokens,
CacheReadTokens: ev.Usage.CacheReadTokens,
CacheWriteTokens: ev.Usage.CacheCreationTokens,
}
})
if runner.HasHandlers(extensions.ReasoningStart) {
m.Subscribe(func(e Event) {
if ev, ok := e.(ReasoningStartEvent); ok {
_, _ = runner.Emit(extensions.ReasoningStartEvent{ID: ev.ID})
}
})
}
bridgeObserve(m, runner, extensions.ReasoningStart, func(ev ReasoningStartEvent) extensions.Event {
return extensions.ReasoningStartEvent{ID: ev.ID}
})
if runner.HasHandlers(extensions.Warnings) {
m.Subscribe(func(e Event) {
if ev, ok := e.(WarningsEvent); ok {
_, _ = runner.Emit(extensions.WarningsEvent{Warnings: ev.Warnings})
}
})
}
bridgeObserve(m, runner, extensions.Warnings, func(ev WarningsEvent) extensions.Event {
return extensions.WarningsEvent{Warnings: ev.Warnings}
})
if runner.HasHandlers(extensions.Source) {
m.Subscribe(func(e Event) {
if ev, ok := e.(SourceEvent); ok {
_, _ = runner.Emit(extensions.SourceEvent{
SourceType: ev.SourceType,
ID: ev.ID,
URL: ev.URL,
Title: ev.Title,
})
}
})
}
bridgeObserve(m, runner, extensions.Source, func(ev SourceEvent) extensions.Event {
return extensions.SourceEvent{
SourceType: ev.SourceType,
ID: ev.ID,
URL: ev.URL,
Title: ev.Title,
}
})
if runner.HasHandlers(extensions.Error) {
m.Subscribe(func(e Event) {
if ev, ok := e.(ErrorEvent); ok {
_, _ = runner.Emit(extensions.ErrorEvent{Error: ev.Error.Error()})
}
})
}
bridgeObserve(m, runner, extensions.Error, func(ev ErrorEvent) extensions.Event {
return extensions.ErrorEvent{Error: ev.Error.Error()}
})
if runner.HasHandlers(extensions.Retry) {
m.Subscribe(func(e Event) {
if ev, ok := e.(RetryEvent); ok {
_, _ = runner.Emit(extensions.RetryEvent{
Attempt: ev.Attempt,
Error: ev.Error.Error(),
})
}
})
}
bridgeObserve(m, runner, extensions.Retry, func(ev RetryEvent) extensions.Event {
return extensions.RetryEvent{
Attempt: ev.Attempt,
Error: ev.Error.Error(),
}
})
// --- PrepareStep hook ---
// Extension PrepareStep → SDK PrepareStep hook.
// Same pattern as ContextPrepare: convert LLMMessage ↔ ContextMessage.
if runner.HasHandlers(extensions.PrepareStep) {
m.OnPrepareStep(HookPriorityNormal, func(h PrepareStepHook) *PrepareStepResult {
// Convert LLM message slice to extension ContextMessage slice.
extMsgs := make([]extensions.ContextMessage, len(h.Messages))
for i, msg := range h.Messages {
var sb strings.Builder
for _, part := range msg.Content {
if tp, ok := part.(LLMTextPart); ok {
sb.WriteString(tp.Text)
}
}
extMsgs[i] = extensions.ContextMessage{
Index: i,
Role: string(msg.Role),
Content: sb.String(),
}
}
extMsgs := llmToContextMessages(h.Messages)
result, _ := runner.Emit(extensions.PrepareStepEvent{
StepNumber: h.StepNumber,
Messages: extMsgs,
@@ -460,30 +344,71 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
if !ok || r.Messages == nil {
return nil
}
// Rebuild LLM message slice from extension result.
rebuilt := make([]LLMMessage, 0, len(r.Messages))
for _, cm := range r.Messages {
if cm.Index >= 0 && cm.Index < len(h.Messages) {
rebuilt = append(rebuilt, h.Messages[cm.Index])
} else {
role := LLMRoleUser
switch cm.Role {
case "assistant":
role = LLMRoleAssistant
case "system":
role = LLMRoleSystem
case "tool":
role = LLMRoleTool
}
rebuilt = append(rebuilt, LLMMessage{
Role: role,
Content: []LLMMessagePart{LLMTextPart{Text: cm.Content}},
})
}
}
return &PrepareStepResult{Messages: rebuilt}
return &PrepareStepResult{Messages: contextMessagesToLLM(r.Messages, h.Messages)}
})
}
}
// bridgeObserve subscribes to SDK events of type In and forwards them to the
// extension runner as the event returned by conv. The subscription is only
// registered when the runner has handlers for the given event kind.
func bridgeObserve[In Event](m *Kit, runner *extensions.Runner, kind extensions.EventType, conv func(In) extensions.Event) {
if !runner.HasHandlers(kind) {
return
}
m.Subscribe(func(e Event) {
if ev, ok := e.(In); ok {
_, _ = runner.Emit(conv(ev))
}
})
}
// llmToContextMessages converts a slice of LLM messages to extension
// ContextMessage values, extracting plain text from each message.
func llmToContextMessages(msgs []LLMMessage) []extensions.ContextMessage {
extMsgs := make([]extensions.ContextMessage, len(msgs))
for i, msg := range msgs {
var sb strings.Builder
for _, part := range msg.Content {
if tp, ok := part.(LLMTextPart); ok {
sb.WriteString(tp.Text)
}
}
extMsgs[i] = extensions.ContextMessage{
Index: i,
Role: string(msg.Role),
Content: sb.String(),
}
}
return extMsgs
}
// contextMessagesToLLM rebuilds an LLM message slice from extension
// ContextMessages. Messages with a valid index reuse the original from
// originals; new messages injected by extensions are constructed from
// role + text.
func contextMessagesToLLM(cms []extensions.ContextMessage, originals []LLMMessage) []LLMMessage {
rebuilt := make([]LLMMessage, 0, len(cms))
for _, cm := range cms {
if cm.Index >= 0 && cm.Index < len(originals) {
// Reuse original message (preserves original role and content).
rebuilt = append(rebuilt, originals[cm.Index])
} else {
// New message injected by extension — construct from role + text.
role := LLMRoleUser
switch cm.Role {
case "assistant":
role = LLMRoleAssistant
case "system":
role = LLMRoleSystem
case "tool":
role = LLMRoleTool
}
rebuilt = append(rebuilt, LLMMessage{
Role: role,
Content: []LLMMessagePart{LLMTextPart{Text: cm.Content}},
})
}
}
return rebuilt
}
+1 -1
View File
@@ -2126,7 +2126,7 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
})
},
// New callbacks for previously unwired Fantasy lifecycle events.
// New callbacks for previously unwired agent lifecycle events.
OnStepStart: func(stepNumber int) {
m.events.emit(StepStartEvent{StepNumber: stepNumber})
},