🐛 fix: Google stream error unable to abort request (#9180)

* 🐛 fix: 优化 Gemini 流处理中的错误注入和终止事件管理

* 🐛 fix: 在流处理过程中注入提供者信息以增强错误报告

* 提取 lobe_error;添加单测

* fix test

* 修复单元测试

* 优化 LobeGoogleAI 中的错误日志记录,使用 debug 替代 console.log;更新单元测试以验证流处理的错误和数据块

* 增强 createSSEProtocolTransformer,添加 requireTerminalEvent 选项以控制终端事件的强制要求;更新相关单元测试以验证新行为

* revert tests

* fix test
This commit is contained in:
sxjeru
2025-09-14 00:00:31 +08:00
committed by GitHub
parent 45fa4e01ae
commit 78eaead0d2
8 changed files with 452 additions and 24 deletions
@@ -58,6 +58,14 @@ afterEach(() => {
});
describe('LobeOpenAICompatibleFactory', () => {
// Polyfill File for Node environment used in image tests
if (typeof File === 'undefined') {
// @ts-ignore
global.File = class MockFile {
constructor(public parts: any[], public name: string, public opts?: any) {}
};
}
describe('init', () => {
it('should correctly initialize with an API key', async () => {
const instance = new LobeMockProvider({ apiKey: 'test_api_key' });
@@ -148,10 +156,22 @@ describe('LobeOpenAICompatibleFactory', () => {
const decoder = new TextDecoder();
const reader = result.body!.getReader();
expect(decoder.decode((await reader.read()).value)).toEqual('id: a\n');
expect(decoder.decode((await reader.read()).value)).toEqual('event: text\n');
expect(decoder.decode((await reader.read()).value)).toEqual('data: "hello"\n\n');
expect((await reader.read()).done).toBe(true);
// Collect all chunks
const chunks = [];
while (true) {
const { value, done } = await reader.read();
if (done) break;
chunks.push(decoder.decode(value));
}
// Assert that all expected chunk patterns are present
expect(chunks).toEqual(
expect.arrayContaining([
'id: a\n',
'event: text\n',
'data: "hello"\n\n',
]),
);
});
// https://github.com/lobehub/lobe-chat/issues/2752
@@ -2,7 +2,7 @@ import { GenerateContentResponse } from '@google/genai';
import { describe, expect, it, vi } from 'vitest';
import * as uuidModule from '../../utils/uuid';
import { GoogleGenerativeAIStream } from './google-ai';
import { GoogleGenerativeAIStream, LOBE_ERROR_KEY } from './google-ai';
describe('GoogleGenerativeAIStream', () => {
it('should transform Google Generative AI stream to protocol stream', async () => {
@@ -21,7 +21,21 @@ describe('GoogleGenerativeAIStream', () => {
controller.enqueue(
mockGenerateContentResponse('', [{ name: 'testFunction', args: { arg1: 'value1' } }]),
);
controller.enqueue(mockGenerateContentResponse(' world!'));
// final chunk should include finishReason and usageMetadata to mark terminal event
controller.enqueue({
text: ' world!',
candidates: [
{ content: { role: 'model' }, finishReason: 'STOP', index: 0 },
],
usageMetadata: {
promptTokenCount: 1,
totalTokenCount: 1,
promptTokensDetails: [{ modality: 'TEXT', tokenCount: 1 }],
},
modelVersion: 'gemini-test',
} as unknown as GenerateContentResponse);
controller.close();
},
});
@@ -63,6 +77,14 @@ describe('GoogleGenerativeAIStream', () => {
'id: chat_1\n',
'event: text\n',
`data: " world!"\n\n`,
// stop
'id: chat_1\n',
'event: stop\n',
`data: "STOP"\n\n`,
// usage
'id: chat_1\n',
'event: usage\n',
`data: {"inputTextTokens":1,"outputImageTokens":0,"outputTextTokens":0,"totalInputTokens":1,"totalOutputTokens":0,"totalTokens":1}\n\n`,
]);
expect(onStartMock).toHaveBeenCalledTimes(1);
@@ -73,8 +95,23 @@ describe('GoogleGenerativeAIStream', () => {
});
it('should handle empty stream', async () => {
vi.spyOn(uuidModule, 'nanoid').mockReturnValueOnce('E5M9dFKw');
const mockGoogleStream = new ReadableStream({
start(controller) {
controller.enqueue({
candidates: [{ content: { role: 'model' }, finishReason: 'STOP', index: 0 }],
usageMetadata: {
promptTokenCount: 0,
cachedContentTokenCount: 0,
totalTokenCount: 0,
promptTokensDetails: [
{ modality: 'TEXT', tokenCount: 0 },
{ modality: 'IMAGE', tokenCount: 0 },
],
},
modelVersion: 'gemini-test',
} as unknown as GenerateContentResponse);
controller.close();
},
});
@@ -89,7 +126,14 @@ describe('GoogleGenerativeAIStream', () => {
chunks.push(decoder.decode(chunk, { stream: true }));
}
expect(chunks).toEqual([]);
expect(chunks).toEqual([
'id: chat_E5M9dFKw\n',
'event: stop\n',
`data: "STOP"\n\n`,
'id: chat_E5M9dFKw\n',
'event: usage\n',
`data: {"inputCachedTokens":0,"inputImageTokens":0,"inputTextTokens":0,"outputImageTokens":0,"outputTextTokens":0,"totalInputTokens":0,"totalOutputTokens":0,"totalTokens":0}\n\n`,
]);
});
it('should handle image', async () => {
@@ -102,13 +146,17 @@ describe('GoogleGenerativeAIStream', () => {
parts: [{ inlineData: { mimeType: 'image/png', data: 'iVBORw0KGgoAA' } }],
role: 'model',
},
finishReason: 'STOP',
index: 0,
},
],
usageMetadata: {
promptTokenCount: 6,
totalTokenCount: 6,
promptTokensDetails: [{ modality: 'TEXT', tokenCount: 6 }],
promptTokensDetails: [
{ modality: 'TEXT', tokenCount: 6 },
{ modality: 'IMAGE', tokenCount: 0 },
],
},
modelVersion: 'gemini-2.0-flash-exp',
};
@@ -136,6 +184,14 @@ describe('GoogleGenerativeAIStream', () => {
'id: chat_1\n',
'event: base64_image\n',
`data: "data:image/png;base64,iVBORw0KGgoAA"\n\n`,
// stop
'id: chat_1\n',
'event: stop\n',
`data: "STOP"\n\n`,
// usage
'id: chat_1\n',
'event: usage\n',
`data: {"inputImageTokens":0,"inputTextTokens":6,"outputImageTokens":0,"outputTextTokens":0,"totalInputTokens":6,"totalOutputTokens":0,"totalTokens":6}\n\n`,
]);
});
@@ -855,4 +911,33 @@ describe('GoogleGenerativeAIStream', () => {
`data: {"body":{"context":{"promptFeedback":{"blockReason":"PROHIBITED_CONTENT"}},"message":"您的请求可能包含违禁内容。请调整您的请求,确保内容符合使用规范。","provider":"google"},"type":"ProviderBizError"}\n\n`,
]);
});
it('should pass through injected lobe error marker', async () => {
vi.spyOn(uuidModule, 'nanoid').mockReturnValueOnce('1');
const errorPayload = { message: 'internal error', code: 123 };
const mockGoogleStream = new ReadableStream({
start(controller) {
controller.enqueue({ [LOBE_ERROR_KEY]: errorPayload });
controller.close();
},
});
const protocolStream = GoogleGenerativeAIStream(mockGoogleStream);
const decoder = new TextDecoder();
const chunks = [];
// @ts-ignore
for await (const chunk of protocolStream) {
chunks.push(decoder.decode(chunk, { stream: true }));
}
expect(chunks).toEqual([
'id: chat_1\n',
'event: error\n',
`data: ${JSON.stringify(errorPayload)}\n\n`,
]);
});
});
@@ -16,6 +16,8 @@ import {
generateToolCallId,
} from './protocol';
export const LOBE_ERROR_KEY = '__lobe_error';
const getBlockReasonMessage = (blockReason: string): string => {
const blockReasonMessages = errorLocale.response.GoogleAIBlockReason;
@@ -29,6 +31,14 @@ const transformGoogleGenerativeAIStream = (
chunk: GenerateContentResponse,
context: StreamContext,
): StreamProtocolChunk | StreamProtocolChunk[] => {
// Handle injected internal error marker to pass through detailed error info
if ((chunk as any)?.[LOBE_ERROR_KEY]) {
return {
data: (chunk as any)[LOBE_ERROR_KEY],
id: context?.id || 'error',
type: 'error',
};
}
// Handle promptFeedback with blockReason (e.g., PROHIBITED_CONTENT)
if ('promptFeedback' in chunk && (chunk as any).promptFeedback?.blockReason) {
const blockReason = (chunk as any).promptFeedback.blockReason;
@@ -216,6 +226,8 @@ export const GoogleGenerativeAIStream = (
.pipeThrough(
createTokenSpeedCalculator(transformGoogleGenerativeAIStream, { inputStartAt, streamStack }),
)
.pipeThrough(createSSEProtocolTransformer((c) => c, streamStack))
.pipeThrough(
createSSEProtocolTransformer((c) => c, streamStack, { requireTerminalEvent: true }),
)
.pipeThrough(createCallbacksTransformer(callbacks));
};
@@ -2481,4 +2481,4 @@ describe('OpenAIStream', () => {
`data: "${base64_2}"\n\n`,
]);
});
});
});
@@ -1,6 +1,6 @@
import { describe, expect, it } from 'vitest';
import { createSSEDataExtractor, createTokenSpeedCalculator } from './protocol';
import { createSSEDataExtractor, createTokenSpeedCalculator, createSSEProtocolTransformer } from './protocol';
describe('createSSEDataExtractor', () => {
// Helper function to convert string to Uint8Array
@@ -233,3 +233,82 @@ describe('createTokenSpeedCalculator', async () => {
expect(speedChunk.data.ttft).not.toBeNaN();
});
});
describe('createSSEProtocolTransformer', () => {
const processChunk = async (transformer: TransformStream, chunk: any) => {
const results: any[] = [];
const readable = new ReadableStream({
start(controller) {
controller.enqueue(chunk);
controller.close();
},
});
const writable = new WritableStream({
write(chunk) {
results.push(chunk);
},
});
await readable.pipeThrough(transformer).pipeTo(writable);
return results;
};
it('should convert chunk into SSE formatted lines without enforcing terminal (default)', async () => {
const transformerFn = (chunk: any) => ({ type: 'text', id: chunk.id, data: chunk.data });
const transformer = createSSEProtocolTransformer(transformerFn as any);
const input = { id: '1', data: 'hello' };
const results = await processChunk(transformer, input);
// Should only output the text event, no injected error on flush (default not enforced)
expect(results).toEqual([
`id: 1\n`,
`event: text\n`,
`data: ${JSON.stringify('hello')}\n\n`,
]);
});
it('should not emit flush error if a terminal event was received (enforced)', async () => {
const transformerFn = (chunk: any) => ({ type: 'stop', id: chunk.id, data: chunk.data });
const transformer = createSSEProtocolTransformer(transformerFn as any, { id: 'stream_ok' }, { requireTerminalEvent: true });
const input = { id: 'ok', data: 'bye' };
const results = await processChunk(transformer, input);
// Only the stop event lines should be present (no extra error event from flush)
expect(results).toEqual([
`id: ok\n`,
`event: stop\n`,
`data: ${JSON.stringify('bye')}\n\n`,
]);
});
it('should emit an error event on flush when no terminal event received (enforced)', async () => {
const transformerFn = (chunk: any) => ({ type: 'text', id: chunk.id, data: chunk.data });
const streamStack = { id: 'stream_missing_term' } as any;
const transformer = createSSEProtocolTransformer(transformerFn as any, streamStack, { requireTerminalEvent: true });
const input = { id: '1', data: 'partial' };
const results = await processChunk(transformer, input);
// original 3 lines + 3 lines from flush error
expect(results).toHaveLength(6);
// last three lines should be the injected error event
const lastThree = results.slice(-3);
const expectedData = {
body: { name: 'Stream parsing error', reason: 'unexpected_end' },
message: 'Stream ended unexpectedly',
name: 'Stream parsing error',
type: 'StreamChunkError',
};
expect(lastThree).toEqual([
`id: ${streamStack.id}\n`,
`event: error\n`,
`data: ${JSON.stringify(expectedData)}\n\n`,
]);
});
});
@@ -166,8 +166,27 @@ export const convertIterableToStream = <T>(stream: AsyncIterable<T>) => {
export const createSSEProtocolTransformer = (
transformer: (chunk: any, stack: StreamContext) => StreamProtocolChunk | StreamProtocolChunk[],
streamStack?: StreamContext,
) =>
new TransformStream({
options?: { requireTerminalEvent?: boolean },
) => {
let hasTerminalEvent = false;
const requireTerminalEvent = Boolean(options?.requireTerminalEvent);
return new TransformStream({
flush(controller) {
// If the upstream closes without sending a terminal event, emit a final error event
if (requireTerminalEvent && !hasTerminalEvent) {
const id = streamStack?.id || 'stream_end';
const data = {
body: { name: 'Stream parsing error', reason: 'unexpected_end' },
message: 'Stream ended unexpectedly',
name: 'Stream parsing error',
type: 'StreamChunkError',
};
controller.enqueue(`id: ${id}\n`);
controller.enqueue(`event: error\n`);
controller.enqueue(`data: ${JSON.stringify(data)}\n\n`);
}
},
transform: (chunk, controller) => {
const result = transformer(chunk, streamStack || { id: '' });
@@ -177,9 +196,13 @@ export const createSSEProtocolTransformer = (
controller.enqueue(`id: ${id}\n`);
controller.enqueue(`event: ${type}\n`);
controller.enqueue(`data: ${JSON.stringify(data)}\n\n`);
// mark terminal when receiving any of these events
if (type === 'stop' || type === 'usage' || type === 'error') hasTerminalEvent = true;
});
},
});
};
export function createCallbacksTransformer(cb: ChatStreamCallbacks | undefined) {
const textEncoder = new TextEncoder();
@@ -8,6 +8,8 @@ import { ChatStreamPayload } from '@/types/openai/chat';
import * as debugStreamModule from '../../utils/debugStream';
import * as imageToBase64Module from '../../utils/imageToBase64';
import { LOBE_ERROR_KEY } from '../../core/streams/google-ai';
import { AgentRuntimeErrorType } from '../../types/error';
import { LobeGoogleAI } from './index';
const provider = 'google';
@@ -825,5 +827,158 @@ describe('LobeGoogleAI', () => {
});
});
});
describe('createEnhancedStream', () => {
it('should handle stream cancellation with data gracefully', async () => {
const mockStream = (async function* () {
yield { text: 'Hello' };
yield { text: ' world' };
})();
const abortController = new AbortController();
const enhancedStream = instance['createEnhancedStream'](mockStream, abortController.signal);
const reader = enhancedStream.getReader();
const chunks: any[] = [];
// Read first value then cancel to trigger error chunk
chunks.push((await reader.read()).value);
abortController.abort();
// Read all remaining chunks
let result;
while (!(result = await reader.read()).done) {
chunks.push(result.value);
}
// Batch-assert the entire chunks array
expect(chunks).toEqual([
{ text: 'Hello' },
{
[LOBE_ERROR_KEY]: {
body: { name: 'Stream cancelled', provider, reason: 'aborted' },
message: 'Stream cancelled',
name: 'Stream cancelled',
type: AgentRuntimeErrorType.StreamChunkError,
},
},
]);
});
it('should handle stream cancellation without data', async () => {
const mockStream = (async function* () {
// Empty stream
})();
const abortController = new AbortController();
const enhancedStream = instance['createEnhancedStream'](mockStream, abortController.signal);
const reader = enhancedStream.getReader();
// Cancel immediately
abortController.abort();
// Should be closed without any chunks
const chunk = await reader.read();
expect(chunk.done).toBe(true);
});
it('should handle AbortError with data', async () => {
const mockStream = (async function* () {
yield { text: 'Hello' };
throw new Error('aborted');
})();
const abortController = new AbortController();
const enhancedStream = instance['createEnhancedStream'](mockStream, abortController.signal);
const reader = enhancedStream.getReader();
const chunks: any[] = [];
// Read first value then collect remaining chunks (error included)
chunks.push((await reader.read()).value);
let result;
while (!(result = await reader.read()).done) {
chunks.push(result.value);
}
// Assert both data and error chunk together
expect(chunks).toEqual([
{ text: 'Hello' },
{
[LOBE_ERROR_KEY]: {
body: { name: 'Stream cancelled', provider, reason: 'aborted' },
message: 'Stream cancelled',
name: 'Stream cancelled',
type: AgentRuntimeErrorType.StreamChunkError,
},
},
]);
});
it('should handle AbortError without data', async () => {
const mockStream = (async function* () {
throw new Error('aborted');
})();
const abortController = new AbortController();
const enhancedStream = instance['createEnhancedStream'](mockStream, abortController.signal);
const reader = enhancedStream.getReader();
const chunks: any[] = [];
// Read error chunk
const chunk1 = await reader.read();
chunks.push(chunk1.value);
// Stream should be closed
const chunk2 = await reader.read();
expect(chunk2.done).toBe(true);
expect(chunks[0][LOBE_ERROR_KEY]).toEqual({
body: {
message: 'aborted',
name: 'AbortError',
provider,
stack: expect.any(String),
},
message: 'aborted',
name: 'AbortError',
type: AgentRuntimeErrorType.StreamChunkError,
});
});
it('should handle other stream parsing errors', async () => {
const mockStream = (async function* () {
yield { text: 'Hello' };
throw new Error('Network error');
})();
const abortController = new AbortController();
const enhancedStream = instance['createEnhancedStream'](mockStream, abortController.signal);
const reader = enhancedStream.getReader();
const chunks: any[] = [];
// Read first value then collect remaining chunks (parsing error)
chunks.push((await reader.read()).value);
let result;
while (!(result = await reader.read()).done) {
chunks.push(result.value);
}
expect(chunks).toEqual([
{ text: 'Hello' },
{
[LOBE_ERROR_KEY]: {
body: { message: 'Network error', provider },
message: 'Network error',
name: 'Stream parsing error',
type: AgentRuntimeErrorType.ProviderBizError,
},
},
]);
});
});
});
});
@@ -10,6 +10,7 @@ import {
ThinkingConfig,
} from '@google/genai';
import { LOBE_ERROR_KEY } from '../../core/streams/google-ai';
import { LobeRuntimeAI } from '../../core/BaseAI';
import { GoogleGenerativeAIStream, VertexAIStream } from '../../core/streams';
import {
@@ -29,6 +30,9 @@ import { StreamingResponse } from '../../utils/response';
import { safeParseJSON } from '../../utils/safeParseJSON';
import { parseDataUri } from '../../utils/uriParser';
import { createGoogleImage } from './createImage';
import debug from 'debug';
const log = debug('model-runtime:google');
const modelsOffSafetySettings = new Set(['gemini-2.0-flash-exp']);
@@ -244,7 +248,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {
// 移除之前的静默处理,统一抛出错误
if (isAbortError(err)) {
console.log('Request was cancelled');
log('Request was cancelled');
throw AgentRuntimeError.chat({
error: { message: 'Request was cancelled' },
errorType: AgentRuntimeErrorType.ProviderBizError,
@@ -252,7 +256,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {
});
}
console.log(err);
log('Error: %O', err);
const { errorType, error } = parseGoogleErrorMessage(err.message);
throw AgentRuntimeError.chat({ error, errorType, provider: this.provider });
@@ -268,6 +272,8 @@ export class LobeGoogleAI implements LobeRuntimeAI {
}
private createEnhancedStream(originalStream: any, signal: AbortSignal): ReadableStream {
// capture provider for error payloads inside the stream closure
const provider = this.provider;
return new ReadableStream({
async start(controller) {
let hasData = false;
@@ -277,12 +283,23 @@ export class LobeGoogleAI implements LobeRuntimeAI {
if (signal.aborted) {
// 如果有数据已经输出,优雅地关闭流而不是抛出错误
if (hasData) {
console.log('Stream cancelled gracefully, preserving existing output');
log('Stream cancelled gracefully, preserving existing output');
// 显式注入取消错误,避免走 SSE 兜底 unexpected_end
controller.enqueue({
[LOBE_ERROR_KEY]: {
body: { name: 'Stream cancelled', provider, reason: 'aborted' },
message: 'Stream cancelled',
name: 'Stream cancelled',
type: AgentRuntimeErrorType.StreamChunkError,
},
});
controller.close();
return;
} else {
// 如果还没有数据输出,则抛出取消错误
throw new Error('Stream cancelled');
// 如果还没有数据输出,直接关闭流,由下游 SSE 在 flush 阶段补发错误事件
log('Stream cancelled before any output');
controller.close();
return;
}
}
@@ -296,18 +313,55 @@ export class LobeGoogleAI implements LobeRuntimeAI {
if (isAbortError(err) || signal.aborted) {
// 如果有数据已经输出,优雅地关闭流
if (hasData) {
console.log('Stream reading cancelled gracefully, preserving existing output');
log('Stream reading cancelled gracefully, preserving existing output');
// 显式注入取消错误,避免走 SSE 兜底 unexpected_end
controller.enqueue({
[LOBE_ERROR_KEY]: {
body: { name: 'Stream cancelled', provider, reason: 'aborted' },
message: 'Stream cancelled',
name: 'Stream cancelled',
type: AgentRuntimeErrorType.StreamChunkError,
},
});
controller.close();
return;
} else {
console.log('Stream reading cancelled before any output');
controller.error(new Error('Stream cancelled'));
log('Stream reading cancelled before any output');
// 注入一个带详细错误信息的错误标记,交由下游 google-ai transformer 输出 error 事件
controller.enqueue({
[LOBE_ERROR_KEY]: {
body: {
message: err.message,
name: 'AbortError',
provider,
stack: err.stack,
},
message: err.message || 'Request was cancelled',
name: 'AbortError',
type: AgentRuntimeErrorType.StreamChunkError,
},
});
controller.close();
return;
}
} else {
// 处理其他流解析错误
console.error('Stream parsing error:', err);
controller.error(err);
log('Stream parsing error: %O', err);
// 尝试解析 Google 错误并提取 code/message/status
const { error: parsedError, errorType } = parseGoogleErrorMessage(
err?.message || String(err),
);
// 注入一个带详细错误信息的错误标记,交由下游 google-ai transformer 输出 error 事件
controller.enqueue({
[LOBE_ERROR_KEY]: {
body: { ...parsedError, provider },
message: parsedError?.message || err.message || 'Stream parsing error',
name: 'Stream parsing error',
type: errorType ?? AgentRuntimeErrorType.StreamChunkError,
},
});
controller.close();
return;
}
}
@@ -348,7 +402,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {
return processModelList(processedModels, MODEL_LIST_CONFIGS.google);
} catch (error) {
console.error('Failed to fetch Google models:', error);
log('Failed to fetch Google models: %O', error);
throw error;
}
}