Compare commits

...

17 Commits

Author SHA1 Message Date
YuTengjing 976155f0e8 Adjust SSRF fetch usage 2025-11-10 16:53:22 +08:00
YuTengjing c73a15d4a5 Respect custom fetch overrides 2025-11-10 16:53:22 +08:00
YuTengjing 2187409dcf 🔧 refactor: update fetch type assertion in ModelRuntime initialization
Changed the fetch property type assertion from 'any' to 'typeof fetch' for improved type safety in the ModelRuntime initialization function.
2025-11-10 16:53:22 +08:00
YuTengjing 0274a623fa Revert edge runtime changes 2025-11-10 16:53:22 +08:00
YuTengjing 55d6597644 🔧 refactor: enhance fetch functionality with SSRF options across multiple context builders 2025-11-10 16:53:22 +08:00
YuTengjing 2bdf218c57 test: fix ollama URL image conversion timeout by mocking imageUrlToBase64
Mock the imageUrlToBase64 function from @lobechat/utils to prevent real network
requests in tests, which was causing the URL image conversion failure test to timeout.
2025-11-10 16:53:22 +08:00
YuTengjing aedfc5d59c test: fix mock assertions for updated image conversion function signatures
Update test assertions to match the new function signatures that now accept
an optional customFetch parameter for imageUrlToBase64 and convertImageUrlToFile.
2025-11-10 16:53:22 +08:00
YuTengjing 81deeb7374 ♻️ refactor: unify fetch property naming across providers and extract ModelRuntimeOptions type
- Rename _fetch to fetch in Anthropic provider
- Rename fetchImpl to fetch in Google, BFL, and Ollama providers
- Extract ModelRuntimeOptions interface from complex intersection types
- Add type-safe parameter interfaces for each provider (BflAIParams, OllamaAIParams)
- Remove unnecessary type assertions (as typeof fetch | undefined)
- All fetch properties are now consistently named public fetch with optional signature
2025-11-10 16:53:22 +08:00
YuTengjing 8312f34703 chore: remove temp logs 2025-11-10 16:53:22 +08:00
YuTengjing 751d09b14a chore: try fix node:stream import 2025-11-10 16:53:22 +08:00
YuTengjing 417182b31b chore: add some debug logs 2025-11-10 16:53:22 +08:00
YuTengjing f91c036c6c 🐛 fix: provide browser entry for ssrf-safe-fetch 2025-11-10 16:53:22 +08:00
YuTengjing f6c5ad1498 refactor: remove edge runtime setting from OpenAI route 2025-11-10 16:53:22 +08:00
YuTengjing 5612dfe5a4 refactor: remove tts edge runtime setting 2025-11-10 16:53:22 +08:00
YuTengjing ade6faa020 🐛 fix: support azure ai node runtime 2025-11-10 16:53:22 +08:00
YuTengjing cbd0346ea0 🐛 fix: ollama image to base64 before send message #10002 2025-11-10 16:53:22 +08:00
YuTengjing de8e4e4a1a 🐛 fix: imageUrlToBase64 ssrf 2025-11-10 16:53:22 +08:00
32 changed files with 504 additions and 354 deletions
+3
View File
@@ -22,6 +22,9 @@ config.rules['unicorn/no-array-callback-reference'] = 0;
// FIXME: Linting error in src/app/[variants]/(main)/chat/features/Migration/DBReader.ts, the fundamental solution should be upgrading typescript-eslint
config.rules['@typescript-eslint/no-useless-constructor'] = 0;
if (!config.globals) config.globals = {};
config.globals.RequestInit = true;
config.overrides = [
{
extends: ['plugin:mdx/recommended'],
@@ -25,6 +25,16 @@ export interface AgentChatOptions {
trace?: TracePayload;
}
export interface ModelRuntimeOptions
extends ClientOptions,
LobeBedrockAIParams,
LobeCloudflareParams {
apiKey?: string;
apiVersion?: string;
baseURL?: string;
fetch?: typeof fetch;
}
export class ModelRuntime {
private _runtime: LobeRuntimeAI;
@@ -114,14 +124,7 @@ export class ModelRuntime {
* - `src/app/api/chat/agentRuntime.ts: initAgentRuntimeWithUserPayload` on server
* - `src/services/chat.ts: initializeWithClientStore` on client
*/
static initializeWithProvider(
provider: string,
params: Partial<
ClientOptions &
LobeBedrockAIParams &
LobeCloudflareParams & { apiKey?: string; apiVersion?: string; baseURL?: string }
>,
) {
static initializeWithProvider(provider: string, params: Partial<ModelRuntimeOptions>) {
// @ts-expect-error runtime map not include vertex so it will be undefined
const providerAI = providerRuntimeMap[provider] ?? LobeOpenAI;
@@ -1,8 +1,8 @@
import { imageUrlToBase64 } from '@lobechat/utils/imageToBase64';
import { OpenAI } from 'openai';
import { describe, expect, it, vi } from 'vitest';
import { OpenAIChatMessage, UserMessageContentPart } from '../../types/chat';
import { imageUrlToBase64 } from '../../utils/imageToBase64';
import { parseDataUri } from '../../utils/uriParser';
import {
buildAnthropicBlock,
@@ -19,7 +19,7 @@ vi.mock('../../utils/uriParser', () => ({
type: 'base64',
}),
}));
vi.mock('../../utils/imageToBase64');
vi.mock('@lobechat/utils/imageToBase64');
describe('anthropicHelpers', () => {
describe('buildAnthropicBlock', () => {
@@ -65,7 +65,7 @@ describe('anthropicHelpers', () => {
const result = await buildAnthropicBlock(content);
expect(parseDataUri).toHaveBeenCalledWith(content.image_url.url);
expect(imageUrlToBase64).toHaveBeenCalledWith(content.image_url.url);
expect(imageUrlToBase64).toHaveBeenCalledWith(content.image_url.url, undefined);
expect(result).toEqual({
source: {
data: 'convertedBase64String',
@@ -1,12 +1,13 @@
import Anthropic from '@anthropic-ai/sdk';
import { imageUrlToBase64 } from '@lobechat/utils/imageToBase64';
import OpenAI from 'openai';
import { OpenAIChatMessage, UserMessageContentPart } from '../../types';
import { imageUrlToBase64 } from '../../utils/imageToBase64';
import { parseDataUri } from '../../utils/uriParser';
export const buildAnthropicBlock = async (
content: UserMessageContentPart,
customFetch?: typeof fetch,
): Promise<Anthropic.ContentBlock | Anthropic.ImageBlockParam | undefined> => {
switch (content.type) {
case 'thinking': {
@@ -34,7 +35,7 @@ export const buildAnthropicBlock = async (
};
if (type === 'url') {
const { base64, mimeType } = await imageUrlToBase64(content.image_url.url);
const { base64, mimeType } = await imageUrlToBase64(content.image_url.url, customFetch);
return {
source: {
data: base64 as string,
@@ -50,9 +51,11 @@ export const buildAnthropicBlock = async (
}
};
const buildArrayContent = async (content: UserMessageContentPart[]) => {
const buildArrayContent = async (content: UserMessageContentPart[], customFetch?: typeof fetch) => {
let messageContent = (await Promise.all(
(content as UserMessageContentPart[]).map(async (c) => await buildAnthropicBlock(c)),
(content as UserMessageContentPart[]).map(
async (c) => await buildAnthropicBlock(c, customFetch),
),
)) as Anthropic.Messages.ContentBlockParam[];
messageContent = messageContent.filter(Boolean);
@@ -62,6 +65,7 @@ const buildArrayContent = async (content: UserMessageContentPart[]) => {
export const buildAnthropicMessage = async (
message: OpenAIChatMessage,
customFetch?: typeof fetch,
): Promise<Anthropic.Messages.MessageParam> => {
const content = message.content as string | UserMessageContentPart[];
@@ -72,7 +76,8 @@ export const buildAnthropicMessage = async (
case 'user': {
return {
content: typeof content === 'string' ? content : await buildArrayContent(content),
content:
typeof content === 'string' ? content : await buildArrayContent(content, customFetch),
role: 'user',
};
}
@@ -100,7 +105,7 @@ export const buildAnthropicMessage = async (
? ([{ text: message.content, type: 'text' }] as UserMessageContentPart[])
: content;
const messageContent = await buildArrayContent(rawContent);
const messageContent = await buildArrayContent(rawContent, customFetch);
return {
content: [
@@ -129,7 +134,7 @@ export const buildAnthropicMessage = async (
export const buildAnthropicMessages = async (
oaiMessages: OpenAIChatMessage[],
options: { enabledContextCaching?: boolean } = {},
options: { customFetch?: typeof fetch; enabledContextCaching?: boolean } = {},
): Promise<Anthropic.Messages.MessageParam[]> => {
const messages: Anthropic.Messages.MessageParam[] = [];
let pendingToolResults: Anthropic.ToolResultBlockParam[] = [];
@@ -175,7 +180,7 @@ export const buildAnthropicMessages = async (
});
}
} else {
const anthropicMessage = await buildAnthropicMessage(message);
const anthropicMessage = await buildAnthropicMessage(message, options.customFetch);
messages.push({ ...anthropicMessage, role: anthropicMessage.role });
}
}
@@ -1,9 +1,9 @@
// @vitest-environment node
import { Type as SchemaType } from '@google/genai';
import * as imageToBase64Module from '@lobechat/utils/imageToBase64';
import { describe, expect, it, vi } from 'vitest';
import { ChatCompletionTool, OpenAIChatMessage, UserMessageContentPart } from '../../types';
import * as imageToBase64Module from '../../utils/imageToBase64';
import { parseDataUri } from '../../utils/uriParser';
import {
buildGoogleMessage,
@@ -18,7 +18,7 @@ vi.mock('../../utils/uriParser', () => ({
parseDataUri: vi.fn(),
}));
vi.mock('../../utils/imageToBase64', () => ({
vi.mock('@lobechat/utils/imageToBase64', () => ({
imageUrlToBase64: vi.fn(),
}));
@@ -102,7 +102,7 @@ describe('google contextBuilders', () => {
},
});
expect(imageToBase64Module.imageUrlToBase64).toHaveBeenCalledWith(imageUrl);
expect(imageToBase64Module.imageUrlToBase64).toHaveBeenCalledWith(imageUrl, undefined);
});
it('should throw TypeError for unsupported image URL types', async () => {
@@ -5,9 +5,9 @@ import {
Part,
Type as SchemaType,
} from '@google/genai';
import { imageUrlToBase64 } from '@lobechat/utils/imageToBase64';
import { ChatCompletionTool, OpenAIChatMessage, UserMessageContentPart } from '../../types';
import { imageUrlToBase64 } from '../../utils/imageToBase64';
import { safeParseJSON } from '../../utils/safeParseJSON';
import { parseDataUri } from '../../utils/uriParser';
@@ -16,6 +16,7 @@ import { parseDataUri } from '../../utils/uriParser';
*/
export const buildGooglePart = async (
content: UserMessageContentPart,
customFetch?: typeof fetch,
): Promise<Part | undefined> => {
switch (content.type) {
default: {
@@ -40,7 +41,7 @@ export const buildGooglePart = async (
}
if (type === 'url') {
const { base64, mimeType } = await imageUrlToBase64(content.image_url.url);
const { base64, mimeType } = await imageUrlToBase64(content.image_url.url, customFetch);
return {
inlineData: { data: base64, mimeType },
@@ -66,7 +67,12 @@ export const buildGooglePart = async (
if (type === 'url') {
// For video URLs, we need to fetch and convert to base64
// Note: This might need size/duration limits for practical use
const response = await fetch(content.video_url.url);
const fetchFn = customFetch || fetch;
const ssrfOptions = customFetch ? ({ ssrf: true } as RequestInit) : undefined;
const response =
ssrfOptions === undefined
? await fetchFn(content.video_url.url)
: await fetchFn(content.video_url.url, ssrfOptions);
const arrayBuffer = await response.arrayBuffer();
const base64 = Buffer.from(arrayBuffer).toString('base64');
const mimeType = response.headers.get('content-type') || 'video/mp4';
@@ -87,6 +93,7 @@ export const buildGooglePart = async (
export const buildGoogleMessage = async (
message: OpenAIChatMessage,
toolCallNameMap?: Map<string, string>,
customFetch?: typeof fetch,
): Promise<Content> => {
const content = message.content as string | UserMessageContentPart[];
@@ -124,7 +131,9 @@ export const buildGoogleMessage = async (
const getParts = async () => {
if (typeof content === 'string') return [{ text: content }];
const parts = await Promise.all(content.map(async (c) => await buildGooglePart(c)));
const parts = await Promise.all(
content.map(async (c) => await buildGooglePart(c, customFetch)),
);
return parts.filter(Boolean) as Part[];
};
@@ -137,7 +146,10 @@ export const buildGoogleMessage = async (
/**
* Convert messages from the OpenAI format to Google GenAI SDK format
*/
export const buildGoogleMessages = async (messages: OpenAIChatMessage[]): Promise<Content[]> => {
export const buildGoogleMessages = async (
messages: OpenAIChatMessage[],
customFetch?: typeof fetch,
): Promise<Content[]> => {
const toolCallNameMap = new Map<string, string>();
// Build tool call id to name mapping
@@ -153,7 +165,7 @@ export const buildGoogleMessages = async (messages: OpenAIChatMessage[]): Promis
const pools = messages
.filter((message) => message.role !== 'function')
.map(async (msg) => await buildGoogleMessage(msg, toolCallNameMap));
.map(async (msg) => await buildGoogleMessage(msg, toolCallNameMap, customFetch));
const contents = await Promise.all(pools);
@@ -1,7 +1,7 @@
import { imageUrlToBase64 } from '@lobechat/utils/imageToBase64';
import OpenAI from 'openai';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import { imageUrlToBase64 } from '../../utils/imageToBase64';
import { parseDataUri } from '../../utils/uriParser';
import {
convertImageUrlToFile,
@@ -11,7 +11,7 @@ import {
} from './openai';
// 模拟依赖
vi.mock('../../utils/imageToBase64');
vi.mock('@lobechat/utils/imageToBase64');
vi.mock('../../utils/uriParser');
describe('convertMessageContent', () => {
@@ -52,7 +52,7 @@ describe('convertMessageContent', () => {
});
expect(parseDataUri).toHaveBeenCalledWith('https://example.com/image.jpg');
expect(imageUrlToBase64).toHaveBeenCalledWith('https://example.com/image.jpg');
expect(imageUrlToBase64).toHaveBeenCalledWith('https://example.com/image.jpg', undefined);
});
it('should not convert image URL when not necessary', async () => {
@@ -1,18 +1,19 @@
import { imageUrlToBase64 } from '@lobechat/utils/imageToBase64';
import OpenAI, { toFile } from 'openai';
import { disableStreamModels, systemToUserModels } from '../../const/models';
import { ChatStreamPayload, OpenAIChatMessage } from '../../types';
import { imageUrlToBase64 } from '../../utils/imageToBase64';
import { parseDataUri } from '../../utils/uriParser';
export const convertMessageContent = async (
content: OpenAI.ChatCompletionContentPart,
customFetch?: typeof fetch,
): Promise<OpenAI.ChatCompletionContentPart> => {
if (content.type === 'image_url') {
const { type } = parseDataUri(content.image_url.url);
if (type === 'url' && process.env.LLM_VISION_IMAGE_USE_BASE64 === '1') {
const { base64, mimeType } = await imageUrlToBase64(content.image_url.url);
const { base64, mimeType } = await imageUrlToBase64(content.image_url.url, customFetch);
return {
...content,
@@ -24,7 +25,10 @@ export const convertMessageContent = async (
return content;
};
export const convertOpenAIMessages = async (messages: OpenAI.ChatCompletionMessageParam[]) => {
export const convertOpenAIMessages = async (
messages: OpenAI.ChatCompletionMessageParam[],
customFetch?: typeof fetch,
) => {
return (await Promise.all(
messages.map(async (message) => ({
...message,
@@ -33,7 +37,7 @@ export const convertOpenAIMessages = async (messages: OpenAI.ChatCompletionMessa
? message.content
: await Promise.all(
(message.content || []).map((c) =>
convertMessageContent(c as OpenAI.ChatCompletionContentPart),
convertMessageContent(c as OpenAI.ChatCompletionContentPart, customFetch),
),
),
})),
@@ -42,6 +46,7 @@ export const convertOpenAIMessages = async (messages: OpenAI.ChatCompletionMessa
export const convertOpenAIResponseInputs = async (
messages: OpenAI.ChatCompletionMessageParam[],
customFetch?: typeof fetch,
) => {
let input: OpenAI.Responses.ResponseInputItem[] = [];
await Promise.all(
@@ -83,7 +88,10 @@ export const convertOpenAIResponseInputs = async (
return { ...c, type: 'input_text' };
}
const image = await convertMessageContent(c as OpenAI.ChatCompletionContentPart);
const image = await convertMessageContent(
c as OpenAI.ChatCompletionContentPart,
customFetch,
);
return {
image_url: (image as OpenAI.ChatCompletionContentPartImage).image_url?.url,
type: 'input_image',
@@ -127,7 +135,7 @@ export const pruneReasoningPayload = (payload: ChatStreamPayload) => {
/**
* Convert image URL (data URL or HTTP URL) to File object for OpenAI API
*/
export const convertImageUrlToFile = async (imageUrl: string) => {
export const convertImageUrlToFile = async (imageUrl: string, customFetch?: typeof fetch) => {
let buffer: Buffer;
let mimeType: string;
@@ -138,7 +146,10 @@ export const convertImageUrlToFile = async (imageUrl: string) => {
buffer = Buffer.from(base64Data, 'base64');
} else {
// a http url
const response = await fetch(imageUrl);
const fetchFn = customFetch || fetch;
const ssrfOptions = customFetch ? ({ ssrf: true } as RequestInit) : undefined;
const response =
ssrfOptions === undefined ? await fetchFn(imageUrl) : await fetchFn(imageUrl, ssrfOptions);
if (!response.ok) {
throw new Error(`Failed to fetch image from ${imageUrl}: ${response.statusText}`);
}
@@ -1,9 +1,9 @@
// @vitest-environment node
import * as imageToBase64Module from '@lobechat/utils/imageToBase64';
import OpenAI from 'openai';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { CreateImagePayload } from '../../types/image';
import * as imageToBase64Module from '../../utils/imageToBase64';
import * as uriParserModule from '../../utils/uriParser';
import { createOpenAICompatibleImage } from './createImage';
@@ -81,7 +81,11 @@ describe('createOpenAICompatibleImage', () => {
},
};
const result = await createOpenAICompatibleImage(mockClient, payload, 'openrouter');
const result = await createOpenAICompatibleImage({
client: mockClient,
payload,
provider: 'openrouter',
});
expect(result.imageUrl).toBe('data:image/png;base64,generatedImageData');
expect(mockClient.chat.completions.create).toHaveBeenCalled();
@@ -122,7 +126,11 @@ describe('createOpenAICompatibleImage', () => {
},
};
const result = await createOpenAICompatibleImage(mockClient, payload, 'test-provider');
const result = await createOpenAICompatibleImage({
client: mockClient,
payload,
provider: 'test-provider',
});
expect(result.imageUrl).toBe('data:image/png;base64,result');
});
@@ -145,7 +153,7 @@ describe('createOpenAICompatibleImage', () => {
};
await expect(
createOpenAICompatibleImage(mockClient, payload, 'test-provider'),
createOpenAICompatibleImage({ client: mockClient, payload, provider: 'test-provider' }),
).rejects.toThrow(
"Failed to process image URL: TypeError: Image URL doesn't contain base64 data",
);
@@ -191,9 +199,16 @@ describe('createOpenAICompatibleImage', () => {
},
};
const result = await createOpenAICompatibleImage(mockClient, payload, 'test-provider');
const result = await createOpenAICompatibleImage({
client: mockClient,
payload,
provider: 'test-provider',
});
expect(imageToBase64Module.imageUrlToBase64).toHaveBeenCalledWith(mockHttpImageUrl);
expect(imageToBase64Module.imageUrlToBase64).toHaveBeenCalledWith(
mockHttpImageUrl,
undefined,
);
expect(result.imageUrl).toBe('data:image/png;base64,output');
});
@@ -215,7 +230,7 @@ describe('createOpenAICompatibleImage', () => {
};
await expect(
createOpenAICompatibleImage(mockClient, payload, 'test-provider'),
createOpenAICompatibleImage({ client: mockClient, payload, provider: 'test-provider' }),
).rejects.toThrow(
`Failed to process image URL: TypeError: Currently we don't support image url: ${mockInvalidUrl}`,
);
@@ -249,7 +264,11 @@ describe('createOpenAICompatibleImage', () => {
},
};
const result = await createOpenAICompatibleImage(mockClient, payload, 'openrouter');
const result = await createOpenAICompatibleImage({
client: mockClient,
payload,
provider: 'openrouter',
});
expect(result.imageUrl).toBe('data:image/png;base64,generatedWithoutInputImage');
expect(mockClient.chat.completions.create).toHaveBeenCalledWith({
@@ -296,7 +315,11 @@ describe('createOpenAICompatibleImage', () => {
},
};
const result = await createOpenAICompatibleImage(mockClient, payload, 'test-provider');
const result = await createOpenAICompatibleImage({
client: mockClient,
payload,
provider: 'test-provider',
});
expect(result.imageUrl).toBe('data:image/png;base64,generatedImage');
// Should not include image in content array
@@ -324,7 +347,7 @@ describe('createOpenAICompatibleImage', () => {
};
await expect(
createOpenAICompatibleImage(mockClient, payload, 'test-provider'),
createOpenAICompatibleImage({ client: mockClient, payload, provider: 'test-provider' }),
).rejects.toThrow('No message in chat completion response');
});
@@ -349,7 +372,7 @@ describe('createOpenAICompatibleImage', () => {
};
await expect(
createOpenAICompatibleImage(mockClient, payload, 'test-provider'),
createOpenAICompatibleImage({ client: mockClient, payload, provider: 'test-provider' }),
).rejects.toThrow('No image generated in chat completion response');
});
@@ -374,7 +397,7 @@ describe('createOpenAICompatibleImage', () => {
};
await expect(
createOpenAICompatibleImage(mockClient, payload, 'test-provider'),
createOpenAICompatibleImage({ client: mockClient, payload, provider: 'test-provider' }),
).rejects.toThrow('No image generated in chat completion response');
});
@@ -404,7 +427,7 @@ describe('createOpenAICompatibleImage', () => {
};
await expect(
createOpenAICompatibleImage(mockClient, payload, 'test-provider'),
createOpenAICompatibleImage({ client: mockClient, payload, provider: 'test-provider' }),
).rejects.toThrow('No image generated in chat completion response');
});
@@ -436,7 +459,7 @@ describe('createOpenAICompatibleImage', () => {
};
await expect(
createOpenAICompatibleImage(mockClient, payload, 'test-provider'),
createOpenAICompatibleImage({ client: mockClient, payload, provider: 'test-provider' }),
).rejects.toThrow('No image generated in chat completion response');
});
@@ -475,7 +498,11 @@ describe('createOpenAICompatibleImage', () => {
},
};
const result = await createOpenAICompatibleImage(mockClient, payload, 'openrouter');
const result = await createOpenAICompatibleImage({
client: mockClient,
payload,
provider: 'openrouter',
});
expect(result.imageUrl).toBe('data:image/png;base64,processedResult');
@@ -522,7 +549,11 @@ describe('createOpenAICompatibleImage', () => {
},
};
const result = await createOpenAICompatibleImage(mockClient, payload, 'test-provider');
const result = await createOpenAICompatibleImage({
client: mockClient,
payload,
provider: 'test-provider',
});
expect(result.imageUrl).toBe('data:image/png;base64,chatModelResult');
expect(mockClient.chat.completions.create).toHaveBeenCalled();
@@ -548,7 +579,11 @@ describe('createOpenAICompatibleImage', () => {
},
};
const result = await createOpenAICompatibleImage(mockClient, payload, 'openai');
const result = await createOpenAICompatibleImage({
client: mockClient,
payload,
provider: 'openai',
});
expect(result.imageUrl).toBe('data:image/png;base64,imageModelBase64Result');
expect(mockClient.images.generate).toHaveBeenCalled();
@@ -586,7 +621,11 @@ describe('createOpenAICompatibleImage', () => {
},
};
const result = await createOpenAICompatibleImage(mockClient, payload, 'openai');
const result = await createOpenAICompatibleImage({
client: mockClient,
payload,
provider: 'openai',
});
expect(result.imageUrl).toBe('data:image/png;base64,editedImageResult');
expect(mockClient.images.edit).toHaveBeenCalled();
@@ -611,7 +650,11 @@ describe('createOpenAICompatibleImage', () => {
},
};
const result = await createOpenAICompatibleImage(mockClient, payload, 'openai');
const result = await createOpenAICompatibleImage({
client: mockClient,
payload,
provider: 'openai',
});
expect(result.imageUrl).toBe('data:image/png;base64,generatedImage');
expect(mockClient.images.generate).toHaveBeenCalled();
@@ -637,7 +680,11 @@ describe('createOpenAICompatibleImage', () => {
},
};
const result = await createOpenAICompatibleImage(mockClient, payload, 'openai');
const result = await createOpenAICompatibleImage({
client: mockClient,
payload,
provider: 'openai',
});
expect(result.imageUrl).toBe('data:image/png;base64,generatedImage');
expect(mockClient.images.generate).toHaveBeenCalled();
@@ -665,7 +712,11 @@ describe('createOpenAICompatibleImage', () => {
},
};
const result = await createOpenAICompatibleImage(mockClient, payload, 'openai');
const result = await createOpenAICompatibleImage({
client: mockClient,
payload,
provider: 'openai',
});
expect(result.imageUrl).toBe(mockImageUrl);
expect(mockClient.images.generate).toHaveBeenCalled();
@@ -690,9 +741,9 @@ describe('createOpenAICompatibleImage', () => {
},
};
await expect(createOpenAICompatibleImage(mockClient, payload, 'openai')).rejects.toThrow(
'Invalid image response: missing both b64_json and url fields',
);
await expect(
createOpenAICompatibleImage({ client: mockClient, payload, provider: 'openai' }),
).rejects.toThrow('Invalid image response: missing both b64_json and url fields');
});
it('should throw error when response data is not an array', async () => {
@@ -709,9 +760,9 @@ describe('createOpenAICompatibleImage', () => {
},
};
await expect(createOpenAICompatibleImage(mockClient, payload, 'openai')).rejects.toThrow(
'Invalid image response: missing or empty data array',
);
await expect(
createOpenAICompatibleImage({ client: mockClient, payload, provider: 'openai' }),
).rejects.toThrow('Invalid image response: missing or empty data array');
});
it('should throw error when imageData is undefined in array', async () => {
@@ -728,9 +779,9 @@ describe('createOpenAICompatibleImage', () => {
},
};
await expect(createOpenAICompatibleImage(mockClient, payload, 'openai')).rejects.toThrow(
'Invalid image response: first data item is null or undefined',
);
await expect(
createOpenAICompatibleImage({ client: mockClient, payload, provider: 'openai' }),
).rejects.toThrow('Invalid image response: first data item is null or undefined');
});
});
@@ -762,7 +813,11 @@ describe('createOpenAICompatibleImage', () => {
},
};
const result = await createOpenAICompatibleImage(mockClient, payload, 'openai');
const result = await createOpenAICompatibleImage({
client: mockClient,
payload,
provider: 'openai',
});
expect(result.imageUrl).toBe('data:image/png;base64,imageWithUsage');
expect(result.modelUsage).toBeDefined();
@@ -790,7 +845,11 @@ describe('createOpenAICompatibleImage', () => {
},
};
const result = await createOpenAICompatibleImage(mockClient, payload, 'openai');
const result = await createOpenAICompatibleImage({
client: mockClient,
payload,
provider: 'openai',
});
expect(result.imageUrl).toBe('data:image/png;base64,imageWithoutUsage');
expect(result.modelUsage).toBeUndefined();
@@ -1,3 +1,4 @@
import { imageUrlToBase64 } from '@lobechat/utils/imageToBase64';
import { cleanObject } from '@lobechat/utils/object';
import createDebug from 'debug';
import { RuntimeImageGenParamsValue } from 'model-bank';
@@ -5,24 +6,31 @@ import OpenAI from 'openai';
import { CreateImagePayload, CreateImageResponse } from '../../types/image';
import { getModelPricing } from '../../utils/getModelPricing';
import { imageUrlToBase64 } from '../../utils/imageToBase64';
import { parseDataUri } from '../../utils/uriParser';
import { convertImageUrlToFile } from '../contextBuilders/openai';
import { convertOpenAIImageUsage } from '../usageConverters/openai';
interface CreateImageContext {
client: OpenAI;
fetchImpl?: typeof fetch;
payload: CreateImagePayload;
provider: string;
}
const log = createDebug('lobe-image:openai-compatible');
/**
* Generate images using traditional OpenAI images API (DALL-E, etc.)
*/
async function generateByImageMode(
client: OpenAI,
payload: CreateImagePayload,
provider: string,
): Promise<CreateImageResponse> {
async function generateByImageMode({
client,
payload,
provider,
fetchImpl,
}: CreateImageContext): Promise<CreateImageResponse> {
const { model, params } = payload;
log('Creating image with model: %s and params: %O', model, params);
log('Creating image with provider: %s, model: %s and params: %O', provider, model, params);
// Map parameter names, mapping imageUrls to image
const paramsMap = new Map<RuntimeImageGenParamsValue, string>([
@@ -48,7 +56,7 @@ async function generateByImageMode(
try {
// Convert all image URLs to File objects
const imageFiles = await Promise.all(
userInput.image.map((url: string) => convertImageUrlToFile(url)),
userInput.image.map((url: string) => convertImageUrlToFile(url, fetchImpl)),
);
// According to official docs, if there are multiple images, pass an array; if only one, pass a single File
@@ -127,7 +135,7 @@ async function generateByImageMode(
/**
* Process image URL for chat model input
*/
async function processImageUrlForChat(imageUrl: string): Promise<string> {
async function processImageUrlForChat(imageUrl: string, fetchImpl?: typeof fetch): Promise<string> {
const { type, base64, mimeType } = parseDataUri(imageUrl);
if (type === 'base64') {
@@ -137,7 +145,10 @@ async function processImageUrlForChat(imageUrl: string): Promise<string> {
return `data:${mimeType || 'image/png'};base64,${base64}`;
} else if (type === 'url') {
// For URL type, convert to base64 first
const { base64: urlBase64, mimeType: urlMimeType } = await imageUrlToBase64(imageUrl);
const { base64: urlBase64, mimeType: urlMimeType } = await imageUrlToBase64(
imageUrl,
fetchImpl,
);
return `data:${urlMimeType};base64,${urlBase64}`;
} else {
throw new TypeError(`Currently we don't support image url: ${imageUrl}`);
@@ -147,14 +158,21 @@ async function processImageUrlForChat(imageUrl: string): Promise<string> {
/**
* Generate images using chat completion API (OpenRouter Gemini, etc.)
*/
async function generateByChatModel(
client: OpenAI,
payload: CreateImagePayload,
): Promise<CreateImageResponse> {
async function generateByChatModel({
client,
payload,
provider,
fetchImpl,
}: CreateImageContext): Promise<CreateImageResponse> {
const { model, params } = payload;
const actualModel = model.replace(':image', ''); // Remove :image suffix
log('Creating image via chat API with model: %s and params: %O', actualModel, params);
log(
'Creating image via chat API with provider: %s, model: %s and params: %O',
provider,
actualModel,
params,
);
// Build message content array
const content: Array<any> = [
@@ -168,7 +186,7 @@ async function generateByChatModel(
if (params.imageUrl && params.imageUrl !== null) {
log('Processing image URL for editing mode: %s', params.imageUrl);
try {
const processedImageUrl = await processImageUrlForChat(params.imageUrl);
const processedImageUrl = await processImageUrlForChat(params.imageUrl, fetchImpl);
content.push({
image_url: {
url: processedImageUrl,
@@ -220,18 +238,19 @@ async function generateByChatModel(
/**
* Create image using OpenAI Compatible API
*/
export async function createOpenAICompatibleImage(
client: OpenAI,
payload: CreateImagePayload,
provider: string,
): Promise<CreateImageResponse> {
export async function createOpenAICompatibleImage({
client,
payload,
provider,
fetchImpl,
}: CreateImageContext): Promise<CreateImageResponse> {
const { model } = payload;
// Check if it's a chat model for image generation (via :image suffix)
if (model.endsWith(':image')) {
return await generateByChatModel(client, payload);
return await generateByChatModel({ client, fetchImpl, payload, provider });
}
// Default to traditional images API
return await generateByImageMode(client, payload, provider);
return await generateByImageMode({ client, fetchImpl, payload, provider });
}
@@ -1251,6 +1251,7 @@ describe('LobeOpenAICompatibleFactory', () => {
expect(openaiHelpers.convertImageUrlToFile).toHaveBeenCalledWith(
'https://example.com/image1.jpg',
undefined,
);
expect(instance['client'].images.edit).toHaveBeenCalledWith({
image: expect.any(File),
@@ -1293,9 +1294,11 @@ describe('LobeOpenAICompatibleFactory', () => {
expect(openaiHelpers.convertImageUrlToFile).toHaveBeenCalledTimes(2);
expect(openaiHelpers.convertImageUrlToFile).toHaveBeenCalledWith(
'https://example.com/image1.jpg',
undefined,
);
expect(openaiHelpers.convertImageUrlToFile).toHaveBeenCalledWith(
'https://example.com/image2.jpg',
undefined,
);
expect(instance['client'].images.edit).toHaveBeenCalledWith({
@@ -57,6 +57,7 @@ export const CHAT_MODELS_BLOCK_LIST = [
type ConstructorOptions<T extends Record<string, any> = any> = ClientOptions & T;
export type CreateImageOptions = Omit<ClientOptions, 'apiKey'> & {
apiKey: string;
fetch?: typeof fetch;
provider: string;
};
@@ -178,6 +179,7 @@ export const createOpenAICompatibleRuntime = <T extends Record<string, any> = an
baseURL!: string;
protected _options: ConstructorOptions<T>;
protected _fetch?: typeof fetch;
constructor(options: ClientOptions & Record<string, any> = {}) {
const _options = {
@@ -187,6 +189,7 @@ export const createOpenAICompatibleRuntime = <T extends Record<string, any> = an
};
const { apiKey, baseURL = DEFAULT_BASE_URL, ...res } = _options;
this._options = _options as ConstructorOptions<T>;
this._fetch = options.fetch as any;
if (!apiKey) throw AgentRuntimeError.createError(ErrorType?.invalidAPIKey);
@@ -252,7 +255,7 @@ export const createOpenAICompatibleRuntime = <T extends Record<string, any> = an
return this.handleResponseAPIMode(processedPayload, options);
}
const messages = await convertOpenAIMessages(postPayload.messages);
const messages = await convertOpenAIMessages(postPayload.messages, this._fetch);
let response: Stream<OpenAI.Chat.Completions.ChatCompletionChunk>;
@@ -361,16 +364,27 @@ export const createOpenAICompatibleRuntime = <T extends Record<string, any> = an
// If custom createImage implementation is provided, use it
if (customCreateImage) {
log('using custom createImage implementation');
return customCreateImage(payload, {
const fetchImpl = (this._fetch ?? (this._options as any).fetch) as typeof fetch | undefined;
const imageOptions: CreateImageOptions = {
...this._options,
apiKey: this._options.apiKey!,
provider,
});
} as CreateImageOptions;
if (fetchImpl) imageOptions.fetch = fetchImpl;
return customCreateImage(payload, imageOptions);
}
log('using default createOpenAICompatibleImage');
// Use the new createOpenAICompatibleImage function
return createOpenAICompatibleImage(this.client, payload, this.id);
return createOpenAICompatibleImage({
client: this.client,
fetchImpl: this._fetch,
payload,
provider: this.id,
});
}
async models() {
@@ -776,7 +790,7 @@ export const createOpenAICompatibleRuntime = <T extends Record<string, any> = an
delete res.frequency_penalty;
delete res.presence_penalty;
const input = await convertOpenAIResponseInputs(messages as any);
const input = await convertOpenAIResponseInputs(messages as any, this._fetch);
const isStreaming = payload.stream !== false;
log(
@@ -928,7 +942,7 @@ export const createOpenAICompatibleRuntime = <T extends Record<string, any> = an
if (shouldUseResponses) {
log('calling responses.create for tool calling');
const input = await convertOpenAIResponseInputs(messages as any);
const input = await convertOpenAIResponseInputs(messages as any, this._fetch);
const res = await this.client.responses.create(
{
@@ -73,6 +73,7 @@ const resolveCacheTTL = (
};
interface AnthropicAIParams extends ClientOptions {
fetch?: typeof fetch;
id?: string;
}
@@ -82,6 +83,7 @@ export class LobeAnthropicAI implements LobeRuntimeAI {
baseURL: string;
apiKey?: string;
private id: string;
fetch?: typeof fetch;
private isDebug() {
return process.env.DEBUG_ANTHROPIC_CHAT_COMPLETION === '1';
@@ -92,16 +94,20 @@ export class LobeAnthropicAI implements LobeRuntimeAI {
baseURL = DEFAULT_BASE_URL,
id,
defaultHeaders,
fetch: customFetch,
...res
}: AnthropicAIParams = {}) {
if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);
const betaHeaders = process.env.ANTHROPIC_BETA_HEADERS;
this.fetch = customFetch;
this.client = new Anthropic({
apiKey,
baseURL,
defaultHeaders: { ...defaultHeaders, 'anthropic-beta': betaHeaders },
fetch: customFetch,
...res,
});
this.baseURL = this.client.baseURL;
@@ -200,7 +206,10 @@ export class LobeAnthropicAI implements LobeRuntimeAI {
] as Anthropic.TextBlockParam[])
: undefined;
const postMessages = await buildAnthropicMessages(user_messages, { enabledContextCaching });
const postMessages = await buildAnthropicMessages(user_messages, {
customFetch: this.fetch,
enabledContextCaching,
});
let postTools: anthropicTools[] | undefined = buildAnthropicTools(tools, {
enabledContextCaching,
@@ -1,6 +1,7 @@
import createClient, { ModelClient } from '@azure-rest/ai-inference';
import { AzureKeyCredential } from '@azure/core-auth';
import { ModelProvider } from 'model-bank';
import type { Readable as NodeReadable } from 'node:stream';
import OpenAI from 'openai';
import { systemToUserModels } from '../../const/models';
@@ -64,9 +65,40 @@ export class LobeAzureAI implements LobeRuntimeAI {
});
if (enableStreaming) {
const stream = await response.asBrowserStream();
const unifiedStream = await (async () => {
if (typeof window === 'undefined') {
/**
* In Node.js the SDK exposes a Node readable stream, so we convert it to a Web ReadableStream
* to reuse the same streaming pipeline used by Edge/browser runtimes.
*/
const streamModule = await import('node:stream');
const Readable = streamModule.Readable ?? streamModule.default.Readable;
const [prod, debug] = stream.body!.tee();
if (!Readable) throw new Error('node:stream module missing Readable export');
if (typeof Readable.toWeb !== 'function')
throw new Error('Readable.toWeb is not a function');
const nodeResponse = await response.asNodeStream();
const nodeStream = nodeResponse.body;
if (!nodeStream) {
throw new Error('Azure AI response body is empty');
}
return Readable.toWeb(nodeStream as unknown as NodeReadable) as ReadableStream;
}
const browserResponse = await response.asBrowserStream();
const browserStream = browserResponse.body;
if (!browserStream) {
throw new Error('Azure AI response body is empty');
}
return browserStream;
})();
const [prod, debug] = unifiedStream.tee();
if (process.env.DEBUG_AZURE_AI_CHAT_COMPLETION === '1') {
debugStream(debug).catch(console.error);
@@ -130,7 +162,7 @@ export class LobeAzureAI implements LobeRuntimeAI {
const regex = /^(https:\/\/)([^.]+)(\.cognitiveservices\.azure\.com\/.*)$/;
// 使用替换函数
return url.replace(regex, (match, protocol, subdomain, rest) => {
return url.replace(regex, (_match, protocol, _subdomain, rest) => {
// 将子域名替换为 '***'
return `${protocol}***${rest}`;
});
@@ -6,7 +6,7 @@ import { createBflImage } from './createImage';
import { BflStatusResponse } from './types';
// Mock external dependencies
vi.mock('../../utils/imageToBase64', () => ({
vi.mock('@lobechat/utils/imageToBase64', () => ({
imageUrlToBase64: vi.fn(),
}));
@@ -187,7 +187,7 @@ describe('createBflImage', () => {
it('should convert single imageUrl to image_prompt base64', async () => {
// Arrange
const { parseDataUri } = await import('../../utils/uriParser');
const { imageUrlToBase64 } = await import('../../utils/imageToBase64');
const { imageUrlToBase64 } = await import('@lobechat/utils/imageToBase64');
const { asyncifyPolling } = await import('../../utils/asyncifyPolling');
const mockParseDataUri = vi.mocked(parseDataUri);
@@ -226,7 +226,7 @@ describe('createBflImage', () => {
// Assert
expect(mockParseDataUri).toHaveBeenCalledWith('https://example.com/input.jpg');
expect(mockImageUrlToBase64).toHaveBeenCalledWith('https://example.com/input.jpg');
expect(mockImageUrlToBase64).toHaveBeenCalledWith('https://example.com/input.jpg', undefined);
const callArgs = mockFetch.mock.calls[0][1];
const requestBody = JSON.parse(callArgs?.body as string);
@@ -290,7 +290,7 @@ describe('createBflImage', () => {
it('should convert multiple imageUrls for Kontext models', async () => {
// Arrange
const { parseDataUri } = await import('../../utils/uriParser');
const { imageUrlToBase64 } = await import('../../utils/imageToBase64');
const { imageUrlToBase64 } = await import('@lobechat/utils/imageToBase64');
const { asyncifyPolling } = await import('../../utils/asyncifyPolling');
const mockParseDataUri = vi.mocked(parseDataUri);
@@ -350,7 +350,7 @@ describe('createBflImage', () => {
it('should limit imageUrls to maximum 4 images', async () => {
// Arrange
const { parseDataUri } = await import('../../utils/uriParser');
const { imageUrlToBase64 } = await import('../../utils/imageToBase64');
const { imageUrlToBase64 } = await import('@lobechat/utils/imageToBase64');
const { asyncifyPolling } = await import('../../utils/asyncifyPolling');
const mockParseDataUri = vi.mocked(parseDataUri);
@@ -1,3 +1,4 @@
import { imageUrlToBase64 } from '@lobechat/utils/imageToBase64';
import createDebug from 'debug';
import { RuntimeImageGenParamsValue } from 'model-bank';
@@ -5,7 +6,6 @@ import { AgentRuntimeErrorType } from '../../types/error';
import { CreateImagePayload, CreateImageResponse } from '../../types/image';
import { type TaskResult, asyncifyPolling } from '../../utils/asyncifyPolling';
import { AgentRuntimeError } from '../../utils/createError';
import { imageUrlToBase64 } from '../../utils/imageToBase64';
import { parseDataUri } from '../../utils/uriParser';
import {
BFL_ENDPOINTS,
@@ -23,13 +23,14 @@ const BASE_URL = 'https://api.bfl.ai';
interface BflCreateImageOptions {
apiKey: string;
baseURL?: string;
fetch?: typeof fetch;
provider: string;
}
/**
* Convert image URL to base64 format required by BFL API
*/
async function convertImageToBase64(imageUrl: string): Promise<string> {
async function convertImageToBase64(imageUrl: string, fetchImpl?: typeof fetch): Promise<string> {
try {
const { type } = parseDataUri(imageUrl);
@@ -44,7 +45,7 @@ async function convertImageToBase64(imageUrl: string): Promise<string> {
if (type === 'url') {
// Convert URL to base64
const { base64 } = await imageUrlToBase64(imageUrl);
const { base64 } = await imageUrlToBase64(imageUrl, fetchImpl);
return base64;
}
@@ -61,6 +62,7 @@ async function convertImageToBase64(imageUrl: string): Promise<string> {
async function buildRequestPayload(
model: BflModelId,
params: CreateImagePayload['params'],
fetchImpl?: typeof fetch,
): Promise<BflRequest> {
log('Building request payload for model: %s', model);
@@ -88,7 +90,7 @@ async function buildRequestPayload(
if (params.imageUrls && params.imageUrls.length > 0) {
for (let i = 0; i < Math.min(params.imageUrls.length, 4); i++) {
const fieldName = i === 0 ? 'input_image' : `input_image_${i + 1}`;
userPayload[fieldName] = await convertImageToBase64(params.imageUrls[i]);
userPayload[fieldName] = await convertImageToBase64(params.imageUrls[i], fetchImpl);
}
// Remove the original imageUrls field as it's now mapped to input_image_*
delete userPayload.imageUrls;
@@ -96,7 +98,7 @@ async function buildRequestPayload(
// Handle single image input (imageUrl)
if (params.imageUrl) {
userPayload.image_prompt = await convertImageToBase64(params.imageUrl);
userPayload.image_prompt = await convertImageToBase64(params.imageUrl, fetchImpl);
// Remove the original imageUrl field as it's now mapped to image_prompt
delete userPayload.imageUrl;
}
@@ -123,7 +125,9 @@ async function submitTask(
log('Submitting task to: %s', url);
const response = await fetch(url, {
const fetchImpl = options.fetch ?? fetch;
const response = await fetchImpl(url, {
body: JSON.stringify(payload),
headers: {
'Content-Type': 'application/json',
@@ -160,7 +164,9 @@ async function queryTaskStatus(
): Promise<BflResultResponse> {
log('Querying task status using polling URL: %s', pollingUrl);
const response = await fetch(pollingUrl, {
const fetchImpl = options.fetch ?? fetch;
const response = await fetchImpl(pollingUrl, {
headers: {
'accept': 'application/json',
'x-key': options.apiKey,
@@ -203,7 +209,7 @@ export async function createBflImage(
try {
// 1. Build request payload
const requestPayload = await buildRequestPayload(model as BflModelId, params);
const requestPayload = await buildRequestPayload(model as BflModelId, params, options.fetch);
// 2. Submit image generation task
const taskResponse = await submitTask(model as BflModelId, requestPayload, options);
@@ -9,15 +9,21 @@ import { createBflImage } from './createImage';
const log = createDebug('lobe-image:bfl');
interface BflAIParams extends ClientOptions {
fetch?: typeof fetch;
}
export class LobeBflAI implements LobeRuntimeAI {
private apiKey: string;
fetch?: typeof fetch;
baseURL?: string;
constructor({ apiKey, baseURL }: ClientOptions = {}) {
constructor({ apiKey, baseURL, fetch: customFetch }: BflAIParams = {}) {
if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);
this.apiKey = apiKey;
this.baseURL = baseURL || undefined;
this.fetch = customFetch;
log('BFL AI initialized');
}
@@ -30,6 +36,7 @@ export class LobeBflAI implements LobeRuntimeAI {
return await createBflImage(payload, {
apiKey: this.apiKey,
baseURL: this.baseURL,
fetch: this.fetch,
provider: 'bfl',
});
} catch (error) {
@@ -1,9 +1,9 @@
// @vitest-environment edge-runtime
import { GoogleGenAI } from '@google/genai';
import * as imageToBase64Module from '@lobechat/utils/imageToBase64';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { CreateImagePayload } from '../../types/image';
import * as imageToBase64Module from '../../utils/imageToBase64';
import { createGoogleImage } from './createImage';
const provider = 'google';
@@ -494,6 +494,7 @@ describe('createGoogleImage', () => {
// Assert
expect(imageToBase64Module.imageUrlToBase64).toHaveBeenCalledWith(
'https://example.com/image.jpg',
undefined,
);
expect(mockClient.models.generateContent).toHaveBeenCalledWith({
contents: [
@@ -1,11 +1,11 @@
import { Content, GenerateContentConfig, GoogleGenAI, Part } from '@google/genai';
import { imageUrlToBase64 } from '@lobechat/utils/imageToBase64';
import { convertGoogleAIUsage } from '../../core/usageConverters/google-ai';
import { CreateImagePayload, CreateImageResponse } from '../../types/image';
import { AgentRuntimeError } from '../../utils/createError';
import { getModelPricing } from '../../utils/getModelPricing';
import { parseGoogleErrorMessage } from '../../utils/googleErrorParser';
import { imageUrlToBase64 } from '../../utils/imageToBase64';
import { parseDataUri } from '../../utils/uriParser';
// Maximum number of images allowed for processing
@@ -14,7 +14,7 @@ const MAX_IMAGE_COUNT = 10;
/**
* Process a single image URL and convert it to Google AI Part format
*/
async function processImageForParts(imageUrl: string): Promise<Part> {
async function processImageForParts(imageUrl: string, fetchImpl?: typeof fetch): Promise<Part> {
const { mimeType, base64, type } = parseDataUri(imageUrl);
if (type === 'base64') {
@@ -29,7 +29,10 @@ async function processImageForParts(imageUrl: string): Promise<Part> {
},
};
} else if (type === 'url') {
const { base64: urlBase64, mimeType: urlMimeType } = await imageUrlToBase64(imageUrl);
const { base64: urlBase64, mimeType: urlMimeType } = await imageUrlToBase64(
imageUrl,
fetchImpl,
);
return {
inlineData: {
@@ -104,6 +107,7 @@ async function generateImageByChatModel(
client: GoogleGenAI,
payload: CreateImagePayload,
provider: string,
fetchImpl?: typeof fetch,
): Promise<CreateImageResponse> {
const { model, params } = payload;
const actualModel = model.replace(':image', '');
@@ -118,7 +122,7 @@ async function generateImageByChatModel(
// Add image for editing if provided
if (params.imageUrl && params.imageUrl !== null) {
const imagePart = await processImageForParts(params.imageUrl);
const imagePart = await processImageForParts(params.imageUrl, fetchImpl);
parts.push(imagePart);
}
@@ -129,7 +133,7 @@ async function generateImageByChatModel(
}
const imageParts = await Promise.all(
params.imageUrls.map((imageUrl) => processImageForParts(imageUrl)),
params.imageUrls.map((imageUrl) => processImageForParts(imageUrl, fetchImpl)),
);
parts.push(...imageParts);
}
@@ -174,13 +178,14 @@ export async function createGoogleImage(
client: GoogleGenAI,
provider: string,
payload: CreateImagePayload,
fetchImpl?: typeof fetch,
): Promise<CreateImageResponse> {
try {
const { model } = payload;
// Handle Gemini 2.5 Flash Image models that use generateContent
if (model.endsWith(':image')) {
return await generateImageByChatModel(client, payload, provider);
return await generateImageByChatModel(client, payload, provider, fetchImpl);
}
// Handle traditional Imagen models that use generateImages
@@ -2,13 +2,13 @@
import { GenerateContentResponse, Tool } from '@google/genai';
import { OpenAIChatMessage } from '@lobechat/model-runtime';
import { ChatStreamPayload } from '@lobechat/types';
import * as imageToBase64Module from '@lobechat/utils/imageToBase64';
import OpenAI from 'openai';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import { LOBE_ERROR_KEY } from '../../core/streams';
import { AgentRuntimeErrorType } from '../../types/error';
import * as debugStreamModule from '../../utils/debugStream';
import * as imageToBase64Module from '../../utils/imageToBase64';
import { LobeGoogleAI, resolveModelThinkingBudget } from './index';
const provider = 'google';
@@ -147,6 +147,7 @@ interface LobeGoogleAIParams {
baseURL?: string;
client?: GoogleGenAI;
defaultHeaders?: Record<string, any>;
fetch?: typeof fetch;
id?: string;
isVertexAi?: boolean;
}
@@ -165,6 +166,7 @@ const isAbortError = (error: Error): boolean => {
export class LobeGoogleAI implements LobeRuntimeAI {
private client: GoogleGenAI;
private isVertexAi: boolean;
fetch?: typeof fetch;
baseURL?: string;
apiKey?: string;
provider: string;
@@ -176,6 +178,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {
isVertexAi,
id,
defaultHeaders,
fetch: customFetch,
}: LobeGoogleAIParams = {}) {
if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);
@@ -184,7 +187,22 @@ export class LobeGoogleAI implements LobeRuntimeAI {
: undefined;
this.apiKey = apiKey;
this.client = client ? client : new GoogleGenAI({ apiKey, httpOptions });
if (client) {
this.client = client;
this.fetch =
customFetch ??
(typeof (client as any).fetch === 'function'
? ((client as any).fetch as typeof fetch)
: undefined);
} else {
const clientOptions: Record<string, any> = { apiKey };
if (httpOptions) clientOptions.httpOptions = httpOptions;
if (customFetch) clientOptions.fetch = customFetch;
this.client = new GoogleGenAI(clientOptions);
this.fetch = customFetch;
}
this.baseURL = client ? undefined : baseURL || DEFAULT_BASE_URL;
this.isVertexAi = isVertexAi || false;
@@ -209,7 +227,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {
thinkingBudget: resolvedThinkingBudget,
};
const contents = await buildGoogleMessages(payload.messages);
const contents = await buildGoogleMessages(payload.messages, this.fetch);
const controller = new AbortController();
const originalSignal = options?.signal;
@@ -316,7 +334,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {
* @see https://ai.google.dev/gemini-api/docs/image-generation#imagen
*/
async createImage(payload: CreateImagePayload): Promise<CreateImageResponse> {
return createGoogleImage(this.client, this.provider, payload);
return createGoogleImage(this.client, this.provider, payload, this.fetch);
}
/**
@@ -326,7 +344,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {
*/
async generateObject(payload: GenerateObjectPayload, options?: GenerateObjectOptions) {
// Convert OpenAI messages to Google format
const contents = await buildGoogleMessages(payload.messages);
const contents = await buildGoogleMessages(payload.messages, this.fetch);
// Handle tools-based structured output
if (payload.tools && payload.tools.length > 0) {
@@ -1,4 +1,5 @@
// @vitest-environment node
import { imageUrlToBase64 } from '@lobechat/utils';
import { ModelProvider } from 'model-bank';
import { Ollama } from 'ollama/browser';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
@@ -9,6 +10,9 @@ import * as debugStreamModule from '../../utils/debugStream';
import { LobeOllamaAI, params } from './index';
vi.mock('ollama/browser');
vi.mock('@lobechat/utils', () => ({
imageUrlToBase64: vi.fn(),
}));
// Mock the console.error to avoid polluting test output
vi.spyOn(console, 'error').mockImplementation(() => {});
@@ -462,13 +466,13 @@ describe('LobeOllamaAI', () => {
});
describe('buildOllamaMessages', () => {
it('should convert OpenAIChatMessage array to OllamaMessage array', () => {
it('should convert OpenAIChatMessage array to OllamaMessage array', async () => {
const messages = [
{ content: 'Hello', role: 'user' },
{ content: 'Hi there!', role: 'assistant' },
];
const ollamaMessages = ollamaAI['buildOllamaMessages'](messages as any);
const ollamaMessages = await ollamaAI['buildOllamaMessages'](messages as any);
expect(ollamaMessages).toEqual([
{ content: 'Hello', role: 'user' },
@@ -476,15 +480,15 @@ describe('LobeOllamaAI', () => {
]);
});
it('should handle empty message array', () => {
it('should handle empty message array', async () => {
const messages: any[] = [];
const ollamaMessages = ollamaAI['buildOllamaMessages'](messages);
const ollamaMessages = await ollamaAI['buildOllamaMessages'](messages);
expect(ollamaMessages).toEqual([]);
});
it('should handle multiple messages with different roles', () => {
it('should handle multiple messages with different roles', async () => {
const messages = [
{ content: 'Hello', role: 'system' },
{ content: 'Hi', role: 'user' },
@@ -492,7 +496,7 @@ describe('LobeOllamaAI', () => {
{ content: 'How are you?', role: 'user' },
];
const ollamaMessages = ollamaAI['buildOllamaMessages'](messages as any);
const ollamaMessages = await ollamaAI['buildOllamaMessages'](messages as any);
expect(ollamaMessages).toHaveLength(4);
expect(ollamaMessages[0].role).toBe('system');
@@ -503,26 +507,26 @@ describe('LobeOllamaAI', () => {
});
describe('convertContentToOllamaMessage', () => {
it('should convert string content to OllamaMessage', () => {
it('should convert string content to OllamaMessage', async () => {
const message = { content: 'Hello', role: 'user' };
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
expect(ollamaMessage).toEqual({ content: 'Hello', role: 'user' });
});
it('should convert text content to OllamaMessage', () => {
it('should convert text content to OllamaMessage', async () => {
const message = {
content: [{ type: 'text', text: 'Hello' }],
role: 'user',
};
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
expect(ollamaMessage).toEqual({ content: 'Hello', role: 'user' });
});
it('should convert image_url content to OllamaMessage with images', () => {
it('should convert image_url content to OllamaMessage with images', async () => {
const message = {
content: [
{
@@ -533,7 +537,7 @@ describe('LobeOllamaAI', () => {
role: 'user',
};
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
expect(ollamaMessage).toEqual({
content: '',
@@ -542,7 +546,7 @@ describe('LobeOllamaAI', () => {
});
});
it('should ignore invalid image_url content', () => {
it('should ignore invalid image_url content', async () => {
const message = {
content: [
{
@@ -553,7 +557,7 @@ describe('LobeOllamaAI', () => {
role: 'user',
};
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
expect(ollamaMessage).toEqual({
content: '',
@@ -561,7 +565,7 @@ describe('LobeOllamaAI', () => {
});
});
it('should handle mixed text and image content', () => {
it('should handle mixed text and image content', async () => {
const message = {
content: [
{ type: 'text', text: 'First text' },
@@ -578,7 +582,7 @@ describe('LobeOllamaAI', () => {
role: 'user',
};
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
expect(ollamaMessage).toEqual({
content: 'Second text', // Should keep latest text
@@ -587,13 +591,13 @@ describe('LobeOllamaAI', () => {
});
});
it('should handle content with empty text', () => {
it('should handle content with empty text', async () => {
const message = {
content: [{ type: 'text', text: '' }],
role: 'user',
};
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
expect(ollamaMessage).toEqual({
content: '',
@@ -601,7 +605,7 @@ describe('LobeOllamaAI', () => {
});
});
it('should handle content with only images (no text)', () => {
it('should handle content with only images (no text)', async () => {
const message = {
content: [
{
@@ -612,7 +616,7 @@ describe('LobeOllamaAI', () => {
role: 'user',
};
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
expect(ollamaMessage).toEqual({
content: '',
@@ -621,7 +625,7 @@ describe('LobeOllamaAI', () => {
});
});
it('should handle multiple images without text', () => {
it('should handle multiple images without text', async () => {
const message = {
content: [
{
@@ -640,7 +644,7 @@ describe('LobeOllamaAI', () => {
role: 'user',
};
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
expect(ollamaMessage).toEqual({
content: '',
@@ -649,7 +653,10 @@ describe('LobeOllamaAI', () => {
});
});
it('should ignore images with invalid data URIs', () => {
it('should handle URL image conversion failure gracefully', async () => {
// Mock imageUrlToBase64 to simulate conversion failure
vi.mocked(imageUrlToBase64).mockRejectedValue(new Error('Network error'));
const message = {
content: [
{ type: 'text', text: 'Hello' },
@@ -665,16 +672,18 @@ describe('LobeOllamaAI', () => {
role: 'user',
};
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
expect(ollamaMessage).toEqual({
content: 'Hello',
role: 'user',
images: ['valid123'],
});
// When URL conversion fails, it should continue processing other images
// The mock is set up to fail, so only the base64 image should be included
expect(ollamaMessage.content).toBe('Hello');
expect(ollamaMessage.role).toBe('user');
expect(ollamaMessage.images).toBeDefined();
// Should have at least the base64 image
expect(ollamaMessage.images).toContain('valid123');
});
it('should handle complex interleaved content', () => {
it('should handle complex interleaved content', async () => {
const message = {
content: [
{ type: 'text', text: 'Text 1' },
@@ -692,7 +701,7 @@ describe('LobeOllamaAI', () => {
role: 'user',
};
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
expect(ollamaMessage).toEqual({
content: 'Text 3', // Should keep latest text
@@ -701,7 +710,7 @@ describe('LobeOllamaAI', () => {
});
});
it('should handle assistant role with images', () => {
it('should handle assistant role with images', async () => {
const message = {
content: [
{ type: 'text', text: 'Here is the image' },
@@ -713,7 +722,7 @@ describe('LobeOllamaAI', () => {
role: 'assistant',
};
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
expect(ollamaMessage).toEqual({
content: 'Here is the image',
@@ -722,13 +731,13 @@ describe('LobeOllamaAI', () => {
});
});
it('should handle system role with text', () => {
it('should handle system role with text', async () => {
const message = {
content: [{ type: 'text', text: 'You are a helpful assistant' }],
role: 'system',
};
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
expect(ollamaMessage).toEqual({
content: 'You are a helpful assistant',
@@ -736,13 +745,13 @@ describe('LobeOllamaAI', () => {
});
});
it('should handle empty content array', () => {
it('should handle empty content array', async () => {
const message = {
content: [],
role: 'user',
};
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
expect(ollamaMessage).toEqual({
content: '',
@@ -1,4 +1,5 @@
import type { ChatModelCard } from '@lobechat/types';
import { imageUrlToBase64 } from '@lobechat/utils';
import { ModelProvider } from 'model-bank';
import { Ollama, Tool } from 'ollama/browser';
import { ClientOptions } from 'openai';
@@ -34,12 +35,17 @@ export const params = {
provider: ModelProvider.Ollama,
};
interface OllamaAIParams extends ClientOptions {
fetch?: typeof fetch;
}
export class LobeOllamaAI implements LobeRuntimeAI {
private client: Ollama;
baseURL?: string;
fetch?: typeof fetch;
constructor({ baseURL }: ClientOptions = {}) {
constructor({ baseURL, fetch: customFetch }: OllamaAIParams = {}) {
try {
if (baseURL) new URL(baseURL);
} catch (e) {
@@ -49,6 +55,7 @@ export class LobeOllamaAI implements LobeRuntimeAI {
this.client = new Ollama(!baseURL ? undefined : { host: baseURL });
if (baseURL) this.baseURL = baseURL;
this.fetch = customFetch;
}
async chat(payload: ChatStreamPayload, options?: ChatMethodOptions) {
@@ -61,7 +68,7 @@ export class LobeOllamaAI implements LobeRuntimeAI {
options?.signal?.addEventListener('abort', abort);
const response = await this.client.chat({
messages: this.buildOllamaMessages(payload.messages),
messages: await this.buildOllamaMessages(payload.messages),
model: payload.model,
options: {
frequency_penalty: payload.frequency_penalty,
@@ -169,11 +176,13 @@ export class LobeOllamaAI implements LobeRuntimeAI {
}
};
private buildOllamaMessages(messages: OpenAIChatMessage[]) {
return messages.map((message) => this.convertContentToOllamaMessage(message));
private async buildOllamaMessages(messages: OpenAIChatMessage[]) {
return Promise.all(messages.map((message) => this.convertContentToOllamaMessage(message)));
}
private convertContentToOllamaMessage = (message: OpenAIChatMessage): OllamaMessage => {
private convertContentToOllamaMessage = async (
message: OpenAIChatMessage,
): Promise<OllamaMessage> => {
if (typeof message.content === 'string') {
return { content: message.content, role: message.role };
}
@@ -183,6 +192,9 @@ export class LobeOllamaAI implements LobeRuntimeAI {
role: message.role,
};
// Collect all URL images to convert them in parallel
const urlImagePromises: Promise<string | null>[] = [];
for (const content of message.content) {
switch (content.type) {
case 'text': {
@@ -191,16 +203,39 @@ export class LobeOllamaAI implements LobeRuntimeAI {
break;
}
case 'image_url': {
const { base64 } = parseDataUri(content.image_url.url);
const { base64, type } = parseDataUri(content.image_url.url);
// If already base64 format, use it directly
if (base64) {
ollamaMessage.images ??= [];
ollamaMessage.images.push(base64);
}
// If it's a URL, collect the promise for parallel processing
else if (type === 'url') {
urlImagePromises.push(
imageUrlToBase64(content.image_url.url, this.fetch)
.then((result) =>
result.base64 && result.base64.trim() !== '' ? result.base64 : null,
)
.catch(() => null), // Return null on error to continue processing other images
);
}
break;
}
}
}
// Process all URL images in parallel
if (urlImagePromises.length > 0) {
const urlImages = await Promise.all(urlImagePromises);
const validUrlImages = urlImages.filter((img): img is string => img !== null);
if (validUrlImages.length > 0) {
ollamaMessage.images ??= [];
ollamaMessage.images.push(...validUrlImages);
}
}
return ollamaMessage;
};
@@ -1,3 +1,4 @@
import { imageUrlToBase64 } from '@lobechat/utils/imageToBase64';
import createDebug from 'debug';
import { RuntimeImageGenParamsValue } from 'model-bank';
@@ -5,7 +6,6 @@ import { CreateImageOptions } from '../../core/openaiCompatibleFactory';
import { CreateImagePayload, CreateImageResponse } from '../../types';
import { AgentRuntimeErrorType } from '../../types/error';
import { AgentRuntimeError } from '../../utils/createError';
import { imageUrlToBase64 } from '../../utils/imageToBase64';
import { parseDataUri } from '../../utils/uriParser';
const log = createDebug('lobe-image:siliconcloud');
@@ -16,7 +16,7 @@ interface SiliconCloudImageResponse {
timings: { inference: number };
}
async function convertToDataURI(imageUrl: string): Promise<string> {
async function convertToDataURI(imageUrl: string, fetchImpl?: typeof fetch): Promise<string> {
const { type, base64, mimeType } = parseDataUri(imageUrl);
if (type === 'base64') {
@@ -27,7 +27,10 @@ async function convertToDataURI(imageUrl: string): Promise<string> {
}
if (type === 'url') {
const { base64: urlBase64, mimeType: urlMimeType } = await imageUrlToBase64(imageUrl);
const { base64: urlBase64, mimeType: urlMimeType } = await imageUrlToBase64(
imageUrl,
fetchImpl,
);
return `data:${urlMimeType};base64,${urlBase64}`;
}
@@ -39,7 +42,7 @@ export async function createSiliconCloudImage(
options: CreateImageOptions,
): Promise<CreateImageResponse> {
const { model, params } = payload;
const { apiKey, baseURL, provider } = options;
const { apiKey, baseURL, provider, fetch: customFetch } = options;
log('Creating image with SiliconCloud model: %s, params: %O', model, params);
@@ -72,17 +75,17 @@ export async function createSiliconCloudImage(
if (key === 'imageUrl') {
if (typeof value === 'string') {
body['image'] = await convertToDataURI(value);
body['image'] = await convertToDataURI(value, customFetch);
}
continue;
}
if (key === 'imageUrls') {
if (Array.isArray(value) && value.length > 0) {
body['image'] = await convertToDataURI(value[0]);
body['image'] = await convertToDataURI(value[0], customFetch);
if (model === 'Qwen/Qwen-Image-Edit-2509') {
if (value.length > 1) body['image2'] = await convertToDataURI(value[1]);
if (value.length > 2) body['image3'] = await convertToDataURI(value[2]);
if (value.length > 1) body['image2'] = await convertToDataURI(value[1], customFetch);
if (value.length > 2) body['image3'] = await convertToDataURI(value[2], customFetch);
}
}
continue;
@@ -94,7 +97,8 @@ export async function createSiliconCloudImage(
log('Request body: %O', body);
const response = await fetch(endpoint, {
const fetchFn = customFetch ?? fetch;
const response = await fetchFn(endpoint, {
body: JSON.stringify(body),
headers: {
'Authorization': `Bearer ${apiKey}`,
@@ -7,7 +7,7 @@ import { LobeGoogleAI } from '../google';
const DEFAULT_VERTEXAI_LOCATION = 'global';
export class LobeVertexAI extends LobeGoogleAI {
static initFromVertexAI(params?: GoogleGenAIOptions) {
static initFromVertexAI(params?: GoogleGenAIOptions, customFetch?: typeof fetch) {
try {
const client = new GoogleGenAI({
...params,
@@ -15,7 +15,12 @@ export class LobeVertexAI extends LobeGoogleAI {
vertexai: true,
});
return new LobeGoogleAI({ apiKey: 'avoid-error', client, isVertexAi: true });
return new LobeGoogleAI({
apiKey: 'avoid-error',
client,
fetch: customFetch,
isVertexAi: true,
});
} catch (e) {
const err = e as Error;
@@ -1,91 +0,0 @@
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import { imageToBase64, imageUrlToBase64 } from './imageToBase64';
describe('imageToBase64', () => {
let mockImage: HTMLImageElement;
let mockCanvas: HTMLCanvasElement;
let mockContext: CanvasRenderingContext2D;
beforeEach(() => {
mockImage = {
width: 200,
height: 100,
} as HTMLImageElement;
mockContext = {
drawImage: vi.fn(),
} as unknown as CanvasRenderingContext2D;
mockCanvas = {
width: 0,
height: 0,
getContext: vi.fn().mockReturnValue(mockContext),
toDataURL: vi.fn().mockReturnValue('data:image/webp;base64,mockBase64Data'),
} as unknown as HTMLCanvasElement;
vi.spyOn(document, 'createElement').mockReturnValue(mockCanvas);
});
afterEach(() => {
vi.restoreAllMocks();
});
it('should convert image to base64 with correct size and type', () => {
const result = imageToBase64({ img: mockImage, size: 100, type: 'image/jpeg' });
expect(document.createElement).toHaveBeenCalledWith('canvas');
expect(mockCanvas.width).toBe(100);
expect(mockCanvas.height).toBe(100);
expect(mockCanvas.getContext).toHaveBeenCalledWith('2d');
expect(mockContext.drawImage).toHaveBeenCalledWith(mockImage, 50, 0, 100, 100, 0, 0, 100, 100);
expect(mockCanvas.toDataURL).toHaveBeenCalledWith('image/jpeg');
expect(result).toBe('data:image/webp;base64,mockBase64Data');
});
it('should use default type when not specified', () => {
imageToBase64({ img: mockImage, size: 100 });
expect(mockCanvas.toDataURL).toHaveBeenCalledWith('image/webp');
});
it('should handle taller images correctly', () => {
mockImage.width = 100;
mockImage.height = 200;
imageToBase64({ img: mockImage, size: 100 });
expect(mockContext.drawImage).toHaveBeenCalledWith(mockImage, 0, 50, 100, 100, 0, 0, 100, 100);
});
});
describe('imageUrlToBase64', () => {
const mockFetch = vi.fn();
const mockArrayBuffer = new ArrayBuffer(8);
beforeEach(() => {
global.fetch = mockFetch;
global.btoa = vi.fn().mockReturnValue('mockBase64String');
});
afterEach(() => {
vi.restoreAllMocks();
});
it('should convert image URL to base64 string', async () => {
mockFetch.mockResolvedValue({
arrayBuffer: () => Promise.resolve(mockArrayBuffer),
blob: () => Promise.resolve(new Blob([mockArrayBuffer], { type: 'image/jpg' })),
});
const result = await imageUrlToBase64('https://example.com/image.jpg');
expect(mockFetch).toHaveBeenCalledWith('https://example.com/image.jpg');
expect(global.btoa).toHaveBeenCalled();
expect(result).toEqual({ base64: 'mockBase64String', mimeType: 'image/jpg' });
});
it('should throw an error when fetch fails', async () => {
const mockError = new Error('Fetch failed');
mockFetch.mockRejectedValue(mockError);
await expect(imageUrlToBase64('https://example.com/image.jpg')).rejects.toThrow('Fetch failed');
});
});
@@ -1,62 +0,0 @@
export const imageToBase64 = ({
size,
img,
type = 'image/webp',
}: {
img: HTMLImageElement;
size: number;
type?: string;
}) => {
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d') as CanvasRenderingContext2D;
let startX = 0;
let startY = 0;
if (img.width > img.height) {
startX = (img.width - img.height) / 2;
} else {
startY = (img.height - img.width) / 2;
}
canvas.width = size;
canvas.height = size;
ctx.drawImage(
img,
startX,
startY,
Math.min(img.width, img.height),
Math.min(img.width, img.height),
0,
0,
size,
size,
);
return canvas.toDataURL(type);
};
export const imageUrlToBase64 = async (
imageUrl: string,
): Promise<{ base64: string; mimeType: string }> => {
try {
const res = await fetch(imageUrl);
const blob = await res.blob();
const arrayBuffer = await blob.arrayBuffer();
const base64 =
typeof btoa === 'function'
? btoa(
new Uint8Array(arrayBuffer).reduce(
(data, byte) => data + String.fromCharCode(byte),
'',
),
)
: Buffer.from(arrayBuffer).toString('base64');
return { base64, mimeType: blob.type };
} catch (error) {
console.error('Error converting image to base64:', error);
throw error;
}
};
+25 -9
View File
@@ -1,13 +1,34 @@
import fetch from 'node-fetch';
import type { RequestInit as NodeFetchOptions } from 'node-fetch';
import { RequestFilteringAgentOptions, useAgent as ssrfAgent } from 'request-filtering-agent';
interface FetchOptions extends RequestInit {
ssrf?: boolean;
}
const toStandardResponse = async (response: Awaited<ReturnType<typeof fetch>>) => {
return new Response(await response.arrayBuffer(), {
headers: response.headers as any,
status: response.status,
statusText: response.statusText,
});
};
/**
* SSRF-safe fetch implementation for server-side use
* Uses request-filtering-agent to prevent requests to private IP addresses
*/
// eslint-disable-next-line no-undef
export const ssrfSafeFetch = async (url: string, options?: RequestInit): Promise<Response> => {
export const ssrfSafeFetch = async (url: string, options?: FetchOptions): Promise<Response> => {
const { ssrf, ...restOptions } = options ?? {};
const fetchOptions = restOptions as NodeFetchOptions;
try {
if (!ssrf) {
const response = await fetch(url, fetchOptions);
return await toStandardResponse(response);
}
// Configure SSRF protection options
const agentOptions: RequestFilteringAgentOptions = {
allowIPAddressList: process.env.SSRF_ALLOW_IP_ADDRESS_LIST?.split(',') || [],
@@ -18,16 +39,11 @@ export const ssrfSafeFetch = async (url: string, options?: RequestInit): Promise
// Use node-fetch with SSRF protection agent
const response = await fetch(url, {
...options,
...fetchOptions,
agent: ssrfAgent(url, agentOptions),
} as any);
// Convert node-fetch Response to standard Response
return new Response(await response.arrayBuffer(), {
headers: response.headers as any,
status: response.status,
statusText: response.statusText,
});
return await toStandardResponse(response);
} catch (error) {
console.error('SSRF-safe fetch error:', error);
throw new Error(
+2
View File
@@ -3,7 +3,9 @@
"version": "1.0.0",
"private": true,
"description": "",
"exports": "./index.ts",
"main": "index.ts",
"types": "index.ts",
"scripts": {
"test": "vitest run"
},
+2 -1
View File
@@ -17,7 +17,8 @@
"@lobechat/const": "workspace:*",
"@lobechat/types": "workspace:*",
"dayjs": "^1.11.18",
"dompurify": "^3.2.7"
"dompurify": "^3.2.7",
"ssrf-safe-fetch": "workspace:*"
},
"devDependencies": {
"vitest-canvas-mock": "^0.3.3"
+10 -1
View File
@@ -36,11 +36,20 @@ export const imageToBase64 = ({
return canvas.toDataURL(type);
};
/**
* Convert image URL to base64.
* Accepts an optional custom fetch implementation (e.g., SSRF-safe fetch) for server environments.
*/
export const imageUrlToBase64 = async (
imageUrl: string,
customFetch?: typeof fetch,
): Promise<{ base64: string; mimeType: string }> => {
try {
const res = await fetch(imageUrl);
const fetchFn = customFetch || fetch;
const ssrfOptions = customFetch ? ({ ssrf: true } as RequestInit) : undefined;
const res =
ssrfOptions === undefined ? await fetchFn(imageUrl) : await fetchFn(imageUrl, ssrfOptions);
const blob = await res.blob();
const arrayBuffer = await blob.arrayBuffer();
+22 -7
View File
@@ -4,6 +4,7 @@ import { LobeVertexAI } from '@lobechat/model-runtime/vertexai';
import { ClientSecretPayload } from '@lobechat/types';
import { safeParseJSON } from '@lobechat/utils';
import { ModelProvider } from 'model-bank';
import { ssrfSafeFetch } from 'ssrf-safe-fetch';
import { getLLMConfig } from '@/envs/llm';
@@ -11,6 +12,8 @@ import apiKeyManager from './apiKeyManager';
export * from './trace';
type RuntimeInitializeParams = Parameters<typeof ModelRuntime.initializeWithProvider>[1];
/**
* Retrieves the options object from environment and apikeymanager
* based on the provider and payload.
@@ -171,7 +174,10 @@ const buildVertexOptions = (
const project = projectFromParams ?? projectFromCredentials ?? projectFromEnv;
const location =
(params.location as string | undefined) ?? payload.vertexAIRegion ?? process.env.VERTEXAI_LOCATION ?? undefined;
(params.location as string | undefined) ??
payload.vertexAIRegion ??
process.env.VERTEXAI_LOCATION ??
undefined;
const googleAuthOptions = params.googleAuthOptions ?? (credentials ? { credentials } : undefined);
@@ -197,19 +203,28 @@ const buildVertexOptions = (
export const initModelRuntimeWithUserPayload = (
provider: string,
payload: ClientSecretPayload,
params: any = {},
params: RuntimeInitializeParams = {},
) => {
const { fetch: customFetch, ...restParams } = params ?? {};
const fetchImpl = (customFetch ?? ssrfSafeFetch) as typeof fetch;
const runtimeProvider = payload.runtimeProvider ?? provider;
if (runtimeProvider === ModelProvider.VertexAI) {
const vertexOptions = buildVertexOptions(payload, params);
const runtime = LobeVertexAI.initFromVertexAI(vertexOptions);
const vertexOptions = buildVertexOptions(payload, restParams as Partial<GoogleGenAIOptions>);
const runtime = LobeVertexAI.initFromVertexAI(vertexOptions, fetchImpl);
return new ModelRuntime(runtime);
}
return ModelRuntime.initializeWithProvider(runtimeProvider, {
const mergedParams = {
...getParamsFromPayload(runtimeProvider, payload),
...params,
});
...restParams,
};
const runtimeParams = {
...mergedParams,
fetch: (mergedParams as Record<string, any>).fetch ?? fetchImpl,
};
return ModelRuntime.initializeWithProvider(runtimeProvider, runtimeParams);
};