mirror of
https://github.com/lobehub/lobe-chat.git
synced 2026-06-17 04:55:51 +00:00
Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 976155f0e8 | |||
| c73a15d4a5 | |||
| 2187409dcf | |||
| 0274a623fa | |||
| 55d6597644 | |||
| 2bdf218c57 | |||
| aedfc5d59c | |||
| 81deeb7374 | |||
| 8312f34703 | |||
| 751d09b14a | |||
| 417182b31b | |||
| f91c036c6c | |||
| f6c5ad1498 | |||
| 5612dfe5a4 | |||
| ade6faa020 | |||
| cbd0346ea0 | |||
| de8e4e4a1a |
@@ -22,6 +22,9 @@ config.rules['unicorn/no-array-callback-reference'] = 0;
|
||||
// FIXME: Linting error in src/app/[variants]/(main)/chat/features/Migration/DBReader.ts, the fundamental solution should be upgrading typescript-eslint
|
||||
config.rules['@typescript-eslint/no-useless-constructor'] = 0;
|
||||
|
||||
if (!config.globals) config.globals = {};
|
||||
config.globals.RequestInit = true;
|
||||
|
||||
config.overrides = [
|
||||
{
|
||||
extends: ['plugin:mdx/recommended'],
|
||||
|
||||
@@ -25,6 +25,16 @@ export interface AgentChatOptions {
|
||||
trace?: TracePayload;
|
||||
}
|
||||
|
||||
export interface ModelRuntimeOptions
|
||||
extends ClientOptions,
|
||||
LobeBedrockAIParams,
|
||||
LobeCloudflareParams {
|
||||
apiKey?: string;
|
||||
apiVersion?: string;
|
||||
baseURL?: string;
|
||||
fetch?: typeof fetch;
|
||||
}
|
||||
|
||||
export class ModelRuntime {
|
||||
private _runtime: LobeRuntimeAI;
|
||||
|
||||
@@ -114,14 +124,7 @@ export class ModelRuntime {
|
||||
* - `src/app/api/chat/agentRuntime.ts: initAgentRuntimeWithUserPayload` on server
|
||||
* - `src/services/chat.ts: initializeWithClientStore` on client
|
||||
*/
|
||||
static initializeWithProvider(
|
||||
provider: string,
|
||||
params: Partial<
|
||||
ClientOptions &
|
||||
LobeBedrockAIParams &
|
||||
LobeCloudflareParams & { apiKey?: string; apiVersion?: string; baseURL?: string }
|
||||
>,
|
||||
) {
|
||||
static initializeWithProvider(provider: string, params: Partial<ModelRuntimeOptions>) {
|
||||
// @ts-expect-error runtime map not include vertex so it will be undefined
|
||||
const providerAI = providerRuntimeMap[provider] ?? LobeOpenAI;
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { imageUrlToBase64 } from '@lobechat/utils/imageToBase64';
|
||||
import { OpenAI } from 'openai';
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { OpenAIChatMessage, UserMessageContentPart } from '../../types/chat';
|
||||
import { imageUrlToBase64 } from '../../utils/imageToBase64';
|
||||
import { parseDataUri } from '../../utils/uriParser';
|
||||
import {
|
||||
buildAnthropicBlock,
|
||||
@@ -19,7 +19,7 @@ vi.mock('../../utils/uriParser', () => ({
|
||||
type: 'base64',
|
||||
}),
|
||||
}));
|
||||
vi.mock('../../utils/imageToBase64');
|
||||
vi.mock('@lobechat/utils/imageToBase64');
|
||||
|
||||
describe('anthropicHelpers', () => {
|
||||
describe('buildAnthropicBlock', () => {
|
||||
@@ -65,7 +65,7 @@ describe('anthropicHelpers', () => {
|
||||
const result = await buildAnthropicBlock(content);
|
||||
|
||||
expect(parseDataUri).toHaveBeenCalledWith(content.image_url.url);
|
||||
expect(imageUrlToBase64).toHaveBeenCalledWith(content.image_url.url);
|
||||
expect(imageUrlToBase64).toHaveBeenCalledWith(content.image_url.url, undefined);
|
||||
expect(result).toEqual({
|
||||
source: {
|
||||
data: 'convertedBase64String',
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import Anthropic from '@anthropic-ai/sdk';
|
||||
import { imageUrlToBase64 } from '@lobechat/utils/imageToBase64';
|
||||
import OpenAI from 'openai';
|
||||
|
||||
import { OpenAIChatMessage, UserMessageContentPart } from '../../types';
|
||||
import { imageUrlToBase64 } from '../../utils/imageToBase64';
|
||||
import { parseDataUri } from '../../utils/uriParser';
|
||||
|
||||
export const buildAnthropicBlock = async (
|
||||
content: UserMessageContentPart,
|
||||
customFetch?: typeof fetch,
|
||||
): Promise<Anthropic.ContentBlock | Anthropic.ImageBlockParam | undefined> => {
|
||||
switch (content.type) {
|
||||
case 'thinking': {
|
||||
@@ -34,7 +35,7 @@ export const buildAnthropicBlock = async (
|
||||
};
|
||||
|
||||
if (type === 'url') {
|
||||
const { base64, mimeType } = await imageUrlToBase64(content.image_url.url);
|
||||
const { base64, mimeType } = await imageUrlToBase64(content.image_url.url, customFetch);
|
||||
return {
|
||||
source: {
|
||||
data: base64 as string,
|
||||
@@ -50,9 +51,11 @@ export const buildAnthropicBlock = async (
|
||||
}
|
||||
};
|
||||
|
||||
const buildArrayContent = async (content: UserMessageContentPart[]) => {
|
||||
const buildArrayContent = async (content: UserMessageContentPart[], customFetch?: typeof fetch) => {
|
||||
let messageContent = (await Promise.all(
|
||||
(content as UserMessageContentPart[]).map(async (c) => await buildAnthropicBlock(c)),
|
||||
(content as UserMessageContentPart[]).map(
|
||||
async (c) => await buildAnthropicBlock(c, customFetch),
|
||||
),
|
||||
)) as Anthropic.Messages.ContentBlockParam[];
|
||||
|
||||
messageContent = messageContent.filter(Boolean);
|
||||
@@ -62,6 +65,7 @@ const buildArrayContent = async (content: UserMessageContentPart[]) => {
|
||||
|
||||
export const buildAnthropicMessage = async (
|
||||
message: OpenAIChatMessage,
|
||||
customFetch?: typeof fetch,
|
||||
): Promise<Anthropic.Messages.MessageParam> => {
|
||||
const content = message.content as string | UserMessageContentPart[];
|
||||
|
||||
@@ -72,7 +76,8 @@ export const buildAnthropicMessage = async (
|
||||
|
||||
case 'user': {
|
||||
return {
|
||||
content: typeof content === 'string' ? content : await buildArrayContent(content),
|
||||
content:
|
||||
typeof content === 'string' ? content : await buildArrayContent(content, customFetch),
|
||||
role: 'user',
|
||||
};
|
||||
}
|
||||
@@ -100,7 +105,7 @@ export const buildAnthropicMessage = async (
|
||||
? ([{ text: message.content, type: 'text' }] as UserMessageContentPart[])
|
||||
: content;
|
||||
|
||||
const messageContent = await buildArrayContent(rawContent);
|
||||
const messageContent = await buildArrayContent(rawContent, customFetch);
|
||||
|
||||
return {
|
||||
content: [
|
||||
@@ -129,7 +134,7 @@ export const buildAnthropicMessage = async (
|
||||
|
||||
export const buildAnthropicMessages = async (
|
||||
oaiMessages: OpenAIChatMessage[],
|
||||
options: { enabledContextCaching?: boolean } = {},
|
||||
options: { customFetch?: typeof fetch; enabledContextCaching?: boolean } = {},
|
||||
): Promise<Anthropic.Messages.MessageParam[]> => {
|
||||
const messages: Anthropic.Messages.MessageParam[] = [];
|
||||
let pendingToolResults: Anthropic.ToolResultBlockParam[] = [];
|
||||
@@ -175,7 +180,7 @@ export const buildAnthropicMessages = async (
|
||||
});
|
||||
}
|
||||
} else {
|
||||
const anthropicMessage = await buildAnthropicMessage(message);
|
||||
const anthropicMessage = await buildAnthropicMessage(message, options.customFetch);
|
||||
messages.push({ ...anthropicMessage, role: anthropicMessage.role });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
// @vitest-environment node
|
||||
import { Type as SchemaType } from '@google/genai';
|
||||
import * as imageToBase64Module from '@lobechat/utils/imageToBase64';
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { ChatCompletionTool, OpenAIChatMessage, UserMessageContentPart } from '../../types';
|
||||
import * as imageToBase64Module from '../../utils/imageToBase64';
|
||||
import { parseDataUri } from '../../utils/uriParser';
|
||||
import {
|
||||
buildGoogleMessage,
|
||||
@@ -18,7 +18,7 @@ vi.mock('../../utils/uriParser', () => ({
|
||||
parseDataUri: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../../utils/imageToBase64', () => ({
|
||||
vi.mock('@lobechat/utils/imageToBase64', () => ({
|
||||
imageUrlToBase64: vi.fn(),
|
||||
}));
|
||||
|
||||
@@ -102,7 +102,7 @@ describe('google contextBuilders', () => {
|
||||
},
|
||||
});
|
||||
|
||||
expect(imageToBase64Module.imageUrlToBase64).toHaveBeenCalledWith(imageUrl);
|
||||
expect(imageToBase64Module.imageUrlToBase64).toHaveBeenCalledWith(imageUrl, undefined);
|
||||
});
|
||||
|
||||
it('should throw TypeError for unsupported image URL types', async () => {
|
||||
|
||||
@@ -5,9 +5,9 @@ import {
|
||||
Part,
|
||||
Type as SchemaType,
|
||||
} from '@google/genai';
|
||||
import { imageUrlToBase64 } from '@lobechat/utils/imageToBase64';
|
||||
|
||||
import { ChatCompletionTool, OpenAIChatMessage, UserMessageContentPart } from '../../types';
|
||||
import { imageUrlToBase64 } from '../../utils/imageToBase64';
|
||||
import { safeParseJSON } from '../../utils/safeParseJSON';
|
||||
import { parseDataUri } from '../../utils/uriParser';
|
||||
|
||||
@@ -16,6 +16,7 @@ import { parseDataUri } from '../../utils/uriParser';
|
||||
*/
|
||||
export const buildGooglePart = async (
|
||||
content: UserMessageContentPart,
|
||||
customFetch?: typeof fetch,
|
||||
): Promise<Part | undefined> => {
|
||||
switch (content.type) {
|
||||
default: {
|
||||
@@ -40,7 +41,7 @@ export const buildGooglePart = async (
|
||||
}
|
||||
|
||||
if (type === 'url') {
|
||||
const { base64, mimeType } = await imageUrlToBase64(content.image_url.url);
|
||||
const { base64, mimeType } = await imageUrlToBase64(content.image_url.url, customFetch);
|
||||
|
||||
return {
|
||||
inlineData: { data: base64, mimeType },
|
||||
@@ -66,7 +67,12 @@ export const buildGooglePart = async (
|
||||
if (type === 'url') {
|
||||
// For video URLs, we need to fetch and convert to base64
|
||||
// Note: This might need size/duration limits for practical use
|
||||
const response = await fetch(content.video_url.url);
|
||||
const fetchFn = customFetch || fetch;
|
||||
const ssrfOptions = customFetch ? ({ ssrf: true } as RequestInit) : undefined;
|
||||
const response =
|
||||
ssrfOptions === undefined
|
||||
? await fetchFn(content.video_url.url)
|
||||
: await fetchFn(content.video_url.url, ssrfOptions);
|
||||
const arrayBuffer = await response.arrayBuffer();
|
||||
const base64 = Buffer.from(arrayBuffer).toString('base64');
|
||||
const mimeType = response.headers.get('content-type') || 'video/mp4';
|
||||
@@ -87,6 +93,7 @@ export const buildGooglePart = async (
|
||||
export const buildGoogleMessage = async (
|
||||
message: OpenAIChatMessage,
|
||||
toolCallNameMap?: Map<string, string>,
|
||||
customFetch?: typeof fetch,
|
||||
): Promise<Content> => {
|
||||
const content = message.content as string | UserMessageContentPart[];
|
||||
|
||||
@@ -124,7 +131,9 @@ export const buildGoogleMessage = async (
|
||||
const getParts = async () => {
|
||||
if (typeof content === 'string') return [{ text: content }];
|
||||
|
||||
const parts = await Promise.all(content.map(async (c) => await buildGooglePart(c)));
|
||||
const parts = await Promise.all(
|
||||
content.map(async (c) => await buildGooglePart(c, customFetch)),
|
||||
);
|
||||
return parts.filter(Boolean) as Part[];
|
||||
};
|
||||
|
||||
@@ -137,7 +146,10 @@ export const buildGoogleMessage = async (
|
||||
/**
|
||||
* Convert messages from the OpenAI format to Google GenAI SDK format
|
||||
*/
|
||||
export const buildGoogleMessages = async (messages: OpenAIChatMessage[]): Promise<Content[]> => {
|
||||
export const buildGoogleMessages = async (
|
||||
messages: OpenAIChatMessage[],
|
||||
customFetch?: typeof fetch,
|
||||
): Promise<Content[]> => {
|
||||
const toolCallNameMap = new Map<string, string>();
|
||||
|
||||
// Build tool call id to name mapping
|
||||
@@ -153,7 +165,7 @@ export const buildGoogleMessages = async (messages: OpenAIChatMessage[]): Promis
|
||||
|
||||
const pools = messages
|
||||
.filter((message) => message.role !== 'function')
|
||||
.map(async (msg) => await buildGoogleMessage(msg, toolCallNameMap));
|
||||
.map(async (msg) => await buildGoogleMessage(msg, toolCallNameMap, customFetch));
|
||||
|
||||
const contents = await Promise.all(pools);
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { imageUrlToBase64 } from '@lobechat/utils/imageToBase64';
|
||||
import OpenAI from 'openai';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { imageUrlToBase64 } from '../../utils/imageToBase64';
|
||||
import { parseDataUri } from '../../utils/uriParser';
|
||||
import {
|
||||
convertImageUrlToFile,
|
||||
@@ -11,7 +11,7 @@ import {
|
||||
} from './openai';
|
||||
|
||||
// 模拟依赖
|
||||
vi.mock('../../utils/imageToBase64');
|
||||
vi.mock('@lobechat/utils/imageToBase64');
|
||||
vi.mock('../../utils/uriParser');
|
||||
|
||||
describe('convertMessageContent', () => {
|
||||
@@ -52,7 +52,7 @@ describe('convertMessageContent', () => {
|
||||
});
|
||||
|
||||
expect(parseDataUri).toHaveBeenCalledWith('https://example.com/image.jpg');
|
||||
expect(imageUrlToBase64).toHaveBeenCalledWith('https://example.com/image.jpg');
|
||||
expect(imageUrlToBase64).toHaveBeenCalledWith('https://example.com/image.jpg', undefined);
|
||||
});
|
||||
|
||||
it('should not convert image URL when not necessary', async () => {
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
import { imageUrlToBase64 } from '@lobechat/utils/imageToBase64';
|
||||
import OpenAI, { toFile } from 'openai';
|
||||
|
||||
import { disableStreamModels, systemToUserModels } from '../../const/models';
|
||||
import { ChatStreamPayload, OpenAIChatMessage } from '../../types';
|
||||
import { imageUrlToBase64 } from '../../utils/imageToBase64';
|
||||
import { parseDataUri } from '../../utils/uriParser';
|
||||
|
||||
export const convertMessageContent = async (
|
||||
content: OpenAI.ChatCompletionContentPart,
|
||||
customFetch?: typeof fetch,
|
||||
): Promise<OpenAI.ChatCompletionContentPart> => {
|
||||
if (content.type === 'image_url') {
|
||||
const { type } = parseDataUri(content.image_url.url);
|
||||
|
||||
if (type === 'url' && process.env.LLM_VISION_IMAGE_USE_BASE64 === '1') {
|
||||
const { base64, mimeType } = await imageUrlToBase64(content.image_url.url);
|
||||
const { base64, mimeType } = await imageUrlToBase64(content.image_url.url, customFetch);
|
||||
|
||||
return {
|
||||
...content,
|
||||
@@ -24,7 +25,10 @@ export const convertMessageContent = async (
|
||||
return content;
|
||||
};
|
||||
|
||||
export const convertOpenAIMessages = async (messages: OpenAI.ChatCompletionMessageParam[]) => {
|
||||
export const convertOpenAIMessages = async (
|
||||
messages: OpenAI.ChatCompletionMessageParam[],
|
||||
customFetch?: typeof fetch,
|
||||
) => {
|
||||
return (await Promise.all(
|
||||
messages.map(async (message) => ({
|
||||
...message,
|
||||
@@ -33,7 +37,7 @@ export const convertOpenAIMessages = async (messages: OpenAI.ChatCompletionMessa
|
||||
? message.content
|
||||
: await Promise.all(
|
||||
(message.content || []).map((c) =>
|
||||
convertMessageContent(c as OpenAI.ChatCompletionContentPart),
|
||||
convertMessageContent(c as OpenAI.ChatCompletionContentPart, customFetch),
|
||||
),
|
||||
),
|
||||
})),
|
||||
@@ -42,6 +46,7 @@ export const convertOpenAIMessages = async (messages: OpenAI.ChatCompletionMessa
|
||||
|
||||
export const convertOpenAIResponseInputs = async (
|
||||
messages: OpenAI.ChatCompletionMessageParam[],
|
||||
customFetch?: typeof fetch,
|
||||
) => {
|
||||
let input: OpenAI.Responses.ResponseInputItem[] = [];
|
||||
await Promise.all(
|
||||
@@ -83,7 +88,10 @@ export const convertOpenAIResponseInputs = async (
|
||||
return { ...c, type: 'input_text' };
|
||||
}
|
||||
|
||||
const image = await convertMessageContent(c as OpenAI.ChatCompletionContentPart);
|
||||
const image = await convertMessageContent(
|
||||
c as OpenAI.ChatCompletionContentPart,
|
||||
customFetch,
|
||||
);
|
||||
return {
|
||||
image_url: (image as OpenAI.ChatCompletionContentPartImage).image_url?.url,
|
||||
type: 'input_image',
|
||||
@@ -127,7 +135,7 @@ export const pruneReasoningPayload = (payload: ChatStreamPayload) => {
|
||||
/**
|
||||
* Convert image URL (data URL or HTTP URL) to File object for OpenAI API
|
||||
*/
|
||||
export const convertImageUrlToFile = async (imageUrl: string) => {
|
||||
export const convertImageUrlToFile = async (imageUrl: string, customFetch?: typeof fetch) => {
|
||||
let buffer: Buffer;
|
||||
let mimeType: string;
|
||||
|
||||
@@ -138,7 +146,10 @@ export const convertImageUrlToFile = async (imageUrl: string) => {
|
||||
buffer = Buffer.from(base64Data, 'base64');
|
||||
} else {
|
||||
// a http url
|
||||
const response = await fetch(imageUrl);
|
||||
const fetchFn = customFetch || fetch;
|
||||
const ssrfOptions = customFetch ? ({ ssrf: true } as RequestInit) : undefined;
|
||||
const response =
|
||||
ssrfOptions === undefined ? await fetchFn(imageUrl) : await fetchFn(imageUrl, ssrfOptions);
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to fetch image from ${imageUrl}: ${response.statusText}`);
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
// @vitest-environment node
|
||||
import * as imageToBase64Module from '@lobechat/utils/imageToBase64';
|
||||
import OpenAI from 'openai';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { CreateImagePayload } from '../../types/image';
|
||||
import * as imageToBase64Module from '../../utils/imageToBase64';
|
||||
import * as uriParserModule from '../../utils/uriParser';
|
||||
import { createOpenAICompatibleImage } from './createImage';
|
||||
|
||||
@@ -81,7 +81,11 @@ describe('createOpenAICompatibleImage', () => {
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createOpenAICompatibleImage(mockClient, payload, 'openrouter');
|
||||
const result = await createOpenAICompatibleImage({
|
||||
client: mockClient,
|
||||
payload,
|
||||
provider: 'openrouter',
|
||||
});
|
||||
|
||||
expect(result.imageUrl).toBe('data:image/png;base64,generatedImageData');
|
||||
expect(mockClient.chat.completions.create).toHaveBeenCalled();
|
||||
@@ -122,7 +126,11 @@ describe('createOpenAICompatibleImage', () => {
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createOpenAICompatibleImage(mockClient, payload, 'test-provider');
|
||||
const result = await createOpenAICompatibleImage({
|
||||
client: mockClient,
|
||||
payload,
|
||||
provider: 'test-provider',
|
||||
});
|
||||
|
||||
expect(result.imageUrl).toBe('data:image/png;base64,result');
|
||||
});
|
||||
@@ -145,7 +153,7 @@ describe('createOpenAICompatibleImage', () => {
|
||||
};
|
||||
|
||||
await expect(
|
||||
createOpenAICompatibleImage(mockClient, payload, 'test-provider'),
|
||||
createOpenAICompatibleImage({ client: mockClient, payload, provider: 'test-provider' }),
|
||||
).rejects.toThrow(
|
||||
"Failed to process image URL: TypeError: Image URL doesn't contain base64 data",
|
||||
);
|
||||
@@ -191,9 +199,16 @@ describe('createOpenAICompatibleImage', () => {
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createOpenAICompatibleImage(mockClient, payload, 'test-provider');
|
||||
const result = await createOpenAICompatibleImage({
|
||||
client: mockClient,
|
||||
payload,
|
||||
provider: 'test-provider',
|
||||
});
|
||||
|
||||
expect(imageToBase64Module.imageUrlToBase64).toHaveBeenCalledWith(mockHttpImageUrl);
|
||||
expect(imageToBase64Module.imageUrlToBase64).toHaveBeenCalledWith(
|
||||
mockHttpImageUrl,
|
||||
undefined,
|
||||
);
|
||||
expect(result.imageUrl).toBe('data:image/png;base64,output');
|
||||
});
|
||||
|
||||
@@ -215,7 +230,7 @@ describe('createOpenAICompatibleImage', () => {
|
||||
};
|
||||
|
||||
await expect(
|
||||
createOpenAICompatibleImage(mockClient, payload, 'test-provider'),
|
||||
createOpenAICompatibleImage({ client: mockClient, payload, provider: 'test-provider' }),
|
||||
).rejects.toThrow(
|
||||
`Failed to process image URL: TypeError: Currently we don't support image url: ${mockInvalidUrl}`,
|
||||
);
|
||||
@@ -249,7 +264,11 @@ describe('createOpenAICompatibleImage', () => {
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createOpenAICompatibleImage(mockClient, payload, 'openrouter');
|
||||
const result = await createOpenAICompatibleImage({
|
||||
client: mockClient,
|
||||
payload,
|
||||
provider: 'openrouter',
|
||||
});
|
||||
|
||||
expect(result.imageUrl).toBe('data:image/png;base64,generatedWithoutInputImage');
|
||||
expect(mockClient.chat.completions.create).toHaveBeenCalledWith({
|
||||
@@ -296,7 +315,11 @@ describe('createOpenAICompatibleImage', () => {
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createOpenAICompatibleImage(mockClient, payload, 'test-provider');
|
||||
const result = await createOpenAICompatibleImage({
|
||||
client: mockClient,
|
||||
payload,
|
||||
provider: 'test-provider',
|
||||
});
|
||||
|
||||
expect(result.imageUrl).toBe('data:image/png;base64,generatedImage');
|
||||
// Should not include image in content array
|
||||
@@ -324,7 +347,7 @@ describe('createOpenAICompatibleImage', () => {
|
||||
};
|
||||
|
||||
await expect(
|
||||
createOpenAICompatibleImage(mockClient, payload, 'test-provider'),
|
||||
createOpenAICompatibleImage({ client: mockClient, payload, provider: 'test-provider' }),
|
||||
).rejects.toThrow('No message in chat completion response');
|
||||
});
|
||||
|
||||
@@ -349,7 +372,7 @@ describe('createOpenAICompatibleImage', () => {
|
||||
};
|
||||
|
||||
await expect(
|
||||
createOpenAICompatibleImage(mockClient, payload, 'test-provider'),
|
||||
createOpenAICompatibleImage({ client: mockClient, payload, provider: 'test-provider' }),
|
||||
).rejects.toThrow('No image generated in chat completion response');
|
||||
});
|
||||
|
||||
@@ -374,7 +397,7 @@ describe('createOpenAICompatibleImage', () => {
|
||||
};
|
||||
|
||||
await expect(
|
||||
createOpenAICompatibleImage(mockClient, payload, 'test-provider'),
|
||||
createOpenAICompatibleImage({ client: mockClient, payload, provider: 'test-provider' }),
|
||||
).rejects.toThrow('No image generated in chat completion response');
|
||||
});
|
||||
|
||||
@@ -404,7 +427,7 @@ describe('createOpenAICompatibleImage', () => {
|
||||
};
|
||||
|
||||
await expect(
|
||||
createOpenAICompatibleImage(mockClient, payload, 'test-provider'),
|
||||
createOpenAICompatibleImage({ client: mockClient, payload, provider: 'test-provider' }),
|
||||
).rejects.toThrow('No image generated in chat completion response');
|
||||
});
|
||||
|
||||
@@ -436,7 +459,7 @@ describe('createOpenAICompatibleImage', () => {
|
||||
};
|
||||
|
||||
await expect(
|
||||
createOpenAICompatibleImage(mockClient, payload, 'test-provider'),
|
||||
createOpenAICompatibleImage({ client: mockClient, payload, provider: 'test-provider' }),
|
||||
).rejects.toThrow('No image generated in chat completion response');
|
||||
});
|
||||
|
||||
@@ -475,7 +498,11 @@ describe('createOpenAICompatibleImage', () => {
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createOpenAICompatibleImage(mockClient, payload, 'openrouter');
|
||||
const result = await createOpenAICompatibleImage({
|
||||
client: mockClient,
|
||||
payload,
|
||||
provider: 'openrouter',
|
||||
});
|
||||
|
||||
expect(result.imageUrl).toBe('data:image/png;base64,processedResult');
|
||||
|
||||
@@ -522,7 +549,11 @@ describe('createOpenAICompatibleImage', () => {
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createOpenAICompatibleImage(mockClient, payload, 'test-provider');
|
||||
const result = await createOpenAICompatibleImage({
|
||||
client: mockClient,
|
||||
payload,
|
||||
provider: 'test-provider',
|
||||
});
|
||||
|
||||
expect(result.imageUrl).toBe('data:image/png;base64,chatModelResult');
|
||||
expect(mockClient.chat.completions.create).toHaveBeenCalled();
|
||||
@@ -548,7 +579,11 @@ describe('createOpenAICompatibleImage', () => {
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createOpenAICompatibleImage(mockClient, payload, 'openai');
|
||||
const result = await createOpenAICompatibleImage({
|
||||
client: mockClient,
|
||||
payload,
|
||||
provider: 'openai',
|
||||
});
|
||||
|
||||
expect(result.imageUrl).toBe('data:image/png;base64,imageModelBase64Result');
|
||||
expect(mockClient.images.generate).toHaveBeenCalled();
|
||||
@@ -586,7 +621,11 @@ describe('createOpenAICompatibleImage', () => {
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createOpenAICompatibleImage(mockClient, payload, 'openai');
|
||||
const result = await createOpenAICompatibleImage({
|
||||
client: mockClient,
|
||||
payload,
|
||||
provider: 'openai',
|
||||
});
|
||||
|
||||
expect(result.imageUrl).toBe('data:image/png;base64,editedImageResult');
|
||||
expect(mockClient.images.edit).toHaveBeenCalled();
|
||||
@@ -611,7 +650,11 @@ describe('createOpenAICompatibleImage', () => {
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createOpenAICompatibleImage(mockClient, payload, 'openai');
|
||||
const result = await createOpenAICompatibleImage({
|
||||
client: mockClient,
|
||||
payload,
|
||||
provider: 'openai',
|
||||
});
|
||||
|
||||
expect(result.imageUrl).toBe('data:image/png;base64,generatedImage');
|
||||
expect(mockClient.images.generate).toHaveBeenCalled();
|
||||
@@ -637,7 +680,11 @@ describe('createOpenAICompatibleImage', () => {
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createOpenAICompatibleImage(mockClient, payload, 'openai');
|
||||
const result = await createOpenAICompatibleImage({
|
||||
client: mockClient,
|
||||
payload,
|
||||
provider: 'openai',
|
||||
});
|
||||
|
||||
expect(result.imageUrl).toBe('data:image/png;base64,generatedImage');
|
||||
expect(mockClient.images.generate).toHaveBeenCalled();
|
||||
@@ -665,7 +712,11 @@ describe('createOpenAICompatibleImage', () => {
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createOpenAICompatibleImage(mockClient, payload, 'openai');
|
||||
const result = await createOpenAICompatibleImage({
|
||||
client: mockClient,
|
||||
payload,
|
||||
provider: 'openai',
|
||||
});
|
||||
|
||||
expect(result.imageUrl).toBe(mockImageUrl);
|
||||
expect(mockClient.images.generate).toHaveBeenCalled();
|
||||
@@ -690,9 +741,9 @@ describe('createOpenAICompatibleImage', () => {
|
||||
},
|
||||
};
|
||||
|
||||
await expect(createOpenAICompatibleImage(mockClient, payload, 'openai')).rejects.toThrow(
|
||||
'Invalid image response: missing both b64_json and url fields',
|
||||
);
|
||||
await expect(
|
||||
createOpenAICompatibleImage({ client: mockClient, payload, provider: 'openai' }),
|
||||
).rejects.toThrow('Invalid image response: missing both b64_json and url fields');
|
||||
});
|
||||
|
||||
it('should throw error when response data is not an array', async () => {
|
||||
@@ -709,9 +760,9 @@ describe('createOpenAICompatibleImage', () => {
|
||||
},
|
||||
};
|
||||
|
||||
await expect(createOpenAICompatibleImage(mockClient, payload, 'openai')).rejects.toThrow(
|
||||
'Invalid image response: missing or empty data array',
|
||||
);
|
||||
await expect(
|
||||
createOpenAICompatibleImage({ client: mockClient, payload, provider: 'openai' }),
|
||||
).rejects.toThrow('Invalid image response: missing or empty data array');
|
||||
});
|
||||
|
||||
it('should throw error when imageData is undefined in array', async () => {
|
||||
@@ -728,9 +779,9 @@ describe('createOpenAICompatibleImage', () => {
|
||||
},
|
||||
};
|
||||
|
||||
await expect(createOpenAICompatibleImage(mockClient, payload, 'openai')).rejects.toThrow(
|
||||
'Invalid image response: first data item is null or undefined',
|
||||
);
|
||||
await expect(
|
||||
createOpenAICompatibleImage({ client: mockClient, payload, provider: 'openai' }),
|
||||
).rejects.toThrow('Invalid image response: first data item is null or undefined');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -762,7 +813,11 @@ describe('createOpenAICompatibleImage', () => {
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createOpenAICompatibleImage(mockClient, payload, 'openai');
|
||||
const result = await createOpenAICompatibleImage({
|
||||
client: mockClient,
|
||||
payload,
|
||||
provider: 'openai',
|
||||
});
|
||||
|
||||
expect(result.imageUrl).toBe('data:image/png;base64,imageWithUsage');
|
||||
expect(result.modelUsage).toBeDefined();
|
||||
@@ -790,7 +845,11 @@ describe('createOpenAICompatibleImage', () => {
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createOpenAICompatibleImage(mockClient, payload, 'openai');
|
||||
const result = await createOpenAICompatibleImage({
|
||||
client: mockClient,
|
||||
payload,
|
||||
provider: 'openai',
|
||||
});
|
||||
|
||||
expect(result.imageUrl).toBe('data:image/png;base64,imageWithoutUsage');
|
||||
expect(result.modelUsage).toBeUndefined();
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { imageUrlToBase64 } from '@lobechat/utils/imageToBase64';
|
||||
import { cleanObject } from '@lobechat/utils/object';
|
||||
import createDebug from 'debug';
|
||||
import { RuntimeImageGenParamsValue } from 'model-bank';
|
||||
@@ -5,24 +6,31 @@ import OpenAI from 'openai';
|
||||
|
||||
import { CreateImagePayload, CreateImageResponse } from '../../types/image';
|
||||
import { getModelPricing } from '../../utils/getModelPricing';
|
||||
import { imageUrlToBase64 } from '../../utils/imageToBase64';
|
||||
import { parseDataUri } from '../../utils/uriParser';
|
||||
import { convertImageUrlToFile } from '../contextBuilders/openai';
|
||||
import { convertOpenAIImageUsage } from '../usageConverters/openai';
|
||||
|
||||
interface CreateImageContext {
|
||||
client: OpenAI;
|
||||
fetchImpl?: typeof fetch;
|
||||
payload: CreateImagePayload;
|
||||
provider: string;
|
||||
}
|
||||
|
||||
const log = createDebug('lobe-image:openai-compatible');
|
||||
|
||||
/**
|
||||
* Generate images using traditional OpenAI images API (DALL-E, etc.)
|
||||
*/
|
||||
async function generateByImageMode(
|
||||
client: OpenAI,
|
||||
payload: CreateImagePayload,
|
||||
provider: string,
|
||||
): Promise<CreateImageResponse> {
|
||||
async function generateByImageMode({
|
||||
client,
|
||||
payload,
|
||||
provider,
|
||||
fetchImpl,
|
||||
}: CreateImageContext): Promise<CreateImageResponse> {
|
||||
const { model, params } = payload;
|
||||
|
||||
log('Creating image with model: %s and params: %O', model, params);
|
||||
log('Creating image with provider: %s, model: %s and params: %O', provider, model, params);
|
||||
|
||||
// Map parameter names, mapping imageUrls to image
|
||||
const paramsMap = new Map<RuntimeImageGenParamsValue, string>([
|
||||
@@ -48,7 +56,7 @@ async function generateByImageMode(
|
||||
try {
|
||||
// Convert all image URLs to File objects
|
||||
const imageFiles = await Promise.all(
|
||||
userInput.image.map((url: string) => convertImageUrlToFile(url)),
|
||||
userInput.image.map((url: string) => convertImageUrlToFile(url, fetchImpl)),
|
||||
);
|
||||
|
||||
// According to official docs, if there are multiple images, pass an array; if only one, pass a single File
|
||||
@@ -127,7 +135,7 @@ async function generateByImageMode(
|
||||
/**
|
||||
* Process image URL for chat model input
|
||||
*/
|
||||
async function processImageUrlForChat(imageUrl: string): Promise<string> {
|
||||
async function processImageUrlForChat(imageUrl: string, fetchImpl?: typeof fetch): Promise<string> {
|
||||
const { type, base64, mimeType } = parseDataUri(imageUrl);
|
||||
|
||||
if (type === 'base64') {
|
||||
@@ -137,7 +145,10 @@ async function processImageUrlForChat(imageUrl: string): Promise<string> {
|
||||
return `data:${mimeType || 'image/png'};base64,${base64}`;
|
||||
} else if (type === 'url') {
|
||||
// For URL type, convert to base64 first
|
||||
const { base64: urlBase64, mimeType: urlMimeType } = await imageUrlToBase64(imageUrl);
|
||||
const { base64: urlBase64, mimeType: urlMimeType } = await imageUrlToBase64(
|
||||
imageUrl,
|
||||
fetchImpl,
|
||||
);
|
||||
return `data:${urlMimeType};base64,${urlBase64}`;
|
||||
} else {
|
||||
throw new TypeError(`Currently we don't support image url: ${imageUrl}`);
|
||||
@@ -147,14 +158,21 @@ async function processImageUrlForChat(imageUrl: string): Promise<string> {
|
||||
/**
|
||||
* Generate images using chat completion API (OpenRouter Gemini, etc.)
|
||||
*/
|
||||
async function generateByChatModel(
|
||||
client: OpenAI,
|
||||
payload: CreateImagePayload,
|
||||
): Promise<CreateImageResponse> {
|
||||
async function generateByChatModel({
|
||||
client,
|
||||
payload,
|
||||
provider,
|
||||
fetchImpl,
|
||||
}: CreateImageContext): Promise<CreateImageResponse> {
|
||||
const { model, params } = payload;
|
||||
const actualModel = model.replace(':image', ''); // Remove :image suffix
|
||||
|
||||
log('Creating image via chat API with model: %s and params: %O', actualModel, params);
|
||||
log(
|
||||
'Creating image via chat API with provider: %s, model: %s and params: %O',
|
||||
provider,
|
||||
actualModel,
|
||||
params,
|
||||
);
|
||||
|
||||
// Build message content array
|
||||
const content: Array<any> = [
|
||||
@@ -168,7 +186,7 @@ async function generateByChatModel(
|
||||
if (params.imageUrl && params.imageUrl !== null) {
|
||||
log('Processing image URL for editing mode: %s', params.imageUrl);
|
||||
try {
|
||||
const processedImageUrl = await processImageUrlForChat(params.imageUrl);
|
||||
const processedImageUrl = await processImageUrlForChat(params.imageUrl, fetchImpl);
|
||||
content.push({
|
||||
image_url: {
|
||||
url: processedImageUrl,
|
||||
@@ -220,18 +238,19 @@ async function generateByChatModel(
|
||||
/**
|
||||
* Create image using OpenAI Compatible API
|
||||
*/
|
||||
export async function createOpenAICompatibleImage(
|
||||
client: OpenAI,
|
||||
payload: CreateImagePayload,
|
||||
provider: string,
|
||||
): Promise<CreateImageResponse> {
|
||||
export async function createOpenAICompatibleImage({
|
||||
client,
|
||||
payload,
|
||||
provider,
|
||||
fetchImpl,
|
||||
}: CreateImageContext): Promise<CreateImageResponse> {
|
||||
const { model } = payload;
|
||||
|
||||
// Check if it's a chat model for image generation (via :image suffix)
|
||||
if (model.endsWith(':image')) {
|
||||
return await generateByChatModel(client, payload);
|
||||
return await generateByChatModel({ client, fetchImpl, payload, provider });
|
||||
}
|
||||
|
||||
// Default to traditional images API
|
||||
return await generateByImageMode(client, payload, provider);
|
||||
return await generateByImageMode({ client, fetchImpl, payload, provider });
|
||||
}
|
||||
|
||||
@@ -1251,6 +1251,7 @@ describe('LobeOpenAICompatibleFactory', () => {
|
||||
|
||||
expect(openaiHelpers.convertImageUrlToFile).toHaveBeenCalledWith(
|
||||
'https://example.com/image1.jpg',
|
||||
undefined,
|
||||
);
|
||||
expect(instance['client'].images.edit).toHaveBeenCalledWith({
|
||||
image: expect.any(File),
|
||||
@@ -1293,9 +1294,11 @@ describe('LobeOpenAICompatibleFactory', () => {
|
||||
expect(openaiHelpers.convertImageUrlToFile).toHaveBeenCalledTimes(2);
|
||||
expect(openaiHelpers.convertImageUrlToFile).toHaveBeenCalledWith(
|
||||
'https://example.com/image1.jpg',
|
||||
undefined,
|
||||
);
|
||||
expect(openaiHelpers.convertImageUrlToFile).toHaveBeenCalledWith(
|
||||
'https://example.com/image2.jpg',
|
||||
undefined,
|
||||
);
|
||||
|
||||
expect(instance['client'].images.edit).toHaveBeenCalledWith({
|
||||
|
||||
@@ -57,6 +57,7 @@ export const CHAT_MODELS_BLOCK_LIST = [
|
||||
type ConstructorOptions<T extends Record<string, any> = any> = ClientOptions & T;
|
||||
export type CreateImageOptions = Omit<ClientOptions, 'apiKey'> & {
|
||||
apiKey: string;
|
||||
fetch?: typeof fetch;
|
||||
provider: string;
|
||||
};
|
||||
|
||||
@@ -178,6 +179,7 @@ export const createOpenAICompatibleRuntime = <T extends Record<string, any> = an
|
||||
|
||||
baseURL!: string;
|
||||
protected _options: ConstructorOptions<T>;
|
||||
protected _fetch?: typeof fetch;
|
||||
|
||||
constructor(options: ClientOptions & Record<string, any> = {}) {
|
||||
const _options = {
|
||||
@@ -187,6 +189,7 @@ export const createOpenAICompatibleRuntime = <T extends Record<string, any> = an
|
||||
};
|
||||
const { apiKey, baseURL = DEFAULT_BASE_URL, ...res } = _options;
|
||||
this._options = _options as ConstructorOptions<T>;
|
||||
this._fetch = options.fetch as any;
|
||||
|
||||
if (!apiKey) throw AgentRuntimeError.createError(ErrorType?.invalidAPIKey);
|
||||
|
||||
@@ -252,7 +255,7 @@ export const createOpenAICompatibleRuntime = <T extends Record<string, any> = an
|
||||
return this.handleResponseAPIMode(processedPayload, options);
|
||||
}
|
||||
|
||||
const messages = await convertOpenAIMessages(postPayload.messages);
|
||||
const messages = await convertOpenAIMessages(postPayload.messages, this._fetch);
|
||||
|
||||
let response: Stream<OpenAI.Chat.Completions.ChatCompletionChunk>;
|
||||
|
||||
@@ -361,16 +364,27 @@ export const createOpenAICompatibleRuntime = <T extends Record<string, any> = an
|
||||
// If custom createImage implementation is provided, use it
|
||||
if (customCreateImage) {
|
||||
log('using custom createImage implementation');
|
||||
return customCreateImage(payload, {
|
||||
const fetchImpl = (this._fetch ?? (this._options as any).fetch) as typeof fetch | undefined;
|
||||
|
||||
const imageOptions: CreateImageOptions = {
|
||||
...this._options,
|
||||
apiKey: this._options.apiKey!,
|
||||
provider,
|
||||
});
|
||||
} as CreateImageOptions;
|
||||
|
||||
if (fetchImpl) imageOptions.fetch = fetchImpl;
|
||||
|
||||
return customCreateImage(payload, imageOptions);
|
||||
}
|
||||
|
||||
log('using default createOpenAICompatibleImage');
|
||||
// Use the new createOpenAICompatibleImage function
|
||||
return createOpenAICompatibleImage(this.client, payload, this.id);
|
||||
return createOpenAICompatibleImage({
|
||||
client: this.client,
|
||||
fetchImpl: this._fetch,
|
||||
payload,
|
||||
provider: this.id,
|
||||
});
|
||||
}
|
||||
|
||||
async models() {
|
||||
@@ -776,7 +790,7 @@ export const createOpenAICompatibleRuntime = <T extends Record<string, any> = an
|
||||
delete res.frequency_penalty;
|
||||
delete res.presence_penalty;
|
||||
|
||||
const input = await convertOpenAIResponseInputs(messages as any);
|
||||
const input = await convertOpenAIResponseInputs(messages as any, this._fetch);
|
||||
|
||||
const isStreaming = payload.stream !== false;
|
||||
log(
|
||||
@@ -928,7 +942,7 @@ export const createOpenAICompatibleRuntime = <T extends Record<string, any> = an
|
||||
|
||||
if (shouldUseResponses) {
|
||||
log('calling responses.create for tool calling');
|
||||
const input = await convertOpenAIResponseInputs(messages as any);
|
||||
const input = await convertOpenAIResponseInputs(messages as any, this._fetch);
|
||||
|
||||
const res = await this.client.responses.create(
|
||||
{
|
||||
|
||||
@@ -73,6 +73,7 @@ const resolveCacheTTL = (
|
||||
};
|
||||
|
||||
interface AnthropicAIParams extends ClientOptions {
|
||||
fetch?: typeof fetch;
|
||||
id?: string;
|
||||
}
|
||||
|
||||
@@ -82,6 +83,7 @@ export class LobeAnthropicAI implements LobeRuntimeAI {
|
||||
baseURL: string;
|
||||
apiKey?: string;
|
||||
private id: string;
|
||||
fetch?: typeof fetch;
|
||||
|
||||
private isDebug() {
|
||||
return process.env.DEBUG_ANTHROPIC_CHAT_COMPLETION === '1';
|
||||
@@ -92,16 +94,20 @@ export class LobeAnthropicAI implements LobeRuntimeAI {
|
||||
baseURL = DEFAULT_BASE_URL,
|
||||
id,
|
||||
defaultHeaders,
|
||||
fetch: customFetch,
|
||||
...res
|
||||
}: AnthropicAIParams = {}) {
|
||||
if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);
|
||||
|
||||
const betaHeaders = process.env.ANTHROPIC_BETA_HEADERS;
|
||||
|
||||
this.fetch = customFetch;
|
||||
|
||||
this.client = new Anthropic({
|
||||
apiKey,
|
||||
baseURL,
|
||||
defaultHeaders: { ...defaultHeaders, 'anthropic-beta': betaHeaders },
|
||||
fetch: customFetch,
|
||||
...res,
|
||||
});
|
||||
this.baseURL = this.client.baseURL;
|
||||
@@ -200,7 +206,10 @@ export class LobeAnthropicAI implements LobeRuntimeAI {
|
||||
] as Anthropic.TextBlockParam[])
|
||||
: undefined;
|
||||
|
||||
const postMessages = await buildAnthropicMessages(user_messages, { enabledContextCaching });
|
||||
const postMessages = await buildAnthropicMessages(user_messages, {
|
||||
customFetch: this.fetch,
|
||||
enabledContextCaching,
|
||||
});
|
||||
|
||||
let postTools: anthropicTools[] | undefined = buildAnthropicTools(tools, {
|
||||
enabledContextCaching,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import createClient, { ModelClient } from '@azure-rest/ai-inference';
|
||||
import { AzureKeyCredential } from '@azure/core-auth';
|
||||
import { ModelProvider } from 'model-bank';
|
||||
import type { Readable as NodeReadable } from 'node:stream';
|
||||
import OpenAI from 'openai';
|
||||
|
||||
import { systemToUserModels } from '../../const/models';
|
||||
@@ -64,9 +65,40 @@ export class LobeAzureAI implements LobeRuntimeAI {
|
||||
});
|
||||
|
||||
if (enableStreaming) {
|
||||
const stream = await response.asBrowserStream();
|
||||
const unifiedStream = await (async () => {
|
||||
if (typeof window === 'undefined') {
|
||||
/**
|
||||
* In Node.js the SDK exposes a Node readable stream, so we convert it to a Web ReadableStream
|
||||
* to reuse the same streaming pipeline used by Edge/browser runtimes.
|
||||
*/
|
||||
const streamModule = await import('node:stream');
|
||||
const Readable = streamModule.Readable ?? streamModule.default.Readable;
|
||||
|
||||
const [prod, debug] = stream.body!.tee();
|
||||
if (!Readable) throw new Error('node:stream module missing Readable export');
|
||||
if (typeof Readable.toWeb !== 'function')
|
||||
throw new Error('Readable.toWeb is not a function');
|
||||
|
||||
const nodeResponse = await response.asNodeStream();
|
||||
const nodeStream = nodeResponse.body;
|
||||
|
||||
if (!nodeStream) {
|
||||
throw new Error('Azure AI response body is empty');
|
||||
}
|
||||
|
||||
return Readable.toWeb(nodeStream as unknown as NodeReadable) as ReadableStream;
|
||||
}
|
||||
|
||||
const browserResponse = await response.asBrowserStream();
|
||||
const browserStream = browserResponse.body;
|
||||
|
||||
if (!browserStream) {
|
||||
throw new Error('Azure AI response body is empty');
|
||||
}
|
||||
|
||||
return browserStream;
|
||||
})();
|
||||
|
||||
const [prod, debug] = unifiedStream.tee();
|
||||
|
||||
if (process.env.DEBUG_AZURE_AI_CHAT_COMPLETION === '1') {
|
||||
debugStream(debug).catch(console.error);
|
||||
@@ -130,7 +162,7 @@ export class LobeAzureAI implements LobeRuntimeAI {
|
||||
const regex = /^(https:\/\/)([^.]+)(\.cognitiveservices\.azure\.com\/.*)$/;
|
||||
|
||||
// 使用替换函数
|
||||
return url.replace(regex, (match, protocol, subdomain, rest) => {
|
||||
return url.replace(regex, (_match, protocol, _subdomain, rest) => {
|
||||
// 将子域名替换为 '***'
|
||||
return `${protocol}***${rest}`;
|
||||
});
|
||||
|
||||
@@ -6,7 +6,7 @@ import { createBflImage } from './createImage';
|
||||
import { BflStatusResponse } from './types';
|
||||
|
||||
// Mock external dependencies
|
||||
vi.mock('../../utils/imageToBase64', () => ({
|
||||
vi.mock('@lobechat/utils/imageToBase64', () => ({
|
||||
imageUrlToBase64: vi.fn(),
|
||||
}));
|
||||
|
||||
@@ -187,7 +187,7 @@ describe('createBflImage', () => {
|
||||
it('should convert single imageUrl to image_prompt base64', async () => {
|
||||
// Arrange
|
||||
const { parseDataUri } = await import('../../utils/uriParser');
|
||||
const { imageUrlToBase64 } = await import('../../utils/imageToBase64');
|
||||
const { imageUrlToBase64 } = await import('@lobechat/utils/imageToBase64');
|
||||
const { asyncifyPolling } = await import('../../utils/asyncifyPolling');
|
||||
|
||||
const mockParseDataUri = vi.mocked(parseDataUri);
|
||||
@@ -226,7 +226,7 @@ describe('createBflImage', () => {
|
||||
|
||||
// Assert
|
||||
expect(mockParseDataUri).toHaveBeenCalledWith('https://example.com/input.jpg');
|
||||
expect(mockImageUrlToBase64).toHaveBeenCalledWith('https://example.com/input.jpg');
|
||||
expect(mockImageUrlToBase64).toHaveBeenCalledWith('https://example.com/input.jpg', undefined);
|
||||
|
||||
const callArgs = mockFetch.mock.calls[0][1];
|
||||
const requestBody = JSON.parse(callArgs?.body as string);
|
||||
@@ -290,7 +290,7 @@ describe('createBflImage', () => {
|
||||
it('should convert multiple imageUrls for Kontext models', async () => {
|
||||
// Arrange
|
||||
const { parseDataUri } = await import('../../utils/uriParser');
|
||||
const { imageUrlToBase64 } = await import('../../utils/imageToBase64');
|
||||
const { imageUrlToBase64 } = await import('@lobechat/utils/imageToBase64');
|
||||
const { asyncifyPolling } = await import('../../utils/asyncifyPolling');
|
||||
|
||||
const mockParseDataUri = vi.mocked(parseDataUri);
|
||||
@@ -350,7 +350,7 @@ describe('createBflImage', () => {
|
||||
it('should limit imageUrls to maximum 4 images', async () => {
|
||||
// Arrange
|
||||
const { parseDataUri } = await import('../../utils/uriParser');
|
||||
const { imageUrlToBase64 } = await import('../../utils/imageToBase64');
|
||||
const { imageUrlToBase64 } = await import('@lobechat/utils/imageToBase64');
|
||||
const { asyncifyPolling } = await import('../../utils/asyncifyPolling');
|
||||
|
||||
const mockParseDataUri = vi.mocked(parseDataUri);
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { imageUrlToBase64 } from '@lobechat/utils/imageToBase64';
|
||||
import createDebug from 'debug';
|
||||
import { RuntimeImageGenParamsValue } from 'model-bank';
|
||||
|
||||
@@ -5,7 +6,6 @@ import { AgentRuntimeErrorType } from '../../types/error';
|
||||
import { CreateImagePayload, CreateImageResponse } from '../../types/image';
|
||||
import { type TaskResult, asyncifyPolling } from '../../utils/asyncifyPolling';
|
||||
import { AgentRuntimeError } from '../../utils/createError';
|
||||
import { imageUrlToBase64 } from '../../utils/imageToBase64';
|
||||
import { parseDataUri } from '../../utils/uriParser';
|
||||
import {
|
||||
BFL_ENDPOINTS,
|
||||
@@ -23,13 +23,14 @@ const BASE_URL = 'https://api.bfl.ai';
|
||||
interface BflCreateImageOptions {
|
||||
apiKey: string;
|
||||
baseURL?: string;
|
||||
fetch?: typeof fetch;
|
||||
provider: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert image URL to base64 format required by BFL API
|
||||
*/
|
||||
async function convertImageToBase64(imageUrl: string): Promise<string> {
|
||||
async function convertImageToBase64(imageUrl: string, fetchImpl?: typeof fetch): Promise<string> {
|
||||
try {
|
||||
const { type } = parseDataUri(imageUrl);
|
||||
|
||||
@@ -44,7 +45,7 @@ async function convertImageToBase64(imageUrl: string): Promise<string> {
|
||||
|
||||
if (type === 'url') {
|
||||
// Convert URL to base64
|
||||
const { base64 } = await imageUrlToBase64(imageUrl);
|
||||
const { base64 } = await imageUrlToBase64(imageUrl, fetchImpl);
|
||||
return base64;
|
||||
}
|
||||
|
||||
@@ -61,6 +62,7 @@ async function convertImageToBase64(imageUrl: string): Promise<string> {
|
||||
async function buildRequestPayload(
|
||||
model: BflModelId,
|
||||
params: CreateImagePayload['params'],
|
||||
fetchImpl?: typeof fetch,
|
||||
): Promise<BflRequest> {
|
||||
log('Building request payload for model: %s', model);
|
||||
|
||||
@@ -88,7 +90,7 @@ async function buildRequestPayload(
|
||||
if (params.imageUrls && params.imageUrls.length > 0) {
|
||||
for (let i = 0; i < Math.min(params.imageUrls.length, 4); i++) {
|
||||
const fieldName = i === 0 ? 'input_image' : `input_image_${i + 1}`;
|
||||
userPayload[fieldName] = await convertImageToBase64(params.imageUrls[i]);
|
||||
userPayload[fieldName] = await convertImageToBase64(params.imageUrls[i], fetchImpl);
|
||||
}
|
||||
// Remove the original imageUrls field as it's now mapped to input_image_*
|
||||
delete userPayload.imageUrls;
|
||||
@@ -96,7 +98,7 @@ async function buildRequestPayload(
|
||||
|
||||
// Handle single image input (imageUrl)
|
||||
if (params.imageUrl) {
|
||||
userPayload.image_prompt = await convertImageToBase64(params.imageUrl);
|
||||
userPayload.image_prompt = await convertImageToBase64(params.imageUrl, fetchImpl);
|
||||
// Remove the original imageUrl field as it's now mapped to image_prompt
|
||||
delete userPayload.imageUrl;
|
||||
}
|
||||
@@ -123,7 +125,9 @@ async function submitTask(
|
||||
|
||||
log('Submitting task to: %s', url);
|
||||
|
||||
const response = await fetch(url, {
|
||||
const fetchImpl = options.fetch ?? fetch;
|
||||
|
||||
const response = await fetchImpl(url, {
|
||||
body: JSON.stringify(payload),
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
@@ -160,7 +164,9 @@ async function queryTaskStatus(
|
||||
): Promise<BflResultResponse> {
|
||||
log('Querying task status using polling URL: %s', pollingUrl);
|
||||
|
||||
const response = await fetch(pollingUrl, {
|
||||
const fetchImpl = options.fetch ?? fetch;
|
||||
|
||||
const response = await fetchImpl(pollingUrl, {
|
||||
headers: {
|
||||
'accept': 'application/json',
|
||||
'x-key': options.apiKey,
|
||||
@@ -203,7 +209,7 @@ export async function createBflImage(
|
||||
|
||||
try {
|
||||
// 1. Build request payload
|
||||
const requestPayload = await buildRequestPayload(model as BflModelId, params);
|
||||
const requestPayload = await buildRequestPayload(model as BflModelId, params, options.fetch);
|
||||
|
||||
// 2. Submit image generation task
|
||||
const taskResponse = await submitTask(model as BflModelId, requestPayload, options);
|
||||
|
||||
@@ -9,15 +9,21 @@ import { createBflImage } from './createImage';
|
||||
|
||||
const log = createDebug('lobe-image:bfl');
|
||||
|
||||
interface BflAIParams extends ClientOptions {
|
||||
fetch?: typeof fetch;
|
||||
}
|
||||
|
||||
export class LobeBflAI implements LobeRuntimeAI {
|
||||
private apiKey: string;
|
||||
fetch?: typeof fetch;
|
||||
baseURL?: string;
|
||||
|
||||
constructor({ apiKey, baseURL }: ClientOptions = {}) {
|
||||
constructor({ apiKey, baseURL, fetch: customFetch }: BflAIParams = {}) {
|
||||
if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);
|
||||
|
||||
this.apiKey = apiKey;
|
||||
this.baseURL = baseURL || undefined;
|
||||
this.fetch = customFetch;
|
||||
|
||||
log('BFL AI initialized');
|
||||
}
|
||||
@@ -30,6 +36,7 @@ export class LobeBflAI implements LobeRuntimeAI {
|
||||
return await createBflImage(payload, {
|
||||
apiKey: this.apiKey,
|
||||
baseURL: this.baseURL,
|
||||
fetch: this.fetch,
|
||||
provider: 'bfl',
|
||||
});
|
||||
} catch (error) {
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
// @vitest-environment edge-runtime
|
||||
import { GoogleGenAI } from '@google/genai';
|
||||
import * as imageToBase64Module from '@lobechat/utils/imageToBase64';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { CreateImagePayload } from '../../types/image';
|
||||
import * as imageToBase64Module from '../../utils/imageToBase64';
|
||||
import { createGoogleImage } from './createImage';
|
||||
|
||||
const provider = 'google';
|
||||
@@ -494,6 +494,7 @@ describe('createGoogleImage', () => {
|
||||
// Assert
|
||||
expect(imageToBase64Module.imageUrlToBase64).toHaveBeenCalledWith(
|
||||
'https://example.com/image.jpg',
|
||||
undefined,
|
||||
);
|
||||
expect(mockClient.models.generateContent).toHaveBeenCalledWith({
|
||||
contents: [
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { Content, GenerateContentConfig, GoogleGenAI, Part } from '@google/genai';
|
||||
import { imageUrlToBase64 } from '@lobechat/utils/imageToBase64';
|
||||
|
||||
import { convertGoogleAIUsage } from '../../core/usageConverters/google-ai';
|
||||
import { CreateImagePayload, CreateImageResponse } from '../../types/image';
|
||||
import { AgentRuntimeError } from '../../utils/createError';
|
||||
import { getModelPricing } from '../../utils/getModelPricing';
|
||||
import { parseGoogleErrorMessage } from '../../utils/googleErrorParser';
|
||||
import { imageUrlToBase64 } from '../../utils/imageToBase64';
|
||||
import { parseDataUri } from '../../utils/uriParser';
|
||||
|
||||
// Maximum number of images allowed for processing
|
||||
@@ -14,7 +14,7 @@ const MAX_IMAGE_COUNT = 10;
|
||||
/**
|
||||
* Process a single image URL and convert it to Google AI Part format
|
||||
*/
|
||||
async function processImageForParts(imageUrl: string): Promise<Part> {
|
||||
async function processImageForParts(imageUrl: string, fetchImpl?: typeof fetch): Promise<Part> {
|
||||
const { mimeType, base64, type } = parseDataUri(imageUrl);
|
||||
|
||||
if (type === 'base64') {
|
||||
@@ -29,7 +29,10 @@ async function processImageForParts(imageUrl: string): Promise<Part> {
|
||||
},
|
||||
};
|
||||
} else if (type === 'url') {
|
||||
const { base64: urlBase64, mimeType: urlMimeType } = await imageUrlToBase64(imageUrl);
|
||||
const { base64: urlBase64, mimeType: urlMimeType } = await imageUrlToBase64(
|
||||
imageUrl,
|
||||
fetchImpl,
|
||||
);
|
||||
|
||||
return {
|
||||
inlineData: {
|
||||
@@ -104,6 +107,7 @@ async function generateImageByChatModel(
|
||||
client: GoogleGenAI,
|
||||
payload: CreateImagePayload,
|
||||
provider: string,
|
||||
fetchImpl?: typeof fetch,
|
||||
): Promise<CreateImageResponse> {
|
||||
const { model, params } = payload;
|
||||
const actualModel = model.replace(':image', '');
|
||||
@@ -118,7 +122,7 @@ async function generateImageByChatModel(
|
||||
|
||||
// Add image for editing if provided
|
||||
if (params.imageUrl && params.imageUrl !== null) {
|
||||
const imagePart = await processImageForParts(params.imageUrl);
|
||||
const imagePart = await processImageForParts(params.imageUrl, fetchImpl);
|
||||
parts.push(imagePart);
|
||||
}
|
||||
|
||||
@@ -129,7 +133,7 @@ async function generateImageByChatModel(
|
||||
}
|
||||
|
||||
const imageParts = await Promise.all(
|
||||
params.imageUrls.map((imageUrl) => processImageForParts(imageUrl)),
|
||||
params.imageUrls.map((imageUrl) => processImageForParts(imageUrl, fetchImpl)),
|
||||
);
|
||||
parts.push(...imageParts);
|
||||
}
|
||||
@@ -174,13 +178,14 @@ export async function createGoogleImage(
|
||||
client: GoogleGenAI,
|
||||
provider: string,
|
||||
payload: CreateImagePayload,
|
||||
fetchImpl?: typeof fetch,
|
||||
): Promise<CreateImageResponse> {
|
||||
try {
|
||||
const { model } = payload;
|
||||
|
||||
// Handle Gemini 2.5 Flash Image models that use generateContent
|
||||
if (model.endsWith(':image')) {
|
||||
return await generateImageByChatModel(client, payload, provider);
|
||||
return await generateImageByChatModel(client, payload, provider, fetchImpl);
|
||||
}
|
||||
|
||||
// Handle traditional Imagen models that use generateImages
|
||||
|
||||
@@ -2,13 +2,13 @@
|
||||
import { GenerateContentResponse, Tool } from '@google/genai';
|
||||
import { OpenAIChatMessage } from '@lobechat/model-runtime';
|
||||
import { ChatStreamPayload } from '@lobechat/types';
|
||||
import * as imageToBase64Module from '@lobechat/utils/imageToBase64';
|
||||
import OpenAI from 'openai';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { LOBE_ERROR_KEY } from '../../core/streams';
|
||||
import { AgentRuntimeErrorType } from '../../types/error';
|
||||
import * as debugStreamModule from '../../utils/debugStream';
|
||||
import * as imageToBase64Module from '../../utils/imageToBase64';
|
||||
import { LobeGoogleAI, resolveModelThinkingBudget } from './index';
|
||||
|
||||
const provider = 'google';
|
||||
|
||||
@@ -147,6 +147,7 @@ interface LobeGoogleAIParams {
|
||||
baseURL?: string;
|
||||
client?: GoogleGenAI;
|
||||
defaultHeaders?: Record<string, any>;
|
||||
fetch?: typeof fetch;
|
||||
id?: string;
|
||||
isVertexAi?: boolean;
|
||||
}
|
||||
@@ -165,6 +166,7 @@ const isAbortError = (error: Error): boolean => {
|
||||
export class LobeGoogleAI implements LobeRuntimeAI {
|
||||
private client: GoogleGenAI;
|
||||
private isVertexAi: boolean;
|
||||
fetch?: typeof fetch;
|
||||
baseURL?: string;
|
||||
apiKey?: string;
|
||||
provider: string;
|
||||
@@ -176,6 +178,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {
|
||||
isVertexAi,
|
||||
id,
|
||||
defaultHeaders,
|
||||
fetch: customFetch,
|
||||
}: LobeGoogleAIParams = {}) {
|
||||
if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);
|
||||
|
||||
@@ -184,7 +187,22 @@ export class LobeGoogleAI implements LobeRuntimeAI {
|
||||
: undefined;
|
||||
|
||||
this.apiKey = apiKey;
|
||||
this.client = client ? client : new GoogleGenAI({ apiKey, httpOptions });
|
||||
|
||||
if (client) {
|
||||
this.client = client;
|
||||
this.fetch =
|
||||
customFetch ??
|
||||
(typeof (client as any).fetch === 'function'
|
||||
? ((client as any).fetch as typeof fetch)
|
||||
: undefined);
|
||||
} else {
|
||||
const clientOptions: Record<string, any> = { apiKey };
|
||||
if (httpOptions) clientOptions.httpOptions = httpOptions;
|
||||
if (customFetch) clientOptions.fetch = customFetch;
|
||||
this.client = new GoogleGenAI(clientOptions);
|
||||
this.fetch = customFetch;
|
||||
}
|
||||
|
||||
this.baseURL = client ? undefined : baseURL || DEFAULT_BASE_URL;
|
||||
this.isVertexAi = isVertexAi || false;
|
||||
|
||||
@@ -209,7 +227,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {
|
||||
thinkingBudget: resolvedThinkingBudget,
|
||||
};
|
||||
|
||||
const contents = await buildGoogleMessages(payload.messages);
|
||||
const contents = await buildGoogleMessages(payload.messages, this.fetch);
|
||||
|
||||
const controller = new AbortController();
|
||||
const originalSignal = options?.signal;
|
||||
@@ -316,7 +334,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {
|
||||
* @see https://ai.google.dev/gemini-api/docs/image-generation#imagen
|
||||
*/
|
||||
async createImage(payload: CreateImagePayload): Promise<CreateImageResponse> {
|
||||
return createGoogleImage(this.client, this.provider, payload);
|
||||
return createGoogleImage(this.client, this.provider, payload, this.fetch);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -326,7 +344,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {
|
||||
*/
|
||||
async generateObject(payload: GenerateObjectPayload, options?: GenerateObjectOptions) {
|
||||
// Convert OpenAI messages to Google format
|
||||
const contents = await buildGoogleMessages(payload.messages);
|
||||
const contents = await buildGoogleMessages(payload.messages, this.fetch);
|
||||
|
||||
// Handle tools-based structured output
|
||||
if (payload.tools && payload.tools.length > 0) {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
// @vitest-environment node
|
||||
import { imageUrlToBase64 } from '@lobechat/utils';
|
||||
import { ModelProvider } from 'model-bank';
|
||||
import { Ollama } from 'ollama/browser';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
@@ -9,6 +10,9 @@ import * as debugStreamModule from '../../utils/debugStream';
|
||||
import { LobeOllamaAI, params } from './index';
|
||||
|
||||
vi.mock('ollama/browser');
|
||||
vi.mock('@lobechat/utils', () => ({
|
||||
imageUrlToBase64: vi.fn(),
|
||||
}));
|
||||
|
||||
// Mock the console.error to avoid polluting test output
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
@@ -462,13 +466,13 @@ describe('LobeOllamaAI', () => {
|
||||
});
|
||||
|
||||
describe('buildOllamaMessages', () => {
|
||||
it('should convert OpenAIChatMessage array to OllamaMessage array', () => {
|
||||
it('should convert OpenAIChatMessage array to OllamaMessage array', async () => {
|
||||
const messages = [
|
||||
{ content: 'Hello', role: 'user' },
|
||||
{ content: 'Hi there!', role: 'assistant' },
|
||||
];
|
||||
|
||||
const ollamaMessages = ollamaAI['buildOllamaMessages'](messages as any);
|
||||
const ollamaMessages = await ollamaAI['buildOllamaMessages'](messages as any);
|
||||
|
||||
expect(ollamaMessages).toEqual([
|
||||
{ content: 'Hello', role: 'user' },
|
||||
@@ -476,15 +480,15 @@ describe('LobeOllamaAI', () => {
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle empty message array', () => {
|
||||
it('should handle empty message array', async () => {
|
||||
const messages: any[] = [];
|
||||
|
||||
const ollamaMessages = ollamaAI['buildOllamaMessages'](messages);
|
||||
const ollamaMessages = await ollamaAI['buildOllamaMessages'](messages);
|
||||
|
||||
expect(ollamaMessages).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle multiple messages with different roles', () => {
|
||||
it('should handle multiple messages with different roles', async () => {
|
||||
const messages = [
|
||||
{ content: 'Hello', role: 'system' },
|
||||
{ content: 'Hi', role: 'user' },
|
||||
@@ -492,7 +496,7 @@ describe('LobeOllamaAI', () => {
|
||||
{ content: 'How are you?', role: 'user' },
|
||||
];
|
||||
|
||||
const ollamaMessages = ollamaAI['buildOllamaMessages'](messages as any);
|
||||
const ollamaMessages = await ollamaAI['buildOllamaMessages'](messages as any);
|
||||
|
||||
expect(ollamaMessages).toHaveLength(4);
|
||||
expect(ollamaMessages[0].role).toBe('system');
|
||||
@@ -503,26 +507,26 @@ describe('LobeOllamaAI', () => {
|
||||
});
|
||||
|
||||
describe('convertContentToOllamaMessage', () => {
|
||||
it('should convert string content to OllamaMessage', () => {
|
||||
it('should convert string content to OllamaMessage', async () => {
|
||||
const message = { content: 'Hello', role: 'user' };
|
||||
|
||||
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
|
||||
expect(ollamaMessage).toEqual({ content: 'Hello', role: 'user' });
|
||||
});
|
||||
|
||||
it('should convert text content to OllamaMessage', () => {
|
||||
it('should convert text content to OllamaMessage', async () => {
|
||||
const message = {
|
||||
content: [{ type: 'text', text: 'Hello' }],
|
||||
role: 'user',
|
||||
};
|
||||
|
||||
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
|
||||
expect(ollamaMessage).toEqual({ content: 'Hello', role: 'user' });
|
||||
});
|
||||
|
||||
it('should convert image_url content to OllamaMessage with images', () => {
|
||||
it('should convert image_url content to OllamaMessage with images', async () => {
|
||||
const message = {
|
||||
content: [
|
||||
{
|
||||
@@ -533,7 +537,7 @@ describe('LobeOllamaAI', () => {
|
||||
role: 'user',
|
||||
};
|
||||
|
||||
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
|
||||
expect(ollamaMessage).toEqual({
|
||||
content: '',
|
||||
@@ -542,7 +546,7 @@ describe('LobeOllamaAI', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should ignore invalid image_url content', () => {
|
||||
it('should ignore invalid image_url content', async () => {
|
||||
const message = {
|
||||
content: [
|
||||
{
|
||||
@@ -553,7 +557,7 @@ describe('LobeOllamaAI', () => {
|
||||
role: 'user',
|
||||
};
|
||||
|
||||
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
|
||||
expect(ollamaMessage).toEqual({
|
||||
content: '',
|
||||
@@ -561,7 +565,7 @@ describe('LobeOllamaAI', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle mixed text and image content', () => {
|
||||
it('should handle mixed text and image content', async () => {
|
||||
const message = {
|
||||
content: [
|
||||
{ type: 'text', text: 'First text' },
|
||||
@@ -578,7 +582,7 @@ describe('LobeOllamaAI', () => {
|
||||
role: 'user',
|
||||
};
|
||||
|
||||
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
|
||||
expect(ollamaMessage).toEqual({
|
||||
content: 'Second text', // Should keep latest text
|
||||
@@ -587,13 +591,13 @@ describe('LobeOllamaAI', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle content with empty text', () => {
|
||||
it('should handle content with empty text', async () => {
|
||||
const message = {
|
||||
content: [{ type: 'text', text: '' }],
|
||||
role: 'user',
|
||||
};
|
||||
|
||||
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
|
||||
expect(ollamaMessage).toEqual({
|
||||
content: '',
|
||||
@@ -601,7 +605,7 @@ describe('LobeOllamaAI', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle content with only images (no text)', () => {
|
||||
it('should handle content with only images (no text)', async () => {
|
||||
const message = {
|
||||
content: [
|
||||
{
|
||||
@@ -612,7 +616,7 @@ describe('LobeOllamaAI', () => {
|
||||
role: 'user',
|
||||
};
|
||||
|
||||
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
|
||||
expect(ollamaMessage).toEqual({
|
||||
content: '',
|
||||
@@ -621,7 +625,7 @@ describe('LobeOllamaAI', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle multiple images without text', () => {
|
||||
it('should handle multiple images without text', async () => {
|
||||
const message = {
|
||||
content: [
|
||||
{
|
||||
@@ -640,7 +644,7 @@ describe('LobeOllamaAI', () => {
|
||||
role: 'user',
|
||||
};
|
||||
|
||||
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
|
||||
expect(ollamaMessage).toEqual({
|
||||
content: '',
|
||||
@@ -649,7 +653,10 @@ describe('LobeOllamaAI', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should ignore images with invalid data URIs', () => {
|
||||
it('should handle URL image conversion failure gracefully', async () => {
|
||||
// Mock imageUrlToBase64 to simulate conversion failure
|
||||
vi.mocked(imageUrlToBase64).mockRejectedValue(new Error('Network error'));
|
||||
|
||||
const message = {
|
||||
content: [
|
||||
{ type: 'text', text: 'Hello' },
|
||||
@@ -665,16 +672,18 @@ describe('LobeOllamaAI', () => {
|
||||
role: 'user',
|
||||
};
|
||||
|
||||
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
|
||||
expect(ollamaMessage).toEqual({
|
||||
content: 'Hello',
|
||||
role: 'user',
|
||||
images: ['valid123'],
|
||||
});
|
||||
// When URL conversion fails, it should continue processing other images
|
||||
// The mock is set up to fail, so only the base64 image should be included
|
||||
expect(ollamaMessage.content).toBe('Hello');
|
||||
expect(ollamaMessage.role).toBe('user');
|
||||
expect(ollamaMessage.images).toBeDefined();
|
||||
// Should have at least the base64 image
|
||||
expect(ollamaMessage.images).toContain('valid123');
|
||||
});
|
||||
|
||||
it('should handle complex interleaved content', () => {
|
||||
it('should handle complex interleaved content', async () => {
|
||||
const message = {
|
||||
content: [
|
||||
{ type: 'text', text: 'Text 1' },
|
||||
@@ -692,7 +701,7 @@ describe('LobeOllamaAI', () => {
|
||||
role: 'user',
|
||||
};
|
||||
|
||||
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
|
||||
expect(ollamaMessage).toEqual({
|
||||
content: 'Text 3', // Should keep latest text
|
||||
@@ -701,7 +710,7 @@ describe('LobeOllamaAI', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle assistant role with images', () => {
|
||||
it('should handle assistant role with images', async () => {
|
||||
const message = {
|
||||
content: [
|
||||
{ type: 'text', text: 'Here is the image' },
|
||||
@@ -713,7 +722,7 @@ describe('LobeOllamaAI', () => {
|
||||
role: 'assistant',
|
||||
};
|
||||
|
||||
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
|
||||
expect(ollamaMessage).toEqual({
|
||||
content: 'Here is the image',
|
||||
@@ -722,13 +731,13 @@ describe('LobeOllamaAI', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle system role with text', () => {
|
||||
it('should handle system role with text', async () => {
|
||||
const message = {
|
||||
content: [{ type: 'text', text: 'You are a helpful assistant' }],
|
||||
role: 'system',
|
||||
};
|
||||
|
||||
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
|
||||
expect(ollamaMessage).toEqual({
|
||||
content: 'You are a helpful assistant',
|
||||
@@ -736,13 +745,13 @@ describe('LobeOllamaAI', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle empty content array', () => {
|
||||
it('should handle empty content array', async () => {
|
||||
const message = {
|
||||
content: [],
|
||||
role: 'user',
|
||||
};
|
||||
|
||||
const ollamaMessage = ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
const ollamaMessage = await ollamaAI['convertContentToOllamaMessage'](message as any);
|
||||
|
||||
expect(ollamaMessage).toEqual({
|
||||
content: '',
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import type { ChatModelCard } from '@lobechat/types';
|
||||
import { imageUrlToBase64 } from '@lobechat/utils';
|
||||
import { ModelProvider } from 'model-bank';
|
||||
import { Ollama, Tool } from 'ollama/browser';
|
||||
import { ClientOptions } from 'openai';
|
||||
@@ -34,12 +35,17 @@ export const params = {
|
||||
provider: ModelProvider.Ollama,
|
||||
};
|
||||
|
||||
interface OllamaAIParams extends ClientOptions {
|
||||
fetch?: typeof fetch;
|
||||
}
|
||||
|
||||
export class LobeOllamaAI implements LobeRuntimeAI {
|
||||
private client: Ollama;
|
||||
|
||||
baseURL?: string;
|
||||
fetch?: typeof fetch;
|
||||
|
||||
constructor({ baseURL }: ClientOptions = {}) {
|
||||
constructor({ baseURL, fetch: customFetch }: OllamaAIParams = {}) {
|
||||
try {
|
||||
if (baseURL) new URL(baseURL);
|
||||
} catch (e) {
|
||||
@@ -49,6 +55,7 @@ export class LobeOllamaAI implements LobeRuntimeAI {
|
||||
this.client = new Ollama(!baseURL ? undefined : { host: baseURL });
|
||||
|
||||
if (baseURL) this.baseURL = baseURL;
|
||||
this.fetch = customFetch;
|
||||
}
|
||||
|
||||
async chat(payload: ChatStreamPayload, options?: ChatMethodOptions) {
|
||||
@@ -61,7 +68,7 @@ export class LobeOllamaAI implements LobeRuntimeAI {
|
||||
options?.signal?.addEventListener('abort', abort);
|
||||
|
||||
const response = await this.client.chat({
|
||||
messages: this.buildOllamaMessages(payload.messages),
|
||||
messages: await this.buildOllamaMessages(payload.messages),
|
||||
model: payload.model,
|
||||
options: {
|
||||
frequency_penalty: payload.frequency_penalty,
|
||||
@@ -169,11 +176,13 @@ export class LobeOllamaAI implements LobeRuntimeAI {
|
||||
}
|
||||
};
|
||||
|
||||
private buildOllamaMessages(messages: OpenAIChatMessage[]) {
|
||||
return messages.map((message) => this.convertContentToOllamaMessage(message));
|
||||
private async buildOllamaMessages(messages: OpenAIChatMessage[]) {
|
||||
return Promise.all(messages.map((message) => this.convertContentToOllamaMessage(message)));
|
||||
}
|
||||
|
||||
private convertContentToOllamaMessage = (message: OpenAIChatMessage): OllamaMessage => {
|
||||
private convertContentToOllamaMessage = async (
|
||||
message: OpenAIChatMessage,
|
||||
): Promise<OllamaMessage> => {
|
||||
if (typeof message.content === 'string') {
|
||||
return { content: message.content, role: message.role };
|
||||
}
|
||||
@@ -183,6 +192,9 @@ export class LobeOllamaAI implements LobeRuntimeAI {
|
||||
role: message.role,
|
||||
};
|
||||
|
||||
// Collect all URL images to convert them in parallel
|
||||
const urlImagePromises: Promise<string | null>[] = [];
|
||||
|
||||
for (const content of message.content) {
|
||||
switch (content.type) {
|
||||
case 'text': {
|
||||
@@ -191,16 +203,39 @@ export class LobeOllamaAI implements LobeRuntimeAI {
|
||||
break;
|
||||
}
|
||||
case 'image_url': {
|
||||
const { base64 } = parseDataUri(content.image_url.url);
|
||||
const { base64, type } = parseDataUri(content.image_url.url);
|
||||
|
||||
// If already base64 format, use it directly
|
||||
if (base64) {
|
||||
ollamaMessage.images ??= [];
|
||||
ollamaMessage.images.push(base64);
|
||||
}
|
||||
// If it's a URL, collect the promise for parallel processing
|
||||
else if (type === 'url') {
|
||||
urlImagePromises.push(
|
||||
imageUrlToBase64(content.image_url.url, this.fetch)
|
||||
.then((result) =>
|
||||
result.base64 && result.base64.trim() !== '' ? result.base64 : null,
|
||||
)
|
||||
.catch(() => null), // Return null on error to continue processing other images
|
||||
);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process all URL images in parallel
|
||||
if (urlImagePromises.length > 0) {
|
||||
const urlImages = await Promise.all(urlImagePromises);
|
||||
const validUrlImages = urlImages.filter((img): img is string => img !== null);
|
||||
|
||||
if (validUrlImages.length > 0) {
|
||||
ollamaMessage.images ??= [];
|
||||
ollamaMessage.images.push(...validUrlImages);
|
||||
}
|
||||
}
|
||||
|
||||
return ollamaMessage;
|
||||
};
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { imageUrlToBase64 } from '@lobechat/utils/imageToBase64';
|
||||
import createDebug from 'debug';
|
||||
import { RuntimeImageGenParamsValue } from 'model-bank';
|
||||
|
||||
@@ -5,7 +6,6 @@ import { CreateImageOptions } from '../../core/openaiCompatibleFactory';
|
||||
import { CreateImagePayload, CreateImageResponse } from '../../types';
|
||||
import { AgentRuntimeErrorType } from '../../types/error';
|
||||
import { AgentRuntimeError } from '../../utils/createError';
|
||||
import { imageUrlToBase64 } from '../../utils/imageToBase64';
|
||||
import { parseDataUri } from '../../utils/uriParser';
|
||||
|
||||
const log = createDebug('lobe-image:siliconcloud');
|
||||
@@ -16,7 +16,7 @@ interface SiliconCloudImageResponse {
|
||||
timings: { inference: number };
|
||||
}
|
||||
|
||||
async function convertToDataURI(imageUrl: string): Promise<string> {
|
||||
async function convertToDataURI(imageUrl: string, fetchImpl?: typeof fetch): Promise<string> {
|
||||
const { type, base64, mimeType } = parseDataUri(imageUrl);
|
||||
|
||||
if (type === 'base64') {
|
||||
@@ -27,7 +27,10 @@ async function convertToDataURI(imageUrl: string): Promise<string> {
|
||||
}
|
||||
|
||||
if (type === 'url') {
|
||||
const { base64: urlBase64, mimeType: urlMimeType } = await imageUrlToBase64(imageUrl);
|
||||
const { base64: urlBase64, mimeType: urlMimeType } = await imageUrlToBase64(
|
||||
imageUrl,
|
||||
fetchImpl,
|
||||
);
|
||||
return `data:${urlMimeType};base64,${urlBase64}`;
|
||||
}
|
||||
|
||||
@@ -39,7 +42,7 @@ export async function createSiliconCloudImage(
|
||||
options: CreateImageOptions,
|
||||
): Promise<CreateImageResponse> {
|
||||
const { model, params } = payload;
|
||||
const { apiKey, baseURL, provider } = options;
|
||||
const { apiKey, baseURL, provider, fetch: customFetch } = options;
|
||||
|
||||
log('Creating image with SiliconCloud model: %s, params: %O', model, params);
|
||||
|
||||
@@ -72,17 +75,17 @@ export async function createSiliconCloudImage(
|
||||
|
||||
if (key === 'imageUrl') {
|
||||
if (typeof value === 'string') {
|
||||
body['image'] = await convertToDataURI(value);
|
||||
body['image'] = await convertToDataURI(value, customFetch);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (key === 'imageUrls') {
|
||||
if (Array.isArray(value) && value.length > 0) {
|
||||
body['image'] = await convertToDataURI(value[0]);
|
||||
body['image'] = await convertToDataURI(value[0], customFetch);
|
||||
if (model === 'Qwen/Qwen-Image-Edit-2509') {
|
||||
if (value.length > 1) body['image2'] = await convertToDataURI(value[1]);
|
||||
if (value.length > 2) body['image3'] = await convertToDataURI(value[2]);
|
||||
if (value.length > 1) body['image2'] = await convertToDataURI(value[1], customFetch);
|
||||
if (value.length > 2) body['image3'] = await convertToDataURI(value[2], customFetch);
|
||||
}
|
||||
}
|
||||
continue;
|
||||
@@ -94,7 +97,8 @@ export async function createSiliconCloudImage(
|
||||
|
||||
log('Request body: %O', body);
|
||||
|
||||
const response = await fetch(endpoint, {
|
||||
const fetchFn = customFetch ?? fetch;
|
||||
const response = await fetchFn(endpoint, {
|
||||
body: JSON.stringify(body),
|
||||
headers: {
|
||||
'Authorization': `Bearer ${apiKey}`,
|
||||
|
||||
@@ -7,7 +7,7 @@ import { LobeGoogleAI } from '../google';
|
||||
const DEFAULT_VERTEXAI_LOCATION = 'global';
|
||||
|
||||
export class LobeVertexAI extends LobeGoogleAI {
|
||||
static initFromVertexAI(params?: GoogleGenAIOptions) {
|
||||
static initFromVertexAI(params?: GoogleGenAIOptions, customFetch?: typeof fetch) {
|
||||
try {
|
||||
const client = new GoogleGenAI({
|
||||
...params,
|
||||
@@ -15,7 +15,12 @@ export class LobeVertexAI extends LobeGoogleAI {
|
||||
vertexai: true,
|
||||
});
|
||||
|
||||
return new LobeGoogleAI({ apiKey: 'avoid-error', client, isVertexAi: true });
|
||||
return new LobeGoogleAI({
|
||||
apiKey: 'avoid-error',
|
||||
client,
|
||||
fetch: customFetch,
|
||||
isVertexAi: true,
|
||||
});
|
||||
} catch (e) {
|
||||
const err = e as Error;
|
||||
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { imageToBase64, imageUrlToBase64 } from './imageToBase64';
|
||||
|
||||
describe('imageToBase64', () => {
|
||||
let mockImage: HTMLImageElement;
|
||||
let mockCanvas: HTMLCanvasElement;
|
||||
let mockContext: CanvasRenderingContext2D;
|
||||
|
||||
beforeEach(() => {
|
||||
mockImage = {
|
||||
width: 200,
|
||||
height: 100,
|
||||
} as HTMLImageElement;
|
||||
|
||||
mockContext = {
|
||||
drawImage: vi.fn(),
|
||||
} as unknown as CanvasRenderingContext2D;
|
||||
|
||||
mockCanvas = {
|
||||
width: 0,
|
||||
height: 0,
|
||||
getContext: vi.fn().mockReturnValue(mockContext),
|
||||
toDataURL: vi.fn().mockReturnValue('data:image/webp;base64,mockBase64Data'),
|
||||
} as unknown as HTMLCanvasElement;
|
||||
|
||||
vi.spyOn(document, 'createElement').mockReturnValue(mockCanvas);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('should convert image to base64 with correct size and type', () => {
|
||||
const result = imageToBase64({ img: mockImage, size: 100, type: 'image/jpeg' });
|
||||
|
||||
expect(document.createElement).toHaveBeenCalledWith('canvas');
|
||||
expect(mockCanvas.width).toBe(100);
|
||||
expect(mockCanvas.height).toBe(100);
|
||||
expect(mockCanvas.getContext).toHaveBeenCalledWith('2d');
|
||||
expect(mockContext.drawImage).toHaveBeenCalledWith(mockImage, 50, 0, 100, 100, 0, 0, 100, 100);
|
||||
expect(mockCanvas.toDataURL).toHaveBeenCalledWith('image/jpeg');
|
||||
expect(result).toBe('data:image/webp;base64,mockBase64Data');
|
||||
});
|
||||
|
||||
it('should use default type when not specified', () => {
|
||||
imageToBase64({ img: mockImage, size: 100 });
|
||||
expect(mockCanvas.toDataURL).toHaveBeenCalledWith('image/webp');
|
||||
});
|
||||
|
||||
it('should handle taller images correctly', () => {
|
||||
mockImage.width = 100;
|
||||
mockImage.height = 200;
|
||||
imageToBase64({ img: mockImage, size: 100 });
|
||||
expect(mockContext.drawImage).toHaveBeenCalledWith(mockImage, 0, 50, 100, 100, 0, 0, 100, 100);
|
||||
});
|
||||
});
|
||||
|
||||
describe('imageUrlToBase64', () => {
|
||||
const mockFetch = vi.fn();
|
||||
const mockArrayBuffer = new ArrayBuffer(8);
|
||||
|
||||
beforeEach(() => {
|
||||
global.fetch = mockFetch;
|
||||
global.btoa = vi.fn().mockReturnValue('mockBase64String');
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('should convert image URL to base64 string', async () => {
|
||||
mockFetch.mockResolvedValue({
|
||||
arrayBuffer: () => Promise.resolve(mockArrayBuffer),
|
||||
blob: () => Promise.resolve(new Blob([mockArrayBuffer], { type: 'image/jpg' })),
|
||||
});
|
||||
|
||||
const result = await imageUrlToBase64('https://example.com/image.jpg');
|
||||
|
||||
expect(mockFetch).toHaveBeenCalledWith('https://example.com/image.jpg');
|
||||
expect(global.btoa).toHaveBeenCalled();
|
||||
expect(result).toEqual({ base64: 'mockBase64String', mimeType: 'image/jpg' });
|
||||
});
|
||||
|
||||
it('should throw an error when fetch fails', async () => {
|
||||
const mockError = new Error('Fetch failed');
|
||||
mockFetch.mockRejectedValue(mockError);
|
||||
|
||||
await expect(imageUrlToBase64('https://example.com/image.jpg')).rejects.toThrow('Fetch failed');
|
||||
});
|
||||
});
|
||||
@@ -1,62 +0,0 @@
|
||||
export const imageToBase64 = ({
|
||||
size,
|
||||
img,
|
||||
type = 'image/webp',
|
||||
}: {
|
||||
img: HTMLImageElement;
|
||||
size: number;
|
||||
type?: string;
|
||||
}) => {
|
||||
const canvas = document.createElement('canvas');
|
||||
const ctx = canvas.getContext('2d') as CanvasRenderingContext2D;
|
||||
let startX = 0;
|
||||
let startY = 0;
|
||||
|
||||
if (img.width > img.height) {
|
||||
startX = (img.width - img.height) / 2;
|
||||
} else {
|
||||
startY = (img.height - img.width) / 2;
|
||||
}
|
||||
|
||||
canvas.width = size;
|
||||
canvas.height = size;
|
||||
|
||||
ctx.drawImage(
|
||||
img,
|
||||
startX,
|
||||
startY,
|
||||
Math.min(img.width, img.height),
|
||||
Math.min(img.width, img.height),
|
||||
0,
|
||||
0,
|
||||
size,
|
||||
size,
|
||||
);
|
||||
|
||||
return canvas.toDataURL(type);
|
||||
};
|
||||
|
||||
export const imageUrlToBase64 = async (
|
||||
imageUrl: string,
|
||||
): Promise<{ base64: string; mimeType: string }> => {
|
||||
try {
|
||||
const res = await fetch(imageUrl);
|
||||
const blob = await res.blob();
|
||||
const arrayBuffer = await blob.arrayBuffer();
|
||||
|
||||
const base64 =
|
||||
typeof btoa === 'function'
|
||||
? btoa(
|
||||
new Uint8Array(arrayBuffer).reduce(
|
||||
(data, byte) => data + String.fromCharCode(byte),
|
||||
'',
|
||||
),
|
||||
)
|
||||
: Buffer.from(arrayBuffer).toString('base64');
|
||||
|
||||
return { base64, mimeType: blob.type };
|
||||
} catch (error) {
|
||||
console.error('Error converting image to base64:', error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
@@ -1,13 +1,34 @@
|
||||
import fetch from 'node-fetch';
|
||||
import type { RequestInit as NodeFetchOptions } from 'node-fetch';
|
||||
import { RequestFilteringAgentOptions, useAgent as ssrfAgent } from 'request-filtering-agent';
|
||||
|
||||
interface FetchOptions extends RequestInit {
|
||||
ssrf?: boolean;
|
||||
}
|
||||
|
||||
const toStandardResponse = async (response: Awaited<ReturnType<typeof fetch>>) => {
|
||||
return new Response(await response.arrayBuffer(), {
|
||||
headers: response.headers as any,
|
||||
status: response.status,
|
||||
statusText: response.statusText,
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* SSRF-safe fetch implementation for server-side use
|
||||
* Uses request-filtering-agent to prevent requests to private IP addresses
|
||||
*/
|
||||
// eslint-disable-next-line no-undef
|
||||
export const ssrfSafeFetch = async (url: string, options?: RequestInit): Promise<Response> => {
|
||||
export const ssrfSafeFetch = async (url: string, options?: FetchOptions): Promise<Response> => {
|
||||
const { ssrf, ...restOptions } = options ?? {};
|
||||
const fetchOptions = restOptions as NodeFetchOptions;
|
||||
|
||||
try {
|
||||
if (!ssrf) {
|
||||
const response = await fetch(url, fetchOptions);
|
||||
return await toStandardResponse(response);
|
||||
}
|
||||
|
||||
// Configure SSRF protection options
|
||||
const agentOptions: RequestFilteringAgentOptions = {
|
||||
allowIPAddressList: process.env.SSRF_ALLOW_IP_ADDRESS_LIST?.split(',') || [],
|
||||
@@ -18,16 +39,11 @@ export const ssrfSafeFetch = async (url: string, options?: RequestInit): Promise
|
||||
|
||||
// Use node-fetch with SSRF protection agent
|
||||
const response = await fetch(url, {
|
||||
...options,
|
||||
...fetchOptions,
|
||||
agent: ssrfAgent(url, agentOptions),
|
||||
} as any);
|
||||
|
||||
// Convert node-fetch Response to standard Response
|
||||
return new Response(await response.arrayBuffer(), {
|
||||
headers: response.headers as any,
|
||||
status: response.status,
|
||||
statusText: response.statusText,
|
||||
});
|
||||
|
||||
return await toStandardResponse(response);
|
||||
} catch (error) {
|
||||
console.error('SSRF-safe fetch error:', error);
|
||||
throw new Error(
|
||||
|
||||
@@ -3,7 +3,9 @@
|
||||
"version": "1.0.0",
|
||||
"private": true,
|
||||
"description": "",
|
||||
"exports": "./index.ts",
|
||||
"main": "index.ts",
|
||||
"types": "index.ts",
|
||||
"scripts": {
|
||||
"test": "vitest run"
|
||||
},
|
||||
|
||||
@@ -17,7 +17,8 @@
|
||||
"@lobechat/const": "workspace:*",
|
||||
"@lobechat/types": "workspace:*",
|
||||
"dayjs": "^1.11.18",
|
||||
"dompurify": "^3.2.7"
|
||||
"dompurify": "^3.2.7",
|
||||
"ssrf-safe-fetch": "workspace:*"
|
||||
},
|
||||
"devDependencies": {
|
||||
"vitest-canvas-mock": "^0.3.3"
|
||||
|
||||
@@ -36,11 +36,20 @@ export const imageToBase64 = ({
|
||||
return canvas.toDataURL(type);
|
||||
};
|
||||
|
||||
/**
|
||||
* Convert image URL to base64.
|
||||
* Accepts an optional custom fetch implementation (e.g., SSRF-safe fetch) for server environments.
|
||||
*/
|
||||
export const imageUrlToBase64 = async (
|
||||
imageUrl: string,
|
||||
customFetch?: typeof fetch,
|
||||
): Promise<{ base64: string; mimeType: string }> => {
|
||||
try {
|
||||
const res = await fetch(imageUrl);
|
||||
const fetchFn = customFetch || fetch;
|
||||
const ssrfOptions = customFetch ? ({ ssrf: true } as RequestInit) : undefined;
|
||||
const res =
|
||||
ssrfOptions === undefined ? await fetchFn(imageUrl) : await fetchFn(imageUrl, ssrfOptions);
|
||||
|
||||
const blob = await res.blob();
|
||||
const arrayBuffer = await blob.arrayBuffer();
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import { LobeVertexAI } from '@lobechat/model-runtime/vertexai';
|
||||
import { ClientSecretPayload } from '@lobechat/types';
|
||||
import { safeParseJSON } from '@lobechat/utils';
|
||||
import { ModelProvider } from 'model-bank';
|
||||
import { ssrfSafeFetch } from 'ssrf-safe-fetch';
|
||||
|
||||
import { getLLMConfig } from '@/envs/llm';
|
||||
|
||||
@@ -11,6 +12,8 @@ import apiKeyManager from './apiKeyManager';
|
||||
|
||||
export * from './trace';
|
||||
|
||||
type RuntimeInitializeParams = Parameters<typeof ModelRuntime.initializeWithProvider>[1];
|
||||
|
||||
/**
|
||||
* Retrieves the options object from environment and apikeymanager
|
||||
* based on the provider and payload.
|
||||
@@ -171,7 +174,10 @@ const buildVertexOptions = (
|
||||
|
||||
const project = projectFromParams ?? projectFromCredentials ?? projectFromEnv;
|
||||
const location =
|
||||
(params.location as string | undefined) ?? payload.vertexAIRegion ?? process.env.VERTEXAI_LOCATION ?? undefined;
|
||||
(params.location as string | undefined) ??
|
||||
payload.vertexAIRegion ??
|
||||
process.env.VERTEXAI_LOCATION ??
|
||||
undefined;
|
||||
|
||||
const googleAuthOptions = params.googleAuthOptions ?? (credentials ? { credentials } : undefined);
|
||||
|
||||
@@ -197,19 +203,28 @@ const buildVertexOptions = (
|
||||
export const initModelRuntimeWithUserPayload = (
|
||||
provider: string,
|
||||
payload: ClientSecretPayload,
|
||||
params: any = {},
|
||||
params: RuntimeInitializeParams = {},
|
||||
) => {
|
||||
const { fetch: customFetch, ...restParams } = params ?? {};
|
||||
const fetchImpl = (customFetch ?? ssrfSafeFetch) as typeof fetch;
|
||||
const runtimeProvider = payload.runtimeProvider ?? provider;
|
||||
|
||||
if (runtimeProvider === ModelProvider.VertexAI) {
|
||||
const vertexOptions = buildVertexOptions(payload, params);
|
||||
const runtime = LobeVertexAI.initFromVertexAI(vertexOptions);
|
||||
const vertexOptions = buildVertexOptions(payload, restParams as Partial<GoogleGenAIOptions>);
|
||||
const runtime = LobeVertexAI.initFromVertexAI(vertexOptions, fetchImpl);
|
||||
|
||||
return new ModelRuntime(runtime);
|
||||
}
|
||||
|
||||
return ModelRuntime.initializeWithProvider(runtimeProvider, {
|
||||
const mergedParams = {
|
||||
...getParamsFromPayload(runtimeProvider, payload),
|
||||
...params,
|
||||
});
|
||||
...restParams,
|
||||
};
|
||||
|
||||
const runtimeParams = {
|
||||
...mergedParams,
|
||||
fetch: (mergedParams as Record<string, any>).fetch ?? fetchImpl,
|
||||
};
|
||||
|
||||
return ModelRuntime.initializeWithProvider(runtimeProvider, runtimeParams);
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user