From 3c51c20be702fb7dbaffe4ecb13a6a4c10df3820 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Wed, 15 Apr 2026 15:23:01 +0300 Subject: [PATCH] feat(mcp): handle embedded resources in prompt messages MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Extract all MCP content types in prompt expansion: ImageContent, AudioContent, EmbeddedResource (text and blob), and ResourceLink - Add MCPFilePart type to carry decoded binary attachments through the tools → SDK → bridge → UI layers - Inline text resources as fenced code blocks with URI annotation - Decode image/audio/blob content from base64 into LLMFilePart attachments submitted via RunWithFiles - Render ResourceLink as text annotation for the LLM - Show attachment badges on user messages (e.g. '1 image(s) attached') matching the existing clipboard paste UI pattern - Log warnings on base64 decode failures instead of silently dropping --- cmd/root.go | 5 +- internal/tools/mcp.go | 139 +++++++++++-- internal/tools/mcp_prompts_test.go | 319 +++++++++++++++++++++++++++-- internal/ui/model.go | 59 +++++- pkg/kit/kit.go | 17 +- 5 files changed, 490 insertions(+), 49 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index ffc7f73b..cfc69cda 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1754,8 +1754,9 @@ func runNormalMode(ctx context.Context) error { msgs := make([]ui.MCPPromptMessageInfo, len(result.Messages)) for i, m := range result.Messages { msgs[i] = ui.MCPPromptMessageInfo{ - Role: m.Role, - Content: m.Content, + Role: m.Role, + Content: m.Content, + FileParts: m.FileParts, } } return &ui.MCPPromptExpandResult{Messages: msgs}, nil diff --git a/internal/tools/mcp.go b/internal/tools/mcp.go index 8537fcd2..da7f5086 100644 --- a/internal/tools/mcp.go +++ b/internal/tools/mcp.go @@ -10,6 +10,8 @@ import ( "strings" "sync" + log "github.com/charmbracelet/log" + "github.com/mark3labs/kit/internal/config" "github.com/mark3labs/mcp-go/mcp" ) @@ -68,6 +70,20 @@ type MCPPromptMessage struct { Role string // Content is the text content of the message. Content string + // FileParts contains binary attachments extracted from embedded resources, + // images, or audio content blocks. Empty for text-only messages. + FileParts []MCPFilePart +} + +// MCPFilePart represents a binary file attachment extracted from an MCP prompt +// content block (ImageContent, AudioContent, or EmbeddedResource with blob data). +type MCPFilePart struct { + // Filename is a best-effort name derived from the resource URI or content type. + Filename string + // Data is the raw binary content (already base64-decoded). + Data []byte + // MediaType is the MIME type (e.g. "image/png", "audio/wav"). + MediaType string } // MCPPromptResult is the result of expanding an MCP prompt via GetPrompt. @@ -650,14 +666,15 @@ func (m *MCPToolManager) GetPrompt(ctx context.Context, serverName, promptName s return nil, fmt.Errorf("failed to get prompt %q from server %q: %w", promptName, serverName, err) } - // Convert MCP messages to our types, extracting text content. + // Convert MCP messages to our types, extracting all content types. var messages []MCPPromptMessage for _, msg := range result.Messages { - text := extractContentText(msg.Content) - if text != "" { + text, fileParts := extractPromptContent(msg.Content) + if text != "" || len(fileParts) > 0 { messages = append(messages, MCPPromptMessage{ - Role: string(msg.Role), - Content: text, + Role: string(msg.Role), + Content: text, + FileParts: fileParts, }) } } @@ -668,18 +685,110 @@ func (m *MCPToolManager) GetPrompt(ctx context.Context, serverName, promptName s }, nil } -// extractContentText extracts text from an MCP Content value. -// Content can be TextContent, ImageContent, AudioContent, or EmbeddedResource. -// We only extract text content; other types are skipped. -func extractContentText(content mcp.Content) string { - if tc, ok := content.(mcp.TextContent); ok { - return tc.Text +// extractPromptContent extracts text and binary attachments from an MCP Content value. +// Handles all MCP content types: TextContent, ImageContent, AudioContent, +// EmbeddedResource (text and blob), and ResourceLink. +func extractPromptContent(content mcp.Content) (string, []MCPFilePart) { + switch c := content.(type) { + case mcp.TextContent: + return c.Text, nil + case *mcp.TextContent: + if c != nil { + return c.Text, nil + } + return "", nil + + case mcp.ImageContent: + return "", decodeBase64FilePart(c.Data, c.MIMEType, "image/png", "image.png") + case *mcp.ImageContent: + if c != nil { + return "", decodeBase64FilePart(c.Data, c.MIMEType, "image/png", "image.png") + } + return "", nil + + case mcp.AudioContent: + return "", decodeBase64FilePart(c.Data, c.MIMEType, "audio/wav", "audio.wav") + case *mcp.AudioContent: + if c != nil { + return "", decodeBase64FilePart(c.Data, c.MIMEType, "audio/wav", "audio.wav") + } + return "", nil + + case mcp.EmbeddedResource: + return extractEmbeddedResourceContent(c.Resource) + case *mcp.EmbeddedResource: + if c != nil { + return extractEmbeddedResourceContent(c.Resource) + } + return "", nil + + case mcp.ResourceLink: + // ResourceLink is a reference without inline content — include as a + // text annotation so the LLM knows about it. + return fmt.Sprintf("[Referenced resource: %s (%s)]", c.URI, c.Name), nil + case *mcp.ResourceLink: + if c != nil { + return fmt.Sprintf("[Referenced resource: %s (%s)]", c.URI, c.Name), nil + } + return "", nil + + default: + return "", nil } - // Try pointer form as well. - if tc, ok := content.(*mcp.TextContent); ok && tc != nil { - return tc.Text +} + +// extractEmbeddedResourceContent handles the two variants of embedded resource +// content: text resources are inlined as fenced code blocks, blob resources +// are base64-decoded into MCPFilePart attachments. +func extractEmbeddedResourceContent(res mcp.ResourceContents) (string, []MCPFilePart) { + switch r := res.(type) { + case mcp.TextResourceContents: + return fmt.Sprintf("[File: %s]\n```\n%s\n```", r.URI, r.Text), nil + case *mcp.TextResourceContents: + if r != nil { + return fmt.Sprintf("[File: %s]\n```\n%s\n```", r.URI, r.Text), nil + } + return "", nil + case mcp.BlobResourceContents: + return "", decodeBase64FilePart(r.Blob, r.MIMEType, "application/octet-stream", filenameFromURI(r.URI)) + case *mcp.BlobResourceContents: + if r != nil { + return "", decodeBase64FilePart(r.Blob, r.MIMEType, "application/octet-stream", filenameFromURI(r.URI)) + } + return "", nil + default: + return "", nil } - return "" +} + +// decodeBase64FilePart decodes base64-encoded data into an MCPFilePart. +// Returns nil on decode failure (logged as a warning). +func decodeBase64FilePart(data, mimeType, defaultMIME, filename string) []MCPFilePart { + decoded, err := base64.StdEncoding.DecodeString(data) + if err != nil { + log.Warn("mcp prompt: failed to decode base64 content", "filename", filename, "error", err) + return nil + } + if mimeType == "" { + mimeType = defaultMIME + } + return []MCPFilePart{{ + Filename: filename, + Data: decoded, + MediaType: mimeType, + }} +} + +// filenameFromURI extracts a filename from a URI (e.g. "file:///path/to/img.png" → "img.png"). +func filenameFromURI(uri string) string { + uri = strings.TrimPrefix(uri, "file://") + if idx := strings.LastIndex(uri, "/"); idx >= 0 { + return uri[idx+1:] + } + if uri == "" { + return "resource" + } + return uri } // loadServerPrompts loads prompts from a single MCP server connection. diff --git a/internal/tools/mcp_prompts_test.go b/internal/tools/mcp_prompts_test.go index 8f09aa10..b098e953 100644 --- a/internal/tools/mcp_prompts_test.go +++ b/internal/tools/mcp_prompts_test.go @@ -2,7 +2,9 @@ package tools import ( "context" + "encoding/base64" "fmt" + "strings" "testing" mcpclient "github.com/mark3labs/mcp-go/client" @@ -383,30 +385,307 @@ func TestLoadServerPrompts_NoPromptCapability(t *testing.T) { } } -func TestExtractContentText(t *testing.T) { - tests := []struct { - name string - content mcp.Content - want string - }{ - { - name: "TextContent", - content: mcp.TextContent{Type: "text", Text: "hello world"}, - want: "hello world", - }, - { - name: "ImageContent", - content: mcp.ImageContent{Type: "image", Data: "base64data", MIMEType: "image/png"}, - want: "", - }, - } +func TestExtractPromptContent(t *testing.T) { + t.Run("TextContent", func(t *testing.T) { + text, parts := extractPromptContent(mcp.TextContent{Type: "text", Text: "hello world"}) + if text != "hello world" { + t.Errorf("text = %q, want %q", text, "hello world") + } + if len(parts) != 0 { + t.Errorf("expected 0 file parts, got %d", len(parts)) + } + }) + t.Run("ImageContent", func(t *testing.T) { + // base64 of "fake image" + encoded := base64.StdEncoding.EncodeToString([]byte("fake image")) + text, parts := extractPromptContent(mcp.ImageContent{ + Type: "image", + Data: encoded, + MIMEType: "image/png", + }) + if text != "" { + t.Errorf("expected empty text, got %q", text) + } + if len(parts) != 1 { + t.Fatalf("expected 1 file part, got %d", len(parts)) + } + if parts[0].MediaType != "image/png" { + t.Errorf("media type = %q, want %q", parts[0].MediaType, "image/png") + } + if parts[0].Filename != "image.png" { + t.Errorf("filename = %q, want %q", parts[0].Filename, "image.png") + } + if string(parts[0].Data) != "fake image" { + t.Errorf("data = %q, want %q", string(parts[0].Data), "fake image") + } + }) + + t.Run("ImageContent_DefaultMIME", func(t *testing.T) { + encoded := base64.StdEncoding.EncodeToString([]byte("img")) + _, parts := extractPromptContent(mcp.ImageContent{ + Type: "image", + Data: encoded, + // no MIMEType → should default to image/png + }) + if len(parts) != 1 { + t.Fatalf("expected 1 file part, got %d", len(parts)) + } + if parts[0].MediaType != "image/png" { + t.Errorf("default MIME = %q, want %q", parts[0].MediaType, "image/png") + } + }) + + t.Run("AudioContent", func(t *testing.T) { + encoded := base64.StdEncoding.EncodeToString([]byte("fake audio")) + text, parts := extractPromptContent(mcp.AudioContent{ + Type: "audio", + Data: encoded, + MIMEType: "audio/mp3", + }) + if text != "" { + t.Errorf("expected empty text, got %q", text) + } + if len(parts) != 1 { + t.Fatalf("expected 1 file part, got %d", len(parts)) + } + if parts[0].MediaType != "audio/mp3" { + t.Errorf("media type = %q, want %q", parts[0].MediaType, "audio/mp3") + } + if parts[0].Filename != "audio.wav" { + t.Errorf("filename = %q, want %q", parts[0].Filename, "audio.wav") + } + }) + + t.Run("EmbeddedResource_Text", func(t *testing.T) { + text, parts := extractPromptContent(mcp.EmbeddedResource{ + Type: "resource", + Resource: mcp.TextResourceContents{ + URI: "file:///project/main.go", + MIMEType: "text/x-go", + Text: "package main", + }, + }) + if text == "" { + t.Fatal("expected non-empty text for text resource") + } + if !strings.Contains(text, "package main") { + t.Errorf("text should contain resource content, got %q", text) + } + if !strings.Contains(text, "file:///project/main.go") { + t.Errorf("text should contain URI, got %q", text) + } + if len(parts) != 0 { + t.Errorf("expected 0 file parts for text resource, got %d", len(parts)) + } + }) + + t.Run("EmbeddedResource_Blob", func(t *testing.T) { + blobData := []byte("binary content") + encoded := base64.StdEncoding.EncodeToString(blobData) + text, parts := extractPromptContent(mcp.EmbeddedResource{ + Type: "resource", + Resource: mcp.BlobResourceContents{ + URI: "file:///project/data.bin", + MIMEType: "application/octet-stream", + Blob: encoded, + }, + }) + if text != "" { + t.Errorf("expected empty text for blob resource, got %q", text) + } + if len(parts) != 1 { + t.Fatalf("expected 1 file part for blob resource, got %d", len(parts)) + } + if parts[0].Filename != "data.bin" { + t.Errorf("filename = %q, want %q", parts[0].Filename, "data.bin") + } + if parts[0].MediaType != "application/octet-stream" { + t.Errorf("media type = %q, want %q", parts[0].MediaType, "application/octet-stream") + } + if string(parts[0].Data) != "binary content" { + t.Errorf("data = %q, want %q", string(parts[0].Data), "binary content") + } + }) + + t.Run("ResourceLink", func(t *testing.T) { + text, parts := extractPromptContent(mcp.ResourceLink{ + Type: "resource_link", + URI: "file:///docs/readme.md", + Name: "readme.md", + }) + if text == "" { + t.Fatal("expected non-empty text for resource link") + } + if !strings.Contains(text, "file:///docs/readme.md") { + t.Errorf("text should contain URI, got %q", text) + } + if !strings.Contains(text, "readme.md") { + t.Errorf("text should contain name, got %q", text) + } + if len(parts) != 0 { + t.Errorf("expected 0 file parts for resource link, got %d", len(parts)) + } + }) + + t.Run("InvalidBase64", func(t *testing.T) { + _, parts := extractPromptContent(mcp.ImageContent{ + Type: "image", + Data: "not-valid-base64!!!", + MIMEType: "image/png", + }) + if len(parts) != 0 { + t.Errorf("expected 0 file parts for invalid base64, got %d", len(parts)) + } + }) + + t.Run("NilContent", func(t *testing.T) { + text, parts := extractPromptContent((*mcp.TextContent)(nil)) + if text != "" { + t.Errorf("expected empty text for nil, got %q", text) + } + if len(parts) != 0 { + t.Errorf("expected 0 parts for nil, got %d", len(parts)) + } + }) +} + +func TestFilenameFromURI(t *testing.T) { + tests := []struct { + uri string + want string + }{ + {"file:///path/to/image.png", "image.png"}, + {"file:///single.txt", "single.txt"}, + {"resource://server/data.json", "data.json"}, + {"nopath", "nopath"}, + {"", "resource"}, + } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := extractContentText(tt.content) + t.Run(tt.uri, func(t *testing.T) { + got := filenameFromURI(tt.uri) if got != tt.want { - t.Errorf("extractContentText() = %q, want %q", got, tt.want) + t.Errorf("filenameFromURI(%q) = %q, want %q", tt.uri, got, tt.want) } }) } } + +func TestGetPrompt_EmbeddedResources(t *testing.T) { + ctx := context.Background() + + imgData := base64.StdEncoding.EncodeToString([]byte("fake-png")) + blobData := base64.StdEncoding.EncodeToString([]byte("binary-blob")) + + client := newTestPromptServer(t, + server.ServerPrompt{ + Prompt: mcp.NewPrompt("review-with-files", + mcp.WithPromptDescription("Review with embedded resources"), + ), + Handler: func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return &mcp.GetPromptResult{ + Description: "Review prompt with embedded files", + Messages: []mcp.PromptMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{Type: "text", Text: "Please review these files:"}, + }, + { + Role: mcp.RoleUser, + Content: mcp.EmbeddedResource{ + Type: "resource", + Resource: mcp.TextResourceContents{ + URI: "file:///src/main.go", + MIMEType: "text/x-go", + Text: "package main\n\nfunc main() {}", + }, + }, + }, + { + Role: mcp.RoleUser, + Content: mcp.ImageContent{ + Type: "image", + Data: imgData, + MIMEType: "image/png", + }, + }, + { + Role: mcp.RoleUser, + Content: mcp.EmbeddedResource{ + Type: "resource", + Resource: mcp.BlobResourceContents{ + URI: "file:///data/model.bin", + MIMEType: "application/octet-stream", + Blob: blobData, + }, + }, + }, + }, + }, nil + }, + }, + ) + + m := injectClientIntoManager(t, "test", client) + + result, err := m.GetPrompt(ctx, "test", "review-with-files", nil) + if err != nil { + t.Fatalf("GetPrompt error: %v", err) + } + if result.Description != "Review prompt with embedded files" { + t.Errorf("unexpected description: %q", result.Description) + } + + // Should have 4 messages: text, embedded text resource, image, embedded blob + if len(result.Messages) != 4 { + t.Fatalf("expected 4 messages, got %d", len(result.Messages)) + } + + // Message 0: plain text + msg0 := result.Messages[0] + if msg0.Content != "Please review these files:" { + t.Errorf("msg[0] content = %q", msg0.Content) + } + if len(msg0.FileParts) != 0 { + t.Errorf("msg[0] expected 0 file parts, got %d", len(msg0.FileParts)) + } + + // Message 1: embedded text resource → inlined as text + msg1 := result.Messages[1] + if !strings.Contains(msg1.Content, "package main") { + t.Errorf("msg[1] should contain resource text, got %q", msg1.Content) + } + if len(msg1.FileParts) != 0 { + t.Errorf("msg[1] expected 0 file parts (text resource), got %d", len(msg1.FileParts)) + } + + // Message 2: image → file part + msg2 := result.Messages[2] + if msg2.Content != "" { + t.Errorf("msg[2] expected empty text for image, got %q", msg2.Content) + } + if len(msg2.FileParts) != 1 { + t.Fatalf("msg[2] expected 1 file part, got %d", len(msg2.FileParts)) + } + if msg2.FileParts[0].MediaType != "image/png" { + t.Errorf("msg[2] file part MIME = %q", msg2.FileParts[0].MediaType) + } + if string(msg2.FileParts[0].Data) != "fake-png" { + t.Errorf("msg[2] file part data = %q", string(msg2.FileParts[0].Data)) + } + + // Message 3: embedded blob resource → file part + msg3 := result.Messages[3] + if msg3.Content != "" { + t.Errorf("msg[3] expected empty text for blob resource, got %q", msg3.Content) + } + if len(msg3.FileParts) != 1 { + t.Fatalf("msg[3] expected 1 file part, got %d", len(msg3.FileParts)) + } + if msg3.FileParts[0].Filename != "model.bin" { + t.Errorf("msg[3] filename = %q, want %q", msg3.FileParts[0].Filename, "model.bin") + } + if string(msg3.FileParts[0].Data) != "binary-blob" { + t.Errorf("msg[3] file part data = %q", string(msg3.FileParts[0].Data)) + } +} diff --git a/internal/ui/model.go b/internal/ui/model.go index 6453e644..2d1e4fbb 100644 --- a/internal/ui/model.go +++ b/internal/ui/model.go @@ -157,8 +157,9 @@ type MCPPromptExpandResult struct { // MCPPromptMessageInfo is a single message from an expanded MCP prompt. type MCPPromptMessageInfo struct { - Role string // "user" or "assistant" - Content string + Role string // "user" or "assistant" + Content string + FileParts []kit.LLMFilePart } // ToolRendererData holds extension-provided rendering functions for a specific @@ -2153,7 +2154,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // as a user message (same behavior as local prompt templates). if msg.err != nil { m.printSystemMessage(fmt.Sprintf("MCP prompt error: %v", msg.err)) - } else if msg.text != "" { + } else if msg.text != "" || len(msg.fileParts) > 0 { // Process @file references and submit. processedText := msg.text var fileParts []kit.LLMFilePart @@ -2168,6 +2169,35 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { }) } } + // Merge file parts from embedded resources (images, audio, blobs) + // with any @file/@mcp: file parts extracted from the text. + fileParts = append(fileParts, msg.fileParts...) + + // Build display text with attachment badges (matches the + // normal submit path so embedded resources look like pasted + // images / attached files). + displayText := msg.text + if len(msg.fileParts) > 0 { + var imageCount, fileCount int + for _, fp := range msg.fileParts { + if strings.HasPrefix(fp.MediaType, "image/") { + imageCount++ + } else { + fileCount++ + } + } + var badges []string + if imageCount > 0 { + badges = append(badges, fmt.Sprintf("%d image(s) attached", imageCount)) + } + if fileCount > 0 { + badges = append(badges, fmt.Sprintf("%d file(s) attached", fileCount)) + } + if len(badges) > 0 { + displayText = fmt.Sprintf("%s\n[%s]", msg.text, strings.Join(badges, ", ")) + } + } + if m.appCtrl != nil { var qLen int if len(fileParts) > 0 { @@ -2176,10 +2206,10 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { qLen = m.appCtrl.Run(processedText) } if qLen > 0 { - m.queuedMessages = append(m.queuedMessages, msg.text) + m.queuedMessages = append(m.queuedMessages, displayText) m.layoutDirty = true } else { - m.pendingUserPrints = append(m.pendingUserPrints, msg.text) + m.pendingUserPrints = append(m.pendingUserPrints, displayText) m.flushStreamAndPendingUserMessages() } if m.state != stateWorking { @@ -3125,14 +3155,22 @@ func (m *AppModel) handleMCPPromptCommand(text string) tea.Cmd { ctrl.SendEvent(mcpPromptResultMsg{err: err}) return } - // Concatenate user-role messages as the prompt text. + // Concatenate user-role messages as the prompt text and collect + // any binary attachments from embedded resources. var parts []string + var allFileParts []kit.LLMFilePart for _, msg := range result.Messages { if msg.Role == "user" { - parts = append(parts, msg.Content) + if msg.Content != "" { + parts = append(parts, msg.Content) + } + allFileParts = append(allFileParts, msg.FileParts...) } } - ctrl.SendEvent(mcpPromptResultMsg{text: strings.Join(parts, "\n\n")}) + ctrl.SendEvent(mcpPromptResultMsg{ + text: strings.Join(parts, "\n\n"), + fileParts: allFileParts, + }) }() return noopCmd @@ -4472,8 +4510,9 @@ type extensionCmdResultMsg struct { // mcpPromptResultMsg carries the result of an asynchronously expanded MCP // prompt. The expansion runs in a goroutine since it contacts the MCP server. type mcpPromptResultMsg struct { - text string // concatenated user messages to submit as the prompt - err error // error from the server + text string // concatenated user messages to submit as the prompt + fileParts []kit.LLMFilePart // binary attachments from embedded resources + err error // error from the server } // beforeSessionSwitchResultMsg carries the result of an asynchronously diff --git a/pkg/kit/kit.go b/pkg/kit/kit.go index d4eaa7a1..ef0efd0c 100644 --- a/pkg/kit/kit.go +++ b/pkg/kit/kit.go @@ -257,6 +257,10 @@ type MCPPromptMessage struct { Role string // Content is the text content of the message. Content string + // FileParts contains binary attachments extracted from embedded resources, + // images, or audio content blocks within the prompt message. Empty for + // text-only messages. + FileParts []LLMFilePart } // MCPPromptResult is the result of expanding an MCP prompt. @@ -308,9 +312,18 @@ func (m *Kit) GetMCPPrompt(ctx context.Context, serverName, promptName string, a } msgs := make([]MCPPromptMessage, len(internal.Messages)) for i, msg := range internal.Messages { + var fileParts []LLMFilePart + for _, fp := range msg.FileParts { + fileParts = append(fileParts, LLMFilePart{ + Filename: fp.Filename, + Data: fp.Data, + MediaType: fp.MediaType, + }) + } msgs[i] = MCPPromptMessage{ - Role: msg.Role, - Content: msg.Content, + Role: msg.Role, + Content: msg.Content, + FileParts: fileParts, } } return &MCPPromptResult{