feat: add custom stream handle support for LobeOpenAICompatibleFactory (#5039)

* ♻️ refactor: add function call support for Spark

* ♻️ refactor: add non-stream mode support

* ️ perf: using stream mode for tools call

*  feat: add `handleStream` & `handleStreamResponse` for LobeOpenAICompatibleFactory, custom stream handle

*  feat: add `handleTtransformResponseToStream` for custom non-stream transform handle

* ♻️ refactor: refactor qwen to LobeOpenAICompatibleFactory, enable `enable_search` for Qwen LLM

* 🔨 chore: add unit test for LobeOpenAICompatibleFactory

* 🔨 chore: add unit test for SparkAIStream

* 🔨 chore: add unit test for Qwen & Spark

* 🐛 fix: fix Qwen param range error

* 🔨 chore: add `QwenLegacyModels` array, limit `presence_penalty`

* 🐛 fix: fix typo
This commit is contained in:
Zhijie He
2024-12-29 13:06:52 +08:00
committed by GitHub
parent cf0e8d8b48
commit ea7e732350
10 changed files with 568 additions and 349 deletions
+3 -6
View File
@@ -10,7 +10,6 @@ const Spark: ModelProviderCard = {
'Spark Lite 是一款轻量级大语言模型,具备极低的延迟与高效的处理能力,完全免费开放,支持实时在线搜索功能。其快速响应的特性使其在低算力设备上的推理应用和模型微调中表现出色,为用户带来出色的成本效益和智能体验,尤其在知识问答、内容生成及搜索场景下表现不俗。',
displayName: 'Spark Lite',
enabled: true,
functionCall: false,
id: 'lite',
maxOutput: 4096,
},
@@ -20,7 +19,6 @@ const Spark: ModelProviderCard = {
'Spark Pro 是一款为专业领域优化的高性能大语言模型,专注数学、编程、医疗、教育等多个领域,并支持联网搜索及内置天气、日期等插件。其优化后模型在复杂知识问答、语言理解及高层次文本创作中展现出色表现和高效性能,是适合专业应用场景的理想选择。',
displayName: 'Spark Pro',
enabled: true,
functionCall: false,
id: 'generalv3',
maxOutput: 8192,
},
@@ -30,7 +28,6 @@ const Spark: ModelProviderCard = {
'Spark Pro 128K 配置了特大上下文处理能力,能够处理多达128K的上下文信息,特别适合需通篇分析和长期逻辑关联处理的长文内容,可在复杂文本沟通中提供流畅一致的逻辑与多样的引用支持。',
displayName: 'Spark Pro 128K',
enabled: true,
functionCall: false,
id: 'pro-128k',
maxOutput: 4096,
},
@@ -40,7 +37,7 @@ const Spark: ModelProviderCard = {
'Spark Max 为功能最为全面的版本,支持联网搜索及众多内置插件。其全面优化的核心能力以及系统角色设定和函数调用功能,使其在各种复杂应用场景中的表现极为优异和出色。',
displayName: 'Spark Max',
enabled: true,
functionCall: false,
functionCall: true,
id: 'generalv3.5',
maxOutput: 8192,
},
@@ -50,7 +47,7 @@ const Spark: ModelProviderCard = {
'Spark Max 32K 配置了大上下文处理能力,更强的上下文理解和逻辑推理能力,支持32K tokens的文本输入,适用于长文档阅读、私有知识问答等场景',
displayName: 'Spark Max 32K',
enabled: true,
functionCall: false,
functionCall: true,
id: 'max-32k',
maxOutput: 8192,
},
@@ -60,7 +57,7 @@ const Spark: ModelProviderCard = {
'Spark Ultra 是星火大模型系列中最为强大的版本,在升级联网搜索链路同时,提升对文本内容的理解和总结能力。它是用于提升办公生产力和准确响应需求的全方位解决方案,是引领行业的智能产品。',
displayName: 'Spark 4.0 Ultra',
enabled: true,
functionCall: false,
functionCall: true,
id: '4.0Ultra',
maxOutput: 8192,
},
+13 -188
View File
@@ -2,8 +2,9 @@
import OpenAI from 'openai';
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import Qwen from '@/config/modelProviders/qwen';
import { AgentRuntimeErrorType, ModelProvider } from '@/libs/agent-runtime';
import { LobeOpenAICompatibleRuntime } from '@/libs/agent-runtime';
import { ModelProvider } from '@/libs/agent-runtime';
import { AgentRuntimeErrorType } from '@/libs/agent-runtime';
import * as debugStreamModule from '../utils/debugStream';
import { LobeQwenAI } from './index';
@@ -16,7 +17,7 @@ const invalidErrorType = AgentRuntimeErrorType.InvalidProviderAPIKey;
// Mock the console.error to avoid polluting test output
vi.spyOn(console, 'error').mockImplementation(() => {});
let instance: LobeQwenAI;
let instance: LobeOpenAICompatibleRuntime;
beforeEach(() => {
instance = new LobeQwenAI({ apiKey: 'test' });
@@ -40,183 +41,7 @@ describe('LobeQwenAI', () => {
});
});
describe('models', () => {
it('should correctly list available models', async () => {
const instance = new LobeQwenAI({ apiKey: 'test_api_key' });
vi.spyOn(instance, 'models').mockResolvedValue(Qwen.chatModels);
const models = await instance.models();
expect(models).toEqual(Qwen.chatModels);
});
});
describe('chat', () => {
describe('Params', () => {
it('should call llms with proper options', async () => {
const mockStream = new ReadableStream();
const mockResponse = Promise.resolve(mockStream);
(instance['client'].chat.completions.create as Mock).mockResolvedValue(mockResponse);
const result = await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
temperature: 0.6,
top_p: 0.7,
});
// Assert
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith(
{
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
temperature: 0.6,
stream: true,
top_p: 0.7,
result_format: 'message',
},
{ headers: { Accept: '*/*' } },
);
expect(result).toBeInstanceOf(Response);
});
it('should call vlms with proper options', async () => {
const mockStream = new ReadableStream();
const mockResponse = Promise.resolve(mockStream);
(instance['client'].chat.completions.create as Mock).mockResolvedValue(mockResponse);
const result = await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-vl-plus',
temperature: 0.6,
top_p: 0.7,
});
// Assert
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith(
{
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-vl-plus',
stream: true,
},
{ headers: { Accept: '*/*' } },
);
expect(result).toBeInstanceOf(Response);
});
it('should transform non-streaming response to stream correctly', async () => {
const mockResponse = {
id: 'chatcmpl-fc539f49-51a8-94be-8061',
object: 'chat.completion',
created: 1719901794,
model: 'qwen-turbo',
choices: [
{
index: 0,
message: { role: 'assistant', content: 'Hello' },
finish_reason: 'stop',
logprobs: null,
},
],
} as OpenAI.ChatCompletion;
vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue(
mockResponse as any,
);
const result = await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
temperature: 0.6,
stream: false,
});
const decoder = new TextDecoder();
const reader = result.body!.getReader();
const stream: string[] = [];
while (true) {
const { value, done } = await reader.read();
if (done) break;
stream.push(decoder.decode(value));
}
expect(stream).toEqual([
'id: chatcmpl-fc539f49-51a8-94be-8061\n',
'event: text\n',
'data: "Hello"\n\n',
'id: chatcmpl-fc539f49-51a8-94be-8061\n',
'event: stop\n',
'data: "stop"\n\n',
]);
expect((await reader.read()).done).toBe(true);
});
it('should set temperature to undefined if temperature is 0 or >= 2', async () => {
const temperatures = [0, 2, 3];
const expectedTemperature = undefined;
for (const temp of temperatures) {
vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue(
new ReadableStream() as any,
);
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
temperature: temp,
});
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith(
expect.objectContaining({
messages: expect.any(Array),
model: 'qwen-turbo',
temperature: expectedTemperature,
}),
expect.any(Object),
);
}
});
it('should set temperature to original temperature', async () => {
vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue(
new ReadableStream() as any,
);
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
temperature: 1.5,
});
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith(
expect.objectContaining({
messages: expect.any(Array),
model: 'qwen-turbo',
temperature: 1.5,
}),
expect.any(Object),
);
});
it('should set temperature to Float', async () => {
const createMock = vi.fn().mockResolvedValue(new ReadableStream() as any);
vi.spyOn(instance['client'].chat.completions, 'create').mockImplementation(createMock);
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
temperature: 1,
});
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith(
expect.objectContaining({
messages: expect.any(Array),
model: 'qwen-turbo',
temperature: expect.any(Number),
}),
expect.any(Object),
);
const callArgs = createMock.mock.calls[0][0];
expect(Number.isInteger(callArgs.temperature)).toBe(false); // Temperature is always not an integer
});
});
describe('Error', () => {
it('should return QwenBizError with an openai error response when OpenAI.APIError is thrown', async () => {
// Arrange
@@ -238,7 +63,7 @@ describe('LobeQwenAI', () => {
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
model: 'qwen-turbo-latest',
temperature: 0.999,
});
} catch (e) {
@@ -278,7 +103,7 @@ describe('LobeQwenAI', () => {
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
model: 'qwen-turbo-latest',
temperature: 0.999,
});
} catch (e) {
@@ -304,7 +129,8 @@ describe('LobeQwenAI', () => {
instance = new LobeQwenAI({
apiKey: 'test',
baseURL: defaultBaseURL,
baseURL: 'https://api.abc.com/v1',
});
vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError);
@@ -313,13 +139,12 @@ describe('LobeQwenAI', () => {
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
model: 'qwen-turbo-latest',
temperature: 0.999,
});
} catch (e) {
expect(e).toEqual({
/* Desensitizing is unnecessary for a public-accessible gateway endpoint. */
endpoint: defaultBaseURL,
endpoint: 'https://api.***.com/v1',
error: {
cause: { message: 'api is undefined' },
stack: 'abc',
@@ -339,7 +164,7 @@ describe('LobeQwenAI', () => {
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
model: 'qwen-turbo-latest',
temperature: 0.999,
});
} catch (e) {
@@ -362,7 +187,7 @@ describe('LobeQwenAI', () => {
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
model: 'qwen-turbo-latest',
temperature: 0.999,
});
} catch (e) {
@@ -410,7 +235,7 @@ describe('LobeQwenAI', () => {
// 假设的测试函数调用,你可能需要根据实际情况调整
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
model: 'qwen-turbo-latest',
stream: true,
temperature: 0.999,
});
+45 -124
View File
@@ -1,129 +1,50 @@
import { omit } from 'lodash-es';
import OpenAI, { ClientOptions } from 'openai';
import { ModelProvider } from '../types';
import { LobeOpenAICompatibleFactory } from '../utils/openaiCompatibleFactory';
import Qwen from '@/config/modelProviders/qwen';
import { LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType } from '../error';
import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types';
import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
import { handleOpenAIError } from '../utils/handleOpenAIError';
import { transformResponseToStream } from '../utils/openaiCompatibleFactory';
import { StreamingResponse } from '../utils/response';
import { QwenAIStream } from '../utils/streams';
const DEFAULT_BASE_URL = 'https://dashscope.aliyuncs.com/compatible-mode/v1';
/*
QwenLegacyModels: A set of legacy Qwen models that do not support presence_penalty.
Currently, presence_penalty is only supported on Qwen commercial models and open-source models starting from Qwen 1.5 and later.
*/
export const QwenLegacyModels = new Set([
'qwen-72b-chat',
'qwen-14b-chat',
'qwen-7b-chat',
'qwen-1.8b-chat',
'qwen-1.8b-longcontext-chat',
]);
/**
* Use DashScope OpenAI compatible mode for now.
* DashScope OpenAI [compatible mode](https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-vl-plus-api) currently supports base64 image input for vision models e.g. qwen-vl-plus.
* You can use images input either:
* 1. Use qwen-vl-* out of box with base64 image_url input;
* or
* 2. Set S3-* enviroment variables properly to store all uploaded files.
*/
export class LobeQwenAI implements LobeRuntimeAI {
client: OpenAI;
baseURL: string;
export const LobeQwenAI = LobeOpenAICompatibleFactory({
baseURL: 'https://dashscope.aliyuncs.com/compatible-mode/v1',
chatCompletion: {
handlePayload: (payload) => {
const { model, presence_penalty, temperature, top_p, ...rest } = payload;
constructor({
apiKey,
baseURL = DEFAULT_BASE_URL,
...res
}: ClientOptions & Record<string, any> = {}) {
if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);
this.client = new OpenAI({ apiKey, baseURL, ...res });
this.baseURL = this.client.baseURL;
}
async models() {
return Qwen.chatModels;
}
async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions) {
try {
const params = this.buildCompletionParamsByModel(payload);
const response = await this.client.chat.completions.create(
params as OpenAI.ChatCompletionCreateParamsStreaming & { result_format: string },
{
headers: { Accept: '*/*' },
signal: options?.signal,
},
);
if (params.stream) {
const [prod, debug] = response.tee();
if (process.env.DEBUG_QWEN_CHAT_COMPLETION === '1') {
debugStream(debug.toReadableStream()).catch(console.error);
}
return StreamingResponse(QwenAIStream(prod, options?.callback), {
headers: options?.headers,
});
}
const stream = transformResponseToStream(response as unknown as OpenAI.ChatCompletion);
return StreamingResponse(QwenAIStream(stream, options?.callback), {
headers: options?.headers,
});
} catch (error) {
if ('status' in (error as any)) {
switch ((error as Response).status) {
case 401: {
throw AgentRuntimeError.chat({
endpoint: this.baseURL,
error: error as any,
errorType: AgentRuntimeErrorType.InvalidProviderAPIKey,
provider: ModelProvider.Qwen,
});
}
default: {
break;
}
}
}
const { errorResult, RuntimeError } = handleOpenAIError(error);
const errorType = RuntimeError || AgentRuntimeErrorType.ProviderBizError;
throw AgentRuntimeError.chat({
endpoint: this.baseURL,
error: errorResult,
errorType,
provider: ModelProvider.Qwen,
});
}
}
private buildCompletionParamsByModel(payload: ChatStreamPayload) {
const { model, temperature, top_p, stream, messages, tools } = payload;
const isVisionModel = model.startsWith('qwen-vl');
const params = {
...payload,
messages,
result_format: 'message',
stream: !!tools?.length ? false : (stream ?? true),
temperature:
temperature === 0 || temperature >= 2 ? undefined : temperature === 1 ? 0.999 : temperature, // 'temperature' must be Float
top_p: top_p && top_p >= 1 ? 0.999 : top_p,
};
/* Qwen-vl models temporarily do not support parameters below. */
/* Notice: `top_p` imposes significant impact on the resultthe default 1 or 0.999 is not a proper choice. */
return isVisionModel
? omit(
params,
'presence_penalty',
'frequency_penalty',
'temperature',
'result_format',
'top_p',
)
: omit(params, 'frequency_penalty');
}
}
return {
...rest,
frequency_penalty: undefined,
model,
presence_penalty:
QwenLegacyModels.has(model)
? undefined
: (presence_penalty !== undefined && presence_penalty >= -2 && presence_penalty <= 2)
? presence_penalty
: undefined,
stream: !payload.tools,
temperature: (temperature !== undefined && temperature >= 0 && temperature < 2) ? temperature : undefined,
...(model.startsWith('qwen-vl') ? {
top_p: (top_p !== undefined && top_p > 0 && top_p <= 1) ? top_p : undefined,
} : {
enable_search: true,
top_p: (top_p !== undefined && top_p > 0 && top_p < 1) ? top_p : undefined,
}),
} as any;
},
handleStream: QwenAIStream,
},
debug: {
chatCompletion: () => process.env.DEBUG_QWEN_CHAT_COMPLETION === '1',
},
provider: ModelProvider.Qwen,
});
+24 -28
View File
@@ -2,20 +2,17 @@
import OpenAI from 'openai';
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import {
ChatStreamCallbacks,
LobeOpenAICompatibleRuntime,
ModelProvider,
} from '@/libs/agent-runtime';
import { LobeOpenAICompatibleRuntime } from '@/libs/agent-runtime';
import { ModelProvider } from '@/libs/agent-runtime';
import { AgentRuntimeErrorType } from '@/libs/agent-runtime';
import * as debugStreamModule from '../utils/debugStream';
import { LobeSparkAI } from './index';
const provider = ModelProvider.Spark;
const defaultBaseURL = 'https://spark-api-open.xf-yun.com/v1';
const bizErrorType = 'ProviderBizError';
const invalidErrorType = 'InvalidProviderAPIKey';
const bizErrorType = AgentRuntimeErrorType.ProviderBizError;
const invalidErrorType = AgentRuntimeErrorType.InvalidProviderAPIKey;
// Mock the console.error to avoid polluting test output
vi.spyOn(console, 'error').mockImplementation(() => {});
@@ -46,7 +43,7 @@ describe('LobeSparkAI', () => {
describe('chat', () => {
describe('Error', () => {
it('should return OpenAIBizError with an openai error response when OpenAI.APIError is thrown', async () => {
it('should return QwenBizError with an openai error response when OpenAI.APIError is thrown', async () => {
// Arrange
const apiError = new OpenAI.APIError(
400,
@@ -66,8 +63,8 @@ describe('LobeSparkAI', () => {
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'general',
temperature: 0,
model: 'max-32k',
temperature: 0.999,
});
} catch (e) {
expect(e).toEqual({
@@ -82,7 +79,7 @@ describe('LobeSparkAI', () => {
}
});
it('should throw AgentRuntimeError with NoOpenAIAPIKey if no apiKey is provided', async () => {
it('should throw AgentRuntimeError with InvalidQwenAPIKey if no apiKey is provided', async () => {
try {
new LobeSparkAI({});
} catch (e) {
@@ -90,7 +87,7 @@ describe('LobeSparkAI', () => {
}
});
it('should return OpenAIBizError with the cause when OpenAI.APIError is thrown with cause', async () => {
it('should return QwenBizError with the cause when OpenAI.APIError is thrown with cause', async () => {
// Arrange
const errorInfo = {
stack: 'abc',
@@ -106,8 +103,8 @@ describe('LobeSparkAI', () => {
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'general',
temperature: 0,
model: 'max-32k',
temperature: 0.999,
});
} catch (e) {
expect(e).toEqual({
@@ -122,7 +119,7 @@ describe('LobeSparkAI', () => {
}
});
it('should return OpenAIBizError with an cause response with desensitize Url', async () => {
it('should return QwenBizError with an cause response with desensitize Url', async () => {
// Arrange
const errorInfo = {
stack: 'abc',
@@ -142,8 +139,8 @@ describe('LobeSparkAI', () => {
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'general',
temperature: 0,
model: 'max-32k',
temperature: 0.999,
});
} catch (e) {
expect(e).toEqual({
@@ -158,23 +155,22 @@ describe('LobeSparkAI', () => {
}
});
it('should throw an InvalidSparkAPIKey error type on 401 status code', async () => {
it('should throw an InvalidQwenAPIKey error type on 401 status code', async () => {
// Mock the API call to simulate a 401 error
const error = new Error('Unauthorized') as any;
const error = new Error('InvalidApiKey') as any;
error.status = 401;
vi.mocked(instance['client'].chat.completions.create).mockRejectedValue(error);
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'general',
temperature: 0,
model: 'max-32k',
temperature: 0.999,
});
} catch (e) {
// Expect the chat method to throw an error with InvalidSparkAPIKey
expect(e).toEqual({
endpoint: defaultBaseURL,
error: new Error('Unauthorized'),
error: new Error('InvalidApiKey'),
errorType: invalidErrorType,
provider,
});
@@ -191,8 +187,8 @@ describe('LobeSparkAI', () => {
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'general',
temperature: 0,
model: 'max-32k',
temperature: 0.999,
});
} catch (e) {
expect(e).toEqual({
@@ -239,9 +235,9 @@ describe('LobeSparkAI', () => {
// 假设的测试函数调用,你可能需要根据实际情况调整
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'general',
model: 'max-32k',
stream: true,
temperature: 0,
temperature: 0.999,
});
// 验证 debugStream 被调用
+4
View File
@@ -1,9 +1,13 @@
import { ModelProvider } from '../types';
import { LobeOpenAICompatibleFactory } from '../utils/openaiCompatibleFactory';
import { transformSparkResponseToStream, SparkAIStream } from '../utils/streams';
export const LobeSparkAI = LobeOpenAICompatibleFactory({
baseURL: 'https://spark-api-open.xf-yun.com/v1',
chatCompletion: {
handleStream: SparkAIStream,
handleTransformResponseToStream: transformSparkResponseToStream,
noUserId: true,
},
debug: {
@@ -1,10 +1,13 @@
// @vitest-environment node
import OpenAI from 'openai';
import type { Stream } from 'openai/streaming';
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import {
AgentRuntimeErrorType,
ChatStreamCallbacks,
ChatStreamPayload,
LobeOpenAICompatibleRuntime,
ModelProvider,
} from '@/libs/agent-runtime';
@@ -797,6 +800,134 @@ describe('LobeOpenAICompatibleFactory', () => {
});
});
it('should use custom stream handler when provided', async () => {
// Create a custom stream handler that handles both ReadableStream and OpenAI Stream
const customStreamHandler = vi.fn((stream: ReadableStream | Stream<OpenAI.ChatCompletionChunk>) => {
const readableStream = stream instanceof ReadableStream ? stream : stream.toReadableStream();
return new ReadableStream({
start(controller) {
const reader = readableStream.getReader();
const process = async () => {
try {
while (true) {
const { done, value } = await reader.read();
if (done) break;
controller.enqueue(value);
}
} finally {
controller.close();
}
};
process();
},
});
});
const LobeMockProvider = LobeOpenAICompatibleFactory({
baseURL: 'https://api.test.com/v1',
chatCompletion: {
handleStream: customStreamHandler,
},
provider: ModelProvider.OpenAI,
});
const instance = new LobeMockProvider({ apiKey: 'test' });
// Create a mock stream
const mockStream = new ReadableStream({
start(controller) {
controller.enqueue({
id: 'test-id',
choices: [{ delta: { content: 'Hello' }, index: 0 }],
created: Date.now(),
model: 'test-model',
object: 'chat.completion.chunk',
});
controller.close();
},
});
vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue({
tee: () => [mockStream, mockStream],
} as any);
const payload: ChatStreamPayload = {
messages: [{ content: 'Test', role: 'user' }],
model: 'test-model',
temperature: 0.7,
};
await instance.chat(payload);
expect(customStreamHandler).toHaveBeenCalled();
});
it('should use custom transform handler for non-streaming response', async () => {
const customTransformHandler = vi.fn((data: OpenAI.ChatCompletion): ReadableStream => {
return new ReadableStream({
start(controller) {
// Transform the completion to chunk format
controller.enqueue({
id: data.id,
choices: data.choices.map((choice) => ({
delta: { content: choice.message.content },
index: choice.index,
})),
created: data.created,
model: data.model,
object: 'chat.completion.chunk',
});
controller.close();
},
});
});
const LobeMockProvider = LobeOpenAICompatibleFactory({
baseURL: 'https://api.test.com/v1',
chatCompletion: {
handleTransformResponseToStream: customTransformHandler,
},
provider: ModelProvider.OpenAI,
});
const instance = new LobeMockProvider({ apiKey: 'test' });
const mockResponse: OpenAI.ChatCompletion = {
id: 'test-id',
choices: [
{
index: 0,
message: {
role: 'assistant',
content: 'Test response',
refusal: null
},
logprobs: null,
finish_reason: 'stop',
},
],
created: Date.now(),
model: 'test-model',
object: 'chat.completion',
usage: { completion_tokens: 2, prompt_tokens: 1, total_tokens: 3 },
};
vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue(
mockResponse as any,
);
const payload: ChatStreamPayload = {
messages: [{ content: 'Test', role: 'user' }],
model: 'test-model',
temperature: 0.7,
stream: false,
};
await instance.chat(payload);
expect(customTransformHandler).toHaveBeenCalledWith(mockResponse);
});
describe('DEBUG', () => {
it('should call debugStream and return StreamingTextResponse when DEBUG_OPENROUTER_CHAT_COMPLETION is 1', async () => {
// Arrange
@@ -25,6 +25,7 @@ import { handleOpenAIError } from '../handleOpenAIError';
import { convertOpenAIMessages } from '../openaiHelpers';
import { StreamingResponse } from '../response';
import { OpenAIStream, OpenAIStreamOptions } from '../streams';
import { ChatStreamCallbacks } from '../../types';
// the model contains the following keywords is not a chat model, so we should filter them out
export const CHAT_MODELS_BLOCK_LIST = [
@@ -62,10 +63,17 @@ interface OpenAICompatibleFactoryOptions<T extends Record<string, any> = any> {
payload: ChatStreamPayload,
options: ConstructorOptions<T>,
) => OpenAI.ChatCompletionCreateParamsStreaming;
handleStream?: (
stream: Stream<OpenAI.ChatCompletionChunk> | ReadableStream,
callbacks?: ChatStreamCallbacks,
) => ReadableStream;
handleStreamBizErrorType?: (error: {
message: string;
name: string;
}) => ILobeAgentRuntimeErrorType | undefined;
handleTransformResponseToStream?: (
data: OpenAI.ChatCompletion,
) => ReadableStream<OpenAI.ChatCompletionChunk>;
noUserId?: boolean;
};
constructorOptions?: ConstructorOptions<T>;
@@ -228,7 +236,8 @@ export const LobeOpenAICompatibleFactory = <T extends Record<string, any> = any>
debugStream(useForDebugStream).catch(console.error);
}
return StreamingResponse(OpenAIStream(prod, streamOptions), {
const streamHandler = chatCompletion?.handleStream || OpenAIStream;
return StreamingResponse(streamHandler(prod, streamOptions), {
headers: options?.headers,
});
}
@@ -239,9 +248,11 @@ export const LobeOpenAICompatibleFactory = <T extends Record<string, any> = any>
if (responseMode === 'json') return Response.json(response);
const stream = transformResponseToStream(response as unknown as OpenAI.ChatCompletion);
const transformHandler = chatCompletion?.handleTransformResponseToStream || transformResponseToStream;
const stream = transformHandler(response as unknown as OpenAI.ChatCompletion);
return StreamingResponse(OpenAIStream(stream, streamOptions), {
const streamHandler = chatCompletion?.handleStream || OpenAIStream;
return StreamingResponse(streamHandler(stream, streamOptions), {
headers: options?.headers,
});
} catch (error) {
@@ -7,3 +7,4 @@ export * from './ollama';
export * from './openai';
export * from './protocol';
export * from './qwen';
export * from './spark';
@@ -0,0 +1,199 @@
import { beforeAll, describe, expect, it, vi } from 'vitest';
import { SparkAIStream, transformSparkResponseToStream } from './spark';
import type OpenAI from 'openai';
describe('SparkAIStream', () => {
beforeAll(() => {});
it('should transform non-streaming response to stream', async () => {
const mockResponse = {
id: "cha000ceba6@dx193d200b580b8f3532",
object: "chat.completion",
created: 1734395014,
model: "max-32k",
choices: [
{
message: {
role: "assistant",
content: "",
refusal: null,
tool_calls: {
type: "function",
function: {
arguments: '{"city":"Shanghai"}',
name: "realtime-weather____fetchCurrentWeather"
},
id: "call_1"
}
},
index: 0,
logprobs: null,
finish_reason: "tool_calls"
}
],
usage: {
prompt_tokens: 8,
completion_tokens: 0,
total_tokens: 8
}
} as unknown as OpenAI.ChatCompletion;
const stream = transformSparkResponseToStream(mockResponse);
const decoder = new TextDecoder();
const chunks = [];
// @ts-ignore
for await (const chunk of stream) {
chunks.push(chunk);
}
expect(chunks).toHaveLength(2);
expect(chunks[0].choices[0].delta.tool_calls).toEqual([{
function: {
arguments: '{"city":"Shanghai"}',
name: "realtime-weather____fetchCurrentWeather"
},
id: "call_1",
index: 0,
type: "function"
}]);
expect(chunks[1].choices[0].finish_reason).toBeDefined();
});
it('should transform streaming response with tool calls', async () => {
const mockStream = new ReadableStream({
start(controller) {
controller.enqueue({
id: "cha000b0bf9@dx193d1ffa61cb894532",
object: "chat.completion.chunk",
created: 1734395014,
model: "max-32k",
choices: [
{
delta: {
role: "assistant",
content: "",
tool_calls: {
type: "function",
function: {
arguments: '{"city":"Shanghai"}',
name: "realtime-weather____fetchCurrentWeather"
},
id: "call_1"
}
},
index: 0
}
]
} as unknown as OpenAI.ChatCompletionChunk);
controller.close();
}
});
const onToolCallMock = vi.fn();
const protocolStream = SparkAIStream(mockStream, {
onToolCall: onToolCallMock
});
const decoder = new TextDecoder();
const chunks = [];
// @ts-ignore
for await (const chunk of protocolStream) {
chunks.push(decoder.decode(chunk, { stream: true }));
}
expect(chunks).toEqual([
'id: cha000b0bf9@dx193d1ffa61cb894532\n',
'event: tool_calls\n',
`data: [{\"function\":{\"arguments\":\"{\\\"city\\\":\\\"Shanghai\\\"}\",\"name\":\"realtime-weather____fetchCurrentWeather\"},\"id\":\"call_1\",\"index\":0,\"type\":\"function\"}]\n\n`
]);
expect(onToolCallMock).toHaveBeenCalledTimes(1);
});
it('should handle text content in stream', async () => {
const mockStream = new ReadableStream({
start(controller) {
controller.enqueue({
id: "test-id",
object: "chat.completion.chunk",
created: 1734395014,
model: "max-32k",
choices: [
{
delta: {
content: "Hello",
role: "assistant"
},
index: 0
}
]
} as OpenAI.ChatCompletionChunk);
controller.enqueue({
id: "test-id",
object: "chat.completion.chunk",
created: 1734395014,
model: "max-32k",
choices: [
{
delta: {
content: " World",
role: "assistant"
},
index: 0
}
]
} as OpenAI.ChatCompletionChunk);
controller.close();
}
});
const onTextMock = vi.fn();
const protocolStream = SparkAIStream(mockStream, {
onText: onTextMock
});
const decoder = new TextDecoder();
const chunks = [];
// @ts-ignore
for await (const chunk of protocolStream) {
chunks.push(decoder.decode(chunk, { stream: true }));
}
expect(chunks).toEqual([
'id: test-id\n',
'event: text\n',
'data: "Hello"\n\n',
'id: test-id\n',
'event: text\n',
'data: " World"\n\n'
]);
expect(onTextMock).toHaveBeenNthCalledWith(1, '"Hello"');
expect(onTextMock).toHaveBeenNthCalledWith(2, '" World"');
});
it('should handle empty stream', async () => {
const mockStream = new ReadableStream({
start(controller) {
controller.close();
}
});
const protocolStream = SparkAIStream(mockStream);
const decoder = new TextDecoder();
const chunks = [];
// @ts-ignore
for await (const chunk of protocolStream) {
chunks.push(decoder.decode(chunk, { stream: true }));
}
expect(chunks).toEqual([]);
});
});
@@ -0,0 +1,134 @@
import OpenAI from 'openai';
import type { Stream } from 'openai/streaming';
import { ChatStreamCallbacks } from '../../types';
import {
StreamProtocolChunk,
StreamProtocolToolCallChunk,
convertIterableToStream,
createCallbacksTransformer,
createSSEProtocolTransformer,
generateToolCallId,
} from './protocol';
export function transformSparkResponseToStream(data: OpenAI.ChatCompletion) {
return new ReadableStream({
start(controller) {
const chunk: OpenAI.ChatCompletionChunk = {
choices: data.choices.map((choice: OpenAI.ChatCompletion.Choice) => {
const toolCallsArray = choice.message.tool_calls
? Array.isArray(choice.message.tool_calls)
? choice.message.tool_calls
: [choice.message.tool_calls]
: []; // 如果不是数组,包装成数组
return {
delta: {
content: choice.message.content,
role: choice.message.role,
tool_calls: toolCallsArray.map(
(tool, index): OpenAI.ChatCompletionChunk.Choice.Delta.ToolCall => ({
function: tool.function,
id: tool.id,
index,
type: tool.type,
}),
),
},
finish_reason: null,
index: choice.index,
logprobs: choice.logprobs,
};
}),
created: data.created,
id: data.id,
model: data.model,
object: 'chat.completion.chunk',
};
controller.enqueue(chunk);
controller.enqueue({
choices: data.choices.map((choice: OpenAI.ChatCompletion.Choice) => ({
delta: {
content: null,
role: choice.message.role,
},
finish_reason: choice.finish_reason,
index: choice.index,
logprobs: choice.logprobs,
})),
created: data.created,
id: data.id,
model: data.model,
object: 'chat.completion.chunk',
system_fingerprint: data.system_fingerprint,
} as OpenAI.ChatCompletionChunk);
controller.close();
},
});
}
export const transformSparkStream = (chunk: OpenAI.ChatCompletionChunk): StreamProtocolChunk => {
const item = chunk.choices[0];
if (!item) {
return { data: chunk, id: chunk.id, type: 'data' };
}
if (item.delta?.tool_calls) {
const toolCallsArray = Array.isArray(item.delta.tool_calls)
? item.delta.tool_calls
: [item.delta.tool_calls]; // 如果不是数组,包装成数组
if (toolCallsArray.length > 0) {
return {
data: toolCallsArray.map((toolCall, index) => ({
function: toolCall.function,
id: toolCall.id || generateToolCallId(index, toolCall.function?.name),
index: typeof toolCall.index !== 'undefined' ? toolCall.index : index,
type: toolCall.type || 'function',
})),
id: chunk.id,
type: 'tool_calls',
} as StreamProtocolToolCallChunk;
}
}
if (item.finish_reason) {
// one-api 的流式接口,会出现既有 finish_reason ,也有 content 的情况
// {"id":"demo","model":"deepl-en","choices":[{"index":0,"delta":{"role":"assistant","content":"Introduce yourself."},"finish_reason":"stop"}]}
if (typeof item.delta?.content === 'string' && !!item.delta.content) {
return { data: item.delta.content, id: chunk.id, type: 'text' };
}
return { data: item.finish_reason, id: chunk.id, type: 'stop' };
}
if (typeof item.delta?.content === 'string') {
return { data: item.delta.content, id: chunk.id, type: 'text' };
}
if (item.delta?.content === null) {
return { data: item.delta, id: chunk.id, type: 'data' };
}
return {
data: { delta: item.delta, id: chunk.id, index: item.index },
id: chunk.id,
type: 'data',
};
};
export const SparkAIStream = (
stream: Stream<OpenAI.ChatCompletionChunk> | ReadableStream,
callbacks?: ChatStreamCallbacks,
) => {
const readableStream =
stream instanceof ReadableStream ? stream : convertIterableToStream(stream);
return readableStream
.pipeThrough(createSSEProtocolTransformer(transformSparkStream))
.pipeThrough(createCallbacksTransformer(callbacks));
};