From d7c4565999bb01871cedd914d2073ebfda8cba5c Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Mon, 25 May 2026 13:30:22 +0300 Subject: [PATCH] refactor: remove dead code, fix SDK leakage, deduplicate helpers - Remove unused SetOpenAICredentials/validateOpenAIAPIKey (internal/auth) - Remove unused SudoPasswordRequiredMetadata/IsSudoPasswordRequiredResult (internal/core) - Add Extension* type aliases in pkg/kit/extension_api.go so the public ExtensionAPI interface no longer exposes internal/extensions types - Extract bridgeObserve generic helper and llmToContextMessages / contextMessagesToLLM in pkg/kit/extensions_bridge.go (~150 lines saved) - Extract parseHeaders and buildOAuthConfig in connection_pool.go to deduplicate SSE/Streamable client construction (~60 lines saved) - Eliminate redundant second buildInteractiveExtensionContext call in cmd/root.go; swap print closures on the same context instead - Replace 'Fantasy' with 'agent' in internal comment (pkg/kit/kit.go) --- cmd/root.go | 12 +- internal/auth/credentials.go | 43 ---- internal/core/bash.go | 9 - internal/tools/connection_pool.go | 130 ++++++----- pkg/kit/extension_api.go | 89 ++++++-- pkg/kit/extensions_bridge.go | 361 ++++++++++++------------------ pkg/kit/kit.go | 2 +- 7 files changed, 278 insertions(+), 368 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index f9a284b3..e8478318 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -899,8 +899,9 @@ func runNormalMode(ctx context.Context) error { appInstance: appInstance, usageTracker: usageTracker, }) + + // During startup, buffer extension messages so they appear after the banner. extCtx.Print = func(text string) { - // Capture messages during startup, print after startup banner. startupExtensionMessages = append(startupExtensionMessages, text) } extCtx.PrintInfo = func(text string) { @@ -913,15 +914,6 @@ func runNormalMode(ctx context.Context) error { kitInstance.Extensions().EmitSessionStart() // Restore normal print functions for runtime use. - extCtx = buildInteractiveExtensionContext(extensionContextDeps{ - ctx: ctx, - cwd: cwd, - modelName: modelName, - interactive: positionalPrompt == "", - kitInstance: kitInstance, - appInstance: appInstance, - usageTracker: usageTracker, - }) extCtx.Print = func(text string) { appInstance.PrintFromExtension("", text) } extCtx.PrintInfo = func(text string) { appInstance.PrintFromExtension("info", text) } extCtx.PrintError = func(text string) { appInstance.PrintFromExtension("error", text) } diff --git a/internal/auth/credentials.go b/internal/auth/credentials.go index ff9588df..b226710c 100644 --- a/internal/auth/credentials.go +++ b/internal/auth/credentials.go @@ -255,29 +255,6 @@ func (cm *CredentialManager) HasAnthropicCredentials() (bool, error) { } } -// SetOpenAICredentials stores OpenAI API key credentials. It validates the -// API key format before storing. The API key must start with "sk-" and be -// at least 20 characters long. Returns an error if the API key is invalid or -// if storage fails. -func (cm *CredentialManager) SetOpenAICredentials(apiKey string) error { - if err := validateOpenAIAPIKey(apiKey); err != nil { - return err - } - - store, err := cm.LoadCredentials() - if err != nil { - return err - } - - store.OpenAI = &OpenAICredentials{ - Type: "api_key", - APIKey: apiKey, - CreatedAt: time.Now(), - } - - return cm.SaveCredentials(store) -} - // GetOpenAICredentials retrieves stored OpenAI credentials. Returns nil if // no credentials are stored. The returned credentials may be either OAuth or API // key type, check the Type field to determine which. @@ -417,26 +394,6 @@ func validateAnthropicAPIKey(apiKey string) error { return nil } -// validateOpenAIAPIKey validates the format of an OpenAI API key -func validateOpenAIAPIKey(apiKey string) error { - apiKey = strings.TrimSpace(apiKey) - - if apiKey == "" { - return fmt.Errorf("API key cannot be empty") - } - - // OpenAI API keys typically start with "sk-" and are quite long - if !strings.HasPrefix(apiKey, "sk-") { - return fmt.Errorf("invalid OpenAI API key format (should start with 'sk-')") - } - - if len(apiKey) < 20 { - return fmt.Errorf("API key appears to be too short") - } - - return nil -} - // GetAnthropicAPIKey retrieves an Anthropic API key from multiple sources in priority order: // 1. Command-line flag value (highest priority) // 2. Stored credentials (OAuth or API key) diff --git a/internal/core/bash.go b/internal/core/bash.go index ad637665..0e641f93 100644 --- a/internal/core/bash.go +++ b/internal/core/bash.go @@ -160,15 +160,6 @@ func rewriteSudoForStdin(command string) string { return result } -// SudoPasswordRequiredResult is a special marker that indicates sudo needs a password. -// This is stored in tool response metadata to signal the TUI to prompt for password. -const SudoPasswordRequiredMetadata = `{"sudo_password_required":true}` - -// IsSudoPasswordRequiredResult checks if a tool response indicates sudo password is needed. -func IsSudoPasswordRequiredResult(resp fantasy.ToolResponse) bool { - return resp.Metadata == SudoPasswordRequiredMetadata -} - func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) { var args bashArgs if err := parseArgs(call.Input, &args); err != nil { diff --git a/internal/tools/connection_pool.go b/internal/tools/connection_pool.go index 8517ea92..f4964089 100644 --- a/internal/tools/connection_pool.go +++ b/internal/tools/connection_pool.go @@ -345,49 +345,70 @@ func (p *MCPConnectionPool) createStdioClient(ctx context.Context, serverConfig return stdioClient, nil } -// createSSEClient creates an SSE client +// parseHeaders parses "Key: Value" header strings into a map. +func parseHeaders(raw []string) map[string]string { + if len(raw) == 0 { + return nil + } + headers := make(map[string]string) + for _, header := range raw { + parts := strings.SplitN(header, ":", 2) + if len(parts) == 2 { + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + headers[key] = value + } + } + if len(headers) == 0 { + return nil + } + return headers +} + +// buildOAuthConfig constructs a transport.OAuthConfig from the server config +// and the pool's OAuth flow. Returns nil if OAuth is not applicable. +func (p *MCPConnectionPool) buildOAuthConfig(serverConfig config.MCPServerConfig) (*transport.OAuthConfig, error) { + if p.oauthFlow == nil || serverConfig.NoOAuth { + return nil, nil + } + tokenStore, err := p.createTokenStore(serverConfig.URL) + if err != nil { + return nil, fmt.Errorf("failed to create token store: %w", err) + } + cfg := &transport.OAuthConfig{ + RedirectURI: p.oauthFlow.handler.RedirectURI(), + PKCEEnabled: true, + TokenStore: tokenStore, + } + if serverConfig.OAuthClientID != "" { + cfg.ClientID = serverConfig.OAuthClientID + } + if serverConfig.OAuthClientSecret != "" { + cfg.ClientSecret = serverConfig.OAuthClientSecret + } + if len(serverConfig.OAuthScopes) > 0 { + cfg.Scopes = serverConfig.OAuthScopes + } + return cfg, nil +} + func (p *MCPConnectionPool) createSSEClient(ctx context.Context, serverConfig config.MCPServerConfig) (client.MCPClient, error) { var options []transport.ClientOption - if len(serverConfig.Headers) > 0 { - headers := make(map[string]string) - for _, header := range serverConfig.Headers { - parts := strings.SplitN(header, ":", 2) - if len(parts) == 2 { - key := strings.TrimSpace(parts[0]) - value := strings.TrimSpace(parts[1]) - headers[key] = value - } - } - if len(headers) > 0 { - options = append(options, transport.WithHeaders(headers)) - } + if headers := parseHeaders(serverConfig.Headers); headers != nil { + options = append(options, transport.WithHeaders(headers)) } // Enable OAuth for remote transports when an auth handler is configured // and the server hasn't opted out via NoOAuth. Public MCP servers (e.g. // PubMed) set NoOAuth to skip dynamic client registration and token // exchange, which would otherwise fail with a 404. - if p.oauthFlow != nil && !serverConfig.NoOAuth { - tokenStore, tsErr := p.createTokenStore(serverConfig.URL) - if tsErr != nil { - return nil, fmt.Errorf("failed to create token store: %w", tsErr) - } - oauthCfg := transport.OAuthConfig{ - RedirectURI: p.oauthFlow.handler.RedirectURI(), - PKCEEnabled: true, - TokenStore: tokenStore, - } - if serverConfig.OAuthClientID != "" { - oauthCfg.ClientID = serverConfig.OAuthClientID - } - if serverConfig.OAuthClientSecret != "" { - oauthCfg.ClientSecret = serverConfig.OAuthClientSecret - } - if len(serverConfig.OAuthScopes) > 0 { - oauthCfg.Scopes = serverConfig.OAuthScopes - } - options = append(options, transport.WithOAuth(oauthCfg)) + oauthCfg, err := p.buildOAuthConfig(serverConfig) + if err != nil { + return nil, err + } + if oauthCfg != nil { + options = append(options, transport.WithOAuth(*oauthCfg)) } sseClient, err := client.NewSSEMCPClient(serverConfig.URL, options...) @@ -406,43 +427,18 @@ func (p *MCPConnectionPool) createSSEClient(ctx context.Context, serverConfig co func (p *MCPConnectionPool) createStreamableClient(ctx context.Context, serverConfig config.MCPServerConfig) (client.MCPClient, error) { var options []transport.StreamableHTTPCOption - if len(serverConfig.Headers) > 0 { - headers := make(map[string]string) - for _, header := range serverConfig.Headers { - parts := strings.SplitN(header, ":", 2) - if len(parts) == 2 { - key := strings.TrimSpace(parts[0]) - value := strings.TrimSpace(parts[1]) - headers[key] = value - } - } - if len(headers) > 0 { - options = append(options, transport.WithHTTPHeaders(headers)) - } + if headers := parseHeaders(serverConfig.Headers); headers != nil { + options = append(options, transport.WithHTTPHeaders(headers)) } // Enable OAuth for remote transports when an auth handler is configured // and the server hasn't opted out via NoOAuth. - if p.oauthFlow != nil && !serverConfig.NoOAuth { - tokenStore, tsErr := p.createTokenStore(serverConfig.URL) - if tsErr != nil { - return nil, fmt.Errorf("failed to create token store: %w", tsErr) - } - oauthCfg := transport.OAuthConfig{ - RedirectURI: p.oauthFlow.handler.RedirectURI(), - PKCEEnabled: true, - TokenStore: tokenStore, - } - if serverConfig.OAuthClientID != "" { - oauthCfg.ClientID = serverConfig.OAuthClientID - } - if serverConfig.OAuthClientSecret != "" { - oauthCfg.ClientSecret = serverConfig.OAuthClientSecret - } - if len(serverConfig.OAuthScopes) > 0 { - oauthCfg.Scopes = serverConfig.OAuthScopes - } - options = append(options, transport.WithHTTPOAuth(oauthCfg)) + oauthCfg, err := p.buildOAuthConfig(serverConfig) + if err != nil { + return nil, err + } + if oauthCfg != nil { + options = append(options, transport.WithHTTPOAuth(*oauthCfg)) } streamableClient, err := client.NewStreamableHttpClient(serverConfig.URL, options...) diff --git a/pkg/kit/extension_api.go b/pkg/kit/extension_api.go index 50fba36b..efdc1a82 100644 --- a/pkg/kit/extension_api.go +++ b/pkg/kit/extension_api.go @@ -8,55 +8,104 @@ import ( "github.com/mark3labs/kit/internal/session" ) +// ==== Extension Types ==== +// +// Type aliases for internal extension types exposed through the public +// ExtensionAPI interface. External SDK consumers can use these without +// importing internal packages directly. + +// ExtensionContext holds the runtime context passed to extensions, including +// callbacks for printing, sending messages, and accessing session state. +type ExtensionContext = extensions.Context + +// ExtensionWidgetConfig describes a widget registered by an extension. +type ExtensionWidgetConfig = extensions.WidgetConfig + +// ExtensionWidgetPlacement indicates where a widget should be rendered +// (e.g. above or below the conversation). +type ExtensionWidgetPlacement = extensions.WidgetPlacement + +// ExtensionHeaderFooterConfig describes a header or footer registered by an extension. +type ExtensionHeaderFooterConfig = extensions.HeaderFooterConfig + +// ExtensionEditorConfig configures editor behaviour overrides set by extensions. +type ExtensionEditorConfig = extensions.EditorConfig + +// ExtensionUIVisibility controls which UI elements are visible. +type ExtensionUIVisibility = extensions.UIVisibility + +// ExtensionToolRenderConfig describes custom tool output rendering registered by an extension. +type ExtensionToolRenderConfig = extensions.ToolRenderConfig + +// ExtensionMessageRendererConfig describes custom message rendering registered by an extension. +type ExtensionMessageRendererConfig = extensions.MessageRendererConfig + +// ExtensionSessionMessage represents a single message in the session history +// as exposed to extensions. +type ExtensionSessionMessage = extensions.SessionMessage + +// ExtensionEntry represents a custom data entry stored by an extension +// in the session tree. +type ExtensionEntry = extensions.ExtensionEntry + +// ExtensionStatusBarEntry describes a status bar entry registered by an extension. +type ExtensionStatusBarEntry = extensions.StatusBarEntry + +// ExtensionToolInfo describes a tool available to the agent, as seen by extensions. +type ExtensionToolInfo = extensions.ToolInfo + +// ExtensionCommandDef describes a slash command registered by an extension. +type ExtensionCommandDef = extensions.CommandDef + // ExtensionAPI provides grouped access to all extension-related functionality. // This cleans up the main Kit API surface while keeping all extension capabilities available. type ExtensionAPI interface { // Context management - SetContext(ctx extensions.Context) - GetContext() extensions.Context + SetContext(ctx ExtensionContext) + GetContext() ExtensionContext UpdateContextModel(model string) // Widgets - SetWidget(config extensions.WidgetConfig) + SetWidget(config ExtensionWidgetConfig) RemoveWidget(id string) - GetWidgets(placement extensions.WidgetPlacement) []extensions.WidgetConfig + GetWidgets(placement ExtensionWidgetPlacement) []ExtensionWidgetConfig // Header/Footer - SetHeader(config extensions.HeaderFooterConfig) + SetHeader(config ExtensionHeaderFooterConfig) RemoveHeader() - GetHeader() *extensions.HeaderFooterConfig - SetFooter(config extensions.HeaderFooterConfig) + GetHeader() *ExtensionHeaderFooterConfig + SetFooter(config ExtensionHeaderFooterConfig) RemoveFooter() - GetFooter() *extensions.HeaderFooterConfig + GetFooter() *ExtensionHeaderFooterConfig // Editor - SetEditor(config extensions.EditorConfig) + SetEditor(config ExtensionEditorConfig) ResetEditor() - GetEditor() *extensions.EditorConfig + GetEditor() *ExtensionEditorConfig // UI Visibility - SetUIVisibility(v extensions.UIVisibility) - GetUIVisibility() *extensions.UIVisibility + SetUIVisibility(v ExtensionUIVisibility) + GetUIVisibility() *ExtensionUIVisibility // Tool rendering - GetToolRenderer(toolName string) *extensions.ToolRenderConfig - GetMessageRenderer(name string) *extensions.MessageRendererConfig + GetToolRenderer(toolName string) *ExtensionToolRenderConfig + GetMessageRenderer(name string) *ExtensionMessageRendererConfig // Session data - GetSessionMessages() []extensions.SessionMessage + GetSessionMessages() []ExtensionSessionMessage AppendEntry(extType, data string) (string, error) - GetEntries(extType string) []extensions.ExtensionEntry + GetEntries(extType string) []ExtensionEntry // Status bar - SetStatus(entry extensions.StatusBarEntry) + SetStatus(entry ExtensionStatusBarEntry) RemoveStatus(key string) - GetStatusEntries() []extensions.StatusBarEntry + GetStatusEntries() []ExtensionStatusBarEntry // Shortcuts GetShortcuts() map[string]func() // Tools - GetToolInfos() []extensions.ToolInfo + GetToolInfos() []ExtensionToolInfo SetActiveTools(names []string) // Options @@ -71,7 +120,7 @@ type ExtensionAPI interface { EmitBeforeSessionSwitch(switchReason string) (cancelled bool, reason string) // Commands - Commands() []extensions.CommandDef + Commands() []ExtensionCommandDef // Lifecycle Reload() error diff --git a/pkg/kit/extensions_bridge.go b/pkg/kit/extensions_bridge.go index af589786..03ac8983 100644 --- a/pkg/kit/extensions_bridge.go +++ b/pkg/kit/extensions_bridge.go @@ -54,83 +54,51 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) { // Subscribe to SDK events and forward to extension runner so extensions // see lifecycle events from the SDK's runTurn()/generate() path. - if runner.HasHandlers(extensions.AgentStart) { - m.Subscribe(func(e Event) { - if ev, ok := e.(TurnStartEvent); ok { - _, _ = runner.Emit(extensions.AgentStartEvent{Prompt: ev.Prompt}) - } - }) - } + bridgeObserve(m, runner, extensions.AgentStart, func(ev TurnStartEvent) extensions.Event { + return extensions.AgentStartEvent{Prompt: ev.Prompt} + }) - if runner.HasHandlers(extensions.MessageStart) { - m.Subscribe(func(e Event) { - if _, ok := e.(MessageStartEvent); ok { - _, _ = runner.Emit(extensions.MessageStartEvent{}) - } - }) - } + bridgeObserve(m, runner, extensions.MessageStart, func(_ MessageStartEvent) extensions.Event { + return extensions.MessageStartEvent{} + }) - if runner.HasHandlers(extensions.MessageUpdate) { - m.Subscribe(func(e Event) { - if ev, ok := e.(MessageUpdateEvent); ok { - _, _ = runner.Emit(extensions.MessageUpdateEvent{Chunk: ev.Chunk}) - } - }) - } + bridgeObserve(m, runner, extensions.MessageUpdate, func(ev MessageUpdateEvent) extensions.Event { + return extensions.MessageUpdateEvent{Chunk: ev.Chunk} + }) - if runner.HasHandlers(extensions.MessageEnd) { - m.Subscribe(func(e Event) { - if ev, ok := e.(MessageEndEvent); ok { - _, _ = runner.Emit(extensions.MessageEndEvent{Content: ev.Content}) - } - }) - } + bridgeObserve(m, runner, extensions.MessageEnd, func(ev MessageEndEvent) extensions.Event { + return extensions.MessageEndEvent{Content: ev.Content} + }) // Tool output streaming events (observation only). - if runner.HasHandlers(extensions.ToolOutput) { - m.Subscribe(func(e Event) { - if ev, ok := e.(ToolOutputEvent); ok { - _, _ = runner.Emit(extensions.ToolOutputEvent{ - ToolCallID: ev.ToolCallID, - ToolName: ev.ToolName, - Chunk: ev.Chunk, - IsStderr: ev.IsStderr, - }) - } - }) - } + bridgeObserve(m, runner, extensions.ToolOutput, func(ev ToolOutputEvent) extensions.Event { + return extensions.ToolOutputEvent{ + ToolCallID: ev.ToolCallID, + ToolName: ev.ToolName, + Chunk: ev.Chunk, + IsStderr: ev.IsStderr, + } + }) // Tool call input streaming events — fire as the LLM generates tool arguments. - if runner.HasHandlers(extensions.ToolCallInputStart) { - m.Subscribe(func(e Event) { - if ev, ok := e.(ToolCallStartEvent); ok { - _, _ = runner.Emit(extensions.ToolCallInputStartEvent{ - ToolCallID: ev.ToolCallID, - ToolName: ev.ToolName, - ToolKind: ev.ToolKind, - }) - } - }) - } - if runner.HasHandlers(extensions.ToolCallInputDelta) { - m.Subscribe(func(e Event) { - if ev, ok := e.(ToolCallDeltaEvent); ok { - _, _ = runner.Emit(extensions.ToolCallInputDeltaEvent{ - ToolCallID: ev.ToolCallID, - Delta: ev.Delta, - }) - } - }) - } - if runner.HasHandlers(extensions.ToolCallInputEnd) { - m.Subscribe(func(e Event) { - if ev, ok := e.(ToolCallEndEvent); ok { - _, _ = runner.Emit(extensions.ToolCallInputEndEvent{ - ToolCallID: ev.ToolCallID, - }) - } - }) - } + bridgeObserve(m, runner, extensions.ToolCallInputStart, func(ev ToolCallStartEvent) extensions.Event { + return extensions.ToolCallInputStartEvent{ + ToolCallID: ev.ToolCallID, + ToolName: ev.ToolName, + ToolKind: ev.ToolKind, + } + }) + bridgeObserve(m, runner, extensions.ToolCallInputDelta, func(ev ToolCallDeltaEvent) extensions.Event { + return extensions.ToolCallInputDeltaEvent{ + ToolCallID: ev.ToolCallID, + Delta: ev.Delta, + } + }) + bridgeObserve(m, runner, extensions.ToolCallInputEnd, func(ev ToolCallEndEvent) extensions.Event { + return extensions.ToolCallInputEndEvent{ + ToolCallID: ev.ToolCallID, + } + }) if runner.HasHandlers(extensions.AgentEnd) { m.Subscribe(func(e Event) { @@ -278,54 +246,13 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) { // Extension ContextPrepare → SDK ContextPrepare hook. if runner.HasHandlers(extensions.ContextPrepare) { m.OnContextPrepare(HookPriorityNormal, func(h ContextPrepareHook) *ContextPrepareResult { - // Convert LLM message slice to extension ContextMessage slice. - // Extract plain text from each message for the extension API. - extMsgs := make([]extensions.ContextMessage, len(h.Messages)) - for i, msg := range h.Messages { - var sb strings.Builder - for _, part := range msg.Content { - if tp, ok := part.(LLMTextPart); ok { - sb.WriteString(tp.Text) - } - } - extMsgs[i] = extensions.ContextMessage{ - Index: i, - Role: string(msg.Role), - Content: sb.String(), - } - } - + extMsgs := llmToContextMessages(h.Messages) result, _ := runner.Emit(extensions.ContextPrepareEvent{Messages: extMsgs}) r, ok := result.(extensions.ContextPrepareResult) if !ok || r.Messages == nil { return nil } - - // Rebuild LLM message slice from extension result. - rebuilt := make([]LLMMessage, 0, len(r.Messages)) - for _, cm := range r.Messages { - if cm.Index >= 0 && cm.Index < len(h.Messages) { - // Reuse original message (preserves original role and content). - rebuilt = append(rebuilt, h.Messages[cm.Index]) - } else { - // New message injected by extension — construct from role + text. - role := LLMRoleUser - switch cm.Role { - case "assistant": - role = LLMRoleAssistant - case "system": - role = LLMRoleSystem - case "tool": - role = LLMRoleTool - } - rebuilt = append(rebuilt, LLMMessage{ - Role: role, - Content: []LLMMessagePart{LLMTextPart{Text: cm.Content}}, - }) - } - } - - return &ContextPrepareResult{Messages: rebuilt} + return &ContextPrepareResult{Messages: contextMessagesToLLM(r.Messages, h.Messages)} }) } @@ -359,99 +286,56 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) { // --- Step lifecycle observation events --- - if runner.HasHandlers(extensions.StepStart) { - m.Subscribe(func(e Event) { - if ev, ok := e.(StepStartEvent); ok { - _, _ = runner.Emit(extensions.StepStartEvent{StepNumber: ev.StepNumber}) - } - }) - } + bridgeObserve(m, runner, extensions.StepStart, func(ev StepStartEvent) extensions.Event { + return extensions.StepStartEvent{StepNumber: ev.StepNumber} + }) - if runner.HasHandlers(extensions.StepFinish) { - m.Subscribe(func(e Event) { - if ev, ok := e.(StepFinishEvent); ok { - _, _ = runner.Emit(extensions.StepFinishEvent{ - StepNumber: ev.StepNumber, - HasToolCalls: ev.HasToolCalls, - FinishReason: ev.FinishReason, - InputTokens: ev.Usage.InputTokens, - OutputTokens: ev.Usage.OutputTokens, - CacheReadTokens: ev.Usage.CacheReadTokens, - CacheWriteTokens: ev.Usage.CacheCreationTokens, - }) - } - }) - } + bridgeObserve(m, runner, extensions.StepFinish, func(ev StepFinishEvent) extensions.Event { + return extensions.StepFinishEvent{ + StepNumber: ev.StepNumber, + HasToolCalls: ev.HasToolCalls, + FinishReason: ev.FinishReason, + InputTokens: ev.Usage.InputTokens, + OutputTokens: ev.Usage.OutputTokens, + CacheReadTokens: ev.Usage.CacheReadTokens, + CacheWriteTokens: ev.Usage.CacheCreationTokens, + } + }) - if runner.HasHandlers(extensions.ReasoningStart) { - m.Subscribe(func(e Event) { - if ev, ok := e.(ReasoningStartEvent); ok { - _, _ = runner.Emit(extensions.ReasoningStartEvent{ID: ev.ID}) - } - }) - } + bridgeObserve(m, runner, extensions.ReasoningStart, func(ev ReasoningStartEvent) extensions.Event { + return extensions.ReasoningStartEvent{ID: ev.ID} + }) - if runner.HasHandlers(extensions.Warnings) { - m.Subscribe(func(e Event) { - if ev, ok := e.(WarningsEvent); ok { - _, _ = runner.Emit(extensions.WarningsEvent{Warnings: ev.Warnings}) - } - }) - } + bridgeObserve(m, runner, extensions.Warnings, func(ev WarningsEvent) extensions.Event { + return extensions.WarningsEvent{Warnings: ev.Warnings} + }) - if runner.HasHandlers(extensions.Source) { - m.Subscribe(func(e Event) { - if ev, ok := e.(SourceEvent); ok { - _, _ = runner.Emit(extensions.SourceEvent{ - SourceType: ev.SourceType, - ID: ev.ID, - URL: ev.URL, - Title: ev.Title, - }) - } - }) - } + bridgeObserve(m, runner, extensions.Source, func(ev SourceEvent) extensions.Event { + return extensions.SourceEvent{ + SourceType: ev.SourceType, + ID: ev.ID, + URL: ev.URL, + Title: ev.Title, + } + }) - if runner.HasHandlers(extensions.Error) { - m.Subscribe(func(e Event) { - if ev, ok := e.(ErrorEvent); ok { - _, _ = runner.Emit(extensions.ErrorEvent{Error: ev.Error.Error()}) - } - }) - } + bridgeObserve(m, runner, extensions.Error, func(ev ErrorEvent) extensions.Event { + return extensions.ErrorEvent{Error: ev.Error.Error()} + }) - if runner.HasHandlers(extensions.Retry) { - m.Subscribe(func(e Event) { - if ev, ok := e.(RetryEvent); ok { - _, _ = runner.Emit(extensions.RetryEvent{ - Attempt: ev.Attempt, - Error: ev.Error.Error(), - }) - } - }) - } + bridgeObserve(m, runner, extensions.Retry, func(ev RetryEvent) extensions.Event { + return extensions.RetryEvent{ + Attempt: ev.Attempt, + Error: ev.Error.Error(), + } + }) // --- PrepareStep hook --- // Extension PrepareStep → SDK PrepareStep hook. // Same pattern as ContextPrepare: convert LLMMessage ↔ ContextMessage. if runner.HasHandlers(extensions.PrepareStep) { m.OnPrepareStep(HookPriorityNormal, func(h PrepareStepHook) *PrepareStepResult { - // Convert LLM message slice to extension ContextMessage slice. - extMsgs := make([]extensions.ContextMessage, len(h.Messages)) - for i, msg := range h.Messages { - var sb strings.Builder - for _, part := range msg.Content { - if tp, ok := part.(LLMTextPart); ok { - sb.WriteString(tp.Text) - } - } - extMsgs[i] = extensions.ContextMessage{ - Index: i, - Role: string(msg.Role), - Content: sb.String(), - } - } - + extMsgs := llmToContextMessages(h.Messages) result, _ := runner.Emit(extensions.PrepareStepEvent{ StepNumber: h.StepNumber, Messages: extMsgs, @@ -460,30 +344,71 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) { if !ok || r.Messages == nil { return nil } - - // Rebuild LLM message slice from extension result. - rebuilt := make([]LLMMessage, 0, len(r.Messages)) - for _, cm := range r.Messages { - if cm.Index >= 0 && cm.Index < len(h.Messages) { - rebuilt = append(rebuilt, h.Messages[cm.Index]) - } else { - role := LLMRoleUser - switch cm.Role { - case "assistant": - role = LLMRoleAssistant - case "system": - role = LLMRoleSystem - case "tool": - role = LLMRoleTool - } - rebuilt = append(rebuilt, LLMMessage{ - Role: role, - Content: []LLMMessagePart{LLMTextPart{Text: cm.Content}}, - }) - } - } - - return &PrepareStepResult{Messages: rebuilt} + return &PrepareStepResult{Messages: contextMessagesToLLM(r.Messages, h.Messages)} }) } } + +// bridgeObserve subscribes to SDK events of type In and forwards them to the +// extension runner as the event returned by conv. The subscription is only +// registered when the runner has handlers for the given event kind. +func bridgeObserve[In Event](m *Kit, runner *extensions.Runner, kind extensions.EventType, conv func(In) extensions.Event) { + if !runner.HasHandlers(kind) { + return + } + m.Subscribe(func(e Event) { + if ev, ok := e.(In); ok { + _, _ = runner.Emit(conv(ev)) + } + }) +} + +// llmToContextMessages converts a slice of LLM messages to extension +// ContextMessage values, extracting plain text from each message. +func llmToContextMessages(msgs []LLMMessage) []extensions.ContextMessage { + extMsgs := make([]extensions.ContextMessage, len(msgs)) + for i, msg := range msgs { + var sb strings.Builder + for _, part := range msg.Content { + if tp, ok := part.(LLMTextPart); ok { + sb.WriteString(tp.Text) + } + } + extMsgs[i] = extensions.ContextMessage{ + Index: i, + Role: string(msg.Role), + Content: sb.String(), + } + } + return extMsgs +} + +// contextMessagesToLLM rebuilds an LLM message slice from extension +// ContextMessages. Messages with a valid index reuse the original from +// originals; new messages injected by extensions are constructed from +// role + text. +func contextMessagesToLLM(cms []extensions.ContextMessage, originals []LLMMessage) []LLMMessage { + rebuilt := make([]LLMMessage, 0, len(cms)) + for _, cm := range cms { + if cm.Index >= 0 && cm.Index < len(originals) { + // Reuse original message (preserves original role and content). + rebuilt = append(rebuilt, originals[cm.Index]) + } else { + // New message injected by extension — construct from role + text. + role := LLMRoleUser + switch cm.Role { + case "assistant": + role = LLMRoleAssistant + case "system": + role = LLMRoleSystem + case "tool": + role = LLMRoleTool + } + rebuilt = append(rebuilt, LLMMessage{ + Role: role, + Content: []LLMMessagePart{LLMTextPart{Text: cm.Content}}, + }) + } + } + return rebuilt +} diff --git a/pkg/kit/kit.go b/pkg/kit/kit.go index e79048a4..1df69b09 100644 --- a/pkg/kit/kit.go +++ b/pkg/kit/kit.go @@ -2126,7 +2126,7 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent. }) }, - // New callbacks for previously unwired Fantasy lifecycle events. + // New callbacks for previously unwired agent lifecycle events. OnStepStart: func(stepNumber int) { m.events.emit(StepStartEvent{StepNumber: stepNumber}) },